memoir_core/graph/
synthesis.rs1use std::future::Future;
20
21use crate::embedding::{EmbeddingError, EmbeddingModel};
22
23use super::cosine::cosine_similarity;
24use super::{Triple, TripleSet};
25
26pub const MIN_CORROBORATION_SIMILARITY: f32 = 0.6;
34
35#[derive(Debug, Clone, PartialEq)]
41pub struct SemanticFact {
42 pub content: String,
44}
45
46pub trait Synthesizer: Send + Sync + 'static {
55 fn synthesize(
62 &self,
63 triples: TripleSet,
64 facts: &[SemanticFact],
65 ) -> impl Future<Output = Result<TripleSet, SynthesisError>> + Send;
66}
67
68#[derive(Debug, thiserror::Error)]
70pub enum SynthesisError {
71 #[error("synthesis embedding failed: {0}")]
73 Embed(#[from] EmbeddingError),
74}
75
76#[derive(Debug, Default, Clone, Copy)]
82pub struct PassthroughSynthesizer;
83
84impl PassthroughSynthesizer {
85 pub fn new() -> Self {
87 Self
88 }
89}
90
91impl Synthesizer for PassthroughSynthesizer {
92 async fn synthesize(&self, triples: TripleSet, _facts: &[SemanticFact]) -> Result<TripleSet, SynthesisError> {
93 Ok(triples)
94 }
95}
96
97pub struct EmbeddingSynthesizer<E> {
109 embedder: E,
110 min_similarity: f32,
111}
112
113impl<E: EmbeddingModel> EmbeddingSynthesizer<E> {
114 pub fn new(embedder: E) -> Self {
116 Self {
117 embedder,
118 min_similarity: MIN_CORROBORATION_SIMILARITY,
119 }
120 }
121
122 #[must_use]
124 pub fn with_min_similarity(mut self, min_similarity: f32) -> Self {
125 self.min_similarity = min_similarity;
126 self
127 }
128}
129
130impl<E: EmbeddingModel> Synthesizer for EmbeddingSynthesizer<E> {
131 async fn synthesize(&self, triples: TripleSet, facts: &[SemanticFact]) -> Result<TripleSet, SynthesisError> {
132 if facts.is_empty() {
133 return Ok(TripleSet::default());
134 }
135
136 let mut fact_embeddings = Vec::with_capacity(facts.len());
137 for fact in facts {
138 fact_embeddings.push(self.embedder.embed(&fact.content).await?);
139 }
140
141 let mut kept = Vec::new();
142 for triple in triples {
143 let rendered = render_triple(&triple);
144 let triple_embedding = self.embedder.embed(&rendered).await?;
145 let corroborated = fact_embeddings
146 .iter()
147 .filter_map(|fact| cosine_similarity(&triple_embedding, fact))
148 .any(|score| score >= self.min_similarity);
149 if corroborated {
150 kept.push(triple);
151 }
152 }
153
154 Ok(kept.into_iter().collect())
155 }
156}
157
158fn render_triple(triple: &Triple) -> String {
160 format!("{} {} {}", triple.subject, triple.relation, triple.object)
161}
162
163#[cfg(test)]
164mod tests {
165 use super::*;
166
167 fn triple(subject: &str, relation: &str, object: &str) -> Triple {
168 Triple {
169 subject: subject.to_string(),
170 relation: relation.to_string(),
171 object: object.to_string(),
172 confidence: 0.9,
173 }
174 }
175
176 fn triples(items: Vec<Triple>) -> TripleSet {
177 items.into_iter().collect()
178 }
179
180 fn fact(content: &str) -> SemanticFact {
181 SemanticFact {
182 content: content.to_string(),
183 }
184 }
185
186 struct FakeEmbedding;
189
190 impl EmbeddingModel for FakeEmbedding {
191 async fn embed(&self, text: &str) -> Result<Vec<f32>, EmbeddingError> {
192 let vector = if text.contains("Acme") {
193 vec![1.0, 0.0, 0.0]
194 } else if text.contains("Globex") {
195 vec![0.0, 1.0, 0.0]
196 } else {
197 vec![0.0, 0.0, 1.0]
198 };
199 Ok(vector)
200 }
201
202 fn dimensions(&self) -> usize {
203 3
204 }
205 }
206
207 #[tokio::test(flavor = "current_thread")]
208 async fn should_pass_all_triples_through_passthrough() {
209 let synth = PassthroughSynthesizer::new();
210 let input = triples(vec![triple("Alice", "works at", "Acme"), triple("Bob", "likes", "tea")]);
211
212 let out = synth.synthesize(input.clone(), &[]).await.unwrap();
213
214 assert_eq!(out.len(), 2);
215 }
216
217 #[tokio::test(flavor = "current_thread")]
218 async fn should_keep_corroborated_triple() {
219 let synth = EmbeddingSynthesizer::new(FakeEmbedding);
220 let input = triples(vec![triple("Alice", "works at", "Acme")]);
221
222 let out = synth.synthesize(input, &[fact("Alice works at Acme Corp")]).await.unwrap();
223
224 assert_eq!(out.len(), 1);
225 assert_eq!(out[0].object, "Acme");
226 }
227
228 #[tokio::test(flavor = "current_thread")]
229 async fn should_veto_uncorroborated_triple() {
230 let synth = EmbeddingSynthesizer::new(FakeEmbedding);
232 let input = triples(vec![triple("Alice", "works at", "Globex")]);
233
234 let out = synth.synthesize(input, &[fact("Alice works at Acme Corp")]).await.unwrap();
235
236 assert!(out.is_empty());
237 }
238
239 #[tokio::test(flavor = "current_thread")]
240 async fn should_veto_everything_when_no_facts() {
241 let synth = EmbeddingSynthesizer::new(FakeEmbedding);
243 let input = triples(vec![triple("Alice", "works at", "Acme")]);
244
245 let out = synth.synthesize(input, &[]).await.unwrap();
246
247 assert!(out.is_empty());
248 }
249
250 #[tokio::test(flavor = "current_thread")]
251 async fn should_keep_only_corroborated_among_mixed() {
252 let synth = EmbeddingSynthesizer::new(FakeEmbedding);
253 let input = triples(vec![
254 triple("Alice", "works at", "Acme"),
255 triple("Alice", "works at", "Globex"),
256 ]);
257
258 let out = synth.synthesize(input, &[fact("Alice works at Acme")]).await.unwrap();
259
260 assert_eq!(out.len(), 1);
261 assert_eq!(out[0].object, "Acme");
262 }
263
264 #[test]
265 fn should_render_triple_as_subject_relation_object() {
266 assert_eq!(render_triple(&triple("Alice", "works at", "Acme")), "Alice works at Acme");
267 }
268}