use std::future::Future;
use crate::embedding::{EmbeddingError, EmbeddingModel};
use super::cosine::cosine_similarity;
use super::{Triple, TripleSet};
pub const MIN_CORROBORATION_SIMILARITY: f32 = 0.6;
#[derive(Debug, Clone, PartialEq)]
pub struct SemanticFact {
pub content: String,
}
pub trait Synthesizer: Send + Sync + 'static {
fn synthesize(
&self,
triples: TripleSet,
facts: &[SemanticFact],
) -> impl Future<Output = Result<TripleSet, SynthesisError>> + Send;
}
#[derive(Debug, thiserror::Error)]
pub enum SynthesisError {
#[error("synthesis embedding failed: {0}")]
Embed(#[from] EmbeddingError),
}
#[derive(Debug, Default, Clone, Copy)]
pub struct PassthroughSynthesizer;
impl PassthroughSynthesizer {
pub fn new() -> Self {
Self
}
}
impl Synthesizer for PassthroughSynthesizer {
async fn synthesize(&self, triples: TripleSet, _facts: &[SemanticFact]) -> Result<TripleSet, SynthesisError> {
Ok(triples)
}
}
pub struct EmbeddingSynthesizer<E> {
embedder: E,
min_similarity: f32,
}
impl<E: EmbeddingModel> EmbeddingSynthesizer<E> {
pub fn new(embedder: E) -> Self {
Self {
embedder,
min_similarity: MIN_CORROBORATION_SIMILARITY,
}
}
#[must_use]
pub fn with_min_similarity(mut self, min_similarity: f32) -> Self {
self.min_similarity = min_similarity;
self
}
}
impl<E: EmbeddingModel> Synthesizer for EmbeddingSynthesizer<E> {
async fn synthesize(&self, triples: TripleSet, facts: &[SemanticFact]) -> Result<TripleSet, SynthesisError> {
if facts.is_empty() {
return Ok(TripleSet::default());
}
let mut fact_embeddings = Vec::with_capacity(facts.len());
for fact in facts {
fact_embeddings.push(self.embedder.embed(&fact.content).await?);
}
let mut kept = Vec::new();
for triple in triples {
let rendered = render_triple(&triple);
let triple_embedding = self.embedder.embed(&rendered).await?;
let corroborated = fact_embeddings
.iter()
.filter_map(|fact| cosine_similarity(&triple_embedding, fact))
.any(|score| score >= self.min_similarity);
if corroborated {
kept.push(triple);
}
}
Ok(kept.into_iter().collect())
}
}
fn render_triple(triple: &Triple) -> String {
format!("{} {} {}", triple.subject, triple.relation, triple.object)
}
#[cfg(test)]
mod tests {
use super::*;
fn triple(subject: &str, relation: &str, object: &str) -> Triple {
Triple {
subject: subject.to_string(),
relation: relation.to_string(),
object: object.to_string(),
confidence: 0.9,
}
}
fn triples(items: Vec<Triple>) -> TripleSet {
items.into_iter().collect()
}
fn fact(content: &str) -> SemanticFact {
SemanticFact {
content: content.to_string(),
}
}
struct FakeEmbedding;
impl EmbeddingModel for FakeEmbedding {
async fn embed(&self, text: &str) -> Result<Vec<f32>, EmbeddingError> {
let vector = if text.contains("Acme") {
vec![1.0, 0.0, 0.0]
} else if text.contains("Globex") {
vec![0.0, 1.0, 0.0]
} else {
vec![0.0, 0.0, 1.0]
};
Ok(vector)
}
fn dimensions(&self) -> usize {
3
}
}
#[tokio::test(flavor = "current_thread")]
async fn should_pass_all_triples_through_passthrough() {
let synth = PassthroughSynthesizer::new();
let input = triples(vec![triple("Alice", "works at", "Acme"), triple("Bob", "likes", "tea")]);
let out = synth.synthesize(input.clone(), &[]).await.unwrap();
assert_eq!(out.len(), 2);
}
#[tokio::test(flavor = "current_thread")]
async fn should_keep_corroborated_triple() {
let synth = EmbeddingSynthesizer::new(FakeEmbedding);
let input = triples(vec![triple("Alice", "works at", "Acme")]);
let out = synth.synthesize(input, &[fact("Alice works at Acme Corp")]).await.unwrap();
assert_eq!(out.len(), 1);
assert_eq!(out[0].object, "Acme");
}
#[tokio::test(flavor = "current_thread")]
async fn should_veto_uncorroborated_triple() {
let synth = EmbeddingSynthesizer::new(FakeEmbedding);
let input = triples(vec![triple("Alice", "works at", "Globex")]);
let out = synth.synthesize(input, &[fact("Alice works at Acme Corp")]).await.unwrap();
assert!(out.is_empty());
}
#[tokio::test(flavor = "current_thread")]
async fn should_veto_everything_when_no_facts() {
let synth = EmbeddingSynthesizer::new(FakeEmbedding);
let input = triples(vec![triple("Alice", "works at", "Acme")]);
let out = synth.synthesize(input, &[]).await.unwrap();
assert!(out.is_empty());
}
#[tokio::test(flavor = "current_thread")]
async fn should_keep_only_corroborated_among_mixed() {
let synth = EmbeddingSynthesizer::new(FakeEmbedding);
let input = triples(vec![
triple("Alice", "works at", "Acme"),
triple("Alice", "works at", "Globex"),
]);
let out = synth.synthesize(input, &[fact("Alice works at Acme")]).await.unwrap();
assert_eq!(out.len(), 1);
assert_eq!(out[0].object, "Acme");
}
#[test]
fn should_render_triple_as_subject_relation_object() {
assert_eq!(render_triple(&triple("Alice", "works at", "Acme")), "Alice works at Acme");
}
}