1use super::http::{default_http_client, normalize_base_url, HttpClient};
4use super::types::*;
5use super::LlmClient;
6use crate::llm::types::{ToolResultContent, ToolResultContentField};
7use crate::retry::{AttemptOutcome, RetryConfig};
8use anyhow::{Context, Result};
9use async_trait::async_trait;
10use futures::StreamExt;
11use serde::Deserialize;
12use std::sync::Arc;
13use tokio::sync::mpsc;
14
15pub struct OpenAiClient {
17 pub(crate) api_key: SecretString,
18 pub(crate) model: String,
19 pub(crate) base_url: String,
20 pub(crate) temperature: Option<f32>,
21 pub(crate) max_tokens: Option<usize>,
22 pub(crate) http: Arc<dyn HttpClient>,
23 pub(crate) retry_config: RetryConfig,
24}
25
26impl OpenAiClient {
27 pub fn new(api_key: String, model: String) -> Self {
28 Self {
29 api_key: SecretString::new(api_key),
30 model,
31 base_url: "https://api.openai.com".to_string(),
32 temperature: None,
33 max_tokens: None,
34 http: default_http_client(),
35 retry_config: RetryConfig::default(),
36 }
37 }
38
39 pub fn with_base_url(mut self, base_url: String) -> Self {
40 self.base_url = normalize_base_url(&base_url);
41 self
42 }
43
44 pub fn with_temperature(mut self, temperature: f32) -> Self {
45 self.temperature = Some(temperature);
46 self
47 }
48
49 pub fn with_max_tokens(mut self, max_tokens: usize) -> Self {
50 self.max_tokens = Some(max_tokens);
51 self
52 }
53
54 pub fn with_retry_config(mut self, retry_config: RetryConfig) -> Self {
55 self.retry_config = retry_config;
56 self
57 }
58
59 pub fn with_http_client(mut self, http: Arc<dyn HttpClient>) -> Self {
60 self.http = http;
61 self
62 }
63
64 pub(crate) fn convert_messages(&self, messages: &[Message]) -> Vec<serde_json::Value> {
65 messages
66 .iter()
67 .map(|msg| {
68 let content: serde_json::Value = if msg.content.len() == 1 {
69 match &msg.content[0] {
70 ContentBlock::Text { text } => serde_json::json!(text),
71 ContentBlock::ToolResult {
72 tool_use_id,
73 content,
74 ..
75 } => {
76 let content_str = match content {
77 ToolResultContentField::Text(s) => s.clone(),
78 ToolResultContentField::Blocks(blocks) => blocks
79 .iter()
80 .filter_map(|b| {
81 if let ToolResultContent::Text { text } = b {
82 Some(text.clone())
83 } else {
84 None
85 }
86 })
87 .collect::<Vec<_>>()
88 .join("\n"),
89 };
90 return serde_json::json!({
91 "role": "tool",
92 "tool_call_id": tool_use_id,
93 "content": content_str,
94 });
95 }
96 _ => serde_json::json!(""),
97 }
98 } else {
99 serde_json::json!(msg
100 .content
101 .iter()
102 .map(|block| {
103 match block {
104 ContentBlock::Text { text } => serde_json::json!({
105 "type": "text",
106 "text": text,
107 }),
108 ContentBlock::Image { source } => serde_json::json!({
109 "type": "image_url",
110 "image_url": {
111 "url": format!(
112 "data:{};base64,{}",
113 source.media_type, source.data
114 ),
115 }
116 }),
117 ContentBlock::ToolUse { id, name, input } => serde_json::json!({
118 "type": "function",
119 "id": id,
120 "function": {
121 "name": name,
122 "arguments": input.to_string(),
123 }
124 }),
125 _ => serde_json::json!({}),
126 }
127 })
128 .collect::<Vec<_>>())
129 };
130
131 if msg.role == "assistant" {
134 let rc = msg.reasoning_content.as_deref().unwrap_or("");
135 let tool_calls: Vec<_> = msg.tool_calls();
136 if !tool_calls.is_empty() {
137 return serde_json::json!({
138 "role": "assistant",
139 "content": msg.text(),
140 "reasoning_content": rc,
141 "tool_calls": tool_calls.iter().map(|tc| {
142 serde_json::json!({
143 "id": tc.id,
144 "type": "function",
145 "function": {
146 "name": tc.name,
147 "arguments": tc.args.to_string(),
148 }
149 })
150 }).collect::<Vec<_>>(),
151 });
152 }
153 return serde_json::json!({
154 "role": "assistant",
155 "content": content,
156 "reasoning_content": rc,
157 });
158 }
159
160 serde_json::json!({
161 "role": msg.role,
162 "content": content,
163 })
164 })
165 .collect()
166 }
167
168 pub(crate) fn convert_tools(&self, tools: &[ToolDefinition]) -> Vec<serde_json::Value> {
169 tools
170 .iter()
171 .map(|t| {
172 serde_json::json!({
173 "type": "function",
174 "function": {
175 "name": t.name,
176 "description": t.description,
177 "parameters": t.parameters,
178 }
179 })
180 })
181 .collect()
182 }
183}
184
185#[async_trait]
186impl LlmClient for OpenAiClient {
187 async fn complete(
188 &self,
189 messages: &[Message],
190 system: Option<&str>,
191 tools: &[ToolDefinition],
192 ) -> Result<LlmResponse> {
193 {
194 let mut openai_messages = Vec::new();
195
196 if let Some(sys) = system {
197 openai_messages.push(serde_json::json!({
198 "role": "system",
199 "content": sys,
200 }));
201 }
202
203 openai_messages.extend(self.convert_messages(messages));
204
205 let mut request = serde_json::json!({
206 "model": self.model,
207 "messages": openai_messages,
208 });
209
210 if let Some(temp) = self.temperature {
211 request["temperature"] = serde_json::json!(temp);
212 }
213 if let Some(max) = self.max_tokens {
214 request["max_tokens"] = serde_json::json!(max);
215 }
216
217 if !tools.is_empty() {
218 request["tools"] = serde_json::json!(self.convert_tools(tools));
219 }
220
221 let url = format!("{}/v1/chat/completions", self.base_url);
222 let auth_header = format!("Bearer {}", self.api_key.expose());
223 let headers = vec![("Authorization", auth_header.as_str())];
224
225 let response = crate::retry::with_retry(&self.retry_config, |_attempt| {
226 let http = &self.http;
227 let url = &url;
228 let headers = headers.clone();
229 let request = &request;
230 async move {
231 match http.post(url, headers, request).await {
232 Ok(resp) => {
233 let status = reqwest::StatusCode::from_u16(resp.status)
234 .unwrap_or(reqwest::StatusCode::INTERNAL_SERVER_ERROR);
235 if status.is_success() {
236 AttemptOutcome::Success(resp.body)
237 } else if self.retry_config.is_retryable_status(status) {
238 AttemptOutcome::Retryable {
239 status,
240 body: resp.body,
241 retry_after: None,
242 }
243 } else {
244 AttemptOutcome::Fatal(anyhow::anyhow!(
245 "OpenAI API error at {} ({}): {}",
246 url,
247 status,
248 resp.body
249 ))
250 }
251 }
252 Err(e) => AttemptOutcome::Fatal(e),
253 }
254 }
255 })
256 .await?;
257
258 let parsed: OpenAiResponse =
259 serde_json::from_str(&response).context("Failed to parse OpenAI response")?;
260
261 let choice = parsed.choices.into_iter().next().context("No choices")?;
262
263 let mut content = vec![];
264
265 let reasoning_content = choice.message.reasoning_content;
266
267 let text_content = choice.message.content;
268
269 if let Some(text) = text_content {
270 if !text.is_empty() {
271 content.push(ContentBlock::Text { text });
272 }
273 }
274
275 if let Some(tool_calls) = choice.message.tool_calls {
276 for tc in tool_calls {
277 content.push(ContentBlock::ToolUse {
278 id: tc.id,
279 name: tc.function.name.clone(),
280 input: serde_json::from_str(&tc.function.arguments).unwrap_or_else(|e| {
281 tracing::warn!(
282 "Failed to parse tool arguments JSON for tool '{}': {}",
283 tc.function.name,
284 e
285 );
286 serde_json::Value::default()
287 }),
288 });
289 }
290 }
291
292 let llm_response = LlmResponse {
293 message: Message {
294 role: "assistant".to_string(),
295 content,
296 reasoning_content,
297 },
298 usage: TokenUsage {
299 prompt_tokens: parsed.usage.prompt_tokens,
300 completion_tokens: parsed.usage.completion_tokens,
301 total_tokens: parsed.usage.total_tokens,
302 cache_read_tokens: parsed
303 .usage
304 .prompt_tokens_details
305 .as_ref()
306 .and_then(|d| d.cached_tokens),
307 cache_write_tokens: None,
308 },
309 stop_reason: choice.finish_reason,
310 };
311
312 crate::telemetry::record_llm_usage(
313 llm_response.usage.prompt_tokens,
314 llm_response.usage.completion_tokens,
315 llm_response.usage.total_tokens,
316 llm_response.stop_reason.as_deref(),
317 );
318
319 Ok(llm_response)
320 }
321 }
322
323 async fn complete_streaming(
324 &self,
325 messages: &[Message],
326 system: Option<&str>,
327 tools: &[ToolDefinition],
328 ) -> Result<mpsc::Receiver<StreamEvent>> {
329 {
330 let mut openai_messages = Vec::new();
331
332 if let Some(sys) = system {
333 openai_messages.push(serde_json::json!({
334 "role": "system",
335 "content": sys,
336 }));
337 }
338
339 openai_messages.extend(self.convert_messages(messages));
340
341 let mut request = serde_json::json!({
342 "model": self.model,
343 "messages": openai_messages,
344 "stream": true,
345 "stream_options": { "include_usage": true },
346 });
347
348 if let Some(temp) = self.temperature {
349 request["temperature"] = serde_json::json!(temp);
350 }
351 if let Some(max) = self.max_tokens {
352 request["max_tokens"] = serde_json::json!(max);
353 }
354
355 if !tools.is_empty() {
356 request["tools"] = serde_json::json!(self.convert_tools(tools));
357 }
358
359 let url = format!("{}/v1/chat/completions", self.base_url);
360 let auth_header = format!("Bearer {}", self.api_key.expose());
361 let headers = vec![("Authorization", auth_header.as_str())];
362
363 let streaming_resp = crate::retry::with_retry(&self.retry_config, |_attempt| {
364 let http = &self.http;
365 let url = &url;
366 let headers = headers.clone();
367 let request = &request;
368 async move {
369 match http.post_streaming(url, headers, request).await {
370 Ok(resp) => {
371 let status = reqwest::StatusCode::from_u16(resp.status)
372 .unwrap_or(reqwest::StatusCode::INTERNAL_SERVER_ERROR);
373 if status.is_success() {
374 AttemptOutcome::Success(resp)
375 } else {
376 let retry_after = resp
377 .retry_after
378 .as_deref()
379 .and_then(|v| RetryConfig::parse_retry_after(Some(v)));
380 if self.retry_config.is_retryable_status(status) {
381 AttemptOutcome::Retryable {
382 status,
383 body: resp.error_body,
384 retry_after,
385 }
386 } else {
387 AttemptOutcome::Fatal(anyhow::anyhow!(
388 "OpenAI API error at {} ({}): {}",
389 url,
390 status,
391 resp.error_body
392 ))
393 }
394 }
395 }
396 Err(e) => AttemptOutcome::Fatal(anyhow::anyhow!(
397 "Failed to send streaming request: {}",
398 e
399 )),
400 }
401 }
402 })
403 .await?;
404
405 let (tx, rx) = mpsc::channel(100);
406
407 let mut stream = streaming_resp.byte_stream;
408 tokio::spawn(async move {
409 let mut buffer = String::new();
410 let mut content_blocks: Vec<ContentBlock> = Vec::new();
411 let mut text_content = String::new();
412 let mut reasoning_content_accum = String::new();
413 let mut tool_calls: std::collections::BTreeMap<usize, (String, String, String)> =
414 std::collections::BTreeMap::new();
415 let mut usage = TokenUsage::default();
416 let mut finish_reason = None;
417
418 while let Some(chunk_result) = stream.next().await {
419 let chunk = match chunk_result {
420 Ok(c) => c,
421 Err(e) => {
422 tracing::error!("Stream error: {}", e);
423 break;
424 }
425 };
426
427 buffer.push_str(&String::from_utf8_lossy(&chunk));
428
429 while let Some(event_end) = buffer.find("\n\n") {
430 let event_data: String = buffer.drain(..event_end).collect();
431 buffer.drain(..2);
432
433 for line in event_data.lines() {
434 if let Some(data) = line.strip_prefix("data: ") {
435 if data == "[DONE]" {
436 if !text_content.is_empty() {
437 content_blocks.push(ContentBlock::Text {
438 text: text_content.clone(),
439 });
440 }
441 for (_, (id, name, args)) in tool_calls.iter() {
442 content_blocks.push(ContentBlock::ToolUse {
443 id: id.clone(),
444 name: name.clone(),
445 input: serde_json::from_str(args).unwrap_or_else(|e| {
446 tracing::warn!(
447 "Failed to parse tool arguments JSON for tool '{}': {}",
448 name, e
449 );
450 serde_json::Value::default()
451 }),
452 });
453 }
454 tool_calls.clear();
455 crate::telemetry::record_llm_usage(
456 usage.prompt_tokens,
457 usage.completion_tokens,
458 usage.total_tokens,
459 finish_reason.as_deref(),
460 );
461 let response = LlmResponse {
462 message: Message {
463 role: "assistant".to_string(),
464 content: std::mem::take(&mut content_blocks),
465 reasoning_content: if reasoning_content_accum.is_empty()
466 {
467 None
468 } else {
469 Some(std::mem::take(&mut reasoning_content_accum))
470 },
471 },
472 usage: usage.clone(),
473 stop_reason: std::mem::take(&mut finish_reason),
474 };
475 let _ = tx.send(StreamEvent::Done(response)).await;
476 continue;
477 }
478
479 if let Ok(event) = serde_json::from_str::<OpenAiStreamChunk>(data) {
480 if let Some(u) = event.usage {
481 usage.prompt_tokens = u.prompt_tokens;
482 usage.completion_tokens = u.completion_tokens;
483 usage.total_tokens = u.total_tokens;
484 usage.cache_read_tokens = u
485 .prompt_tokens_details
486 .as_ref()
487 .and_then(|d| d.cached_tokens);
488 }
489
490 if let Some(choice) = event.choices.into_iter().next() {
491 if let Some(reason) = choice.finish_reason {
492 finish_reason = Some(reason);
493 }
494
495 if let Some(delta) = choice.delta {
496 if let Some(ref rc) = delta.reasoning_content {
497 reasoning_content_accum.push_str(rc);
498 }
499
500 if let Some(content) = delta.content {
501 text_content.push_str(&content);
502 let _ =
503 tx.send(StreamEvent::TextDelta(content)).await;
504 }
505
506 if let Some(tcs) = delta.tool_calls {
507 for tc in tcs {
508 let entry = tool_calls
509 .entry(tc.index)
510 .or_insert_with(|| {
511 (
512 String::new(),
513 String::new(),
514 String::new(),
515 )
516 });
517
518 if let Some(id) = tc.id {
519 entry.0 = id;
520 }
521 if let Some(func) = tc.function {
522 if let Some(name) = func.name {
523 entry.1 = name.clone();
524 let _ = tx
525 .send(StreamEvent::ToolUseStart {
526 id: entry.0.clone(),
527 name,
528 })
529 .await;
530 }
531 if let Some(args) = func.arguments {
532 entry.2.push_str(&args);
533 let _ = tx
534 .send(
535 StreamEvent::ToolUseInputDelta(
536 args,
537 ),
538 )
539 .await;
540 }
541 }
542 }
543 }
544 }
545 }
546 }
547 }
548 }
549 }
550 }
551 });
552
553 Ok(rx)
554 }
555 }
556}
557
558#[derive(Debug, Deserialize)]
560pub(crate) struct OpenAiResponse {
561 pub(crate) choices: Vec<OpenAiChoice>,
562 pub(crate) usage: OpenAiUsage,
563}
564
565#[derive(Debug, Deserialize)]
566pub(crate) struct OpenAiChoice {
567 pub(crate) message: OpenAiMessage,
568 pub(crate) finish_reason: Option<String>,
569}
570
571#[derive(Debug, Deserialize)]
572pub(crate) struct OpenAiMessage {
573 pub(crate) reasoning_content: Option<String>,
574 pub(crate) content: Option<String>,
575 pub(crate) tool_calls: Option<Vec<OpenAiToolCall>>,
576}
577
578#[derive(Debug, Deserialize)]
579pub(crate) struct OpenAiToolCall {
580 pub(crate) id: String,
581 pub(crate) function: OpenAiFunction,
582}
583
584#[derive(Debug, Deserialize)]
585pub(crate) struct OpenAiFunction {
586 pub(crate) name: String,
587 pub(crate) arguments: String,
588}
589
590#[derive(Debug, Deserialize)]
591pub(crate) struct OpenAiUsage {
592 pub(crate) prompt_tokens: usize,
593 pub(crate) completion_tokens: usize,
594 pub(crate) total_tokens: usize,
595 #[serde(default)]
597 pub(crate) prompt_tokens_details: Option<OpenAiPromptTokensDetails>,
598}
599
600#[derive(Debug, Deserialize)]
601pub(crate) struct OpenAiPromptTokensDetails {
602 #[serde(default)]
603 pub(crate) cached_tokens: Option<usize>,
604}
605
606#[derive(Debug, Deserialize)]
608pub(crate) struct OpenAiStreamChunk {
609 pub(crate) choices: Vec<OpenAiStreamChoice>,
610 pub(crate) usage: Option<OpenAiUsage>,
611}
612
613#[derive(Debug, Deserialize)]
614pub(crate) struct OpenAiStreamChoice {
615 pub(crate) delta: Option<OpenAiDelta>,
616 pub(crate) finish_reason: Option<String>,
617}
618
619#[derive(Debug, Deserialize)]
620pub(crate) struct OpenAiDelta {
621 pub(crate) reasoning_content: Option<String>,
622 pub(crate) content: Option<String>,
623 pub(crate) tool_calls: Option<Vec<OpenAiToolCallDelta>>,
624}
625
626#[derive(Debug, Deserialize)]
627pub(crate) struct OpenAiToolCallDelta {
628 pub(crate) index: usize,
629 pub(crate) id: Option<String>,
630 pub(crate) function: Option<OpenAiFunctionDelta>,
631}
632
633#[derive(Debug, Deserialize)]
634pub(crate) struct OpenAiFunctionDelta {
635 pub(crate) name: Option<String>,
636 pub(crate) arguments: Option<String>,
637}