memoir_core/graph/
extraction.rs1use std::future::Future;
14use std::ops::Deref;
15
16use serde::{Deserialize, Serialize};
17
18use crate::llm::{LlmError, LlmProvider};
19
20#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
25pub struct Triple {
26 pub subject: String,
27 pub relation: String,
28 pub object: String,
29 #[serde(default = "default_confidence")]
30 pub confidence: f32,
31}
32
33fn default_confidence() -> f32 {
34 1.0
35}
36
37#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
43pub struct TripleSet {
44 #[serde(default)]
45 triples: Vec<Triple>,
46}
47
48impl TripleSet {
49 pub fn try_new(raw: &str) -> Result<Self, LlmError> {
60 let trimmed = raw.trim();
61 if trimmed.is_empty() {
62 return Err(LlmError::Parse("empty llm reply".to_string()));
63 }
64 if trimmed.len() > TRIPLE_REPLY_MAX_CHARS {
65 return Err(LlmError::Parse(format!(
66 "reply too long: len={} > max={TRIPLE_REPLY_MAX_CHARS}",
67 trimmed.len()
68 )));
69 }
70
71 let json_slice = crate::llm::locate_json_object(trimmed)
72 .ok_or_else(|| LlmError::Parse(format!("no balanced json object found in len={}", trimmed.len())))?;
73
74 serde_json::from_str(json_slice)
75 .map_err(|err| LlmError::Parse(format!("json deserialize failed at len={}: {err}", json_slice.len())))
76 }
77
78 pub fn into_inner(self) -> Vec<Triple> {
80 self.triples
81 }
82}
83
84impl Deref for TripleSet {
85 type Target = [Triple];
86
87 fn deref(&self) -> &Self::Target {
88 &self.triples
89 }
90}
91
92impl IntoIterator for TripleSet {
93 type Item = Triple;
94 type IntoIter = std::vec::IntoIter<Triple>;
95
96 fn into_iter(self) -> Self::IntoIter {
97 self.triples.into_iter()
98 }
99}
100
101impl<'a> IntoIterator for &'a TripleSet {
102 type Item = &'a Triple;
103 type IntoIter = std::slice::Iter<'a, Triple>;
104
105 fn into_iter(self) -> Self::IntoIter {
106 self.triples.iter()
107 }
108}
109
110impl FromIterator<Triple> for TripleSet {
111 fn from_iter<I: IntoIterator<Item = Triple>>(iter: I) -> Self {
112 Self {
113 triples: iter.into_iter().collect(),
114 }
115 }
116}
117
118pub const DEFAULT_TRIPLE_PROMPT: &str = "\
120You extract relationships from text as subject-relation-object triples.
121Return ONLY a JSON object of the form:
122{\"triples\": [{\"subject\": \"...\", \"relation\": \"...\", \"object\": \"...\", \"confidence\": 0.0}]}
123Rules:
124- subject and object are concrete entities (people, places, organizations, things).
125- relation is a short verb phrase in your own words (e.g. \"works at\", \"prefers\", \"lives in\").
126- confidence is your certainty from 0.0 to 1.0.
127- Extract only relationships the text actually states. Emit an empty list if there are none.
128- Do not add commentary outside the JSON object.";
129
130pub const TRIPLE_REPLY_MAX_CHARS: usize = 100_000;
132
133pub trait TripleExtractor: Send + Sync + 'static {
139 fn extract(&self, content: &str) -> impl Future<Output = Result<TripleSet, LlmError>> + Send;
146}
147
148pub struct LlmExtractor<P> {
153 provider: P,
154 prompt: String,
155}
156
157impl<P: LlmProvider> LlmExtractor<P> {
158 pub fn new(provider: P) -> Self {
160 Self {
161 provider,
162 prompt: DEFAULT_TRIPLE_PROMPT.to_string(),
163 }
164 }
165
166 #[must_use]
168 pub fn with_prompt(mut self, prompt: impl Into<String>) -> Self {
169 self.prompt = prompt.into();
170 self
171 }
172}
173
174impl<P: LlmProvider> TripleExtractor for LlmExtractor<P> {
175 async fn extract(&self, content: &str) -> Result<TripleSet, LlmError> {
176 let raw = self.provider.extract(&self.prompt, content).await?;
177 TripleSet::try_new(&raw)
178 }
179}
180
181#[cfg(test)]
182mod tests {
183 use super::*;
184
185 #[test]
186 fn should_parse_well_formed_triple_reply() {
187 let raw = r#"{"triples":[{"subject":"Alice","relation":"works at","object":"Acme","confidence":0.9}]}"#;
188 let triples = TripleSet::try_new(raw).unwrap();
189 assert_eq!(triples.len(), 1);
190 assert_eq!(triples[0].subject, "Alice");
191 assert_eq!(triples[0].relation, "works at");
192 assert_eq!(triples[0].object, "Acme");
193 assert_eq!(triples[0].confidence, 0.9);
194 }
195
196 #[test]
197 fn should_parse_reply_wrapped_in_prose_and_fences() {
198 let raw = "Here are the triples:\n```json\n{\"triples\":[{\"subject\":\"Bob\",\"relation\":\"lives in\",\"object\":\"Paris\"}]}\n```\nDone.";
199 let triples = TripleSet::try_new(raw).unwrap();
200 assert_eq!(triples.len(), 1);
201 assert_eq!(triples[0].object, "Paris");
202 }
203
204 #[test]
205 fn should_default_confidence_when_absent() {
206 let raw = r#"{"triples":[{"subject":"Bob","relation":"likes","object":"tea"}]}"#;
207 let triples = TripleSet::try_new(raw).unwrap();
208 assert_eq!(triples[0].confidence, 1.0);
209 }
210
211 #[test]
212 fn should_return_empty_set_for_empty_triple_list() {
213 let triples = TripleSet::try_new(r#"{"triples":[]}"#).unwrap();
214 assert!(triples.is_empty());
215 }
216
217 #[test]
218 fn should_reject_empty_reply() {
219 assert!(TripleSet::try_new(" ").is_err());
220 }
221
222 #[test]
223 fn should_reject_reply_with_no_json() {
224 assert!(TripleSet::try_new("no json here").is_err());
225 }
226
227 struct StubProvider {
228 reply: String,
229 }
230
231 impl LlmProvider for StubProvider {
232 async fn extract(&self, _preamble: &str, _content: &str) -> Result<String, LlmError> {
233 Ok(self.reply.clone())
234 }
235 }
236
237 #[tokio::test(flavor = "current_thread")]
238 async fn should_extract_triples_through_the_trait() {
239 let provider = StubProvider {
240 reply: r#"{"triples":[{"subject":"Alice","relation":"works at","object":"Acme","confidence":0.8}]}"#
241 .to_string(),
242 };
243 let extractor = LlmExtractor::new(provider);
244
245 let triples = extractor.extract("Alice works at Acme.").await.unwrap();
246
247 assert_eq!(triples.len(), 1);
248 assert_eq!(triples[0].subject, "Alice");
249 assert_eq!(triples[0].relation, "works at");
250 }
251}