Skip to main content

oxios_memory/memory/
sona.rs

1#![allow(missing_docs)]
2//! SONA — Self-Optimizing Neural Architecture (simplified).
3//!
4//! Tracks execution trajectories, distills successful patterns,
5//! and adapts future behavior based on learned experience.
6//!
7//! Performance target: adaptation in < 0.05ms (in-memory lookup).
8
9use std::collections::HashMap;
10
11use chrono::{DateTime, Utc};
12use parking_lot::RwLock;
13use serde::{Deserialize, Serialize};
14use uuid::Uuid;
15
16use crate::memory::embedding::{EmbeddingProvider, EmbeddingVector};
17
18// ---------------------------------------------------------------------------
19// Data types
20// ---------------------------------------------------------------------------
21
22/// Operating mode for the SONA engine.
23#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
24#[serde(rename_all = "snake_case")]
25#[derive(Default)]
26pub enum SonaMode {
27    /// Real-time adaptation (< 0.05ms target).
28    RealTime,
29    /// Balanced between speed and depth.
30    #[default]
31    Balanced,
32    /// Research mode — deep analysis, no time constraints.
33    Research,
34    /// Edge device — minimal memory footprint.
35    Edge,
36}
37
38/// Verdict for a trajectory outcome.
39#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
40#[serde(rename_all = "snake_case")]
41pub enum Verdict {
42    /// Fully successful.
43    Success,
44    /// Partially successful (some steps failed).
45    PartialFailure,
46    /// Completely failed.
47    Failure,
48}
49
50/// A single step within a trajectory.
51#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct TrajectoryStep {
53    /// Step description / input.
54    pub input: String,
55    /// Step result / output.
56    pub output: String,
57    /// Duration of this step in milliseconds.
58    pub duration_ms: u64,
59    /// Confidence score (0.0–1.0).
60    pub confidence: f32,
61}
62
63/// A trajectory — a sequence of steps with a final verdict.
64#[derive(Debug, Clone, Serialize, Deserialize)]
65pub struct Trajectory {
66    /// Unique ID.
67    pub id: String,
68    /// Ordered steps.
69    pub steps: Vec<TrajectoryStep>,
70    /// Final verdict.
71    pub verdict: Verdict,
72    /// Domain or task type.
73    pub domain: String,
74    /// Creation timestamp.
75    pub created_at: DateTime<Utc>,
76    /// Embedding of the trajectory's combined input text.
77    #[serde(skip)]
78    pub embedding: Option<EmbeddingVector>,
79}
80
81impl Trajectory {
82    /// Create a new trajectory with the given steps and verdict.
83    pub fn new(steps: Vec<TrajectoryStep>, verdict: Verdict, domain: &str) -> Self {
84        Self {
85            id: Uuid::new_v4().to_string(),
86            steps,
87            verdict,
88            domain: domain.to_string(),
89            created_at: Utc::now(),
90            embedding: None,
91        }
92    }
93
94    /// Total duration across all steps.
95    pub fn total_duration_ms(&self) -> u64 {
96        self.steps.iter().map(|s| s.duration_ms).sum()
97    }
98
99    /// Average confidence across all steps.
100    pub fn avg_confidence(&self) -> f32 {
101        if self.steps.is_empty() {
102            return 0.0;
103        }
104        self.steps.iter().map(|s| s.confidence).sum::<f32>() / self.steps.len() as f32
105    }
106
107    /// Concatenated input text for embedding.
108    pub fn input_text(&self) -> String {
109        self.steps
110            .iter()
111            .map(|s| s.input.as_str())
112            .collect::<Vec<_>>()
113            .join(" ")
114    }
115}
116
117/// A distilled pattern extracted from successful trajectories.
118#[derive(Debug, Clone, Serialize, Deserialize)]
119pub struct LearnedPattern {
120    /// Unique ID.
121    pub id: String,
122    /// Source trajectory IDs.
123    pub source_trajectories: Vec<String>,
124    /// Distilled strategy description.
125    pub strategy: String,
126    /// Domain category.
127    pub domain: String,
128    /// Confidence in this pattern (based on number of supporting trajectories).
129    pub confidence: f32,
130    /// Number of trajectories this was distilled from.
131    pub support_count: usize,
132    /// Embedding for similarity matching.
133    #[serde(skip)]
134    pub embedding: Option<EmbeddingVector>,
135}
136
137// ---------------------------------------------------------------------------
138// SonaEngine
139// ---------------------------------------------------------------------------
140
141/// The SONA engine for trajectory-based self-learning.
142///
143/// Records execution trajectories, distills patterns from successful ones,
144/// and adapts future behavior by matching against learned patterns.
145pub struct SonaEngine {
146    /// Operating mode.
147    mode: SonaMode,
148    /// Recorded trajectories.
149    trajectories: RwLock<Vec<Trajectory>>,
150    /// Distilled patterns from successful trajectories.
151    learned_patterns: RwLock<Vec<LearnedPattern>>,
152    /// Embedding provider.
153    embedding: std::sync::Arc<dyn EmbeddingProvider>,
154    /// Maximum trajectories to keep (mode-dependent).
155    max_trajectories: usize,
156}
157
158impl std::fmt::Debug for SonaEngine {
159    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
160        f.debug_struct("SonaEngine")
161            .field("mode", &self.mode)
162            .field("trajectory_count", &self.trajectories.read().len())
163            .field("pattern_count", &self.learned_patterns.read().len())
164            .finish()
165    }
166}
167
168impl SonaEngine {
169    /// Create a new SONA engine with the given mode and embedding provider.
170    pub fn new(mode: SonaMode, embedding: std::sync::Arc<dyn EmbeddingProvider>) -> Self {
171        let max_trajectories = match mode {
172            SonaMode::RealTime => 100,
173            SonaMode::Balanced => 500,
174            SonaMode::Research => 5000,
175            SonaMode::Edge => 50,
176        };
177
178        Self {
179            mode,
180            trajectories: RwLock::new(Vec::new()),
181            learned_patterns: RwLock::new(Vec::new()),
182            embedding,
183            max_trajectories,
184        }
185    }
186
187    /// Return the current operating mode.
188    pub fn mode(&self) -> SonaMode {
189        self.mode
190    }
191
192    /// Record a new trajectory.
193    ///
194    /// Generates an embedding for the trajectory's input text
195    /// and stores it for future distillation.
196    pub async fn record(&self, mut trajectory: Trajectory) -> Result<String, anyhow::Error> {
197        if trajectory.id.is_empty() {
198            trajectory.id = Uuid::new_v4().to_string();
199        }
200
201        // Generate embedding
202        let text = trajectory.input_text();
203        if !text.is_empty() {
204            let embedding = self.embedding.embed(&text).await?;
205            trajectory.embedding = Some(embedding);
206        }
207
208        let id = trajectory.id.clone();
209
210        let mut trajs = self.trajectories.write();
211
212        // Enforce capacity limit
213        if trajs.len() >= self.max_trajectories {
214            // Remove oldest failure trajectories first
215            let remove_count = trajs.len() - self.max_trajectories + 1;
216            let mut removed = 0;
217            trajs.retain(|t| {
218                if removed >= remove_count {
219                    return true;
220                }
221                if t.verdict == Verdict::Failure {
222                    removed += 1;
223                    false
224                } else {
225                    true
226                }
227            });
228            // If still over capacity, remove oldest
229            while trajs.len() >= self.max_trajectories {
230                trajs.remove(0);
231            }
232        }
233
234        trajs.push(trajectory);
235        Ok(id)
236    }
237
238    /// Distill learned patterns from successful trajectories.
239    ///
240    /// Groups successful trajectories by domain and extracts common
241    /// patterns. Returns the newly distilled patterns.
242    pub async fn distill(&self) -> Result<Vec<LearnedPattern>, anyhow::Error> {
243        // Collect data under lock, then release before any .await
244        let domain_groups: HashMap<String, Vec<Trajectory>> = {
245            let trajs = self.trajectories.read();
246            let mut groups: HashMap<String, Vec<Trajectory>> = HashMap::new();
247            for traj in trajs.iter() {
248                if traj.verdict == Verdict::Success {
249                    groups
250                        .entry(traj.domain.clone())
251                        .or_default()
252                        .push(traj.clone());
253                }
254            }
255            groups
256        }; // lock dropped here
257
258        let mut new_patterns = Vec::new();
259
260        for (domain, group) in &domain_groups {
261            if group.len() < 2 {
262                continue; // Need at least 2 trajectories to distill
263            }
264
265            // Simple distillation: extract common step patterns
266            // For each trajectory, get the strategy from concatenated inputs
267            let mut strategy_parts: Vec<String> = Vec::new();
268            for traj in group {
269                let summary: String = traj
270                    .steps
271                    .iter()
272                    .take(3) // Use first 3 steps as summary
273                    .map(|s| s.input.clone())
274                    .collect::<Vec<_>>()
275                    .join(" → ");
276                strategy_parts.push(summary);
277            }
278
279            // Combine strategies into a distilled pattern
280            let combined = strategy_parts.join("; ");
281            let strategy = if combined.chars().count() > 500 {
282                // Truncate by chars, not bytes. `&combined[..500]` panics when
283                // byte index 500 falls inside a multi-byte UTF-8 sequence
284                // (Korean, emoji) — common for this Korean-first codebase.
285                let truncated: String = combined.chars().take(500).collect();
286                format!("{truncated}...")
287            } else {
288                combined
289            };
290
291            let embedding = self.embedding.embed(&strategy).await?;
292
293            let source_ids: Vec<String> = group.iter().map(|t| t.id.clone()).collect();
294
295            let pattern = LearnedPattern {
296                id: Uuid::new_v4().to_string(),
297                source_trajectories: source_ids,
298                strategy,
299                domain: domain.clone(),
300                confidence: (group.len() as f32 * 0.2).min(1.0),
301                support_count: group.len(),
302                embedding: Some(embedding),
303            };
304
305            new_patterns.push(pattern);
306        }
307
308        // Store new patterns
309        {
310            let mut patterns = self.learned_patterns.write();
311            for pattern in &new_patterns {
312                // Don't duplicate — check by strategy similarity (simplified: exact match)
313                let is_dup = patterns
314                    .iter()
315                    .any(|p| p.strategy == pattern.strategy && p.domain == pattern.domain);
316                if !is_dup {
317                    patterns.push(pattern.clone());
318                }
319            }
320        }
321
322        tracing::info!(
323            new_patterns = new_patterns.len(),
324            "SONA distillation complete"
325        );
326        Ok(new_patterns)
327    }
328
329    /// Adapt to a new query by finding the most similar learned pattern.
330    ///
331    /// Returns the best matching pattern if similarity exceeds threshold.
332    /// Target: < 0.05ms for in-memory lookup.
333    pub async fn adapt(&self, query: &str) -> Result<Option<LearnedPattern>, anyhow::Error> {
334        let query_embedding = self.embedding.embed(query).await?;
335
336        let patterns = self.learned_patterns.read();
337        let mut best: Option<(&LearnedPattern, f64)> = None;
338
339        for pattern in patterns.iter() {
340            if let Some(ref emb) = pattern.embedding {
341                let sim = query_embedding.cosine_similarity(emb);
342                match best {
343                    Some((_, best_sim)) if sim <= best_sim => {}
344                    _ => best = Some((pattern, sim)),
345                }
346            }
347        }
348
349        Ok(best.filter(|(_, sim)| *sim > 0.3).map(|(p, sim)| {
350            let mut adapted = p.clone();
351            adapted.confidence = (p.confidence * sim as f32).min(1.0);
352            adapted
353        }))
354    }
355
356    /// Return counts of trajectories and patterns.
357    pub fn counts(&self) -> (usize, usize) {
358        let traj_count = self.trajectories.read().len();
359        let pattern_count = self.learned_patterns.read().len();
360        (traj_count, pattern_count)
361    }
362
363    /// Get all learned patterns for persistence.
364    pub fn get_learned_patterns(&self) -> Vec<LearnedPattern> {
365        self.learned_patterns.read().clone()
366    }
367
368    /// Load learned patterns from persistence.
369    pub fn load_learned_patterns(&self, patterns: Vec<LearnedPattern>) {
370        let mut existing = self.learned_patterns.write();
371        *existing = patterns;
372    }
373
374    /// Get trajectories filtered by verdict.
375    pub fn trajectories_by_verdict(&self, verdict: Verdict) -> Vec<Trajectory> {
376        self.trajectories
377            .read()
378            .iter()
379            .filter(|t| t.verdict == verdict)
380            .cloned()
381            .collect()
382    }
383
384    /// Persist learned patterns to SQLite.
385    ///
386    /// Saves all distilled patterns to the `patterns` table.
387    #[cfg(feature = "sqlite-memory")]
388    pub fn persist_to_sqlite(
389        &self,
390        store: &crate::memory::sqlite::store::SqliteMemoryStore,
391    ) -> anyhow::Result<()> {
392        let patterns = self.learned_patterns.read();
393        for pattern in patterns.iter() {
394            let data = serde_json::to_string(pattern)?;
395            store.save_pattern(
396                &pattern.id,
397                "sona",
398                Some(&pattern.domain),
399                pattern.confidence,
400                &data,
401            )?;
402        }
403        tracing::debug!(count = patterns.len(), "SONA patterns persisted to SQLite");
404        Ok(())
405    }
406
407    /// Restore learned patterns from SQLite.
408    ///
409    /// Loads all SONA patterns from the `patterns` table.
410    #[cfg(feature = "sqlite-memory")]
411    pub fn restore_from_sqlite(
412        &self,
413        store: &crate::memory::sqlite::store::SqliteMemoryStore,
414    ) -> anyhow::Result<()> {
415        let rows = store.load_patterns()?;
416        let sona_rows: Vec<_> = rows.into_iter().filter(|r| r.strategy == "sona").collect();
417
418        let mut patterns = Vec::new();
419        for row in &sona_rows {
420            if let Ok(pattern) = serde_json::from_str::<LearnedPattern>(&row.data) {
421                patterns.push(pattern);
422            }
423        }
424
425        *self.learned_patterns.write() = patterns;
426        tracing::debug!(
427            count = sona_rows.len(),
428            "SONA patterns restored from SQLite"
429        );
430        Ok(())
431    }
432}
433
434// ---------------------------------------------------------------------------
435// Tests
436// ---------------------------------------------------------------------------
437
438#[cfg(test)]
439mod tests {
440    use super::*;
441    use crate::memory::embedding::TfIdfEmbeddingProvider;
442
443    fn make_step(input: &str, output: &str) -> TrajectoryStep {
444        TrajectoryStep {
445            input: input.to_string(),
446            output: output.to_string(),
447            duration_ms: 10,
448            confidence: 0.9,
449        }
450    }
451
452    fn make_trajectory(domain: &str, verdict: Verdict) -> Trajectory {
453        Trajectory::new(
454            vec![
455                make_step("analyze input", "parsed"),
456                make_step("execute plan", "completed"),
457            ],
458            verdict,
459            domain,
460        )
461    }
462
463    #[tokio::test]
464    async fn test_record_trajectory() {
465        let engine = SonaEngine::new(
466            SonaMode::Balanced,
467            std::sync::Arc::new(TfIdfEmbeddingProvider),
468        );
469        let traj = make_trajectory("testing", Verdict::Success);
470
471        let id = engine.record(traj).await.unwrap();
472        assert!(!id.is_empty());
473
474        let (traj_count, _) = engine.counts();
475        assert_eq!(traj_count, 1);
476    }
477
478    #[tokio::test]
479    async fn test_distill_patterns() {
480        let engine = SonaEngine::new(
481            SonaMode::Balanced,
482            std::sync::Arc::new(TfIdfEmbeddingProvider),
483        );
484
485        // Record multiple successful trajectories in the same domain
486        for _ in 0..3 {
487            let traj = make_trajectory("security", Verdict::Success);
488            engine.record(traj).await.unwrap();
489        }
490
491        let patterns = engine.distill().await.unwrap();
492        assert!(
493            !patterns.is_empty(),
494            "Should distill patterns from 3+ successful trajectories"
495        );
496
497        let (_, pattern_count) = engine.counts();
498        assert!(pattern_count > 0);
499    }
500
501    #[tokio::test]
502    async fn test_distill_needs_multiple_successes() {
503        let engine = SonaEngine::new(
504            SonaMode::Balanced,
505            std::sync::Arc::new(TfIdfEmbeddingProvider),
506        );
507
508        engine
509            .record(make_trajectory("testing", Verdict::Success))
510            .await
511            .unwrap();
512        let patterns = engine.distill().await.unwrap();
513        assert!(patterns.is_empty(), "Need 2+ trajectories to distill");
514    }
515
516    #[tokio::test]
517    async fn test_distill_ignores_failures() {
518        let engine = SonaEngine::new(
519            SonaMode::Balanced,
520            std::sync::Arc::new(TfIdfEmbeddingProvider),
521        );
522
523        engine
524            .record(make_trajectory("testing", Verdict::Failure))
525            .await
526            .unwrap();
527        engine
528            .record(make_trajectory("testing", Verdict::Failure))
529            .await
530            .unwrap();
531
532        let patterns = engine.distill().await.unwrap();
533        assert!(patterns.is_empty(), "Failures should not produce patterns");
534    }
535
536    #[tokio::test]
537    async fn test_adapt_finds_similar_pattern() {
538        let engine = SonaEngine::new(
539            SonaMode::Balanced,
540            std::sync::Arc::new(TfIdfEmbeddingProvider),
541        );
542
543        // Record and distill
544        for _ in 0..3 {
545            let mut traj = make_trajectory("security", Verdict::Success);
546            traj.steps[0].input =
547                "scan for SQL injection vulnerabilities in the codebase".to_string();
548            engine.record(traj).await.unwrap();
549        }
550        engine.distill().await.unwrap();
551
552        // Adapt should find the pattern
553        let result = engine
554            .adapt("check for SQL injection security issues")
555            .await
556            .unwrap();
557        assert!(result.is_some(), "Should find a matching pattern");
558        let pattern = result.unwrap();
559        assert_eq!(pattern.domain, "security");
560        assert!(pattern.confidence > 0.0);
561    }
562
563    #[tokio::test]
564    async fn test_adapt_no_match_below_threshold() {
565        let engine = SonaEngine::new(
566            SonaMode::Balanced,
567            std::sync::Arc::new(TfIdfEmbeddingProvider),
568        );
569
570        // No patterns learned
571        let result = engine
572            .adapt("completely unrelated query about cooking")
573            .await
574            .unwrap();
575        assert!(result.is_none());
576    }
577
578    #[tokio::test]
579    async fn test_capacity_limit() {
580        let engine = SonaEngine::new(SonaMode::Edge, std::sync::Arc::new(TfIdfEmbeddingProvider));
581        // Edge mode: max 50 trajectories
582
583        for i in 0..55 {
584            let mut traj = make_trajectory("testing", Verdict::Success);
585            traj.id = format!("traj-{}", i);
586            engine.record(traj).await.unwrap();
587        }
588
589        let (count, _) = engine.counts();
590        assert!(count <= 50, "Should not exceed capacity: got {}", count);
591    }
592
593    #[test]
594    fn test_trajectory_total_duration() {
595        let traj = Trajectory::new(
596            vec![make_step("a", "b"), make_step("c", "d")],
597            Verdict::Success,
598            "testing",
599        );
600        assert_eq!(traj.total_duration_ms(), 20);
601    }
602
603    #[test]
604    fn test_trajectory_avg_confidence() {
605        let traj = Trajectory::new(
606            vec![
607                TrajectoryStep {
608                    input: "a".into(),
609                    output: "b".into(),
610                    duration_ms: 10,
611                    confidence: 0.8,
612                },
613                TrajectoryStep {
614                    input: "c".into(),
615                    output: "d".into(),
616                    duration_ms: 10,
617                    confidence: 0.6,
618                },
619            ],
620            Verdict::Success,
621            "testing",
622        );
623        assert!((traj.avg_confidence() - 0.7).abs() < 0.01);
624    }
625
626    #[test]
627    fn test_sona_mode_default() {
628        assert_eq!(SonaMode::default(), SonaMode::Balanced);
629    }
630}