memvid_cli/
openai_reranker.rs

1//! OpenAI Reranker Provider
2//!
3//! This module provides a `Reranker` implementation that uses OpenAI's
4//! GPT models to score and rerank search results for improved relevance.
5//!
6//! ## Environment Variables
7//! - `OPENAI_API_KEY`: Required API key for OpenAI
8//! - `OPENAI_RERANK_MODEL`: Optional model override (default: gpt-4o-mini)
9//!
10//! ## Features
11//! - Uses structured prompting to score relevance
12//! - Efficient batch processing with configurable concurrency
13//! - Automatic retry with exponential backoff
14//! - Thread-safe for concurrent use
15
16use anyhow::{anyhow, bail, Result};
17use memvid_core::{Reranker, RerankerConfig, RerankerDocument, RerankerResult};
18use reqwest::blocking::Client;
19use serde::{Deserialize, Serialize};
20use std::sync::atomic::{AtomicBool, Ordering};
21use std::time::Duration;
22use tracing::{debug, info, warn};
23
24/// OpenAI chat completions API endpoint
25const OPENAI_CHAT_URL: &str = "https://api.openai.com/v1/chat/completions";
26
27/// Default model for reranking (fast and cost-effective)
28const DEFAULT_RERANK_MODEL: &str = "gpt-4o-mini";
29
30/// Request timeout for reranking
31const REQUEST_TIMEOUT: Duration = Duration::from_secs(30);
32
33/// Maximum documents to rerank in a single prompt
34const MAX_DOCS_PER_PROMPT: usize = 20;
35
36/// OpenAI chat request payload
37#[derive(Debug, Serialize)]
38struct ChatRequest<'a> {
39    model: &'a str,
40    messages: Vec<ChatMessage<'a>>,
41    temperature: f32,
42    max_tokens: usize,
43}
44
45#[derive(Debug, Serialize)]
46struct ChatMessage<'a> {
47    role: &'a str,
48    content: &'a str,
49}
50
51/// OpenAI chat response
52#[derive(Debug, Deserialize)]
53struct ChatResponse {
54    choices: Vec<ChatChoice>,
55    usage: ChatUsage,
56}
57
58#[derive(Debug, Deserialize)]
59struct ChatChoice {
60    message: ChatMessageResponse,
61}
62
63#[derive(Debug, Deserialize)]
64struct ChatMessageResponse {
65    content: String,
66}
67
68#[derive(Debug, Deserialize)]
69struct ChatUsage {
70    #[allow(dead_code)]
71    prompt_tokens: usize,
72    #[allow(dead_code)]
73    completion_tokens: usize,
74    total_tokens: usize,
75}
76
77/// OpenAI error response
78#[derive(Debug, Deserialize)]
79struct OpenAIErrorResponse {
80    error: OpenAIError,
81}
82
83#[derive(Debug, Deserialize)]
84struct OpenAIError {
85    message: String,
86    #[serde(rename = "type")]
87    error_type: String,
88}
89
90/// Parsed relevance score from LLM response
91#[derive(Debug, Deserialize)]
92struct RelevanceScore {
93    id: u64,
94    score: f32,
95}
96
97/// OpenAI Reranker Provider
98///
99/// Uses GPT models to evaluate query-document relevance for improved ranking.
100#[derive(Clone)]
101pub struct OpenAIReranker {
102    api_key: String,
103    model: String,
104    config: RerankerConfig,
105    client: Client,
106    ready: std::sync::Arc<AtomicBool>,
107}
108
109impl std::fmt::Debug for OpenAIReranker {
110    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
111        f.debug_struct("OpenAIReranker")
112            .field("model", &self.model)
113            .field("max_candidates", &self.config.max_candidates)
114            .field("ready", &self.ready.load(Ordering::Relaxed))
115            .finish()
116    }
117}
118
119impl OpenAIReranker {
120    /// Create a new OpenAI reranker
121    ///
122    /// # Arguments
123    /// * `api_key` - OpenAI API key
124    /// * `model` - Model to use (e.g., "gpt-4o-mini", "gpt-4o")
125    /// * `config` - Reranker configuration
126    pub fn new(api_key: String, model: Option<String>, config: RerankerConfig) -> Result<Self> {
127        if api_key.is_empty() {
128            bail!("OpenAI API key cannot be empty");
129        }
130
131        let client = Client::builder()
132            .timeout(REQUEST_TIMEOUT)
133            .build()
134            .map_err(|e| anyhow!("Failed to create HTTP client: {}", e))?;
135
136        Ok(Self {
137            api_key,
138            model: model.unwrap_or_else(|| DEFAULT_RERANK_MODEL.to_string()),
139            config,
140            client,
141            ready: std::sync::Arc::new(AtomicBool::new(false)),
142        })
143    }
144
145    /// Create reranker from environment variables
146    pub fn from_env() -> Result<Self> {
147        let api_key = std::env::var("OPENAI_API_KEY")
148            .map_err(|_| anyhow!("OPENAI_API_KEY environment variable not set"))?;
149
150        let model = std::env::var("OPENAI_RERANK_MODEL").ok();
151
152        Self::new(api_key, model, RerankerConfig::default())
153    }
154
155    /// Create reranker with high precision config
156    pub fn high_precision(api_key: String) -> Result<Self> {
157        Self::new(api_key, Some("gpt-4o".to_string()), RerankerConfig::high_precision())
158    }
159
160    /// Create reranker with high recall config
161    pub fn high_recall(api_key: String) -> Result<Self> {
162        Self::new(api_key, None, RerankerConfig::high_recall())
163    }
164
165    /// Build the reranking prompt
166    fn build_prompt(&self, query: &str, documents: &[&RerankerDocument]) -> String {
167        let mut prompt = format!(
168            r#"You are a relevance scoring assistant. Given a query and a list of documents, score each document's relevance to the query on a scale of 0.0 to 1.0.
169
170Query: "{}"
171
172Documents:
173"#,
174            query
175        );
176
177        for (idx, doc) in documents.iter().enumerate() {
178            let preview = if doc.text.len() > 500 {
179                format!("{}...", &doc.text[..500])
180            } else {
181                doc.text.clone()
182            };
183            prompt.push_str(&format!(
184                "\n[{}] ID={}: {}\n",
185                idx + 1,
186                doc.id,
187                preview
188            ));
189        }
190
191        prompt.push_str(
192            r#"
193Return a JSON array of objects with "id" and "score" fields for each document.
194Score based on semantic relevance, not just keyword matching.
195Consider:
196- Direct answers to the query
197- Related context that helps answer the query
198- Factual relevance even if wording differs
199
200Output format (JSON only, no explanation):
201[{"id": 123, "score": 0.95}, {"id": 456, "score": 0.72}, ...]
202"#,
203        );
204
205        prompt
206    }
207
208    /// Parse relevance scores from LLM response
209    fn parse_scores(&self, response: &str) -> Result<Vec<RelevanceScore>> {
210        // Try to find JSON array in response
211        let json_start = response.find('[').ok_or_else(|| anyhow!("No JSON array found"))?;
212        let json_end = response.rfind(']').ok_or_else(|| anyhow!("No JSON array end found"))?;
213
214        let json_str = &response[json_start..=json_end];
215        let scores: Vec<RelevanceScore> = serde_json::from_str(json_str)
216            .map_err(|e| anyhow!("Failed to parse scores: {} from: {}", e, json_str))?;
217
218        Ok(scores)
219    }
220
221    /// Call OpenAI chat API
222    fn call_openai(&self, prompt: &str) -> Result<String> {
223        let messages = vec![
224            ChatMessage {
225                role: "system",
226                content: "You are a document relevance scoring assistant. Output only valid JSON.",
227            },
228            ChatMessage {
229                role: "user",
230                content: prompt,
231            },
232        ];
233
234        let request = ChatRequest {
235            model: &self.model,
236            messages,
237            temperature: 0.0,
238            max_tokens: 1024,
239        };
240
241        let response = self
242            .client
243            .post(OPENAI_CHAT_URL)
244            .header("Authorization", format!("Bearer {}", self.api_key))
245            .header("Content-Type", "application/json")
246            .json(&request)
247            .send()
248            .map_err(|e| anyhow!("OpenAI API request failed: {}", e))?;
249
250        let status = response.status();
251        let body = response
252            .text()
253            .map_err(|e| anyhow!("Failed to read response body: {}", e))?;
254
255        if !status.is_success() {
256            if let Ok(error_response) = serde_json::from_str::<OpenAIErrorResponse>(&body) {
257                bail!(
258                    "OpenAI API error ({}): {}",
259                    error_response.error.error_type,
260                    error_response.error.message
261                );
262            }
263            bail!("OpenAI API request failed with status {}: {}", status, body);
264        }
265
266        let chat_response: ChatResponse = serde_json::from_str(&body)
267            .map_err(|e| anyhow!("Failed to parse OpenAI response: {}", e))?;
268
269        let content = chat_response
270            .choices
271            .first()
272            .map(|c| c.message.content.clone())
273            .ok_or_else(|| anyhow!("No response content"))?;
274
275        debug!(
276            "OpenAI rerank: {} tokens used, model={}",
277            chat_response.usage.total_tokens, self.model
278        );
279
280        Ok(content)
281    }
282
283    /// Rerank with retry logic
284    fn rerank_with_retry(
285        &self,
286        query: &str,
287        documents: &[&RerankerDocument],
288        max_retries: usize,
289    ) -> Result<Vec<RelevanceScore>> {
290        let prompt = self.build_prompt(query, documents);
291        let mut last_error = None;
292
293        for attempt in 0..max_retries {
294            match self.call_openai(&prompt) {
295                Ok(response) => match self.parse_scores(&response) {
296                    Ok(scores) => return Ok(scores),
297                    Err(e) => {
298                        warn!("Failed to parse scores (attempt {}): {}", attempt + 1, e);
299                        last_error = Some(e);
300                    }
301                },
302                Err(e) => {
303                    let error_str = e.to_string();
304                    if error_str.contains("rate_limit") || error_str.contains("429") {
305                        let backoff = Duration::from_millis(1000 * (1 << attempt));
306                        warn!(
307                            "Rate limited, retrying in {:?} (attempt {}/{})",
308                            backoff,
309                            attempt + 1,
310                            max_retries
311                        );
312                        std::thread::sleep(backoff);
313                        last_error = Some(e);
314                        continue;
315                    }
316                    return Err(e);
317                }
318            }
319        }
320
321        Err(last_error.unwrap_or_else(|| anyhow!("Failed after {} retries", max_retries)))
322    }
323}
324
325impl Reranker for OpenAIReranker {
326    fn kind(&self) -> &str {
327        "openai"
328    }
329
330    fn rerank(
331        &self,
332        query: &str,
333        documents: &[RerankerDocument],
334        top_k: usize,
335    ) -> memvid_core::Result<Vec<RerankerResult>> {
336        if documents.is_empty() {
337            return Ok(Vec::new());
338        }
339
340        // Limit candidates
341        let max_candidates = self.config.max_candidates.min(documents.len());
342        let candidates: Vec<&RerankerDocument> = documents.iter().take(max_candidates).collect();
343
344        // Process in batches if needed
345        let mut all_scores: Vec<RelevanceScore> = Vec::new();
346
347        for chunk in candidates.chunks(MAX_DOCS_PER_PROMPT) {
348            let scores = self
349                .rerank_with_retry(query, chunk, 3)
350                .map_err(|e| memvid_core::MemvidError::RerankFailed {
351                    reason: e.to_string().into_boxed_str(),
352                })?;
353            all_scores.extend(scores);
354        }
355
356        // Build results with original ranks
357        let mut results: Vec<RerankerResult> = all_scores
358            .into_iter()
359            .filter_map(|score| {
360                let original_rank = documents.iter().position(|d| d.id == score.id)?;
361                if score.score < self.config.min_score {
362                    return None;
363                }
364                Some(RerankerResult {
365                    id: score.id,
366                    score: score.score,
367                    original_rank: original_rank + 1,
368                    new_rank: 0, // Will be set after sorting
369                })
370            })
371            .collect();
372
373        // Sort by score descending
374        results.sort_by(|a, b| {
375            b.score
376                .partial_cmp(&a.score)
377                .unwrap_or(std::cmp::Ordering::Equal)
378        });
379
380        // Assign new ranks and limit to top_k
381        let top_k = top_k.min(self.config.top_k);
382        for (idx, result) in results.iter_mut().enumerate() {
383            result.new_rank = idx + 1;
384        }
385
386        Ok(results.into_iter().take(top_k).collect())
387    }
388
389    fn is_ready(&self) -> bool {
390        self.ready.load(Ordering::Relaxed)
391    }
392
393    fn init(&mut self) -> memvid_core::Result<()> {
394        info!("Initializing OpenAI reranker with model: {}", self.model);
395
396        // Test with a simple rerank to validate API key
397        let test_docs = vec![RerankerDocument::new(0, "Test document")];
398        let _ = self
399            .rerank_with_retry("test query", &[&test_docs[0]], 1)
400            .map_err(|e| memvid_core::MemvidError::RerankFailed {
401                reason: format!("Failed to initialize reranker: {}", e).into_boxed_str(),
402            })?;
403
404        info!("OpenAI reranker initialized successfully");
405        self.ready.store(true, Ordering::Relaxed);
406        Ok(())
407    }
408}
409
410/// Helper to create an OpenAI reranker or return None
411pub fn try_openai_reranker() -> Option<OpenAIReranker> {
412    match OpenAIReranker::from_env() {
413        Ok(reranker) => {
414            info!("OpenAI reranker available");
415            Some(reranker)
416        }
417        Err(e) => {
418            debug!("OpenAI reranker not available: {}", e);
419            None
420        }
421    }
422}
423
424#[cfg(test)]
425mod tests {
426    use super::*;
427
428    #[test]
429    fn test_empty_api_key() {
430        let result = OpenAIReranker::new(String::new(), None, RerankerConfig::default());
431        assert!(result.is_err());
432    }
433
434    #[test]
435    fn test_build_prompt() {
436        let reranker = OpenAIReranker::new(
437            "test-key".to_string(),
438            None,
439            RerankerConfig::default(),
440        )
441        .unwrap();
442
443        let docs = vec![
444            RerankerDocument::new(1, "First document about Rust"),
445            RerankerDocument::new(2, "Second document about Python"),
446        ];
447
448        let doc_refs: Vec<&RerankerDocument> = docs.iter().collect();
449        let prompt = reranker.build_prompt("What is Rust?", &doc_refs);
450
451        assert!(prompt.contains("What is Rust?"));
452        assert!(prompt.contains("ID=1"));
453        assert!(prompt.contains("ID=2"));
454        assert!(prompt.contains("First document"));
455        assert!(prompt.contains("Second document"));
456    }
457
458    #[test]
459    fn test_parse_scores() {
460        let reranker = OpenAIReranker::new(
461            "test-key".to_string(),
462            None,
463            RerankerConfig::default(),
464        )
465        .unwrap();
466
467        let response = r#"Here are the scores:
468[{"id": 1, "score": 0.95}, {"id": 2, "score": 0.42}]"#;
469
470        let scores = reranker.parse_scores(response).unwrap();
471        assert_eq!(scores.len(), 2);
472        assert_eq!(scores[0].id, 1);
473        assert!((scores[0].score - 0.95).abs() < 0.01);
474        assert_eq!(scores[1].id, 2);
475        assert!((scores[1].score - 0.42).abs() < 0.01);
476    }
477
478    #[test]
479    #[ignore] // Requires valid API key
480    fn test_real_rerank() {
481        let reranker = OpenAIReranker::from_env().expect("OPENAI_API_KEY must be set");
482
483        let docs = vec![
484            RerankerDocument::new(1, "Rust is a systems programming language focused on safety."),
485            RerankerDocument::new(2, "Python is great for data science and machine learning."),
486            RerankerDocument::new(3, "Rust provides memory safety without garbage collection."),
487        ];
488
489        let results = reranker.rerank("What makes Rust safe?", &docs, 2).unwrap();
490        assert!(!results.is_empty());
491        // Document about Rust safety should rank higher
492        assert!(results[0].id == 1 || results[0].id == 3);
493    }
494}