1use super::http::{default_http_client, normalize_base_url, HttpClient};
4use super::types::*;
5use super::LlmClient;
6use crate::llm::types::{ToolResultContent, ToolResultContentField};
7use crate::retry::{AttemptOutcome, RetryConfig};
8use anyhow::{Context, Result};
9use async_trait::async_trait;
10use futures::StreamExt;
11use serde::Deserialize;
12use std::sync::Arc;
13use tokio::sync::mpsc;
14
15pub struct OpenAiClient {
17 pub(crate) api_key: SecretString,
18 pub(crate) model: String,
19 pub(crate) base_url: String,
20 pub(crate) http: Arc<dyn HttpClient>,
21 pub(crate) retry_config: RetryConfig,
22}
23
24impl OpenAiClient {
25 pub fn new(api_key: String, model: String) -> Self {
26 Self {
27 api_key: SecretString::new(api_key),
28 model,
29 base_url: "https://api.openai.com".to_string(),
30 http: default_http_client(),
31 retry_config: RetryConfig::default(),
32 }
33 }
34
35 pub fn with_base_url(mut self, base_url: String) -> Self {
36 self.base_url = normalize_base_url(&base_url);
37 self
38 }
39
40 pub fn with_retry_config(mut self, retry_config: RetryConfig) -> Self {
41 self.retry_config = retry_config;
42 self
43 }
44
45 pub fn with_http_client(mut self, http: Arc<dyn HttpClient>) -> Self {
46 self.http = http;
47 self
48 }
49
50 pub(crate) fn convert_messages(&self, messages: &[Message]) -> Vec<serde_json::Value> {
51 messages
52 .iter()
53 .map(|msg| {
54 let content: serde_json::Value = if msg.content.len() == 1 {
55 match &msg.content[0] {
56 ContentBlock::Text { text } => serde_json::json!(text),
57 ContentBlock::ToolResult {
58 tool_use_id,
59 content,
60 ..
61 } => {
62 let content_str = match content {
63 ToolResultContentField::Text(s) => s.clone(),
64 ToolResultContentField::Blocks(blocks) => blocks
65 .iter()
66 .filter_map(|b| {
67 if let ToolResultContent::Text { text } = b {
68 Some(text.clone())
69 } else {
70 None
71 }
72 })
73 .collect::<Vec<_>>()
74 .join("\n"),
75 };
76 return serde_json::json!({
77 "role": "tool",
78 "tool_call_id": tool_use_id,
79 "content": content_str,
80 });
81 }
82 _ => serde_json::json!(""),
83 }
84 } else {
85 serde_json::json!(msg
86 .content
87 .iter()
88 .map(|block| {
89 match block {
90 ContentBlock::Text { text } => serde_json::json!({
91 "type": "text",
92 "text": text,
93 }),
94 ContentBlock::Image { source } => serde_json::json!({
95 "type": "image_url",
96 "image_url": {
97 "url": format!(
98 "data:{};base64,{}",
99 source.media_type, source.data
100 ),
101 }
102 }),
103 ContentBlock::ToolUse { id, name, input } => serde_json::json!({
104 "type": "function",
105 "id": id,
106 "function": {
107 "name": name,
108 "arguments": input.to_string(),
109 }
110 }),
111 _ => serde_json::json!({}),
112 }
113 })
114 .collect::<Vec<_>>())
115 };
116
117 if msg.role == "assistant" {
120 let rc = msg.reasoning_content.as_deref().unwrap_or("");
121 let tool_calls: Vec<_> = msg.tool_calls();
122 if !tool_calls.is_empty() {
123 return serde_json::json!({
124 "role": "assistant",
125 "content": msg.text(),
126 "reasoning_content": rc,
127 "tool_calls": tool_calls.iter().map(|tc| {
128 serde_json::json!({
129 "id": tc.id,
130 "type": "function",
131 "function": {
132 "name": tc.name,
133 "arguments": tc.args.to_string(),
134 }
135 })
136 }).collect::<Vec<_>>(),
137 });
138 }
139 return serde_json::json!({
140 "role": "assistant",
141 "content": content,
142 "reasoning_content": rc,
143 });
144 }
145
146 serde_json::json!({
147 "role": msg.role,
148 "content": content,
149 })
150 })
151 .collect()
152 }
153
154 pub(crate) fn convert_tools(&self, tools: &[ToolDefinition]) -> Vec<serde_json::Value> {
155 tools
156 .iter()
157 .map(|t| {
158 serde_json::json!({
159 "type": "function",
160 "function": {
161 "name": t.name,
162 "description": t.description,
163 "parameters": t.parameters,
164 }
165 })
166 })
167 .collect()
168 }
169}
170
171#[async_trait]
172impl LlmClient for OpenAiClient {
173 async fn complete(
174 &self,
175 messages: &[Message],
176 system: Option<&str>,
177 tools: &[ToolDefinition],
178 ) -> Result<LlmResponse> {
179 {
180 let mut openai_messages = Vec::new();
181
182 if let Some(sys) = system {
183 openai_messages.push(serde_json::json!({
184 "role": "system",
185 "content": sys,
186 }));
187 }
188
189 openai_messages.extend(self.convert_messages(messages));
190
191 let mut request = serde_json::json!({
192 "model": self.model,
193 "messages": openai_messages,
194 });
195
196 if !tools.is_empty() {
197 request["tools"] = serde_json::json!(self.convert_tools(tools));
198 }
199
200 let url = format!("{}/v1/chat/completions", self.base_url);
201 let auth_header = format!("Bearer {}", self.api_key.expose());
202 let headers = vec![("Authorization", auth_header.as_str())];
203
204 let response = crate::retry::with_retry(&self.retry_config, |_attempt| {
205 let http = &self.http;
206 let url = &url;
207 let headers = headers.clone();
208 let request = &request;
209 async move {
210 match http.post(url, headers, request).await {
211 Ok(resp) => {
212 let status = reqwest::StatusCode::from_u16(resp.status)
213 .unwrap_or(reqwest::StatusCode::INTERNAL_SERVER_ERROR);
214 if status.is_success() {
215 AttemptOutcome::Success(resp.body)
216 } else if self.retry_config.is_retryable_status(status) {
217 AttemptOutcome::Retryable {
218 status,
219 body: resp.body,
220 retry_after: None,
221 }
222 } else {
223 AttemptOutcome::Fatal(anyhow::anyhow!(
224 "OpenAI API error at {} ({}): {}",
225 url,
226 status,
227 resp.body
228 ))
229 }
230 }
231 Err(e) => AttemptOutcome::Fatal(e),
232 }
233 }
234 })
235 .await?;
236
237 let parsed: OpenAiResponse =
238 serde_json::from_str(&response).context("Failed to parse OpenAI response")?;
239
240 let choice = parsed.choices.into_iter().next().context("No choices")?;
241
242 let mut content = vec![];
243
244 let reasoning_content = choice.message.reasoning_content.clone();
245
246 let text_content = choice.message.content.or(choice.message.reasoning_content);
247
248 if let Some(text) = text_content {
249 if !text.is_empty() {
250 content.push(ContentBlock::Text { text });
251 }
252 }
253
254 if let Some(tool_calls) = choice.message.tool_calls {
255 for tc in tool_calls {
256 content.push(ContentBlock::ToolUse {
257 id: tc.id,
258 name: tc.function.name.clone(),
259 input: serde_json::from_str(&tc.function.arguments).unwrap_or_else(|e| {
260 tracing::warn!(
261 "Failed to parse tool arguments JSON for tool '{}': {}",
262 tc.function.name,
263 e
264 );
265 serde_json::Value::default()
266 }),
267 });
268 }
269 }
270
271 let llm_response = LlmResponse {
272 message: Message {
273 role: "assistant".to_string(),
274 content,
275 reasoning_content,
276 },
277 usage: TokenUsage {
278 prompt_tokens: parsed.usage.prompt_tokens,
279 completion_tokens: parsed.usage.completion_tokens,
280 total_tokens: parsed.usage.total_tokens,
281 cache_read_tokens: None,
282 cache_write_tokens: None,
283 },
284 stop_reason: choice.finish_reason,
285 };
286
287 crate::telemetry::record_llm_usage(
288 llm_response.usage.prompt_tokens,
289 llm_response.usage.completion_tokens,
290 llm_response.usage.total_tokens,
291 llm_response.stop_reason.as_deref(),
292 );
293
294 Ok(llm_response)
295 }
296 }
297
298 async fn complete_streaming(
299 &self,
300 messages: &[Message],
301 system: Option<&str>,
302 tools: &[ToolDefinition],
303 ) -> Result<mpsc::Receiver<StreamEvent>> {
304 {
305 let mut openai_messages = Vec::new();
306
307 if let Some(sys) = system {
308 openai_messages.push(serde_json::json!({
309 "role": "system",
310 "content": sys,
311 }));
312 }
313
314 openai_messages.extend(self.convert_messages(messages));
315
316 let mut request = serde_json::json!({
317 "model": self.model,
318 "messages": openai_messages,
319 "stream": true,
320 "stream_options": { "include_usage": true },
321 });
322
323 if !tools.is_empty() {
324 request["tools"] = serde_json::json!(self.convert_tools(tools));
325 }
326
327 let url = format!("{}/v1/chat/completions", self.base_url);
328 let auth_header = format!("Bearer {}", self.api_key.expose());
329 let headers = vec![("Authorization", auth_header.as_str())];
330
331 let streaming_resp = crate::retry::with_retry(&self.retry_config, |_attempt| {
332 let http = &self.http;
333 let url = &url;
334 let headers = headers.clone();
335 let request = &request;
336 async move {
337 match http.post_streaming(url, headers, request).await {
338 Ok(resp) => {
339 let status = reqwest::StatusCode::from_u16(resp.status)
340 .unwrap_or(reqwest::StatusCode::INTERNAL_SERVER_ERROR);
341 if status.is_success() {
342 AttemptOutcome::Success(resp)
343 } else {
344 let retry_after = resp
345 .retry_after
346 .as_deref()
347 .and_then(|v| RetryConfig::parse_retry_after(Some(v)));
348 if self.retry_config.is_retryable_status(status) {
349 AttemptOutcome::Retryable {
350 status,
351 body: resp.error_body,
352 retry_after,
353 }
354 } else {
355 AttemptOutcome::Fatal(anyhow::anyhow!(
356 "OpenAI API error at {} ({}): {}",
357 url,
358 status,
359 resp.error_body
360 ))
361 }
362 }
363 }
364 Err(e) => AttemptOutcome::Fatal(anyhow::anyhow!(
365 "Failed to send streaming request: {}",
366 e
367 )),
368 }
369 }
370 })
371 .await?;
372
373 let (tx, rx) = mpsc::channel(100);
374
375 let mut stream = streaming_resp.byte_stream;
376 tokio::spawn(async move {
377 let mut buffer = String::new();
378 let mut content_blocks: Vec<ContentBlock> = Vec::new();
379 let mut text_content = String::new();
380 let mut reasoning_content_accum = String::new();
381 let mut tool_calls: std::collections::BTreeMap<usize, (String, String, String)> =
382 std::collections::BTreeMap::new();
383 let mut usage = TokenUsage::default();
384 let mut finish_reason = None;
385
386 while let Some(chunk_result) = stream.next().await {
387 let chunk = match chunk_result {
388 Ok(c) => c,
389 Err(e) => {
390 tracing::error!("Stream error: {}", e);
391 break;
392 }
393 };
394
395 buffer.push_str(&String::from_utf8_lossy(&chunk));
396
397 while let Some(event_end) = buffer.find("\n\n") {
398 let event_data: String = buffer.drain(..event_end).collect();
399 buffer.drain(..2);
400
401 for line in event_data.lines() {
402 if let Some(data) = line.strip_prefix("data: ") {
403 if data == "[DONE]" {
404 if !text_content.is_empty() {
405 content_blocks.push(ContentBlock::Text {
406 text: text_content.clone(),
407 });
408 }
409 for (_, (id, name, args)) in tool_calls.iter() {
410 content_blocks.push(ContentBlock::ToolUse {
411 id: id.clone(),
412 name: name.clone(),
413 input: serde_json::from_str(args).unwrap_or_else(|e| {
414 tracing::warn!(
415 "Failed to parse tool arguments JSON for tool '{}': {}",
416 name, e
417 );
418 serde_json::Value::default()
419 }),
420 });
421 }
422 tool_calls.clear();
423 crate::telemetry::record_llm_usage(
424 usage.prompt_tokens,
425 usage.completion_tokens,
426 usage.total_tokens,
427 finish_reason.as_deref(),
428 );
429 let response = LlmResponse {
430 message: Message {
431 role: "assistant".to_string(),
432 content: std::mem::take(&mut content_blocks),
433 reasoning_content: if reasoning_content_accum.is_empty()
434 {
435 None
436 } else {
437 Some(std::mem::take(&mut reasoning_content_accum))
438 },
439 },
440 usage: usage.clone(),
441 stop_reason: std::mem::take(&mut finish_reason),
442 };
443 let _ = tx.send(StreamEvent::Done(response)).await;
444 continue;
445 }
446
447 if let Ok(event) = serde_json::from_str::<OpenAiStreamChunk>(data) {
448 if let Some(u) = event.usage {
449 usage.prompt_tokens = u.prompt_tokens;
450 usage.completion_tokens = u.completion_tokens;
451 usage.total_tokens = u.total_tokens;
452 }
453
454 if let Some(choice) = event.choices.into_iter().next() {
455 if let Some(reason) = choice.finish_reason {
456 finish_reason = Some(reason);
457 }
458
459 if let Some(delta) = choice.delta {
460 if let Some(ref rc) = delta.reasoning_content {
461 reasoning_content_accum.push_str(rc);
462 }
463
464 let text_delta =
465 delta.content.or(delta.reasoning_content);
466 if let Some(content) = text_delta {
467 text_content.push_str(&content);
468 let _ =
469 tx.send(StreamEvent::TextDelta(content)).await;
470 }
471
472 if let Some(tcs) = delta.tool_calls {
473 for tc in tcs {
474 let entry = tool_calls
475 .entry(tc.index)
476 .or_insert_with(|| {
477 (
478 String::new(),
479 String::new(),
480 String::new(),
481 )
482 });
483
484 if let Some(id) = tc.id {
485 entry.0 = id;
486 }
487 if let Some(func) = tc.function {
488 if let Some(name) = func.name {
489 entry.1 = name.clone();
490 let _ = tx
491 .send(StreamEvent::ToolUseStart {
492 id: entry.0.clone(),
493 name,
494 })
495 .await;
496 }
497 if let Some(args) = func.arguments {
498 entry.2.push_str(&args);
499 let _ = tx
500 .send(
501 StreamEvent::ToolUseInputDelta(
502 args,
503 ),
504 )
505 .await;
506 }
507 }
508 }
509 }
510 }
511 }
512 }
513 }
514 }
515 }
516 }
517 });
518
519 Ok(rx)
520 }
521 }
522}
523
524#[derive(Debug, Deserialize)]
526pub(crate) struct OpenAiResponse {
527 pub(crate) choices: Vec<OpenAiChoice>,
528 pub(crate) usage: OpenAiUsage,
529}
530
531#[derive(Debug, Deserialize)]
532pub(crate) struct OpenAiChoice {
533 pub(crate) message: OpenAiMessage,
534 pub(crate) finish_reason: Option<String>,
535}
536
537#[derive(Debug, Deserialize)]
538pub(crate) struct OpenAiMessage {
539 pub(crate) reasoning_content: Option<String>,
540 pub(crate) content: Option<String>,
541 pub(crate) tool_calls: Option<Vec<OpenAiToolCall>>,
542}
543
544#[derive(Debug, Deserialize)]
545pub(crate) struct OpenAiToolCall {
546 pub(crate) id: String,
547 pub(crate) function: OpenAiFunction,
548}
549
550#[derive(Debug, Deserialize)]
551pub(crate) struct OpenAiFunction {
552 pub(crate) name: String,
553 pub(crate) arguments: String,
554}
555
556#[derive(Debug, Deserialize)]
557pub(crate) struct OpenAiUsage {
558 pub(crate) prompt_tokens: usize,
559 pub(crate) completion_tokens: usize,
560 pub(crate) total_tokens: usize,
561}
562
563#[derive(Debug, Deserialize)]
565pub(crate) struct OpenAiStreamChunk {
566 pub(crate) choices: Vec<OpenAiStreamChoice>,
567 pub(crate) usage: Option<OpenAiUsage>,
568}
569
570#[derive(Debug, Deserialize)]
571pub(crate) struct OpenAiStreamChoice {
572 pub(crate) delta: Option<OpenAiDelta>,
573 pub(crate) finish_reason: Option<String>,
574}
575
576#[derive(Debug, Deserialize)]
577pub(crate) struct OpenAiDelta {
578 pub(crate) reasoning_content: Option<String>,
579 pub(crate) content: Option<String>,
580 pub(crate) tool_calls: Option<Vec<OpenAiToolCallDelta>>,
581}
582
583#[derive(Debug, Deserialize)]
584pub(crate) struct OpenAiToolCallDelta {
585 pub(crate) index: usize,
586 pub(crate) id: Option<String>,
587 pub(crate) function: Option<OpenAiFunctionDelta>,
588}
589
590#[derive(Debug, Deserialize)]
591pub(crate) struct OpenAiFunctionDelta {
592 pub(crate) name: Option<String>,
593 pub(crate) arguments: Option<String>,
594}