graphrag-core 0.2.0

Core portable library for GraphRAG - works on native and WASM
Documentation
#![cfg_attr(not(feature = "async"), allow(unused_imports))]

//! Contextual Chunk Enrichment via LLM (Anthropic Contextual Retrieval pattern)
//!
//! Before embedding, an LLM augments each chunk with 2-3 sentences of context
//! derived from the entire document. Ollama's KV Cache means the document is
//! evaluated **once** for all chunks from the same document.
//!
//! ## How it works
//!
//! 1. Group chunks by source document
//! 2. For each document, calculate `num_ctx` from document + chunk sizes
//! 3. Build prompts using a static prefix (full document) + dynamic suffix (chunk)
//! 4. Call Ollama with `keep_alive` and `num_ctx` — the static prefix is cached
//! 5. Return enriched chunks: `[LLM context]\n\n[original chunk text]`
//!
//! ## Performance with KV Cache
//!
//! With Mistral-NeMo 12B on RTX 4070:
//! - First chunk: ~2 min (loads document into KV cache)
//! - Subsequent chunks: ~3-5 sec each (only the chunk is re-evaluated)
//! - ~100 chunks from a 45k-token book: **5-10 minutes total**
//!
//! Without `keep_alive`, Ollama unloads the model between requests, destroying
//! the KV cache and making this approach impractical at scale.

use crate::{
    core::{Document, Result, TextChunk},
    ollama::{OllamaClient, OllamaConfig, OllamaGenerationParams},
};
use std::collections::{HashMap, HashSet};

/// Configuration for contextual chunk enrichment
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct ContextualEnricherConfig {
    /// Enable contextual enrichment. When false, chunks are returned unchanged.
    pub enabled: bool,

    /// How long to keep the model loaded in Ollama between requests.
    ///
    /// **Critical**: Without this, Ollama unloads the model between chunks, destroying
    /// the KV cache. The document context (static prefix) must be re-evaluated for
    /// every chunk, making this approach 20-50x slower.
    ///
    /// Set to `"1h"` for processing a book; `"30m"` for shorter documents.
    pub keep_alive: String,

    /// Safety margin added to the calculated `num_ctx` (as a fraction, e.g. 0.05 = 5%).
    ///
    /// Prevents silent truncation caused by slightly underestimating token counts.
    pub safety_margin: f32,

    /// Maximum tokens for the LLM-generated context sentences (output budget).
    ///
    /// 100-200 tokens is typically enough for 2-3 context sentences.
    pub max_output_tokens: u32,

    /// Separator inserted between the generated context and the original chunk text.
    pub context_separator: String,
}

impl Default for ContextualEnricherConfig {
    fn default() -> Self {
        Self {
            enabled: true,
            keep_alive: "1h".to_string(),
            safety_margin: 0.05,
            max_output_tokens: 150,
            context_separator: "\n\n".to_string(),
        }
    }
}

/// LLM-based contextual chunk enricher (Anthropic Contextual Retrieval pattern)
///
/// Uses Ollama to prepend 2-3 context sentences to each chunk before embedding,
/// dramatically improving retrieval accuracy for long documents (typically +10-15%).
///
/// ## Example
///
/// ```ignore
/// use graphrag_core::text::contextual_enricher::{ContextualEnricher, ContextualEnricherConfig};
/// use graphrag_core::ollama::OllamaConfig;
///
/// # async fn example() -> graphrag_core::Result<()> {
/// let ollama_config = OllamaConfig {
///     enabled: true,
///     chat_model: "mistral-nemo:12b".to_string(),
///     keep_alive: Some("1h".to_string()),
///     ..OllamaConfig::default()
/// };
///
/// let enricher = ContextualEnricher::with_defaults(ollama_config);
///
/// // Process a document's chunks
/// let enriched = enricher.enrich_document_chunks(&document, &chunks).await?;
/// # Ok(())
/// # }
/// ```
pub struct ContextualEnricher {
    #[cfg_attr(not(feature = "async"), allow(dead_code))]
    client: OllamaClient,
    config: ContextualEnricherConfig,
}

impl ContextualEnricher {
    /// Create a new contextual enricher with explicit configs
    pub fn new(ollama_config: OllamaConfig, enricher_config: ContextualEnricherConfig) -> Self {
        let client = OllamaClient::new(ollama_config);
        Self {
            client,
            config: enricher_config,
        }
    }

    /// Create with default enricher config
    pub fn with_defaults(ollama_config: OllamaConfig) -> Self {
        Self::new(ollama_config, ContextualEnricherConfig::default())
    }

