1#![allow(missing_docs)]
2use std::collections::HashMap;
10
11use chrono::{DateTime, Utc};
12use parking_lot::RwLock;
13use serde::{Deserialize, Serialize};
14use uuid::Uuid;
15
16use crate::memory::embedding::{EmbeddingProvider, EmbeddingVector};
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
24#[serde(rename_all = "snake_case")]
25#[derive(Default)]
26pub enum SonaMode {
27 RealTime,
29 #[default]
31 Balanced,
32 Research,
34 Edge,
36}
37
38#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
40#[serde(rename_all = "snake_case")]
41pub enum Verdict {
42 Success,
44 PartialFailure,
46 Failure,
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct TrajectoryStep {
53 pub input: String,
55 pub output: String,
57 pub duration_ms: u64,
59 pub confidence: f32,
61}
62
63#[derive(Debug, Clone, Serialize, Deserialize)]
65pub struct Trajectory {
66 pub id: String,
68 pub steps: Vec<TrajectoryStep>,
70 pub verdict: Verdict,
72 pub domain: String,
74 pub created_at: DateTime<Utc>,
76 #[serde(skip)]
78 pub embedding: Option<EmbeddingVector>,
79}
80
81impl Trajectory {
82 pub fn new(steps: Vec<TrajectoryStep>, verdict: Verdict, domain: &str) -> Self {
84 Self {
85 id: Uuid::new_v4().to_string(),
86 steps,
87 verdict,
88 domain: domain.to_string(),
89 created_at: Utc::now(),
90 embedding: None,
91 }
92 }
93
94 pub fn total_duration_ms(&self) -> u64 {
96 self.steps.iter().map(|s| s.duration_ms).sum()
97 }
98
99 pub fn avg_confidence(&self) -> f32 {
101 if self.steps.is_empty() {
102 return 0.0;
103 }
104 self.steps.iter().map(|s| s.confidence).sum::<f32>() / self.steps.len() as f32
105 }
106
107 pub fn input_text(&self) -> String {
109 self.steps
110 .iter()
111 .map(|s| s.input.as_str())
112 .collect::<Vec<_>>()
113 .join(" ")
114 }
115}
116
117#[derive(Debug, Clone, Serialize, Deserialize)]
119pub struct LearnedPattern {
120 pub id: String,
122 pub source_trajectories: Vec<String>,
124 pub strategy: String,
126 pub domain: String,
128 pub confidence: f32,
130 pub support_count: usize,
132 #[serde(skip)]
134 pub embedding: Option<EmbeddingVector>,
135}
136
137pub struct SonaEngine {
146 mode: SonaMode,
148 trajectories: RwLock<Vec<Trajectory>>,
150 learned_patterns: RwLock<Vec<LearnedPattern>>,
152 embedding: std::sync::Arc<dyn EmbeddingProvider>,
154 max_trajectories: usize,
156}
157
158impl std::fmt::Debug for SonaEngine {
159 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
160 f.debug_struct("SonaEngine")
161 .field("mode", &self.mode)
162 .field("trajectory_count", &self.trajectories.read().len())
163 .field("pattern_count", &self.learned_patterns.read().len())
164 .finish()
165 }
166}
167
168impl SonaEngine {
169 pub fn new(mode: SonaMode, embedding: std::sync::Arc<dyn EmbeddingProvider>) -> Self {
171 let max_trajectories = match mode {
172 SonaMode::RealTime => 100,
173 SonaMode::Balanced => 500,
174 SonaMode::Research => 5000,
175 SonaMode::Edge => 50,
176 };
177
178 Self {
179 mode,
180 trajectories: RwLock::new(Vec::new()),
181 learned_patterns: RwLock::new(Vec::new()),
182 embedding,
183 max_trajectories,
184 }
185 }
186
187 pub fn mode(&self) -> SonaMode {
189 self.mode
190 }
191
192 pub async fn record(&self, mut trajectory: Trajectory) -> Result<String, anyhow::Error> {
197 if trajectory.id.is_empty() {
198 trajectory.id = Uuid::new_v4().to_string();
199 }
200
201 let text = trajectory.input_text();
203 if !text.is_empty() {
204 let embedding = self.embedding.embed(&text).await?;
205 trajectory.embedding = Some(embedding);
206 }
207
208 let id = trajectory.id.clone();
209
210 let mut trajs = self.trajectories.write();
211
212 if trajs.len() >= self.max_trajectories {
214 let remove_count = trajs.len() - self.max_trajectories + 1;
216 let mut removed = 0;
217 trajs.retain(|t| {
218 if removed >= remove_count {
219 return true;
220 }
221 if t.verdict == Verdict::Failure {
222 removed += 1;
223 false
224 } else {
225 true
226 }
227 });
228 while trajs.len() >= self.max_trajectories {
230 trajs.remove(0);
231 }
232 }
233
234 trajs.push(trajectory);
235 Ok(id)
236 }
237
238 pub async fn distill(&self) -> Result<Vec<LearnedPattern>, anyhow::Error> {
243 let domain_groups: HashMap<String, Vec<Trajectory>> = {
245 let trajs = self.trajectories.read();
246 let mut groups: HashMap<String, Vec<Trajectory>> = HashMap::new();
247 for traj in trajs.iter() {
248 if traj.verdict == Verdict::Success {
249 groups
250 .entry(traj.domain.clone())
251 .or_default()
252 .push(traj.clone());
253 }
254 }
255 groups
256 }; let mut new_patterns = Vec::new();
259
260 for (domain, group) in &domain_groups {
261 if group.len() < 2 {
262 continue; }
264
265 let mut strategy_parts: Vec<String> = Vec::new();
268 for traj in group {
269 let summary: String = traj
270 .steps
271 .iter()
272 .take(3) .map(|s| s.input.clone())
274 .collect::<Vec<_>>()
275 .join(" → ");
276 strategy_parts.push(summary);
277 }
278
279 let combined = strategy_parts.join("; ");
281 let strategy = if combined.chars().count() > 500 {
282 let truncated: String = combined.chars().take(500).collect();
286 format!("{truncated}...")
287 } else {
288 combined
289 };
290
291 let embedding = self.embedding.embed(&strategy).await?;
292
293 let source_ids: Vec<String> = group.iter().map(|t| t.id.clone()).collect();
294
295 let pattern = LearnedPattern {
296 id: Uuid::new_v4().to_string(),
297 source_trajectories: source_ids,
298 strategy,
299 domain: domain.clone(),
300 confidence: (group.len() as f32 * 0.2).min(1.0),
301 support_count: group.len(),
302 embedding: Some(embedding),
303 };
304
305 new_patterns.push(pattern);
306 }
307
308 {
310 let mut patterns = self.learned_patterns.write();
311 for pattern in &new_patterns {
312 let is_dup = patterns
314 .iter()
315 .any(|p| p.strategy == pattern.strategy && p.domain == pattern.domain);
316 if !is_dup {
317 patterns.push(pattern.clone());
318 }
319 }
320 }
321
322 tracing::info!(
323 new_patterns = new_patterns.len(),
324 "SONA distillation complete"
325 );
326 Ok(new_patterns)
327 }
328
329 pub async fn adapt(&self, query: &str) -> Result<Option<LearnedPattern>, anyhow::Error> {
334 let query_embedding = self.embedding.embed(query).await?;
335
336 let patterns = self.learned_patterns.read();
337 let mut best: Option<(&LearnedPattern, f64)> = None;
338
339 for pattern in patterns.iter() {
340 if let Some(ref emb) = pattern.embedding {
341 let sim = query_embedding.cosine_similarity(emb);
342 match best {
343 Some((_, best_sim)) if sim <= best_sim => {}
344 _ => best = Some((pattern, sim)),
345 }
346 }
347 }
348
349 Ok(best.filter(|(_, sim)| *sim > 0.3).map(|(p, sim)| {
350 let mut adapted = p.clone();
351 adapted.confidence = (p.confidence * sim as f32).min(1.0);
352 adapted
353 }))
354 }
355
356 pub fn counts(&self) -> (usize, usize) {
358 let traj_count = self.trajectories.read().len();
359 let pattern_count = self.learned_patterns.read().len();
360 (traj_count, pattern_count)
361 }
362
363 pub fn get_learned_patterns(&self) -> Vec<LearnedPattern> {
365 self.learned_patterns.read().clone()
366 }
367
368 pub fn load_learned_patterns(&self, patterns: Vec<LearnedPattern>) {
370 let mut existing = self.learned_patterns.write();
371 *existing = patterns;
372 }
373
374 pub fn trajectories_by_verdict(&self, verdict: Verdict) -> Vec<Trajectory> {
376 self.trajectories
377 .read()
378 .iter()
379 .filter(|t| t.verdict == verdict)
380 .cloned()
381 .collect()
382 }
383
384 #[cfg(feature = "sqlite-memory")]
388 pub fn persist_to_sqlite(
389 &self,
390 store: &crate::memory::sqlite::store::SqliteMemoryStore,
391 ) -> anyhow::Result<()> {
392 let patterns = self.learned_patterns.read();
393 for pattern in patterns.iter() {
394 let data = serde_json::to_string(pattern)?;
395 store.save_pattern(
396 &pattern.id,
397 "sona",
398 Some(&pattern.domain),
399 pattern.confidence,
400 &data,
401 )?;
402 }
403 tracing::debug!(count = patterns.len(), "SONA patterns persisted to SQLite");
404 Ok(())
405 }
406
407 #[cfg(feature = "sqlite-memory")]
411 pub fn restore_from_sqlite(
412 &self,
413 store: &crate::memory::sqlite::store::SqliteMemoryStore,
414 ) -> anyhow::Result<()> {
415 let rows = store.load_patterns()?;
416 let sona_rows: Vec<_> = rows.into_iter().filter(|r| r.strategy == "sona").collect();
417
418 let mut patterns = Vec::new();
419 for row in &sona_rows {
420 if let Ok(pattern) = serde_json::from_str::<LearnedPattern>(&row.data) {
421 patterns.push(pattern);
422 }
423 }
424
425 *self.learned_patterns.write() = patterns;
426 tracing::debug!(
427 count = sona_rows.len(),
428 "SONA patterns restored from SQLite"
429 );
430 Ok(())
431 }
432}
433
434#[cfg(test)]
439mod tests {
440 use super::*;
441 use crate::memory::embedding::TfIdfEmbeddingProvider;
442
443 fn make_step(input: &str, output: &str) -> TrajectoryStep {
444 TrajectoryStep {
445 input: input.to_string(),
446 output: output.to_string(),
447 duration_ms: 10,
448 confidence: 0.9,
449 }
450 }
451
452 fn make_trajectory(domain: &str, verdict: Verdict) -> Trajectory {
453 Trajectory::new(
454 vec![
455 make_step("analyze input", "parsed"),
456 make_step("execute plan", "completed"),
457 ],
458 verdict,
459 domain,
460 )
461 }
462
463 #[tokio::test]
464 async fn test_record_trajectory() {
465 let engine = SonaEngine::new(
466 SonaMode::Balanced,
467 std::sync::Arc::new(TfIdfEmbeddingProvider),
468 );
469 let traj = make_trajectory("testing", Verdict::Success);
470
471 let id = engine.record(traj).await.unwrap();
472 assert!(!id.is_empty());
473
474 let (traj_count, _) = engine.counts();
475 assert_eq!(traj_count, 1);
476 }
477
478 #[tokio::test]
479 async fn test_distill_patterns() {
480 let engine = SonaEngine::new(
481 SonaMode::Balanced,
482 std::sync::Arc::new(TfIdfEmbeddingProvider),
483 );
484
485 for _ in 0..3 {
487 let traj = make_trajectory("security", Verdict::Success);
488 engine.record(traj).await.unwrap();
489 }
490
491 let patterns = engine.distill().await.unwrap();
492 assert!(
493 !patterns.is_empty(),
494 "Should distill patterns from 3+ successful trajectories"
495 );
496
497 let (_, pattern_count) = engine.counts();
498 assert!(pattern_count > 0);
499 }
500
501 #[tokio::test]
502 async fn test_distill_needs_multiple_successes() {
503 let engine = SonaEngine::new(
504 SonaMode::Balanced,
505 std::sync::Arc::new(TfIdfEmbeddingProvider),
506 );
507
508 engine
509 .record(make_trajectory("testing", Verdict::Success))
510 .await
511 .unwrap();
512 let patterns = engine.distill().await.unwrap();
513 assert!(patterns.is_empty(), "Need 2+ trajectories to distill");
514 }
515
516 #[tokio::test]
517 async fn test_distill_ignores_failures() {
518 let engine = SonaEngine::new(
519 SonaMode::Balanced,
520 std::sync::Arc::new(TfIdfEmbeddingProvider),
521 );
522
523 engine
524 .record(make_trajectory("testing", Verdict::Failure))
525 .await
526 .unwrap();
527 engine
528 .record(make_trajectory("testing", Verdict::Failure))
529 .await
530 .unwrap();
531
532 let patterns = engine.distill().await.unwrap();
533 assert!(patterns.is_empty(), "Failures should not produce patterns");
534 }
535
536 #[tokio::test]
537 async fn test_adapt_finds_similar_pattern() {
538 let engine = SonaEngine::new(
539 SonaMode::Balanced,
540 std::sync::Arc::new(TfIdfEmbeddingProvider),
541 );
542
543 for _ in 0..3 {
545 let mut traj = make_trajectory("security", Verdict::Success);
546 traj.steps[0].input =
547 "scan for SQL injection vulnerabilities in the codebase".to_string();
548 engine.record(traj).await.unwrap();
549 }
550 engine.distill().await.unwrap();
551
552 let result = engine
554 .adapt("check for SQL injection security issues")
555 .await
556 .unwrap();
557 assert!(result.is_some(), "Should find a matching pattern");
558 let pattern = result.unwrap();
559 assert_eq!(pattern.domain, "security");
560 assert!(pattern.confidence > 0.0);
561 }
562
563 #[tokio::test]
564 async fn test_adapt_no_match_below_threshold() {
565 let engine = SonaEngine::new(
566 SonaMode::Balanced,
567 std::sync::Arc::new(TfIdfEmbeddingProvider),
568 );
569
570 let result = engine
572 .adapt("completely unrelated query about cooking")
573 .await
574 .unwrap();
575 assert!(result.is_none());
576 }
577
578 #[tokio::test]
579 async fn test_capacity_limit() {
580 let engine = SonaEngine::new(SonaMode::Edge, std::sync::Arc::new(TfIdfEmbeddingProvider));
581 for i in 0..55 {
584 let mut traj = make_trajectory("testing", Verdict::Success);
585 traj.id = format!("traj-{}", i);
586 engine.record(traj).await.unwrap();
587 }
588
589 let (count, _) = engine.counts();
590 assert!(count <= 50, "Should not exceed capacity: got {}", count);
591 }
592
593 #[test]
594 fn test_trajectory_total_duration() {
595 let traj = Trajectory::new(
596 vec![make_step("a", "b"), make_step("c", "d")],
597 Verdict::Success,
598 "testing",
599 );
600 assert_eq!(traj.total_duration_ms(), 20);
601 }
602
603 #[test]
604 fn test_trajectory_avg_confidence() {
605 let traj = Trajectory::new(
606 vec![
607 TrajectoryStep {
608 input: "a".into(),
609 output: "b".into(),
610 duration_ms: 10,
611 confidence: 0.8,
612 },
613 TrajectoryStep {
614 input: "c".into(),
615 output: "d".into(),
616 duration_ms: 10,
617 confidence: 0.6,
618 },
619 ],
620 Verdict::Success,
621 "testing",
622 );
623 assert!((traj.avg_confidence() - 0.7).abs() < 0.01);
624 }
625
626 #[test]
627 fn test_sona_mode_default() {
628 assert_eq!(SonaMode::default(), SonaMode::Balanced);
629 }
630}