1use super::truth::{BehavioralTruth, PendingTruthSubmission, TruthCategory, TruthFeedback};
7use anyhow::Result;
8use rusqlite::{Connection, OptionalExtension, params};
9use std::collections::HashMap;
10use std::path::Path;
11use std::sync::{Arc, Mutex};
12
13pub struct BehavioralKnowledgeCache {
15 conn: Arc<Mutex<Connection>>,
17
18 truths: HashMap<String, BehavioralTruth>,
20
21 pub last_sync: i64,
23
24 pending_submissions: Vec<PendingTruthSubmission>,
26
27 pending_feedback: Vec<TruthFeedback>,
29
30 max_queue_size: usize,
32}
33
34impl BehavioralKnowledgeCache {
35 pub fn new<P: AsRef<Path>>(db_path: P, max_queue_size: usize) -> Result<Self> {
37 let conn = Connection::open(db_path)?;
38 Self::init_schema(&conn)?;
39
40 let mut cache = Self {
41 conn: Arc::new(Mutex::new(conn)),
42 truths: HashMap::new(),
43 last_sync: 0,
44 pending_submissions: Vec::new(),
45 pending_feedback: Vec::new(),
46 max_queue_size,
47 };
48
49 cache.load_from_db()?;
51
52 Ok(cache)
53 }
54
55 pub fn in_memory(max_queue_size: usize) -> Result<Self> {
57 let conn = Connection::open_in_memory()?;
58 Self::init_schema(&conn)?;
59
60 Ok(Self {
61 conn: Arc::new(Mutex::new(conn)),
62 truths: HashMap::new(),
63 last_sync: 0,
64 pending_submissions: Vec::new(),
65 pending_feedback: Vec::new(),
66 max_queue_size,
67 })
68 }
69
70 fn init_schema(conn: &Connection) -> Result<()> {
72 conn.execute_batch(
73 r#"
74 CREATE TABLE IF NOT EXISTS truths (
75 id TEXT PRIMARY KEY,
76 category TEXT NOT NULL,
77 context_pattern TEXT NOT NULL,
78 rule TEXT NOT NULL,
79 rationale TEXT NOT NULL,
80 confidence REAL NOT NULL,
81 reinforcements INTEGER NOT NULL DEFAULT 0,
82 contradictions INTEGER NOT NULL DEFAULT 0,
83 last_used INTEGER NOT NULL,
84 created_at INTEGER NOT NULL,
85 created_by TEXT,
86 source TEXT NOT NULL,
87 version INTEGER NOT NULL DEFAULT 1,
88 deleted INTEGER NOT NULL DEFAULT 0
89 );
90
91 CREATE INDEX IF NOT EXISTS idx_truths_context ON truths(context_pattern);
92 CREATE INDEX IF NOT EXISTS idx_truths_category ON truths(category);
93 CREATE INDEX IF NOT EXISTS idx_truths_confidence ON truths(confidence);
94
95 CREATE TABLE IF NOT EXISTS pending_submissions (
96 id INTEGER PRIMARY KEY AUTOINCREMENT,
97 truth_json TEXT NOT NULL,
98 queued_at INTEGER NOT NULL,
99 attempts INTEGER NOT NULL DEFAULT 0,
100 last_error TEXT
101 );
102
103 CREATE TABLE IF NOT EXISTS pending_feedback (
104 id INTEGER PRIMARY KEY AUTOINCREMENT,
105 truth_id TEXT NOT NULL,
106 is_reinforcement INTEGER NOT NULL,
107 context TEXT,
108 timestamp INTEGER NOT NULL
109 );
110
111 CREATE TABLE IF NOT EXISTS sync_state (
112 key TEXT PRIMARY KEY,
113 value TEXT NOT NULL
114 );
115 "#,
116 )?;
117
118 Ok(())
119 }
120
121 fn load_from_db(&mut self) -> Result<()> {
123 let conn = self
124 .conn
125 .lock()
126 .expect("knowledge cache connection lock poisoned");
127
128 self.last_sync = conn
130 .query_row(
131 "SELECT value FROM sync_state WHERE key = 'last_sync'",
132 [],
133 |row| row.get::<_, String>(0),
134 )
135 .optional()?
136 .and_then(|s| s.parse().ok())
137 .unwrap_or(0);
138
139 let mut stmt = conn.prepare(
141 "SELECT id, category, context_pattern, rule, rationale, confidence,
142 reinforcements, contradictions, last_used, created_at,
143 created_by, source, version, deleted
144 FROM truths WHERE deleted = 0",
145 )?;
146
147 let truths = stmt.query_map([], |row| {
148 Ok(BehavioralTruth {
149 id: row.get(0)?,
150 category: serde_json::from_str(&format!("\"{}\"", row.get::<_, String>(1)?))
151 .unwrap_or(TruthCategory::CommandUsage),
152 context_pattern: row.get(2)?,
153 rule: row.get(3)?,
154 rationale: row.get(4)?,
155 confidence: row.get(5)?,
156 reinforcements: row.get(6)?,
157 contradictions: row.get(7)?,
158 last_used: row.get(8)?,
159 created_at: row.get(9)?,
160 created_by: row.get(10)?,
161 source: serde_json::from_str(&format!("\"{}\"", row.get::<_, String>(11)?))
162 .unwrap_or(super::truth::TruthSource::ExplicitCommand),
163 version: row.get::<_, i64>(12)? as u64,
164 deleted: row.get::<_, i32>(13)? != 0,
165 })
166 })?;
167
168 for truth in truths {
169 let truth = truth?;
170 self.truths.insert(truth.id.clone(), truth);
171 }
172
173 let mut stmt = conn.prepare(
175 "SELECT truth_json, queued_at, attempts, last_error FROM pending_submissions",
176 )?;
177
178 let submissions = stmt.query_map([], |row| {
179 let json: String = row.get(0)?;
180 let truth: BehavioralTruth = serde_json::from_str(&json).map_err(|e| {
181 rusqlite::Error::FromSqlConversionFailure(
182 0,
183 rusqlite::types::Type::Text,
184 Box::new(e),
185 )
186 })?;
187 Ok(PendingTruthSubmission {
188 truth,
189 queued_at: row.get(1)?,
190 attempts: row.get(2)?,
191 last_error: row.get(3)?,
192 })
193 })?;
194
195 for submission in submissions {
196 self.pending_submissions.push(submission?);
197 }
198
199 let mut stmt = conn.prepare(
201 "SELECT truth_id, is_reinforcement, context, timestamp FROM pending_feedback",
202 )?;
203
204 let feedback = stmt.query_map([], |row| {
205 Ok(TruthFeedback {
206 truth_id: row.get(0)?,
207 is_reinforcement: row.get::<_, i32>(1)? != 0,
208 context: row.get(2)?,
209 timestamp: row.get(3)?,
210 })
211 })?;
212
213 for fb in feedback {
214 self.pending_feedback.push(fb?);
215 }
216
217 Ok(())
218 }
219
220 fn save_truth_to_db(&self, truth: &BehavioralTruth) -> Result<()> {
222 let conn = self
223 .conn
224 .lock()
225 .expect("knowledge cache connection lock poisoned");
226 let category = serde_json::to_string(&truth.category)?
227 .trim_matches('"')
228 .to_string();
229 let source = serde_json::to_string(&truth.source)?
230 .trim_matches('"')
231 .to_string();
232
233 conn.execute(
234 r#"INSERT OR REPLACE INTO truths
235 (id, category, context_pattern, rule, rationale, confidence,
236 reinforcements, contradictions, last_used, created_at,
237 created_by, source, version, deleted)
238 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13, ?14)"#,
239 params![
240 truth.id,
241 category,
242 truth.context_pattern,
243 truth.rule,
244 truth.rationale,
245 truth.confidence,
246 truth.reinforcements,
247 truth.contradictions,
248 truth.last_used,
249 truth.created_at,
250 truth.created_by,
251 source,
252 truth.version as i64,
253 truth.deleted as i32,
254 ],
255 )?;
256
257 Ok(())
258 }
259
260 pub fn set_last_sync(&mut self, timestamp: i64) -> Result<()> {
262 self.last_sync = timestamp;
263 let conn = self
264 .conn
265 .lock()
266 .expect("knowledge cache connection lock poisoned");
267 conn.execute(
268 "INSERT OR REPLACE INTO sync_state (key, value) VALUES ('last_sync', ?1)",
269 params![timestamp.to_string()],
270 )?;
271 Ok(())
272 }
273
274 pub fn add_truth(&mut self, truth: BehavioralTruth) -> Result<()> {
276 self.save_truth_to_db(&truth)?;
277 self.truths.insert(truth.id.clone(), truth);
278 Ok(())
279 }
280
281 pub fn update_truth(&mut self, truth: BehavioralTruth) -> Result<()> {
283 self.save_truth_to_db(&truth)?;
284 self.truths.insert(truth.id.clone(), truth);
285 Ok(())
286 }
287
288 pub fn get_truth(&self, id: &str) -> Option<&BehavioralTruth> {
290 self.truths.get(id)
291 }
292
293 pub fn get_truth_mut(&mut self, id: &str) -> Option<&mut BehavioralTruth> {
295 self.truths.get_mut(id)
296 }
297
298 pub fn remove_truth(&mut self, id: &str) -> Result<bool> {
300 if let Some(truth) = self.truths.get_mut(id) {
301 truth.delete();
302 } else {
303 return Ok(false);
304 }
305
306 if let Some(truth) = self.truths.get(id) {
308 self.save_truth_to_db(truth)?;
309 }
310 Ok(true)
311 }
312
313 pub fn all_truths(&self) -> impl Iterator<Item = &BehavioralTruth> {
315 self.truths.values().filter(|t| !t.deleted)
316 }
317
318 pub fn truths_by_category(&self, category: TruthCategory) -> Vec<&BehavioralTruth> {
320 self.truths
321 .values()
322 .filter(|t| !t.deleted && t.category == category)
323 .collect()
324 }
325
326 pub fn get_matching_truths(&self, context: &str) -> Vec<&BehavioralTruth> {
328 let context_lower = context.to_lowercase();
329 self.truths
330 .values()
331 .filter(|t| {
332 !t.deleted
333 && t.context_pattern
334 .to_lowercase()
335 .split_whitespace()
336 .any(|word| context_lower.contains(word))
337 })
338 .collect()
339 }
340
341 pub fn get_matching_truths_with_scores(
344 &self,
345 context: &str,
346 min_confidence: f32,
347 limit: usize,
348 ) -> Result<Vec<(&BehavioralTruth, f32)>> {
349 let context_lower = context.to_lowercase();
350 let context_words: Vec<&str> = context_lower.split_whitespace().collect();
351
352 let mut matches: Vec<(&BehavioralTruth, f32)> = self
353 .truths
354 .values()
355 .filter(|t| !t.deleted && t.confidence >= min_confidence)
356 .filter_map(|truth| {
357 let pattern_lower = truth.context_pattern.to_lowercase();
359 let pattern_words: Vec<&str> = pattern_lower.split_whitespace().collect();
360
361 let mut score = 0.0f32;
362 for pattern_word in &pattern_words {
363 for context_word in &context_words {
364 if context_word.contains(pattern_word)
365 || pattern_word.contains(context_word)
366 {
367 score += 1.0;
368 }
369 }
370 }
371
372 if score > 0.0 {
373 let normalized_score = (score / pattern_words.len() as f32) * truth.confidence;
375 Some((truth, normalized_score))
376 } else {
377 None
378 }
379 })
380 .collect();
381
382 matches.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
384
385 matches.truncate(limit);
387
388 Ok(matches)
389 }
390
391 pub fn get_reliable_truths(
393 &self,
394 min_confidence: f32,
395 decay_days: u32,
396 ) -> Vec<&BehavioralTruth> {
397 self.truths
398 .values()
399 .filter(|t| !t.deleted && t.is_reliable(min_confidence, decay_days))
400 .collect()
401 }
402
403 pub fn queue_submission(&mut self, truth: BehavioralTruth) -> Result<bool> {
405 if self.pending_submissions.len() >= self.max_queue_size {
406 return Ok(false);
407 }
408
409 let submission = PendingTruthSubmission::new(truth);
410 let json = serde_json::to_string(&submission.truth)?;
411
412 let conn = self
413 .conn
414 .lock()
415 .expect("knowledge cache connection lock poisoned");
416 conn.execute(
417 "INSERT INTO pending_submissions (truth_json, queued_at, attempts) VALUES (?1, ?2, ?3)",
418 params![json, submission.queued_at, submission.attempts],
419 )?;
420
421 self.pending_submissions.push(submission);
422 Ok(true)
423 }
424
425 pub fn pending_submissions(&self) -> &[PendingTruthSubmission] {
427 &self.pending_submissions
428 }
429
430 pub fn clear_pending_submissions(&mut self) -> Result<()> {
432 self.pending_submissions.clear();
433 let conn = self
434 .conn
435 .lock()
436 .expect("knowledge cache connection lock poisoned");
437 conn.execute("DELETE FROM pending_submissions", [])?;
438 Ok(())
439 }
440
441 pub fn queue_feedback(&mut self, feedback: TruthFeedback) -> Result<bool> {
443 if self.pending_feedback.len() >= self.max_queue_size {
444 return Ok(false);
445 }
446
447 let conn = self
448 .conn
449 .lock()
450 .expect("knowledge cache connection lock poisoned");
451 conn.execute(
452 "INSERT INTO pending_feedback (truth_id, is_reinforcement, context, timestamp)
453 VALUES (?1, ?2, ?3, ?4)",
454 params![
455 feedback.truth_id,
456 feedback.is_reinforcement as i32,
457 feedback.context,
458 feedback.timestamp,
459 ],
460 )?;
461
462 self.pending_feedback.push(feedback);
463 Ok(true)
464 }
465
466 pub fn pending_feedback(&self) -> &[TruthFeedback] {
468 &self.pending_feedback
469 }
470
471 pub fn clear_pending_feedback(&mut self) -> Result<()> {
473 self.pending_feedback.clear();
474 let conn = self
475 .conn
476 .lock()
477 .expect("knowledge cache connection lock poisoned");
478 conn.execute("DELETE FROM pending_feedback", [])?;
479 Ok(())
480 }
481
482 pub fn merge_from_server(
484 &mut self,
485 server_truths: Vec<BehavioralTruth>,
486 ) -> Result<MergeResult> {
487 let mut added = 0;
488 let mut updated = 0;
489 let mut conflicts = 0;
490
491 for server_truth in server_truths {
492 if let Some(local_truth) = self.truths.get(&server_truth.id) {
493 if server_truth.version > local_truth.version {
495 self.save_truth_to_db(&server_truth)?;
497 self.truths.insert(server_truth.id.clone(), server_truth);
498 updated += 1;
499 } else if server_truth.version < local_truth.version {
500 conflicts += 1;
502 }
503 } else {
505 self.save_truth_to_db(&server_truth)?;
507 self.truths.insert(server_truth.id.clone(), server_truth);
508 added += 1;
509 }
510 }
511
512 Ok(MergeResult {
513 added,
514 updated,
515 conflicts,
516 })
517 }
518
519 pub fn apply_decay(&mut self, decay_start_days: u32) -> Result<u32> {
521 let mut decayed = 0;
522
523 for truth in self.truths.values_mut() {
524 let old_confidence = truth.confidence;
525 truth.apply_decay(decay_start_days);
526 if (truth.confidence - old_confidence).abs() > 0.001 {
527 decayed += 1;
528 }
529 }
530
531 if decayed > 0 {
533 for truth in self.truths.values() {
534 self.save_truth_to_db(truth)?;
535 }
536 }
537
538 Ok(decayed)
539 }
540
541 pub fn stats(&self) -> CacheStats {
543 let mut by_category: HashMap<TruthCategory, u32> = HashMap::new();
544 let mut total_confidence = 0.0f32;
545 let mut count = 0u32;
546
547 for truth in self.truths.values().filter(|t| !t.deleted) {
548 *by_category.entry(truth.category).or_insert(0) += 1;
549 total_confidence += truth.confidence;
550 count += 1;
551 }
552
553 CacheStats {
554 total_truths: count,
555 by_category,
556 avg_confidence: if count > 0 {
557 total_confidence / count as f32
558 } else {
559 0.0
560 },
561 pending_submissions: self.pending_submissions.len(),
562 pending_feedback: self.pending_feedback.len(),
563 last_sync: self.last_sync,
564 }
565 }
566}
567
568#[derive(Debug, Clone)]
570pub struct MergeResult {
571 pub added: u32,
573 pub updated: u32,
575 pub conflicts: u32,
577}
578
579#[derive(Debug, Clone)]
581pub struct CacheStats {
582 pub total_truths: u32,
584 pub by_category: HashMap<TruthCategory, u32>,
586 pub avg_confidence: f32,
588 pub pending_submissions: usize,
590 pub pending_feedback: usize,
592 pub last_sync: i64,
594}
595
596#[cfg(test)]
597mod tests {
598 use super::*;
599 use crate::knowledge::bks_pks::truth::TruthSource;
600
601 fn create_test_truth(context: &str, rule: &str) -> BehavioralTruth {
602 BehavioralTruth::new(
603 TruthCategory::CommandUsage,
604 context.to_string(),
605 rule.to_string(),
606 "Test rationale".to_string(),
607 TruthSource::ExplicitCommand,
608 None,
609 )
610 }
611
612 #[test]
613 fn test_cache_creation() {
614 let cache = BehavioralKnowledgeCache::in_memory(100).unwrap();
615 assert_eq!(cache.last_sync, 0);
616 assert_eq!(cache.all_truths().count(), 0);
617 }
618
619 #[test]
620 fn test_add_and_get_truth() {
621 let mut cache = BehavioralKnowledgeCache::in_memory(100).unwrap();
622 let truth = create_test_truth("pm2 logs", "Use --nostream");
623
624 let id = truth.id.clone();
625 cache.add_truth(truth).unwrap();
626
627 let retrieved = cache.get_truth(&id).unwrap();
628 assert_eq!(retrieved.rule, "Use --nostream");
629 }
630
631 #[test]
632 fn test_matching_truths() {
633 let mut cache = BehavioralKnowledgeCache::in_memory(100).unwrap();
634
635 cache
636 .add_truth(create_test_truth("pm2 logs", "Use --nostream"))
637 .unwrap();
638 cache
639 .add_truth(create_test_truth("cargo build", "Use cargo-watch"))
640 .unwrap();
641
642 let matches = cache.get_matching_truths("pm2 logs myapp");
643 assert_eq!(matches.len(), 1);
644 assert!(matches[0].rule.contains("--nostream"));
645 }
646
647 #[test]
648 fn test_truths_by_category() {
649 let mut cache = BehavioralKnowledgeCache::in_memory(100).unwrap();
650
651 cache
652 .add_truth(create_test_truth("test1", "rule1"))
653 .unwrap();
654
655 let mut task_truth = create_test_truth("test2", "rule2");
656 task_truth.category = TruthCategory::TaskStrategy;
657 cache.add_truth(task_truth).unwrap();
658
659 let cmd_truths = cache.truths_by_category(TruthCategory::CommandUsage);
660 assert_eq!(cmd_truths.len(), 1);
661
662 let task_truths = cache.truths_by_category(TruthCategory::TaskStrategy);
663 assert_eq!(task_truths.len(), 1);
664 }
665
666 #[test]
667 fn test_queue_submission() {
668 let mut cache = BehavioralKnowledgeCache::in_memory(2).unwrap();
669
670 let truth1 = create_test_truth("test1", "rule1");
671 let truth2 = create_test_truth("test2", "rule2");
672 let truth3 = create_test_truth("test3", "rule3");
673
674 assert!(cache.queue_submission(truth1).unwrap());
675 assert!(cache.queue_submission(truth2).unwrap());
676 assert!(!cache.queue_submission(truth3).unwrap()); assert_eq!(cache.pending_submissions().len(), 2);
679 }
680
681 #[test]
682 fn test_merge_from_server() {
683 let mut cache = BehavioralKnowledgeCache::in_memory(100).unwrap();
684
685 let mut local = create_test_truth("local", "local rule");
687 local.version = 1;
688 let local_id = local.id.clone();
689 cache.add_truth(local).unwrap();
690
691 let new_truth = create_test_truth("new", "new rule");
693
694 let mut updated = create_test_truth("local", "updated rule");
695 updated.id = local_id.clone();
696 updated.version = 2;
697
698 let result = cache.merge_from_server(vec![new_truth, updated]).unwrap();
699
700 assert_eq!(result.added, 1);
701 assert_eq!(result.updated, 1);
702 assert_eq!(result.conflicts, 0);
703
704 let truth = cache.get_truth(&local_id).unwrap();
706 assert_eq!(truth.rule, "updated rule");
707 }
708
709 #[test]
710 fn test_stats() {
711 let mut cache = BehavioralKnowledgeCache::in_memory(100).unwrap();
712
713 cache
714 .add_truth(create_test_truth("test1", "rule1"))
715 .unwrap();
716 cache
717 .add_truth(create_test_truth("test2", "rule2"))
718 .unwrap();
719
720 let stats = cache.stats();
721 assert_eq!(stats.total_truths, 2);
722 assert_eq!(
723 *stats.by_category.get(&TruthCategory::CommandUsage).unwrap(),
724 2
725 );
726 assert!(stats.avg_confidence > 0.0);
727 }
728}