Skip to main content

a3s_code_core/llm/
openai.rs

1//! OpenAI-compatible LLM client
2
3use super::http::{default_http_client, normalize_base_url, HttpClient};
4use super::types::*;
5use super::LlmClient;
6use anyhow::{Context, Result};
7use async_trait::async_trait;
8use crate::retry::{AttemptOutcome, RetryConfig};
9use futures::StreamExt;
10use serde::Deserialize;
11use std::sync::Arc;
12use tokio::sync::mpsc;
13use tracing::Instrument;
14
15/// OpenAI client
16pub struct OpenAiClient {
17    pub(crate) api_key: SecretString,
18    pub(crate) model: String,
19    pub(crate) base_url: String,
20    pub(crate) http: Arc<dyn HttpClient>,
21    pub(crate) retry_config: RetryConfig,
22}
23
24impl OpenAiClient {
25    pub fn new(api_key: String, model: String) -> Self {
26        Self {
27            api_key: SecretString::new(api_key),
28            model,
29            base_url: "https://api.openai.com".to_string(),
30            http: default_http_client(),
31            retry_config: RetryConfig::default(),
32        }
33    }
34
35    pub fn with_base_url(mut self, base_url: String) -> Self {
36        self.base_url = normalize_base_url(&base_url);
37        self
38    }
39
40    pub fn with_retry_config(mut self, retry_config: RetryConfig) -> Self {
41        self.retry_config = retry_config;
42        self
43    }
44
45    pub fn with_http_client(mut self, http: Arc<dyn HttpClient>) -> Self {
46        self.http = http;
47        self
48    }
49
50    pub(crate) fn convert_messages(&self, messages: &[Message]) -> Vec<serde_json::Value> {
51        messages
52            .iter()
53            .map(|msg| {
54                let content: serde_json::Value = if msg.content.len() == 1 {
55                    match &msg.content[0] {
56                        ContentBlock::Text { text } => serde_json::json!(text),
57                        ContentBlock::ToolResult {
58                            tool_use_id,
59                            content,
60                            ..
61                        } => {
62                            return serde_json::json!({
63                                "role": "tool",
64                                "tool_call_id": tool_use_id,
65                                "content": content,
66                            });
67                        }
68                        _ => serde_json::json!(""),
69                    }
70                } else {
71                    serde_json::json!(msg
72                        .content
73                        .iter()
74                        .map(|block| {
75                            match block {
76                                ContentBlock::Text { text } => serde_json::json!({
77                                    "type": "text",
78                                    "text": text,
79                                }),
80                                ContentBlock::ToolUse { id, name, input } => serde_json::json!({
81                                    "type": "function",
82                                    "id": id,
83                                    "function": {
84                                        "name": name,
85                                        "arguments": input.to_string(),
86                                    }
87                                }),
88                                _ => serde_json::json!({}),
89                            }
90                        })
91                        .collect::<Vec<_>>())
92                };
93
94                // Handle assistant messages — kimi-k2.5 requires reasoning_content
95                // on all assistant messages when thinking mode is enabled
96                if msg.role == "assistant" {
97                    let rc = msg.reasoning_content.as_deref().unwrap_or("");
98                    let tool_calls: Vec<_> = msg.tool_calls();
99                    if !tool_calls.is_empty() {
100                        return serde_json::json!({
101                            "role": "assistant",
102                            "content": msg.text(),
103                            "reasoning_content": rc,
104                            "tool_calls": tool_calls.iter().map(|tc| {
105                                serde_json::json!({
106                                    "id": tc.id,
107                                    "type": "function",
108                                    "function": {
109                                        "name": tc.name,
110                                        "arguments": tc.args.to_string(),
111                                    }
112                                })
113                            }).collect::<Vec<_>>(),
114                        });
115                    }
116                    return serde_json::json!({
117                        "role": "assistant",
118                        "content": content,
119                        "reasoning_content": rc,
120                    });
121                }
122
123                serde_json::json!({
124                    "role": msg.role,
125                    "content": content,
126                })
127            })
128            .collect()
129    }
130
131    pub(crate) fn convert_tools(&self, tools: &[ToolDefinition]) -> Vec<serde_json::Value> {
132        tools
133            .iter()
134            .map(|t| {
135                serde_json::json!({
136                    "type": "function",
137                    "function": {
138                        "name": t.name,
139                        "description": t.description,
140                        "parameters": t.parameters,
141                    }
142                })
143            })
144            .collect()
145    }
146}
147
148#[async_trait]
149impl LlmClient for OpenAiClient {
150    async fn complete(
151        &self,
152        messages: &[Message],
153        system: Option<&str>,
154        tools: &[ToolDefinition],
155    ) -> Result<LlmResponse> {
156        let span = tracing::info_span!(
157            "a3s.llm.completion",
158            "a3s.llm.provider" = "openai",
159            "a3s.llm.model" = %self.model,
160            "a3s.llm.streaming" = false,
161            "a3s.llm.prompt_tokens" = tracing::field::Empty,
162            "a3s.llm.completion_tokens" = tracing::field::Empty,
163            "a3s.llm.total_tokens" = tracing::field::Empty,
164            "a3s.llm.stop_reason" = tracing::field::Empty,
165        );
166        async {
167            let mut openai_messages = Vec::new();
168
169            if let Some(sys) = system {
170                openai_messages.push(serde_json::json!({
171                    "role": "system",
172                    "content": sys,
173                }));
174            }
175
176            openai_messages.extend(self.convert_messages(messages));
177
178            let mut request = serde_json::json!({
179                "model": self.model,
180                "messages": openai_messages,
181            });
182
183            if !tools.is_empty() {
184                request["tools"] = serde_json::json!(self.convert_tools(tools));
185            }
186
187            let url = format!("{}/v1/chat/completions", self.base_url);
188            let auth_header = format!("Bearer {}", self.api_key.expose());
189            let headers = vec![("Authorization", auth_header.as_str())];
190
191            let response = crate::retry::with_retry(&self.retry_config, |_attempt| {
192                let http = &self.http;
193                let url = &url;
194                let headers = headers.clone();
195                let request = &request;
196                async move {
197                    match http.post(url, headers, request).await {
198                        Ok(resp) => {
199                            let status = reqwest::StatusCode::from_u16(resp.status)
200                                .unwrap_or(reqwest::StatusCode::INTERNAL_SERVER_ERROR);
201                            if status.is_success() {
202                                AttemptOutcome::Success(resp.body)
203                            } else if self.retry_config.is_retryable_status(status) {
204                                AttemptOutcome::Retryable {
205                                    status,
206                                    body: resp.body,
207                                    retry_after: None,
208                                }
209                            } else {
210                                AttemptOutcome::Fatal(anyhow::anyhow!(
211                                    "OpenAI API error at {} ({}): {}",
212                                    url,
213                                    status,
214                                    resp.body
215                                ))
216                            }
217                        }
218                        Err(e) => AttemptOutcome::Fatal(e),
219                    }
220                }
221            })
222            .await?;
223
224            let parsed: OpenAiResponse =
225                serde_json::from_str(&response).context("Failed to parse OpenAI response")?;
226
227            let choice = parsed.choices.into_iter().next().context("No choices")?;
228
229            let mut content = vec![];
230
231            let reasoning_content = choice.message.reasoning_content.clone();
232
233            let text_content = choice.message.content.or(choice.message.reasoning_content);
234
235            if let Some(text) = text_content {
236                if !text.is_empty() {
237                    content.push(ContentBlock::Text { text });
238                }
239            }
240
241            if let Some(tool_calls) = choice.message.tool_calls {
242                for tc in tool_calls {
243                    content.push(ContentBlock::ToolUse {
244                        id: tc.id,
245                        name: tc.function.name.clone(),
246                        input: serde_json::from_str(&tc.function.arguments).unwrap_or_else(|e| {
247                            tracing::warn!(
248                                "Failed to parse tool arguments JSON for tool '{}': {}",
249                                tc.function.name,
250                                e
251                            );
252                            serde_json::Value::default()
253                        }),
254                    });
255                }
256            }
257
258            let llm_response = LlmResponse {
259                message: Message {
260                    role: "assistant".to_string(),
261                    content,
262                    reasoning_content,
263                },
264                usage: TokenUsage {
265                    prompt_tokens: parsed.usage.prompt_tokens,
266                    completion_tokens: parsed.usage.completion_tokens,
267                    total_tokens: parsed.usage.total_tokens,
268                    cache_read_tokens: None,
269                    cache_write_tokens: None,
270                },
271                stop_reason: choice.finish_reason,
272            };
273
274            crate::telemetry::record_llm_usage(
275                llm_response.usage.prompt_tokens,
276                llm_response.usage.completion_tokens,
277                llm_response.usage.total_tokens,
278                llm_response.stop_reason.as_deref(),
279            );
280
281            Ok(llm_response)
282        }
283        .instrument(span)
284        .await
285    }
286
287    async fn complete_streaming(
288        &self,
289        messages: &[Message],
290        system: Option<&str>,
291        tools: &[ToolDefinition],
292    ) -> Result<mpsc::Receiver<StreamEvent>> {
293        let span = tracing::info_span!(
294            "a3s.llm.completion",
295            "a3s.llm.provider" = "openai",
296            "a3s.llm.model" = %self.model,
297            "a3s.llm.streaming" = true,
298            "a3s.llm.prompt_tokens" = tracing::field::Empty,
299            "a3s.llm.completion_tokens" = tracing::field::Empty,
300            "a3s.llm.total_tokens" = tracing::field::Empty,
301            "a3s.llm.stop_reason" = tracing::field::Empty,
302        );
303        async {
304        let mut openai_messages = Vec::new();
305
306        if let Some(sys) = system {
307            openai_messages.push(serde_json::json!({
308                "role": "system",
309                "content": sys,
310            }));
311        }
312
313        openai_messages.extend(self.convert_messages(messages));
314
315        let mut request = serde_json::json!({
316            "model": self.model,
317            "messages": openai_messages,
318            "stream": true,
319            "stream_options": { "include_usage": true },
320        });
321
322        if !tools.is_empty() {
323            request["tools"] = serde_json::json!(self.convert_tools(tools));
324        }
325
326        let url = format!("{}/v1/chat/completions", self.base_url);
327        let auth_header = format!("Bearer {}", self.api_key.expose());
328        let headers = vec![("Authorization", auth_header.as_str())];
329
330        let streaming_resp = crate::retry::with_retry(&self.retry_config, |_attempt| {
331            let http = &self.http;
332            let url = &url;
333            let headers = headers.clone();
334            let request = &request;
335            async move {
336                match http.post_streaming(url, headers, request).await {
337                    Ok(resp) => {
338                        let status = reqwest::StatusCode::from_u16(resp.status)
339                            .unwrap_or(reqwest::StatusCode::INTERNAL_SERVER_ERROR);
340                        if status.is_success() {
341                            AttemptOutcome::Success(resp)
342                        } else {
343                            let retry_after = resp
344                                .retry_after
345                                .as_deref()
346                                .and_then(|v| RetryConfig::parse_retry_after(Some(v)));
347                            if self.retry_config.is_retryable_status(status) {
348                                AttemptOutcome::Retryable {
349                                    status,
350                                    body: resp.error_body,
351                                    retry_after,
352                                }
353                            } else {
354                                AttemptOutcome::Fatal(anyhow::anyhow!(
355                                    "OpenAI API error at {} ({}): {}", url, status, resp.error_body
356                                ))
357                            }
358                        }
359                    }
360                    Err(e) => AttemptOutcome::Fatal(anyhow::anyhow!(
361                        "Failed to send streaming request: {}", e
362                    )),
363                }
364            }
365        })
366        .await?;
367
368        let (tx, rx) = mpsc::channel(100);
369
370        let mut stream = streaming_resp.byte_stream;
371        tokio::spawn(async move {
372            let mut buffer = String::new();
373            let mut content_blocks: Vec<ContentBlock> = Vec::new();
374            let mut text_content = String::new();
375            let mut reasoning_content_accum = String::new();
376            let mut tool_calls: std::collections::BTreeMap<usize, (String, String, String)> =
377                std::collections::BTreeMap::new();
378            let mut usage = TokenUsage::default();
379            let mut finish_reason = None;
380
381            while let Some(chunk_result) = stream.next().await {
382                let chunk = match chunk_result {
383                    Ok(c) => c,
384                    Err(e) => {
385                        tracing::error!("Stream error: {}", e);
386                        break;
387                    }
388                };
389
390                buffer.push_str(&String::from_utf8_lossy(&chunk));
391
392                while let Some(event_end) = buffer.find("\n\n") {
393                    let event_data: String = buffer.drain(..event_end).collect();
394                    buffer.drain(..2);
395
396                    for line in event_data.lines() {
397                        if let Some(data) = line.strip_prefix("data: ") {
398                            if data == "[DONE]" {
399                                if !text_content.is_empty() {
400                                    content_blocks.push(ContentBlock::Text {
401                                        text: text_content.clone(),
402                                    });
403                                }
404                                for (_, (id, name, args)) in tool_calls.iter() {
405                                    content_blocks.push(ContentBlock::ToolUse {
406                                        id: id.clone(),
407                                        name: name.clone(),
408                                        input: serde_json::from_str(args).unwrap_or_else(|e| {
409                                            tracing::warn!(
410                                                "Failed to parse tool arguments JSON for tool '{}': {}",
411                                                name, e
412                                            );
413                                            serde_json::Value::default()
414                                        }),
415                                    });
416                                }
417                                tool_calls.clear();
418                                crate::telemetry::record_llm_usage(
419                                    usage.prompt_tokens,
420                                    usage.completion_tokens,
421                                    usage.total_tokens,
422                                    finish_reason.as_deref(),
423                                );
424                                let response = LlmResponse {
425                                    message: Message {
426                                        role: "assistant".to_string(),
427                                        content: std::mem::take(&mut content_blocks),
428                                        reasoning_content: if reasoning_content_accum.is_empty() { None } else { Some(std::mem::take(&mut reasoning_content_accum)) },
429                                    },
430                                    usage: usage.clone(),
431                                    stop_reason: std::mem::take(&mut finish_reason),
432                                };
433                                let _ = tx.send(StreamEvent::Done(response)).await;
434                                continue;
435                            }
436
437                            if let Ok(event) = serde_json::from_str::<OpenAiStreamChunk>(data) {
438                                if let Some(u) = event.usage {
439                                    usage.prompt_tokens = u.prompt_tokens;
440                                    usage.completion_tokens = u.completion_tokens;
441                                    usage.total_tokens = u.total_tokens;
442                                }
443
444                                if let Some(choice) = event.choices.into_iter().next() {
445                                    if let Some(reason) = choice.finish_reason {
446                                        finish_reason = Some(reason);
447                                    }
448
449                                    if let Some(delta) = choice.delta {
450                                        if let Some(ref rc) = delta.reasoning_content {
451                                            reasoning_content_accum.push_str(rc);
452                                        }
453
454                                        let text_delta = delta.content
455                                            .or(delta.reasoning_content);
456                                        if let Some(content) = text_delta {
457                                            text_content.push_str(&content);
458                                            let _ = tx.send(StreamEvent::TextDelta(content)).await;
459                                        }
460
461                                        if let Some(tcs) = delta.tool_calls {
462                                            for tc in tcs {
463                                                let entry = tool_calls
464                                                    .entry(tc.index)
465                                                    .or_insert_with(|| {
466                                                        (
467                                                            String::new(),
468                                                            String::new(),
469                                                            String::new(),
470                                                        )
471                                                    });
472
473                                                if let Some(id) = tc.id {
474                                                    entry.0 = id;
475                                                }
476                                                if let Some(func) = tc.function {
477                                                    if let Some(name) = func.name {
478                                                        entry.1 = name.clone();
479                                                        let _ = tx
480                                                            .send(StreamEvent::ToolUseStart {
481                                                                id: entry.0.clone(),
482                                                                name,
483                                                            })
484                                                            .await;
485                                                    }
486                                                    if let Some(args) = func.arguments {
487                                                        entry.2.push_str(&args);
488                                                        let _ = tx
489                                                            .send(StreamEvent::ToolUseInputDelta(
490                                                                args,
491                                                            ))
492                                                            .await;
493                                                    }
494                                                }
495                                            }
496                                        }
497                                    }
498                                }
499                            }
500                        }
501                    }
502                }
503            }
504        });
505
506        Ok(rx)
507        }
508        .instrument(span)
509        .await
510    }
511}
512
513// OpenAI API response types (private)
514#[derive(Debug, Deserialize)]
515pub(crate) struct OpenAiResponse {
516    pub(crate) choices: Vec<OpenAiChoice>,
517    pub(crate) usage: OpenAiUsage,
518}
519
520#[derive(Debug, Deserialize)]
521pub(crate) struct OpenAiChoice {
522    pub(crate) message: OpenAiMessage,
523    pub(crate) finish_reason: Option<String>,
524}
525
526#[derive(Debug, Deserialize)]
527pub(crate) struct OpenAiMessage {
528    pub(crate) reasoning_content: Option<String>,
529    pub(crate) content: Option<String>,
530    pub(crate) tool_calls: Option<Vec<OpenAiToolCall>>,
531}
532
533#[derive(Debug, Deserialize)]
534pub(crate) struct OpenAiToolCall {
535    pub(crate) id: String,
536    pub(crate) function: OpenAiFunction,
537}
538
539#[derive(Debug, Deserialize)]
540pub(crate) struct OpenAiFunction {
541    pub(crate) name: String,
542    pub(crate) arguments: String,
543}
544
545#[derive(Debug, Deserialize)]
546pub(crate) struct OpenAiUsage {
547    pub(crate) prompt_tokens: usize,
548    pub(crate) completion_tokens: usize,
549    pub(crate) total_tokens: usize,
550}
551
552// OpenAI streaming types
553#[derive(Debug, Deserialize)]
554pub(crate) struct OpenAiStreamChunk {
555    pub(crate) choices: Vec<OpenAiStreamChoice>,
556    pub(crate) usage: Option<OpenAiUsage>,
557}
558
559#[derive(Debug, Deserialize)]
560pub(crate) struct OpenAiStreamChoice {
561    pub(crate) delta: Option<OpenAiDelta>,
562    pub(crate) finish_reason: Option<String>,
563}
564
565#[derive(Debug, Deserialize)]
566pub(crate) struct OpenAiDelta {
567    pub(crate) reasoning_content: Option<String>,
568    pub(crate) content: Option<String>,
569    pub(crate) tool_calls: Option<Vec<OpenAiToolCallDelta>>,
570}
571
572#[derive(Debug, Deserialize)]
573pub(crate) struct OpenAiToolCallDelta {
574    pub(crate) index: usize,
575    pub(crate) id: Option<String>,
576    pub(crate) function: Option<OpenAiFunctionDelta>,
577}
578
579#[derive(Debug, Deserialize)]
580pub(crate) struct OpenAiFunctionDelta {
581    pub(crate) name: Option<String>,
582    pub(crate) arguments: Option<String>,
583}