1use anyhow::{anyhow, Result};
16use reqwest::blocking::Client;
17use serde::{Deserialize, Serialize};
18#[cfg(feature = "llama-cpp")]
19use std::path::PathBuf;
20use std::time::Duration;
21use tracing::{debug, info, warn};
22
23const CONTEXTUAL_PROMPT: &str = r#"You are a document analysis assistant. Given a document and a chunk from that document, provide a brief context that situates the chunk within the document.
25
26<document>
27{document}
28</document>
29
30<chunk>
31{chunk}
32</chunk>
33
34Provide a short context (2-3 sentences max) that:
351. Summarizes the document's topic and purpose
362. Notes any user preferences, personal information, or key facts mentioned in the document
373. Explains what this specific chunk is about within that context
38
39Focus especially on first-person statements, preferences, and personal context that might be important for later retrieval.
40
41Respond with ONLY the context, no preamble or explanation."#;
42
43#[derive(Debug, Serialize, Clone)]
45struct ChatMessage {
46 role: String,
47 content: String,
48}
49
50#[derive(Debug, Serialize)]
52struct ChatRequest {
53 model: String,
54 messages: Vec<ChatMessage>,
55 max_tokens: u32,
56 temperature: f32,
57}
58
59#[derive(Debug, Deserialize)]
61struct ChatResponse {
62 choices: Vec<ChatChoice>,
63}
64
65#[derive(Debug, Deserialize)]
66struct ChatChoice {
67 message: ChatMessageResponse,
68}
69
70#[derive(Debug, Deserialize)]
71struct ChatMessageResponse {
72 content: String,
73}
74
75pub enum ContextualEngine {
77 OpenAI { api_key: String, model: String },
79 #[cfg(feature = "llama-cpp")]
81 Local { model_path: PathBuf },
82}
83
84impl ContextualEngine {
85 pub fn openai() -> Result<Self> {
87 let api_key = std::env::var("OPENAI_API_KEY")
88 .map_err(|_| anyhow!("OPENAI_API_KEY environment variable not set"))?;
89 Ok(Self::OpenAI {
90 api_key,
91 model: "gpt-4o-mini".to_string(),
92 })
93 }
94
95 pub fn openai_with_model(model: &str) -> Result<Self> {
97 let api_key = std::env::var("OPENAI_API_KEY")
98 .map_err(|_| anyhow!("OPENAI_API_KEY environment variable not set"))?;
99 Ok(Self::OpenAI {
100 api_key,
101 model: model.to_string(),
102 })
103 }
104
105 #[cfg(feature = "llama-cpp")]
107 pub fn local(model_path: PathBuf) -> Self {
108 Self::Local { model_path }
109 }
110
111 pub fn generate_context(&self, document: &str, chunk: &str) -> Result<String> {
114 match self {
115 Self::OpenAI { api_key, model } => {
116 let client = crate::http::blocking_client(Duration::from_secs(60))?;
117 Self::generate_context_openai(&client, api_key, model, document, chunk)
118 }
119 #[cfg(feature = "llama-cpp")]
120 Self::Local { model_path } => Self::generate_context_local(model_path, document, chunk),
121 }
122 }
123
124 pub fn generate_contexts_batch(
127 &self,
128 document: &str,
129 chunks: &[String],
130 ) -> Result<Vec<String>> {
131 match self {
132 Self::OpenAI { api_key, model } => {
133 Self::generate_contexts_batch_openai(api_key, model, document, chunks)
134 }
135 #[cfg(feature = "llama-cpp")]
136 Self::Local { model_path } => {
137 let mut contexts = Vec::with_capacity(chunks.len());
139 for chunk in chunks {
140 let ctx = Self::generate_context_local(model_path, document, chunk)?;
141 contexts.push(ctx);
142 }
143 Ok(contexts)
144 }
145 }
146 }
147
148 fn generate_context_openai(
150 client: &Client,
151 api_key: &str,
152 model: &str,
153 document: &str,
154 chunk: &str,
155 ) -> Result<String> {
156 let truncated_doc = if document.len() > 6000 {
158 format!("{}...[truncated]", &document[..6000])
159 } else {
160 document.to_string()
161 };
162
163 let prompt = CONTEXTUAL_PROMPT
164 .replace("{document}", &truncated_doc)
165 .replace("{chunk}", chunk);
166
167 let request = ChatRequest {
168 model: model.to_string(),
169 messages: vec![ChatMessage {
170 role: "user".to_string(),
171 content: prompt,
172 }],
173 max_tokens: 200,
174 temperature: 0.0,
175 };
176
177 let response = client
178 .post("https://api.openai.com/v1/chat/completions")
179 .header("Authorization", format!("Bearer {}", api_key))
180 .header("Content-Type", "application/json")
181 .json(&request)
182 .send()
183 .map_err(|e| anyhow!("OpenAI API request failed: {}", e))?;
184
185 if !response.status().is_success() {
186 let status = response.status();
187 let body = response.text().unwrap_or_default();
188 return Err(anyhow!("OpenAI API error {}: {}", status, body));
189 }
190
191 let chat_response: ChatResponse = response
192 .json()
193 .map_err(|e| anyhow!("Failed to parse OpenAI response: {}", e))?;
194
195 chat_response
196 .choices
197 .first()
198 .map(|c| c.message.content.clone())
199 .ok_or_else(|| anyhow!("No response from OpenAI"))
200 }
201
202 fn generate_contexts_batch_openai(
205 api_key: &str,
206 model: &str,
207 document: &str,
208 chunks: &[String],
209 ) -> Result<Vec<String>> {
210 let client = crate::http::blocking_client(Duration::from_secs(60))?;
211
212 eprintln!(
213 " Generating contextual prefixes for {} chunks...",
214 chunks.len()
215 );
216 info!(
217 "Generating contextual prefixes for {} chunks sequentially",
218 chunks.len()
219 );
220
221 let mut contexts = Vec::with_capacity(chunks.len());
222 for (i, chunk) in chunks.iter().enumerate() {
223 if i > 0 && i % 5 == 0 {
224 eprintln!(" Context progress: {}/{}", i, chunks.len());
225 }
226
227 match Self::generate_context_openai(&client, api_key, model, document, chunk) {
228 Ok(ctx) => {
229 debug!(
230 "Generated context for chunk {}: {}...",
231 i,
232 &ctx[..ctx.len().min(50)]
233 );
234 contexts.push(ctx);
235 }
236 Err(e) => {
237 warn!("Failed to generate context for chunk {}: {}", i, e);
238 contexts.push(String::new()); }
240 }
241 }
242
243 eprintln!(
244 " Contextual prefix generation complete ({} contexts)",
245 contexts.len()
246 );
247 info!("Contextual prefix generation complete");
248 Ok(contexts)
249 }
250
251 #[cfg(feature = "llama-cpp")]
253 fn generate_context_local(model_path: &PathBuf, document: &str, chunk: &str) -> Result<String> {
254 use llama_cpp::standard_sampler::StandardSampler;
255 use llama_cpp::{LlamaModel, LlamaParams, SessionParams};
256 use tokio::runtime::Runtime;
257
258 if !model_path.exists() {
259 return Err(anyhow!(
260 "Model file not found: {}. Run 'memvid models install phi-3.5-mini' first.",
261 model_path.display()
262 ));
263 }
264
265 debug!("Loading local model from {}", model_path.display());
267 let model = LlamaModel::load_from_file(model_path, LlamaParams::default())
268 .map_err(|e| anyhow!("Failed to load model: {}", e))?;
269
270 let truncated_doc = if document.len() > 4000 {
272 format!("{}...[truncated]", &document[..4000])
273 } else {
274 document.to_string()
275 };
276
277 let prompt = format!(
279 r#"<|system|>
280You are a document analysis assistant. Given a document and a chunk, provide brief context.
281<|end|>
282<|user|>
283Document:
284{truncated_doc}
285
286Chunk:
287{chunk}
288
289Provide a short context (2-3 sentences) that summarizes what this document is about and what user preferences or key facts are mentioned. Focus on first-person statements.
290<|end|>
291<|assistant|>
292"#
293 );
294
295 let mut session_params = SessionParams::default();
297 session_params.n_ctx = 4096;
298 session_params.n_batch = 512;
299 if session_params.n_ubatch == 0 {
300 session_params.n_ubatch = 512;
301 }
302
303 let mut session = model
304 .create_session(session_params)
305 .map_err(|e| anyhow!("Failed to create session: {}", e))?;
306
307 let tokens = model
309 .tokenize_bytes(prompt.as_bytes(), true, true)
310 .map_err(|e| anyhow!("Failed to tokenize: {}", e))?;
311
312 session
313 .advance_context_with_tokens(&tokens)
314 .map_err(|e| anyhow!("Failed to prime context: {}", e))?;
315
316 let handle = session
318 .start_completing_with(StandardSampler::default(), 200)
319 .map_err(|e| anyhow!("Failed to start completion: {}", e))?;
320
321 let runtime = Runtime::new().map_err(|e| anyhow!("Failed to create runtime: {}", e))?;
322 let generated = runtime.block_on(async { handle.into_string_async().await });
323
324 Ok(generated.trim().to_string())
325 }
326}
327
328pub fn apply_contextual_prefixes(
331 _document: &str,
332 chunks: &[String],
333 contexts: &[String],
334) -> Vec<String> {
335 chunks
336 .iter()
337 .zip(contexts.iter())
338 .map(|(chunk, context)| {
339 if context.is_empty() {
340 chunk.clone()
341 } else {
342 format!("[Context: {}]\n\n{}", context, chunk)
343 }
344 })
345 .collect()
346}
347
348#[cfg(test)]
349mod tests {
350 use super::*;
351
352 #[test]
353 fn test_apply_contextual_prefixes() {
354 let document = "A conversation about cooking";
355 let chunks = vec!["I like basil".to_string(), "I grow tomatoes".to_string()];
356 let contexts = vec![
357 "User discusses their herb preferences".to_string(),
358 "User mentions their garden".to_string(),
359 ];
360
361 let result = apply_contextual_prefixes(document, &chunks, &contexts);
362
363 assert_eq!(result.len(), 2);
364 assert!(result[0].contains("[Context:"));
365 assert!(result[0].contains("I like basil"));
366 assert!(result[1].contains("User mentions their garden"));
367 }
368
369 #[test]
370 fn test_apply_contextual_prefixes_empty_context() {
371 let document = "A document";
372 let chunks = vec!["Some text".to_string()];
373 let contexts = vec![String::new()];
374
375 let result = apply_contextual_prefixes(document, &chunks, &contexts);
376
377 assert_eq!(result[0], "Some text");
378 }
379}