1use async_trait::async_trait;
2use reqwest::Client;
3use serde::{Deserialize, Serialize};
4
5use super::{GenerationRequest, GenerationResponse, LLMProvider, ModelPricing};
6use crate::{Error, Result};
7
8const OPENAI_API_URL: &str = "https://api.openai.com/v1/chat/completions";
9
10pub struct OpenAIProvider {
11 client: Client,
12 api_key: String,
13 base_url: String,
14 model: String,
15 default_temperature: Option<f32>,
16 default_max_tokens: Option<u32>,
17 pricing: ModelPricing,
18}
19
20#[derive(Serialize)]
21struct OpenAIRequest {
22 model: String,
23 messages: Vec<Message>,
24 #[serde(skip_serializing_if = "Option::is_none")]
25 temperature: Option<f32>,
26 #[serde(skip_serializing_if = "Option::is_none")]
27 max_tokens: Option<u32>,
28}
29
30#[derive(Serialize)]
31struct Message {
32 role: String,
33 content: String,
34}
35
36#[derive(Deserialize)]
37struct OpenAIResponse {
38 choices: Vec<Choice>,
39 usage: Usage,
40}
41
42#[derive(Deserialize)]
43struct Choice {
44 message: ResponseMessage,
45}
46
47#[derive(Deserialize)]
48struct ResponseMessage {
49 content: String,
50}
51
52#[derive(Deserialize)]
53struct Usage {
54 prompt_tokens: u32,
55 completion_tokens: u32,
56}
57
58#[derive(Deserialize)]
59struct OpenAIError {
60 error: ErrorDetail,
61}
62
63#[derive(Deserialize)]
64struct ErrorDetail {
65 message: String,
66}
67
68impl OpenAIProvider {
69 pub fn new(
70 model: String,
71 api_key: Option<String>,
72 base_url: Option<String>,
73 temperature: Option<f32>,
74 max_tokens: Option<u32>,
75 ) -> Result<Self> {
76 let api_key = api_key
77 .or_else(|| std::env::var("OPENAI_API_KEY").ok())
78 .ok_or_else(|| {
79 Error::Provider("OPENAI_API_KEY not set and no api_key provided".into())
80 })?;
81
82 let base_url = base_url.unwrap_or_else(|| OPENAI_API_URL.to_string());
83 let client = Client::new();
84 let pricing = get_model_pricing(&model);
85
86 Ok(Self {
87 client,
88 api_key,
89 base_url,
90 model,
91 default_temperature: temperature,
92 default_max_tokens: max_tokens,
93 pricing,
94 })
95 }
96}
97
98#[async_trait]
99impl LLMProvider for OpenAIProvider {
100 async fn generate(&self, request: GenerationRequest) -> Result<GenerationResponse> {
101 let mut messages = Vec::new();
102
103 if let Some(system) = &request.system_prompt {
104 messages.push(Message {
105 role: "system".to_string(),
106 content: system.clone(),
107 });
108 }
109
110 messages.push(Message {
111 role: "user".to_string(),
112 content: request.prompt,
113 });
114
115 let temperature = request.temperature.or(self.default_temperature);
116 let max_tokens = request.max_tokens.or(self.default_max_tokens);
117
118 let openai_request = OpenAIRequest {
119 model: self.model.clone(),
120 messages,
121 temperature,
122 max_tokens,
123 };
124
125 let response = self
126 .client
127 .post(&self.base_url)
128 .header("Authorization", format!("Bearer {}", self.api_key))
129 .header("Content-Type", "application/json")
130 .json(&openai_request)
131 .send()
132 .await
133 .map_err(|e| Error::Provider(format!("Request failed: {}", e)))?;
134
135 let status = response.status();
136 let body = response
137 .text()
138 .await
139 .map_err(|e| Error::Provider(format!("Failed to read response: {}", e)))?;
140
141 if !status.is_success() {
142 let error: OpenAIError = serde_json::from_str(&body)
143 .map_err(|_| Error::Provider(format!("API error ({}): {}", status, body)))?;
144 return Err(Error::Provider(format!(
145 "OpenAI API error: {}",
146 error.error.message
147 )));
148 }
149
150 let response: OpenAIResponse = serde_json::from_str(&body)
151 .map_err(|e| Error::Provider(format!("Failed to parse response: {}", e)))?;
152
153 let content = response
154 .choices
155 .first()
156 .map(|c| c.message.content.clone())
157 .ok_or_else(|| Error::Provider("No response content".into()))?;
158
159 Ok(GenerationResponse {
160 content,
161 input_tokens: response.usage.prompt_tokens,
162 output_tokens: response.usage.completion_tokens,
163 })
164 }
165
166 fn estimate_cost(&self, input_tokens: u32, output_tokens: u32) -> f64 {
167 let input_cost = (input_tokens as f64 / 1_000_000.0) * self.pricing.input_per_million;
168 let output_cost = (output_tokens as f64 / 1_000_000.0) * self.pricing.output_per_million;
169 input_cost + output_cost
170 }
171
172 fn name(&self) -> &str {
173 "openai"
174 }
175
176 fn model(&self) -> &str {
177 &self.model
178 }
179}
180
181fn get_model_pricing(model: &str) -> ModelPricing {
182 match model {
183 "gpt-4o" => ModelPricing {
184 input_per_million: 2.50,
185 output_per_million: 10.00,
186 },
187 "gpt-4o-mini" => ModelPricing {
188 input_per_million: 0.15,
189 output_per_million: 0.60,
190 },
191 "gpt-4-turbo" | "gpt-4-turbo-preview" => ModelPricing {
192 input_per_million: 10.00,
193 output_per_million: 30.00,
194 },
195 "gpt-4" => ModelPricing {
196 input_per_million: 30.00,
197 output_per_million: 60.00,
198 },
199 "gpt-3.5-turbo" => ModelPricing {
200 input_per_million: 0.50,
201 output_per_million: 1.50,
202 },
203 "o1" => ModelPricing {
204 input_per_million: 15.00,
205 output_per_million: 60.00,
206 },
207 "o1-mini" => ModelPricing {
208 input_per_million: 3.00,
209 output_per_million: 12.00,
210 },
211 "o3-mini" => ModelPricing {
212 input_per_million: 1.10,
213 output_per_million: 4.40,
214 },
215 _ => ModelPricing {
216 input_per_million: 2.50,
217 output_per_million: 10.00,
218 },
219 }
220}
221
222#[cfg(test)]
223mod tests {
224 use super::*;
225
226 #[test]
227 fn test_model_pricing_gpt4o_mini() {
228 let pricing = get_model_pricing("gpt-4o-mini");
229 assert_eq!(pricing.input_per_million, 0.15);
230 assert_eq!(pricing.output_per_million, 0.60);
231 }
232
233 #[test]
234 fn test_cost_estimation() {
235 let provider = OpenAIProvider::new(
236 "gpt-4o-mini".to_string(),
237 Some("test-key".to_string()),
238 None,
239 None,
240 None,
241 )
242 .unwrap();
243
244 let cost = provider.estimate_cost(1_000_000, 1_000_000);
246 assert!((cost - 0.75).abs() < 0.001); }
248
249 #[test]
250 fn test_provider_name_and_model() {
251 let provider = OpenAIProvider::new(
252 "gpt-4o".to_string(),
253 Some("test-key".to_string()),
254 None,
255 Some(0.7),
256 Some(1000),
257 )
258 .unwrap();
259
260 assert_eq!(provider.name(), "openai");
261 assert_eq!(provider.model(), "gpt-4o");
262 }
263
264 #[test]
265 fn test_missing_api_key_error() {
266 unsafe { std::env::remove_var("OPENAI_API_KEY") };
267 let result = OpenAIProvider::new("gpt-4o".to_string(), None, None, None, None);
268 assert!(result.is_err());
269 }
270}