1use anyhow::{anyhow, Result};
16use reqwest::blocking::Client;
17use serde::{Deserialize, Serialize};
18#[cfg(feature = "llama-cpp")]
19use std::path::PathBuf;
20use tracing::{debug, info, warn};
21
22const CONTEXTUAL_PROMPT: &str = r#"You are a document analysis assistant. Given a document and a chunk from that document, provide a brief context that situates the chunk within the document.
24
25<document>
26{document}
27</document>
28
29<chunk>
30{chunk}
31</chunk>
32
33Provide a short context (2-3 sentences max) that:
341. Summarizes the document's topic and purpose
352. Notes any user preferences, personal information, or key facts mentioned in the document
363. Explains what this specific chunk is about within that context
37
38Focus especially on first-person statements, preferences, and personal context that might be important for later retrieval.
39
40Respond with ONLY the context, no preamble or explanation."#;
41
42#[derive(Debug, Serialize, Clone)]
44struct ChatMessage {
45 role: String,
46 content: String,
47}
48
49#[derive(Debug, Serialize)]
51struct ChatRequest {
52 model: String,
53 messages: Vec<ChatMessage>,
54 max_tokens: u32,
55 temperature: f32,
56}
57
58#[derive(Debug, Deserialize)]
60struct ChatResponse {
61 choices: Vec<ChatChoice>,
62}
63
64#[derive(Debug, Deserialize)]
65struct ChatChoice {
66 message: ChatMessageResponse,
67}
68
69#[derive(Debug, Deserialize)]
70struct ChatMessageResponse {
71 content: String,
72}
73
74pub enum ContextualEngine {
76 OpenAI {
78 api_key: String,
79 model: String,
80 },
81 #[cfg(feature = "llama-cpp")]
83 Local {
84 model_path: PathBuf,
85 },
86}
87
88impl ContextualEngine {
89 pub fn openai() -> Result<Self> {
91 let api_key = std::env::var("OPENAI_API_KEY")
92 .map_err(|_| anyhow!("OPENAI_API_KEY environment variable not set"))?;
93 Ok(Self::OpenAI {
94 api_key,
95 model: "gpt-4o-mini".to_string(),
96 })
97 }
98
99 pub fn openai_with_model(model: &str) -> Result<Self> {
101 let api_key = std::env::var("OPENAI_API_KEY")
102 .map_err(|_| anyhow!("OPENAI_API_KEY environment variable not set"))?;
103 Ok(Self::OpenAI {
104 api_key,
105 model: model.to_string(),
106 })
107 }
108
109 #[cfg(feature = "llama-cpp")]
111 pub fn local(model_path: PathBuf) -> Self {
112 Self::Local { model_path }
113 }
114
115 pub fn generate_context(&self, document: &str, chunk: &str) -> Result<String> {
118 match self {
119 Self::OpenAI { api_key, model } => {
120 Self::generate_context_openai(api_key, model, document, chunk)
121 }
122 #[cfg(feature = "llama-cpp")]
123 Self::Local { model_path } => {
124 Self::generate_context_local(model_path, document, chunk)
125 }
126 }
127 }
128
129 pub fn generate_contexts_batch(
132 &self,
133 document: &str,
134 chunks: &[String],
135 ) -> Result<Vec<String>> {
136 match self {
137 Self::OpenAI { api_key, model } => {
138 Self::generate_contexts_batch_openai(api_key, model, document, chunks)
139 }
140 #[cfg(feature = "llama-cpp")]
141 Self::Local { model_path } => {
142 let mut contexts = Vec::with_capacity(chunks.len());
144 for chunk in chunks {
145 let ctx = Self::generate_context_local(model_path, document, chunk)?;
146 contexts.push(ctx);
147 }
148 Ok(contexts)
149 }
150 }
151 }
152
153 fn generate_context_openai(
155 api_key: &str,
156 model: &str,
157 document: &str,
158 chunk: &str,
159 ) -> Result<String> {
160 let client = Client::new();
161
162 let truncated_doc = if document.len() > 6000 {
164 format!("{}...[truncated]", &document[..6000])
165 } else {
166 document.to_string()
167 };
168
169 let prompt = CONTEXTUAL_PROMPT
170 .replace("{document}", &truncated_doc)
171 .replace("{chunk}", chunk);
172
173 let request = ChatRequest {
174 model: model.to_string(),
175 messages: vec![ChatMessage {
176 role: "user".to_string(),
177 content: prompt,
178 }],
179 max_tokens: 200,
180 temperature: 0.0,
181 };
182
183 let response = client
184 .post("https://api.openai.com/v1/chat/completions")
185 .header("Authorization", format!("Bearer {}", api_key))
186 .header("Content-Type", "application/json")
187 .json(&request)
188 .send()
189 .map_err(|e| anyhow!("OpenAI API request failed: {}", e))?;
190
191 if !response.status().is_success() {
192 let status = response.status();
193 let body = response.text().unwrap_or_default();
194 return Err(anyhow!("OpenAI API error {}: {}", status, body));
195 }
196
197 let chat_response: ChatResponse = response
198 .json()
199 .map_err(|e| anyhow!("Failed to parse OpenAI response: {}", e))?;
200
201 chat_response
202 .choices
203 .first()
204 .map(|c| c.message.content.clone())
205 .ok_or_else(|| anyhow!("No response from OpenAI"))
206 }
207
208 fn generate_contexts_batch_openai(
211 api_key: &str,
212 model: &str,
213 document: &str,
214 chunks: &[String],
215 ) -> Result<Vec<String>> {
216 eprintln!(
217 " Generating contextual prefixes for {} chunks...",
218 chunks.len()
219 );
220 info!(
221 "Generating contextual prefixes for {} chunks sequentially",
222 chunks.len()
223 );
224
225 let mut contexts = Vec::with_capacity(chunks.len());
226 for (i, chunk) in chunks.iter().enumerate() {
227 if i > 0 && i % 5 == 0 {
228 eprintln!(" Context progress: {}/{}", i, chunks.len());
229 }
230
231 match Self::generate_context_openai(api_key, model, document, chunk) {
232 Ok(ctx) => {
233 debug!("Generated context for chunk {}: {}...", i, &ctx[..ctx.len().min(50)]);
234 contexts.push(ctx);
235 }
236 Err(e) => {
237 warn!("Failed to generate context for chunk {}: {}", i, e);
238 contexts.push(String::new()); }
240 }
241 }
242
243 eprintln!(" Contextual prefix generation complete ({} contexts)", contexts.len());
244 info!("Contextual prefix generation complete");
245 Ok(contexts)
246 }
247
248 #[cfg(feature = "llama-cpp")]
250 fn generate_context_local(model_path: &PathBuf, document: &str, chunk: &str) -> Result<String> {
251 use llama_cpp::standard_sampler::StandardSampler;
252 use llama_cpp::{LlamaModel, LlamaParams, SessionParams};
253 use tokio::runtime::Runtime;
254
255 if !model_path.exists() {
256 return Err(anyhow!(
257 "Model file not found: {}. Run 'memvid models install phi-3.5-mini' first.",
258 model_path.display()
259 ));
260 }
261
262 debug!("Loading local model from {}", model_path.display());
264 let model = LlamaModel::load_from_file(model_path, LlamaParams::default())
265 .map_err(|e| anyhow!("Failed to load model: {}", e))?;
266
267 let truncated_doc = if document.len() > 4000 {
269 format!("{}...[truncated]", &document[..4000])
270 } else {
271 document.to_string()
272 };
273
274 let prompt = format!(
276 r#"<|system|>
277You are a document analysis assistant. Given a document and a chunk, provide brief context.
278<|end|>
279<|user|>
280Document:
281{truncated_doc}
282
283Chunk:
284{chunk}
285
286Provide a short context (2-3 sentences) that summarizes what this document is about and what user preferences or key facts are mentioned. Focus on first-person statements.
287<|end|>
288<|assistant|>
289"#
290 );
291
292 let mut session_params = SessionParams::default();
294 session_params.n_ctx = 4096;
295 session_params.n_batch = 512;
296 if session_params.n_ubatch == 0 {
297 session_params.n_ubatch = 512;
298 }
299
300 let mut session = model
301 .create_session(session_params)
302 .map_err(|e| anyhow!("Failed to create session: {}", e))?;
303
304 let tokens = model
306 .tokenize_bytes(prompt.as_bytes(), true, true)
307 .map_err(|e| anyhow!("Failed to tokenize: {}", e))?;
308
309 session
310 .advance_context_with_tokens(&tokens)
311 .map_err(|e| anyhow!("Failed to prime context: {}", e))?;
312
313 let handle = session
315 .start_completing_with(StandardSampler::default(), 200)
316 .map_err(|e| anyhow!("Failed to start completion: {}", e))?;
317
318 let runtime = Runtime::new().map_err(|e| anyhow!("Failed to create runtime: {}", e))?;
319 let generated = runtime.block_on(async { handle.into_string_async().await });
320
321 Ok(generated.trim().to_string())
322 }
323}
324
325pub fn apply_contextual_prefixes(
328 _document: &str,
329 chunks: &[String],
330 contexts: &[String],
331) -> Vec<String> {
332 chunks
333 .iter()
334 .zip(contexts.iter())
335 .map(|(chunk, context)| {
336 if context.is_empty() {
337 chunk.clone()
338 } else {
339 format!("[Context: {}]\n\n{}", context, chunk)
340 }
341 })
342 .collect()
343}
344
345#[cfg(test)]
346mod tests {
347 use super::*;
348
349 #[test]
350 fn test_apply_contextual_prefixes() {
351 let document = "A conversation about cooking";
352 let chunks = vec![
353 "I like basil".to_string(),
354 "I grow tomatoes".to_string(),
355 ];
356 let contexts = vec![
357 "User discusses their herb preferences".to_string(),
358 "User mentions their garden".to_string(),
359 ];
360
361 let result = apply_contextual_prefixes(document, &chunks, &contexts);
362
363 assert_eq!(result.len(), 2);
364 assert!(result[0].contains("[Context:"));
365 assert!(result[0].contains("I like basil"));
366 assert!(result[1].contains("User mentions their garden"));
367 }
368
369 #[test]
370 fn test_apply_contextual_prefixes_empty_context() {
371 let document = "A document";
372 let chunks = vec!["Some text".to_string()];
373 let contexts = vec![String::new()];
374
375 let result = apply_contextual_prefixes(document, &chunks, &contexts);
376
377 assert_eq!(result[0], "Some text");
378 }
379}