use std::collections::{HashMap, HashSet, VecDeque};
use rusqlite::{params, Connection};
use serde::{Deserialize, Serialize};
use crate::error::Result;
use crate::intelligence::fact_extraction::Fact;
#[derive(Debug, Clone, Default)]
pub struct TripletPattern {
pub subject: Option<String>,
pub predicate: Option<String>,
pub object: Option<String>,
}
impl TripletPattern {
pub fn any() -> Self {
Self::default()
}
pub fn with_subject(mut self, subject: impl Into<String>) -> Self {
self.subject = Some(subject.into());
self
}
pub fn with_predicate(mut self, predicate: impl Into<String>) -> Self {
self.predicate = Some(predicate.into());
self
}
pub fn with_object(mut self, object: impl Into<String>) -> Self {
self.object = Some(object.into());
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InferenceStep {
pub subject: String,
pub predicate: String,
pub object: String,
pub source_fact_id: i64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InferencePath {
pub steps: Vec<InferenceStep>,
pub confidence: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KnowledgeStats {
pub total_facts: i64,
pub unique_subjects: i64,
pub unique_predicates: i64,
pub unique_objects: i64,
pub top_predicates: Vec<(String, i64)>,
pub top_subjects: Vec<(String, i64)>,
}
pub struct TripletMatcher;
impl TripletMatcher {
pub fn match_pattern(conn: &Connection, pattern: &TripletPattern) -> Result<Vec<Fact>> {
let mut conditions: Vec<String> = Vec::new();
let mut bind_values: Vec<String> = Vec::new();
if let Some(ref s) = pattern.subject {
conditions.push(format!(
"lower(subject) LIKE lower(?{})",
conditions.len() + 1
));
bind_values.push(s.clone());
}
if let Some(ref p) = pattern.predicate {
conditions.push(format!(
"lower(predicate) LIKE lower(?{})",
conditions.len() + 1
));
bind_values.push(p.clone());
}
if let Some(ref o) = pattern.object {
conditions.push(format!(
"lower(object) LIKE lower(?{})",
conditions.len() + 1
));
bind_values.push(o.clone());
}
let where_clause = if conditions.is_empty() {
String::new()
} else {
format!("WHERE {}", conditions.join(" AND "))
};
let sql = format!(
"SELECT id, subject, predicate, object, confidence, source_memory_id, created_at
FROM facts
{where_clause}
ORDER BY id ASC"
);
let mut stmt = conn.prepare(&sql)?;
let facts = match bind_values.len() {
0 => stmt
.query_map([], map_row)?
.collect::<std::result::Result<Vec<Fact>, _>>()?,
1 => stmt
.query_map(params![bind_values[0]], map_row)?
.collect::<std::result::Result<Vec<Fact>, _>>()?,
2 => stmt
.query_map(params![bind_values[0], bind_values[1]], map_row)?
.collect::<std::result::Result<Vec<Fact>, _>>()?,
3 => stmt
.query_map(
params![bind_values[0], bind_values[1], bind_values[2]],
map_row,
)?
.collect::<std::result::Result<Vec<Fact>, _>>()?,
_ => unreachable!("pattern has at most 3 fields"),
};
Ok(facts)
}
pub fn infer_transitive(
conn: &Connection,
subject: &str,
predicate: &str,
max_hops: usize,
) -> Result<Vec<InferencePath>> {
if max_hops == 0 {
return Ok(Vec::new());
}
let mut adj: HashMap<String, Vec<(i64, String, String, f64)>> = HashMap::new();
{
let mut stmt = conn.prepare(
"SELECT id, subject, object, confidence
FROM facts
WHERE lower(predicate) = lower(?1)
ORDER BY id ASC",
)?;
let rows = stmt.query_map(params![predicate], |row| {
Ok((
row.get::<_, i64>(0)?,
row.get::<_, String>(1)?,
row.get::<_, String>(2)?,
row.get::<_, f64>(3)?,
))
})?;
for row in rows {
let (fact_id, subj, obj, conf) = row?;
adj.entry(subj.to_lowercase())
.or_default()
.push((fact_id, subj, obj, conf));
}
}
let mut results: Vec<InferencePath> = Vec::new();
let mut queue: VecDeque<(String, Vec<InferenceStep>, f64)> = VecDeque::new();
queue.push_back((subject.to_lowercase(), Vec::new(), 1.0));
while let Some((current, path, path_conf)) = queue.pop_front() {
if path.len() >= max_hops {
if !path.is_empty() {
results.push(InferencePath {
steps: path,
confidence: path_conf,
});
}
continue;
}
let neighbors = match adj.get(¤t) {
Some(n) => n.clone(),
None => {
if !path.is_empty() {
results.push(InferencePath {
steps: path,
confidence: path_conf,
});
}
continue;
}
};
let visited_in_path: HashSet<String> =
path.iter().map(|s| s.subject.to_lowercase()).collect();
let mut branched = false;
for (fact_id, subj_orig, obj, conf) in neighbors {
let obj_lower = obj.to_lowercase();
if visited_in_path.contains(&obj_lower) || obj_lower == current {
continue;
}
let step = InferenceStep {
subject: subj_orig,
predicate: predicate.to_string(),
object: obj.clone(),
source_fact_id: fact_id,
};
let mut next_path = path.clone();
next_path.push(step);
let next_conf = path_conf * conf;
queue.push_back((obj_lower, next_path, next_conf));
branched = true;
}
if !branched && !path.is_empty() {
results.push(InferencePath {
steps: path,
confidence: path_conf,
});
}
}
Ok(results)
}
pub fn query_knowledge(conn: &Connection, natural_language: &str) -> Result<Vec<Fact>> {
let entities: Vec<String> = natural_language
.split_whitespace()
.filter(|w| w.chars().next().map(|c| c.is_uppercase()).unwrap_or(false))
.map(|w| {
w.trim_end_matches(|c: char| !c.is_alphanumeric())
.to_string()
})
.filter(|w| w.len() >= 2)
.collect::<std::collections::HashSet<_>>() .into_iter()
.collect();
if entities.is_empty() {
return Ok(Vec::new());
}
let mut seen_ids: HashSet<i64> = HashSet::new();
let mut all_facts: Vec<Fact> = Vec::new();
let mut stmt = conn.prepare(
"SELECT id, subject, predicate, object, confidence, source_memory_id, created_at
FROM facts
WHERE lower(subject) LIKE lower(?1) OR lower(object) LIKE lower(?1)
ORDER BY id ASC",
)?;
for entity in &entities {
let rows = stmt
.query_map(params![entity], map_row)?
.collect::<std::result::Result<Vec<Fact>, _>>()?;
for fact in rows {
if seen_ids.insert(fact.id) {
all_facts.push(fact);
}
}
}
all_facts.sort_by_key(|f| f.id);
Ok(all_facts)
}
pub fn knowledge_stats(conn: &Connection) -> Result<KnowledgeStats> {
let total_facts: i64 =
conn.query_row("SELECT COUNT(*) FROM facts", [], |row| row.get(0))?;
let unique_subjects: i64 =
conn.query_row("SELECT COUNT(DISTINCT subject) FROM facts", [], |row| {
row.get(0)
})?;
let unique_predicates: i64 =
conn.query_row("SELECT COUNT(DISTINCT predicate) FROM facts", [], |row| {
row.get(0)
})?;
let unique_objects: i64 =
conn.query_row("SELECT COUNT(DISTINCT object) FROM facts", [], |row| {
row.get(0)
})?;
let mut pred_stmt = conn.prepare(
"SELECT predicate, COUNT(*) as cnt
FROM facts
GROUP BY predicate
ORDER BY cnt DESC
LIMIT 10",
)?;
let top_predicates: Vec<(String, i64)> = pred_stmt
.query_map([], |row| {
Ok((row.get::<_, String>(0)?, row.get::<_, i64>(1)?))
})?
.collect::<std::result::Result<Vec<_>, _>>()?;
let mut subj_stmt = conn.prepare(
"SELECT subject, COUNT(*) as cnt
FROM facts
GROUP BY subject
ORDER BY cnt DESC
LIMIT 10",
)?;
let top_subjects: Vec<(String, i64)> = subj_stmt
.query_map([], |row| {
Ok((row.get::<_, String>(0)?, row.get::<_, i64>(1)?))
})?
.collect::<std::result::Result<Vec<_>, _>>()?;
Ok(KnowledgeStats {
total_facts,
unique_subjects,
unique_predicates,
unique_objects,
top_predicates,
top_subjects,
})
}
}
fn map_row(row: &rusqlite::Row<'_>) -> rusqlite::Result<Fact> {
Ok(Fact {
id: row.get(0)?,
subject: row.get(1)?,
predicate: row.get(2)?,
object: row.get(3)?,
confidence: row.get(4)?,
source_memory_id: row.get(5)?,
created_at: row.get(6)?,
})
}
#[cfg(test)]
mod tests {
use super::*;
use rusqlite::Connection;
const CREATE_TABLE: &str = r#"
CREATE TABLE IF NOT EXISTS facts (
id INTEGER PRIMARY KEY AUTOINCREMENT,
subject TEXT NOT NULL,
predicate TEXT NOT NULL,
object TEXT NOT NULL,
confidence REAL NOT NULL DEFAULT 0.8,
source_memory_id INTEGER,
created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ', 'now'))
);
CREATE INDEX IF NOT EXISTS idx_facts_subject ON facts(subject);
"#;
fn setup() -> Connection {
let conn = Connection::open_in_memory().expect("in-memory db");
conn.execute_batch(CREATE_TABLE).expect("create table");
conn
}
fn insert(conn: &Connection, subject: &str, predicate: &str, object: &str, confidence: f64) {
conn.execute(
"INSERT INTO facts (subject, predicate, object, confidence, created_at)
VALUES (?1, ?2, ?3, ?4, '2026-01-01T00:00:00Z')",
params![subject, predicate, object, confidence],
)
.expect("insert fact");
}
fn seed_graph(conn: &Connection) {
insert(conn, "Alice", "works_at", "Google", 0.9);
insert(conn, "Bob", "works_at", "Google", 0.85);
insert(conn, "Google", "located_in", "California", 0.95);
insert(conn, "Carol", "lives_in", "London", 0.8);
insert(conn, "Dave", "located_in", "Paris", 0.75);
}
#[test]
fn test_match_by_subject() {
let conn = setup();
seed_graph(&conn);
let pattern = TripletPattern::any().with_subject("Alice");
let facts = TripletMatcher::match_pattern(&conn, &pattern).expect("match");
assert_eq!(facts.len(), 1);
assert_eq!(facts[0].subject, "Alice");
assert_eq!(facts[0].predicate, "works_at");
assert_eq!(facts[0].object, "Google");
}
#[test]
fn test_match_by_predicate() {
let conn = setup();
seed_graph(&conn);
let pattern = TripletPattern::any().with_predicate("works_at");
let facts = TripletMatcher::match_pattern(&conn, &pattern).expect("match");
assert_eq!(facts.len(), 2);
let subjects: Vec<&str> = facts.iter().map(|f| f.subject.as_str()).collect();
assert!(subjects.contains(&"Alice"));
assert!(subjects.contains(&"Bob"));
}
#[test]
fn test_match_by_object() {
let conn = setup();
seed_graph(&conn);
let pattern = TripletPattern::any().with_object("Google");
let facts = TripletMatcher::match_pattern(&conn, &pattern).expect("match");
assert_eq!(facts.len(), 2);
}
#[test]
fn test_wildcard_match_returns_all() {
let conn = setup();
seed_graph(&conn);
let pattern = TripletPattern::any();
let facts = TripletMatcher::match_pattern(&conn, &pattern).expect("match");
assert_eq!(facts.len(), 5);
}
#[test]
fn test_match_case_insensitive() {
let conn = setup();
seed_graph(&conn);
let pattern = TripletPattern::any().with_subject("alice");
let facts = TripletMatcher::match_pattern(&conn, &pattern).expect("match");
assert_eq!(facts.len(), 1);
}
#[test]
fn test_no_matches_returns_empty() {
let conn = setup();
seed_graph(&conn);
let pattern = TripletPattern::any().with_subject("Nonexistent");
let facts = TripletMatcher::match_pattern(&conn, &pattern).expect("match");
assert!(facts.is_empty());
}
#[test]
fn test_match_subject_and_predicate() {
let conn = setup();
seed_graph(&conn);
let pattern = TripletPattern::any()
.with_subject("Google")
.with_predicate("located_in");
let facts = TripletMatcher::match_pattern(&conn, &pattern).expect("match");
assert_eq!(facts.len(), 1);
assert_eq!(facts[0].object, "California");
}
#[test]
fn test_transitive_inference_two_hops() {
let conn = setup();
insert(&conn, "Alice", "works_at", "Google", 0.9);
insert(&conn, "Google", "works_at", "Alphabet", 0.8);
let paths = TripletMatcher::infer_transitive(&conn, "Alice", "works_at", 3).expect("infer");
assert!(!paths.is_empty(), "expected at least one inference path");
let longest = paths.iter().max_by_key(|p| p.steps.len()).unwrap();
assert_eq!(longest.steps.len(), 2);
assert_eq!(longest.steps[0].subject, "Alice");
assert_eq!(longest.steps[0].object, "Google");
assert_eq!(longest.steps[1].subject, "Google");
assert_eq!(longest.steps[1].object, "Alphabet");
}
#[test]
fn test_transitive_inference_no_matching_predicate() {
let conn = setup();
seed_graph(&conn);
let paths = TripletMatcher::infer_transitive(&conn, "Alice", "lives_in", 3).expect("infer");
assert!(paths.is_empty());
}
#[test]
fn test_transitive_inference_max_hops_zero() {
let conn = setup();
seed_graph(&conn);
let paths = TripletMatcher::infer_transitive(&conn, "Alice", "works_at", 0).expect("infer");
assert!(paths.is_empty());
}
#[test]
fn test_transitive_confidence_product() {
let conn = setup();
insert(&conn, "Alice", "rel", "B", 0.9);
insert(&conn, "B", "rel", "C", 0.8);
let paths = TripletMatcher::infer_transitive(&conn, "Alice", "rel", 5).expect("infer");
let two_hop = paths.iter().find(|p| p.steps.len() == 2);
assert!(two_hop.is_some(), "expected 2-hop path");
let conf = two_hop.unwrap().confidence;
assert!(
(conf - 0.9 * 0.8).abs() < 1e-9,
"expected confidence ~0.72, got {conf}"
);
}
#[test]
fn test_query_knowledge_by_entity() {
let conn = setup();
seed_graph(&conn);
let facts = TripletMatcher::query_knowledge(&conn, "What does Alice do?").expect("query");
assert!(
facts.iter().any(|f| f.subject == "Alice"),
"expected Alice fact, got: {:?}",
facts
);
}
#[test]
fn test_query_knowledge_no_entities_returns_empty() {
let conn = setup();
seed_graph(&conn);
let facts =
TripletMatcher::query_knowledge(&conn, "what does everyone do?").expect("query");
assert!(facts.is_empty());
}
#[test]
fn test_query_knowledge_multiple_entities() {
let conn = setup();
seed_graph(&conn);
let facts =
TripletMatcher::query_knowledge(&conn, "Tell me about Alice and Carol").expect("query");
let subjects: Vec<&str> = facts.iter().map(|f| f.subject.as_str()).collect();
assert!(subjects.contains(&"Alice"), "expected Alice fact");
assert!(subjects.contains(&"Carol"), "expected Carol fact");
}
#[test]
fn test_knowledge_stats_empty_table() {
let conn = setup();
let stats = TripletMatcher::knowledge_stats(&conn).expect("stats");
assert_eq!(stats.total_facts, 0);
assert_eq!(stats.unique_subjects, 0);
assert_eq!(stats.unique_predicates, 0);
assert_eq!(stats.unique_objects, 0);
assert!(stats.top_predicates.is_empty());
assert!(stats.top_subjects.is_empty());
}
#[test]
fn test_knowledge_stats_with_data() {
let conn = setup();
seed_graph(&conn);
let stats = TripletMatcher::knowledge_stats(&conn).expect("stats");
assert_eq!(stats.total_facts, 5);
assert_eq!(stats.unique_subjects, 5);
assert_eq!(stats.unique_predicates, 3);
assert_eq!(stats.unique_objects, 4);
assert!(!stats.top_predicates.is_empty());
assert_eq!(stats.top_predicates[0].0, "works_at");
assert_eq!(stats.top_predicates[0].1, 2);
}
}