1use std::collections::{HashMap, HashSet, VecDeque};
21
22use rusqlite::{params, Connection};
23use serde::{Deserialize, Serialize};
24
25use crate::error::Result;
26use crate::intelligence::fact_extraction::Fact;
27
28#[derive(Debug, Clone, Default)]
39pub struct TripletPattern {
40 pub subject: Option<String>,
42 pub predicate: Option<String>,
44 pub object: Option<String>,
46}
47
48impl TripletPattern {
49 pub fn any() -> Self {
51 Self::default()
52 }
53
54 pub fn with_subject(mut self, subject: impl Into<String>) -> Self {
55 self.subject = Some(subject.into());
56 self
57 }
58
59 pub fn with_predicate(mut self, predicate: impl Into<String>) -> Self {
60 self.predicate = Some(predicate.into());
61 self
62 }
63
64 pub fn with_object(mut self, object: impl Into<String>) -> Self {
65 self.object = Some(object.into());
66 self
67 }
68}
69
70#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct InferenceStep {
73 pub subject: String,
75 pub predicate: String,
77 pub object: String,
79 pub source_fact_id: i64,
81}
82
83#[derive(Debug, Clone, Serialize, Deserialize)]
87pub struct InferencePath {
88 pub steps: Vec<InferenceStep>,
90 pub confidence: f64,
92}
93
94#[derive(Debug, Clone, Serialize, Deserialize)]
96pub struct KnowledgeStats {
97 pub total_facts: i64,
99 pub unique_subjects: i64,
101 pub unique_predicates: i64,
103 pub unique_objects: i64,
105 pub top_predicates: Vec<(String, i64)>,
107 pub top_subjects: Vec<(String, i64)>,
109}
110
111pub struct TripletMatcher;
117
118impl TripletMatcher {
119 pub fn match_pattern(conn: &Connection, pattern: &TripletPattern) -> Result<Vec<Fact>> {
128 let mut conditions: Vec<String> = Vec::new();
130 let mut bind_values: Vec<String> = Vec::new();
131
132 if let Some(ref s) = pattern.subject {
133 conditions.push(format!(
134 "lower(subject) LIKE lower(?{})",
135 conditions.len() + 1
136 ));
137 bind_values.push(s.clone());
138 }
139 if let Some(ref p) = pattern.predicate {
140 conditions.push(format!(
141 "lower(predicate) LIKE lower(?{})",
142 conditions.len() + 1
143 ));
144 bind_values.push(p.clone());
145 }
146 if let Some(ref o) = pattern.object {
147 conditions.push(format!(
148 "lower(object) LIKE lower(?{})",
149 conditions.len() + 1
150 ));
151 bind_values.push(o.clone());
152 }
153
154 let where_clause = if conditions.is_empty() {
155 String::new()
156 } else {
157 format!("WHERE {}", conditions.join(" AND "))
158 };
159
160 let sql = format!(
161 "SELECT id, subject, predicate, object, confidence, source_memory_id, created_at
162 FROM facts
163 {where_clause}
164 ORDER BY id ASC"
165 );
166
167 let mut stmt = conn.prepare(&sql)?;
168
169 let facts = match bind_values.len() {
171 0 => stmt
172 .query_map([], map_row)?
173 .collect::<std::result::Result<Vec<Fact>, _>>()?,
174 1 => stmt
175 .query_map(params![bind_values[0]], map_row)?
176 .collect::<std::result::Result<Vec<Fact>, _>>()?,
177 2 => stmt
178 .query_map(params![bind_values[0], bind_values[1]], map_row)?
179 .collect::<std::result::Result<Vec<Fact>, _>>()?,
180 3 => stmt
181 .query_map(
182 params![bind_values[0], bind_values[1], bind_values[2]],
183 map_row,
184 )?
185 .collect::<std::result::Result<Vec<Fact>, _>>()?,
186 _ => unreachable!("pattern has at most 3 fields"),
187 };
188
189 Ok(facts)
190 }
191
192 pub fn infer_transitive(
202 conn: &Connection,
203 subject: &str,
204 predicate: &str,
205 max_hops: usize,
206 ) -> Result<Vec<InferencePath>> {
207 if max_hops == 0 {
208 return Ok(Vec::new());
209 }
210
211 let mut adj: HashMap<String, Vec<(i64, String, String, f64)>> = HashMap::new();
214
215 {
216 let mut stmt = conn.prepare(
217 "SELECT id, subject, object, confidence
218 FROM facts
219 WHERE lower(predicate) = lower(?1)
220 ORDER BY id ASC",
221 )?;
222
223 let rows = stmt.query_map(params![predicate], |row| {
224 Ok((
225 row.get::<_, i64>(0)?,
226 row.get::<_, String>(1)?,
227 row.get::<_, String>(2)?,
228 row.get::<_, f64>(3)?,
229 ))
230 })?;
231
232 for row in rows {
233 let (fact_id, subj, obj, conf) = row?;
234 adj.entry(subj.to_lowercase())
235 .or_default()
236 .push((fact_id, subj, obj, conf));
237 }
238 }
239
240 let mut results: Vec<InferencePath> = Vec::new();
242 let mut queue: VecDeque<(String, Vec<InferenceStep>, f64)> = VecDeque::new();
243 queue.push_back((subject.to_lowercase(), Vec::new(), 1.0));
244
245 while let Some((current, path, path_conf)) = queue.pop_front() {
246 if path.len() >= max_hops {
247 if !path.is_empty() {
249 results.push(InferencePath {
250 steps: path,
251 confidence: path_conf,
252 });
253 }
254 continue;
255 }
256
257 let neighbors = match adj.get(¤t) {
258 Some(n) => n.clone(),
259 None => {
260 if !path.is_empty() {
262 results.push(InferencePath {
263 steps: path,
264 confidence: path_conf,
265 });
266 }
267 continue;
268 }
269 };
270
271 let visited_in_path: HashSet<String> =
273 path.iter().map(|s| s.subject.to_lowercase()).collect();
274
275 let mut branched = false;
276 for (fact_id, subj_orig, obj, conf) in neighbors {
277 let obj_lower = obj.to_lowercase();
278
279 if visited_in_path.contains(&obj_lower) || obj_lower == current {
281 continue;
282 }
283
284 let step = InferenceStep {
285 subject: subj_orig,
286 predicate: predicate.to_string(),
287 object: obj.clone(),
288 source_fact_id: fact_id,
289 };
290
291 let mut next_path = path.clone();
292 next_path.push(step);
293 let next_conf = path_conf * conf;
294 queue.push_back((obj_lower, next_path, next_conf));
295 branched = true;
296 }
297
298 if !branched && !path.is_empty() {
300 results.push(InferencePath {
301 steps: path,
302 confidence: path_conf,
303 });
304 }
305 }
306
307 Ok(results)
308 }
309
310 pub fn query_knowledge(conn: &Connection, natural_language: &str) -> Result<Vec<Fact>> {
318 let entities: Vec<String> = natural_language
320 .split_whitespace()
321 .filter(|w| w.chars().next().map(|c| c.is_uppercase()).unwrap_or(false))
322 .map(|w| {
323 w.trim_end_matches(|c: char| !c.is_alphanumeric())
325 .to_string()
326 })
327 .filter(|w| w.len() >= 2)
328 .collect::<std::collections::HashSet<_>>() .into_iter()
330 .collect();
331
332 if entities.is_empty() {
333 return Ok(Vec::new());
334 }
335
336 let mut seen_ids: HashSet<i64> = HashSet::new();
340 let mut all_facts: Vec<Fact> = Vec::new();
341
342 let mut stmt = conn.prepare(
343 "SELECT id, subject, predicate, object, confidence, source_memory_id, created_at
344 FROM facts
345 WHERE lower(subject) LIKE lower(?1) OR lower(object) LIKE lower(?1)
346 ORDER BY id ASC",
347 )?;
348
349 for entity in &entities {
350 let rows = stmt
351 .query_map(params![entity], map_row)?
352 .collect::<std::result::Result<Vec<Fact>, _>>()?;
353
354 for fact in rows {
355 if seen_ids.insert(fact.id) {
356 all_facts.push(fact);
357 }
358 }
359 }
360
361 all_facts.sort_by_key(|f| f.id);
363 Ok(all_facts)
364 }
365
366 pub fn knowledge_stats(conn: &Connection) -> Result<KnowledgeStats> {
370 let total_facts: i64 =
371 conn.query_row("SELECT COUNT(*) FROM facts", [], |row| row.get(0))?;
372
373 let unique_subjects: i64 =
374 conn.query_row("SELECT COUNT(DISTINCT subject) FROM facts", [], |row| {
375 row.get(0)
376 })?;
377
378 let unique_predicates: i64 =
379 conn.query_row("SELECT COUNT(DISTINCT predicate) FROM facts", [], |row| {
380 row.get(0)
381 })?;
382
383 let unique_objects: i64 =
384 conn.query_row("SELECT COUNT(DISTINCT object) FROM facts", [], |row| {
385 row.get(0)
386 })?;
387
388 let mut pred_stmt = conn.prepare(
390 "SELECT predicate, COUNT(*) as cnt
391 FROM facts
392 GROUP BY predicate
393 ORDER BY cnt DESC
394 LIMIT 10",
395 )?;
396 let top_predicates: Vec<(String, i64)> = pred_stmt
397 .query_map([], |row| {
398 Ok((row.get::<_, String>(0)?, row.get::<_, i64>(1)?))
399 })?
400 .collect::<std::result::Result<Vec<_>, _>>()?;
401
402 let mut subj_stmt = conn.prepare(
404 "SELECT subject, COUNT(*) as cnt
405 FROM facts
406 GROUP BY subject
407 ORDER BY cnt DESC
408 LIMIT 10",
409 )?;
410 let top_subjects: Vec<(String, i64)> = subj_stmt
411 .query_map([], |row| {
412 Ok((row.get::<_, String>(0)?, row.get::<_, i64>(1)?))
413 })?
414 .collect::<std::result::Result<Vec<_>, _>>()?;
415
416 Ok(KnowledgeStats {
417 total_facts,
418 unique_subjects,
419 unique_predicates,
420 unique_objects,
421 top_predicates,
422 top_subjects,
423 })
424 }
425}
426
427fn map_row(row: &rusqlite::Row<'_>) -> rusqlite::Result<Fact> {
432 Ok(Fact {
433 id: row.get(0)?,
434 subject: row.get(1)?,
435 predicate: row.get(2)?,
436 object: row.get(3)?,
437 confidence: row.get(4)?,
438 source_memory_id: row.get(5)?,
439 created_at: row.get(6)?,
440 })
441}
442
443#[cfg(test)]
448mod tests {
449 use super::*;
450 use rusqlite::Connection;
451
452 const CREATE_TABLE: &str = r#"
454 CREATE TABLE IF NOT EXISTS facts (
455 id INTEGER PRIMARY KEY AUTOINCREMENT,
456 subject TEXT NOT NULL,
457 predicate TEXT NOT NULL,
458 object TEXT NOT NULL,
459 confidence REAL NOT NULL DEFAULT 0.8,
460 source_memory_id INTEGER,
461 created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ', 'now'))
462 );
463 CREATE INDEX IF NOT EXISTS idx_facts_subject ON facts(subject);
464 "#;
465
466 fn setup() -> Connection {
467 let conn = Connection::open_in_memory().expect("in-memory db");
468 conn.execute_batch(CREATE_TABLE).expect("create table");
469 conn
470 }
471
472 fn insert(conn: &Connection, subject: &str, predicate: &str, object: &str, confidence: f64) {
473 conn.execute(
474 "INSERT INTO facts (subject, predicate, object, confidence, created_at)
475 VALUES (?1, ?2, ?3, ?4, '2026-01-01T00:00:00Z')",
476 params![subject, predicate, object, confidence],
477 )
478 .expect("insert fact");
479 }
480
481 fn seed_graph(conn: &Connection) {
482 insert(conn, "Alice", "works_at", "Google", 0.9);
487 insert(conn, "Bob", "works_at", "Google", 0.85);
488 insert(conn, "Google", "located_in", "California", 0.95);
489 insert(conn, "Carol", "lives_in", "London", 0.8);
490 insert(conn, "Dave", "located_in", "Paris", 0.75);
491 }
492
493 #[test]
498 fn test_match_by_subject() {
499 let conn = setup();
500 seed_graph(&conn);
501
502 let pattern = TripletPattern::any().with_subject("Alice");
503 let facts = TripletMatcher::match_pattern(&conn, &pattern).expect("match");
504 assert_eq!(facts.len(), 1);
505 assert_eq!(facts[0].subject, "Alice");
506 assert_eq!(facts[0].predicate, "works_at");
507 assert_eq!(facts[0].object, "Google");
508 }
509
510 #[test]
511 fn test_match_by_predicate() {
512 let conn = setup();
513 seed_graph(&conn);
514
515 let pattern = TripletPattern::any().with_predicate("works_at");
516 let facts = TripletMatcher::match_pattern(&conn, &pattern).expect("match");
517 assert_eq!(facts.len(), 2);
518 let subjects: Vec<&str> = facts.iter().map(|f| f.subject.as_str()).collect();
519 assert!(subjects.contains(&"Alice"));
520 assert!(subjects.contains(&"Bob"));
521 }
522
523 #[test]
524 fn test_match_by_object() {
525 let conn = setup();
526 seed_graph(&conn);
527
528 let pattern = TripletPattern::any().with_object("Google");
529 let facts = TripletMatcher::match_pattern(&conn, &pattern).expect("match");
530 assert_eq!(facts.len(), 2);
531 }
532
533 #[test]
534 fn test_wildcard_match_returns_all() {
535 let conn = setup();
536 seed_graph(&conn);
537
538 let pattern = TripletPattern::any();
539 let facts = TripletMatcher::match_pattern(&conn, &pattern).expect("match");
540 assert_eq!(facts.len(), 5);
541 }
542
543 #[test]
544 fn test_match_case_insensitive() {
545 let conn = setup();
546 seed_graph(&conn);
547
548 let pattern = TripletPattern::any().with_subject("alice");
550 let facts = TripletMatcher::match_pattern(&conn, &pattern).expect("match");
551 assert_eq!(facts.len(), 1);
552 }
553
554 #[test]
555 fn test_no_matches_returns_empty() {
556 let conn = setup();
557 seed_graph(&conn);
558
559 let pattern = TripletPattern::any().with_subject("Nonexistent");
560 let facts = TripletMatcher::match_pattern(&conn, &pattern).expect("match");
561 assert!(facts.is_empty());
562 }
563
564 #[test]
565 fn test_match_subject_and_predicate() {
566 let conn = setup();
567 seed_graph(&conn);
568
569 let pattern = TripletPattern::any()
570 .with_subject("Google")
571 .with_predicate("located_in");
572 let facts = TripletMatcher::match_pattern(&conn, &pattern).expect("match");
573 assert_eq!(facts.len(), 1);
574 assert_eq!(facts[0].object, "California");
575 }
576
577 #[test]
582 fn test_transitive_inference_two_hops() {
583 let conn = setup();
584 insert(&conn, "Alice", "works_at", "Google", 0.9);
586 insert(&conn, "Google", "works_at", "Alphabet", 0.8);
587
588 let paths = TripletMatcher::infer_transitive(&conn, "Alice", "works_at", 3).expect("infer");
589
590 assert!(!paths.is_empty(), "expected at least one inference path");
591 let longest = paths.iter().max_by_key(|p| p.steps.len()).unwrap();
593 assert_eq!(longest.steps.len(), 2);
594 assert_eq!(longest.steps[0].subject, "Alice");
595 assert_eq!(longest.steps[0].object, "Google");
596 assert_eq!(longest.steps[1].subject, "Google");
597 assert_eq!(longest.steps[1].object, "Alphabet");
598 }
599
600 #[test]
601 fn test_transitive_inference_no_matching_predicate() {
602 let conn = setup();
603 seed_graph(&conn);
604
605 let paths = TripletMatcher::infer_transitive(&conn, "Alice", "lives_in", 3).expect("infer");
607 assert!(paths.is_empty());
608 }
609
610 #[test]
611 fn test_transitive_inference_max_hops_zero() {
612 let conn = setup();
613 seed_graph(&conn);
614
615 let paths = TripletMatcher::infer_transitive(&conn, "Alice", "works_at", 0).expect("infer");
616 assert!(paths.is_empty());
617 }
618
619 #[test]
620 fn test_transitive_confidence_product() {
621 let conn = setup();
622 insert(&conn, "Alice", "rel", "B", 0.9);
624 insert(&conn, "B", "rel", "C", 0.8);
625
626 let paths = TripletMatcher::infer_transitive(&conn, "Alice", "rel", 5).expect("infer");
627 let two_hop = paths.iter().find(|p| p.steps.len() == 2);
629 assert!(two_hop.is_some(), "expected 2-hop path");
630 let conf = two_hop.unwrap().confidence;
631 assert!(
632 (conf - 0.9 * 0.8).abs() < 1e-9,
633 "expected confidence ~0.72, got {conf}"
634 );
635 }
636
637 #[test]
642 fn test_query_knowledge_by_entity() {
643 let conn = setup();
644 seed_graph(&conn);
645
646 let facts = TripletMatcher::query_knowledge(&conn, "What does Alice do?").expect("query");
647 assert!(
649 facts.iter().any(|f| f.subject == "Alice"),
650 "expected Alice fact, got: {:?}",
651 facts
652 );
653 }
654
655 #[test]
656 fn test_query_knowledge_no_entities_returns_empty() {
657 let conn = setup();
658 seed_graph(&conn);
659
660 let facts =
662 TripletMatcher::query_knowledge(&conn, "what does everyone do?").expect("query");
663 assert!(facts.is_empty());
664 }
665
666 #[test]
667 fn test_query_knowledge_multiple_entities() {
668 let conn = setup();
669 seed_graph(&conn);
670
671 let facts =
673 TripletMatcher::query_knowledge(&conn, "Tell me about Alice and Carol").expect("query");
674 let subjects: Vec<&str> = facts.iter().map(|f| f.subject.as_str()).collect();
675 assert!(subjects.contains(&"Alice"), "expected Alice fact");
676 assert!(subjects.contains(&"Carol"), "expected Carol fact");
677 }
678
679 #[test]
684 fn test_knowledge_stats_empty_table() {
685 let conn = setup();
686 let stats = TripletMatcher::knowledge_stats(&conn).expect("stats");
687 assert_eq!(stats.total_facts, 0);
688 assert_eq!(stats.unique_subjects, 0);
689 assert_eq!(stats.unique_predicates, 0);
690 assert_eq!(stats.unique_objects, 0);
691 assert!(stats.top_predicates.is_empty());
692 assert!(stats.top_subjects.is_empty());
693 }
694
695 #[test]
696 fn test_knowledge_stats_with_data() {
697 let conn = setup();
698 seed_graph(&conn);
699
700 let stats = TripletMatcher::knowledge_stats(&conn).expect("stats");
701 assert_eq!(stats.total_facts, 5);
702 assert_eq!(stats.unique_subjects, 5);
704 assert_eq!(stats.unique_predicates, 3);
706 assert_eq!(stats.unique_objects, 4);
708
709 assert!(!stats.top_predicates.is_empty());
711 assert_eq!(stats.top_predicates[0].0, "works_at");
712 assert_eq!(stats.top_predicates[0].1, 2);
713 }
714}