manx_cli/rag/providers/
openai.rs

1use anyhow::{anyhow, Result};
2use async_trait::async_trait;
3use reqwest::Client;
4use serde::{Deserialize, Serialize};
5
6use super::{EmbeddingProvider as ProviderTrait, ProviderInfo};
7
8/// OpenAI API embedding provider
9pub struct OpenAiProvider {
10    client: Client,
11    api_key: String,
12    model: String,
13    dimension: Option<usize>, // Cached dimension
14}
15
16#[derive(Serialize)]
17struct EmbeddingRequest {
18    input: String,
19    model: String,
20    encoding_format: String,
21}
22
23#[derive(Deserialize)]
24struct EmbeddingResponse {
25    data: Vec<EmbeddingData>,
26    model: String,
27    usage: Usage,
28}
29
30#[derive(Deserialize)]
31struct EmbeddingData {
32    embedding: Vec<f32>,
33    index: usize,
34}
35
36#[derive(Deserialize)]
37struct Usage {
38    prompt_tokens: u32,
39    total_tokens: u32,
40}
41
42impl OpenAiProvider {
43    /// Create a new OpenAI provider
44    pub fn new(api_key: String, model: String) -> Self {
45        let client = Client::builder()
46            .timeout(std::time::Duration::from_secs(30))
47            .build()
48            .unwrap();
49
50        Self {
51            client,
52            api_key,
53            model,
54            dimension: None,
55        }
56    }
57
58    /// Detect dimension by making a test API call
59    #[allow(dead_code)]
60    pub async fn detect_dimension(&mut self) -> Result<usize> {
61        if let Some(dim) = self.dimension {
62            return Ok(dim);
63        }
64
65        log::info!(
66            "Detecting embedding dimension for OpenAI model: {}",
67            self.model
68        );
69
70        let test_embedding = self.call_api("test").await?;
71        let dimension = test_embedding.len();
72
73        self.dimension = Some(dimension);
74        log::info!("Detected dimension: {} for model {}", dimension, self.model);
75
76        Ok(dimension)
77    }
78
79    /// Get token usage from last API call
80    #[allow(dead_code)]
81    pub fn get_usage_stats(&self) -> Option<(u32, u32)> {
82        // This would store the last usage from call_api
83        // For now return None as we don't track it
84        None
85    }
86
87    /// Make API call to OpenAI embeddings endpoint
88    async fn call_api(&self, text: &str) -> Result<Vec<f32>> {
89        let request = EmbeddingRequest {
90            input: text.to_string(),
91            model: self.model.clone(),
92            encoding_format: "float".to_string(),
93        };
94
95        let response = self
96            .client
97            .post("https://api.openai.com/v1/embeddings")
98            .header("Authorization", format!("Bearer {}", self.api_key))
99            .header("Content-Type", "application/json")
100            .json(&request)
101            .send()
102            .await?;
103
104        let status = response.status();
105        if !status.is_success() {
106            let error_text = response.text().await.unwrap_or_default();
107            return Err(anyhow!(
108                "OpenAI API error: HTTP {} - {}",
109                status,
110                error_text
111            ));
112        }
113
114        let embedding_response: EmbeddingResponse = response.json().await?;
115
116        if embedding_response.data.is_empty() {
117            return Err(anyhow!("No embeddings returned from OpenAI API"));
118        }
119
120        // Log usage statistics
121        log::debug!(
122            "OpenAI API usage: {} prompt tokens, {} total tokens",
123            embedding_response.usage.prompt_tokens,
124            embedding_response.usage.total_tokens
125        );
126
127        // Ensure we have the right embedding (index should match)
128        if embedding_response.data[0].index != 0 {
129            log::warn!(
130                "Unexpected embedding index: {}",
131                embedding_response.data[0].index
132            );
133        }
134
135        // Verify model matches request
136        if embedding_response.model != self.model {
137            log::info!(
138                "API returned model: {} (requested: {})",
139                embedding_response.model,
140                self.model
141            );
142        }
143
144        Ok(embedding_response.data[0].embedding.clone())
145    }
146
147    /// Get common OpenAI model information
148    pub fn get_model_info(model: &str) -> (usize, usize) {
149        match model {
150            "text-embedding-3-small" => (1536, 8191),
151            "text-embedding-3-large" => (3072, 8191),
152            "text-embedding-ada-002" => (1536, 8191),
153            _ => (1536, 8191), // Default
154        }
155    }
156}
157
158#[async_trait]
159impl ProviderTrait for OpenAiProvider {
160    async fn embed_text(&self, text: &str) -> Result<Vec<f32>> {
161        if text.trim().is_empty() {
162            return Err(anyhow!("Cannot embed empty text"));
163        }
164
165        // Truncate text if too long (OpenAI models have token limits)
166        let (_, max_chars) = Self::get_model_info(&self.model);
167        let truncated_text = if text.len() > max_chars {
168            &text[..max_chars]
169        } else {
170            text
171        };
172
173        self.call_api(truncated_text).await
174    }
175
176    async fn get_dimension(&self) -> Result<usize> {
177        if let Some(dim) = self.dimension {
178            Ok(dim)
179        } else {
180            // Use known dimensions for common models
181            let (dim, _) = Self::get_model_info(&self.model);
182            Ok(dim)
183        }
184    }
185
186    async fn health_check(&self) -> Result<()> {
187        self.call_api("test").await.map(|_| ())
188    }
189
190    fn get_info(&self) -> ProviderInfo {
191        let (_, max_length) = Self::get_model_info(&self.model);
192
193        ProviderInfo {
194            name: "OpenAI Embeddings".to_string(),
195            provider_type: "openai".to_string(),
196            model_name: Some(self.model.clone()),
197            description: format!("OpenAI embeddings model: {}", self.model),
198            max_input_length: Some(max_length),
199        }
200    }
201}