1use crate::types::{Message, ToolDefinition};
2use async_trait::async_trait;
3use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
4use serde::{Deserialize, Serialize};
5
6#[derive(Serialize)]
8struct ChatRequest {
9 model: String,
11 messages: Vec<Message>,
13 tools: Vec<ToolDefinition>,
15 #[serde(skip_serializing_if = "Option::is_none")]
17 temperature: Option<f32>,
18 #[serde(skip_serializing_if = "Option::is_none")]
20 max_tokens: Option<u32>,
21 #[serde(skip_serializing_if = "Option::is_none")]
23 stream: Option<bool>,
24}
25
26#[derive(Deserialize)]
28struct ChatResponse {
29 choices: Vec<Choice>,
31}
32
33#[derive(Deserialize)]
35struct StreamChunk {
36 choices: Vec<StreamChoice>,
37}
38
39#[derive(Deserialize)]
41struct StreamChoice {
42 delta: Delta,
43}
44
45#[derive(Deserialize)]
47struct Delta {
48 #[serde(default)]
49 content: Option<String>,
50 #[serde(default)]
51 tool_calls: Option<Vec<crate::types::ToolCall>>,
52}
53
54#[derive(Deserialize)]
56struct Choice {
57 message: Message,
59}
60
61pub struct OpenAIProvider {
74 client: reqwest::Client,
76 base_url: String,
78 api_key: String,
80 model: String,
82 temperature: Option<f32>,
84 max_tokens: Option<u32>,
86 custom_headers: HeaderMap,
88 max_retries: u32,
90 retry_delay_ms: u64,
92}
93
94impl Default for OpenAIProvider {
95 fn default() -> Self {
96 Self::new()
97 }
98}
99
100impl OpenAIProvider {
101 pub fn new() -> Self {
111 Self {
112 client: reqwest::Client::new(),
113 base_url: "https://api.openai.com/v1".into(),
114 api_key: "".into(),
115 model: "gpt-4o".into(),
116 temperature: None,
117 max_tokens: None,
118 custom_headers: HeaderMap::new(),
119 max_retries: 3,
120 retry_delay_ms: 1000,
121 }
122 }
123
124 pub fn base_url(mut self, value: impl Into<String>) -> Self {
135 self.base_url = value.into();
136 self
137 }
138
139 pub fn api_key(mut self, value: impl Into<String>) -> Self {
150 self.api_key = value.into();
151 self
152 }
153
154 pub fn model(mut self, value: impl Into<String>) -> Self {
165 self.model = value.into();
166 self
167 }
168
169 pub fn temperature(mut self, value: impl Into<Option<f32>>) -> Self {
180 self.temperature = value.into();
181 self
182 }
183
184 pub fn max_tokens(mut self, value: impl Into<Option<u32>>) -> Self {
195 self.max_tokens = value.into();
196 self
197 }
198
199 pub fn header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
214 self.custom_headers.insert(
215 HeaderName::try_from(key.into()).unwrap(),
216 HeaderValue::try_from(value.into()).unwrap(),
217 );
218 self
219 }
220
221 pub fn max_retries(mut self, retries: u32) -> Self {
232 self.max_retries = retries;
233 self
234 }
235
236 pub fn retry_delay(mut self, delay_ms: u64) -> Self {
247 self.retry_delay_ms = delay_ms;
248 self
249 }
250}
251
252#[async_trait]
253impl super::LLMProvider for OpenAIProvider {
254 async fn call(
255 &self,
256 messages: &[Message],
257 tools: &[ToolDefinition],
258 mut stream_callback: Option<&mut super::StreamCallback>,
259 ) -> anyhow::Result<Message> {
260 let mut attempt = 0;
261 loop {
262 attempt += 1;
263 tracing::debug!(
264 model = %self.model,
265 messages = messages.len(),
266 tools = tools.len(),
267 streaming = stream_callback.is_some(),
268 attempt = attempt,
269 max_retries = self.max_retries,
270 "Calling LLM API"
271 );
272
273 match self
274 .call_once(messages, tools, stream_callback.as_deref_mut())
275 .await
276 {
277 Ok(message) => return Ok(message),
278 Err(e) if attempt > self.max_retries => {
279 tracing::debug!("Max retries exceeded");
280 return Err(e);
281 }
282 Err(e) => {
283 tracing::debug!("API call failed, retrying: {}", e);
284 tokio::time::sleep(tokio::time::Duration::from_millis(self.retry_delay_ms))
285 .await;
286 }
287 }
288 }
289 }
290}
291
292impl OpenAIProvider {
293 async fn call_once(
294 &self,
295 messages: &[Message],
296 tools: &[ToolDefinition],
297 stream_callback: Option<&mut super::StreamCallback>,
298 ) -> anyhow::Result<Message> {
299 let request = ChatRequest {
300 model: self.model.clone(),
301 messages: messages.to_vec(),
302 tools: tools.to_vec(),
303 temperature: self.temperature,
304 max_tokens: self.max_tokens,
305 stream: if stream_callback.is_some() {
306 Some(true)
307 } else {
308 None
309 },
310 };
311
312 let response = self
313 .client
314 .post(format!("{}/chat/completions", self.base_url))
315 .header("Authorization", format!("Bearer {}", self.api_key))
316 .header("Content-Type", "application/json")
317 .headers(self.custom_headers.clone())
318 .json(&request)
319 .send()
320 .await?;
321
322 let status = response.status();
323 tracing::trace!("LLM API response status: {}", status);
324
325 if !status.is_success() {
326 let body = response.text().await?;
327 tracing::debug!("LLM API error: status={}, body={}", status, body);
328 anyhow::bail!("API error ({}): {}", status, body);
329 }
330
331 if let Some(callback) = stream_callback {
332 self.handle_stream(response, callback).await
333 } else {
334 let body = response.text().await?;
335 let chat_response: ChatResponse = serde_json::from_str(&body)
336 .map_err(|e| anyhow::anyhow!("Failed to parse response: {}. Body: {}", e, body))?;
337 tracing::debug!("LLM API call completed successfully");
338 Ok(chat_response.choices[0].message.clone())
339 }
340 }
341
342 async fn handle_stream(
343 &self,
344 response: reqwest::Response,
345 callback: &mut super::StreamCallback,
346 ) -> anyhow::Result<Message> {
347 use futures::TryStreamExt;
348
349 let mut stream = response.bytes_stream();
350 let mut buffer = String::new();
351 let mut content = String::new();
352 let mut tool_calls = Vec::new();
353
354 while let Some(chunk) = stream.try_next().await? {
355 buffer.push_str(&String::from_utf8_lossy(&chunk));
356
357 while let Some(line_end) = buffer.find('\n') {
358 let line = buffer[..line_end].trim().to_string();
359 buffer.drain(..=line_end);
360
361 if let Some(data) = line.strip_prefix("data: ") {
362 if data == "[DONE]" {
363 break;
364 }
365
366 if let Ok(chunk) = serde_json::from_str::<StreamChunk>(data) {
367 if let Some(choice) = chunk.choices.first() {
368 if let Some(delta_content) = &choice.delta.content {
369 content.push_str(delta_content);
370 callback(delta_content.clone());
371 }
372
373 if let Some(delta_tool_calls) = &choice.delta.tool_calls {
374 tool_calls.extend(delta_tool_calls.clone());
375 }
376 }
377 }
378 }
379 }
380 }
381
382 tracing::debug!("Streaming completed, total length: {}", content.len());
383 Ok(Message::Assistant {
384 content,
385 tool_calls: if tool_calls.is_empty() {
386 None
387 } else {
388 Some(tool_calls)
389 },
390 })
391 }
392}