use crate::error::NeomemxError;
use crate::error::Result;
use crate::extraction::types::FactItem;
use crate::llm::{LlmBase, Message};
use async_trait::async_trait;
use serde::Deserialize;
use std::sync::Arc;
const DEFAULT_EXTRACTION_PROMPT: &str = "You are a helpful assistant that extracts atomic facts from text. Return a JSON array of fact strings.";
#[async_trait]
pub trait FactExtractor: Send + Sync {
async fn extract(&self, text: &str, custom_prompt: Option<&str>) -> Result<Vec<String>>;
}
pub struct LlmFactExtractor {
llm: Arc<dyn LlmBase>,
default_prompt: String,
}
impl LlmFactExtractor {
pub fn new(llm: Arc<dyn LlmBase>) -> Self {
Self {
llm,
default_prompt: DEFAULT_EXTRACTION_PROMPT.to_string(),
}
}
}
#[async_trait]
impl FactExtractor for LlmFactExtractor {
async fn extract(&self, text: &str, custom_prompt: Option<&str>) -> Result<Vec<String>> {
if text.trim().is_empty() {
return Ok(Vec::new());
}
let prompt = custom_prompt.unwrap_or(&self.default_prompt);
let messages = vec![
Message::system(prompt),
Message::user(&format!("Text to extract facts from:\n\n{}", text)),
];
let response = self.llm.generate_json(messages).await?;
let parsed: ExtractionResponse =
serde_json::from_str(&response).map_err(NeomemxError::JsonError)?;
Ok(parsed.into_strings())
}
}
#[derive(Debug, Deserialize)]
#[serde(untagged)]
enum ExtractionResponse {
Wrapped { facts: Vec<FactItem> },
Direct(Vec<FactItem>),
}
impl ExtractionResponse {
fn into_strings(self) -> Vec<String> {
match self {
ExtractionResponse::Wrapped { facts } => facts,
ExtractionResponse::Direct(facts) => facts,
}
.into_iter()
.filter_map(|f| f.into_string())
.collect()
}
}