1use super::http::{default_http_client, normalize_base_url, HttpClient};
4use super::types::*;
5use super::LlmClient;
6use anyhow::{Context, Result};
7use async_trait::async_trait;
8use crate::retry::{AttemptOutcome, RetryConfig};
9use futures::StreamExt;
10use serde::Deserialize;
11use std::sync::Arc;
12use tokio::sync::mpsc;
13use tracing::Instrument;
14
15pub struct OpenAiClient {
17 pub(crate) api_key: SecretString,
18 pub(crate) model: String,
19 pub(crate) base_url: String,
20 pub(crate) http: Arc<dyn HttpClient>,
21 pub(crate) retry_config: RetryConfig,
22}
23
24impl OpenAiClient {
25 pub fn new(api_key: String, model: String) -> Self {
26 Self {
27 api_key: SecretString::new(api_key),
28 model,
29 base_url: "https://api.openai.com".to_string(),
30 http: default_http_client(),
31 retry_config: RetryConfig::default(),
32 }
33 }
34
35 pub fn with_base_url(mut self, base_url: String) -> Self {
36 self.base_url = normalize_base_url(&base_url);
37 self
38 }
39
40 pub fn with_retry_config(mut self, retry_config: RetryConfig) -> Self {
41 self.retry_config = retry_config;
42 self
43 }
44
45 pub fn with_http_client(mut self, http: Arc<dyn HttpClient>) -> Self {
46 self.http = http;
47 self
48 }
49
50 pub(crate) fn convert_messages(&self, messages: &[Message]) -> Vec<serde_json::Value> {
51 messages
52 .iter()
53 .map(|msg| {
54 let content: serde_json::Value = if msg.content.len() == 1 {
55 match &msg.content[0] {
56 ContentBlock::Text { text } => serde_json::json!(text),
57 ContentBlock::ToolResult {
58 tool_use_id,
59 content,
60 ..
61 } => {
62 return serde_json::json!({
63 "role": "tool",
64 "tool_call_id": tool_use_id,
65 "content": content,
66 });
67 }
68 _ => serde_json::json!(""),
69 }
70 } else {
71 serde_json::json!(msg
72 .content
73 .iter()
74 .map(|block| {
75 match block {
76 ContentBlock::Text { text } => serde_json::json!({
77 "type": "text",
78 "text": text,
79 }),
80 ContentBlock::ToolUse { id, name, input } => serde_json::json!({
81 "type": "function",
82 "id": id,
83 "function": {
84 "name": name,
85 "arguments": input.to_string(),
86 }
87 }),
88 _ => serde_json::json!({}),
89 }
90 })
91 .collect::<Vec<_>>())
92 };
93
94 if msg.role == "assistant" {
97 let rc = msg.reasoning_content.as_deref().unwrap_or("");
98 let tool_calls: Vec<_> = msg.tool_calls();
99 if !tool_calls.is_empty() {
100 return serde_json::json!({
101 "role": "assistant",
102 "content": msg.text(),
103 "reasoning_content": rc,
104 "tool_calls": tool_calls.iter().map(|tc| {
105 serde_json::json!({
106 "id": tc.id,
107 "type": "function",
108 "function": {
109 "name": tc.name,
110 "arguments": tc.args.to_string(),
111 }
112 })
113 }).collect::<Vec<_>>(),
114 });
115 }
116 return serde_json::json!({
117 "role": "assistant",
118 "content": content,
119 "reasoning_content": rc,
120 });
121 }
122
123 serde_json::json!({
124 "role": msg.role,
125 "content": content,
126 })
127 })
128 .collect()
129 }
130
131 pub(crate) fn convert_tools(&self, tools: &[ToolDefinition]) -> Vec<serde_json::Value> {
132 tools
133 .iter()
134 .map(|t| {
135 serde_json::json!({
136 "type": "function",
137 "function": {
138 "name": t.name,
139 "description": t.description,
140 "parameters": t.parameters,
141 }
142 })
143 })
144 .collect()
145 }
146}
147
148#[async_trait]
149impl LlmClient for OpenAiClient {
150 async fn complete(
151 &self,
152 messages: &[Message],
153 system: Option<&str>,
154 tools: &[ToolDefinition],
155 ) -> Result<LlmResponse> {
156 let span = tracing::info_span!(
157 "a3s.llm.completion",
158 "a3s.llm.provider" = "openai",
159 "a3s.llm.model" = %self.model,
160 "a3s.llm.streaming" = false,
161 "a3s.llm.prompt_tokens" = tracing::field::Empty,
162 "a3s.llm.completion_tokens" = tracing::field::Empty,
163 "a3s.llm.total_tokens" = tracing::field::Empty,
164 "a3s.llm.stop_reason" = tracing::field::Empty,
165 );
166 async {
167 let mut openai_messages = Vec::new();
168
169 if let Some(sys) = system {
170 openai_messages.push(serde_json::json!({
171 "role": "system",
172 "content": sys,
173 }));
174 }
175
176 openai_messages.extend(self.convert_messages(messages));
177
178 let mut request = serde_json::json!({
179 "model": self.model,
180 "messages": openai_messages,
181 });
182
183 if !tools.is_empty() {
184 request["tools"] = serde_json::json!(self.convert_tools(tools));
185 }
186
187 let url = format!("{}/v1/chat/completions", self.base_url);
188 let auth_header = format!("Bearer {}", self.api_key.expose());
189 let headers = vec![("Authorization", auth_header.as_str())];
190
191 let response = crate::retry::with_retry(&self.retry_config, |_attempt| {
192 let http = &self.http;
193 let url = &url;
194 let headers = headers.clone();
195 let request = &request;
196 async move {
197 match http.post(url, headers, request).await {
198 Ok(resp) => {
199 let status = reqwest::StatusCode::from_u16(resp.status)
200 .unwrap_or(reqwest::StatusCode::INTERNAL_SERVER_ERROR);
201 if status.is_success() {
202 AttemptOutcome::Success(resp.body)
203 } else if self.retry_config.is_retryable_status(status) {
204 AttemptOutcome::Retryable {
205 status,
206 body: resp.body,
207 retry_after: None,
208 }
209 } else {
210 AttemptOutcome::Fatal(anyhow::anyhow!(
211 "OpenAI API error at {} ({}): {}",
212 url,
213 status,
214 resp.body
215 ))
216 }
217 }
218 Err(e) => AttemptOutcome::Fatal(e),
219 }
220 }
221 })
222 .await?;
223
224 let parsed: OpenAiResponse =
225 serde_json::from_str(&response).context("Failed to parse OpenAI response")?;
226
227 let choice = parsed.choices.into_iter().next().context("No choices")?;
228
229 let mut content = vec![];
230
231 let reasoning_content = choice.message.reasoning_content.clone();
232
233 let text_content = choice.message.content.or(choice.message.reasoning_content);
234
235 if let Some(text) = text_content {
236 if !text.is_empty() {
237 content.push(ContentBlock::Text { text });
238 }
239 }
240
241 if let Some(tool_calls) = choice.message.tool_calls {
242 for tc in tool_calls {
243 content.push(ContentBlock::ToolUse {
244 id: tc.id,
245 name: tc.function.name.clone(),
246 input: serde_json::from_str(&tc.function.arguments).unwrap_or_else(|e| {
247 tracing::warn!(
248 "Failed to parse tool arguments JSON for tool '{}': {}",
249 tc.function.name,
250 e
251 );
252 serde_json::Value::default()
253 }),
254 });
255 }
256 }
257
258 let llm_response = LlmResponse {
259 message: Message {
260 role: "assistant".to_string(),
261 content,
262 reasoning_content,
263 },
264 usage: TokenUsage {
265 prompt_tokens: parsed.usage.prompt_tokens,
266 completion_tokens: parsed.usage.completion_tokens,
267 total_tokens: parsed.usage.total_tokens,
268 cache_read_tokens: None,
269 cache_write_tokens: None,
270 },
271 stop_reason: choice.finish_reason,
272 };
273
274 crate::telemetry::record_llm_usage(
275 llm_response.usage.prompt_tokens,
276 llm_response.usage.completion_tokens,
277 llm_response.usage.total_tokens,
278 llm_response.stop_reason.as_deref(),
279 );
280
281 Ok(llm_response)
282 }
283 .instrument(span)
284 .await
285 }
286
287 async fn complete_streaming(
288 &self,
289 messages: &[Message],
290 system: Option<&str>,
291 tools: &[ToolDefinition],
292 ) -> Result<mpsc::Receiver<StreamEvent>> {
293 let span = tracing::info_span!(
294 "a3s.llm.completion",
295 "a3s.llm.provider" = "openai",
296 "a3s.llm.model" = %self.model,
297 "a3s.llm.streaming" = true,
298 "a3s.llm.prompt_tokens" = tracing::field::Empty,
299 "a3s.llm.completion_tokens" = tracing::field::Empty,
300 "a3s.llm.total_tokens" = tracing::field::Empty,
301 "a3s.llm.stop_reason" = tracing::field::Empty,
302 );
303 async {
304 let mut openai_messages = Vec::new();
305
306 if let Some(sys) = system {
307 openai_messages.push(serde_json::json!({
308 "role": "system",
309 "content": sys,
310 }));
311 }
312
313 openai_messages.extend(self.convert_messages(messages));
314
315 let mut request = serde_json::json!({
316 "model": self.model,
317 "messages": openai_messages,
318 "stream": true,
319 "stream_options": { "include_usage": true },
320 });
321
322 if !tools.is_empty() {
323 request["tools"] = serde_json::json!(self.convert_tools(tools));
324 }
325
326 let url = format!("{}/v1/chat/completions", self.base_url);
327 let auth_header = format!("Bearer {}", self.api_key.expose());
328 let headers = vec![("Authorization", auth_header.as_str())];
329
330 let streaming_resp = crate::retry::with_retry(&self.retry_config, |_attempt| {
331 let http = &self.http;
332 let url = &url;
333 let headers = headers.clone();
334 let request = &request;
335 async move {
336 match http.post_streaming(url, headers, request).await {
337 Ok(resp) => {
338 let status = reqwest::StatusCode::from_u16(resp.status)
339 .unwrap_or(reqwest::StatusCode::INTERNAL_SERVER_ERROR);
340 if status.is_success() {
341 AttemptOutcome::Success(resp)
342 } else {
343 let retry_after = resp
344 .retry_after
345 .as_deref()
346 .and_then(|v| RetryConfig::parse_retry_after(Some(v)));
347 if self.retry_config.is_retryable_status(status) {
348 AttemptOutcome::Retryable {
349 status,
350 body: resp.error_body,
351 retry_after,
352 }
353 } else {
354 AttemptOutcome::Fatal(anyhow::anyhow!(
355 "OpenAI API error at {} ({}): {}", url, status, resp.error_body
356 ))
357 }
358 }
359 }
360 Err(e) => AttemptOutcome::Fatal(anyhow::anyhow!(
361 "Failed to send streaming request: {}", e
362 )),
363 }
364 }
365 })
366 .await?;
367
368 let (tx, rx) = mpsc::channel(100);
369
370 let mut stream = streaming_resp.byte_stream;
371 tokio::spawn(async move {
372 let mut buffer = String::new();
373 let mut content_blocks: Vec<ContentBlock> = Vec::new();
374 let mut text_content = String::new();
375 let mut reasoning_content_accum = String::new();
376 let mut tool_calls: std::collections::BTreeMap<usize, (String, String, String)> =
377 std::collections::BTreeMap::new();
378 let mut usage = TokenUsage::default();
379 let mut finish_reason = None;
380
381 while let Some(chunk_result) = stream.next().await {
382 let chunk = match chunk_result {
383 Ok(c) => c,
384 Err(e) => {
385 tracing::error!("Stream error: {}", e);
386 break;
387 }
388 };
389
390 buffer.push_str(&String::from_utf8_lossy(&chunk));
391
392 while let Some(event_end) = buffer.find("\n\n") {
393 let event_data: String = buffer.drain(..event_end).collect();
394 buffer.drain(..2);
395
396 for line in event_data.lines() {
397 if let Some(data) = line.strip_prefix("data: ") {
398 if data == "[DONE]" {
399 if !text_content.is_empty() {
400 content_blocks.push(ContentBlock::Text {
401 text: text_content.clone(),
402 });
403 }
404 for (_, (id, name, args)) in tool_calls.iter() {
405 content_blocks.push(ContentBlock::ToolUse {
406 id: id.clone(),
407 name: name.clone(),
408 input: serde_json::from_str(args).unwrap_or_else(|e| {
409 tracing::warn!(
410 "Failed to parse tool arguments JSON for tool '{}': {}",
411 name, e
412 );
413 serde_json::Value::default()
414 }),
415 });
416 }
417 tool_calls.clear();
418 crate::telemetry::record_llm_usage(
419 usage.prompt_tokens,
420 usage.completion_tokens,
421 usage.total_tokens,
422 finish_reason.as_deref(),
423 );
424 let response = LlmResponse {
425 message: Message {
426 role: "assistant".to_string(),
427 content: std::mem::take(&mut content_blocks),
428 reasoning_content: if reasoning_content_accum.is_empty() { None } else { Some(std::mem::take(&mut reasoning_content_accum)) },
429 },
430 usage: usage.clone(),
431 stop_reason: std::mem::take(&mut finish_reason),
432 };
433 let _ = tx.send(StreamEvent::Done(response)).await;
434 continue;
435 }
436
437 if let Ok(event) = serde_json::from_str::<OpenAiStreamChunk>(data) {
438 if let Some(u) = event.usage {
439 usage.prompt_tokens = u.prompt_tokens;
440 usage.completion_tokens = u.completion_tokens;
441 usage.total_tokens = u.total_tokens;
442 }
443
444 if let Some(choice) = event.choices.into_iter().next() {
445 if let Some(reason) = choice.finish_reason {
446 finish_reason = Some(reason);
447 }
448
449 if let Some(delta) = choice.delta {
450 if let Some(ref rc) = delta.reasoning_content {
451 reasoning_content_accum.push_str(rc);
452 }
453
454 let text_delta = delta.content
455 .or(delta.reasoning_content);
456 if let Some(content) = text_delta {
457 text_content.push_str(&content);
458 let _ = tx.send(StreamEvent::TextDelta(content)).await;
459 }
460
461 if let Some(tcs) = delta.tool_calls {
462 for tc in tcs {
463 let entry = tool_calls
464 .entry(tc.index)
465 .or_insert_with(|| {
466 (
467 String::new(),
468 String::new(),
469 String::new(),
470 )
471 });
472
473 if let Some(id) = tc.id {
474 entry.0 = id;
475 }
476 if let Some(func) = tc.function {
477 if let Some(name) = func.name {
478 entry.1 = name.clone();
479 let _ = tx
480 .send(StreamEvent::ToolUseStart {
481 id: entry.0.clone(),
482 name,
483 })
484 .await;
485 }
486 if let Some(args) = func.arguments {
487 entry.2.push_str(&args);
488 let _ = tx
489 .send(StreamEvent::ToolUseInputDelta(
490 args,
491 ))
492 .await;
493 }
494 }
495 }
496 }
497 }
498 }
499 }
500 }
501 }
502 }
503 }
504 });
505
506 Ok(rx)
507 }
508 .instrument(span)
509 .await
510 }
511}
512
513#[derive(Debug, Deserialize)]
515pub(crate) struct OpenAiResponse {
516 pub(crate) choices: Vec<OpenAiChoice>,
517 pub(crate) usage: OpenAiUsage,
518}
519
520#[derive(Debug, Deserialize)]
521pub(crate) struct OpenAiChoice {
522 pub(crate) message: OpenAiMessage,
523 pub(crate) finish_reason: Option<String>,
524}
525
526#[derive(Debug, Deserialize)]
527pub(crate) struct OpenAiMessage {
528 pub(crate) reasoning_content: Option<String>,
529 pub(crate) content: Option<String>,
530 pub(crate) tool_calls: Option<Vec<OpenAiToolCall>>,
531}
532
533#[derive(Debug, Deserialize)]
534pub(crate) struct OpenAiToolCall {
535 pub(crate) id: String,
536 pub(crate) function: OpenAiFunction,
537}
538
539#[derive(Debug, Deserialize)]
540pub(crate) struct OpenAiFunction {
541 pub(crate) name: String,
542 pub(crate) arguments: String,
543}
544
545#[derive(Debug, Deserialize)]
546pub(crate) struct OpenAiUsage {
547 pub(crate) prompt_tokens: usize,
548 pub(crate) completion_tokens: usize,
549 pub(crate) total_tokens: usize,
550}
551
552#[derive(Debug, Deserialize)]
554pub(crate) struct OpenAiStreamChunk {
555 pub(crate) choices: Vec<OpenAiStreamChoice>,
556 pub(crate) usage: Option<OpenAiUsage>,
557}
558
559#[derive(Debug, Deserialize)]
560pub(crate) struct OpenAiStreamChoice {
561 pub(crate) delta: Option<OpenAiDelta>,
562 pub(crate) finish_reason: Option<String>,
563}
564
565#[derive(Debug, Deserialize)]
566pub(crate) struct OpenAiDelta {
567 pub(crate) reasoning_content: Option<String>,
568 pub(crate) content: Option<String>,
569 pub(crate) tool_calls: Option<Vec<OpenAiToolCallDelta>>,
570}
571
572#[derive(Debug, Deserialize)]
573pub(crate) struct OpenAiToolCallDelta {
574 pub(crate) index: usize,
575 pub(crate) id: Option<String>,
576 pub(crate) function: Option<OpenAiFunctionDelta>,
577}
578
579#[derive(Debug, Deserialize)]
580pub(crate) struct OpenAiFunctionDelta {
581 pub(crate) name: Option<String>,
582 pub(crate) arguments: Option<String>,
583}