    /// Estimate token count from text length (approximation: chars / 4)
    ///
    /// This is a fast, dependency-free estimate. For most Western-language text,
    /// 1 token ≈ 4 characters. The safety_margin in config compensates for variance.
    pub fn estimate_tokens(text: &str) -> u32 {
        (text.len() / 4) as u32
    }

    /// Calculate the required `num_ctx` for a document and its chunks
    ///
    /// Formula:
    /// ```text
    /// num_ctx = tokens(prompt_instructions)   // ~100 fixed tokens
    ///         + tokens(full_document)          // static prefix → KV cached
    ///         + tokens(largest_chunk)          // dynamic suffix
    ///         + max_output_tokens              // generation budget
    ///         + safety_margin (5%)             // avoid edge-case truncation
    /// ```
    ///
    /// The result is rounded up to the nearest 1024 and clamped to [4096, 131072].
    pub fn calculate_num_ctx(&self, document_text: &str, chunks: &[TextChunk]) -> u32 {
        let instruction_tokens = 100u32;
        let doc_tokens = Self::estimate_tokens(document_text);
        let max_chunk_tokens = chunks
            .iter()
            .map(|c| Self::estimate_tokens(&c.content))
            .max()
            .unwrap_or(0);

        let base =
            instruction_tokens + doc_tokens + max_chunk_tokens + self.config.max_output_tokens;
        let with_margin = (base as f32 * (1.0 + self.config.safety_margin)) as u32;

        // Round up to nearest 1024 for memory alignment efficiency
        let rounded = ((with_margin + 1023) / 1024) * 1024;

        // Clamp: minimum 4096 (sensible floor), maximum 131072 (128k, typical GPU limit)
        rounded.clamp(4096, 131_072)
    }

    /// Build the contextual enrichment prompt using the KV-cache-friendly structure
    ///
    /// The document text is the **static prefix** — it will be KV-cached after the
    /// first request and reused for all subsequent chunks from the same document.
    /// Only the chunk content changes between requests.
    #[cfg(feature = "async")]
    fn build_prompt(document_text: &str, chunk_text: &str) -> String {
        format!(
            "<document>\n{document}\n</document>\n\n\
             Here is the chunk we want to situate within the whole document:\n\
             <chunk>\n{chunk}\n</chunk>\n\n\
             Please give a short succinct context to situate this chunk within \
             the overall document for the purposes of improving search retrieval \
             of the chunk. Answer only with the succinct context and nothing else.",
            document = document_text,
            chunk = chunk_text,
        )
    }

    /// Enrich a single chunk with LLM-generated context
    #[cfg(feature = "ureq")]
    async fn enrich_one(
        &self,
        chunk: &TextChunk,
        document_text: &str,
        num_ctx: u32,
    ) -> Result<String> {
        let prompt = Self::build_prompt(document_text, &chunk.content);

        let params = OllamaGenerationParams {
            num_predict: Some(self.config.max_output_tokens),
            temperature: Some(0.1), // low temperature for factual, deterministic context
            num_ctx: Some(num_ctx),
            keep_alive: Some(self.config.keep_alive.clone()),
            ..Default::default()
        };

        let context = self.client.generate_with_params(&prompt, params).await?;

        Ok(format!(
            "{}{}{}",
            context.trim(),
            self.config.context_separator,
            chunk.content,
        ))
    }

    /// Enrich all chunks belonging to a single document
    ///
    /// All chunks are processed with the same `num_ctx` value (calculated once),
    /// and Ollama's KV cache preserves the document context across all requests.
    #[cfg(feature = "ureq")]
    pub async fn enrich_document_chunks(
        &self,
        document: &Document,
        chunks: &[TextChunk],
    ) -> Result<Vec<TextChunk>> {
        if !self.config.enabled {
            return Ok(chunks.to_vec());
        }

        let doc_chunks: Vec<&TextChunk> = chunks
            .iter()
            .filter(|c| c.document_id == document.id)
            .collect();

        if doc_chunks.is_empty() {
            return Ok(chunks.to_vec());
        }

        let doc_chunk_owned: Vec<TextChunk> = doc_chunks.iter().map(|c| (*c).clone()).collect();
        let num_ctx = self.calculate_num_ctx(&document.content, &doc_chunk_owned);

        #[cfg(feature = "tracing")]
        tracing::info!(
            doc = %document.title,
            chunks = doc_chunks.len(),
            num_ctx,
            "Starting contextual enrichment (KV cache enabled)",
        );

        let mut result = chunks.to_vec();

        for (i, chunk) in doc_chunks.iter().enumerate() {
            #[cfg(feature = "tracing")]
            tracing::debug!(
                "Enriching chunk {}/{} (id={})",
                i + 1,
                doc_chunks.len(),
                chunk.id,
            );

            match self.enrich_one(chunk, &document.content, num_ctx).await {
                Ok(enriched_content) => {
                    if let Some(target) = result.iter_mut().find(|c| c.id == chunk.id) {
                        target.content = enriched_content;
                    }
                },
                Err(e) => {
                    // Fail gracefully: keep original chunk on error
                    #[cfg(feature = "tracing")]
                    tracing::warn!(
                        chunk_id = %chunk.id,
                        error = %e,
                        "Contextual enrichment failed for chunk, keeping original",
                    );
                },
            }
        }

        Ok(result)
    }

