#![cfg_attr(not(feature = "async"), allow(unused_imports))]
use crate::{
core::{Document, Result, TextChunk},
ollama::{OllamaClient, OllamaConfig, OllamaGenerationParams},
};
use std::collections::{HashMap, HashSet};
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct ContextualEnricherConfig {
pub enabled: bool,
pub keep_alive: String,
pub safety_margin: f32,
pub max_output_tokens: u32,
pub context_separator: String,
}
impl Default for ContextualEnricherConfig {
fn default() -> Self {
Self {
enabled: true,
keep_alive: "1h".to_string(),
safety_margin: 0.05,
max_output_tokens: 150,
context_separator: "\n\n".to_string(),
}
}
}
pub struct ContextualEnricher {
#[cfg_attr(not(feature = "async"), allow(dead_code))]
client: OllamaClient,
config: ContextualEnricherConfig,
}
impl ContextualEnricher {
pub fn new(ollama_config: OllamaConfig, enricher_config: ContextualEnricherConfig) -> Self {
let client = OllamaClient::new(ollama_config);
Self {
client,
config: enricher_config,
}
}
pub fn with_defaults(ollama_config: OllamaConfig) -> Self {
Self::new(ollama_config, ContextualEnricherConfig::default())
}
pub fn estimate_tokens(text: &str) -> u32 {
(text.len() / 4) as u32
}
pub fn calculate_num_ctx(&self, document_text: &str, chunks: &[TextChunk]) -> u32 {
let instruction_tokens = 100u32;
let doc_tokens = Self::estimate_tokens(document_text);
let max_chunk_tokens = chunks
.iter()
.map(|c| Self::estimate_tokens(&c.content))
.max()
.unwrap_or(0);
let base =
instruction_tokens + doc_tokens + max_chunk_tokens + self.config.max_output_tokens;
let with_margin = (base as f32 * (1.0 + self.config.safety_margin)) as u32;
let rounded = ((with_margin + 1023) / 1024) * 1024;
rounded.clamp(4096, 131_072)
}
#[cfg(feature = "async")]
fn build_prompt(document_text: &str, chunk_text: &str) -> String {
format!(
"<document>\n{document}\n</document>\n\n\
Here is the chunk we want to situate within the whole document:\n\
<chunk>\n{chunk}\n</chunk>\n\n\
Please give a short succinct context to situate this chunk within \
the overall document for the purposes of improving search retrieval \
of the chunk. Answer only with the succinct context and nothing else.",
document = document_text,
chunk = chunk_text,
)
}
#[cfg(feature = "ureq")]
async fn enrich_one(
&self,
chunk: &TextChunk,
document_text: &str,
num_ctx: u32,
) -> Result<String> {
let prompt = Self::build_prompt(document_text, &chunk.content);
let params = OllamaGenerationParams {
num_predict: Some(self.config.max_output_tokens),
temperature: Some(0.1), num_ctx: Some(num_ctx),
keep_alive: Some(self.config.keep_alive.clone()),
..Default::default()
};
let context = self.client.generate_with_params(&prompt, params).await?;
Ok(format!(
"{}{}{}",
context.trim(),
self.config.context_separator,
chunk.content,
))
}
#[cfg(feature = "ureq")]
pub async fn enrich_document_chunks(
&self,
document: &Document,
chunks: &[TextChunk],
) -> Result<Vec<TextChunk>> {
if !self.config.enabled {
return Ok(chunks.to_vec());
}
let doc_chunks: Vec<&TextChunk> = chunks
.iter()
.filter(|c| c.document_id == document.id)
.collect();
if doc_chunks.is_empty() {
return Ok(chunks.to_vec());
}
let doc_chunk_owned: Vec<TextChunk> = doc_chunks.iter().map(|c| (*c).clone()).collect();
let num_ctx = self.calculate_num_ctx(&document.content, &doc_chunk_owned);
#[cfg(feature = "tracing")]
tracing::info!(
doc = %document.title,
chunks = doc_chunks.len(),
num_ctx,
"Starting contextual enrichment (KV cache enabled)",
);
let mut result = chunks.to_vec();
for (i, chunk) in doc_chunks.iter().enumerate() {
#[cfg(feature = "tracing")]
tracing::debug!(
"Enriching chunk {}/{} (id={})",
i + 1,
doc_chunks.len(),
chunk.id,
);
match self.enrich_one(chunk, &document.content, num_ctx).await {
Ok(enriched_content) => {
if let Some(target) = result.iter_mut().find(|c| c.id == chunk.id) {
target.content = enriched_content;
}
},
Err(e) => {
#[cfg(feature = "tracing")]
tracing::warn!(
chunk_id = %chunk.id,
error = %e,
"Contextual enrichment failed for chunk, keeping original",
);
},
}
}
Ok(result)
}
#[cfg(feature = "ureq")]
pub async fn enrich_chunks(
&self,
documents: &[Document],
chunks: Vec<TextChunk>,
) -> Result<Vec<TextChunk>> {
if !self.config.enabled {
return Ok(chunks);
}
let doc_map: HashMap<_, &Document> = documents.iter().map(|d| (d.id.clone(), d)).collect();
let mut enriched = chunks.clone();
let mut processed: HashSet<_> = HashSet::new();
for chunk in &chunks {
if processed.contains(&chunk.document_id) {
continue;
}
if let Some(doc) = doc_map.get(&chunk.document_id) {
enriched = self.enrich_document_chunks(doc, &enriched).await?;
processed.insert(chunk.document_id.clone());
}
}
Ok(enriched)
}
pub fn config(&self) -> &ContextualEnricherConfig {
&self.config
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::{ChunkId, DocumentId};
fn make_document(content: &str) -> Document {
Document::new(
DocumentId::new("doc_test".to_string()),
"Test Document".to_string(),
content.to_string(),
)
}
fn make_chunk(doc_id: DocumentId, id: &str, content: &str) -> TextChunk {
TextChunk::new(
ChunkId::new(id.to_string()),
doc_id,
content.to_string(),
0,
content.len(),
)
}
#[test]
fn test_estimate_tokens() {
let text = "a".repeat(400);
assert_eq!(ContextualEnricher::estimate_tokens(&text), 100);
}
#[test]
fn test_calculate_num_ctx_minimum() {
let config = OllamaConfig::default();
let enricher = ContextualEnricher::with_defaults(config);
let doc = make_document("short document");
let chunk = make_chunk(doc.id.clone(), "c0", "short chunk");
assert!(enricher.calculate_num_ctx(&doc.content, &[chunk]) >= 4096);
}
#[test]
fn test_calculate_num_ctx_large_document() {
let config = OllamaConfig::default();
let enricher = ContextualEnricher::with_defaults(config);
let doc_content = "word ".repeat(36_000);
let chunk_content = "word ".repeat(500); let doc = make_document(&doc_content);
let chunk = make_chunk(doc.id.clone(), "c0", &chunk_content);
let num_ctx = enricher.calculate_num_ctx(&doc.content, &[chunk]);
assert!(num_ctx > 36_000 + 500);
assert_eq!(num_ctx % 1024, 0);
}
#[test]
fn test_build_prompt_contains_document_and_chunk() {
let prompt = ContextualEnricher::build_prompt("full document text", "chunk excerpt");
assert!(prompt.contains("full document text"));
assert!(prompt.contains("chunk excerpt"));
assert!(prompt.contains("<document>"));
assert!(prompt.contains("<chunk>"));
}
#[test]
fn test_disabled_enricher_returns_original() {
let config_with_disabled = ContextualEnricherConfig {
enabled: false,
..Default::default()
};
assert!(!config_with_disabled.enabled);
}
}