1use super::http::{default_http_client, normalize_base_url, HttpClient};
4use super::types::*;
5use super::LlmClient;
6use crate::llm::types::{ToolResultContent, ToolResultContentField};
7use crate::retry::{AttemptOutcome, RetryConfig};
8use anyhow::{Context, Result};
9use async_trait::async_trait;
10use futures::StreamExt;
11use serde::Deserialize;
12use std::sync::Arc;
13use tokio::sync::mpsc;
14
15pub struct OpenAiClient {
17 pub(crate) api_key: SecretString,
18 pub(crate) model: String,
19 pub(crate) base_url: String,
20 pub(crate) temperature: Option<f32>,
21 pub(crate) max_tokens: Option<usize>,
22 pub(crate) http: Arc<dyn HttpClient>,
23 pub(crate) retry_config: RetryConfig,
24}
25
26impl OpenAiClient {
27 pub fn new(api_key: String, model: String) -> Self {
28 Self {
29 api_key: SecretString::new(api_key),
30 model,
31 base_url: "https://api.openai.com".to_string(),
32 temperature: None,
33 max_tokens: None,
34 http: default_http_client(),
35 retry_config: RetryConfig::default(),
36 }
37 }
38
39 pub fn with_base_url(mut self, base_url: String) -> Self {
40 self.base_url = normalize_base_url(&base_url);
41 self
42 }
43
44 pub fn with_temperature(mut self, temperature: f32) -> Self {
45 self.temperature = Some(temperature);
46 self
47 }
48
49 pub fn with_max_tokens(mut self, max_tokens: usize) -> Self {
50 self.max_tokens = Some(max_tokens);
51 self
52 }
53
54 pub fn with_retry_config(mut self, retry_config: RetryConfig) -> Self {
55 self.retry_config = retry_config;
56 self
57 }
58
59 pub fn with_http_client(mut self, http: Arc<dyn HttpClient>) -> Self {
60 self.http = http;
61 self
62 }
63
64 pub(crate) fn convert_messages(&self, messages: &[Message]) -> Vec<serde_json::Value> {
65 messages
66 .iter()
67 .map(|msg| {
68 let content: serde_json::Value = if msg.content.len() == 1 {
69 match &msg.content[0] {
70 ContentBlock::Text { text } => serde_json::json!(text),
71 ContentBlock::ToolResult {
72 tool_use_id,
73 content,
74 ..
75 } => {
76 let content_str = match content {
77 ToolResultContentField::Text(s) => s.clone(),
78 ToolResultContentField::Blocks(blocks) => blocks
79 .iter()
80 .filter_map(|b| {
81 if let ToolResultContent::Text { text } = b {
82 Some(text.clone())
83 } else {
84 None
85 }
86 })
87 .collect::<Vec<_>>()
88 .join("\n"),
89 };
90 return serde_json::json!({
91 "role": "tool",
92 "tool_call_id": tool_use_id,
93 "content": content_str,
94 });
95 }
96 _ => serde_json::json!(""),
97 }
98 } else {
99 serde_json::json!(msg
100 .content
101 .iter()
102 .map(|block| {
103 match block {
104 ContentBlock::Text { text } => serde_json::json!({
105 "type": "text",
106 "text": text,
107 }),
108 ContentBlock::Image { source } => serde_json::json!({
109 "type": "image_url",
110 "image_url": {
111 "url": format!(
112 "data:{};base64,{}",
113 source.media_type, source.data
114 ),
115 }
116 }),
117 ContentBlock::ToolUse { id, name, input } => serde_json::json!({
118 "type": "function",
119 "id": id,
120 "function": {
121 "name": name,
122 "arguments": input.to_string(),
123 }
124 }),
125 _ => serde_json::json!({}),
126 }
127 })
128 .collect::<Vec<_>>())
129 };
130
131 if msg.role == "assistant" {
134 let rc = msg.reasoning_content.as_deref().unwrap_or("");
135 let tool_calls: Vec<_> = msg.tool_calls();
136 if !tool_calls.is_empty() {
137 return serde_json::json!({
138 "role": "assistant",
139 "content": msg.text(),
140 "reasoning_content": rc,
141 "tool_calls": tool_calls.iter().map(|tc| {
142 serde_json::json!({
143 "id": tc.id,
144 "type": "function",
145 "function": {
146 "name": tc.name,
147 "arguments": tc.args.to_string(),
148 }
149 })
150 }).collect::<Vec<_>>(),
151 });
152 }
153 return serde_json::json!({
154 "role": "assistant",
155 "content": content,
156 "reasoning_content": rc,
157 });
158 }
159
160 serde_json::json!({
161 "role": msg.role,
162 "content": content,
163 })
164 })
165 .collect()
166 }
167
168 pub(crate) fn convert_tools(&self, tools: &[ToolDefinition]) -> Vec<serde_json::Value> {
169 tools
170 .iter()
171 .map(|t| {
172 serde_json::json!({
173 "type": "function",
174 "function": {
175 "name": t.name,
176 "description": t.description,
177 "parameters": t.parameters,
178 }
179 })
180 })
181 .collect()
182 }
183}
184
185#[async_trait]
186impl LlmClient for OpenAiClient {
187 async fn complete(
188 &self,
189 messages: &[Message],
190 system: Option<&str>,
191 tools: &[ToolDefinition],
192 ) -> Result<LlmResponse> {
193 {
194 let mut openai_messages = Vec::new();
195
196 if let Some(sys) = system {
197 openai_messages.push(serde_json::json!({
198 "role": "system",
199 "content": sys,
200 }));
201 }
202
203 openai_messages.extend(self.convert_messages(messages));
204
205 let mut request = serde_json::json!({
206 "model": self.model,
207 "messages": openai_messages,
208 });
209
210 if let Some(temp) = self.temperature {
211 request["temperature"] = serde_json::json!(temp);
212 }
213 if let Some(max) = self.max_tokens {
214 request["max_tokens"] = serde_json::json!(max);
215 }
216
217 if !tools.is_empty() {
218 request["tools"] = serde_json::json!(self.convert_tools(tools));
219 }
220
221 let url = format!("{}/v1/chat/completions", self.base_url);
222 let auth_header = format!("Bearer {}", self.api_key.expose());
223 let headers = vec![("Authorization", auth_header.as_str())];
224
225 let response = crate::retry::with_retry(&self.retry_config, |_attempt| {
226 let http = &self.http;
227 let url = &url;
228 let headers = headers.clone();
229 let request = &request;
230 async move {
231 match http.post(url, headers, request).await {
232 Ok(resp) => {
233 let status = reqwest::StatusCode::from_u16(resp.status)
234 .unwrap_or(reqwest::StatusCode::INTERNAL_SERVER_ERROR);
235 if status.is_success() {
236 AttemptOutcome::Success(resp.body)
237 } else if self.retry_config.is_retryable_status(status) {
238 AttemptOutcome::Retryable {
239 status,
240 body: resp.body,
241 retry_after: None,
242 }
243 } else {
244 AttemptOutcome::Fatal(anyhow::anyhow!(
245 "OpenAI API error at {} ({}): {}",
246 url,
247 status,
248 resp.body
249 ))
250 }
251 }
252 Err(e) => {
253 eprintln!("[DEBUG] HTTP error: {:?}", e);
254 AttemptOutcome::Fatal(e)
255 }
256 }
257 }
258 })
259 .await?;
260
261 let parsed: OpenAiResponse =
262 serde_json::from_str(&response).context("Failed to parse OpenAI response")?;
263
264 let choice = parsed.choices.into_iter().next().context("No choices")?;
265
266 let mut content = vec![];
267
268 let reasoning_content = choice.message.reasoning_content;
269
270 let text_content = choice.message.content;
271
272 if let Some(text) = text_content {
273 if !text.is_empty() {
274 content.push(ContentBlock::Text { text });
275 }
276 }
277
278 if let Some(tool_calls) = choice.message.tool_calls {
279 for tc in tool_calls {
280 content.push(ContentBlock::ToolUse {
281 id: tc.id,
282 name: tc.function.name.clone(),
283 input: serde_json::from_str(&tc.function.arguments).unwrap_or_else(|e| {
284 tracing::warn!(
285 "Failed to parse tool arguments JSON for tool '{}': {}",
286 tc.function.name,
287 e
288 );
289 serde_json::Value::default()
290 }),
291 });
292 }
293 }
294
295 let llm_response = LlmResponse {
296 message: Message {
297 role: "assistant".to_string(),
298 content,
299 reasoning_content,
300 },
301 usage: TokenUsage {
302 prompt_tokens: parsed.usage.prompt_tokens,
303 completion_tokens: parsed.usage.completion_tokens,
304 total_tokens: parsed.usage.total_tokens,
305 cache_read_tokens: parsed
306 .usage
307 .prompt_tokens_details
308 .as_ref()
309 .and_then(|d| d.cached_tokens),
310 cache_write_tokens: None,
311 },
312 stop_reason: choice.finish_reason,
313 };
314
315 crate::telemetry::record_llm_usage(
316 llm_response.usage.prompt_tokens,
317 llm_response.usage.completion_tokens,
318 llm_response.usage.total_tokens,
319 llm_response.stop_reason.as_deref(),
320 );
321
322 Ok(llm_response)
323 }
324 }
325
326 async fn complete_streaming(
327 &self,
328 messages: &[Message],
329 system: Option<&str>,
330 tools: &[ToolDefinition],
331 ) -> Result<mpsc::Receiver<StreamEvent>> {
332 {
333 let mut openai_messages = Vec::new();
334
335 if let Some(sys) = system {
336 openai_messages.push(serde_json::json!({
337 "role": "system",
338 "content": sys,
339 }));
340 }
341
342 openai_messages.extend(self.convert_messages(messages));
343
344 let mut request = serde_json::json!({
345 "model": self.model,
346 "messages": openai_messages,
347 "stream": true,
348 "stream_options": { "include_usage": true },
349 });
350
351 if let Some(temp) = self.temperature {
352 request["temperature"] = serde_json::json!(temp);
353 }
354 if let Some(max) = self.max_tokens {
355 request["max_tokens"] = serde_json::json!(max);
356 }
357
358 if !tools.is_empty() {
359 request["tools"] = serde_json::json!(self.convert_tools(tools));
360 }
361
362 let url = format!("{}/v1/chat/completions", self.base_url);
363 let auth_header = format!("Bearer {}", self.api_key.expose());
364 let headers = vec![("Authorization", auth_header.as_str())];
365
366 let streaming_resp = crate::retry::with_retry(&self.retry_config, |_attempt| {
367 let http = &self.http;
368 let url = &url;
369 let headers = headers.clone();
370 let request = &request;
371 async move {
372 match http.post_streaming(url, headers, request).await {
373 Ok(resp) => {
374 let status = reqwest::StatusCode::from_u16(resp.status)
375 .unwrap_or(reqwest::StatusCode::INTERNAL_SERVER_ERROR);
376 if status.is_success() {
377 AttemptOutcome::Success(resp)
378 } else {
379 let retry_after = resp
380 .retry_after
381 .as_deref()
382 .and_then(|v| RetryConfig::parse_retry_after(Some(v)));
383 if self.retry_config.is_retryable_status(status) {
384 AttemptOutcome::Retryable {
385 status,
386 body: resp.error_body,
387 retry_after,
388 }
389 } else {
390 AttemptOutcome::Fatal(anyhow::anyhow!(
391 "OpenAI API error at {} ({}): {}",
392 url,
393 status,
394 resp.error_body
395 ))
396 }
397 }
398 }
399 Err(e) => AttemptOutcome::Fatal(anyhow::anyhow!(
400 "Failed to send streaming request: {}",
401 e
402 )),
403 }
404 }
405 })
406 .await?;
407
408 let (tx, rx) = mpsc::channel(100);
409
410 let mut stream = streaming_resp.byte_stream;
411 tokio::spawn(async move {
412 let mut buffer = String::new();
413 let mut content_blocks: Vec<ContentBlock> = Vec::new();
414 let mut text_content = String::new();
415 let mut reasoning_content_accum = String::new();
416 let mut tool_calls: std::collections::BTreeMap<usize, (String, String, String)> =
417 std::collections::BTreeMap::new();
418 let mut usage = TokenUsage::default();
419 let mut finish_reason = None;
420
421 while let Some(chunk_result) = stream.next().await {
422 let chunk = match chunk_result {
423 Ok(c) => c,
424 Err(e) => {
425 tracing::error!("Stream error: {}", e);
426 break;
427 }
428 };
429
430 buffer.push_str(&String::from_utf8_lossy(&chunk));
431
432 while let Some(event_end) = buffer.find("\n\n") {
433 let event_data: String = buffer.drain(..event_end).collect();
434 buffer.drain(..2);
435
436 for line in event_data.lines() {
437 if let Some(data) = line.strip_prefix("data: ") {
438 if data == "[DONE]" {
439 if !text_content.is_empty() {
440 content_blocks.push(ContentBlock::Text {
441 text: text_content.clone(),
442 });
443 }
444 for (_, (id, name, args)) in tool_calls.iter() {
445 content_blocks.push(ContentBlock::ToolUse {
446 id: id.clone(),
447 name: name.clone(),
448 input: serde_json::from_str(args).unwrap_or_else(|e| {
449 tracing::warn!(
450 "Failed to parse tool arguments JSON for tool '{}': {}",
451 name, e
452 );
453 serde_json::Value::default()
454 }),
455 });
456 }
457 tool_calls.clear();
458 crate::telemetry::record_llm_usage(
459 usage.prompt_tokens,
460 usage.completion_tokens,
461 usage.total_tokens,
462 finish_reason.as_deref(),
463 );
464 let response = LlmResponse {
465 message: Message {
466 role: "assistant".to_string(),
467 content: std::mem::take(&mut content_blocks),
468 reasoning_content: if reasoning_content_accum.is_empty()
469 {
470 None
471 } else {
472 Some(std::mem::take(&mut reasoning_content_accum))
473 },
474 },
475 usage: usage.clone(),
476 stop_reason: std::mem::take(&mut finish_reason),
477 };
478 let _ = tx.send(StreamEvent::Done(response)).await;
479 continue;
480 }
481
482 if let Ok(event) = serde_json::from_str::<OpenAiStreamChunk>(data) {
483 if let Some(u) = event.usage {
484 usage.prompt_tokens = u.prompt_tokens;
485 usage.completion_tokens = u.completion_tokens;
486 usage.total_tokens = u.total_tokens;
487 usage.cache_read_tokens = u
488 .prompt_tokens_details
489 .as_ref()
490 .and_then(|d| d.cached_tokens);
491 }
492
493 if let Some(choice) = event.choices.into_iter().next() {
494 if let Some(reason) = choice.finish_reason {
495 finish_reason = Some(reason);
496 }
497
498 if let Some(delta) = choice.delta {
499 if let Some(ref rc) = delta.reasoning_content {
500 reasoning_content_accum.push_str(rc);
501 }
502
503 if let Some(content) = delta.content {
504 text_content.push_str(&content);
505 let _ =
506 tx.send(StreamEvent::TextDelta(content)).await;
507 }
508
509 if let Some(tcs) = delta.tool_calls {
510 for tc in tcs {
511 let entry = tool_calls
512 .entry(tc.index)
513 .or_insert_with(|| {
514 (
515 String::new(),
516 String::new(),
517 String::new(),
518 )
519 });
520
521 if let Some(id) = tc.id {
522 entry.0 = id;
523 }
524 if let Some(func) = tc.function {
525 if let Some(name) = func.name {
526 entry.1 = name.clone();
527 let _ = tx
528 .send(StreamEvent::ToolUseStart {
529 id: entry.0.clone(),
530 name,
531 })
532 .await;
533 }
534 if let Some(args) = func.arguments {
535 entry.2.push_str(&args);
536 let _ = tx
537 .send(
538 StreamEvent::ToolUseInputDelta(
539 args,
540 ),
541 )
542 .await;
543 }
544 }
545 }
546 }
547 }
548 }
549 }
550 }
551 }
552 }
553 }
554 });
555
556 Ok(rx)
557 }
558 }
559}
560
561#[derive(Debug, Deserialize)]
563pub(crate) struct OpenAiResponse {
564 pub(crate) choices: Vec<OpenAiChoice>,
565 pub(crate) usage: OpenAiUsage,
566}
567
568#[derive(Debug, Deserialize)]
569pub(crate) struct OpenAiChoice {
570 pub(crate) message: OpenAiMessage,
571 pub(crate) finish_reason: Option<String>,
572}
573
574#[derive(Debug, Deserialize)]
575pub(crate) struct OpenAiMessage {
576 pub(crate) reasoning_content: Option<String>,
577 pub(crate) content: Option<String>,
578 pub(crate) tool_calls: Option<Vec<OpenAiToolCall>>,
579}
580
581#[derive(Debug, Deserialize)]
582pub(crate) struct OpenAiToolCall {
583 pub(crate) id: String,
584 pub(crate) function: OpenAiFunction,
585}
586
587#[derive(Debug, Deserialize)]
588pub(crate) struct OpenAiFunction {
589 pub(crate) name: String,
590 pub(crate) arguments: String,
591}
592
593#[derive(Debug, Deserialize)]
594pub(crate) struct OpenAiUsage {
595 pub(crate) prompt_tokens: usize,
596 pub(crate) completion_tokens: usize,
597 pub(crate) total_tokens: usize,
598 #[serde(default)]
600 pub(crate) prompt_tokens_details: Option<OpenAiPromptTokensDetails>,
601}
602
603#[derive(Debug, Deserialize)]
604pub(crate) struct OpenAiPromptTokensDetails {
605 #[serde(default)]
606 pub(crate) cached_tokens: Option<usize>,
607}
608
609#[derive(Debug, Deserialize)]
611pub(crate) struct OpenAiStreamChunk {
612 pub(crate) choices: Vec<OpenAiStreamChoice>,
613 pub(crate) usage: Option<OpenAiUsage>,
614}
615
616#[derive(Debug, Deserialize)]
617pub(crate) struct OpenAiStreamChoice {
618 pub(crate) delta: Option<OpenAiDelta>,
619 pub(crate) finish_reason: Option<String>,
620}
621
622#[derive(Debug, Deserialize)]
623pub(crate) struct OpenAiDelta {
624 pub(crate) reasoning_content: Option<String>,
625 pub(crate) content: Option<String>,
626 pub(crate) tool_calls: Option<Vec<OpenAiToolCallDelta>>,
627}
628
629#[derive(Debug, Deserialize)]
630pub(crate) struct OpenAiToolCallDelta {
631 pub(crate) index: usize,
632 pub(crate) id: Option<String>,
633 pub(crate) function: Option<OpenAiFunctionDelta>,
634}
635
636#[derive(Debug, Deserialize)]
637pub(crate) struct OpenAiFunctionDelta {
638 pub(crate) name: Option<String>,
639 pub(crate) arguments: Option<String>,
640}