memvid_cli/
openai_reranker.rs

1//! OpenAI Reranker Provider
2//!
3//! This module provides a `Reranker` implementation that uses OpenAI's
4//! GPT models to score and rerank search results for improved relevance.
5//!
6//! ## Environment Variables
7//! - `OPENAI_API_KEY`: Required API key for OpenAI
8//! - `OPENAI_RERANK_MODEL`: Optional model override (default: gpt-4o-mini)
9//!
10//! ## Features
11//! - Uses structured prompting to score relevance
12//! - Efficient batch processing with configurable concurrency
13//! - Automatic retry with exponential backoff
14//! - Thread-safe for concurrent use
15
16use anyhow::{anyhow, bail, Result};
17use memvid_core::{Reranker, RerankerConfig, RerankerDocument, RerankerResult};
18use reqwest::blocking::Client;
19use serde::{Deserialize, Serialize};
20use std::sync::atomic::{AtomicBool, Ordering};
21use std::time::Duration;
22use tracing::{debug, info, warn};
23
24/// OpenAI chat completions API endpoint
25const OPENAI_CHAT_URL: &str = "https://api.openai.com/v1/chat/completions";
26
27/// Default model for reranking (fast and cost-effective)
28const DEFAULT_RERANK_MODEL: &str = "gpt-4o-mini";
29
30/// Request timeout for reranking
31const REQUEST_TIMEOUT: Duration = Duration::from_secs(30);
32
33/// Maximum documents to rerank in a single prompt
34const MAX_DOCS_PER_PROMPT: usize = 20;
35
36/// OpenAI chat request payload
37#[derive(Debug, Serialize)]
38struct ChatRequest<'a> {
39    model: &'a str,
40    messages: Vec<ChatMessage<'a>>,
41    temperature: f32,
42    max_tokens: usize,
43}
44
45#[derive(Debug, Serialize)]
46struct ChatMessage<'a> {
47    role: &'a str,
48    content: &'a str,
49}
50
51/// OpenAI chat response
52#[derive(Debug, Deserialize)]
53struct ChatResponse {
54    choices: Vec<ChatChoice>,
55    usage: ChatUsage,
56}
57
58#[derive(Debug, Deserialize)]
59struct ChatChoice {
60    message: ChatMessageResponse,
61}
62
63#[derive(Debug, Deserialize)]
64struct ChatMessageResponse {
65    content: String,
66}
67
68#[derive(Debug, Deserialize)]
69struct ChatUsage {
70    #[allow(dead_code)]
71    prompt_tokens: usize,
72    #[allow(dead_code)]
73    completion_tokens: usize,
74    total_tokens: usize,
75}
76
77/// OpenAI error response
78#[derive(Debug, Deserialize)]
79struct OpenAIErrorResponse {
80    error: OpenAIError,
81}
82
83#[derive(Debug, Deserialize)]
84struct OpenAIError {
85    message: String,
86    #[serde(rename = "type")]
87    error_type: String,
88}
89
90/// Parsed relevance score from LLM response
91#[derive(Debug, Deserialize)]
92struct RelevanceScore {
93    id: u64,
94    score: f32,
95}
96
97/// OpenAI Reranker Provider
98///
99/// Uses GPT models to evaluate query-document relevance for improved ranking.
100#[derive(Clone)]
101pub struct OpenAIReranker {
102    api_key: String,
103    model: String,
104    config: RerankerConfig,
105    client: Client,
106    ready: std::sync::Arc<AtomicBool>,
107}
108
109impl std::fmt::Debug for OpenAIReranker {
110    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
111        f.debug_struct("OpenAIReranker")
112            .field("model", &self.model)
113            .field("max_candidates", &self.config.max_candidates)
114            .field("ready", &self.ready.load(Ordering::Relaxed))
115            .finish()
116    }
117}
118
119impl OpenAIReranker {
120    /// Create a new OpenAI reranker
121    ///
122    /// # Arguments
123    /// * `api_key` - OpenAI API key
124    /// * `model` - Model to use (e.g., "gpt-4o-mini", "gpt-4o")
125    /// * `config` - Reranker configuration
126    pub fn new(api_key: String, model: Option<String>, config: RerankerConfig) -> Result<Self> {
127        if api_key.is_empty() {
128            bail!("OpenAI API key cannot be empty");
129        }
130
131        let client = Client::builder()
132            .timeout(REQUEST_TIMEOUT)
133            .build()
134            .map_err(|e| anyhow!("Failed to create HTTP client: {}", e))?;
135
136        Ok(Self {
137            api_key,
138            model: model.unwrap_or_else(|| DEFAULT_RERANK_MODEL.to_string()),
139            config,
140            client,
141            ready: std::sync::Arc::new(AtomicBool::new(false)),
142        })
143    }
144
145    /// Create reranker from environment variables
146    pub fn from_env() -> Result<Self> {
147        let api_key = std::env::var("OPENAI_API_KEY")
148            .map_err(|_| anyhow!("OPENAI_API_KEY environment variable not set"))?;
149
150        let model = std::env::var("OPENAI_RERANK_MODEL").ok();
151
152        Self::new(api_key, model, RerankerConfig::default())
153    }
154
155    /// Create reranker with high precision config
156    pub fn high_precision(api_key: String) -> Result<Self> {
157        Self::new(
158            api_key,
159            Some("gpt-4o".to_string()),
160            RerankerConfig::high_precision(),
161        )
162    }
163
164    /// Create reranker with high recall config
165    pub fn high_recall(api_key: String) -> Result<Self> {
166        Self::new(api_key, None, RerankerConfig::high_recall())
167    }
168
169    /// Build the reranking prompt
170    fn build_prompt(&self, query: &str, documents: &[&RerankerDocument]) -> String {
171        let mut prompt = format!(
172            r#"You are a relevance scoring assistant. Given a query and a list of documents, score each document's relevance to the query on a scale of 0.0 to 1.0.
173
174Query: "{}"
175
176Documents:
177"#,
178            query
179        );
180
181        for (idx, doc) in documents.iter().enumerate() {
182            let preview = if doc.text.len() > 500 {
183                format!("{}...", &doc.text[..500])
184            } else {
185                doc.text.clone()
186            };
187            prompt.push_str(&format!("\n[{}] ID={}: {}\n", idx + 1, doc.id, preview));
188        }
189
190        prompt.push_str(
191            r#"
192Return a JSON array of objects with "id" and "score" fields for each document.
193Score based on semantic relevance, not just keyword matching.
194Consider:
195- Direct answers to the query
196- Related context that helps answer the query
197- Factual relevance even if wording differs
198
199Output format (JSON only, no explanation):
200[{"id": 123, "score": 0.95}, {"id": 456, "score": 0.72}, ...]
201"#,
202        );
203
204        prompt
205    }
206
207    /// Parse relevance scores from LLM response
208    fn parse_scores(&self, response: &str) -> Result<Vec<RelevanceScore>> {
209        // Try to find JSON array in response
210        let json_start = response
211            .find('[')
212            .ok_or_else(|| anyhow!("No JSON array found"))?;
213        let json_end = response
214            .rfind(']')
215            .ok_or_else(|| anyhow!("No JSON array end found"))?;
216
217        let json_str = &response[json_start..=json_end];
218        let scores: Vec<RelevanceScore> = serde_json::from_str(json_str)
219            .map_err(|e| anyhow!("Failed to parse scores: {} from: {}", e, json_str))?;
220
221        Ok(scores)
222    }
223
224    /// Call OpenAI chat API
225    fn call_openai(&self, prompt: &str) -> Result<String> {
226        let messages = vec![
227            ChatMessage {
228                role: "system",
229                content: "You are a document relevance scoring assistant. Output only valid JSON.",
230            },
231            ChatMessage {
232                role: "user",
233                content: prompt,
234            },
235        ];
236
237        let request = ChatRequest {
238            model: &self.model,
239            messages,
240            temperature: 0.0,
241            max_tokens: 1024,
242        };
243
244        let response = self
245            .client
246            .post(OPENAI_CHAT_URL)
247            .header("Authorization", format!("Bearer {}", self.api_key))
248            .header("Content-Type", "application/json")
249            .json(&request)
250            .send()
251            .map_err(|e| anyhow!("OpenAI API request failed: {}", e))?;
252
253        let status = response.status();
254        let body = response
255            .text()
256            .map_err(|e| anyhow!("Failed to read response body: {}", e))?;
257
258        if !status.is_success() {
259            if let Ok(error_response) = serde_json::from_str::<OpenAIErrorResponse>(&body) {
260                bail!(
261                    "OpenAI API error ({}): {}",
262                    error_response.error.error_type,
263                    error_response.error.message
264                );
265            }
266            bail!("OpenAI API request failed with status {}: {}", status, body);
267        }
268
269        let chat_response: ChatResponse = serde_json::from_str(&body)
270            .map_err(|e| anyhow!("Failed to parse OpenAI response: {}", e))?;
271
272        let content = chat_response
273            .choices
274            .first()
275            .map(|c| c.message.content.clone())
276            .ok_or_else(|| anyhow!("No response content"))?;
277
278        debug!(
279            "OpenAI rerank: {} tokens used, model={}",
280            chat_response.usage.total_tokens, self.model
281        );
282
283        Ok(content)
284    }
285
286    /// Rerank with retry logic
287    fn rerank_with_retry(
288        &self,
289        query: &str,
290        documents: &[&RerankerDocument],
291        max_retries: usize,
292    ) -> Result<Vec<RelevanceScore>> {
293        let prompt = self.build_prompt(query, documents);
294        let mut last_error = None;
295
296        for attempt in 0..max_retries {
297            match self.call_openai(&prompt) {
298                Ok(response) => match self.parse_scores(&response) {
299                    Ok(scores) => return Ok(scores),
300                    Err(e) => {
301                        warn!("Failed to parse scores (attempt {}): {}", attempt + 1, e);
302                        last_error = Some(e);
303                    }
304                },
305                Err(e) => {
306                    let error_str = e.to_string();
307                    if error_str.contains("rate_limit") || error_str.contains("429") {
308                        let backoff = Duration::from_millis(1000 * (1 << attempt));
309                        warn!(
310                            "Rate limited, retrying in {:?} (attempt {}/{})",
311                            backoff,
312                            attempt + 1,
313                            max_retries
314                        );
315                        std::thread::sleep(backoff);
316                        last_error = Some(e);
317                        continue;
318                    }
319                    return Err(e);
320                }
321            }
322        }
323
324        Err(last_error.unwrap_or_else(|| anyhow!("Failed after {} retries", max_retries)))
325    }
326}
327
328impl Reranker for OpenAIReranker {
329    fn kind(&self) -> &str {
330        "openai"
331    }
332
333    fn rerank(
334        &self,
335        query: &str,
336        documents: &[RerankerDocument],
337        top_k: usize,
338    ) -> memvid_core::Result<Vec<RerankerResult>> {
339        if documents.is_empty() {
340            return Ok(Vec::new());
341        }
342
343        // Limit candidates
344        let max_candidates = self.config.max_candidates.min(documents.len());
345        let candidates: Vec<&RerankerDocument> = documents.iter().take(max_candidates).collect();
346
347        // Process in batches if needed
348        let mut all_scores: Vec<RelevanceScore> = Vec::new();
349
350        for chunk in candidates.chunks(MAX_DOCS_PER_PROMPT) {
351            let scores = self.rerank_with_retry(query, chunk, 3).map_err(|e| {
352                memvid_core::MemvidError::RerankFailed {
353                    reason: e.to_string().into_boxed_str(),
354                }
355            })?;
356            all_scores.extend(scores);
357        }
358
359        // Build results with original ranks
360        let mut results: Vec<RerankerResult> = all_scores
361            .into_iter()
362            .filter_map(|score| {
363                let original_rank = documents.iter().position(|d| d.id == score.id)?;
364                if score.score < self.config.min_score {
365                    return None;
366                }
367                Some(RerankerResult {
368                    id: score.id,
369                    score: score.score,
370                    original_rank: original_rank + 1,
371                    new_rank: 0, // Will be set after sorting
372                })
373            })
374            .collect();
375
376        // Sort by score descending
377        results.sort_by(|a, b| {
378            b.score
379                .partial_cmp(&a.score)
380                .unwrap_or(std::cmp::Ordering::Equal)
381        });
382
383        // Assign new ranks and limit to top_k
384        let top_k = top_k.min(self.config.top_k);
385        for (idx, result) in results.iter_mut().enumerate() {
386            result.new_rank = idx + 1;
387        }
388
389        Ok(results.into_iter().take(top_k).collect())
390    }
391
392    fn is_ready(&self) -> bool {
393        self.ready.load(Ordering::Relaxed)
394    }
395
396    fn init(&mut self) -> memvid_core::Result<()> {
397        info!("Initializing OpenAI reranker with model: {}", self.model);
398
399        // Test with a simple rerank to validate API key
400        let test_docs = vec![RerankerDocument::new(0, "Test document")];
401        let _ = self
402            .rerank_with_retry("test query", &[&test_docs[0]], 1)
403            .map_err(|e| memvid_core::MemvidError::RerankFailed {
404                reason: format!("Failed to initialize reranker: {}", e).into_boxed_str(),
405            })?;
406
407        info!("OpenAI reranker initialized successfully");
408        self.ready.store(true, Ordering::Relaxed);
409        Ok(())
410    }
411}
412
413/// Helper to create an OpenAI reranker or return None
414pub fn try_openai_reranker() -> Option<OpenAIReranker> {
415    match OpenAIReranker::from_env() {
416        Ok(reranker) => {
417            info!("OpenAI reranker available");
418            Some(reranker)
419        }
420        Err(e) => {
421            debug!("OpenAI reranker not available: {}", e);
422            None
423        }
424    }
425}
426
427#[cfg(test)]
428mod tests {
429    use super::*;
430
431    #[test]
432    fn test_empty_api_key() {
433        let result = OpenAIReranker::new(String::new(), None, RerankerConfig::default());
434        assert!(result.is_err());
435    }
436
437    #[test]
438    fn test_build_prompt() {
439        let reranker =
440            OpenAIReranker::new("test-key".to_string(), None, RerankerConfig::default()).unwrap();
441
442        let docs = vec![
443            RerankerDocument::new(1, "First document about Rust"),
444            RerankerDocument::new(2, "Second document about Python"),
445        ];
446
447        let doc_refs: Vec<&RerankerDocument> = docs.iter().collect();
448        let prompt = reranker.build_prompt("What is Rust?", &doc_refs);
449
450        assert!(prompt.contains("What is Rust?"));
451        assert!(prompt.contains("ID=1"));
452        assert!(prompt.contains("ID=2"));
453        assert!(prompt.contains("First document"));
454        assert!(prompt.contains("Second document"));
455    }
456
457    #[test]
458    fn test_parse_scores() {
459        let reranker =
460            OpenAIReranker::new("test-key".to_string(), None, RerankerConfig::default()).unwrap();
461
462        let response = r#"Here are the scores:
463[{"id": 1, "score": 0.95}, {"id": 2, "score": 0.42}]"#;
464
465        let scores = reranker.parse_scores(response).unwrap();
466        assert_eq!(scores.len(), 2);
467        assert_eq!(scores[0].id, 1);
468        assert!((scores[0].score - 0.95).abs() < 0.01);
469        assert_eq!(scores[1].id, 2);
470        assert!((scores[1].score - 0.42).abs() < 0.01);
471    }
472
473    #[test]
474    #[ignore] // Requires valid API key
475    fn test_real_rerank() {
476        let reranker = OpenAIReranker::from_env().expect("OPENAI_API_KEY must be set");
477
478        let docs = vec![
479            RerankerDocument::new(
480                1,
481                "Rust is a systems programming language focused on safety.",
482            ),
483            RerankerDocument::new(2, "Python is great for data science and machine learning."),
484            RerankerDocument::new(3, "Rust provides memory safety without garbage collection."),
485        ];
486
487        let results = reranker.rerank("What makes Rust safe?", &docs, 2).unwrap();
488        assert!(!results.is_empty());
489        // Document about Rust safety should rank higher
490        assert!(results[0].id == 1 || results[0].id == 3);
491    }
492}