1use std::collections::HashMap;
16
17use chrono::Utc;
18use once_cell::sync::Lazy;
19use regex::Regex;
20use rusqlite::{params, Connection};
21use serde::{Deserialize, Serialize};
22
23use crate::error::Result;
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct Fact {
32 pub id: i64,
34 pub subject: String,
36 pub predicate: String,
38 pub object: String,
40 pub confidence: f32,
42 pub source_memory_id: Option<i64>,
44 pub created_at: String,
46}
47
48#[derive(Debug, Clone)]
50pub struct ExtractedFact {
51 pub subject: String,
52 pub predicate: String,
53 pub object: String,
54 pub confidence: f32,
55}
56
57pub trait FactExtractor: Send + Sync {
63 fn extract_facts(&self, text: &str) -> Vec<ExtractedFact>;
64}
65
66static IS_A_PATTERN: Lazy<Regex> = Lazy::new(|| {
72 Regex::new(r"(?i)\b([A-Za-z][A-Za-z\s]{0,40}?)\s+is\s+an?\s+([A-Za-z][A-Za-z\s]{0,60}?)\b(?:[,\.\!]|$)")
73 .expect("valid regex")
74});
75
76static IS_PATTERN: Lazy<Regex> = Lazy::new(|| {
78 Regex::new(
79 r"(?i)\b([A-Za-z][A-Za-z\s]{0,40}?)\s+is\s+([A-Za-z][A-Za-z\s]{0,60}?)\b(?:[,\.\!]|$)",
80 )
81 .expect("valid regex")
82});
83
84static WORKS_AT_PATTERN: Lazy<Regex> = Lazy::new(|| {
86 Regex::new(r"(?i)\b([A-Za-z][A-Za-z\s]{0,40}?)\s+works?\s+at\s+([A-Za-z0-9][A-Za-z0-9\s\.\-]{0,60}?)\b(?:[,\.\!]|$)")
87 .expect("valid regex")
88});
89
90static LIVES_IN_PATTERN: Lazy<Regex> = Lazy::new(|| {
92 Regex::new(r"(?i)\b([A-Za-z][A-Za-z\s]{0,40}?)\s+lives?\s+in\s+([A-Za-z][A-Za-z\s]{0,60}?)\b(?:[,\.\!]|$)")
93 .expect("valid regex")
94});
95
96static LIKES_PATTERN: Lazy<Regex> = Lazy::new(|| {
98 Regex::new(
99 r"(?i)\b([A-Za-z][A-Za-z\s]{0,40}?)\s+likes?\s+([A-Za-z][A-Za-z\s]{0,60}?)\b(?:[,\.\!]|$)",
100 )
101 .expect("valid regex")
102});
103
104static BORN_IN_PATTERN: Lazy<Regex> = Lazy::new(|| {
106 Regex::new(r"(?i)\b([A-Za-z][A-Za-z\s]{0,40}?)\s+was\s+born\s+in\s+([A-Za-z][A-Za-z\s]{0,60}?)\b(?:[,\.\!]|$)")
107 .expect("valid regex")
108});
109
110static MANAGES_PATTERN: Lazy<Regex> = Lazy::new(|| {
112 Regex::new(r"(?i)\b([A-Za-z][A-Za-z\s]{0,40}?)\s+manages?\s+([A-Za-z][A-Za-z\s]{0,60}?)\b(?:[,\.\!]|$)")
113 .expect("valid regex")
114});
115
116static REPORTS_TO_PATTERN: Lazy<Regex> = Lazy::new(|| {
118 Regex::new(r"(?i)\b([A-Za-z][A-Za-z\s]{0,40}?)\s+reports?\s+to\s+([A-Za-z][A-Za-z\s]{0,60}?)\b(?:[,\.\!]|$)")
119 .expect("valid regex")
120});
121
122static CREATED_PATTERN: Lazy<Regex> = Lazy::new(|| {
124 Regex::new(r"(?i)\b([A-Za-z][A-Za-z\s]{0,40}?)\s+created?\s+([A-Za-z][A-Za-z\s]{0,60}?)\b(?:[,\.\!]|$)")
125 .expect("valid regex")
126});
127
128static STRUCTURED_PATTERN: Lazy<Regex> = Lazy::new(|| {
130 Regex::new(
131 r"(?m)^(?:Name|Role|Location|Title|Company|Organization|Department|Team)\s*:\s*(.+)$",
132 )
133 .expect("valid regex")
134});
135
136pub struct RuleBasedExtractor;
144
145impl RuleBasedExtractor {
146 pub fn new() -> Self {
147 Self
148 }
149}
150
151impl Default for RuleBasedExtractor {
152 fn default() -> Self {
153 Self::new()
154 }
155}
156
157impl FactExtractor for RuleBasedExtractor {
158 fn extract_facts(&self, text: &str) -> Vec<ExtractedFact> {
159 let text = text.trim();
160 if text.is_empty() {
161 return Vec::new();
162 }
163
164 let mut facts = Vec::new();
165
166 for cap in STRUCTURED_PATTERN.captures_iter(text) {
168 if let (Some(key_m), Some(val_m)) = (cap.get(0), cap.get(1)) {
169 let full = key_m.as_str();
171 let colon_pos = full.find(':').unwrap_or(full.len());
172 let key = full[..colon_pos].trim().to_lowercase().replace(' ', "_");
173 let value = val_m.as_str().trim().to_string();
174
175 if !key.is_empty() && !value.is_empty() {
176 facts.push(ExtractedFact {
179 subject: "entity".to_string(),
180 predicate: key,
181 object: value,
182 confidence: 0.9,
183 });
184 }
185 }
186 }
187
188 apply_pattern(&IS_A_PATTERN, text, "is_a", 0.8, &mut facts);
190
191 apply_pattern(&IS_PATTERN, text, "is", 0.8, &mut facts);
193
194 apply_pattern(&WORKS_AT_PATTERN, text, "works_at", 0.8, &mut facts);
196
197 apply_pattern(&LIVES_IN_PATTERN, text, "lives_in", 0.8, &mut facts);
199
200 apply_pattern(&LIKES_PATTERN, text, "likes", 0.8, &mut facts);
202
203 apply_pattern(&BORN_IN_PATTERN, text, "born_in", 0.8, &mut facts);
205
206 apply_pattern(&MANAGES_PATTERN, text, "manages", 0.8, &mut facts);
208
209 apply_pattern(&REPORTS_TO_PATTERN, text, "reports_to", 0.8, &mut facts);
211
212 apply_pattern(&CREATED_PATTERN, text, "created", 0.8, &mut facts);
214
215 facts
216 }
217}
218
219fn apply_pattern(
221 pattern: &Regex,
222 text: &str,
223 predicate: &str,
224 confidence: f32,
225 out: &mut Vec<ExtractedFact>,
226) {
227 for cap in pattern.captures_iter(text) {
228 let subject_raw = match cap.get(1) {
229 Some(m) => m.as_str().trim(),
230 None => continue,
231 };
232 let object_raw = match cap.get(2) {
233 Some(m) => m.as_str().trim(),
234 None => continue,
235 };
236
237 if subject_raw.is_empty() || object_raw.is_empty() {
238 continue;
239 }
240
241 if subject_raw.len() < 2 || object_raw.len() < 2 {
243 continue;
244 }
245
246 out.push(ExtractedFact {
247 subject: title_case(subject_raw),
248 predicate: predicate.to_string(),
249 object: object_raw.to_string(),
250 confidence,
251 });
252 }
253}
254
255fn title_case(s: &str) -> String {
257 s.split_whitespace()
258 .map(|word| {
259 let mut chars = word.chars();
260 match chars.next() {
261 None => String::new(),
262 Some(first) => first.to_uppercase().to_string() + &chars.as_str().to_lowercase(),
263 }
264 })
265 .collect::<Vec<_>>()
266 .join(" ")
267}
268
269pub struct ConversationProcessor {
275 extractor: Box<dyn FactExtractor>,
276}
277
278impl ConversationProcessor {
279 pub fn new(extractor: Box<dyn FactExtractor>) -> Self {
280 Self { extractor }
281 }
282
283 pub fn process_text(&self, text: &str, source_memory_id: Option<i64>) -> Vec<ExtractedFact> {
286 let raw = self.extractor.extract_facts(text);
287 let _ = source_memory_id; dedup_facts(raw)
289 }
290
291 pub fn process_conversation(
293 &self,
294 messages: &[&str],
295 source_memory_id: Option<i64>,
296 ) -> Vec<ExtractedFact> {
297 let _ = source_memory_id;
298 let raw: Vec<ExtractedFact> = messages
299 .iter()
300 .flat_map(|msg| self.extractor.extract_facts(msg))
301 .collect();
302 dedup_facts(raw)
303 }
304}
305
306fn dedup_facts(facts: Vec<ExtractedFact>) -> Vec<ExtractedFact> {
308 let mut map: HashMap<(String, String, String), ExtractedFact> = HashMap::new();
309 for fact in facts {
310 let key = (
311 fact.subject.clone(),
312 fact.predicate.clone(),
313 fact.object.clone(),
314 );
315 map.entry(key)
316 .and_modify(|existing| {
317 if fact.confidence > existing.confidence {
318 existing.confidence = fact.confidence;
319 }
320 })
321 .or_insert(fact);
322 }
323 map.into_values().collect()
324}
325
326pub const CREATE_FACTS_TABLE: &str = r#"
332 CREATE TABLE IF NOT EXISTS facts (
333 id INTEGER PRIMARY KEY AUTOINCREMENT,
334 subject TEXT NOT NULL,
335 predicate TEXT NOT NULL,
336 object TEXT NOT NULL,
337 confidence REAL NOT NULL DEFAULT 0.5,
338 source_memory_id INTEGER,
339 created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ', 'now')),
340 UNIQUE(subject, predicate, object)
341 );
342 CREATE INDEX IF NOT EXISTS idx_facts_subject ON facts(subject);
343 CREATE INDEX IF NOT EXISTS idx_facts_source ON facts(source_memory_id);
344"#;
345
346pub fn create_fact(
350 conn: &Connection,
351 fact: &ExtractedFact,
352 source_id: Option<i64>,
353) -> Result<Fact> {
354 let now = Utc::now().format("%Y-%m-%dT%H:%M:%SZ").to_string();
355
356 conn.execute(
357 "INSERT OR IGNORE INTO facts (subject, predicate, object, confidence, source_memory_id, created_at)
358 VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
359 params![
360 fact.subject,
361 fact.predicate,
362 fact.object,
363 fact.confidence,
364 source_id,
365 now,
366 ],
367 )?;
368
369 let stored = conn.query_row(
370 "SELECT id, subject, predicate, object, confidence, source_memory_id, created_at
371 FROM facts
372 WHERE subject = ?1 AND predicate = ?2 AND object = ?3",
373 params![fact.subject, fact.predicate, fact.object],
374 |row| {
375 Ok(Fact {
376 id: row.get(0)?,
377 subject: row.get(1)?,
378 predicate: row.get(2)?,
379 object: row.get(3)?,
380 confidence: row.get(4)?,
381 source_memory_id: row.get(5)?,
382 created_at: row.get(6)?,
383 })
384 },
385 )?;
386
387 Ok(stored)
388}
389
390pub fn list_facts(
394 conn: &Connection,
395 source_memory_id: Option<i64>,
396 limit: usize,
397) -> Result<Vec<Fact>> {
398 let effective_limit = if limit == 0 { i64::MAX } else { limit as i64 };
399
400 let mut stmt = match source_memory_id {
401 Some(sid) => {
402 let mut s = conn.prepare(
403 "SELECT id, subject, predicate, object, confidence, source_memory_id, created_at
404 FROM facts
405 WHERE source_memory_id = ?1
406 ORDER BY id ASC
407 LIMIT ?2",
408 )?;
409 let rows = s.query_map(params![sid, effective_limit], map_row)?;
410 return rows
411 .collect::<std::result::Result<Vec<Fact>, _>>()
412 .map_err(Into::into);
413 }
414 None => conn.prepare(
415 "SELECT id, subject, predicate, object, confidence, source_memory_id, created_at
416 FROM facts
417 ORDER BY id ASC
418 LIMIT ?1",
419 )?,
420 };
421
422 let rows = stmt.query_map(params![effective_limit], map_row)?;
423 rows.collect::<std::result::Result<Vec<Fact>, _>>()
424 .map_err(Into::into)
425}
426
427pub fn get_fact_graph(conn: &Connection, subject: &str) -> Result<Vec<Fact>> {
429 let mut stmt = conn.prepare(
430 "SELECT id, subject, predicate, object, confidence, source_memory_id, created_at
431 FROM facts
432 WHERE lower(subject) = lower(?1)
433 ORDER BY id ASC",
434 )?;
435 let rows = stmt.query_map(params![subject], map_row)?;
436 rows.collect::<std::result::Result<Vec<Fact>, _>>()
437 .map_err(Into::into)
438}
439
440pub fn delete_facts_for_memory(conn: &Connection, memory_id: i64) -> Result<usize> {
444 let deleted = conn.execute(
445 "DELETE FROM facts WHERE source_memory_id = ?1",
446 params![memory_id],
447 )?;
448 Ok(deleted)
449}
450
451fn map_row(row: &rusqlite::Row<'_>) -> rusqlite::Result<Fact> {
453 Ok(Fact {
454 id: row.get(0)?,
455 subject: row.get(1)?,
456 predicate: row.get(2)?,
457 object: row.get(3)?,
458 confidence: row.get(4)?,
459 source_memory_id: row.get(5)?,
460 created_at: row.get(6)?,
461 })
462}
463
464#[cfg(test)]
469mod tests {
470 use super::*;
471 use rusqlite::Connection;
472
473 fn make_extractor() -> RuleBasedExtractor {
474 RuleBasedExtractor::new()
475 }
476
477 fn make_processor() -> ConversationProcessor {
478 ConversationProcessor::new(Box::new(RuleBasedExtractor::new()))
479 }
480
481 fn in_memory_conn() -> Connection {
482 let conn = Connection::open_in_memory().expect("in-memory db");
483 conn.execute_batch(CREATE_FACTS_TABLE)
484 .expect("create table");
485 conn
486 }
487
488 #[test]
493 fn test_extract_is_pattern() {
494 let ex = make_extractor();
495 let facts = ex.extract_facts("Alice is a developer");
496 assert!(!facts.is_empty(), "expected at least one fact");
498 let fact = facts
499 .iter()
500 .find(|f| f.predicate == "is_a" || f.predicate == "is");
501 assert!(fact.is_some(), "expected 'is_a' or 'is' predicate");
502 let fact = fact.unwrap();
503 assert!(
504 fact.subject.to_lowercase().contains("alice"),
505 "subject should be Alice, got {}",
506 fact.subject
507 );
508 }
509
510 #[test]
511 fn test_extract_works_at() {
512 let ex = make_extractor();
513 let facts = ex.extract_facts("Bob works at Google");
514 let fact = facts.iter().find(|f| f.predicate == "works_at");
515 assert!(fact.is_some(), "expected works_at fact, got: {:?}", facts);
516 let fact = fact.unwrap();
517 assert!(fact.subject.to_lowercase().contains("bob"));
518 assert!(fact.object.to_lowercase().contains("google"));
519 }
520
521 #[test]
522 fn test_extract_lives_in() {
523 let ex = make_extractor();
524 let facts = ex.extract_facts("Carol lives in Tokyo");
525 let fact = facts.iter().find(|f| f.predicate == "lives_in");
526 assert!(fact.is_some(), "expected lives_in fact, got: {:?}", facts);
527 let fact = fact.unwrap();
528 assert!(fact.subject.to_lowercase().contains("carol"));
529 assert!(fact.object.to_lowercase().contains("tokyo"));
530 }
531
532 #[test]
533 fn test_extract_structured_field() {
534 let ex = make_extractor();
535 let text = "Name: David\nRole: Manager";
536 let facts = ex.extract_facts(text);
537 let has_name = facts
539 .iter()
540 .any(|f| f.predicate == "name" && f.object.contains("David"));
541 let has_role = facts
542 .iter()
543 .any(|f| f.predicate == "role" && f.object.contains("Manager"));
544 assert!(has_name, "expected name fact, got: {:?}", facts);
545 assert!(has_role, "expected role fact, got: {:?}", facts);
546 }
547
548 #[test]
549 fn test_extract_multiple_facts() {
550 let ex = make_extractor();
551 let text = "Emma works at Acme. She lives in Paris. Emma likes music.";
552 let facts = ex.extract_facts(text);
553 assert!(
555 facts.len() >= 3,
556 "expected at least 3 facts, got {}: {:?}",
557 facts.len(),
558 facts
559 );
560 }
561
562 #[test]
563 fn test_dedup_same_fact() {
564 let facts = vec![
566 ExtractedFact {
567 subject: "Alice".to_string(),
568 predicate: "works_at".to_string(),
569 object: "Acme".to_string(),
570 confidence: 0.7,
571 },
572 ExtractedFact {
573 subject: "Alice".to_string(),
574 predicate: "works_at".to_string(),
575 object: "Acme".to_string(),
576 confidence: 0.9,
577 },
578 ];
579 let deduped = dedup_facts(facts);
580 assert_eq!(deduped.len(), 1);
581 assert!(
582 (deduped[0].confidence - 0.9).abs() < f32::EPSILON,
583 "expected confidence 0.9, got {}",
584 deduped[0].confidence
585 );
586 }
587
588 #[test]
589 fn test_empty_text() {
590 let ex = make_extractor();
591 assert!(ex.extract_facts("").is_empty());
592 assert!(ex.extract_facts(" ").is_empty());
593 assert!(ex.extract_facts("\n\t\n").is_empty());
594 }
595
596 #[test]
597 fn test_conversation_processing() {
598 let proc = make_processor();
599 let messages = &[
600 "Alice works at Google.",
601 "Bob lives in London.",
602 "Alice works at Google.", ];
604 let facts = proc.process_conversation(messages, None);
605 let alice_google: Vec<_> = facts
607 .iter()
608 .filter(|f| {
609 f.predicate == "works_at"
610 && f.subject.to_lowercase().contains("alice")
611 && f.object.to_lowercase().contains("google")
612 })
613 .collect();
614 assert_eq!(alice_google.len(), 1, "duplicate should be deduped");
615
616 let bob_london = facts.iter().any(|f| {
618 f.predicate == "lives_in"
619 && f.subject.to_lowercase().contains("bob")
620 && f.object.to_lowercase().contains("london")
621 });
622 assert!(bob_london, "expected Bob lives_in London fact");
623 }
624
625 #[test]
630 fn test_storage_create_and_list() {
631 let conn = in_memory_conn();
632
633 let fact = ExtractedFact {
634 subject: "Frank".to_string(),
635 predicate: "works_at".to_string(),
636 object: "Mozilla".to_string(),
637 confidence: 0.85,
638 };
639
640 let stored = create_fact(&conn, &fact, Some(42)).expect("create fact");
641 assert!(stored.id > 0);
642 assert_eq!(stored.subject, "Frank");
643 assert_eq!(stored.predicate, "works_at");
644 assert_eq!(stored.object, "Mozilla");
645 assert_eq!(stored.source_memory_id, Some(42));
646
647 let all = list_facts(&conn, None, 100).expect("list facts");
648 assert_eq!(all.len(), 1);
649 assert_eq!(all[0].id, stored.id);
650 }
651
652 #[test]
653 fn test_storage_fact_graph() {
654 let conn = in_memory_conn();
655
656 let facts_in = vec![
657 ExtractedFact {
658 subject: "Grace".to_string(),
659 predicate: "works_at".to_string(),
660 object: "Stripe".to_string(),
661 confidence: 0.8,
662 },
663 ExtractedFact {
664 subject: "Grace".to_string(),
665 predicate: "lives_in".to_string(),
666 object: "Dublin".to_string(),
667 confidence: 0.8,
668 },
669 ExtractedFact {
670 subject: "Henry".to_string(),
671 predicate: "works_at".to_string(),
672 object: "Stripe".to_string(),
673 confidence: 0.8,
674 },
675 ];
676
677 for f in &facts_in {
678 create_fact(&conn, f, None).expect("create");
679 }
680
681 let graph = get_fact_graph(&conn, "Grace").expect("get graph");
682 assert_eq!(graph.len(), 2);
683 assert!(graph.iter().all(|f| f.subject == "Grace"));
684
685 let graph2 = get_fact_graph(&conn, "grace").expect("case insensitive");
687 assert_eq!(graph2.len(), 2);
688 }
689
690 #[test]
691 fn test_storage_delete_for_memory() {
692 let conn = in_memory_conn();
693
694 let f1 = ExtractedFact {
695 subject: "Iris".to_string(),
696 predicate: "works_at".to_string(),
697 object: "Corp".to_string(),
698 confidence: 0.8,
699 };
700 let f2 = ExtractedFact {
701 subject: "Jack".to_string(),
702 predicate: "lives_in".to_string(),
703 object: "Berlin".to_string(),
704 confidence: 0.8,
705 };
706
707 create_fact(&conn, &f1, Some(10)).expect("create f1");
708 create_fact(&conn, &f2, Some(20)).expect("create f2");
709
710 let deleted = delete_facts_for_memory(&conn, 10).expect("delete");
711 assert_eq!(deleted, 1);
712
713 let remaining = list_facts(&conn, None, 100).expect("list");
714 assert_eq!(remaining.len(), 1);
715 assert_eq!(remaining[0].subject, "Jack");
716 }
717
718 #[test]
719 fn test_storage_list_filter_by_source() {
720 let conn = in_memory_conn();
721
722 for i in 0..3_i64 {
723 let f = ExtractedFact {
724 subject: format!("Person{}", i),
725 predicate: "works_at".to_string(),
726 object: "Acme".to_string(),
727 confidence: 0.8,
728 };
729 create_fact(&conn, &f, Some(i + 1)).expect("create");
730 }
731
732 let filtered = list_facts(&conn, Some(2), 100).expect("list filtered");
733 assert_eq!(filtered.len(), 1);
734 assert_eq!(filtered[0].subject, "Person1");
735 }
736
737 #[test]
738 fn test_title_case() {
739 assert_eq!(title_case("alice"), "Alice");
740 assert_eq!(title_case("alice smith"), "Alice Smith");
741 assert_eq!(title_case("ALICE"), "Alice");
742 assert_eq!(title_case(""), "");
743 }
744
745 #[test]
746 fn test_confidence_range() {
747 let ex = make_extractor();
748 let facts = ex.extract_facts("Sam works at Acme. Name: Sam\nRole: Engineer.");
749 for f in &facts {
750 assert!(
751 f.confidence >= 0.0 && f.confidence <= 1.0,
752 "confidence out of range: {}",
753 f.confidence
754 );
755 }
756 }
757}