1#![allow(missing_docs)]
2use std::collections::HashMap;
10
11use chrono::{DateTime, Utc};
12use parking_lot::RwLock;
13use serde::{Deserialize, Serialize};
14use uuid::Uuid;
15
16use crate::memory::embedding::{EmbeddingProvider, EmbeddingVector};
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
24#[serde(rename_all = "snake_case")]
25#[derive(Default)]
26pub enum SonaMode {
27 RealTime,
29 #[default]
31 Balanced,
32 Research,
34 Edge,
36}
37
38#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
40#[serde(rename_all = "snake_case")]
41pub enum Verdict {
42 Success,
44 PartialFailure,
46 Failure,
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct TrajectoryStep {
53 pub input: String,
55 pub output: String,
57 pub duration_ms: u64,
59 pub confidence: f32,
61}
62
63#[derive(Debug, Clone, Serialize, Deserialize)]
65pub struct Trajectory {
66 pub id: String,
68 pub steps: Vec<TrajectoryStep>,
70 pub verdict: Verdict,
72 pub domain: String,
74 pub created_at: DateTime<Utc>,
76 #[serde(skip)]
78 pub embedding: Option<EmbeddingVector>,
79}
80
81impl Trajectory {
82 pub fn new(steps: Vec<TrajectoryStep>, verdict: Verdict, domain: &str) -> Self {
84 Self {
85 id: Uuid::new_v4().to_string(),
86 steps,
87 verdict,
88 domain: domain.to_string(),
89 created_at: Utc::now(),
90 embedding: None,
91 }
92 }
93
94 pub fn total_duration_ms(&self) -> u64 {
96 self.steps.iter().map(|s| s.duration_ms).sum()
97 }
98
99 pub fn avg_confidence(&self) -> f32 {
101 if self.steps.is_empty() {
102 return 0.0;
103 }
104 self.steps.iter().map(|s| s.confidence).sum::<f32>() / self.steps.len() as f32
105 }
106
107 pub fn input_text(&self) -> String {
109 self.steps
110 .iter()
111 .map(|s| s.input.as_str())
112 .collect::<Vec<_>>()
113 .join(" ")
114 }
115}
116
117#[derive(Debug, Clone, Serialize, Deserialize)]
119pub struct LearnedPattern {
120 pub id: String,
122 pub source_trajectories: Vec<String>,
124 pub strategy: String,
126 pub domain: String,
128 pub confidence: f32,
130 pub support_count: usize,
132 #[serde(skip)]
134 pub embedding: Option<EmbeddingVector>,
135}
136
137pub struct SonaEngine {
146 mode: SonaMode,
148 trajectories: RwLock<Vec<Trajectory>>,
150 learned_patterns: RwLock<Vec<LearnedPattern>>,
152 embedding: std::sync::Arc<dyn EmbeddingProvider>,
154 max_trajectories: usize,
156}
157
158impl std::fmt::Debug for SonaEngine {
159 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
160 f.debug_struct("SonaEngine")
161 .field("mode", &self.mode)
162 .field("trajectory_count", &self.trajectories.read().len())
163 .field("pattern_count", &self.learned_patterns.read().len())
164 .finish()
165 }
166}
167
168impl SonaEngine {
169 pub fn new(mode: SonaMode, embedding: std::sync::Arc<dyn EmbeddingProvider>) -> Self {
171 let max_trajectories = match mode {
172 SonaMode::RealTime => 100,
173 SonaMode::Balanced => 500,
174 SonaMode::Research => 5000,
175 SonaMode::Edge => 50,
176 };
177
178 Self {
179 mode,
180 trajectories: RwLock::new(Vec::new()),
181 learned_patterns: RwLock::new(Vec::new()),
182 embedding,
183 max_trajectories,
184 }
185 }
186
187 pub fn mode(&self) -> SonaMode {
189 self.mode
190 }
191
192 pub async fn record(&self, mut trajectory: Trajectory) -> Result<String, anyhow::Error> {
197 if trajectory.id.is_empty() {
198 trajectory.id = Uuid::new_v4().to_string();
199 }
200
201 let text = trajectory.input_text();
203 if !text.is_empty() {
204 let embedding = self.embedding.embed(&text).await?;
205 trajectory.embedding = Some(embedding);
206 }
207
208 let id = trajectory.id.clone();
209
210 let mut trajs = self.trajectories.write();
211
212 if trajs.len() >= self.max_trajectories {
214 let remove_count = trajs.len() - self.max_trajectories + 1;
216 let mut removed = 0;
217 trajs.retain(|t| {
218 if removed >= remove_count {
219 return true;
220 }
221 if t.verdict == Verdict::Failure {
222 removed += 1;
223 false
224 } else {
225 true
226 }
227 });
228 while trajs.len() >= self.max_trajectories {
230 trajs.remove(0);
231 }
232 }
233
234 trajs.push(trajectory);
235 Ok(id)
236 }
237
238 pub async fn distill(&self) -> Result<Vec<LearnedPattern>, anyhow::Error> {
243 let domain_groups: HashMap<String, Vec<Trajectory>> = {
245 let trajs = self.trajectories.read();
246 let mut groups: HashMap<String, Vec<Trajectory>> = HashMap::new();
247 for traj in trajs.iter() {
248 if traj.verdict == Verdict::Success {
249 groups
250 .entry(traj.domain.clone())
251 .or_default()
252 .push(traj.clone());
253 }
254 }
255 groups
256 }; let mut new_patterns = Vec::new();
259
260 for (domain, group) in &domain_groups {
261 if group.len() < 2 {
262 continue; }
264
265 let mut strategy_parts: Vec<String> = Vec::new();
268 for traj in group {
269 let summary: String = traj
270 .steps
271 .iter()
272 .take(3) .map(|s| s.input.clone())
274 .collect::<Vec<_>>()
275 .join(" → ");
276 strategy_parts.push(summary);
277 }
278
279 let combined = strategy_parts.join("; ");
281 let strategy = if combined.len() > 500 {
282 format!("{}...", &combined[..500])
283 } else {
284 combined
285 };
286
287 let embedding = self.embedding.embed(&strategy).await?;
288
289 let source_ids: Vec<String> = group.iter().map(|t| t.id.clone()).collect();
290
291 let pattern = LearnedPattern {
292 id: Uuid::new_v4().to_string(),
293 source_trajectories: source_ids,
294 strategy,
295 domain: domain.clone(),
296 confidence: (group.len() as f32 * 0.2).min(1.0),
297 support_count: group.len(),
298 embedding: Some(embedding),
299 };
300
301 new_patterns.push(pattern);
302 }
303
304 {
306 let mut patterns = self.learned_patterns.write();
307 for pattern in &new_patterns {
308 let is_dup = patterns
310 .iter()
311 .any(|p| p.strategy == pattern.strategy && p.domain == pattern.domain);
312 if !is_dup {
313 patterns.push(pattern.clone());
314 }
315 }
316 }
317
318 tracing::info!(
319 new_patterns = new_patterns.len(),
320 "SONA distillation complete"
321 );
322 Ok(new_patterns)
323 }
324
325 pub async fn adapt(&self, query: &str) -> Result<Option<LearnedPattern>, anyhow::Error> {
330 let query_embedding = self.embedding.embed(query).await?;
331
332 let patterns = self.learned_patterns.read();
333 let mut best: Option<(&LearnedPattern, f64)> = None;
334
335 for pattern in patterns.iter() {
336 if let Some(ref emb) = pattern.embedding {
337 let sim = query_embedding.cosine_similarity(emb);
338 match best {
339 Some((_, best_sim)) if sim <= best_sim => {}
340 _ => best = Some((pattern, sim)),
341 }
342 }
343 }
344
345 Ok(best.filter(|(_, sim)| *sim > 0.3).map(|(p, sim)| {
346 let mut adapted = p.clone();
347 adapted.confidence = (p.confidence * sim as f32).min(1.0);
348 adapted
349 }))
350 }
351
352 pub fn counts(&self) -> (usize, usize) {
354 let traj_count = self.trajectories.read().len();
355 let pattern_count = self.learned_patterns.read().len();
356 (traj_count, pattern_count)
357 }
358
359 pub fn get_learned_patterns(&self) -> Vec<LearnedPattern> {
361 self.learned_patterns.read().clone()
362 }
363
364 pub fn load_learned_patterns(&self, patterns: Vec<LearnedPattern>) {
366 let mut existing = self.learned_patterns.write();
367 *existing = patterns;
368 }
369
370 pub fn trajectories_by_verdict(&self, verdict: Verdict) -> Vec<Trajectory> {
372 self.trajectories
373 .read()
374 .iter()
375 .filter(|t| t.verdict == verdict)
376 .cloned()
377 .collect()
378 }
379
380 #[cfg(feature = "sqlite-memory")]
384 pub fn persist_to_sqlite(
385 &self,
386 store: &crate::memory::sqlite::store::SqliteMemoryStore,
387 ) -> anyhow::Result<()> {
388 let patterns = self.learned_patterns.read();
389 for pattern in patterns.iter() {
390 let data = serde_json::to_string(pattern)?;
391 store.save_pattern(
392 &pattern.id,
393 "sona",
394 Some(&pattern.domain),
395 pattern.confidence,
396 &data,
397 )?;
398 }
399 tracing::debug!(count = patterns.len(), "SONA patterns persisted to SQLite");
400 Ok(())
401 }
402
403 #[cfg(feature = "sqlite-memory")]
407 pub fn restore_from_sqlite(
408 &self,
409 store: &crate::memory::sqlite::store::SqliteMemoryStore,
410 ) -> anyhow::Result<()> {
411 let rows = store.load_patterns()?;
412 let sona_rows: Vec<_> = rows.into_iter().filter(|r| r.strategy == "sona").collect();
413
414 let mut patterns = Vec::new();
415 for row in &sona_rows {
416 if let Ok(pattern) = serde_json::from_str::<LearnedPattern>(&row.data) {
417 patterns.push(pattern);
418 }
419 }
420
421 *self.learned_patterns.write() = patterns;
422 tracing::debug!(
423 count = sona_rows.len(),
424 "SONA patterns restored from SQLite"
425 );
426 Ok(())
427 }
428}
429
430#[cfg(test)]
435mod tests {
436 use super::*;
437 use crate::memory::embedding::TfIdfEmbeddingProvider;
438
439 fn make_step(input: &str, output: &str) -> TrajectoryStep {
440 TrajectoryStep {
441 input: input.to_string(),
442 output: output.to_string(),
443 duration_ms: 10,
444 confidence: 0.9,
445 }
446 }
447
448 fn make_trajectory(domain: &str, verdict: Verdict) -> Trajectory {
449 Trajectory::new(
450 vec![
451 make_step("analyze input", "parsed"),
452 make_step("execute plan", "completed"),
453 ],
454 verdict,
455 domain,
456 )
457 }
458
459 #[tokio::test]
460 async fn test_record_trajectory() {
461 let engine = SonaEngine::new(
462 SonaMode::Balanced,
463 std::sync::Arc::new(TfIdfEmbeddingProvider),
464 );
465 let traj = make_trajectory("testing", Verdict::Success);
466
467 let id = engine.record(traj).await.unwrap();
468 assert!(!id.is_empty());
469
470 let (traj_count, _) = engine.counts();
471 assert_eq!(traj_count, 1);
472 }
473
474 #[tokio::test]
475 async fn test_distill_patterns() {
476 let engine = SonaEngine::new(
477 SonaMode::Balanced,
478 std::sync::Arc::new(TfIdfEmbeddingProvider),
479 );
480
481 for _ in 0..3 {
483 let traj = make_trajectory("security", Verdict::Success);
484 engine.record(traj).await.unwrap();
485 }
486
487 let patterns = engine.distill().await.unwrap();
488 assert!(
489 !patterns.is_empty(),
490 "Should distill patterns from 3+ successful trajectories"
491 );
492
493 let (_, pattern_count) = engine.counts();
494 assert!(pattern_count > 0);
495 }
496
497 #[tokio::test]
498 async fn test_distill_needs_multiple_successes() {
499 let engine = SonaEngine::new(
500 SonaMode::Balanced,
501 std::sync::Arc::new(TfIdfEmbeddingProvider),
502 );
503
504 engine
505 .record(make_trajectory("testing", Verdict::Success))
506 .await
507 .unwrap();
508 let patterns = engine.distill().await.unwrap();
509 assert!(patterns.is_empty(), "Need 2+ trajectories to distill");
510 }
511
512 #[tokio::test]
513 async fn test_distill_ignores_failures() {
514 let engine = SonaEngine::new(
515 SonaMode::Balanced,
516 std::sync::Arc::new(TfIdfEmbeddingProvider),
517 );
518
519 engine
520 .record(make_trajectory("testing", Verdict::Failure))
521 .await
522 .unwrap();
523 engine
524 .record(make_trajectory("testing", Verdict::Failure))
525 .await
526 .unwrap();
527
528 let patterns = engine.distill().await.unwrap();
529 assert!(patterns.is_empty(), "Failures should not produce patterns");
530 }
531
532 #[tokio::test]
533 async fn test_adapt_finds_similar_pattern() {
534 let engine = SonaEngine::new(
535 SonaMode::Balanced,
536 std::sync::Arc::new(TfIdfEmbeddingProvider),
537 );
538
539 for _ in 0..3 {
541 let mut traj = make_trajectory("security", Verdict::Success);
542 traj.steps[0].input =
543 "scan for SQL injection vulnerabilities in the codebase".to_string();
544 engine.record(traj).await.unwrap();
545 }
546 engine.distill().await.unwrap();
547
548 let result = engine
550 .adapt("check for SQL injection security issues")
551 .await
552 .unwrap();
553 assert!(result.is_some(), "Should find a matching pattern");
554 let pattern = result.unwrap();
555 assert_eq!(pattern.domain, "security");
556 assert!(pattern.confidence > 0.0);
557 }
558
559 #[tokio::test]
560 async fn test_adapt_no_match_below_threshold() {
561 let engine = SonaEngine::new(
562 SonaMode::Balanced,
563 std::sync::Arc::new(TfIdfEmbeddingProvider),
564 );
565
566 let result = engine
568 .adapt("completely unrelated query about cooking")
569 .await
570 .unwrap();
571 assert!(result.is_none());
572 }
573
574 #[tokio::test]
575 async fn test_capacity_limit() {
576 let engine = SonaEngine::new(SonaMode::Edge, std::sync::Arc::new(TfIdfEmbeddingProvider));
577 for i in 0..55 {
580 let mut traj = make_trajectory("testing", Verdict::Success);
581 traj.id = format!("traj-{}", i);
582 engine.record(traj).await.unwrap();
583 }
584
585 let (count, _) = engine.counts();
586 assert!(count <= 50, "Should not exceed capacity: got {}", count);
587 }
588
589 #[test]
590 fn test_trajectory_total_duration() {
591 let traj = Trajectory::new(
592 vec![make_step("a", "b"), make_step("c", "d")],
593 Verdict::Success,
594 "testing",
595 );
596 assert_eq!(traj.total_duration_ms(), 20);
597 }
598
599 #[test]
600 fn test_trajectory_avg_confidence() {
601 let traj = Trajectory::new(
602 vec![
603 TrajectoryStep {
604 input: "a".into(),
605 output: "b".into(),
606 duration_ms: 10,
607 confidence: 0.8,
608 },
609 TrajectoryStep {
610 input: "c".into(),
611 output: "d".into(),
612 duration_ms: 10,
613 confidence: 0.6,
614 },
615 ],
616 Verdict::Success,
617 "testing",
618 );
619 assert!((traj.avg_confidence() - 0.7).abs() < 0.01);
620 }
621
622 #[test]
623 fn test_sona_mode_default() {
624 assert_eq!(SonaMode::default(), SonaMode::Balanced);
625 }
626}