Skip to main content

oxios_memory/memory/
sona.rs

1#![allow(missing_docs)]
2//! SONA — Self-Optimizing Neural Architecture (simplified).
3//!
4//! Tracks execution trajectories, distills successful patterns,
5//! and adapts future behavior based on learned experience.
6//!
7//! Performance target: adaptation in < 0.05ms (in-memory lookup).
8
9use std::collections::HashMap;
10
11use chrono::{DateTime, Utc};
12use parking_lot::RwLock;
13use serde::{Deserialize, Serialize};
14use uuid::Uuid;
15
16use crate::memory::embedding::{EmbeddingProvider, EmbeddingVector};
17
18// ---------------------------------------------------------------------------
19// Data types
20// ---------------------------------------------------------------------------
21
22/// Operating mode for the SONA engine.
23#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
24#[serde(rename_all = "snake_case")]
25#[derive(Default)]
26pub enum SonaMode {
27    /// Real-time adaptation (< 0.05ms target).
28    RealTime,
29    /// Balanced between speed and depth.
30    #[default]
31    Balanced,
32    /// Research mode — deep analysis, no time constraints.
33    Research,
34    /// Edge device — minimal memory footprint.
35    Edge,
36}
37
38/// Verdict for a trajectory outcome.
39#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
40#[serde(rename_all = "snake_case")]
41pub enum Verdict {
42    /// Fully successful.
43    Success,
44    /// Partially successful (some steps failed).
45    PartialFailure,
46    /// Completely failed.
47    Failure,
48}
49
50/// A single step within a trajectory.
51#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct TrajectoryStep {
53    /// Step description / input.
54    pub input: String,
55    /// Step result / output.
56    pub output: String,
57    /// Duration of this step in milliseconds.
58    pub duration_ms: u64,
59    /// Confidence score (0.0–1.0).
60    pub confidence: f32,
61}
62
63/// A trajectory — a sequence of steps with a final verdict.
64#[derive(Debug, Clone, Serialize, Deserialize)]
65pub struct Trajectory {
66    /// Unique ID.
67    pub id: String,
68    /// Ordered steps.
69    pub steps: Vec<TrajectoryStep>,
70    /// Final verdict.
71    pub verdict: Verdict,
72    /// Domain or task type.
73    pub domain: String,
74    /// Creation timestamp.
75    pub created_at: DateTime<Utc>,
76    /// Embedding of the trajectory's combined input text.
77    #[serde(skip)]
78    pub embedding: Option<EmbeddingVector>,
79}
80
81impl Trajectory {
82    /// Create a new trajectory with the given steps and verdict.
83    pub fn new(steps: Vec<TrajectoryStep>, verdict: Verdict, domain: &str) -> Self {
84        Self {
85            id: Uuid::new_v4().to_string(),
86            steps,
87            verdict,
88            domain: domain.to_string(),
89            created_at: Utc::now(),
90            embedding: None,
91        }
92    }
93
94    /// Total duration across all steps.
95    pub fn total_duration_ms(&self) -> u64 {
96        self.steps.iter().map(|s| s.duration_ms).sum()
97    }
98
99    /// Average confidence across all steps.
100    pub fn avg_confidence(&self) -> f32 {
101        if self.steps.is_empty() {
102            return 0.0;
103        }
104        self.steps.iter().map(|s| s.confidence).sum::<f32>() / self.steps.len() as f32
105    }
106
107    /// Concatenated input text for embedding.
108    pub fn input_text(&self) -> String {
109        self.steps
110            .iter()
111            .map(|s| s.input.as_str())
112            .collect::<Vec<_>>()
113            .join(" ")
114    }
115}
116
117/// A distilled pattern extracted from successful trajectories.
118#[derive(Debug, Clone, Serialize, Deserialize)]
119pub struct LearnedPattern {
120    /// Unique ID.
121    pub id: String,
122    /// Source trajectory IDs.
123    pub source_trajectories: Vec<String>,
124    /// Distilled strategy description.
125    pub strategy: String,
126    /// Domain category.
127    pub domain: String,
128    /// Confidence in this pattern (based on number of supporting trajectories).
129    pub confidence: f32,
130    /// Number of trajectories this was distilled from.
131    pub support_count: usize,
132    /// Embedding for similarity matching.
133    #[serde(skip)]
134    pub embedding: Option<EmbeddingVector>,
135}
136
137// ---------------------------------------------------------------------------
138// SonaEngine
139// ---------------------------------------------------------------------------
140
141/// The SONA engine for trajectory-based self-learning.
142///
143/// Records execution trajectories, distills patterns from successful ones,
144/// and adapts future behavior by matching against learned patterns.
145pub struct SonaEngine {
146    /// Operating mode.
147    mode: SonaMode,
148    /// Recorded trajectories.
149    trajectories: RwLock<Vec<Trajectory>>,
150    /// Distilled patterns from successful trajectories.
151    learned_patterns: RwLock<Vec<LearnedPattern>>,
152    /// Embedding provider.
153    embedding: std::sync::Arc<dyn EmbeddingProvider>,
154    /// Maximum trajectories to keep (mode-dependent).
155    max_trajectories: usize,
156}
157
158impl std::fmt::Debug for SonaEngine {
159    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
160        f.debug_struct("SonaEngine")
161            .field("mode", &self.mode)
162            .field("trajectory_count", &self.trajectories.read().len())
163            .field("pattern_count", &self.learned_patterns.read().len())
164            .finish()
165    }
166}
167
168impl SonaEngine {
169    /// Create a new SONA engine with the given mode and embedding provider.
170    pub fn new(mode: SonaMode, embedding: std::sync::Arc<dyn EmbeddingProvider>) -> Self {
171        let max_trajectories = match mode {
172            SonaMode::RealTime => 100,
173            SonaMode::Balanced => 500,
174            SonaMode::Research => 5000,
175            SonaMode::Edge => 50,
176        };
177
178        Self {
179            mode,
180            trajectories: RwLock::new(Vec::new()),
181            learned_patterns: RwLock::new(Vec::new()),
182            embedding,
183            max_trajectories,
184        }
185    }
186
187    /// Return the current operating mode.
188    pub fn mode(&self) -> SonaMode {
189        self.mode
190    }
191
192    /// Record a new trajectory.
193    ///
194    /// Generates an embedding for the trajectory's input text
195    /// and stores it for future distillation.
196    pub async fn record(&self, mut trajectory: Trajectory) -> Result<String, anyhow::Error> {
197        if trajectory.id.is_empty() {
198            trajectory.id = Uuid::new_v4().to_string();
199        }
200
201        // Generate embedding
202        let text = trajectory.input_text();
203        if !text.is_empty() {
204            let embedding = self.embedding.embed(&text).await?;
205            trajectory.embedding = Some(embedding);
206        }
207
208        let id = trajectory.id.clone();
209
210        let mut trajs = self.trajectories.write();
211
212        // Enforce capacity limit
213        if trajs.len() >= self.max_trajectories {
214            // Remove oldest failure trajectories first
215            let remove_count = trajs.len() - self.max_trajectories + 1;
216            let mut removed = 0;
217            trajs.retain(|t| {
218                if removed >= remove_count {
219                    return true;
220                }
221                if t.verdict == Verdict::Failure {
222                    removed += 1;
223                    false
224                } else {
225                    true
226                }
227            });
228            // If still over capacity, remove oldest
229            while trajs.len() >= self.max_trajectories {
230                trajs.remove(0);
231            }
232        }
233
234        trajs.push(trajectory);
235        Ok(id)
236    }
237
238    /// Distill learned patterns from successful trajectories.
239    ///
240    /// Groups successful trajectories by domain and extracts common
241    /// patterns. Returns the newly distilled patterns.
242    pub async fn distill(&self) -> Result<Vec<LearnedPattern>, anyhow::Error> {
243        // Collect data under lock, then release before any .await
244        let domain_groups: HashMap<String, Vec<Trajectory>> = {
245            let trajs = self.trajectories.read();
246            let mut groups: HashMap<String, Vec<Trajectory>> = HashMap::new();
247            for traj in trajs.iter() {
248                if traj.verdict == Verdict::Success {
249                    groups
250                        .entry(traj.domain.clone())
251                        .or_default()
252                        .push(traj.clone());
253                }
254            }
255            groups
256        }; // lock dropped here
257
258        let mut new_patterns = Vec::new();
259
260        for (domain, group) in &domain_groups {
261            if group.len() < 2 {
262                continue; // Need at least 2 trajectories to distill
263            }
264
265            // Simple distillation: extract common step patterns
266            // For each trajectory, get the strategy from concatenated inputs
267            let mut strategy_parts: Vec<String> = Vec::new();
268            for traj in group {
269                let summary: String = traj
270                    .steps
271                    .iter()
272                    .take(3) // Use first 3 steps as summary
273                    .map(|s| s.input.clone())
274                    .collect::<Vec<_>>()
275                    .join(" → ");
276                strategy_parts.push(summary);
277            }
278
279            // Combine strategies into a distilled pattern
280            let combined = strategy_parts.join("; ");
281            let strategy = if combined.len() > 500 {
282                format!("{}...", &combined[..500])
283            } else {
284                combined
285            };
286
287            let embedding = self.embedding.embed(&strategy).await?;
288
289            let source_ids: Vec<String> = group.iter().map(|t| t.id.clone()).collect();
290
291            let pattern = LearnedPattern {
292                id: Uuid::new_v4().to_string(),
293                source_trajectories: source_ids,
294                strategy,
295                domain: domain.clone(),
296                confidence: (group.len() as f32 * 0.2).min(1.0),
297                support_count: group.len(),
298                embedding: Some(embedding),
299            };
300
301            new_patterns.push(pattern);
302        }
303
304        // Store new patterns
305        {
306            let mut patterns = self.learned_patterns.write();
307            for pattern in &new_patterns {
308                // Don't duplicate — check by strategy similarity (simplified: exact match)
309                let is_dup = patterns
310                    .iter()
311                    .any(|p| p.strategy == pattern.strategy && p.domain == pattern.domain);
312                if !is_dup {
313                    patterns.push(pattern.clone());
314                }
315            }
316        }
317
318        tracing::info!(
319            new_patterns = new_patterns.len(),
320            "SONA distillation complete"
321        );
322        Ok(new_patterns)
323    }
324
325    /// Adapt to a new query by finding the most similar learned pattern.
326    ///
327    /// Returns the best matching pattern if similarity exceeds threshold.
328    /// Target: < 0.05ms for in-memory lookup.
329    pub async fn adapt(&self, query: &str) -> Result<Option<LearnedPattern>, anyhow::Error> {
330        let query_embedding = self.embedding.embed(query).await?;
331
332        let patterns = self.learned_patterns.read();
333        let mut best: Option<(&LearnedPattern, f64)> = None;
334
335        for pattern in patterns.iter() {
336            if let Some(ref emb) = pattern.embedding {
337                let sim = query_embedding.cosine_similarity(emb);
338                match best {
339                    Some((_, best_sim)) if sim <= best_sim => {}
340                    _ => best = Some((pattern, sim)),
341                }
342            }
343        }
344
345        Ok(best.filter(|(_, sim)| *sim > 0.3).map(|(p, sim)| {
346            let mut adapted = p.clone();
347            adapted.confidence = (p.confidence * sim as f32).min(1.0);
348            adapted
349        }))
350    }
351
352    /// Return counts of trajectories and patterns.
353    pub fn counts(&self) -> (usize, usize) {
354        let traj_count = self.trajectories.read().len();
355        let pattern_count = self.learned_patterns.read().len();
356        (traj_count, pattern_count)
357    }
358
359    /// Get all learned patterns for persistence.
360    pub fn get_learned_patterns(&self) -> Vec<LearnedPattern> {
361        self.learned_patterns.read().clone()
362    }
363
364    /// Load learned patterns from persistence.
365    pub fn load_learned_patterns(&self, patterns: Vec<LearnedPattern>) {
366        let mut existing = self.learned_patterns.write();
367        *existing = patterns;
368    }
369
370    /// Get trajectories filtered by verdict.
371    pub fn trajectories_by_verdict(&self, verdict: Verdict) -> Vec<Trajectory> {
372        self.trajectories
373            .read()
374            .iter()
375            .filter(|t| t.verdict == verdict)
376            .cloned()
377            .collect()
378    }
379
380    /// Persist learned patterns to SQLite.
381    ///
382    /// Saves all distilled patterns to the `patterns` table.
383    #[cfg(feature = "sqlite-memory")]
384    pub fn persist_to_sqlite(
385        &self,
386        store: &crate::memory::sqlite::store::SqliteMemoryStore,
387    ) -> anyhow::Result<()> {
388        let patterns = self.learned_patterns.read();
389        for pattern in patterns.iter() {
390            let data = serde_json::to_string(pattern)?;
391            store.save_pattern(
392                &pattern.id,
393                "sona",
394                Some(&pattern.domain),
395                pattern.confidence,
396                &data,
397            )?;
398        }
399        tracing::debug!(count = patterns.len(), "SONA patterns persisted to SQLite");
400        Ok(())
401    }
402
403    /// Restore learned patterns from SQLite.
404    ///
405    /// Loads all SONA patterns from the `patterns` table.
406    #[cfg(feature = "sqlite-memory")]
407    pub fn restore_from_sqlite(
408        &self,
409        store: &crate::memory::sqlite::store::SqliteMemoryStore,
410    ) -> anyhow::Result<()> {
411        let rows = store.load_patterns()?;
412        let sona_rows: Vec<_> = rows.into_iter().filter(|r| r.strategy == "sona").collect();
413
414        let mut patterns = Vec::new();
415        for row in &sona_rows {
416            if let Ok(pattern) = serde_json::from_str::<LearnedPattern>(&row.data) {
417                patterns.push(pattern);
418            }
419        }
420
421        *self.learned_patterns.write() = patterns;
422        tracing::debug!(
423            count = sona_rows.len(),
424            "SONA patterns restored from SQLite"
425        );
426        Ok(())
427    }
428}
429
430// ---------------------------------------------------------------------------
431// Tests
432// ---------------------------------------------------------------------------
433
434#[cfg(test)]
435mod tests {
436    use super::*;
437    use crate::memory::embedding::TfIdfEmbeddingProvider;
438
439    fn make_step(input: &str, output: &str) -> TrajectoryStep {
440        TrajectoryStep {
441            input: input.to_string(),
442            output: output.to_string(),
443            duration_ms: 10,
444            confidence: 0.9,
445        }
446    }
447
448    fn make_trajectory(domain: &str, verdict: Verdict) -> Trajectory {
449        Trajectory::new(
450            vec![
451                make_step("analyze input", "parsed"),
452                make_step("execute plan", "completed"),
453            ],
454            verdict,
455            domain,
456        )
457    }
458
459    #[tokio::test]
460    async fn test_record_trajectory() {
461        let engine = SonaEngine::new(
462            SonaMode::Balanced,
463            std::sync::Arc::new(TfIdfEmbeddingProvider),
464        );
465        let traj = make_trajectory("testing", Verdict::Success);
466
467        let id = engine.record(traj).await.unwrap();
468        assert!(!id.is_empty());
469
470        let (traj_count, _) = engine.counts();
471        assert_eq!(traj_count, 1);
472    }
473
474    #[tokio::test]
475    async fn test_distill_patterns() {
476        let engine = SonaEngine::new(
477            SonaMode::Balanced,
478            std::sync::Arc::new(TfIdfEmbeddingProvider),
479        );
480
481        // Record multiple successful trajectories in the same domain
482        for _ in 0..3 {
483            let traj = make_trajectory("security", Verdict::Success);
484            engine.record(traj).await.unwrap();
485        }
486
487        let patterns = engine.distill().await.unwrap();
488        assert!(
489            !patterns.is_empty(),
490            "Should distill patterns from 3+ successful trajectories"
491        );
492
493        let (_, pattern_count) = engine.counts();
494        assert!(pattern_count > 0);
495    }
496
497    #[tokio::test]
498    async fn test_distill_needs_multiple_successes() {
499        let engine = SonaEngine::new(
500            SonaMode::Balanced,
501            std::sync::Arc::new(TfIdfEmbeddingProvider),
502        );
503
504        engine
505            .record(make_trajectory("testing", Verdict::Success))
506            .await
507            .unwrap();
508        let patterns = engine.distill().await.unwrap();
509        assert!(patterns.is_empty(), "Need 2+ trajectories to distill");
510    }
511
512    #[tokio::test]
513    async fn test_distill_ignores_failures() {
514        let engine = SonaEngine::new(
515            SonaMode::Balanced,
516            std::sync::Arc::new(TfIdfEmbeddingProvider),
517        );
518
519        engine
520            .record(make_trajectory("testing", Verdict::Failure))
521            .await
522            .unwrap();
523        engine
524            .record(make_trajectory("testing", Verdict::Failure))
525            .await
526            .unwrap();
527
528        let patterns = engine.distill().await.unwrap();
529        assert!(patterns.is_empty(), "Failures should not produce patterns");
530    }
531
532    #[tokio::test]
533    async fn test_adapt_finds_similar_pattern() {
534        let engine = SonaEngine::new(
535            SonaMode::Balanced,
536            std::sync::Arc::new(TfIdfEmbeddingProvider),
537        );
538
539        // Record and distill
540        for _ in 0..3 {
541            let mut traj = make_trajectory("security", Verdict::Success);
542            traj.steps[0].input =
543                "scan for SQL injection vulnerabilities in the codebase".to_string();
544            engine.record(traj).await.unwrap();
545        }
546        engine.distill().await.unwrap();
547
548        // Adapt should find the pattern
549        let result = engine
550            .adapt("check for SQL injection security issues")
551            .await
552            .unwrap();
553        assert!(result.is_some(), "Should find a matching pattern");
554        let pattern = result.unwrap();
555        assert_eq!(pattern.domain, "security");
556        assert!(pattern.confidence > 0.0);
557    }
558
559    #[tokio::test]
560    async fn test_adapt_no_match_below_threshold() {
561        let engine = SonaEngine::new(
562            SonaMode::Balanced,
563            std::sync::Arc::new(TfIdfEmbeddingProvider),
564        );
565
566        // No patterns learned
567        let result = engine
568            .adapt("completely unrelated query about cooking")
569            .await
570            .unwrap();
571        assert!(result.is_none());
572    }
573
574    #[tokio::test]
575    async fn test_capacity_limit() {
576        let engine = SonaEngine::new(SonaMode::Edge, std::sync::Arc::new(TfIdfEmbeddingProvider));
577        // Edge mode: max 50 trajectories
578
579        for i in 0..55 {
580            let mut traj = make_trajectory("testing", Verdict::Success);
581            traj.id = format!("traj-{}", i);
582            engine.record(traj).await.unwrap();
583        }
584
585        let (count, _) = engine.counts();
586        assert!(count <= 50, "Should not exceed capacity: got {}", count);
587    }
588
589    #[test]
590    fn test_trajectory_total_duration() {
591        let traj = Trajectory::new(
592            vec![make_step("a", "b"), make_step("c", "d")],
593            Verdict::Success,
594            "testing",
595        );
596        assert_eq!(traj.total_duration_ms(), 20);
597    }
598
599    #[test]
600    fn test_trajectory_avg_confidence() {
601        let traj = Trajectory::new(
602            vec![
603                TrajectoryStep {
604                    input: "a".into(),
605                    output: "b".into(),
606                    duration_ms: 10,
607                    confidence: 0.8,
608                },
609                TrajectoryStep {
610                    input: "c".into(),
611                    output: "d".into(),
612                    duration_ms: 10,
613                    confidence: 0.6,
614                },
615            ],
616            Verdict::Success,
617            "testing",
618        );
619        assert!((traj.avg_confidence() - 0.7).abs() < 0.01);
620    }
621
622    #[test]
623    fn test_sona_mode_default() {
624        assert_eq!(SonaMode::default(), SonaMode::Balanced);
625    }
626}