use std::collections::{HashMap, HashSet, VecDeque};
use rusqlite::{params, Connection};
use serde::{Deserialize, Serialize};
use crate::error::Result;
use crate::intelligence::fact_extraction::Fact;
#[derive(Debug, Clone, Default)]
pub struct TripletPattern {
pub subject: Option<String>,
pub predicate: Option<String>,
pub object: Option<String>,
}
impl TripletPattern {
pub fn any() -> Self {
Self::default()
}
pub fn with_subject(mut self, subject: impl Into<String>) -> Self {
self.subject = Some(subject.into());
self
}
pub fn with_predicate(mut self, predicate: impl Into<String>) -> Self {
self.predicate = Some(predicate.into());
self
}
pub fn with_object(mut self, object: impl Into<String>) -> Self {
self.object = Some(object.into());
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InferenceStep {
pub subject: String,
pub predicate: String,
pub object: String,
pub source_fact_id: i64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InferencePath {
pub steps: Vec<InferenceStep>,
pub confidence: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KnowledgeStats {
pub total_facts: i64,
pub unique_subjects: i64,
pub unique_predicates: i64,
pub unique_objects: i64,
pub top_predicates: Vec<(String, i64)>,
pub top_subjects: Vec<(String, i64)>,
}
pub struct TripletMatcher;
impl TripletMatcher {
pub fn match_pattern(conn: &Connection, pattern: &TripletPattern) -> Result<Vec<Fact>> {
const SQL: &str = "
SELECT id, subject, predicate, object, confidence, source_memory_id, created_at
FROM facts
WHERE (?1 IS NULL OR lower(subject) LIKE lower(?1))
AND (?2 IS NULL OR lower(predicate) LIKE lower(?2))
AND (?3 IS NULL OR lower(object) LIKE lower(?3))
ORDER BY id ASC
";
let subject: Option<&str> = pattern.subject.as_deref();
let predicate: Option<&str> = pattern.predicate.as_deref();
let object: Option<&str> = pattern.object.as_deref();
let mut stmt = conn.prepare(SQL)?;
let facts = stmt
.query_map(params![subject, predicate, object], map_row)?
.collect::<std::result::Result<Vec<Fact>, _>>()?;
Ok(facts)
}
pub fn infer_transitive(
conn: &Connection,
subject: &str,
predicate: &str,
max_hops: usize,
) -> Result<Vec<InferencePath>> {
if max_hops == 0 {
return Ok(Vec::new());
}
let mut adj: HashMap<String, Vec<(i64, String, String, f64)>> = HashMap::new();
{
let mut stmt = conn.prepare(
"SELECT id, subject, object, confidence
FROM facts
WHERE lower(predicate) = lower(?1)
ORDER BY id ASC",
)?;
let rows = stmt.query_map(params![predicate], |row| {
Ok((
row.get::<_, i64>(0)?,
row.get::<_, String>(1)?,
row.get::<_, String>(2)?,
row.get::<_, f64>(3)?,
))
})?;
for row in rows {
let (fact_id, subj, obj, conf) = row?;
adj.entry(subj.to_lowercase())
.or_default()
.push((fact_id, subj, obj, conf));
}
}
let mut results: Vec<InferencePath> = Vec::new();
let mut queue: VecDeque<(String, Vec<InferenceStep>, f64)> = VecDeque::new();
queue.push_back((subject.to_lowercase(), Vec::new(), 1.0));
while let Some((current, path, path_conf)) = queue.pop_front() {
if path.len() >= max_hops {
if !path.is_empty() {
results.push(InferencePath {
steps: path,
confidence: path_conf,
});
}
continue;
}
let neighbors = match adj.get(¤t) {
Some(n) => n.clone(),
None => {
if !path.is_empty() {
results.push(InferencePath {
steps: path,
confidence: path_conf,
});
}
continue;
}
};
let visited_in_path: HashSet<String> =
path.iter().map(|s| s.subject.to_lowercase()).collect();
let mut branched = false;
for (fact_id, subj_orig, obj, conf) in neighbors {
let obj_lower = obj.to_lowercase();
if visited_in_path.contains(&obj_lower) || obj_lower == current {
continue;
}
let step = InferenceStep {
subject: subj_orig,
predicate: predicate.to_string(),
object: obj.clone(),
source_fact_id: fact_id,
};
let mut next_path = path.clone();
next_path.push(step);
let next_conf = path_conf * conf;
queue.push_back((obj_lower, next_path, next_conf));
branched = true;
}
if !branched && !path.is_empty() {
results.push(InferencePath {
steps: path,
confidence: path_conf,
});
}
}
Ok(results)
}
pub fn query_knowledge(conn: &Connection, natural_language: &str) -> Result<Vec<Fact>> {
let entities: Vec<String> = natural_language
.split_whitespace()
.filter(|w| w.chars().next().map(|c| c.is_uppercase()).unwrap_or(false))
.map(|w| {
w.trim_end_matches(|c: char| !c.is_alphanumeric())
.to_string()
})
.filter(|w| w.len() >= 2)
.collect::<std::collections::HashSet<_>>() .into_iter()
.collect();
if entities.is_empty() {
return Ok(Vec::new());
}
let mut seen_ids: HashSet<i64> = HashSet::new();
let mut all_facts: Vec<Fact> = Vec::new();
let mut stmt = conn.prepare(
"SELECT id, subject, predicate, object, confidence, source_memory_id, created_at
FROM facts
WHERE lower(subject) LIKE lower(?1) OR lower(object) LIKE lower(?1)
ORDER BY id ASC",
)?;
for entity in &entities {
let rows = stmt
.query_map(params![entity], map_row)?
.collect::<std::result::Result<Vec<Fact>, _>>()?;
for fact in rows {
if seen_ids.insert(fact.id) {
all_facts.push(fact);
}
}
}
all_facts.sort_by_key(|f| f.id);
Ok(all_facts)
}
pub fn knowledge_stats(conn: &Connection) -> Result<KnowledgeStats> {
let total_facts: i64 =
conn.query_row("SELECT COUNT(*) FROM facts", [], |row| row.get(0))?;
let unique_subjects: i64 =
conn.query_row("SELECT COUNT(DISTINCT subject) FROM facts", [], |row| {
row.get(0)
})?;
let unique_predicates: i64 =
conn.query_row("SELECT COUNT(DISTINCT predicate) FROM facts", [], |row| {
row.get(0)
})?;
let unique_objects: i64 =
conn.query_row("SELECT COUNT(DISTINCT object) FROM facts", [], |row| {
row.get(0)
})?;
let mut pred_stmt = conn.prepare(
"SELECT predicate, COUNT(*) as cnt
FROM facts
GROUP BY predicate
ORDER BY cnt DESC
LIMIT 10",
)?;
let top_predicates: Vec<(String, i64)> = pred_stmt
.query_map([], |row| {
Ok((row.get::<_, String>(0)?, row.get::<_, i64>(1)?))
})?
.collect::<std::result::Result<Vec<_>, _>>()?;
let mut subj_stmt = conn.prepare(
"SELECT subject, COUNT(*) as cnt
FROM facts
GROUP BY subject
ORDER BY cnt DESC
LIMIT 10",
)?;
let top_subjects: Vec<(String, i64)> = subj_stmt
.query_map([], |row| {
Ok((row.get::<_, String>(0)?, row.get::<_, i64>(1)?))
})?
.collect::<std::result::Result<Vec<_>, _>>()?;
Ok(KnowledgeStats {
total_facts,
unique_subjects,
unique_predicates,
unique_objects,
top_predicates,
top_subjects,
})
}
}
fn map_row(row: &rusqlite::Row<'_>) -> rusqlite::Result<Fact> {
Ok(Fact {
id: row.get(0)?,
subject: row.get(1)?,
predicate: row.get(2)?,
object: row.get(3)?,
confidence: row.get(4)?,
source_memory_id: row.get(5)?,
created_at: row.get(6)?,
})
}
#[cfg(test)]
mod tests {
use super::*;
use rusqlite::Connection;
const CREATE_TABLE: &str = r#"
CREATE TABLE IF NOT EXISTS facts (
id INTEGER PRIMARY KEY AUTOINCREMENT,
subject TEXT NOT NULL,
predicate TEXT NOT NULL,
object TEXT NOT NULL,
confidence REAL NOT NULL DEFAULT 0.8,
source_memory_id INTEGER,
created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ', 'now'))
);
CREATE INDEX IF NOT EXISTS idx_facts_subject ON facts(subject);
"#;
fn setup() -> Connection {
let conn = Connection::open_in_memory().expect("in-memory db");
conn.execute_batch(CREATE_TABLE).expect("create table");
conn
}
fn insert(conn: &Connection, subject: &str, predicate: &str, object: &str, confidence: f64) {
conn.execute(
"INSERT INTO facts (subject, predicate, object, confidence, created_at)
VALUES (?1, ?2, ?3, ?4, '2026-01-01T00:00:00Z')",
params![subject, predicate, object, confidence],
)
.expect("insert fact");
}
fn seed_graph(conn: &Connection) {
insert(conn, "Alice", "works_at", "Google", 0.9);
insert(conn, "Bob", "works_at", "Google", 0.85);
insert(conn, "Google", "located_in", "California", 0.95);
insert(conn, "Carol", "lives_in", "London", 0.8);
insert(conn, "Dave", "located_in", "Paris", 0.75);
}
#[test]
fn test_match_all_filters() {
let conn = setup();
seed_graph(&conn);
let pattern = TripletPattern::any()
.with_subject("Alice")
.with_predicate("works_at")
.with_object("Google");
let facts = TripletMatcher::match_pattern(&conn, &pattern).expect("match all filters");
assert_eq!(facts.len(), 1);
assert_eq!(facts[0].subject, "Alice");
}
#[test]
fn test_match_partial_filters_subject_object() {
let conn = setup();
seed_graph(&conn);
let pattern = TripletPattern::any()
.with_subject("Alice")
.with_object("Google");
let facts = TripletMatcher::match_pattern(&conn, &pattern).expect("match partial filters");
assert_eq!(facts.len(), 1);
assert_eq!(facts[0].predicate, "works_at");
}
#[test]
fn test_match_no_filters_returns_all() {
let conn = setup();
seed_graph(&conn);
let pattern = TripletPattern::any();
let facts = TripletMatcher::match_pattern(&conn, &pattern).expect("match no filters");
assert_eq!(facts.len(), 5, "all facts returned when no filters");
}
#[test]
fn test_match_by_subject() {
let conn = setup();
seed_graph(&conn);
let pattern = TripletPattern::any().with_subject("Alice");
let facts = TripletMatcher::match_pattern(&conn, &pattern).expect("match");
assert_eq!(facts.len(), 1);
assert_eq!(facts[0].subject, "Alice");
assert_eq!(facts[0].predicate, "works_at");
assert_eq!(facts[0].object, "Google");
}
#[test]
fn test_match_by_predicate() {
let conn = setup();
seed_graph(&conn);
let pattern = TripletPattern::any().with_predicate("works_at");
let facts = TripletMatcher::match_pattern(&conn, &pattern).expect("match");
assert_eq!(facts.len(), 2);
let subjects: Vec<&str> = facts.iter().map(|f| f.subject.as_str()).collect();
assert!(subjects.contains(&"Alice"));
assert!(subjects.contains(&"Bob"));
}
#[test]
fn test_match_by_object() {
let conn = setup();
seed_graph(&conn);
let pattern = TripletPattern::any().with_object("Google");
let facts = TripletMatcher::match_pattern(&conn, &pattern).expect("match");
assert_eq!(facts.len(), 2);
}
#[test]
fn test_wildcard_match_returns_all() {
let conn = setup();
seed_graph(&conn);
let pattern = TripletPattern::any();
let facts = TripletMatcher::match_pattern(&conn, &pattern).expect("match");
assert_eq!(facts.len(), 5);
}
#[test]
fn test_match_case_insensitive() {
let conn = setup();
seed_graph(&conn);
let pattern = TripletPattern::any().with_subject("alice");
let facts = TripletMatcher::match_pattern(&conn, &pattern).expect("match");
assert_eq!(facts.len(), 1);
}
#[test]
fn test_no_matches_returns_empty() {
let conn = setup();
seed_graph(&conn);
let pattern = TripletPattern::any().with_subject("Nonexistent");
let facts = TripletMatcher::match_pattern(&conn, &pattern).expect("match");
assert!(facts.is_empty());
}
#[test]
fn test_match_subject_and_predicate() {
let conn = setup();
seed_graph(&conn);
let pattern = TripletPattern::any()
.with_subject("Google")
.with_predicate("located_in");
let facts = TripletMatcher::match_pattern(&conn, &pattern).expect("match");
assert_eq!(facts.len(), 1);
assert_eq!(facts[0].object, "California");
}
#[test]
fn test_transitive_inference_two_hops() {
let conn = setup();
insert(&conn, "Alice", "works_at", "Google", 0.9);
insert(&conn, "Google", "works_at", "Alphabet", 0.8);
let paths = TripletMatcher::infer_transitive(&conn, "Alice", "works_at", 3).expect("infer");
assert!(!paths.is_empty(), "expected at least one inference path");
let longest = paths.iter().max_by_key(|p| p.steps.len()).unwrap();
assert_eq!(longest.steps.len(), 2);
assert_eq!(longest.steps[0].subject, "Alice");
assert_eq!(longest.steps[0].object, "Google");
assert_eq!(longest.steps[1].subject, "Google");
assert_eq!(longest.steps[1].object, "Alphabet");
}
#[test]
fn test_transitive_inference_no_matching_predicate() {
let conn = setup();
seed_graph(&conn);
let paths = TripletMatcher::infer_transitive(&conn, "Alice", "lives_in", 3).expect("infer");
assert!(paths.is_empty());
}
#[test]
fn test_transitive_inference_max_hops_zero() {
let conn = setup();
seed_graph(&conn);
let paths = TripletMatcher::infer_transitive(&conn, "Alice", "works_at", 0).expect("infer");
assert!(paths.is_empty());
}
#[test]
fn test_transitive_confidence_product() {
let conn = setup();
insert(&conn, "Alice", "rel", "B", 0.9);
insert(&conn, "B", "rel", "C", 0.8);
let paths = TripletMatcher::infer_transitive(&conn, "Alice", "rel", 5).expect("infer");
let two_hop = paths.iter().find(|p| p.steps.len() == 2);
assert!(two_hop.is_some(), "expected 2-hop path");
let conf = two_hop.unwrap().confidence;
assert!(
(conf - 0.9 * 0.8).abs() < 1e-9,
"expected confidence ~0.72, got {conf}"
);
}
#[test]
fn test_query_knowledge_by_entity() {
let conn = setup();
seed_graph(&conn);
let facts = TripletMatcher::query_knowledge(&conn, "What does Alice do?").expect("query");
assert!(
facts.iter().any(|f| f.subject == "Alice"),
"expected Alice fact, got: {:?}",
facts
);
}
#[test]
fn test_query_knowledge_no_entities_returns_empty() {
let conn = setup();
seed_graph(&conn);
let facts =
TripletMatcher::query_knowledge(&conn, "what does everyone do?").expect("query");
assert!(facts.is_empty());
}
#[test]
fn test_query_knowledge_multiple_entities() {
let conn = setup();
seed_graph(&conn);
let facts =
TripletMatcher::query_knowledge(&conn, "Tell me about Alice and Carol").expect("query");
let subjects: Vec<&str> = facts.iter().map(|f| f.subject.as_str()).collect();
assert!(subjects.contains(&"Alice"), "expected Alice fact");
assert!(subjects.contains(&"Carol"), "expected Carol fact");
}
#[test]
fn test_knowledge_stats_empty_table() {
let conn = setup();
let stats = TripletMatcher::knowledge_stats(&conn).expect("stats");
assert_eq!(stats.total_facts, 0);
assert_eq!(stats.unique_subjects, 0);
assert_eq!(stats.unique_predicates, 0);
assert_eq!(stats.unique_objects, 0);
assert!(stats.top_predicates.is_empty());
assert!(stats.top_subjects.is_empty());
}
#[test]
fn test_knowledge_stats_with_data() {
let conn = setup();
seed_graph(&conn);
let stats = TripletMatcher::knowledge_stats(&conn).expect("stats");
assert_eq!(stats.total_facts, 5);
assert_eq!(stats.unique_subjects, 5);
assert_eq!(stats.unique_predicates, 3);
assert_eq!(stats.unique_objects, 4);
assert!(!stats.top_predicates.is_empty());
assert_eq!(stats.top_predicates[0].0, "works_at");
assert_eq!(stats.top_predicates[0].1, 2);
}
}