1use engram_llm::{ChatLlm, ChatMessage, LlmError};
17use serde::{Deserialize, Serialize};
18
19const FACT_EXTRACTION_PROMPT: &str = r#"You extract atomic factual claims from text as (subject, predicate, object) triples for a memory system.
20
21RULES:
22- subject: the entity the claim is ABOUT (a person, place, organization, project, gene, drug, etc). Use the proper name when available.
23- predicate: a normalized verb phrase in snake_case. Examples: works_at, lives_in, prefers, owns, graduated_from, born_on, has_role, is_member_of, married_to, founded_by, located_in, treats, inhibits, uses, prefers_language, has_age, has_height.
24- object: the value of the claim. Keep it concise (a name, place, number, date, or short noun phrase).
25- confidence: 0.0..=1.0. 1.0 = explicit declarative statement; 0.7 = paraphrase or implicit; 0.4 = inferred from indirect mention.
26
27SKIP:
28- Greetings, opinions, hedged statements ("might", "I think", "probably").
29- Statements about anonymous "I" / "me" / "the user" with no identifiable subject.
30- Pure questions or commands.
31- Generic facts that aren't about a specific entity ("the sky is blue").
32
33OUTPUT:
34A JSON array of objects with keys: subject, predicate, object, confidence.
35If there are no extractable facts, output exactly: []
36Output ONLY the JSON. No prose, no code fences, no commentary.
37
38EXAMPLE INPUT:
39"Ada Example founded Example Labs in 2024 and prefers Rust over Go for CLI tools because of single-binary deployment."
40
41EXAMPLE OUTPUT:
42[
43 {"subject":"Ada Example","predicate":"founded","object":"Example Labs","confidence":1.0},
44 {"subject":"Example Labs","predicate":"founded_in","object":"2024","confidence":1.0},
45 {"subject":"Ada Example","predicate":"prefers_language","object":"Rust","confidence":1.0}
46]"#;
47
48#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct ExtractedFact {
50 pub subject: String,
51 pub predicate: String,
52 pub object: String,
53 #[serde(default = "default_confidence")]
54 pub confidence: f32,
55}
56
57fn default_confidence() -> f32 {
58 1.0
59}
60
61#[derive(Debug, thiserror::Error)]
62pub enum FactExtractionError {
63 #[error("llm error: {0}")]
64 Llm(#[from] LlmError),
65 #[error("could not parse fact JSON: {0}")]
66 Parse(String),
67}
68
69pub async fn extract_facts<L: ChatLlm + ?Sized>(
72 llm: &L,
73 text: &str,
74) -> Result<Vec<ExtractedFact>, FactExtractionError> {
75 let resp = llm
76 .chat(&[
77 ChatMessage::system(FACT_EXTRACTION_PROMPT),
78 ChatMessage::user(text.to_string()),
79 ])
80 .await?;
81 parse_extraction_output(&resp.content)
82}
83
84pub fn parse_extraction_output(content: &str) -> Result<Vec<ExtractedFact>, FactExtractionError> {
87 let trimmed = content.trim();
88 let json = if let Some(rest) = trimmed.strip_prefix("```json") {
90 rest.trim_end_matches("```").trim()
91 } else if let Some(rest) = trimmed.strip_prefix("```") {
92 rest.trim_end_matches("```").trim()
93 } else {
94 trimmed
95 };
96 let start = json.find('[').unwrap_or(0);
98 let json = &json[start..];
99 let end = json.rfind(']').map(|i| i + 1).unwrap_or(json.len());
100 let json = &json[..end];
101
102 serde_json::from_str::<Vec<ExtractedFact>>(json)
103 .map_err(|e| FactExtractionError::Parse(format!("{e}; raw: {}", content)))
104}
105
106pub fn normalize(s: &str) -> String {
109 s.trim().to_lowercase()
110}
111
112#[cfg(test)]
113mod tests {
114 use super::*;
115
116 #[test]
117 fn parse_bare_array() {
118 let raw =
119 r#"[{"subject":"Alice","predicate":"works_at","object":"Acme","confidence":1.0}]"#;
120 let f = parse_extraction_output(raw).unwrap();
121 assert_eq!(f.len(), 1);
122 assert_eq!(f[0].subject, "Alice");
123 assert_eq!(f[0].predicate, "works_at");
124 assert_eq!(f[0].confidence, 1.0);
125 }
126
127 #[test]
128 fn parse_code_fenced_array() {
129 let raw = "```json\n[{\"subject\":\"Bob\",\"predicate\":\"lives_in\",\"object\":\"Berlin\",\"confidence\":0.9}]\n```";
130 let f = parse_extraction_output(raw).unwrap();
131 assert_eq!(f.len(), 1);
132 assert_eq!(f[0].object, "Berlin");
133 }
134
135 #[test]
136 fn parse_with_leading_prose_is_tolerant() {
137 let raw = "Here are the extracted facts:\n[{\"subject\":\"X\",\"predicate\":\"is_a\",\"object\":\"Y\",\"confidence\":1.0}]";
138 let f = parse_extraction_output(raw).unwrap();
139 assert_eq!(f.len(), 1);
140 }
141
142 #[test]
143 fn parse_empty_array() {
144 let f = parse_extraction_output("[]").unwrap();
145 assert!(f.is_empty());
146 }
147
148 #[test]
149 fn missing_confidence_defaults_to_1() {
150 let raw = r#"[{"subject":"X","predicate":"is_a","object":"Y"}]"#;
151 let f = parse_extraction_output(raw).unwrap();
152 assert_eq!(f[0].confidence, 1.0);
153 }
154
155 #[test]
156 fn normalize_lowercases_and_trims() {
157 assert_eq!(normalize(" Ada Example "), "ada example");
158 assert_eq!(normalize("Example Labs"), "example labs");
159 }
160}