ricecoder_providers/providers/
openai.rs1use async_trait::async_trait;
6use reqwest::Client;
7use serde::{Deserialize, Serialize};
8use std::sync::Arc;
9use tracing::{debug, error, warn};
10
11use crate::error::ProviderError;
12use crate::models::{Capability, ChatRequest, ChatResponse, FinishReason, ModelInfo, TokenUsage};
13use crate::provider::Provider;
14use crate::token_counter::TokenCounter;
15
16pub struct OpenAiProvider {
18 api_key: String,
19 client: Arc<Client>,
20 base_url: String,
21 token_counter: Arc<TokenCounter>,
22}
23
24impl OpenAiProvider {
25 pub fn new(api_key: String) -> Result<Self, ProviderError> {
27 if api_key.is_empty() {
28 return Err(ProviderError::ConfigError(
29 "OpenAI API key is required".to_string(),
30 ));
31 }
32
33 Ok(Self {
34 api_key,
35 client: Arc::new(Client::new()),
36 base_url: "https://api.openai.com/v1".to_string(),
37 token_counter: Arc::new(TokenCounter::new()),
38 })
39 }
40
41 pub fn with_base_url(api_key: String, base_url: String) -> Result<Self, ProviderError> {
43 if api_key.is_empty() {
44 return Err(ProviderError::ConfigError(
45 "OpenAI API key is required".to_string(),
46 ));
47 }
48
49 Ok(Self {
50 api_key,
51 client: Arc::new(Client::new()),
52 base_url,
53 token_counter: Arc::new(TokenCounter::new()),
54 })
55 }
56
57 fn get_auth_header(&self) -> String {
59 format!("Bearer {}", self.api_key)
60 }
61
62 fn convert_response(
64 response: OpenAiChatResponse,
65 model: String,
66 ) -> Result<ChatResponse, ProviderError> {
67 let content = response
68 .choices
69 .first()
70 .and_then(|c| c.message.as_ref())
71 .map(|m| m.content.clone())
72 .ok_or_else(|| ProviderError::ProviderError("No content in response".to_string()))?;
73
74 let finish_reason = match response
75 .choices
76 .first()
77 .and_then(|c| c.finish_reason.as_deref())
78 {
79 Some("stop") => FinishReason::Stop,
80 Some("length") => FinishReason::Length,
81 Some("error") => FinishReason::Error,
82 _ => FinishReason::Stop,
83 };
84
85 Ok(ChatResponse {
86 content,
87 model,
88 usage: TokenUsage {
89 prompt_tokens: response.usage.prompt_tokens,
90 completion_tokens: response.usage.completion_tokens,
91 total_tokens: response.usage.total_tokens,
92 },
93 finish_reason,
94 })
95 }
96}
97
98#[async_trait]
99impl Provider for OpenAiProvider {
100 fn id(&self) -> &str {
101 "openai"
102 }
103
104 fn name(&self) -> &str {
105 "OpenAI"
106 }
107
108 fn models(&self) -> Vec<ModelInfo> {
109 vec![
110 ModelInfo {
111 id: "gpt-4".to_string(),
112 name: "GPT-4".to_string(),
113 provider: "openai".to_string(),
114 context_window: 8192,
115 capabilities: vec![Capability::Chat, Capability::Code, Capability::Streaming],
116 pricing: Some(crate::models::Pricing {
117 input_per_1k_tokens: 0.03,
118 output_per_1k_tokens: 0.06,
119 }),
120 },
121 ModelInfo {
122 id: "gpt-4-turbo".to_string(),
123 name: "GPT-4 Turbo".to_string(),
124 provider: "openai".to_string(),
125 context_window: 128000,
126 capabilities: vec![
127 Capability::Chat,
128 Capability::Code,
129 Capability::Vision,
130 Capability::Streaming,
131 ],
132 pricing: Some(crate::models::Pricing {
133 input_per_1k_tokens: 0.01,
134 output_per_1k_tokens: 0.03,
135 }),
136 },
137 ModelInfo {
138 id: "gpt-4o".to_string(),
139 name: "GPT-4o".to_string(),
140 provider: "openai".to_string(),
141 context_window: 128000,
142 capabilities: vec![
143 Capability::Chat,
144 Capability::Code,
145 Capability::Vision,
146 Capability::Streaming,
147 ],
148 pricing: Some(crate::models::Pricing {
149 input_per_1k_tokens: 0.005,
150 output_per_1k_tokens: 0.015,
151 }),
152 },
153 ModelInfo {
154 id: "gpt-3.5-turbo".to_string(),
155 name: "GPT-3.5 Turbo".to_string(),
156 provider: "openai".to_string(),
157 context_window: 4096,
158 capabilities: vec![Capability::Chat, Capability::Code, Capability::Streaming],
159 pricing: Some(crate::models::Pricing {
160 input_per_1k_tokens: 0.0005,
161 output_per_1k_tokens: 0.0015,
162 }),
163 },
164 ]
165 }
166
167 async fn chat(&self, request: ChatRequest) -> Result<ChatResponse, ProviderError> {
168 let model_id = &request.model;
170 if !self.models().iter().any(|m| m.id == *model_id) {
171 return Err(ProviderError::InvalidModel(model_id.clone()));
172 }
173
174 let openai_request = OpenAiChatRequest {
175 model: request.model.clone(),
176 messages: request
177 .messages
178 .iter()
179 .map(|m| OpenAiMessage {
180 role: m.role.clone(),
181 content: m.content.clone(),
182 })
183 .collect(),
184 temperature: request.temperature,
185 max_tokens: request.max_tokens,
186 };
187
188 debug!(
189 "Sending chat request to OpenAI for model: {}",
190 request.model
191 );
192
193 let response = self
194 .client
195 .post(format!("{}/chat/completions", self.base_url))
196 .header("Authorization", self.get_auth_header())
197 .header("Content-Type", "application/json")
198 .json(&openai_request)
199 .send()
200 .await
201 .map_err(|e| {
202 error!("OpenAI API request failed: {}", e);
203 ProviderError::from(e)
204 })?;
205
206 let status = response.status();
207 if !status.is_success() {
208 let error_text = response.text().await.unwrap_or_default();
209 error!("OpenAI API error ({}): {}", status, error_text);
210
211 return match status.as_u16() {
212 401 => Err(ProviderError::AuthError),
213 429 => Err(ProviderError::RateLimited(60)),
214 _ => Err(ProviderError::ProviderError(format!(
215 "OpenAI API error: {}",
216 status
217 ))),
218 };
219 }
220
221 let openai_response: OpenAiChatResponse = response.json().await?;
222 Self::convert_response(openai_response, request.model)
223 }
224
225 async fn chat_stream(
226 &self,
227 _request: ChatRequest,
228 ) -> Result<crate::provider::ChatStream, ProviderError> {
229 Err(ProviderError::ProviderError(
231 "Streaming not yet implemented for OpenAI".to_string(),
232 ))
233 }
234
235 fn count_tokens(&self, content: &str, model: &str) -> Result<usize, ProviderError> {
236 if !self.models().iter().any(|m| m.id == model) {
238 return Err(ProviderError::InvalidModel(model.to_string()));
239 }
240
241 let tokens = self.token_counter.count_tokens_openai(content, model);
243 Ok(tokens)
244 }
245
246 async fn health_check(&self) -> Result<bool, ProviderError> {
247 debug!("Performing health check for OpenAI provider");
248
249 let response = self
251 .client
252 .get(format!("{}/models", self.base_url))
253 .header("Authorization", self.get_auth_header())
254 .send()
255 .await
256 .map_err(|e| {
257 warn!("OpenAI health check failed: {}", e);
258 ProviderError::from(e)
259 })?;
260
261 match response.status().as_u16() {
262 200 => {
263 debug!("OpenAI health check passed");
264 Ok(true)
265 }
266 401 => {
267 error!("OpenAI health check failed: authentication error");
268 Err(ProviderError::AuthError)
269 }
270 _ => {
271 warn!(
272 "OpenAI health check failed with status: {}",
273 response.status()
274 );
275 Ok(false)
276 }
277 }
278 }
279}
280
281#[derive(Debug, Serialize)]
283struct OpenAiChatRequest {
284 model: String,
285 messages: Vec<OpenAiMessage>,
286 #[serde(skip_serializing_if = "Option::is_none")]
287 temperature: Option<f32>,
288 #[serde(skip_serializing_if = "Option::is_none")]
289 max_tokens: Option<usize>,
290}
291
292#[derive(Debug, Serialize, Deserialize)]
294struct OpenAiMessage {
295 role: String,
296 content: String,
297}
298
299#[derive(Debug, Deserialize)]
301struct OpenAiChatResponse {
302 choices: Vec<OpenAiChoice>,
303 usage: OpenAiUsage,
304}
305
306#[derive(Debug, Deserialize)]
308struct OpenAiChoice {
309 message: Option<OpenAiMessage>,
310 finish_reason: Option<String>,
311}
312
313#[derive(Debug, Deserialize)]
315struct OpenAiUsage {
316 prompt_tokens: usize,
317 completion_tokens: usize,
318 total_tokens: usize,
319}
320
321#[cfg(test)]
322mod tests {
323 use super::*;
324
325 #[test]
326 fn test_openai_provider_creation() {
327 let provider = OpenAiProvider::new("test-key".to_string());
328 assert!(provider.is_ok());
329 }
330
331 #[test]
332 fn test_openai_provider_creation_empty_key() {
333 let provider = OpenAiProvider::new("".to_string());
334 assert!(provider.is_err());
335 }
336
337 #[test]
338 fn test_openai_provider_id() {
339 let provider = OpenAiProvider::new("test-key".to_string()).unwrap();
340 assert_eq!(provider.id(), "openai");
341 }
342
343 #[test]
344 fn test_openai_provider_name() {
345 let provider = OpenAiProvider::new("test-key".to_string()).unwrap();
346 assert_eq!(provider.name(), "OpenAI");
347 }
348
349 #[test]
350 fn test_openai_models() {
351 let provider = OpenAiProvider::new("test-key".to_string()).unwrap();
352 let models = provider.models();
353 assert_eq!(models.len(), 4);
354 assert!(models.iter().any(|m| m.id == "gpt-4"));
355 assert!(models.iter().any(|m| m.id == "gpt-4-turbo"));
356 assert!(models.iter().any(|m| m.id == "gpt-4o"));
357 assert!(models.iter().any(|m| m.id == "gpt-3.5-turbo"));
358 }
359
360 #[test]
361 fn test_token_counting() {
362 let provider = OpenAiProvider::new("test-key".to_string()).unwrap();
363 let tokens = provider.count_tokens("Hello, world!", "gpt-4").unwrap();
364 assert!(tokens > 0);
365 }
366
367 #[test]
368 fn test_token_counting_invalid_model() {
369 let provider = OpenAiProvider::new("test-key".to_string()).unwrap();
370 let result = provider.count_tokens("Hello, world!", "invalid-model");
371 assert!(result.is_err());
372 }
373}