use anyhow::{anyhow, Result};
use reqwest::blocking::Client;
use serde::{Deserialize, Serialize};
#[cfg(feature = "llama-cpp")]
use std::path::PathBuf;
use std::time::Duration;
use tracing::{debug, info, warn};
const CONTEXTUAL_PROMPT: &str = r#"You are a document analysis assistant. Given a document and a chunk from that document, provide a brief context that situates the chunk within the document.
<document>
{document}
</document>
<chunk>
{chunk}
</chunk>
Provide a short context (2-3 sentences max) that:
1. Summarizes the document's topic and purpose
2. Notes any user preferences, personal information, or key facts mentioned in the document
3. Explains what this specific chunk is about within that context
Focus especially on first-person statements, preferences, and personal context that might be important for later retrieval.
Respond with ONLY the context, no preamble or explanation."#;
#[derive(Debug, Serialize, Clone)]
struct ChatMessage {
role: String,
content: String,
}
#[derive(Debug, Serialize)]
struct ChatRequest {
model: String,
messages: Vec<ChatMessage>,
max_tokens: u32,
temperature: f32,
}
#[derive(Debug, Deserialize)]
struct ChatResponse {
choices: Vec<ChatChoice>,
}
#[derive(Debug, Deserialize)]
struct ChatChoice {
message: ChatMessageResponse,
}
#[derive(Debug, Deserialize)]
struct ChatMessageResponse {
content: String,
}
pub enum ContextualEngine {
OpenAI { api_key: String, model: String },
#[cfg(feature = "llama-cpp")]
Local { model_path: PathBuf },
}
impl ContextualEngine {
pub fn openai() -> Result<Self> {
let api_key = std::env::var("OPENAI_API_KEY")
.map_err(|_| anyhow!("OPENAI_API_KEY environment variable not set"))?;
Ok(Self::OpenAI {
api_key,
model: "gpt-4o-mini".to_string(),
})
}
pub fn openai_with_model(model: &str) -> Result<Self> {
let api_key = std::env::var("OPENAI_API_KEY")
.map_err(|_| anyhow!("OPENAI_API_KEY environment variable not set"))?;
Ok(Self::OpenAI {
api_key,
model: model.to_string(),
})
}
#[cfg(feature = "llama-cpp")]
pub fn local(model_path: PathBuf) -> Self {
Self::Local { model_path }
}
pub fn generate_context(&self, document: &str, chunk: &str) -> Result<String> {
match self {
Self::OpenAI { api_key, model } => {
let client = crate::http::blocking_client(Duration::from_secs(60))?;
Self::generate_context_openai(&client, api_key, model, document, chunk)
}
#[cfg(feature = "llama-cpp")]
Self::Local { model_path } => Self::generate_context_local(model_path, document, chunk),
}
}
pub fn generate_contexts_batch(
&self,
document: &str,
chunks: &[String],
) -> Result<Vec<String>> {
match self {
Self::OpenAI { api_key, model } => {
Self::generate_contexts_batch_openai(api_key, model, document, chunks)
}
#[cfg(feature = "llama-cpp")]
Self::Local { model_path } => {
let mut contexts = Vec::with_capacity(chunks.len());
for chunk in chunks {
let ctx = Self::generate_context_local(model_path, document, chunk)?;
contexts.push(ctx);
}
Ok(contexts)
}
}
}
fn generate_context_openai(
client: &Client,
api_key: &str,
model: &str,
document: &str,
chunk: &str,
) -> Result<String> {
let truncated_doc = if document.len() > 6000 {
format!("{}...[truncated]", &document[..6000])
} else {
document.to_string()
};
let prompt = CONTEXTUAL_PROMPT
.replace("{document}", &truncated_doc)
.replace("{chunk}", chunk);
let request = ChatRequest {
model: model.to_string(),
messages: vec![ChatMessage {
role: "user".to_string(),
content: prompt,
}],
max_tokens: 200,
temperature: 0.0,
};
let response = client
.post("https://api.openai.com/v1/chat/completions")
.header("Authorization", format!("Bearer {}", api_key))
.header("Content-Type", "application/json")
.json(&request)
.send()
.map_err(|e| anyhow!("OpenAI API request failed: {}", e))?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().unwrap_or_default();
return Err(anyhow!("OpenAI API error {}: {}", status, body));
}
let chat_response: ChatResponse = response
.json()
.map_err(|e| anyhow!("Failed to parse OpenAI response: {}", e))?;
chat_response
.choices
.first()
.map(|c| c.message.content.clone())
.ok_or_else(|| anyhow!("No response from OpenAI"))
}
fn generate_contexts_batch_openai(
api_key: &str,
model: &str,
document: &str,
chunks: &[String],
) -> Result<Vec<String>> {
let client = crate::http::blocking_client(Duration::from_secs(60))?;
eprintln!(
" Generating contextual prefixes for {} chunks...",
chunks.len()
);
info!(
"Generating contextual prefixes for {} chunks sequentially",
chunks.len()
);
let mut contexts = Vec::with_capacity(chunks.len());
for (i, chunk) in chunks.iter().enumerate() {
if i > 0 && i % 5 == 0 {
eprintln!(" Context progress: {}/{}", i, chunks.len());
}
match Self::generate_context_openai(&client, api_key, model, document, chunk) {
Ok(ctx) => {
debug!(
"Generated context for chunk {}: {}...",
i,
&ctx[..ctx.len().min(50)]
);
contexts.push(ctx);
}
Err(e) => {
warn!("Failed to generate context for chunk {}: {}", i, e);
contexts.push(String::new()); }
}
}
eprintln!(
" Contextual prefix generation complete ({} contexts)",
contexts.len()
);
info!("Contextual prefix generation complete");
Ok(contexts)
}
#[cfg(feature = "llama-cpp")]
fn generate_context_local(model_path: &PathBuf, document: &str, chunk: &str) -> Result<String> {
use llama_cpp::standard_sampler::StandardSampler;
use llama_cpp::{LlamaModel, LlamaParams, SessionParams};
use tokio::runtime::Runtime;
if !model_path.exists() {
return Err(anyhow!(
"Model file not found: {}. Run 'memvid models install phi-3.5-mini' first.",
model_path.display()
));
}
debug!("Loading local model from {}", model_path.display());
let model = LlamaModel::load_from_file(model_path, LlamaParams::default())
.map_err(|e| anyhow!("Failed to load model: {}", e))?;
let truncated_doc = if document.len() > 4000 {
format!("{}...[truncated]", &document[..4000])
} else {
document.to_string()
};
let prompt = format!(
r#"<|system|>
You are a document analysis assistant. Given a document and a chunk, provide brief context.
<|end|>
<|user|>
Document:
{truncated_doc}
Chunk:
{chunk}
Provide a short context (2-3 sentences) that summarizes what this document is about and what user preferences or key facts are mentioned. Focus on first-person statements.
<|end|>
<|assistant|>
"#
);
let mut session_params = SessionParams::default();
session_params.n_ctx = 4096;
session_params.n_batch = 512;
if session_params.n_ubatch == 0 {
session_params.n_ubatch = 512;
}
let mut session = model
.create_session(session_params)
.map_err(|e| anyhow!("Failed to create session: {}", e))?;
let tokens = model
.tokenize_bytes(prompt.as_bytes(), true, true)
.map_err(|e| anyhow!("Failed to tokenize: {}", e))?;
session
.advance_context_with_tokens(&tokens)
.map_err(|e| anyhow!("Failed to prime context: {}", e))?;
let handle = session
.start_completing_with(StandardSampler::default(), 200)
.map_err(|e| anyhow!("Failed to start completion: {}", e))?;
let runtime = Runtime::new().map_err(|e| anyhow!("Failed to create runtime: {}", e))?;
let generated = runtime.block_on(async { handle.into_string_async().await });
Ok(generated.trim().to_string())
}
}
pub fn apply_contextual_prefixes(
_document: &str,
chunks: &[String],
contexts: &[String],
) -> Vec<String> {
chunks
.iter()
.zip(contexts.iter())
.map(|(chunk, context)| {
if context.is_empty() {
chunk.clone()
} else {
format!("[Context: {}]\n\n{}", context, chunk)
}
})
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_apply_contextual_prefixes() {
let document = "A conversation about cooking";
let chunks = vec!["I like basil".to_string(), "I grow tomatoes".to_string()];
let contexts = vec![
"User discusses their herb preferences".to_string(),
"User mentions their garden".to_string(),
];
let result = apply_contextual_prefixes(document, &chunks, &contexts);
assert_eq!(result.len(), 2);
assert!(result[0].contains("[Context:"));
assert!(result[0].contains("I like basil"));
assert!(result[1].contains("User mentions their garden"));
}
#[test]
fn test_apply_contextual_prefixes_empty_context() {
let document = "A document";
let chunks = vec!["Some text".to_string()];
let contexts = vec![String::new()];
let result = apply_contextual_prefixes(document, &chunks, &contexts);
assert_eq!(result[0], "Some text");
}
}