Skip to main content

brainwires_knowledge/knowledge/bks_pks/
cache.rs

1//! Local cache for behavioral truths with SQLite persistence
2//!
3//! Maintains a local copy of truths synced from the server, with offline
4//! queue support for when the server is unavailable.
5
6use super::truth::{BehavioralTruth, PendingTruthSubmission, TruthCategory, TruthFeedback};
7use anyhow::Result;
8use rusqlite::{Connection, OptionalExtension, params};
9use std::collections::HashMap;
10use std::path::Path;
11use std::sync::{Arc, Mutex};
12
13/// Local cache of behavioral truths synced from server
14pub struct BehavioralKnowledgeCache {
15    /// SQLite connection
16    conn: Arc<Mutex<Connection>>,
17
18    /// In-memory cache for fast access
19    truths: HashMap<String, BehavioralTruth>,
20
21    /// Timestamp of last successful sync with server
22    pub last_sync: i64,
23
24    /// Queue of truths waiting to be submitted to server
25    pending_submissions: Vec<PendingTruthSubmission>,
26
27    /// Queue of feedback waiting to be sent to server
28    pending_feedback: Vec<TruthFeedback>,
29
30    /// Maximum size of offline queue
31    max_queue_size: usize,
32}
33
34impl BehavioralKnowledgeCache {
35    /// Create a new cache with SQLite persistence
36    pub fn new<P: AsRef<Path>>(db_path: P, max_queue_size: usize) -> Result<Self> {
37        let conn = Connection::open(db_path)?;
38        Self::init_schema(&conn)?;
39
40        let mut cache = Self {
41            conn: Arc::new(Mutex::new(conn)),
42            truths: HashMap::new(),
43            last_sync: 0,
44            pending_submissions: Vec::new(),
45            pending_feedback: Vec::new(),
46            max_queue_size,
47        };
48
49        // Load existing data from database
50        cache.load_from_db()?;
51
52        Ok(cache)
53    }
54
55    /// Create an in-memory cache (for testing)
56    pub fn in_memory(max_queue_size: usize) -> Result<Self> {
57        let conn = Connection::open_in_memory()?;
58        Self::init_schema(&conn)?;
59
60        Ok(Self {
61            conn: Arc::new(Mutex::new(conn)),
62            truths: HashMap::new(),
63            last_sync: 0,
64            pending_submissions: Vec::new(),
65            pending_feedback: Vec::new(),
66            max_queue_size,
67        })
68    }
69
70    /// Initialize database schema
71    fn init_schema(conn: &Connection) -> Result<()> {
72        conn.execute_batch(
73            r#"
74            CREATE TABLE IF NOT EXISTS truths (
75                id TEXT PRIMARY KEY,
76                category TEXT NOT NULL,
77                context_pattern TEXT NOT NULL,
78                rule TEXT NOT NULL,
79                rationale TEXT NOT NULL,
80                confidence REAL NOT NULL,
81                reinforcements INTEGER NOT NULL DEFAULT 0,
82                contradictions INTEGER NOT NULL DEFAULT 0,
83                last_used INTEGER NOT NULL,
84                created_at INTEGER NOT NULL,
85                created_by TEXT,
86                source TEXT NOT NULL,
87                version INTEGER NOT NULL DEFAULT 1,
88                deleted INTEGER NOT NULL DEFAULT 0
89            );
90
91            CREATE INDEX IF NOT EXISTS idx_truths_context ON truths(context_pattern);
92            CREATE INDEX IF NOT EXISTS idx_truths_category ON truths(category);
93            CREATE INDEX IF NOT EXISTS idx_truths_confidence ON truths(confidence);
94
95            CREATE TABLE IF NOT EXISTS pending_submissions (
96                id INTEGER PRIMARY KEY AUTOINCREMENT,
97                truth_json TEXT NOT NULL,
98                queued_at INTEGER NOT NULL,
99                attempts INTEGER NOT NULL DEFAULT 0,
100                last_error TEXT
101            );
102
103            CREATE TABLE IF NOT EXISTS pending_feedback (
104                id INTEGER PRIMARY KEY AUTOINCREMENT,
105                truth_id TEXT NOT NULL,
106                is_reinforcement INTEGER NOT NULL,
107                context TEXT,
108                timestamp INTEGER NOT NULL
109            );
110
111            CREATE TABLE IF NOT EXISTS sync_state (
112                key TEXT PRIMARY KEY,
113                value TEXT NOT NULL
114            );
115            "#,
116        )?;
117
118        Ok(())
119    }
120
121    /// Load truths and state from database
122    fn load_from_db(&mut self) -> Result<()> {
123        let conn = self
124            .conn
125            .lock()
126            .expect("knowledge cache connection lock poisoned");
127
128        // Load last sync timestamp
129        self.last_sync = conn
130            .query_row(
131                "SELECT value FROM sync_state WHERE key = 'last_sync'",
132                [],
133                |row| row.get::<_, String>(0),
134            )
135            .optional()?
136            .and_then(|s| s.parse().ok())
137            .unwrap_or(0);
138
139        // Load truths
140        let mut stmt = conn.prepare(
141            "SELECT id, category, context_pattern, rule, rationale, confidence,
142                    reinforcements, contradictions, last_used, created_at,
143                    created_by, source, version, deleted
144             FROM truths WHERE deleted = 0",
145        )?;
146
147        let truths = stmt.query_map([], |row| {
148            Ok(BehavioralTruth {
149                id: row.get(0)?,
150                category: serde_json::from_str(&format!("\"{}\"", row.get::<_, String>(1)?))
151                    .unwrap_or(TruthCategory::CommandUsage),
152                context_pattern: row.get(2)?,
153                rule: row.get(3)?,
154                rationale: row.get(4)?,
155                confidence: row.get(5)?,
156                reinforcements: row.get(6)?,
157                contradictions: row.get(7)?,
158                last_used: row.get(8)?,
159                created_at: row.get(9)?,
160                created_by: row.get(10)?,
161                source: serde_json::from_str(&format!("\"{}\"", row.get::<_, String>(11)?))
162                    .unwrap_or(super::truth::TruthSource::ExplicitCommand),
163                version: row.get::<_, i64>(12)? as u64,
164                deleted: row.get::<_, i32>(13)? != 0,
165            })
166        })?;
167
168        for truth in truths {
169            let truth = truth?;
170            self.truths.insert(truth.id.clone(), truth);
171        }
172
173        // Load pending submissions
174        let mut stmt = conn.prepare(
175            "SELECT truth_json, queued_at, attempts, last_error FROM pending_submissions",
176        )?;
177
178        let submissions = stmt.query_map([], |row| {
179            let json: String = row.get(0)?;
180            let truth: BehavioralTruth = serde_json::from_str(&json).map_err(|e| {
181                rusqlite::Error::FromSqlConversionFailure(
182                    0,
183                    rusqlite::types::Type::Text,
184                    Box::new(e),
185                )
186            })?;
187            Ok(PendingTruthSubmission {
188                truth,
189                queued_at: row.get(1)?,
190                attempts: row.get(2)?,
191                last_error: row.get(3)?,
192            })
193        })?;
194
195        for submission in submissions {
196            self.pending_submissions.push(submission?);
197        }
198
199        // Load pending feedback
200        let mut stmt = conn.prepare(
201            "SELECT truth_id, is_reinforcement, context, timestamp FROM pending_feedback",
202        )?;
203
204        let feedback = stmt.query_map([], |row| {
205            Ok(TruthFeedback {
206                truth_id: row.get(0)?,
207                is_reinforcement: row.get::<_, i32>(1)? != 0,
208                context: row.get(2)?,
209                timestamp: row.get(3)?,
210            })
211        })?;
212
213        for fb in feedback {
214            self.pending_feedback.push(fb?);
215        }
216
217        Ok(())
218    }
219
220    /// Save a truth to the database
221    fn save_truth_to_db(&self, truth: &BehavioralTruth) -> Result<()> {
222        let conn = self
223            .conn
224            .lock()
225            .expect("knowledge cache connection lock poisoned");
226        let category = serde_json::to_string(&truth.category)?
227            .trim_matches('"')
228            .to_string();
229        let source = serde_json::to_string(&truth.source)?
230            .trim_matches('"')
231            .to_string();
232
233        conn.execute(
234            r#"INSERT OR REPLACE INTO truths
235               (id, category, context_pattern, rule, rationale, confidence,
236                reinforcements, contradictions, last_used, created_at,
237                created_by, source, version, deleted)
238               VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13, ?14)"#,
239            params![
240                truth.id,
241                category,
242                truth.context_pattern,
243                truth.rule,
244                truth.rationale,
245                truth.confidence,
246                truth.reinforcements,
247                truth.contradictions,
248                truth.last_used,
249                truth.created_at,
250                truth.created_by,
251                source,
252                truth.version as i64,
253                truth.deleted as i32,
254            ],
255        )?;
256
257        Ok(())
258    }
259
260    /// Update last sync timestamp
261    pub fn set_last_sync(&mut self, timestamp: i64) -> Result<()> {
262        self.last_sync = timestamp;
263        let conn = self
264            .conn
265            .lock()
266            .expect("knowledge cache connection lock poisoned");
267        conn.execute(
268            "INSERT OR REPLACE INTO sync_state (key, value) VALUES ('last_sync', ?1)",
269            params![timestamp.to_string()],
270        )?;
271        Ok(())
272    }
273
274    /// Add a new truth to the cache
275    pub fn add_truth(&mut self, truth: BehavioralTruth) -> Result<()> {
276        self.save_truth_to_db(&truth)?;
277        self.truths.insert(truth.id.clone(), truth);
278        Ok(())
279    }
280
281    /// Update an existing truth
282    pub fn update_truth(&mut self, truth: BehavioralTruth) -> Result<()> {
283        self.save_truth_to_db(&truth)?;
284        self.truths.insert(truth.id.clone(), truth);
285        Ok(())
286    }
287
288    /// Get a truth by ID
289    pub fn get_truth(&self, id: &str) -> Option<&BehavioralTruth> {
290        self.truths.get(id)
291    }
292
293    /// Get a mutable reference to a truth by ID
294    pub fn get_truth_mut(&mut self, id: &str) -> Option<&mut BehavioralTruth> {
295        self.truths.get_mut(id)
296    }
297
298    /// Remove a truth (soft delete)
299    pub fn remove_truth(&mut self, id: &str) -> Result<bool> {
300        if let Some(truth) = self.truths.get_mut(id) {
301            truth.delete();
302        } else {
303            return Ok(false);
304        }
305
306        // Save after releasing the mutable borrow
307        if let Some(truth) = self.truths.get(id) {
308            self.save_truth_to_db(truth)?;
309        }
310        Ok(true)
311    }
312
313    /// Get all active truths
314    pub fn all_truths(&self) -> impl Iterator<Item = &BehavioralTruth> {
315        self.truths.values().filter(|t| !t.deleted)
316    }
317
318    /// Get truths by category
319    pub fn truths_by_category(&self, category: TruthCategory) -> Vec<&BehavioralTruth> {
320        self.truths
321            .values()
322            .filter(|t| !t.deleted && t.category == category)
323            .collect()
324    }
325
326    /// Get truths matching a context pattern (simple substring match)
327    pub fn get_matching_truths(&self, context: &str) -> Vec<&BehavioralTruth> {
328        let context_lower = context.to_lowercase();
329        self.truths
330            .values()
331            .filter(|t| {
332                !t.deleted
333                    && t.context_pattern
334                        .to_lowercase()
335                        .split_whitespace()
336                        .any(|word| context_lower.contains(word))
337            })
338            .collect()
339    }
340
341    /// Get truths matching a context pattern with relevance scores
342    /// Returns (truth, score) tuples sorted by relevance
343    pub fn get_matching_truths_with_scores(
344        &self,
345        context: &str,
346        min_confidence: f32,
347        limit: usize,
348    ) -> Result<Vec<(&BehavioralTruth, f32)>> {
349        let context_lower = context.to_lowercase();
350        let context_words: Vec<&str> = context_lower.split_whitespace().collect();
351
352        let mut matches: Vec<(&BehavioralTruth, f32)> = self
353            .truths
354            .values()
355            .filter(|t| !t.deleted && t.confidence >= min_confidence)
356            .filter_map(|truth| {
357                // Calculate relevance score based on word overlap
358                let pattern_lower = truth.context_pattern.to_lowercase();
359                let pattern_words: Vec<&str> = pattern_lower.split_whitespace().collect();
360
361                let mut score = 0.0f32;
362                for pattern_word in &pattern_words {
363                    for context_word in &context_words {
364                        if context_word.contains(pattern_word)
365                            || pattern_word.contains(context_word)
366                        {
367                            score += 1.0;
368                        }
369                    }
370                }
371
372                if score > 0.0 {
373                    // Normalize by pattern length and boost by confidence
374                    let normalized_score = (score / pattern_words.len() as f32) * truth.confidence;
375                    Some((truth, normalized_score))
376                } else {
377                    None
378                }
379            })
380            .collect();
381
382        // Sort by score descending
383        matches.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
384
385        // Take top N
386        matches.truncate(limit);
387
388        Ok(matches)
389    }
390
391    /// Get truths above a confidence threshold
392    pub fn get_reliable_truths(
393        &self,
394        min_confidence: f32,
395        decay_days: u32,
396    ) -> Vec<&BehavioralTruth> {
397        self.truths
398            .values()
399            .filter(|t| !t.deleted && t.is_reliable(min_confidence, decay_days))
400            .collect()
401    }
402
403    /// Queue a truth for submission to server
404    pub fn queue_submission(&mut self, truth: BehavioralTruth) -> Result<bool> {
405        if self.pending_submissions.len() >= self.max_queue_size {
406            return Ok(false);
407        }
408
409        let submission = PendingTruthSubmission::new(truth);
410        let json = serde_json::to_string(&submission.truth)?;
411
412        let conn = self
413            .conn
414            .lock()
415            .expect("knowledge cache connection lock poisoned");
416        conn.execute(
417            "INSERT INTO pending_submissions (truth_json, queued_at, attempts) VALUES (?1, ?2, ?3)",
418            params![json, submission.queued_at, submission.attempts],
419        )?;
420
421        self.pending_submissions.push(submission);
422        Ok(true)
423    }
424
425    /// Get pending submissions
426    pub fn pending_submissions(&self) -> &[PendingTruthSubmission] {
427        &self.pending_submissions
428    }
429
430    /// Clear all pending submissions (after successful sync)
431    pub fn clear_pending_submissions(&mut self) -> Result<()> {
432        self.pending_submissions.clear();
433        let conn = self
434            .conn
435            .lock()
436            .expect("knowledge cache connection lock poisoned");
437        conn.execute("DELETE FROM pending_submissions", [])?;
438        Ok(())
439    }
440
441    /// Queue feedback for sending to server
442    pub fn queue_feedback(&mut self, feedback: TruthFeedback) -> Result<bool> {
443        if self.pending_feedback.len() >= self.max_queue_size {
444            return Ok(false);
445        }
446
447        let conn = self
448            .conn
449            .lock()
450            .expect("knowledge cache connection lock poisoned");
451        conn.execute(
452            "INSERT INTO pending_feedback (truth_id, is_reinforcement, context, timestamp)
453             VALUES (?1, ?2, ?3, ?4)",
454            params![
455                feedback.truth_id,
456                feedback.is_reinforcement as i32,
457                feedback.context,
458                feedback.timestamp,
459            ],
460        )?;
461
462        self.pending_feedback.push(feedback);
463        Ok(true)
464    }
465
466    /// Get pending feedback
467    pub fn pending_feedback(&self) -> &[TruthFeedback] {
468        &self.pending_feedback
469    }
470
471    /// Clear all pending feedback (after successful sync)
472    pub fn clear_pending_feedback(&mut self) -> Result<()> {
473        self.pending_feedback.clear();
474        let conn = self
475            .conn
476            .lock()
477            .expect("knowledge cache connection lock poisoned");
478        conn.execute("DELETE FROM pending_feedback", [])?;
479        Ok(())
480    }
481
482    /// Merge truths from server (handles version conflicts)
483    pub fn merge_from_server(
484        &mut self,
485        server_truths: Vec<BehavioralTruth>,
486    ) -> Result<MergeResult> {
487        let mut added = 0;
488        let mut updated = 0;
489        let mut conflicts = 0;
490
491        for server_truth in server_truths {
492            if let Some(local_truth) = self.truths.get(&server_truth.id) {
493                // Check for version conflict
494                if server_truth.version > local_truth.version {
495                    // Server wins - update local
496                    self.save_truth_to_db(&server_truth)?;
497                    self.truths.insert(server_truth.id.clone(), server_truth);
498                    updated += 1;
499                } else if server_truth.version < local_truth.version {
500                    // Local is newer - conflict (should be rare)
501                    conflicts += 1;
502                }
503                // Equal versions - no action needed
504            } else {
505                // New truth from server
506                self.save_truth_to_db(&server_truth)?;
507                self.truths.insert(server_truth.id.clone(), server_truth);
508                added += 1;
509            }
510        }
511
512        Ok(MergeResult {
513            added,
514            updated,
515            conflicts,
516        })
517    }
518
519    /// Apply decay to all truths
520    pub fn apply_decay(&mut self, decay_start_days: u32) -> Result<u32> {
521        let mut decayed = 0;
522
523        for truth in self.truths.values_mut() {
524            let old_confidence = truth.confidence;
525            truth.apply_decay(decay_start_days);
526            if (truth.confidence - old_confidence).abs() > 0.001 {
527                decayed += 1;
528            }
529        }
530
531        // Save decayed truths to database
532        if decayed > 0 {
533            for truth in self.truths.values() {
534                self.save_truth_to_db(truth)?;
535            }
536        }
537
538        Ok(decayed)
539    }
540
541    /// Get statistics about the cache
542    pub fn stats(&self) -> CacheStats {
543        let mut by_category: HashMap<TruthCategory, u32> = HashMap::new();
544        let mut total_confidence = 0.0f32;
545        let mut count = 0u32;
546
547        for truth in self.truths.values().filter(|t| !t.deleted) {
548            *by_category.entry(truth.category).or_insert(0) += 1;
549            total_confidence += truth.confidence;
550            count += 1;
551        }
552
553        CacheStats {
554            total_truths: count,
555            by_category,
556            avg_confidence: if count > 0 {
557                total_confidence / count as f32
558            } else {
559                0.0
560            },
561            pending_submissions: self.pending_submissions.len(),
562            pending_feedback: self.pending_feedback.len(),
563            last_sync: self.last_sync,
564        }
565    }
566}
567
568/// Result of merging truths from server
569#[derive(Debug, Clone)]
570pub struct MergeResult {
571    /// Number of new truths added.
572    pub added: u32,
573    /// Number of existing truths updated.
574    pub updated: u32,
575    /// Number of merge conflicts.
576    pub conflicts: u32,
577}
578
579/// Statistics about the cache
580#[derive(Debug, Clone)]
581pub struct CacheStats {
582    /// Total number of cached truths.
583    pub total_truths: u32,
584    /// Counts by category.
585    pub by_category: HashMap<TruthCategory, u32>,
586    /// Average confidence score.
587    pub avg_confidence: f32,
588    /// Number of pending truth submissions.
589    pub pending_submissions: usize,
590    /// Number of pending feedback reports.
591    pub pending_feedback: usize,
592    /// Unix timestamp of last sync.
593    pub last_sync: i64,
594}
595
596#[cfg(test)]
597mod tests {
598    use super::*;
599    use crate::knowledge::bks_pks::truth::TruthSource;
600
601    fn create_test_truth(context: &str, rule: &str) -> BehavioralTruth {
602        BehavioralTruth::new(
603            TruthCategory::CommandUsage,
604            context.to_string(),
605            rule.to_string(),
606            "Test rationale".to_string(),
607            TruthSource::ExplicitCommand,
608            None,
609        )
610    }
611
612    #[test]
613    fn test_cache_creation() {
614        let cache = BehavioralKnowledgeCache::in_memory(100).unwrap();
615        assert_eq!(cache.last_sync, 0);
616        assert_eq!(cache.all_truths().count(), 0);
617    }
618
619    #[test]
620    fn test_add_and_get_truth() {
621        let mut cache = BehavioralKnowledgeCache::in_memory(100).unwrap();
622        let truth = create_test_truth("pm2 logs", "Use --nostream");
623
624        let id = truth.id.clone();
625        cache.add_truth(truth).unwrap();
626
627        let retrieved = cache.get_truth(&id).unwrap();
628        assert_eq!(retrieved.rule, "Use --nostream");
629    }
630
631    #[test]
632    fn test_matching_truths() {
633        let mut cache = BehavioralKnowledgeCache::in_memory(100).unwrap();
634
635        cache
636            .add_truth(create_test_truth("pm2 logs", "Use --nostream"))
637            .unwrap();
638        cache
639            .add_truth(create_test_truth("cargo build", "Use cargo-watch"))
640            .unwrap();
641
642        let matches = cache.get_matching_truths("pm2 logs myapp");
643        assert_eq!(matches.len(), 1);
644        assert!(matches[0].rule.contains("--nostream"));
645    }
646
647    #[test]
648    fn test_truths_by_category() {
649        let mut cache = BehavioralKnowledgeCache::in_memory(100).unwrap();
650
651        cache
652            .add_truth(create_test_truth("test1", "rule1"))
653            .unwrap();
654
655        let mut task_truth = create_test_truth("test2", "rule2");
656        task_truth.category = TruthCategory::TaskStrategy;
657        cache.add_truth(task_truth).unwrap();
658
659        let cmd_truths = cache.truths_by_category(TruthCategory::CommandUsage);
660        assert_eq!(cmd_truths.len(), 1);
661
662        let task_truths = cache.truths_by_category(TruthCategory::TaskStrategy);
663        assert_eq!(task_truths.len(), 1);
664    }
665
666    #[test]
667    fn test_queue_submission() {
668        let mut cache = BehavioralKnowledgeCache::in_memory(2).unwrap();
669
670        let truth1 = create_test_truth("test1", "rule1");
671        let truth2 = create_test_truth("test2", "rule2");
672        let truth3 = create_test_truth("test3", "rule3");
673
674        assert!(cache.queue_submission(truth1).unwrap());
675        assert!(cache.queue_submission(truth2).unwrap());
676        assert!(!cache.queue_submission(truth3).unwrap()); // Queue full
677
678        assert_eq!(cache.pending_submissions().len(), 2);
679    }
680
681    #[test]
682    fn test_merge_from_server() {
683        let mut cache = BehavioralKnowledgeCache::in_memory(100).unwrap();
684
685        // Add local truth
686        let mut local = create_test_truth("local", "local rule");
687        local.version = 1;
688        let local_id = local.id.clone();
689        cache.add_truth(local).unwrap();
690
691        // Create server truths
692        let new_truth = create_test_truth("new", "new rule");
693
694        let mut updated = create_test_truth("local", "updated rule");
695        updated.id = local_id.clone();
696        updated.version = 2;
697
698        let result = cache.merge_from_server(vec![new_truth, updated]).unwrap();
699
700        assert_eq!(result.added, 1);
701        assert_eq!(result.updated, 1);
702        assert_eq!(result.conflicts, 0);
703
704        // Verify update applied
705        let truth = cache.get_truth(&local_id).unwrap();
706        assert_eq!(truth.rule, "updated rule");
707    }
708
709    #[test]
710    fn test_stats() {
711        let mut cache = BehavioralKnowledgeCache::in_memory(100).unwrap();
712
713        cache
714            .add_truth(create_test_truth("test1", "rule1"))
715            .unwrap();
716        cache
717            .add_truth(create_test_truth("test2", "rule2"))
718            .unwrap();
719
720        let stats = cache.stats();
721        assert_eq!(stats.total_truths, 2);
722        assert_eq!(
723            *stats.by_category.get(&TruthCategory::CommandUsage).unwrap(),
724            2
725        );
726        assert!(stats.avg_confidence > 0.0);
727    }
728}