use async_trait::async_trait;
use chrono::{DateTime, Utc};
use juncture_core::state::messages::Message;
use serde::{Deserialize, Serialize};
use crate::llm::{CallOptions, ChatModel, LlmError};
use crate::memory::MemoryError;
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Fact {
pub topic: String,
pub claim: String,
pub source: String,
pub confidence: f64,
pub timestamp: DateTime<Utc>,
}
impl Fact {
#[must_use]
pub fn new(topic: String, claim: String, source: String, confidence: f64) -> Self {
assert!(
(0.0..=1.0).contains(&confidence),
"confidence must be between 0.0 and 1.0, got {confidence}"
);
Self {
topic,
claim,
source,
confidence,
timestamp: Utc::now(),
}
}
}
#[async_trait]
pub trait FactExtractor: Send + Sync + 'static {
async fn extract(&self, text: &str) -> Result<Vec<Fact>, MemoryError>;
}
#[derive(Clone, Debug)]
pub struct LlmFactExtractor<M: ChatModel> {
model: M,
}
#[derive(Deserialize)]
struct RawFact {
topic: String,
claim: String,
#[serde(default)]
source: String,
confidence: f64,
}
impl<M: ChatModel> LlmFactExtractor<M> {
#[must_use]
pub const fn new(model: M) -> Self {
Self { model }
}
fn clean_json_response(raw: &str) -> String {
let mut s = raw.trim().to_string();
if let Some(stripped) = s.strip_prefix("```json") {
s = stripped.to_string();
} else if let Some(stripped) = s.strip_prefix("```") {
s = stripped.to_string();
}
if let Some(stripped) = s.strip_suffix("```") {
s = stripped.to_string();
}
s.trim().to_string()
}
}
#[async_trait]
impl<M: ChatModel> FactExtractor for LlmFactExtractor<M> {
async fn extract(&self, text: &str) -> Result<Vec<Fact>, MemoryError> {
let trimmed = text.trim();
if trimmed.len() < 50 {
return Ok(Vec::new());
}
let system_prompt = Message::system(
"Extract key facts from the following text. Return a JSON array of objects \
with fields: topic, claim, source, confidence (0.0-1.0). Output ONLY the \
JSON array, no other text.",
);
let user_msg = Message::human(trimmed);
let options = CallOptions {
max_tokens: Some(2000),
..Default::default()
};
let response = juncture_core::wasm_send::force_send(
self.model
.invoke(&[system_prompt, user_msg], Some(&options)),
)
.await
.map_err(|e| match e {
LlmError::InvalidResponse(msg) => {
MemoryError::ExtractionFailed(format!("invalid response: {msg}"))
}
#[cfg(any(feature = "anthropic", feature = "openai", feature = "ollama"))]
LlmError::NetworkError(e) => {
MemoryError::ExtractionFailed(format!("network error: {e}"))
}
_ => MemoryError::ExtractionFailed(format!("LLM error: {e}")),
})?;
let raw_response = response.content_text();
let json_str = Self::clean_json_response(raw_response);
let raw_facts: Vec<RawFact> = serde_json::from_str(&json_str).map_err(|e| {
MemoryError::ExtractionFailed(format!(
"failed to parse facts as JSON array: {e}\nResponse: {json_str}"
))
})?;
let facts = raw_facts
.into_iter()
.map(|raw| Fact::new(raw.topic, raw.claim, raw.source, raw.confidence))
.collect();
Ok(facts)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_fact_construction() {
let fact = Fact::new(
"Geography".to_string(),
"Paris is the capital of France".to_string(),
"test_document.txt".to_string(),
0.95,
);
assert_eq!(fact.topic, "Geography");
assert_eq!(fact.claim, "Paris is the capital of France");
assert_eq!(fact.source, "test_document.txt");
assert!((fact.confidence - 0.95).abs() < f64::EPSILON);
assert!(fact.timestamp <= Utc::now());
}
#[test]
#[should_panic(expected = "confidence must be between 0.0 and 1.0")]
fn test_fact_invalid_confidence_high() {
let _ = Fact::new(
"Test".to_string(),
"Claim".to_string(),
"source".to_string(),
1.5,
);
}
#[test]
#[should_panic(expected = "confidence must be between 0.0 and 1.0")]
fn test_fact_invalid_confidence_low() {
let _ = Fact::new(
"Test".to_string(),
"Claim".to_string(),
"source".to_string(),
-0.1,
);
}
#[test]
fn test_fact_serialization() {
let fact = Fact::new(
"Science".to_string(),
"Water boils at 100°C".to_string(),
"chemistry.txt".to_string(),
0.99,
);
let json = serde_json::to_string(&fact).expect("serialization failed");
let deserialized: Fact = serde_json::from_str(&json).expect("deserialization failed");
assert_eq!(deserialized.topic, fact.topic);
assert_eq!(deserialized.claim, fact.claim);
assert_eq!(deserialized.source, fact.source);
assert!((deserialized.confidence - fact.confidence).abs() < f64::EPSILON);
}
#[test]
fn test_llm_fact_extractor_construction() {
let model = crate::llm::MockChatModel::new("test-model");
let extractor = LlmFactExtractor::new(model);
let _ = &extractor;
}
#[test]
fn test_clean_json_response() {
let raw = "```json\n[{\"topic\": \"test\"}]\n```";
let cleaned = LlmFactExtractor::<crate::llm::MockChatModel>::clean_json_response(raw);
assert_eq!(cleaned, "[{\"topic\": \"test\"}]");
let raw = "```\n[{\"topic\": \"test\"}]\n```";
let cleaned = LlmFactExtractor::<crate::llm::MockChatModel>::clean_json_response(raw);
assert_eq!(cleaned, "[{\"topic\": \"test\"}]");
let raw = "[{\"topic\": \"test\"}]";
let cleaned = LlmFactExtractor::<crate::llm::MockChatModel>::clean_json_response(raw);
assert_eq!(cleaned, "[{\"topic\": \"test\"}]");
}
}