memvid_cli/
contextual.rs

1//! Contextual retrieval module for improving chunk embeddings.
2//!
3//! Based on Anthropic's contextual retrieval technique: before embedding each chunk,
4//! we prepend a context summary that places the chunk within the larger document.
5//! This helps semantic search find chunks that would otherwise miss due to lack of context.
6//!
7//! For example, a chunk saying "I've been using basil and mint in my cooking lately"
8//! might get a context prefix like:
9//! "This is a conversation where the user discusses their cooking preferences
10//! and mentions growing herbs in their garden."
11//!
12//! This allows semantic queries like "dinner with homegrown ingredients" to find
13//! the chunk even though it doesn't explicitly mention "dinner" or "homegrown".
14
15use anyhow::{anyhow, Result};
16use reqwest::blocking::Client;
17use serde::{Deserialize, Serialize};
18#[cfg(feature = "llama-cpp")]
19use std::path::PathBuf;
20use tracing::{debug, info, warn};
21
22/// The contextual prompt for generating chunk context.
23const CONTEXTUAL_PROMPT: &str = r#"You are a document analysis assistant. Given a document and a chunk from that document, provide a brief context that situates the chunk within the document.
24
25<document>
26{document}
27</document>
28
29<chunk>
30{chunk}
31</chunk>
32
33Provide a short context (2-3 sentences max) that:
341. Summarizes the document's topic and purpose
352. Notes any user preferences, personal information, or key facts mentioned in the document
363. Explains what this specific chunk is about within that context
37
38Focus especially on first-person statements, preferences, and personal context that might be important for later retrieval.
39
40Respond with ONLY the context, no preamble or explanation."#;
41
42/// OpenAI API request message
43#[derive(Debug, Serialize, Clone)]
44struct ChatMessage {
45    role: String,
46    content: String,
47}
48
49/// OpenAI API request
50#[derive(Debug, Serialize)]
51struct ChatRequest {
52    model: String,
53    messages: Vec<ChatMessage>,
54    max_tokens: u32,
55    temperature: f32,
56}
57
58/// OpenAI API response
59#[derive(Debug, Deserialize)]
60struct ChatResponse {
61    choices: Vec<ChatChoice>,
62}
63
64#[derive(Debug, Deserialize)]
65struct ChatChoice {
66    message: ChatMessageResponse,
67}
68
69#[derive(Debug, Deserialize)]
70struct ChatMessageResponse {
71    content: String,
72}
73
74/// Contextual retrieval engine that can use either OpenAI or local models.
75pub enum ContextualEngine {
76    /// OpenAI API-based context generation
77    OpenAI { api_key: String, model: String },
78    /// Local LLM-based context generation (llama.cpp)
79    #[cfg(feature = "llama-cpp")]
80    Local { model_path: PathBuf },
81}
82
83impl ContextualEngine {
84    /// Create a new OpenAI-based contextual engine.
85    pub fn openai() -> Result<Self> {
86        let api_key = std::env::var("OPENAI_API_KEY")
87            .map_err(|_| anyhow!("OPENAI_API_KEY environment variable not set"))?;
88        Ok(Self::OpenAI {
89            api_key,
90            model: "gpt-4o-mini".to_string(),
91        })
92    }
93
94    /// Create a new OpenAI-based contextual engine with a specific model.
95    pub fn openai_with_model(model: &str) -> Result<Self> {
96        let api_key = std::env::var("OPENAI_API_KEY")
97            .map_err(|_| anyhow!("OPENAI_API_KEY environment variable not set"))?;
98        Ok(Self::OpenAI {
99            api_key,
100            model: model.to_string(),
101        })
102    }
103
104    /// Create a new local LLM-based contextual engine.
105    #[cfg(feature = "llama-cpp")]
106    pub fn local(model_path: PathBuf) -> Self {
107        Self::Local { model_path }
108    }
109
110    /// Generate context for a chunk within a document.
111    /// Returns the context string to prepend to the chunk before embedding.
112    pub fn generate_context(&self, document: &str, chunk: &str) -> Result<String> {
113        match self {
114            Self::OpenAI { api_key, model } => {
115                Self::generate_context_openai(api_key, model, document, chunk)
116            }
117            #[cfg(feature = "llama-cpp")]
118            Self::Local { model_path } => Self::generate_context_local(model_path, document, chunk),
119        }
120    }
121
122    /// Generate contextual prefixes for multiple chunks in parallel (OpenAI only).
123    /// Returns a vector of context strings in the same order as the input chunks.
124    pub fn generate_contexts_batch(
125        &self,
126        document: &str,
127        chunks: &[String],
128    ) -> Result<Vec<String>> {
129        match self {
130            Self::OpenAI { api_key, model } => {
131                Self::generate_contexts_batch_openai(api_key, model, document, chunks)
132            }
133            #[cfg(feature = "llama-cpp")]
134            Self::Local { model_path } => {
135                // Local models don't support batching efficiently, fall back to sequential
136                let mut contexts = Vec::with_capacity(chunks.len());
137                for chunk in chunks {
138                    let ctx = Self::generate_context_local(model_path, document, chunk)?;
139                    contexts.push(ctx);
140                }
141                Ok(contexts)
142            }
143        }
144    }
145
146    /// Generate context using OpenAI API.
147    fn generate_context_openai(
148        api_key: &str,
149        model: &str,
150        document: &str,
151        chunk: &str,
152    ) -> Result<String> {
153        let client = Client::new();
154
155        // Truncate document if too long (keep first ~6000 chars to fit in context)
156        let truncated_doc = if document.len() > 6000 {
157            format!("{}...[truncated]", &document[..6000])
158        } else {
159            document.to_string()
160        };
161
162        let prompt = CONTEXTUAL_PROMPT
163            .replace("{document}", &truncated_doc)
164            .replace("{chunk}", chunk);
165
166        let request = ChatRequest {
167            model: model.to_string(),
168            messages: vec![ChatMessage {
169                role: "user".to_string(),
170                content: prompt,
171            }],
172            max_tokens: 200,
173            temperature: 0.0,
174        };
175
176        let response = client
177            .post("https://api.openai.com/v1/chat/completions")
178            .header("Authorization", format!("Bearer {}", api_key))
179            .header("Content-Type", "application/json")
180            .json(&request)
181            .send()
182            .map_err(|e| anyhow!("OpenAI API request failed: {}", e))?;
183
184        if !response.status().is_success() {
185            let status = response.status();
186            let body = response.text().unwrap_or_default();
187            return Err(anyhow!("OpenAI API error {}: {}", status, body));
188        }
189
190        let chat_response: ChatResponse = response
191            .json()
192            .map_err(|e| anyhow!("Failed to parse OpenAI response: {}", e))?;
193
194        chat_response
195            .choices
196            .first()
197            .map(|c| c.message.content.clone())
198            .ok_or_else(|| anyhow!("No response from OpenAI"))
199    }
200
201    /// Generate contexts for chunks sequentially using OpenAI.
202    /// Uses sequential processing to avoid issues with blocking HTTP in rayon threads.
203    fn generate_contexts_batch_openai(
204        api_key: &str,
205        model: &str,
206        document: &str,
207        chunks: &[String],
208    ) -> Result<Vec<String>> {
209        eprintln!(
210            "  Generating contextual prefixes for {} chunks...",
211            chunks.len()
212        );
213        info!(
214            "Generating contextual prefixes for {} chunks sequentially",
215            chunks.len()
216        );
217
218        let mut contexts = Vec::with_capacity(chunks.len());
219        for (i, chunk) in chunks.iter().enumerate() {
220            if i > 0 && i % 5 == 0 {
221                eprintln!("    Context progress: {}/{}", i, chunks.len());
222            }
223
224            match Self::generate_context_openai(api_key, model, document, chunk) {
225                Ok(ctx) => {
226                    debug!(
227                        "Generated context for chunk {}: {}...",
228                        i,
229                        &ctx[..ctx.len().min(50)]
230                    );
231                    contexts.push(ctx);
232                }
233                Err(e) => {
234                    warn!("Failed to generate context for chunk {}: {}", i, e);
235                    contexts.push(String::new()); // Empty context on failure
236                }
237            }
238        }
239
240        eprintln!(
241            "  Contextual prefix generation complete ({} contexts)",
242            contexts.len()
243        );
244        info!("Contextual prefix generation complete");
245        Ok(contexts)
246    }
247
248    /// Generate context using local LLM.
249    #[cfg(feature = "llama-cpp")]
250    fn generate_context_local(model_path: &PathBuf, document: &str, chunk: &str) -> Result<String> {
251        use llama_cpp::standard_sampler::StandardSampler;
252        use llama_cpp::{LlamaModel, LlamaParams, SessionParams};
253        use tokio::runtime::Runtime;
254
255        if !model_path.exists() {
256            return Err(anyhow!(
257                "Model file not found: {}. Run 'memvid models install phi-3.5-mini' first.",
258                model_path.display()
259            ));
260        }
261
262        // Load model
263        debug!("Loading local model from {}", model_path.display());
264        let model = LlamaModel::load_from_file(model_path, LlamaParams::default())
265            .map_err(|e| anyhow!("Failed to load model: {}", e))?;
266
267        // Truncate document if too long
268        let truncated_doc = if document.len() > 4000 {
269            format!("{}...[truncated]", &document[..4000])
270        } else {
271            document.to_string()
272        };
273
274        // Build prompt in Phi-3.5 format
275        let prompt = format!(
276            r#"<|system|>
277You are a document analysis assistant. Given a document and a chunk, provide brief context.
278<|end|>
279<|user|>
280Document:
281{truncated_doc}
282
283Chunk:
284{chunk}
285
286Provide a short context (2-3 sentences) that summarizes what this document is about and what user preferences or key facts are mentioned. Focus on first-person statements.
287<|end|>
288<|assistant|>
289"#
290        );
291
292        // Create session
293        let mut session_params = SessionParams::default();
294        session_params.n_ctx = 4096;
295        session_params.n_batch = 512;
296        if session_params.n_ubatch == 0 {
297            session_params.n_ubatch = 512;
298        }
299
300        let mut session = model
301            .create_session(session_params)
302            .map_err(|e| anyhow!("Failed to create session: {}", e))?;
303
304        // Tokenize and prime context
305        let tokens = model
306            .tokenize_bytes(prompt.as_bytes(), true, true)
307            .map_err(|e| anyhow!("Failed to tokenize: {}", e))?;
308
309        session
310            .advance_context_with_tokens(&tokens)
311            .map_err(|e| anyhow!("Failed to prime context: {}", e))?;
312
313        // Generate
314        let handle = session
315            .start_completing_with(StandardSampler::default(), 200)
316            .map_err(|e| anyhow!("Failed to start completion: {}", e))?;
317
318        let runtime = Runtime::new().map_err(|e| anyhow!("Failed to create runtime: {}", e))?;
319        let generated = runtime.block_on(async { handle.into_string_async().await });
320
321        Ok(generated.trim().to_string())
322    }
323}
324
325/// Apply contextual prefixes to chunks for embedding.
326/// Returns new chunk texts with context prepended.
327pub fn apply_contextual_prefixes(
328    _document: &str,
329    chunks: &[String],
330    contexts: &[String],
331) -> Vec<String> {
332    chunks
333        .iter()
334        .zip(contexts.iter())
335        .map(|(chunk, context)| {
336            if context.is_empty() {
337                chunk.clone()
338            } else {
339                format!("[Context: {}]\n\n{}", context, chunk)
340            }
341        })
342        .collect()
343}
344
345#[cfg(test)]
346mod tests {
347    use super::*;
348
349    #[test]
350    fn test_apply_contextual_prefixes() {
351        let document = "A conversation about cooking";
352        let chunks = vec!["I like basil".to_string(), "I grow tomatoes".to_string()];
353        let contexts = vec![
354            "User discusses their herb preferences".to_string(),
355            "User mentions their garden".to_string(),
356        ];
357
358        let result = apply_contextual_prefixes(document, &chunks, &contexts);
359
360        assert_eq!(result.len(), 2);
361        assert!(result[0].contains("[Context:"));
362        assert!(result[0].contains("I like basil"));
363        assert!(result[1].contains("User mentions their garden"));
364    }
365
366    #[test]
367    fn test_apply_contextual_prefixes_empty_context() {
368        let document = "A document";
369        let chunks = vec!["Some text".to_string()];
370        let contexts = vec![String::new()];
371
372        let result = apply_contextual_prefixes(document, &chunks, &contexts);
373
374        assert_eq!(result[0], "Some text");
375    }
376}