neomemx 0.1.2

A high-performance memory library for AI agents with semantic search
Documentation
//! Fact extraction logic

use crate::error::NeomemxError;
use crate::error::Result;
use crate::extraction::types::FactItem;
use crate::llm::{LlmBase, Message};
use async_trait::async_trait;
use serde::Deserialize;
use std::sync::Arc;

const DEFAULT_EXTRACTION_PROMPT: &str = "You are a helpful assistant that extracts atomic facts from text. Return a JSON array of fact strings.";

/// Trait for fact extraction
#[async_trait]
pub trait FactExtractor: Send + Sync {
    /// Extract facts from text
    async fn extract(&self, text: &str, custom_prompt: Option<&str>) -> Result<Vec<String>>;
}

/// Default fact extractor implementation using LLM
pub struct LlmFactExtractor {
    llm: Arc<dyn LlmBase>,
    default_prompt: String,
}

impl LlmFactExtractor {
    /// Create a new LLM-based fact extractor
    pub fn new(llm: Arc<dyn LlmBase>) -> Self {
        Self {
            llm,
            default_prompt: DEFAULT_EXTRACTION_PROMPT.to_string(),
        }
    }
}

#[async_trait]
impl FactExtractor for LlmFactExtractor {
    async fn extract(&self, text: &str, custom_prompt: Option<&str>) -> Result<Vec<String>> {
        if text.trim().is_empty() {
            return Ok(Vec::new());
        }

        let prompt = custom_prompt.unwrap_or(&self.default_prompt);

        let messages = vec![
            Message::system(prompt),
            Message::user(&format!("Text to extract facts from:\n\n{}", text)),
        ];

        let response = self.llm.generate_json(messages).await?;

        let parsed: ExtractionResponse =
            serde_json::from_str(&response).map_err(NeomemxError::JsonError)?;

        Ok(parsed.into_strings())
    }
}

#[derive(Debug, Deserialize)]
#[serde(untagged)]
enum ExtractionResponse {
    Wrapped { facts: Vec<FactItem> },
    Direct(Vec<FactItem>),
}

impl ExtractionResponse {
    fn into_strings(self) -> Vec<String> {
        match self {
            ExtractionResponse::Wrapped { facts } => facts,
            ExtractionResponse::Direct(facts) => facts,
        }
        .into_iter()
        .filter_map(|f| f.into_string())
        .collect()
    }
}