1use super::http::{default_http_client, normalize_base_url, HttpClient};
4use super::types::*;
5use super::LlmClient;
6use crate::llm::types::{ToolResultContent, ToolResultContentField};
7use crate::retry::{AttemptOutcome, RetryConfig};
8use anyhow::{Context, Result};
9use async_trait::async_trait;
10use futures::StreamExt;
11use serde::Deserialize;
12use std::sync::Arc;
13use std::time::Instant;
14use tokio::sync::mpsc;
15
16pub struct OpenAiClient {
18 pub(crate) provider_name: String,
19 pub(crate) api_key: SecretString,
20 pub(crate) model: String,
21 pub(crate) base_url: String,
22 pub(crate) chat_completions_path: String,
23 pub(crate) temperature: Option<f32>,
24 pub(crate) max_tokens: Option<usize>,
25 pub(crate) http: Arc<dyn HttpClient>,
26 pub(crate) retry_config: RetryConfig,
27}
28
29impl OpenAiClient {
30 pub(crate) fn parse_tool_arguments(tool_name: &str, arguments: &str) -> serde_json::Value {
31 if arguments.trim().is_empty() {
32 return serde_json::Value::Object(Default::default());
33 }
34
35 serde_json::from_str(arguments).unwrap_or_else(|e| {
36 tracing::warn!(
37 "Failed to parse tool arguments JSON for tool '{}': {}",
38 tool_name,
39 e
40 );
41 serde_json::json!({
42 "__parse_error": format!(
43 "Malformed tool arguments: {}. Raw input: {}",
44 e, arguments
45 )
46 })
47 })
48 }
49
50 pub fn new(api_key: String, model: String) -> Self {
51 Self {
52 provider_name: "openai".to_string(),
53 api_key: SecretString::new(api_key),
54 model,
55 base_url: "https://api.openai.com".to_string(),
56 chat_completions_path: "/v1/chat/completions".to_string(),
57 temperature: None,
58 max_tokens: None,
59 http: default_http_client(),
60 retry_config: RetryConfig::default(),
61 }
62 }
63
64 pub fn with_base_url(mut self, base_url: String) -> Self {
65 self.base_url = normalize_base_url(&base_url);
66 self
67 }
68
69 pub fn with_provider_name(mut self, provider_name: impl Into<String>) -> Self {
70 self.provider_name = provider_name.into();
71 self
72 }
73
74 pub fn with_chat_completions_path(mut self, path: impl Into<String>) -> Self {
75 let path = path.into();
76 self.chat_completions_path = if path.starts_with('/') {
77 path
78 } else {
79 format!("/{}", path)
80 };
81 self
82 }
83
84 pub fn with_temperature(mut self, temperature: f32) -> Self {
85 self.temperature = Some(temperature);
86 self
87 }
88
89 pub fn with_max_tokens(mut self, max_tokens: usize) -> Self {
90 self.max_tokens = Some(max_tokens);
91 self
92 }
93
94 pub fn with_retry_config(mut self, retry_config: RetryConfig) -> Self {
95 self.retry_config = retry_config;
96 self
97 }
98
99 pub fn with_http_client(mut self, http: Arc<dyn HttpClient>) -> Self {
100 self.http = http;
101 self
102 }
103
104 pub(crate) fn convert_messages(&self, messages: &[Message]) -> Vec<serde_json::Value> {
105 messages
106 .iter()
107 .map(|msg| {
108 let content: serde_json::Value = if msg.content.len() == 1 {
109 match &msg.content[0] {
110 ContentBlock::Text { text } => serde_json::json!(text),
111 ContentBlock::ToolResult {
112 tool_use_id,
113 content,
114 ..
115 } => {
116 let content_str = match content {
117 ToolResultContentField::Text(s) => s.clone(),
118 ToolResultContentField::Blocks(blocks) => blocks
119 .iter()
120 .filter_map(|b| {
121 if let ToolResultContent::Text { text } = b {
122 Some(text.clone())
123 } else {
124 None
125 }
126 })
127 .collect::<Vec<_>>()
128 .join("\n"),
129 };
130 return serde_json::json!({
131 "role": "tool",
132 "tool_call_id": tool_use_id,
133 "content": content_str,
134 });
135 }
136 _ => serde_json::json!(""),
137 }
138 } else {
139 serde_json::json!(msg
140 .content
141 .iter()
142 .map(|block| {
143 match block {
144 ContentBlock::Text { text } => serde_json::json!({
145 "type": "text",
146 "text": text,
147 }),
148 ContentBlock::Image { source } => serde_json::json!({
149 "type": "image_url",
150 "image_url": {
151 "url": format!(
152 "data:{};base64,{}",
153 source.media_type, source.data
154 ),
155 }
156 }),
157 ContentBlock::ToolUse { id, name, input } => serde_json::json!({
158 "type": "function",
159 "id": id,
160 "function": {
161 "name": name,
162 "arguments": input.to_string(),
163 }
164 }),
165 _ => serde_json::json!({}),
166 }
167 })
168 .collect::<Vec<_>>())
169 };
170
171 if msg.role == "assistant" {
174 let rc = msg.reasoning_content.as_deref().unwrap_or("");
175 let tool_calls: Vec<_> = msg.tool_calls();
176 if !tool_calls.is_empty() {
177 return serde_json::json!({
178 "role": "assistant",
179 "content": msg.text(),
180 "reasoning_content": rc,
181 "tool_calls": tool_calls.iter().map(|tc| {
182 serde_json::json!({
183 "id": tc.id,
184 "type": "function",
185 "function": {
186 "name": tc.name,
187 "arguments": tc.args.to_string(),
188 }
189 })
190 }).collect::<Vec<_>>(),
191 });
192 }
193 return serde_json::json!({
194 "role": "assistant",
195 "content": content,
196 "reasoning_content": rc,
197 });
198 }
199
200 serde_json::json!({
201 "role": msg.role,
202 "content": content,
203 })
204 })
205 .collect()
206 }
207
208 pub(crate) fn convert_tools(&self, tools: &[ToolDefinition]) -> Vec<serde_json::Value> {
209 tools
210 .iter()
211 .map(|t| {
212 serde_json::json!({
213 "type": "function",
214 "function": {
215 "name": t.name,
216 "description": t.description,
217 "parameters": t.parameters,
218 }
219 })
220 })
221 .collect()
222 }
223}
224
225#[async_trait]
226impl LlmClient for OpenAiClient {
227 async fn complete(
228 &self,
229 messages: &[Message],
230 system: Option<&str>,
231 tools: &[ToolDefinition],
232 ) -> Result<LlmResponse> {
233 {
234 let request_started_at = Instant::now();
235 let mut openai_messages = Vec::new();
236
237 if let Some(sys) = system {
238 openai_messages.push(serde_json::json!({
239 "role": "system",
240 "content": sys,
241 }));
242 }
243
244 openai_messages.extend(self.convert_messages(messages));
245
246 let mut request = serde_json::json!({
247 "model": self.model,
248 "messages": openai_messages,
249 });
250
251 if let Some(temp) = self.temperature {
252 request["temperature"] = serde_json::json!(temp);
253 }
254 if let Some(max) = self.max_tokens {
255 request["max_tokens"] = serde_json::json!(max);
256 }
257
258 if !tools.is_empty() {
259 request["tools"] = serde_json::json!(self.convert_tools(tools));
260 }
261
262 let url = format!("{}{}", self.base_url, self.chat_completions_path);
263 let auth_header = format!("Bearer {}", self.api_key.expose());
264 let headers = vec![("Authorization", auth_header.as_str())];
265
266 let response = crate::retry::with_retry(&self.retry_config, |_attempt| {
267 let http = &self.http;
268 let url = &url;
269 let headers = headers.clone();
270 let request = &request;
271 async move {
272 match http.post(url, headers, request).await {
273 Ok(resp) => {
274 let status = reqwest::StatusCode::from_u16(resp.status)
275 .unwrap_or(reqwest::StatusCode::INTERNAL_SERVER_ERROR);
276 if status.is_success() {
277 AttemptOutcome::Success(resp.body)
278 } else if self.retry_config.is_retryable_status(status) {
279 AttemptOutcome::Retryable {
280 status,
281 body: resp.body,
282 retry_after: None,
283 }
284 } else {
285 AttemptOutcome::Fatal(anyhow::anyhow!(
286 "OpenAI API error at {} ({}): {}",
287 url,
288 status,
289 resp.body
290 ))
291 }
292 }
293 Err(e) => {
294 eprintln!("[DEBUG] HTTP error: {:?}", e);
295 AttemptOutcome::Fatal(e)
296 }
297 }
298 }
299 })
300 .await?;
301
302 let parsed: OpenAiResponse =
303 serde_json::from_str(&response).context("Failed to parse OpenAI response")?;
304
305 let choice = parsed.choices.into_iter().next().context("No choices")?;
306
307 let mut content = vec![];
308
309 let reasoning_content = choice.message.reasoning_content;
310
311 let text_content = choice.message.content;
312
313 if let Some(text) = text_content {
314 if !text.is_empty() {
315 content.push(ContentBlock::Text { text });
316 }
317 }
318
319 if let Some(tool_calls) = choice.message.tool_calls {
320 for tc in tool_calls {
321 content.push(ContentBlock::ToolUse {
322 id: tc.id,
323 name: tc.function.name.clone(),
324 input: Self::parse_tool_arguments(
325 &tc.function.name,
326 &tc.function.arguments,
327 ),
328 });
329 }
330 }
331
332 let llm_response = LlmResponse {
333 message: Message {
334 role: "assistant".to_string(),
335 content,
336 reasoning_content,
337 },
338 usage: TokenUsage {
339 prompt_tokens: parsed.usage.prompt_tokens,
340 completion_tokens: parsed.usage.completion_tokens,
341 total_tokens: parsed.usage.total_tokens,
342 cache_read_tokens: parsed
343 .usage
344 .prompt_tokens_details
345 .as_ref()
346 .and_then(|d| d.cached_tokens),
347 cache_write_tokens: None,
348 },
349 stop_reason: choice.finish_reason,
350 meta: Some(LlmResponseMeta {
351 provider: Some(self.provider_name.clone()),
352 request_model: Some(self.model.clone()),
353 request_url: Some(url.clone()),
354 response_id: parsed.id,
355 response_model: parsed.model,
356 response_object: parsed.object,
357 first_token_ms: None,
358 duration_ms: Some(request_started_at.elapsed().as_millis() as u64),
359 }),
360 };
361
362 crate::telemetry::record_llm_usage(
363 llm_response.usage.prompt_tokens,
364 llm_response.usage.completion_tokens,
365 llm_response.usage.total_tokens,
366 llm_response.stop_reason.as_deref(),
367 );
368
369 Ok(llm_response)
370 }
371 }
372
373 async fn complete_streaming(
374 &self,
375 messages: &[Message],
376 system: Option<&str>,
377 tools: &[ToolDefinition],
378 ) -> Result<mpsc::Receiver<StreamEvent>> {
379 {
380 let request_started_at = Instant::now();
381 let mut openai_messages = Vec::new();
382
383 if let Some(sys) = system {
384 openai_messages.push(serde_json::json!({
385 "role": "system",
386 "content": sys,
387 }));
388 }
389
390 openai_messages.extend(self.convert_messages(messages));
391
392 let mut request = serde_json::json!({
393 "model": self.model,
394 "messages": openai_messages,
395 "stream": true,
396 "stream_options": { "include_usage": true },
397 });
398
399 if let Some(temp) = self.temperature {
400 request["temperature"] = serde_json::json!(temp);
401 }
402 if let Some(max) = self.max_tokens {
403 request["max_tokens"] = serde_json::json!(max);
404 }
405
406 if !tools.is_empty() {
407 request["tools"] = serde_json::json!(self.convert_tools(tools));
408 }
409
410 let url = format!("{}{}", self.base_url, self.chat_completions_path);
411 let auth_header = format!("Bearer {}", self.api_key.expose());
412 let headers = vec![("Authorization", auth_header.as_str())];
413
414 let streaming_resp = crate::retry::with_retry(&self.retry_config, |_attempt| {
415 let http = &self.http;
416 let url = &url;
417 let headers = headers.clone();
418 let request = &request;
419 async move {
420 match http.post_streaming(url, headers, request).await {
421 Ok(resp) => {
422 let status = reqwest::StatusCode::from_u16(resp.status)
423 .unwrap_or(reqwest::StatusCode::INTERNAL_SERVER_ERROR);
424 if status.is_success() {
425 AttemptOutcome::Success(resp)
426 } else {
427 let retry_after = resp
428 .retry_after
429 .as_deref()
430 .and_then(|v| RetryConfig::parse_retry_after(Some(v)));
431 if self.retry_config.is_retryable_status(status) {
432 AttemptOutcome::Retryable {
433 status,
434 body: resp.error_body,
435 retry_after,
436 }
437 } else {
438 AttemptOutcome::Fatal(anyhow::anyhow!(
439 "OpenAI API error at {} ({}): {}",
440 url,
441 status,
442 resp.error_body
443 ))
444 }
445 }
446 }
447 Err(e) => AttemptOutcome::Fatal(anyhow::anyhow!(
448 "Failed to send streaming request: {}",
449 e
450 )),
451 }
452 }
453 })
454 .await?;
455
456 let (tx, rx) = mpsc::channel(100);
457
458 let mut stream = streaming_resp.byte_stream;
459 let provider_name = self.provider_name.clone();
460 let request_model = self.model.clone();
461 let request_url = url.clone();
462 tokio::spawn(async move {
463 let mut buffer = String::new();
464 let mut content_blocks: Vec<ContentBlock> = Vec::new();
465 let mut text_content = String::new();
466 let mut reasoning_content_accum = String::new();
467 let mut tool_calls: std::collections::BTreeMap<usize, (String, String, String)> =
468 std::collections::BTreeMap::new();
469 let mut usage = TokenUsage::default();
470 let mut finish_reason = None;
471 let mut response_id = None;
472 let mut response_model = None;
473 let mut response_object = None;
474 let mut first_token_ms = None;
475
476 while let Some(chunk_result) = stream.next().await {
477 let chunk = match chunk_result {
478 Ok(c) => c,
479 Err(e) => {
480 tracing::error!("Stream error: {}", e);
481 break;
482 }
483 };
484
485 buffer.push_str(&String::from_utf8_lossy(&chunk));
486
487 while let Some(event_end) = buffer.find("\n\n") {
488 let event_data: String = buffer.drain(..event_end).collect();
489 buffer.drain(..2);
490
491 for line in event_data.lines() {
492 if let Some(data) = line.strip_prefix("data: ") {
493 if data == "[DONE]" {
494 if !text_content.is_empty() {
495 content_blocks.push(ContentBlock::Text {
496 text: text_content.clone(),
497 });
498 }
499 for (_, (id, name, args)) in tool_calls.iter() {
500 content_blocks.push(ContentBlock::ToolUse {
501 id: id.clone(),
502 name: name.clone(),
503 input: Self::parse_tool_arguments(name, args),
504 });
505 }
506 tool_calls.clear();
507 crate::telemetry::record_llm_usage(
508 usage.prompt_tokens,
509 usage.completion_tokens,
510 usage.total_tokens,
511 finish_reason.as_deref(),
512 );
513 let response = LlmResponse {
514 message: Message {
515 role: "assistant".to_string(),
516 content: std::mem::take(&mut content_blocks),
517 reasoning_content: if reasoning_content_accum.is_empty()
518 {
519 None
520 } else {
521 Some(std::mem::take(&mut reasoning_content_accum))
522 },
523 },
524 usage: usage.clone(),
525 stop_reason: std::mem::take(&mut finish_reason),
526 meta: Some(LlmResponseMeta {
527 provider: Some(provider_name.clone()),
528 request_model: Some(request_model.clone()),
529 request_url: Some(request_url.clone()),
530 response_id: response_id.clone(),
531 response_model: response_model.clone(),
532 response_object: response_object.clone(),
533 first_token_ms,
534 duration_ms: Some(
535 request_started_at.elapsed().as_millis() as u64,
536 ),
537 }),
538 };
539 let _ = tx.send(StreamEvent::Done(response)).await;
540 continue;
541 }
542
543 if let Ok(event) = serde_json::from_str::<OpenAiStreamChunk>(data) {
544 if response_id.is_none() {
545 response_id = event.id.clone();
546 }
547 if response_model.is_none() {
548 response_model = event.model.clone();
549 }
550 if response_object.is_none() {
551 response_object = event.object.clone();
552 }
553 if let Some(u) = event.usage {
554 usage.prompt_tokens = u.prompt_tokens;
555 usage.completion_tokens = u.completion_tokens;
556 usage.total_tokens = u.total_tokens;
557 usage.cache_read_tokens = u
558 .prompt_tokens_details
559 .as_ref()
560 .and_then(|d| d.cached_tokens);
561 }
562
563 if let Some(choice) = event.choices.into_iter().next() {
564 if let Some(reason) = choice.finish_reason {
565 finish_reason = Some(reason);
566 }
567
568 if let Some(delta) = choice.delta {
569 if let Some(ref rc) = delta.reasoning_content {
570 reasoning_content_accum.push_str(rc);
571 }
572
573 if let Some(content) = delta.content {
574 if first_token_ms.is_none() {
575 first_token_ms = Some(
576 request_started_at.elapsed().as_millis()
577 as u64,
578 );
579 }
580 text_content.push_str(&content);
581 let _ =
582 tx.send(StreamEvent::TextDelta(content)).await;
583 }
584
585 if let Some(tcs) = delta.tool_calls {
586 for tc in tcs {
587 let entry = tool_calls
588 .entry(tc.index)
589 .or_insert_with(|| {
590 (
591 String::new(),
592 String::new(),
593 String::new(),
594 )
595 });
596
597 if let Some(id) = tc.id {
598 entry.0 = id;
599 }
600 if let Some(func) = tc.function {
601 if let Some(name) = func.name {
602 if first_token_ms.is_none() {
603 first_token_ms = Some(
604 request_started_at
605 .elapsed()
606 .as_millis()
607 as u64,
608 );
609 }
610 entry.1 = name.clone();
611 let _ = tx
612 .send(StreamEvent::ToolUseStart {
613 id: entry.0.clone(),
614 name,
615 })
616 .await;
617 }
618 if let Some(args) = func.arguments {
619 entry.2.push_str(&args);
620 let _ = tx
621 .send(
622 StreamEvent::ToolUseInputDelta(
623 args,
624 ),
625 )
626 .await;
627 }
628 }
629 }
630 }
631 }
632 }
633 }
634 }
635 }
636 }
637 }
638 });
639
640 Ok(rx)
641 }
642 }
643}
644
645#[derive(Debug, Deserialize)]
647pub(crate) struct OpenAiResponse {
648 #[serde(default)]
649 pub(crate) id: Option<String>,
650 #[serde(default)]
651 pub(crate) object: Option<String>,
652 #[serde(default)]
653 pub(crate) model: Option<String>,
654 pub(crate) choices: Vec<OpenAiChoice>,
655 pub(crate) usage: OpenAiUsage,
656}
657
658#[derive(Debug, Deserialize)]
659pub(crate) struct OpenAiChoice {
660 pub(crate) message: OpenAiMessage,
661 pub(crate) finish_reason: Option<String>,
662}
663
664#[derive(Debug, Deserialize)]
665pub(crate) struct OpenAiMessage {
666 pub(crate) reasoning_content: Option<String>,
667 pub(crate) content: Option<String>,
668 pub(crate) tool_calls: Option<Vec<OpenAiToolCall>>,
669}
670
671#[derive(Debug, Deserialize)]
672pub(crate) struct OpenAiToolCall {
673 pub(crate) id: String,
674 pub(crate) function: OpenAiFunction,
675}
676
677#[derive(Debug, Deserialize)]
678pub(crate) struct OpenAiFunction {
679 pub(crate) name: String,
680 pub(crate) arguments: String,
681}
682
683#[derive(Debug, Deserialize)]
684pub(crate) struct OpenAiUsage {
685 pub(crate) prompt_tokens: usize,
686 pub(crate) completion_tokens: usize,
687 pub(crate) total_tokens: usize,
688 #[serde(default)]
690 pub(crate) prompt_tokens_details: Option<OpenAiPromptTokensDetails>,
691}
692
693#[derive(Debug, Deserialize)]
694pub(crate) struct OpenAiPromptTokensDetails {
695 #[serde(default)]
696 pub(crate) cached_tokens: Option<usize>,
697}
698
699#[derive(Debug, Deserialize)]
701pub(crate) struct OpenAiStreamChunk {
702 #[serde(default)]
703 pub(crate) id: Option<String>,
704 #[serde(default)]
705 pub(crate) object: Option<String>,
706 #[serde(default)]
707 pub(crate) model: Option<String>,
708 pub(crate) choices: Vec<OpenAiStreamChoice>,
709 pub(crate) usage: Option<OpenAiUsage>,
710}
711
712#[derive(Debug, Deserialize)]
713pub(crate) struct OpenAiStreamChoice {
714 pub(crate) delta: Option<OpenAiDelta>,
715 pub(crate) finish_reason: Option<String>,
716}
717
718#[derive(Debug, Deserialize)]
719pub(crate) struct OpenAiDelta {
720 pub(crate) reasoning_content: Option<String>,
721 pub(crate) content: Option<String>,
722 pub(crate) tool_calls: Option<Vec<OpenAiToolCallDelta>>,
723}
724
725#[derive(Debug, Deserialize)]
726pub(crate) struct OpenAiToolCallDelta {
727 pub(crate) index: usize,
728 pub(crate) id: Option<String>,
729 pub(crate) function: Option<OpenAiFunctionDelta>,
730}
731
732#[derive(Debug, Deserialize)]
733pub(crate) struct OpenAiFunctionDelta {
734 pub(crate) name: Option<String>,
735 pub(crate) arguments: Option<String>,
736}