use std::future::Future;
use std::ops::Deref;
use serde::{Deserialize, Serialize};
use crate::llm::{LlmError, LlmProvider};
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct Triple {
pub subject: String,
pub relation: String,
pub object: String,
#[serde(default = "default_confidence")]
pub confidence: f32,
}
fn default_confidence() -> f32 {
1.0
}
#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
pub struct TripleSet {
#[serde(default)]
triples: Vec<Triple>,
}
impl TripleSet {
pub fn try_new(raw: &str) -> Result<Self, LlmError> {
let trimmed = raw.trim();
if trimmed.is_empty() {
return Err(LlmError::Parse("empty llm reply".to_string()));
}
if trimmed.len() > TRIPLE_REPLY_MAX_CHARS {
return Err(LlmError::Parse(format!(
"reply too long: len={} > max={TRIPLE_REPLY_MAX_CHARS}",
trimmed.len()
)));
}
let json_slice = crate::llm::locate_json_object(trimmed)
.ok_or_else(|| LlmError::Parse(format!("no balanced json object found in len={}", trimmed.len())))?;
serde_json::from_str(json_slice)
.map_err(|err| LlmError::Parse(format!("json deserialize failed at len={}: {err}", json_slice.len())))
}
pub fn into_inner(self) -> Vec<Triple> {
self.triples
}
}
impl Deref for TripleSet {
type Target = [Triple];
fn deref(&self) -> &Self::Target {
&self.triples
}
}
impl IntoIterator for TripleSet {
type Item = Triple;
type IntoIter = std::vec::IntoIter<Triple>;
fn into_iter(self) -> Self::IntoIter {
self.triples.into_iter()
}
}
impl<'a> IntoIterator for &'a TripleSet {
type Item = &'a Triple;
type IntoIter = std::slice::Iter<'a, Triple>;
fn into_iter(self) -> Self::IntoIter {
self.triples.iter()
}
}
impl FromIterator<Triple> for TripleSet {
fn from_iter<I: IntoIterator<Item = Triple>>(iter: I) -> Self {
Self {
triples: iter.into_iter().collect(),
}
}
}
pub const DEFAULT_TRIPLE_PROMPT: &str = "\
You extract relationships from text as subject-relation-object triples.
Return ONLY a JSON object of the form:
{\"triples\": [{\"subject\": \"...\", \"relation\": \"...\", \"object\": \"...\", \"confidence\": 0.0}]}
Rules:
- subject and object are concrete entities (people, places, organizations, things).
- relation is a short verb phrase in your own words (e.g. \"works at\", \"prefers\", \"lives in\").
- confidence is your certainty from 0.0 to 1.0.
- Extract only relationships the text actually states. Emit an empty list if there are none.
- Do not add commentary outside the JSON object.";
pub const TRIPLE_REPLY_MAX_CHARS: usize = 100_000;
pub trait TripleExtractor: Send + Sync + 'static {
fn extract(&self, content: &str) -> impl Future<Output = Result<TripleSet, LlmError>> + Send;
}
pub struct LlmExtractor<P> {
provider: P,
prompt: String,
}
impl<P: LlmProvider> LlmExtractor<P> {
pub fn new(provider: P) -> Self {
Self {
provider,
prompt: DEFAULT_TRIPLE_PROMPT.to_string(),
}
}
#[must_use]
pub fn with_prompt(mut self, prompt: impl Into<String>) -> Self {
self.prompt = prompt.into();
self
}
}
impl<P: LlmProvider> TripleExtractor for LlmExtractor<P> {
async fn extract(&self, content: &str) -> Result<TripleSet, LlmError> {
let raw = self.provider.extract(&self.prompt, content).await?;
TripleSet::try_new(&raw)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn should_parse_well_formed_triple_reply() {
let raw = r#"{"triples":[{"subject":"Alice","relation":"works at","object":"Acme","confidence":0.9}]}"#;
let triples = TripleSet::try_new(raw).unwrap();
assert_eq!(triples.len(), 1);
assert_eq!(triples[0].subject, "Alice");
assert_eq!(triples[0].relation, "works at");
assert_eq!(triples[0].object, "Acme");
assert_eq!(triples[0].confidence, 0.9);
}
#[test]
fn should_parse_reply_wrapped_in_prose_and_fences() {
let raw = "Here are the triples:\n```json\n{\"triples\":[{\"subject\":\"Bob\",\"relation\":\"lives in\",\"object\":\"Paris\"}]}\n```\nDone.";
let triples = TripleSet::try_new(raw).unwrap();
assert_eq!(triples.len(), 1);
assert_eq!(triples[0].object, "Paris");
}
#[test]
fn should_default_confidence_when_absent() {
let raw = r#"{"triples":[{"subject":"Bob","relation":"likes","object":"tea"}]}"#;
let triples = TripleSet::try_new(raw).unwrap();
assert_eq!(triples[0].confidence, 1.0);
}
#[test]
fn should_return_empty_set_for_empty_triple_list() {
let triples = TripleSet::try_new(r#"{"triples":[]}"#).unwrap();
assert!(triples.is_empty());
}
#[test]
fn should_reject_empty_reply() {
assert!(TripleSet::try_new(" ").is_err());
}
#[test]
fn should_reject_reply_with_no_json() {
assert!(TripleSet::try_new("no json here").is_err());
}
struct StubProvider {
reply: String,
}
impl LlmProvider for StubProvider {
async fn extract(&self, _preamble: &str, _content: &str) -> Result<String, LlmError> {
Ok(self.reply.clone())
}
}
#[tokio::test(flavor = "current_thread")]
async fn should_extract_triples_through_the_trait() {
let provider = StubProvider {
reply: r#"{"triples":[{"subject":"Alice","relation":"works at","object":"Acme","confidence":0.8}]}"#
.to_string(),
};
let extractor = LlmExtractor::new(provider);
let triples = extractor.extract("Alice works at Acme.").await.unwrap();
assert_eq!(triples.len(), 1);
assert_eq!(triples[0].subject, "Alice");
assert_eq!(triples[0].relation, "works at");
}
}