ricecoder_providers/providers/
openai.rs

1//! OpenAI provider implementation
2//!
3//! Supports GPT-4, GPT-4o, and GPT-3.5-turbo models via the OpenAI API.
4
5use async_trait::async_trait;
6use reqwest::Client;
7use serde::{Deserialize, Serialize};
8use std::sync::Arc;
9use tracing::{debug, error, warn};
10
11use crate::error::ProviderError;
12use crate::models::{Capability, ChatRequest, ChatResponse, FinishReason, ModelInfo, TokenUsage};
13use crate::provider::Provider;
14use crate::token_counter::TokenCounter;
15
16/// OpenAI provider implementation
17pub struct OpenAiProvider {
18    api_key: String,
19    client: Arc<Client>,
20    base_url: String,
21    token_counter: Arc<TokenCounter>,
22}
23
24impl OpenAiProvider {
25    /// Create a new OpenAI provider instance
26    pub fn new(api_key: String) -> Result<Self, ProviderError> {
27        if api_key.is_empty() {
28            return Err(ProviderError::ConfigError(
29                "OpenAI API key is required".to_string(),
30            ));
31        }
32
33        Ok(Self {
34            api_key,
35            client: Arc::new(Client::new()),
36            base_url: "https://api.openai.com/v1".to_string(),
37            token_counter: Arc::new(TokenCounter::new()),
38        })
39    }
40
41    /// Create a new OpenAI provider with a custom base URL
42    pub fn with_base_url(api_key: String, base_url: String) -> Result<Self, ProviderError> {
43        if api_key.is_empty() {
44            return Err(ProviderError::ConfigError(
45                "OpenAI API key is required".to_string(),
46            ));
47        }
48
49        Ok(Self {
50            api_key,
51            client: Arc::new(Client::new()),
52            base_url,
53            token_counter: Arc::new(TokenCounter::new()),
54        })
55    }
56
57    /// Get the authorization header value (redacted for logging)
58    fn get_auth_header(&self) -> String {
59        format!("Bearer {}", self.api_key)
60    }
61
62    /// Convert OpenAI API response to our ChatResponse
63    fn convert_response(
64        response: OpenAiChatResponse,
65        model: String,
66    ) -> Result<ChatResponse, ProviderError> {
67        let content = response
68            .choices
69            .first()
70            .and_then(|c| c.message.as_ref())
71            .map(|m| m.content.clone())
72            .ok_or_else(|| ProviderError::ProviderError("No content in response".to_string()))?;
73
74        let finish_reason = match response
75            .choices
76            .first()
77            .and_then(|c| c.finish_reason.as_deref())
78        {
79            Some("stop") => FinishReason::Stop,
80            Some("length") => FinishReason::Length,
81            Some("error") => FinishReason::Error,
82            _ => FinishReason::Stop,
83        };
84
85        Ok(ChatResponse {
86            content,
87            model,
88            usage: TokenUsage {
89                prompt_tokens: response.usage.prompt_tokens,
90                completion_tokens: response.usage.completion_tokens,
91                total_tokens: response.usage.total_tokens,
92            },
93            finish_reason,
94        })
95    }
96}
97
98#[async_trait]
99impl Provider for OpenAiProvider {
100    fn id(&self) -> &str {
101        "openai"
102    }
103
104    fn name(&self) -> &str {
105        "OpenAI"
106    }
107
108    fn models(&self) -> Vec<ModelInfo> {
109        vec![
110            ModelInfo {
111                id: "gpt-4".to_string(),
112                name: "GPT-4".to_string(),
113                provider: "openai".to_string(),
114                context_window: 8192,
115                capabilities: vec![Capability::Chat, Capability::Code, Capability::Streaming],
116                pricing: Some(crate::models::Pricing {
117                    input_per_1k_tokens: 0.03,
118                    output_per_1k_tokens: 0.06,
119                }),
120            },
121            ModelInfo {
122                id: "gpt-4-turbo".to_string(),
123                name: "GPT-4 Turbo".to_string(),
124                provider: "openai".to_string(),
125                context_window: 128000,
126                capabilities: vec![
127                    Capability::Chat,
128                    Capability::Code,
129                    Capability::Vision,
130                    Capability::Streaming,
131                ],
132                pricing: Some(crate::models::Pricing {
133                    input_per_1k_tokens: 0.01,
134                    output_per_1k_tokens: 0.03,
135                }),
136            },
137            ModelInfo {
138                id: "gpt-4o".to_string(),
139                name: "GPT-4o".to_string(),
140                provider: "openai".to_string(),
141                context_window: 128000,
142                capabilities: vec![
143                    Capability::Chat,
144                    Capability::Code,
145                    Capability::Vision,
146                    Capability::Streaming,
147                ],
148                pricing: Some(crate::models::Pricing {
149                    input_per_1k_tokens: 0.005,
150                    output_per_1k_tokens: 0.015,
151                }),
152            },
153            ModelInfo {
154                id: "gpt-3.5-turbo".to_string(),
155                name: "GPT-3.5 Turbo".to_string(),
156                provider: "openai".to_string(),
157                context_window: 4096,
158                capabilities: vec![Capability::Chat, Capability::Code, Capability::Streaming],
159                pricing: Some(crate::models::Pricing {
160                    input_per_1k_tokens: 0.0005,
161                    output_per_1k_tokens: 0.0015,
162                }),
163            },
164        ]
165    }
166
167    async fn chat(&self, request: ChatRequest) -> Result<ChatResponse, ProviderError> {
168        // Validate model
169        let model_id = &request.model;
170        if !self.models().iter().any(|m| m.id == *model_id) {
171            return Err(ProviderError::InvalidModel(model_id.clone()));
172        }
173
174        let openai_request = OpenAiChatRequest {
175            model: request.model.clone(),
176            messages: request
177                .messages
178                .iter()
179                .map(|m| OpenAiMessage {
180                    role: m.role.clone(),
181                    content: m.content.clone(),
182                })
183                .collect(),
184            temperature: request.temperature,
185            max_tokens: request.max_tokens,
186        };
187
188        debug!(
189            "Sending chat request to OpenAI for model: {}",
190            request.model
191        );
192
193        let response = self
194            .client
195            .post(format!("{}/chat/completions", self.base_url))
196            .header("Authorization", self.get_auth_header())
197            .header("Content-Type", "application/json")
198            .json(&openai_request)
199            .send()
200            .await
201            .map_err(|e| {
202                error!("OpenAI API request failed: {}", e);
203                ProviderError::from(e)
204            })?;
205
206        let status = response.status();
207        if !status.is_success() {
208            let error_text = response.text().await.unwrap_or_default();
209            error!("OpenAI API error ({}): {}", status, error_text);
210
211            return match status.as_u16() {
212                401 => Err(ProviderError::AuthError),
213                429 => Err(ProviderError::RateLimited(60)),
214                _ => Err(ProviderError::ProviderError(format!(
215                    "OpenAI API error: {}",
216                    status
217                ))),
218            };
219        }
220
221        let openai_response: OpenAiChatResponse = response.json().await?;
222        Self::convert_response(openai_response, request.model)
223    }
224
225    async fn chat_stream(
226        &self,
227        _request: ChatRequest,
228    ) -> Result<crate::provider::ChatStream, ProviderError> {
229        // Streaming support will be implemented in a future iteration
230        Err(ProviderError::ProviderError(
231            "Streaming not yet implemented for OpenAI".to_string(),
232        ))
233    }
234
235    fn count_tokens(&self, content: &str, model: &str) -> Result<usize, ProviderError> {
236        // Validate model
237        if !self.models().iter().any(|m| m.id == model) {
238            return Err(ProviderError::InvalidModel(model.to_string()));
239        }
240
241        // Use token counter with caching for performance
242        let tokens = self.token_counter.count_tokens_openai(content, model);
243        Ok(tokens)
244    }
245
246    async fn health_check(&self) -> Result<bool, ProviderError> {
247        debug!("Performing health check for OpenAI provider");
248
249        // Try to list models as a health check
250        let response = self
251            .client
252            .get(format!("{}/models", self.base_url))
253            .header("Authorization", self.get_auth_header())
254            .send()
255            .await
256            .map_err(|e| {
257                warn!("OpenAI health check failed: {}", e);
258                ProviderError::from(e)
259            })?;
260
261        match response.status().as_u16() {
262            200 => {
263                debug!("OpenAI health check passed");
264                Ok(true)
265            }
266            401 => {
267                error!("OpenAI health check failed: authentication error");
268                Err(ProviderError::AuthError)
269            }
270            _ => {
271                warn!(
272                    "OpenAI health check failed with status: {}",
273                    response.status()
274                );
275                Ok(false)
276            }
277        }
278    }
279}
280
281/// OpenAI API request format
282#[derive(Debug, Serialize)]
283struct OpenAiChatRequest {
284    model: String,
285    messages: Vec<OpenAiMessage>,
286    #[serde(skip_serializing_if = "Option::is_none")]
287    temperature: Option<f32>,
288    #[serde(skip_serializing_if = "Option::is_none")]
289    max_tokens: Option<usize>,
290}
291
292/// OpenAI API message format
293#[derive(Debug, Serialize, Deserialize)]
294struct OpenAiMessage {
295    role: String,
296    content: String,
297}
298
299/// OpenAI API response format
300#[derive(Debug, Deserialize)]
301struct OpenAiChatResponse {
302    choices: Vec<OpenAiChoice>,
303    usage: OpenAiUsage,
304}
305
306/// OpenAI API choice format
307#[derive(Debug, Deserialize)]
308struct OpenAiChoice {
309    message: Option<OpenAiMessage>,
310    finish_reason: Option<String>,
311}
312
313/// OpenAI API usage format
314#[derive(Debug, Deserialize)]
315struct OpenAiUsage {
316    prompt_tokens: usize,
317    completion_tokens: usize,
318    total_tokens: usize,
319}
320
321#[cfg(test)]
322mod tests {
323    use super::*;
324
325    #[test]
326    fn test_openai_provider_creation() {
327        let provider = OpenAiProvider::new("test-key".to_string());
328        assert!(provider.is_ok());
329    }
330
331    #[test]
332    fn test_openai_provider_creation_empty_key() {
333        let provider = OpenAiProvider::new("".to_string());
334        assert!(provider.is_err());
335    }
336
337    #[test]
338    fn test_openai_provider_id() {
339        let provider = OpenAiProvider::new("test-key".to_string()).unwrap();
340        assert_eq!(provider.id(), "openai");
341    }
342
343    #[test]
344    fn test_openai_provider_name() {
345        let provider = OpenAiProvider::new("test-key".to_string()).unwrap();
346        assert_eq!(provider.name(), "OpenAI");
347    }
348
349    #[test]
350    fn test_openai_models() {
351        let provider = OpenAiProvider::new("test-key".to_string()).unwrap();
352        let models = provider.models();
353        assert_eq!(models.len(), 4);
354        assert!(models.iter().any(|m| m.id == "gpt-4"));
355        assert!(models.iter().any(|m| m.id == "gpt-4-turbo"));
356        assert!(models.iter().any(|m| m.id == "gpt-4o"));
357        assert!(models.iter().any(|m| m.id == "gpt-3.5-turbo"));
358    }
359
360    #[test]
361    fn test_token_counting() {
362        let provider = OpenAiProvider::new("test-key".to_string()).unwrap();
363        let tokens = provider.count_tokens("Hello, world!", "gpt-4").unwrap();
364        assert!(tokens > 0);
365    }
366
367    #[test]
368    fn test_token_counting_invalid_model() {
369        let provider = OpenAiProvider::new("test-key".to_string()).unwrap();
370        let result = provider.count_tokens("Hello, world!", "invalid-model");
371        assert!(result.is_err());
372    }
373}