1use super::http::{default_http_client, normalize_base_url, HttpClient};
4use super::types::*;
5use super::LlmClient;
6use crate::retry::{AttemptOutcome, RetryConfig};
7use anyhow::{Context, Result};
8use async_trait::async_trait;
9use futures::StreamExt;
10use serde::Deserialize;
11use std::sync::Arc;
12use tokio::sync::mpsc;
13
14pub struct OpenAiClient {
16 pub(crate) api_key: SecretString,
17 pub(crate) model: String,
18 pub(crate) base_url: String,
19 pub(crate) http: Arc<dyn HttpClient>,
20 pub(crate) retry_config: RetryConfig,
21}
22
23impl OpenAiClient {
24 pub fn new(api_key: String, model: String) -> Self {
25 Self {
26 api_key: SecretString::new(api_key),
27 model,
28 base_url: "https://api.openai.com".to_string(),
29 http: default_http_client(),
30 retry_config: RetryConfig::default(),
31 }
32 }
33
34 pub fn with_base_url(mut self, base_url: String) -> Self {
35 self.base_url = normalize_base_url(&base_url);
36 self
37 }
38
39 pub fn with_retry_config(mut self, retry_config: RetryConfig) -> Self {
40 self.retry_config = retry_config;
41 self
42 }
43
44 pub fn with_http_client(mut self, http: Arc<dyn HttpClient>) -> Self {
45 self.http = http;
46 self
47 }
48
49 pub(crate) fn convert_messages(&self, messages: &[Message]) -> Vec<serde_json::Value> {
50 messages
51 .iter()
52 .map(|msg| {
53 let content: serde_json::Value = if msg.content.len() == 1 {
54 match &msg.content[0] {
55 ContentBlock::Text { text } => serde_json::json!(text),
56 ContentBlock::ToolResult {
57 tool_use_id,
58 content,
59 ..
60 } => {
61 return serde_json::json!({
62 "role": "tool",
63 "tool_call_id": tool_use_id,
64 "content": content,
65 });
66 }
67 _ => serde_json::json!(""),
68 }
69 } else {
70 serde_json::json!(msg
71 .content
72 .iter()
73 .map(|block| {
74 match block {
75 ContentBlock::Text { text } => serde_json::json!({
76 "type": "text",
77 "text": text,
78 }),
79 ContentBlock::ToolUse { id, name, input } => serde_json::json!({
80 "type": "function",
81 "id": id,
82 "function": {
83 "name": name,
84 "arguments": input.to_string(),
85 }
86 }),
87 _ => serde_json::json!({}),
88 }
89 })
90 .collect::<Vec<_>>())
91 };
92
93 if msg.role == "assistant" {
96 let rc = msg.reasoning_content.as_deref().unwrap_or("");
97 let tool_calls: Vec<_> = msg.tool_calls();
98 if !tool_calls.is_empty() {
99 return serde_json::json!({
100 "role": "assistant",
101 "content": msg.text(),
102 "reasoning_content": rc,
103 "tool_calls": tool_calls.iter().map(|tc| {
104 serde_json::json!({
105 "id": tc.id,
106 "type": "function",
107 "function": {
108 "name": tc.name,
109 "arguments": tc.args.to_string(),
110 }
111 })
112 }).collect::<Vec<_>>(),
113 });
114 }
115 return serde_json::json!({
116 "role": "assistant",
117 "content": content,
118 "reasoning_content": rc,
119 });
120 }
121
122 serde_json::json!({
123 "role": msg.role,
124 "content": content,
125 })
126 })
127 .collect()
128 }
129
130 pub(crate) fn convert_tools(&self, tools: &[ToolDefinition]) -> Vec<serde_json::Value> {
131 tools
132 .iter()
133 .map(|t| {
134 serde_json::json!({
135 "type": "function",
136 "function": {
137 "name": t.name,
138 "description": t.description,
139 "parameters": t.parameters,
140 }
141 })
142 })
143 .collect()
144 }
145}
146
147#[async_trait]
148impl LlmClient for OpenAiClient {
149 async fn complete(
150 &self,
151 messages: &[Message],
152 system: Option<&str>,
153 tools: &[ToolDefinition],
154 ) -> Result<LlmResponse> {
155 {
156 let mut openai_messages = Vec::new();
157
158 if let Some(sys) = system {
159 openai_messages.push(serde_json::json!({
160 "role": "system",
161 "content": sys,
162 }));
163 }
164
165 openai_messages.extend(self.convert_messages(messages));
166
167 let mut request = serde_json::json!({
168 "model": self.model,
169 "messages": openai_messages,
170 });
171
172 if !tools.is_empty() {
173 request["tools"] = serde_json::json!(self.convert_tools(tools));
174 }
175
176 let url = format!("{}/v1/chat/completions", self.base_url);
177 let auth_header = format!("Bearer {}", self.api_key.expose());
178 let headers = vec![("Authorization", auth_header.as_str())];
179
180 let response = crate::retry::with_retry(&self.retry_config, |_attempt| {
181 let http = &self.http;
182 let url = &url;
183 let headers = headers.clone();
184 let request = &request;
185 async move {
186 match http.post(url, headers, request).await {
187 Ok(resp) => {
188 let status = reqwest::StatusCode::from_u16(resp.status)
189 .unwrap_or(reqwest::StatusCode::INTERNAL_SERVER_ERROR);
190 if status.is_success() {
191 AttemptOutcome::Success(resp.body)
192 } else if self.retry_config.is_retryable_status(status) {
193 AttemptOutcome::Retryable {
194 status,
195 body: resp.body,
196 retry_after: None,
197 }
198 } else {
199 AttemptOutcome::Fatal(anyhow::anyhow!(
200 "OpenAI API error at {} ({}): {}",
201 url,
202 status,
203 resp.body
204 ))
205 }
206 }
207 Err(e) => AttemptOutcome::Fatal(e),
208 }
209 }
210 })
211 .await?;
212
213 let parsed: OpenAiResponse =
214 serde_json::from_str(&response).context("Failed to parse OpenAI response")?;
215
216 let choice = parsed.choices.into_iter().next().context("No choices")?;
217
218 let mut content = vec![];
219
220 let reasoning_content = choice.message.reasoning_content.clone();
221
222 let text_content = choice.message.content.or(choice.message.reasoning_content);
223
224 if let Some(text) = text_content {
225 if !text.is_empty() {
226 content.push(ContentBlock::Text { text });
227 }
228 }
229
230 if let Some(tool_calls) = choice.message.tool_calls {
231 for tc in tool_calls {
232 content.push(ContentBlock::ToolUse {
233 id: tc.id,
234 name: tc.function.name.clone(),
235 input: serde_json::from_str(&tc.function.arguments).unwrap_or_else(|e| {
236 tracing::warn!(
237 "Failed to parse tool arguments JSON for tool '{}': {}",
238 tc.function.name,
239 e
240 );
241 serde_json::Value::default()
242 }),
243 });
244 }
245 }
246
247 let llm_response = LlmResponse {
248 message: Message {
249 role: "assistant".to_string(),
250 content,
251 reasoning_content,
252 },
253 usage: TokenUsage {
254 prompt_tokens: parsed.usage.prompt_tokens,
255 completion_tokens: parsed.usage.completion_tokens,
256 total_tokens: parsed.usage.total_tokens,
257 cache_read_tokens: None,
258 cache_write_tokens: None,
259 },
260 stop_reason: choice.finish_reason,
261 };
262
263 crate::telemetry::record_llm_usage(
264 llm_response.usage.prompt_tokens,
265 llm_response.usage.completion_tokens,
266 llm_response.usage.total_tokens,
267 llm_response.stop_reason.as_deref(),
268 );
269
270 Ok(llm_response)
271 }
272 }
273
274 async fn complete_streaming(
275 &self,
276 messages: &[Message],
277 system: Option<&str>,
278 tools: &[ToolDefinition],
279 ) -> Result<mpsc::Receiver<StreamEvent>> {
280 {
281 let mut openai_messages = Vec::new();
282
283 if let Some(sys) = system {
284 openai_messages.push(serde_json::json!({
285 "role": "system",
286 "content": sys,
287 }));
288 }
289
290 openai_messages.extend(self.convert_messages(messages));
291
292 let mut request = serde_json::json!({
293 "model": self.model,
294 "messages": openai_messages,
295 "stream": true,
296 "stream_options": { "include_usage": true },
297 });
298
299 if !tools.is_empty() {
300 request["tools"] = serde_json::json!(self.convert_tools(tools));
301 }
302
303 let url = format!("{}/v1/chat/completions", self.base_url);
304 let auth_header = format!("Bearer {}", self.api_key.expose());
305 let headers = vec![("Authorization", auth_header.as_str())];
306
307 let streaming_resp = crate::retry::with_retry(&self.retry_config, |_attempt| {
308 let http = &self.http;
309 let url = &url;
310 let headers = headers.clone();
311 let request = &request;
312 async move {
313 match http.post_streaming(url, headers, request).await {
314 Ok(resp) => {
315 let status = reqwest::StatusCode::from_u16(resp.status)
316 .unwrap_or(reqwest::StatusCode::INTERNAL_SERVER_ERROR);
317 if status.is_success() {
318 AttemptOutcome::Success(resp)
319 } else {
320 let retry_after = resp
321 .retry_after
322 .as_deref()
323 .and_then(|v| RetryConfig::parse_retry_after(Some(v)));
324 if self.retry_config.is_retryable_status(status) {
325 AttemptOutcome::Retryable {
326 status,
327 body: resp.error_body,
328 retry_after,
329 }
330 } else {
331 AttemptOutcome::Fatal(anyhow::anyhow!(
332 "OpenAI API error at {} ({}): {}",
333 url,
334 status,
335 resp.error_body
336 ))
337 }
338 }
339 }
340 Err(e) => AttemptOutcome::Fatal(anyhow::anyhow!(
341 "Failed to send streaming request: {}",
342 e
343 )),
344 }
345 }
346 })
347 .await?;
348
349 let (tx, rx) = mpsc::channel(100);
350
351 let mut stream = streaming_resp.byte_stream;
352 tokio::spawn(async move {
353 let mut buffer = String::new();
354 let mut content_blocks: Vec<ContentBlock> = Vec::new();
355 let mut text_content = String::new();
356 let mut reasoning_content_accum = String::new();
357 let mut tool_calls: std::collections::BTreeMap<usize, (String, String, String)> =
358 std::collections::BTreeMap::new();
359 let mut usage = TokenUsage::default();
360 let mut finish_reason = None;
361
362 while let Some(chunk_result) = stream.next().await {
363 let chunk = match chunk_result {
364 Ok(c) => c,
365 Err(e) => {
366 tracing::error!("Stream error: {}", e);
367 break;
368 }
369 };
370
371 buffer.push_str(&String::from_utf8_lossy(&chunk));
372
373 while let Some(event_end) = buffer.find("\n\n") {
374 let event_data: String = buffer.drain(..event_end).collect();
375 buffer.drain(..2);
376
377 for line in event_data.lines() {
378 if let Some(data) = line.strip_prefix("data: ") {
379 if data == "[DONE]" {
380 if !text_content.is_empty() {
381 content_blocks.push(ContentBlock::Text {
382 text: text_content.clone(),
383 });
384 }
385 for (_, (id, name, args)) in tool_calls.iter() {
386 content_blocks.push(ContentBlock::ToolUse {
387 id: id.clone(),
388 name: name.clone(),
389 input: serde_json::from_str(args).unwrap_or_else(|e| {
390 tracing::warn!(
391 "Failed to parse tool arguments JSON for tool '{}': {}",
392 name, e
393 );
394 serde_json::Value::default()
395 }),
396 });
397 }
398 tool_calls.clear();
399 crate::telemetry::record_llm_usage(
400 usage.prompt_tokens,
401 usage.completion_tokens,
402 usage.total_tokens,
403 finish_reason.as_deref(),
404 );
405 let response = LlmResponse {
406 message: Message {
407 role: "assistant".to_string(),
408 content: std::mem::take(&mut content_blocks),
409 reasoning_content: if reasoning_content_accum.is_empty()
410 {
411 None
412 } else {
413 Some(std::mem::take(&mut reasoning_content_accum))
414 },
415 },
416 usage: usage.clone(),
417 stop_reason: std::mem::take(&mut finish_reason),
418 };
419 let _ = tx.send(StreamEvent::Done(response)).await;
420 continue;
421 }
422
423 if let Ok(event) = serde_json::from_str::<OpenAiStreamChunk>(data) {
424 if let Some(u) = event.usage {
425 usage.prompt_tokens = u.prompt_tokens;
426 usage.completion_tokens = u.completion_tokens;
427 usage.total_tokens = u.total_tokens;
428 }
429
430 if let Some(choice) = event.choices.into_iter().next() {
431 if let Some(reason) = choice.finish_reason {
432 finish_reason = Some(reason);
433 }
434
435 if let Some(delta) = choice.delta {
436 if let Some(ref rc) = delta.reasoning_content {
437 reasoning_content_accum.push_str(rc);
438 }
439
440 let text_delta =
441 delta.content.or(delta.reasoning_content);
442 if let Some(content) = text_delta {
443 text_content.push_str(&content);
444 let _ =
445 tx.send(StreamEvent::TextDelta(content)).await;
446 }
447
448 if let Some(tcs) = delta.tool_calls {
449 for tc in tcs {
450 let entry = tool_calls
451 .entry(tc.index)
452 .or_insert_with(|| {
453 (
454 String::new(),
455 String::new(),
456 String::new(),
457 )
458 });
459
460 if let Some(id) = tc.id {
461 entry.0 = id;
462 }
463 if let Some(func) = tc.function {
464 if let Some(name) = func.name {
465 entry.1 = name.clone();
466 let _ = tx
467 .send(StreamEvent::ToolUseStart {
468 id: entry.0.clone(),
469 name,
470 })
471 .await;
472 }
473 if let Some(args) = func.arguments {
474 entry.2.push_str(&args);
475 let _ = tx
476 .send(
477 StreamEvent::ToolUseInputDelta(
478 args,
479 ),
480 )
481 .await;
482 }
483 }
484 }
485 }
486 }
487 }
488 }
489 }
490 }
491 }
492 }
493 });
494
495 Ok(rx)
496 }
497 }
498}
499
500#[derive(Debug, Deserialize)]
502pub(crate) struct OpenAiResponse {
503 pub(crate) choices: Vec<OpenAiChoice>,
504 pub(crate) usage: OpenAiUsage,
505}
506
507#[derive(Debug, Deserialize)]
508pub(crate) struct OpenAiChoice {
509 pub(crate) message: OpenAiMessage,
510 pub(crate) finish_reason: Option<String>,
511}
512
513#[derive(Debug, Deserialize)]
514pub(crate) struct OpenAiMessage {
515 pub(crate) reasoning_content: Option<String>,
516 pub(crate) content: Option<String>,
517 pub(crate) tool_calls: Option<Vec<OpenAiToolCall>>,
518}
519
520#[derive(Debug, Deserialize)]
521pub(crate) struct OpenAiToolCall {
522 pub(crate) id: String,
523 pub(crate) function: OpenAiFunction,
524}
525
526#[derive(Debug, Deserialize)]
527pub(crate) struct OpenAiFunction {
528 pub(crate) name: String,
529 pub(crate) arguments: String,
530}
531
532#[derive(Debug, Deserialize)]
533pub(crate) struct OpenAiUsage {
534 pub(crate) prompt_tokens: usize,
535 pub(crate) completion_tokens: usize,
536 pub(crate) total_tokens: usize,
537}
538
539#[derive(Debug, Deserialize)]
541pub(crate) struct OpenAiStreamChunk {
542 pub(crate) choices: Vec<OpenAiStreamChoice>,
543 pub(crate) usage: Option<OpenAiUsage>,
544}
545
546#[derive(Debug, Deserialize)]
547pub(crate) struct OpenAiStreamChoice {
548 pub(crate) delta: Option<OpenAiDelta>,
549 pub(crate) finish_reason: Option<String>,
550}
551
552#[derive(Debug, Deserialize)]
553pub(crate) struct OpenAiDelta {
554 pub(crate) reasoning_content: Option<String>,
555 pub(crate) content: Option<String>,
556 pub(crate) tool_calls: Option<Vec<OpenAiToolCallDelta>>,
557}
558
559#[derive(Debug, Deserialize)]
560pub(crate) struct OpenAiToolCallDelta {
561 pub(crate) index: usize,
562 pub(crate) id: Option<String>,
563 pub(crate) function: Option<OpenAiFunctionDelta>,
564}
565
566#[derive(Debug, Deserialize)]
567pub(crate) struct OpenAiFunctionDelta {
568 pub(crate) name: Option<String>,
569 pub(crate) arguments: Option<String>,
570}