1use anyhow::{anyhow, Result};
16use reqwest::blocking::Client;
17use serde::{Deserialize, Serialize};
18#[cfg(feature = "llama-cpp")]
19use std::path::PathBuf;
20use tracing::{debug, info, warn};
21
22const CONTEXTUAL_PROMPT: &str = r#"You are a document analysis assistant. Given a document and a chunk from that document, provide a brief context that situates the chunk within the document.
24
25<document>
26{document}
27</document>
28
29<chunk>
30{chunk}
31</chunk>
32
33Provide a short context (2-3 sentences max) that:
341. Summarizes the document's topic and purpose
352. Notes any user preferences, personal information, or key facts mentioned in the document
363. Explains what this specific chunk is about within that context
37
38Focus especially on first-person statements, preferences, and personal context that might be important for later retrieval.
39
40Respond with ONLY the context, no preamble or explanation."#;
41
42#[derive(Debug, Serialize, Clone)]
44struct ChatMessage {
45 role: String,
46 content: String,
47}
48
49#[derive(Debug, Serialize)]
51struct ChatRequest {
52 model: String,
53 messages: Vec<ChatMessage>,
54 max_tokens: u32,
55 temperature: f32,
56}
57
58#[derive(Debug, Deserialize)]
60struct ChatResponse {
61 choices: Vec<ChatChoice>,
62}
63
64#[derive(Debug, Deserialize)]
65struct ChatChoice {
66 message: ChatMessageResponse,
67}
68
69#[derive(Debug, Deserialize)]
70struct ChatMessageResponse {
71 content: String,
72}
73
74pub enum ContextualEngine {
76 OpenAI { api_key: String, model: String },
78 #[cfg(feature = "llama-cpp")]
80 Local { model_path: PathBuf },
81}
82
83impl ContextualEngine {
84 pub fn openai() -> Result<Self> {
86 let api_key = std::env::var("OPENAI_API_KEY")
87 .map_err(|_| anyhow!("OPENAI_API_KEY environment variable not set"))?;
88 Ok(Self::OpenAI {
89 api_key,
90 model: "gpt-4o-mini".to_string(),
91 })
92 }
93
94 pub fn openai_with_model(model: &str) -> Result<Self> {
96 let api_key = std::env::var("OPENAI_API_KEY")
97 .map_err(|_| anyhow!("OPENAI_API_KEY environment variable not set"))?;
98 Ok(Self::OpenAI {
99 api_key,
100 model: model.to_string(),
101 })
102 }
103
104 #[cfg(feature = "llama-cpp")]
106 pub fn local(model_path: PathBuf) -> Self {
107 Self::Local { model_path }
108 }
109
110 pub fn generate_context(&self, document: &str, chunk: &str) -> Result<String> {
113 match self {
114 Self::OpenAI { api_key, model } => {
115 Self::generate_context_openai(api_key, model, document, chunk)
116 }
117 #[cfg(feature = "llama-cpp")]
118 Self::Local { model_path } => Self::generate_context_local(model_path, document, chunk),
119 }
120 }
121
122 pub fn generate_contexts_batch(
125 &self,
126 document: &str,
127 chunks: &[String],
128 ) -> Result<Vec<String>> {
129 match self {
130 Self::OpenAI { api_key, model } => {
131 Self::generate_contexts_batch_openai(api_key, model, document, chunks)
132 }
133 #[cfg(feature = "llama-cpp")]
134 Self::Local { model_path } => {
135 let mut contexts = Vec::with_capacity(chunks.len());
137 for chunk in chunks {
138 let ctx = Self::generate_context_local(model_path, document, chunk)?;
139 contexts.push(ctx);
140 }
141 Ok(contexts)
142 }
143 }
144 }
145
146 fn generate_context_openai(
148 api_key: &str,
149 model: &str,
150 document: &str,
151 chunk: &str,
152 ) -> Result<String> {
153 let client = Client::new();
154
155 let truncated_doc = if document.len() > 6000 {
157 format!("{}...[truncated]", &document[..6000])
158 } else {
159 document.to_string()
160 };
161
162 let prompt = CONTEXTUAL_PROMPT
163 .replace("{document}", &truncated_doc)
164 .replace("{chunk}", chunk);
165
166 let request = ChatRequest {
167 model: model.to_string(),
168 messages: vec![ChatMessage {
169 role: "user".to_string(),
170 content: prompt,
171 }],
172 max_tokens: 200,
173 temperature: 0.0,
174 };
175
176 let response = client
177 .post("https://api.openai.com/v1/chat/completions")
178 .header("Authorization", format!("Bearer {}", api_key))
179 .header("Content-Type", "application/json")
180 .json(&request)
181 .send()
182 .map_err(|e| anyhow!("OpenAI API request failed: {}", e))?;
183
184 if !response.status().is_success() {
185 let status = response.status();
186 let body = response.text().unwrap_or_default();
187 return Err(anyhow!("OpenAI API error {}: {}", status, body));
188 }
189
190 let chat_response: ChatResponse = response
191 .json()
192 .map_err(|e| anyhow!("Failed to parse OpenAI response: {}", e))?;
193
194 chat_response
195 .choices
196 .first()
197 .map(|c| c.message.content.clone())
198 .ok_or_else(|| anyhow!("No response from OpenAI"))
199 }
200
201 fn generate_contexts_batch_openai(
204 api_key: &str,
205 model: &str,
206 document: &str,
207 chunks: &[String],
208 ) -> Result<Vec<String>> {
209 eprintln!(
210 " Generating contextual prefixes for {} chunks...",
211 chunks.len()
212 );
213 info!(
214 "Generating contextual prefixes for {} chunks sequentially",
215 chunks.len()
216 );
217
218 let mut contexts = Vec::with_capacity(chunks.len());
219 for (i, chunk) in chunks.iter().enumerate() {
220 if i > 0 && i % 5 == 0 {
221 eprintln!(" Context progress: {}/{}", i, chunks.len());
222 }
223
224 match Self::generate_context_openai(api_key, model, document, chunk) {
225 Ok(ctx) => {
226 debug!(
227 "Generated context for chunk {}: {}...",
228 i,
229 &ctx[..ctx.len().min(50)]
230 );
231 contexts.push(ctx);
232 }
233 Err(e) => {
234 warn!("Failed to generate context for chunk {}: {}", i, e);
235 contexts.push(String::new()); }
237 }
238 }
239
240 eprintln!(
241 " Contextual prefix generation complete ({} contexts)",
242 contexts.len()
243 );
244 info!("Contextual prefix generation complete");
245 Ok(contexts)
246 }
247
248 #[cfg(feature = "llama-cpp")]
250 fn generate_context_local(model_path: &PathBuf, document: &str, chunk: &str) -> Result<String> {
251 use llama_cpp::standard_sampler::StandardSampler;
252 use llama_cpp::{LlamaModel, LlamaParams, SessionParams};
253 use tokio::runtime::Runtime;
254
255 if !model_path.exists() {
256 return Err(anyhow!(
257 "Model file not found: {}. Run 'memvid models install phi-3.5-mini' first.",
258 model_path.display()
259 ));
260 }
261
262 debug!("Loading local model from {}", model_path.display());
264 let model = LlamaModel::load_from_file(model_path, LlamaParams::default())
265 .map_err(|e| anyhow!("Failed to load model: {}", e))?;
266
267 let truncated_doc = if document.len() > 4000 {
269 format!("{}...[truncated]", &document[..4000])
270 } else {
271 document.to_string()
272 };
273
274 let prompt = format!(
276 r#"<|system|>
277You are a document analysis assistant. Given a document and a chunk, provide brief context.
278<|end|>
279<|user|>
280Document:
281{truncated_doc}
282
283Chunk:
284{chunk}
285
286Provide a short context (2-3 sentences) that summarizes what this document is about and what user preferences or key facts are mentioned. Focus on first-person statements.
287<|end|>
288<|assistant|>
289"#
290 );
291
292 let mut session_params = SessionParams::default();
294 session_params.n_ctx = 4096;
295 session_params.n_batch = 512;
296 if session_params.n_ubatch == 0 {
297 session_params.n_ubatch = 512;
298 }
299
300 let mut session = model
301 .create_session(session_params)
302 .map_err(|e| anyhow!("Failed to create session: {}", e))?;
303
304 let tokens = model
306 .tokenize_bytes(prompt.as_bytes(), true, true)
307 .map_err(|e| anyhow!("Failed to tokenize: {}", e))?;
308
309 session
310 .advance_context_with_tokens(&tokens)
311 .map_err(|e| anyhow!("Failed to prime context: {}", e))?;
312
313 let handle = session
315 .start_completing_with(StandardSampler::default(), 200)
316 .map_err(|e| anyhow!("Failed to start completion: {}", e))?;
317
318 let runtime = Runtime::new().map_err(|e| anyhow!("Failed to create runtime: {}", e))?;
319 let generated = runtime.block_on(async { handle.into_string_async().await });
320
321 Ok(generated.trim().to_string())
322 }
323}
324
325pub fn apply_contextual_prefixes(
328 _document: &str,
329 chunks: &[String],
330 contexts: &[String],
331) -> Vec<String> {
332 chunks
333 .iter()
334 .zip(contexts.iter())
335 .map(|(chunk, context)| {
336 if context.is_empty() {
337 chunk.clone()
338 } else {
339 format!("[Context: {}]\n\n{}", context, chunk)
340 }
341 })
342 .collect()
343}
344
345#[cfg(test)]
346mod tests {
347 use super::*;
348
349 #[test]
350 fn test_apply_contextual_prefixes() {
351 let document = "A conversation about cooking";
352 let chunks = vec!["I like basil".to_string(), "I grow tomatoes".to_string()];
353 let contexts = vec![
354 "User discusses their herb preferences".to_string(),
355 "User mentions their garden".to_string(),
356 ];
357
358 let result = apply_contextual_prefixes(document, &chunks, &contexts);
359
360 assert_eq!(result.len(), 2);
361 assert!(result[0].contains("[Context:"));
362 assert!(result[0].contains("I like basil"));
363 assert!(result[1].contains("User mentions their garden"));
364 }
365
366 #[test]
367 fn test_apply_contextual_prefixes_empty_context() {
368 let document = "A document";
369 let chunks = vec!["Some text".to_string()];
370 let contexts = vec![String::new()];
371
372 let result = apply_contextual_prefixes(document, &chunks, &contexts);
373
374 assert_eq!(result[0], "Some text");
375 }
376}