    /// Enrich chunks from multiple documents in a single pipeline call
    ///
    /// Documents are processed one at a time so Ollama can maximise KV cache
    /// reuse (the document context stays loaded for all its chunks before
    /// moving to the next document).
    #[cfg(feature = "ureq")]
    pub async fn enrich_chunks(
        &self,
        documents: &[Document],
        chunks: Vec<TextChunk>,
    ) -> Result<Vec<TextChunk>> {
        if !self.config.enabled {
            return Ok(chunks);
        }

        let doc_map: HashMap<_, &Document> = documents.iter().map(|d| (d.id.clone(), d)).collect();

        let mut enriched = chunks.clone();
        let mut processed: HashSet<_> = HashSet::new();

        for chunk in &chunks {
            if processed.contains(&chunk.document_id) {
                continue;
            }
            if let Some(doc) = doc_map.get(&chunk.document_id) {
                enriched = self.enrich_document_chunks(doc, &enriched).await?;
                processed.insert(chunk.document_id.clone());
            }
        }

        Ok(enriched)
    }

    /// Access the enricher configuration
    pub fn config(&self) -> &ContextualEnricherConfig {
        &self.config
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::core::{ChunkId, DocumentId};

    fn make_document(content: &str) -> Document {
        Document::new(
            DocumentId::new("doc_test".to_string()),
            "Test Document".to_string(),
            content.to_string(),
        )
    }

    fn make_chunk(doc_id: DocumentId, id: &str, content: &str) -> TextChunk {
        TextChunk::new(
            ChunkId::new(id.to_string()),
            doc_id,
            content.to_string(),
            0,
            content.len(),
        )
    }

    #[test]
    fn test_estimate_tokens() {
        let text = "a".repeat(400);
        assert_eq!(ContextualEnricher::estimate_tokens(&text), 100);
    }

    #[test]
    fn test_calculate_num_ctx_minimum() {
        let config = OllamaConfig::default();
        let enricher = ContextualEnricher::with_defaults(config);
        let doc = make_document("short document");
        let chunk = make_chunk(doc.id.clone(), "c0", "short chunk");
        // Should always be at least 4096
        assert!(enricher.calculate_num_ctx(&doc.content, &[chunk]) >= 4096);
    }

    #[test]
    fn test_calculate_num_ctx_large_document() {
        let config = OllamaConfig::default();
        let enricher = ContextualEnricher::with_defaults(config);
        // Simulate a 45k-token book (~180k chars)
        let doc_content = "word ".repeat(36_000);
        let chunk_content = "word ".repeat(500); // ~500 tokens
        let doc = make_document(&doc_content);
        let chunk = make_chunk(doc.id.clone(), "c0", &chunk_content);

        let num_ctx = enricher.calculate_num_ctx(&doc.content, &[chunk]);
        // Should be > doc tokens + chunk tokens
        assert!(num_ctx > 36_000 + 500);
        // Should be a multiple of 1024
        assert_eq!(num_ctx % 1024, 0);
    }

    #[test]
    fn test_build_prompt_contains_document_and_chunk() {
        let prompt = ContextualEnricher::build_prompt("full document text", "chunk excerpt");
        assert!(prompt.contains("full document text"));
        assert!(prompt.contains("chunk excerpt"));
        assert!(prompt.contains("<document>"));
        assert!(prompt.contains("<chunk>"));
    }

    #[test]
    fn test_disabled_enricher_returns_original() {
        // When disabled, enrich_chunks should be a no-op
        let config_with_disabled = ContextualEnricherConfig {
            enabled: false,
            ..Default::default()
        };
        // We can verify the config
        assert!(!config_with_disabled.enabled);
    }
}