memvid_cli/
contextual.rs

1//! Contextual retrieval module for improving chunk embeddings.
2//!
3//! Based on Anthropic's contextual retrieval technique: before embedding each chunk,
4//! we prepend a context summary that places the chunk within the larger document.
5//! This helps semantic search find chunks that would otherwise miss due to lack of context.
6//!
7//! For example, a chunk saying "I've been using basil and mint in my cooking lately"
8//! might get a context prefix like:
9//! "This is a conversation where the user discusses their cooking preferences
10//! and mentions growing herbs in their garden."
11//!
12//! This allows semantic queries like "dinner with homegrown ingredients" to find
13//! the chunk even though it doesn't explicitly mention "dinner" or "homegrown".
14
15use anyhow::{anyhow, Result};
16use reqwest::blocking::Client;
17use serde::{Deserialize, Serialize};
18#[cfg(feature = "llama-cpp")]
19use std::path::PathBuf;
20use std::time::Duration;
21use tracing::{debug, info, warn};
22
23/// The contextual prompt for generating chunk context.
24const CONTEXTUAL_PROMPT: &str = r#"You are a document analysis assistant. Given a document and a chunk from that document, provide a brief context that situates the chunk within the document.
25
26<document>
27{document}
28</document>
29
30<chunk>
31{chunk}
32</chunk>
33
34Provide a short context (2-3 sentences max) that:
351. Summarizes the document's topic and purpose
362. Notes any user preferences, personal information, or key facts mentioned in the document
373. Explains what this specific chunk is about within that context
38
39Focus especially on first-person statements, preferences, and personal context that might be important for later retrieval.
40
41Respond with ONLY the context, no preamble or explanation."#;
42
43/// OpenAI API request message
44#[derive(Debug, Serialize, Clone)]
45struct ChatMessage {
46    role: String,
47    content: String,
48}
49
50/// OpenAI API request
51#[derive(Debug, Serialize)]
52struct ChatRequest {
53    model: String,
54    messages: Vec<ChatMessage>,
55    max_tokens: u32,
56    temperature: f32,
57}
58
59/// OpenAI API response
60#[derive(Debug, Deserialize)]
61struct ChatResponse {
62    choices: Vec<ChatChoice>,
63}
64
65#[derive(Debug, Deserialize)]
66struct ChatChoice {
67    message: ChatMessageResponse,
68}
69
70#[derive(Debug, Deserialize)]
71struct ChatMessageResponse {
72    content: String,
73}
74
75/// Contextual retrieval engine that can use either OpenAI or local models.
76pub enum ContextualEngine {
77    /// OpenAI API-based context generation
78    OpenAI { api_key: String, model: String },
79    /// Local LLM-based context generation (llama.cpp)
80    #[cfg(feature = "llama-cpp")]
81    Local { model_path: PathBuf },
82}
83
84impl ContextualEngine {
85    /// Create a new OpenAI-based contextual engine.
86    pub fn openai() -> Result<Self> {
87        let api_key = std::env::var("OPENAI_API_KEY")
88            .map_err(|_| anyhow!("OPENAI_API_KEY environment variable not set"))?;
89        Ok(Self::OpenAI {
90            api_key,
91            model: "gpt-4o-mini".to_string(),
92        })
93    }
94
95    /// Create a new OpenAI-based contextual engine with a specific model.
96    pub fn openai_with_model(model: &str) -> Result<Self> {
97        let api_key = std::env::var("OPENAI_API_KEY")
98            .map_err(|_| anyhow!("OPENAI_API_KEY environment variable not set"))?;
99        Ok(Self::OpenAI {
100            api_key,
101            model: model.to_string(),
102        })
103    }
104
105    /// Create a new local LLM-based contextual engine.
106    #[cfg(feature = "llama-cpp")]
107    pub fn local(model_path: PathBuf) -> Self {
108        Self::Local { model_path }
109    }
110
111    /// Generate context for a chunk within a document.
112    /// Returns the context string to prepend to the chunk before embedding.
113    pub fn generate_context(&self, document: &str, chunk: &str) -> Result<String> {
114        match self {
115            Self::OpenAI { api_key, model } => {
116                let client = crate::http::blocking_client(Duration::from_secs(60))?;
117                Self::generate_context_openai(&client, api_key, model, document, chunk)
118            }
119            #[cfg(feature = "llama-cpp")]
120            Self::Local { model_path } => Self::generate_context_local(model_path, document, chunk),
121        }
122    }
123
124    /// Generate contextual prefixes for multiple chunks in parallel (OpenAI only).
125    /// Returns a vector of context strings in the same order as the input chunks.
126    pub fn generate_contexts_batch(
127        &self,
128        document: &str,
129        chunks: &[String],
130    ) -> Result<Vec<String>> {
131        match self {
132            Self::OpenAI { api_key, model } => {
133                Self::generate_contexts_batch_openai(api_key, model, document, chunks)
134            }
135            #[cfg(feature = "llama-cpp")]
136            Self::Local { model_path } => {
137                // Local models don't support batching efficiently, fall back to sequential
138                let mut contexts = Vec::with_capacity(chunks.len());
139                for chunk in chunks {
140                    let ctx = Self::generate_context_local(model_path, document, chunk)?;
141                    contexts.push(ctx);
142                }
143                Ok(contexts)
144            }
145        }
146    }
147
148    /// Generate context using OpenAI API.
149    fn generate_context_openai(
150        client: &Client,
151        api_key: &str,
152        model: &str,
153        document: &str,
154        chunk: &str,
155    ) -> Result<String> {
156        // Truncate document if too long (keep first ~6000 chars to fit in context)
157        let truncated_doc = if document.len() > 6000 {
158            format!("{}...[truncated]", &document[..6000])
159        } else {
160            document.to_string()
161        };
162
163        let prompt = CONTEXTUAL_PROMPT
164            .replace("{document}", &truncated_doc)
165            .replace("{chunk}", chunk);
166
167        let request = ChatRequest {
168            model: model.to_string(),
169            messages: vec![ChatMessage {
170                role: "user".to_string(),
171                content: prompt,
172            }],
173            max_tokens: 200,
174            temperature: 0.0,
175        };
176
177        let response = client
178            .post("https://api.openai.com/v1/chat/completions")
179            .header("Authorization", format!("Bearer {}", api_key))
180            .header("Content-Type", "application/json")
181            .json(&request)
182            .send()
183            .map_err(|e| anyhow!("OpenAI API request failed: {}", e))?;
184
185        if !response.status().is_success() {
186            let status = response.status();
187            let body = response.text().unwrap_or_default();
188            return Err(anyhow!("OpenAI API error {}: {}", status, body));
189        }
190
191        let chat_response: ChatResponse = response
192            .json()
193            .map_err(|e| anyhow!("Failed to parse OpenAI response: {}", e))?;
194
195        chat_response
196            .choices
197            .first()
198            .map(|c| c.message.content.clone())
199            .ok_or_else(|| anyhow!("No response from OpenAI"))
200    }
201
202    /// Generate contexts for chunks sequentially using OpenAI.
203    /// Uses sequential processing to avoid issues with blocking HTTP in rayon threads.
204    fn generate_contexts_batch_openai(
205        api_key: &str,
206        model: &str,
207        document: &str,
208        chunks: &[String],
209    ) -> Result<Vec<String>> {
210        let client = crate::http::blocking_client(Duration::from_secs(60))?;
211
212        eprintln!(
213            "  Generating contextual prefixes for {} chunks...",
214            chunks.len()
215        );
216        info!(
217            "Generating contextual prefixes for {} chunks sequentially",
218            chunks.len()
219        );
220
221        let mut contexts = Vec::with_capacity(chunks.len());
222        for (i, chunk) in chunks.iter().enumerate() {
223            if i > 0 && i % 5 == 0 {
224                eprintln!("    Context progress: {}/{}", i, chunks.len());
225            }
226
227            match Self::generate_context_openai(&client, api_key, model, document, chunk) {
228                Ok(ctx) => {
229                    debug!(
230                        "Generated context for chunk {}: {}...",
231                        i,
232                        &ctx[..ctx.len().min(50)]
233                    );
234                    contexts.push(ctx);
235                }
236                Err(e) => {
237                    warn!("Failed to generate context for chunk {}: {}", i, e);
238                    contexts.push(String::new()); // Empty context on failure
239                }
240            }
241        }
242
243        eprintln!(
244            "  Contextual prefix generation complete ({} contexts)",
245            contexts.len()
246        );
247        info!("Contextual prefix generation complete");
248        Ok(contexts)
249    }
250
251    /// Generate context using local LLM.
252    #[cfg(feature = "llama-cpp")]
253    fn generate_context_local(model_path: &PathBuf, document: &str, chunk: &str) -> Result<String> {
254        use llama_cpp::standard_sampler::StandardSampler;
255        use llama_cpp::{LlamaModel, LlamaParams, SessionParams};
256        use tokio::runtime::Runtime;
257
258        if !model_path.exists() {
259            return Err(anyhow!(
260                "Model file not found: {}. Run 'memvid models install phi-3.5-mini' first.",
261                model_path.display()
262            ));
263        }
264
265        // Load model
266        debug!("Loading local model from {}", model_path.display());
267        let model = LlamaModel::load_from_file(model_path, LlamaParams::default())
268            .map_err(|e| anyhow!("Failed to load model: {}", e))?;
269
270        // Truncate document if too long
271        let truncated_doc = if document.len() > 4000 {
272            format!("{}...[truncated]", &document[..4000])
273        } else {
274            document.to_string()
275        };
276
277        // Build prompt in Phi-3.5 format
278        let prompt = format!(
279            r#"<|system|>
280You are a document analysis assistant. Given a document and a chunk, provide brief context.
281<|end|>
282<|user|>
283Document:
284{truncated_doc}
285
286Chunk:
287{chunk}
288
289Provide a short context (2-3 sentences) that summarizes what this document is about and what user preferences or key facts are mentioned. Focus on first-person statements.
290<|end|>
291<|assistant|>
292"#
293        );
294
295        // Create session
296        let mut session_params = SessionParams::default();
297        session_params.n_ctx = 4096;
298        session_params.n_batch = 512;
299        if session_params.n_ubatch == 0 {
300            session_params.n_ubatch = 512;
301        }
302
303        let mut session = model
304            .create_session(session_params)
305            .map_err(|e| anyhow!("Failed to create session: {}", e))?;
306
307        // Tokenize and prime context
308        let tokens = model
309            .tokenize_bytes(prompt.as_bytes(), true, true)
310            .map_err(|e| anyhow!("Failed to tokenize: {}", e))?;
311
312        session
313            .advance_context_with_tokens(&tokens)
314            .map_err(|e| anyhow!("Failed to prime context: {}", e))?;
315
316        // Generate
317        let handle = session
318            .start_completing_with(StandardSampler::default(), 200)
319            .map_err(|e| anyhow!("Failed to start completion: {}", e))?;
320
321        let runtime = Runtime::new().map_err(|e| anyhow!("Failed to create runtime: {}", e))?;
322        let generated = runtime.block_on(async { handle.into_string_async().await });
323
324        Ok(generated.trim().to_string())
325    }
326}
327
328/// Apply contextual prefixes to chunks for embedding.
329/// Returns new chunk texts with context prepended.
330pub fn apply_contextual_prefixes(
331    _document: &str,
332    chunks: &[String],
333    contexts: &[String],
334) -> Vec<String> {
335    chunks
336        .iter()
337        .zip(contexts.iter())
338        .map(|(chunk, context)| {
339            if context.is_empty() {
340                chunk.clone()
341            } else {
342                format!("[Context: {}]\n\n{}", context, chunk)
343            }
344        })
345        .collect()
346}
347
348#[cfg(test)]
349mod tests {
350    use super::*;
351
352    #[test]
353    fn test_apply_contextual_prefixes() {
354        let document = "A conversation about cooking";
355        let chunks = vec!["I like basil".to_string(), "I grow tomatoes".to_string()];
356        let contexts = vec![
357            "User discusses their herb preferences".to_string(),
358            "User mentions their garden".to_string(),
359        ];
360
361        let result = apply_contextual_prefixes(document, &chunks, &contexts);
362
363        assert_eq!(result.len(), 2);
364        assert!(result[0].contains("[Context:"));
365        assert!(result[0].contains("I like basil"));
366        assert!(result[1].contains("User mentions their garden"));
367    }
368
369    #[test]
370    fn test_apply_contextual_prefixes_empty_context() {
371        let document = "A document";
372        let chunks = vec!["Some text".to_string()];
373        let contexts = vec![String::new()];
374
375        let result = apply_contextual_prefixes(document, &chunks, &contexts);
376
377        assert_eq!(result[0], "Some text");
378    }
379}