1use anyhow::{anyhow, bail, Result};
17use memvid_core::{Reranker, RerankerConfig, RerankerDocument, RerankerResult};
18use reqwest::blocking::Client;
19use serde::{Deserialize, Serialize};
20use std::sync::atomic::{AtomicBool, Ordering};
21use std::time::Duration;
22use tracing::{debug, info, warn};
23
24const OPENAI_CHAT_URL: &str = "https://api.openai.com/v1/chat/completions";
26
27const DEFAULT_RERANK_MODEL: &str = "gpt-4o-mini";
29
30const REQUEST_TIMEOUT: Duration = Duration::from_secs(30);
32
33const MAX_DOCS_PER_PROMPT: usize = 20;
35
36#[derive(Debug, Serialize)]
38struct ChatRequest<'a> {
39 model: &'a str,
40 messages: Vec<ChatMessage<'a>>,
41 temperature: f32,
42 max_tokens: usize,
43}
44
45#[derive(Debug, Serialize)]
46struct ChatMessage<'a> {
47 role: &'a str,
48 content: &'a str,
49}
50
51#[derive(Debug, Deserialize)]
53struct ChatResponse {
54 choices: Vec<ChatChoice>,
55 usage: ChatUsage,
56}
57
58#[derive(Debug, Deserialize)]
59struct ChatChoice {
60 message: ChatMessageResponse,
61}
62
63#[derive(Debug, Deserialize)]
64struct ChatMessageResponse {
65 content: String,
66}
67
68#[derive(Debug, Deserialize)]
69struct ChatUsage {
70 #[allow(dead_code)]
71 prompt_tokens: usize,
72 #[allow(dead_code)]
73 completion_tokens: usize,
74 total_tokens: usize,
75}
76
77#[derive(Debug, Deserialize)]
79struct OpenAIErrorResponse {
80 error: OpenAIError,
81}
82
83#[derive(Debug, Deserialize)]
84struct OpenAIError {
85 message: String,
86 #[serde(rename = "type")]
87 error_type: String,
88}
89
90#[derive(Debug, Deserialize)]
92struct RelevanceScore {
93 id: u64,
94 score: f32,
95}
96
97#[derive(Clone)]
101pub struct OpenAIReranker {
102 api_key: String,
103 model: String,
104 config: RerankerConfig,
105 client: Client,
106 ready: std::sync::Arc<AtomicBool>,
107}
108
109impl std::fmt::Debug for OpenAIReranker {
110 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
111 f.debug_struct("OpenAIReranker")
112 .field("model", &self.model)
113 .field("max_candidates", &self.config.max_candidates)
114 .field("ready", &self.ready.load(Ordering::Relaxed))
115 .finish()
116 }
117}
118
119impl OpenAIReranker {
120 pub fn new(api_key: String, model: Option<String>, config: RerankerConfig) -> Result<Self> {
127 if api_key.is_empty() {
128 bail!("OpenAI API key cannot be empty");
129 }
130
131 let client = Client::builder()
132 .timeout(REQUEST_TIMEOUT)
133 .build()
134 .map_err(|e| anyhow!("Failed to create HTTP client: {}", e))?;
135
136 Ok(Self {
137 api_key,
138 model: model.unwrap_or_else(|| DEFAULT_RERANK_MODEL.to_string()),
139 config,
140 client,
141 ready: std::sync::Arc::new(AtomicBool::new(false)),
142 })
143 }
144
145 pub fn from_env() -> Result<Self> {
147 let api_key = std::env::var("OPENAI_API_KEY")
148 .map_err(|_| anyhow!("OPENAI_API_KEY environment variable not set"))?;
149
150 let model = std::env::var("OPENAI_RERANK_MODEL").ok();
151
152 Self::new(api_key, model, RerankerConfig::default())
153 }
154
155 pub fn high_precision(api_key: String) -> Result<Self> {
157 Self::new(api_key, Some("gpt-4o".to_string()), RerankerConfig::high_precision())
158 }
159
160 pub fn high_recall(api_key: String) -> Result<Self> {
162 Self::new(api_key, None, RerankerConfig::high_recall())
163 }
164
165 fn build_prompt(&self, query: &str, documents: &[&RerankerDocument]) -> String {
167 let mut prompt = format!(
168 r#"You are a relevance scoring assistant. Given a query and a list of documents, score each document's relevance to the query on a scale of 0.0 to 1.0.
169
170Query: "{}"
171
172Documents:
173"#,
174 query
175 );
176
177 for (idx, doc) in documents.iter().enumerate() {
178 let preview = if doc.text.len() > 500 {
179 format!("{}...", &doc.text[..500])
180 } else {
181 doc.text.clone()
182 };
183 prompt.push_str(&format!(
184 "\n[{}] ID={}: {}\n",
185 idx + 1,
186 doc.id,
187 preview
188 ));
189 }
190
191 prompt.push_str(
192 r#"
193Return a JSON array of objects with "id" and "score" fields for each document.
194Score based on semantic relevance, not just keyword matching.
195Consider:
196- Direct answers to the query
197- Related context that helps answer the query
198- Factual relevance even if wording differs
199
200Output format (JSON only, no explanation):
201[{"id": 123, "score": 0.95}, {"id": 456, "score": 0.72}, ...]
202"#,
203 );
204
205 prompt
206 }
207
208 fn parse_scores(&self, response: &str) -> Result<Vec<RelevanceScore>> {
210 let json_start = response.find('[').ok_or_else(|| anyhow!("No JSON array found"))?;
212 let json_end = response.rfind(']').ok_or_else(|| anyhow!("No JSON array end found"))?;
213
214 let json_str = &response[json_start..=json_end];
215 let scores: Vec<RelevanceScore> = serde_json::from_str(json_str)
216 .map_err(|e| anyhow!("Failed to parse scores: {} from: {}", e, json_str))?;
217
218 Ok(scores)
219 }
220
221 fn call_openai(&self, prompt: &str) -> Result<String> {
223 let messages = vec![
224 ChatMessage {
225 role: "system",
226 content: "You are a document relevance scoring assistant. Output only valid JSON.",
227 },
228 ChatMessage {
229 role: "user",
230 content: prompt,
231 },
232 ];
233
234 let request = ChatRequest {
235 model: &self.model,
236 messages,
237 temperature: 0.0,
238 max_tokens: 1024,
239 };
240
241 let response = self
242 .client
243 .post(OPENAI_CHAT_URL)
244 .header("Authorization", format!("Bearer {}", self.api_key))
245 .header("Content-Type", "application/json")
246 .json(&request)
247 .send()
248 .map_err(|e| anyhow!("OpenAI API request failed: {}", e))?;
249
250 let status = response.status();
251 let body = response
252 .text()
253 .map_err(|e| anyhow!("Failed to read response body: {}", e))?;
254
255 if !status.is_success() {
256 if let Ok(error_response) = serde_json::from_str::<OpenAIErrorResponse>(&body) {
257 bail!(
258 "OpenAI API error ({}): {}",
259 error_response.error.error_type,
260 error_response.error.message
261 );
262 }
263 bail!("OpenAI API request failed with status {}: {}", status, body);
264 }
265
266 let chat_response: ChatResponse = serde_json::from_str(&body)
267 .map_err(|e| anyhow!("Failed to parse OpenAI response: {}", e))?;
268
269 let content = chat_response
270 .choices
271 .first()
272 .map(|c| c.message.content.clone())
273 .ok_or_else(|| anyhow!("No response content"))?;
274
275 debug!(
276 "OpenAI rerank: {} tokens used, model={}",
277 chat_response.usage.total_tokens, self.model
278 );
279
280 Ok(content)
281 }
282
283 fn rerank_with_retry(
285 &self,
286 query: &str,
287 documents: &[&RerankerDocument],
288 max_retries: usize,
289 ) -> Result<Vec<RelevanceScore>> {
290 let prompt = self.build_prompt(query, documents);
291 let mut last_error = None;
292
293 for attempt in 0..max_retries {
294 match self.call_openai(&prompt) {
295 Ok(response) => match self.parse_scores(&response) {
296 Ok(scores) => return Ok(scores),
297 Err(e) => {
298 warn!("Failed to parse scores (attempt {}): {}", attempt + 1, e);
299 last_error = Some(e);
300 }
301 },
302 Err(e) => {
303 let error_str = e.to_string();
304 if error_str.contains("rate_limit") || error_str.contains("429") {
305 let backoff = Duration::from_millis(1000 * (1 << attempt));
306 warn!(
307 "Rate limited, retrying in {:?} (attempt {}/{})",
308 backoff,
309 attempt + 1,
310 max_retries
311 );
312 std::thread::sleep(backoff);
313 last_error = Some(e);
314 continue;
315 }
316 return Err(e);
317 }
318 }
319 }
320
321 Err(last_error.unwrap_or_else(|| anyhow!("Failed after {} retries", max_retries)))
322 }
323}
324
325impl Reranker for OpenAIReranker {
326 fn kind(&self) -> &str {
327 "openai"
328 }
329
330 fn rerank(
331 &self,
332 query: &str,
333 documents: &[RerankerDocument],
334 top_k: usize,
335 ) -> memvid_core::Result<Vec<RerankerResult>> {
336 if documents.is_empty() {
337 return Ok(Vec::new());
338 }
339
340 let max_candidates = self.config.max_candidates.min(documents.len());
342 let candidates: Vec<&RerankerDocument> = documents.iter().take(max_candidates).collect();
343
344 let mut all_scores: Vec<RelevanceScore> = Vec::new();
346
347 for chunk in candidates.chunks(MAX_DOCS_PER_PROMPT) {
348 let scores = self
349 .rerank_with_retry(query, chunk, 3)
350 .map_err(|e| memvid_core::MemvidError::RerankFailed {
351 reason: e.to_string().into_boxed_str(),
352 })?;
353 all_scores.extend(scores);
354 }
355
356 let mut results: Vec<RerankerResult> = all_scores
358 .into_iter()
359 .filter_map(|score| {
360 let original_rank = documents.iter().position(|d| d.id == score.id)?;
361 if score.score < self.config.min_score {
362 return None;
363 }
364 Some(RerankerResult {
365 id: score.id,
366 score: score.score,
367 original_rank: original_rank + 1,
368 new_rank: 0, })
370 })
371 .collect();
372
373 results.sort_by(|a, b| {
375 b.score
376 .partial_cmp(&a.score)
377 .unwrap_or(std::cmp::Ordering::Equal)
378 });
379
380 let top_k = top_k.min(self.config.top_k);
382 for (idx, result) in results.iter_mut().enumerate() {
383 result.new_rank = idx + 1;
384 }
385
386 Ok(results.into_iter().take(top_k).collect())
387 }
388
389 fn is_ready(&self) -> bool {
390 self.ready.load(Ordering::Relaxed)
391 }
392
393 fn init(&mut self) -> memvid_core::Result<()> {
394 info!("Initializing OpenAI reranker with model: {}", self.model);
395
396 let test_docs = vec![RerankerDocument::new(0, "Test document")];
398 let _ = self
399 .rerank_with_retry("test query", &[&test_docs[0]], 1)
400 .map_err(|e| memvid_core::MemvidError::RerankFailed {
401 reason: format!("Failed to initialize reranker: {}", e).into_boxed_str(),
402 })?;
403
404 info!("OpenAI reranker initialized successfully");
405 self.ready.store(true, Ordering::Relaxed);
406 Ok(())
407 }
408}
409
410pub fn try_openai_reranker() -> Option<OpenAIReranker> {
412 match OpenAIReranker::from_env() {
413 Ok(reranker) => {
414 info!("OpenAI reranker available");
415 Some(reranker)
416 }
417 Err(e) => {
418 debug!("OpenAI reranker not available: {}", e);
419 None
420 }
421 }
422}
423
424#[cfg(test)]
425mod tests {
426 use super::*;
427
428 #[test]
429 fn test_empty_api_key() {
430 let result = OpenAIReranker::new(String::new(), None, RerankerConfig::default());
431 assert!(result.is_err());
432 }
433
434 #[test]
435 fn test_build_prompt() {
436 let reranker = OpenAIReranker::new(
437 "test-key".to_string(),
438 None,
439 RerankerConfig::default(),
440 )
441 .unwrap();
442
443 let docs = vec![
444 RerankerDocument::new(1, "First document about Rust"),
445 RerankerDocument::new(2, "Second document about Python"),
446 ];
447
448 let doc_refs: Vec<&RerankerDocument> = docs.iter().collect();
449 let prompt = reranker.build_prompt("What is Rust?", &doc_refs);
450
451 assert!(prompt.contains("What is Rust?"));
452 assert!(prompt.contains("ID=1"));
453 assert!(prompt.contains("ID=2"));
454 assert!(prompt.contains("First document"));
455 assert!(prompt.contains("Second document"));
456 }
457
458 #[test]
459 fn test_parse_scores() {
460 let reranker = OpenAIReranker::new(
461 "test-key".to_string(),
462 None,
463 RerankerConfig::default(),
464 )
465 .unwrap();
466
467 let response = r#"Here are the scores:
468[{"id": 1, "score": 0.95}, {"id": 2, "score": 0.42}]"#;
469
470 let scores = reranker.parse_scores(response).unwrap();
471 assert_eq!(scores.len(), 2);
472 assert_eq!(scores[0].id, 1);
473 assert!((scores[0].score - 0.95).abs() < 0.01);
474 assert_eq!(scores[1].id, 2);
475 assert!((scores[1].score - 0.42).abs() < 0.01);
476 }
477
478 #[test]
479 #[ignore] fn test_real_rerank() {
481 let reranker = OpenAIReranker::from_env().expect("OPENAI_API_KEY must be set");
482
483 let docs = vec![
484 RerankerDocument::new(1, "Rust is a systems programming language focused on safety."),
485 RerankerDocument::new(2, "Python is great for data science and machine learning."),
486 RerankerDocument::new(3, "Rust provides memory safety without garbage collection."),
487 ];
488
489 let results = reranker.rerank("What makes Rust safe?", &docs, 2).unwrap();
490 assert!(!results.is_empty());
491 assert!(results[0].id == 1 || results[0].id == 3);
493 }
494}