Skip to main content

synth_claw/providers/
openai.rs

1use async_trait::async_trait;
2use reqwest::Client;
3use serde::{Deserialize, Serialize};
4
5use super::{GenerationRequest, GenerationResponse, LLMProvider, ModelPricing};
6use crate::{Error, Result};
7
8const OPENAI_API_URL: &str = "https://api.openai.com/v1/chat/completions";
9
10pub struct OpenAIProvider {
11    client: Client,
12    api_key: String,
13    base_url: String,
14    model: String,
15    default_temperature: Option<f32>,
16    default_max_tokens: Option<u32>,
17    pricing: ModelPricing,
18}
19
20#[derive(Serialize)]
21struct OpenAIRequest {
22    model: String,
23    messages: Vec<Message>,
24    #[serde(skip_serializing_if = "Option::is_none")]
25    temperature: Option<f32>,
26    #[serde(skip_serializing_if = "Option::is_none")]
27    max_tokens: Option<u32>,
28}
29
30#[derive(Serialize)]
31struct Message {
32    role: String,
33    content: String,
34}
35
36#[derive(Deserialize)]
37struct OpenAIResponse {
38    choices: Vec<Choice>,
39    usage: Usage,
40}
41
42#[derive(Deserialize)]
43struct Choice {
44    message: ResponseMessage,
45}
46
47#[derive(Deserialize)]
48struct ResponseMessage {
49    content: String,
50}
51
52#[derive(Deserialize)]
53struct Usage {
54    prompt_tokens: u32,
55    completion_tokens: u32,
56}
57
58#[derive(Deserialize)]
59struct OpenAIError {
60    error: ErrorDetail,
61}
62
63#[derive(Deserialize)]
64struct ErrorDetail {
65    message: String,
66}
67
68impl OpenAIProvider {
69    pub fn new(
70        model: String,
71        api_key: Option<String>,
72        base_url: Option<String>,
73        temperature: Option<f32>,
74        max_tokens: Option<u32>,
75    ) -> Result<Self> {
76        let api_key = api_key
77            .or_else(|| std::env::var("OPENAI_API_KEY").ok())
78            .ok_or_else(|| {
79                Error::Provider("OPENAI_API_KEY not set and no api_key provided".into())
80            })?;
81
82        let base_url = base_url.unwrap_or_else(|| OPENAI_API_URL.to_string());
83        let client = Client::new();
84        let pricing = get_model_pricing(&model);
85
86        Ok(Self {
87            client,
88            api_key,
89            base_url,
90            model,
91            default_temperature: temperature,
92            default_max_tokens: max_tokens,
93            pricing,
94        })
95    }
96}
97
98#[async_trait]
99impl LLMProvider for OpenAIProvider {
100    async fn generate(&self, request: GenerationRequest) -> Result<GenerationResponse> {
101        let mut messages = Vec::new();
102
103        if let Some(system) = &request.system_prompt {
104            messages.push(Message {
105                role: "system".to_string(),
106                content: system.clone(),
107            });
108        }
109
110        messages.push(Message {
111            role: "user".to_string(),
112            content: request.prompt,
113        });
114
115        let temperature = request.temperature.or(self.default_temperature);
116        let max_tokens = request.max_tokens.or(self.default_max_tokens);
117
118        let openai_request = OpenAIRequest {
119            model: self.model.clone(),
120            messages,
121            temperature,
122            max_tokens,
123        };
124
125        let response = self
126            .client
127            .post(&self.base_url)
128            .header("Authorization", format!("Bearer {}", self.api_key))
129            .header("Content-Type", "application/json")
130            .json(&openai_request)
131            .send()
132            .await
133            .map_err(|e| Error::Provider(format!("Request failed: {}", e)))?;
134
135        let status = response.status();
136        let body = response
137            .text()
138            .await
139            .map_err(|e| Error::Provider(format!("Failed to read response: {}", e)))?;
140
141        if !status.is_success() {
142            let error: OpenAIError = serde_json::from_str(&body)
143                .map_err(|_| Error::Provider(format!("API error ({}): {}", status, body)))?;
144            return Err(Error::Provider(format!(
145                "OpenAI API error: {}",
146                error.error.message
147            )));
148        }
149
150        let response: OpenAIResponse = serde_json::from_str(&body)
151            .map_err(|e| Error::Provider(format!("Failed to parse response: {}", e)))?;
152
153        let content = response
154            .choices
155            .first()
156            .map(|c| c.message.content.clone())
157            .ok_or_else(|| Error::Provider("No response content".into()))?;
158
159        Ok(GenerationResponse {
160            content,
161            input_tokens: response.usage.prompt_tokens,
162            output_tokens: response.usage.completion_tokens,
163        })
164    }
165
166    fn estimate_cost(&self, input_tokens: u32, output_tokens: u32) -> f64 {
167        let input_cost = (input_tokens as f64 / 1_000_000.0) * self.pricing.input_per_million;
168        let output_cost = (output_tokens as f64 / 1_000_000.0) * self.pricing.output_per_million;
169        input_cost + output_cost
170    }
171
172    fn name(&self) -> &str {
173        "openai"
174    }
175
176    fn model(&self) -> &str {
177        &self.model
178    }
179}
180
181fn get_model_pricing(model: &str) -> ModelPricing {
182    match model {
183        "gpt-4o" => ModelPricing {
184            input_per_million: 2.50,
185            output_per_million: 10.00,
186        },
187        "gpt-4o-mini" => ModelPricing {
188            input_per_million: 0.15,
189            output_per_million: 0.60,
190        },
191        "gpt-4-turbo" | "gpt-4-turbo-preview" => ModelPricing {
192            input_per_million: 10.00,
193            output_per_million: 30.00,
194        },
195        "gpt-4" => ModelPricing {
196            input_per_million: 30.00,
197            output_per_million: 60.00,
198        },
199        "gpt-3.5-turbo" => ModelPricing {
200            input_per_million: 0.50,
201            output_per_million: 1.50,
202        },
203        "o1" => ModelPricing {
204            input_per_million: 15.00,
205            output_per_million: 60.00,
206        },
207        "o1-mini" => ModelPricing {
208            input_per_million: 3.00,
209            output_per_million: 12.00,
210        },
211        "o3-mini" => ModelPricing {
212            input_per_million: 1.10,
213            output_per_million: 4.40,
214        },
215        _ => ModelPricing {
216            input_per_million: 2.50,
217            output_per_million: 10.00,
218        },
219    }
220}
221
222#[cfg(test)]
223mod tests {
224    use super::*;
225
226    #[test]
227    fn test_model_pricing_gpt4o_mini() {
228        let pricing = get_model_pricing("gpt-4o-mini");
229        assert_eq!(pricing.input_per_million, 0.15);
230        assert_eq!(pricing.output_per_million, 0.60);
231    }
232
233    #[test]
234    fn test_cost_estimation() {
235        let provider = OpenAIProvider::new(
236            "gpt-4o-mini".to_string(),
237            Some("test-key".to_string()),
238            None,
239            None,
240            None,
241        )
242        .unwrap();
243
244        // 1M input tokens + 1M output tokens
245        let cost = provider.estimate_cost(1_000_000, 1_000_000);
246        assert!((cost - 0.75).abs() < 0.001); // $0.15 + $0.60 = $0.75
247    }
248
249    #[test]
250    fn test_provider_name_and_model() {
251        let provider = OpenAIProvider::new(
252            "gpt-4o".to_string(),
253            Some("test-key".to_string()),
254            None,
255            Some(0.7),
256            Some(1000),
257        )
258        .unwrap();
259
260        assert_eq!(provider.name(), "openai");
261        assert_eq!(provider.model(), "gpt-4o");
262    }
263
264    #[test]
265    fn test_missing_api_key_error() {
266        unsafe { std::env::remove_var("OPENAI_API_KEY") };
267        let result = OpenAIProvider::new("gpt-4o".to_string(), None, None, None, None);
268        assert!(result.is_err());
269    }
270}