Skip to main content

a3s_code_core/llm/
openai.rs

1//! OpenAI-compatible LLM client
2
3use super::http::{default_http_client, normalize_base_url, HttpClient};
4use super::types::*;
5use super::LlmClient;
6use crate::llm::types::{ToolResultContent, ToolResultContentField};
7use crate::retry::{AttemptOutcome, RetryConfig};
8use anyhow::{Context, Result};
9use async_trait::async_trait;
10use futures::StreamExt;
11use serde::Deserialize;
12use std::sync::Arc;
13use tokio::sync::mpsc;
14
15/// OpenAI client
16pub struct OpenAiClient {
17    pub(crate) api_key: SecretString,
18    pub(crate) model: String,
19    pub(crate) base_url: String,
20    pub(crate) http: Arc<dyn HttpClient>,
21    pub(crate) retry_config: RetryConfig,
22}
23
24impl OpenAiClient {
25    pub fn new(api_key: String, model: String) -> Self {
26        Self {
27            api_key: SecretString::new(api_key),
28            model,
29            base_url: "https://api.openai.com".to_string(),
30            http: default_http_client(),
31            retry_config: RetryConfig::default(),
32        }
33    }
34
35    pub fn with_base_url(mut self, base_url: String) -> Self {
36        self.base_url = normalize_base_url(&base_url);
37        self
38    }
39
40    pub fn with_retry_config(mut self, retry_config: RetryConfig) -> Self {
41        self.retry_config = retry_config;
42        self
43    }
44
45    pub fn with_http_client(mut self, http: Arc<dyn HttpClient>) -> Self {
46        self.http = http;
47        self
48    }
49
50    pub(crate) fn convert_messages(&self, messages: &[Message]) -> Vec<serde_json::Value> {
51        messages
52            .iter()
53            .map(|msg| {
54                let content: serde_json::Value = if msg.content.len() == 1 {
55                    match &msg.content[0] {
56                        ContentBlock::Text { text } => serde_json::json!(text),
57                        ContentBlock::ToolResult {
58                            tool_use_id,
59                            content,
60                            ..
61                        } => {
62                            let content_str = match content {
63                                ToolResultContentField::Text(s) => s.clone(),
64                                ToolResultContentField::Blocks(blocks) => blocks
65                                    .iter()
66                                    .filter_map(|b| {
67                                        if let ToolResultContent::Text { text } = b {
68                                            Some(text.clone())
69                                        } else {
70                                            None
71                                        }
72                                    })
73                                    .collect::<Vec<_>>()
74                                    .join("\n"),
75                            };
76                            return serde_json::json!({
77                                "role": "tool",
78                                "tool_call_id": tool_use_id,
79                                "content": content_str,
80                            });
81                        }
82                        _ => serde_json::json!(""),
83                    }
84                } else {
85                    serde_json::json!(msg
86                        .content
87                        .iter()
88                        .map(|block| {
89                            match block {
90                                ContentBlock::Text { text } => serde_json::json!({
91                                    "type": "text",
92                                    "text": text,
93                                }),
94                                ContentBlock::Image { source } => serde_json::json!({
95                                    "type": "image_url",
96                                    "image_url": {
97                                        "url": format!(
98                                            "data:{};base64,{}",
99                                            source.media_type, source.data
100                                        ),
101                                    }
102                                }),
103                                ContentBlock::ToolUse { id, name, input } => serde_json::json!({
104                                    "type": "function",
105                                    "id": id,
106                                    "function": {
107                                        "name": name,
108                                        "arguments": input.to_string(),
109                                    }
110                                }),
111                                _ => serde_json::json!({}),
112                            }
113                        })
114                        .collect::<Vec<_>>())
115                };
116
117                // Handle assistant messages — kimi-k2.5 requires reasoning_content
118                // on all assistant messages when thinking mode is enabled
119                if msg.role == "assistant" {
120                    let rc = msg.reasoning_content.as_deref().unwrap_or("");
121                    let tool_calls: Vec<_> = msg.tool_calls();
122                    if !tool_calls.is_empty() {
123                        return serde_json::json!({
124                            "role": "assistant",
125                            "content": msg.text(),
126                            "reasoning_content": rc,
127                            "tool_calls": tool_calls.iter().map(|tc| {
128                                serde_json::json!({
129                                    "id": tc.id,
130                                    "type": "function",
131                                    "function": {
132                                        "name": tc.name,
133                                        "arguments": tc.args.to_string(),
134                                    }
135                                })
136                            }).collect::<Vec<_>>(),
137                        });
138                    }
139                    return serde_json::json!({
140                        "role": "assistant",
141                        "content": content,
142                        "reasoning_content": rc,
143                    });
144                }
145
146                serde_json::json!({
147                    "role": msg.role,
148                    "content": content,
149                })
150            })
151            .collect()
152    }
153
154    pub(crate) fn convert_tools(&self, tools: &[ToolDefinition]) -> Vec<serde_json::Value> {
155        tools
156            .iter()
157            .map(|t| {
158                serde_json::json!({
159                    "type": "function",
160                    "function": {
161                        "name": t.name,
162                        "description": t.description,
163                        "parameters": t.parameters,
164                    }
165                })
166            })
167            .collect()
168    }
169}
170
171#[async_trait]
172impl LlmClient for OpenAiClient {
173    async fn complete(
174        &self,
175        messages: &[Message],
176        system: Option<&str>,
177        tools: &[ToolDefinition],
178    ) -> Result<LlmResponse> {
179        {
180            let mut openai_messages = Vec::new();
181
182            if let Some(sys) = system {
183                openai_messages.push(serde_json::json!({
184                    "role": "system",
185                    "content": sys,
186                }));
187            }
188
189            openai_messages.extend(self.convert_messages(messages));
190
191            let mut request = serde_json::json!({
192                "model": self.model,
193                "messages": openai_messages,
194            });
195
196            if !tools.is_empty() {
197                request["tools"] = serde_json::json!(self.convert_tools(tools));
198            }
199
200            let url = format!("{}/v1/chat/completions", self.base_url);
201            let auth_header = format!("Bearer {}", self.api_key.expose());
202            let headers = vec![("Authorization", auth_header.as_str())];
203
204            let response = crate::retry::with_retry(&self.retry_config, |_attempt| {
205                let http = &self.http;
206                let url = &url;
207                let headers = headers.clone();
208                let request = &request;
209                async move {
210                    match http.post(url, headers, request).await {
211                        Ok(resp) => {
212                            let status = reqwest::StatusCode::from_u16(resp.status)
213                                .unwrap_or(reqwest::StatusCode::INTERNAL_SERVER_ERROR);
214                            if status.is_success() {
215                                AttemptOutcome::Success(resp.body)
216                            } else if self.retry_config.is_retryable_status(status) {
217                                AttemptOutcome::Retryable {
218                                    status,
219                                    body: resp.body,
220                                    retry_after: None,
221                                }
222                            } else {
223                                AttemptOutcome::Fatal(anyhow::anyhow!(
224                                    "OpenAI API error at {} ({}): {}",
225                                    url,
226                                    status,
227                                    resp.body
228                                ))
229                            }
230                        }
231                        Err(e) => AttemptOutcome::Fatal(e),
232                    }
233                }
234            })
235            .await?;
236
237            let parsed: OpenAiResponse =
238                serde_json::from_str(&response).context("Failed to parse OpenAI response")?;
239
240            let choice = parsed.choices.into_iter().next().context("No choices")?;
241
242            let mut content = vec![];
243
244            let reasoning_content = choice.message.reasoning_content.clone();
245
246            let text_content = choice.message.content.or(choice.message.reasoning_content);
247
248            if let Some(text) = text_content {
249                if !text.is_empty() {
250                    content.push(ContentBlock::Text { text });
251                }
252            }
253
254            if let Some(tool_calls) = choice.message.tool_calls {
255                for tc in tool_calls {
256                    content.push(ContentBlock::ToolUse {
257                        id: tc.id,
258                        name: tc.function.name.clone(),
259                        input: serde_json::from_str(&tc.function.arguments).unwrap_or_else(|e| {
260                            tracing::warn!(
261                                "Failed to parse tool arguments JSON for tool '{}': {}",
262                                tc.function.name,
263                                e
264                            );
265                            serde_json::Value::default()
266                        }),
267                    });
268                }
269            }
270
271            let llm_response = LlmResponse {
272                message: Message {
273                    role: "assistant".to_string(),
274                    content,
275                    reasoning_content,
276                },
277                usage: TokenUsage {
278                    prompt_tokens: parsed.usage.prompt_tokens,
279                    completion_tokens: parsed.usage.completion_tokens,
280                    total_tokens: parsed.usage.total_tokens,
281                    cache_read_tokens: None,
282                    cache_write_tokens: None,
283                },
284                stop_reason: choice.finish_reason,
285            };
286
287            crate::telemetry::record_llm_usage(
288                llm_response.usage.prompt_tokens,
289                llm_response.usage.completion_tokens,
290                llm_response.usage.total_tokens,
291                llm_response.stop_reason.as_deref(),
292            );
293
294            Ok(llm_response)
295        }
296    }
297
298    async fn complete_streaming(
299        &self,
300        messages: &[Message],
301        system: Option<&str>,
302        tools: &[ToolDefinition],
303    ) -> Result<mpsc::Receiver<StreamEvent>> {
304        {
305            let mut openai_messages = Vec::new();
306
307            if let Some(sys) = system {
308                openai_messages.push(serde_json::json!({
309                    "role": "system",
310                    "content": sys,
311                }));
312            }
313
314            openai_messages.extend(self.convert_messages(messages));
315
316            let mut request = serde_json::json!({
317                "model": self.model,
318                "messages": openai_messages,
319                "stream": true,
320                "stream_options": { "include_usage": true },
321            });
322
323            if !tools.is_empty() {
324                request["tools"] = serde_json::json!(self.convert_tools(tools));
325            }
326
327            let url = format!("{}/v1/chat/completions", self.base_url);
328            let auth_header = format!("Bearer {}", self.api_key.expose());
329            let headers = vec![("Authorization", auth_header.as_str())];
330
331            let streaming_resp = crate::retry::with_retry(&self.retry_config, |_attempt| {
332                let http = &self.http;
333                let url = &url;
334                let headers = headers.clone();
335                let request = &request;
336                async move {
337                    match http.post_streaming(url, headers, request).await {
338                        Ok(resp) => {
339                            let status = reqwest::StatusCode::from_u16(resp.status)
340                                .unwrap_or(reqwest::StatusCode::INTERNAL_SERVER_ERROR);
341                            if status.is_success() {
342                                AttemptOutcome::Success(resp)
343                            } else {
344                                let retry_after = resp
345                                    .retry_after
346                                    .as_deref()
347                                    .and_then(|v| RetryConfig::parse_retry_after(Some(v)));
348                                if self.retry_config.is_retryable_status(status) {
349                                    AttemptOutcome::Retryable {
350                                        status,
351                                        body: resp.error_body,
352                                        retry_after,
353                                    }
354                                } else {
355                                    AttemptOutcome::Fatal(anyhow::anyhow!(
356                                        "OpenAI API error at {} ({}): {}",
357                                        url,
358                                        status,
359                                        resp.error_body
360                                    ))
361                                }
362                            }
363                        }
364                        Err(e) => AttemptOutcome::Fatal(anyhow::anyhow!(
365                            "Failed to send streaming request: {}",
366                            e
367                        )),
368                    }
369                }
370            })
371            .await?;
372
373            let (tx, rx) = mpsc::channel(100);
374
375            let mut stream = streaming_resp.byte_stream;
376            tokio::spawn(async move {
377                let mut buffer = String::new();
378                let mut content_blocks: Vec<ContentBlock> = Vec::new();
379                let mut text_content = String::new();
380                let mut reasoning_content_accum = String::new();
381                let mut tool_calls: std::collections::BTreeMap<usize, (String, String, String)> =
382                    std::collections::BTreeMap::new();
383                let mut usage = TokenUsage::default();
384                let mut finish_reason = None;
385
386                while let Some(chunk_result) = stream.next().await {
387                    let chunk = match chunk_result {
388                        Ok(c) => c,
389                        Err(e) => {
390                            tracing::error!("Stream error: {}", e);
391                            break;
392                        }
393                    };
394
395                    buffer.push_str(&String::from_utf8_lossy(&chunk));
396
397                    while let Some(event_end) = buffer.find("\n\n") {
398                        let event_data: String = buffer.drain(..event_end).collect();
399                        buffer.drain(..2);
400
401                        for line in event_data.lines() {
402                            if let Some(data) = line.strip_prefix("data: ") {
403                                if data == "[DONE]" {
404                                    if !text_content.is_empty() {
405                                        content_blocks.push(ContentBlock::Text {
406                                            text: text_content.clone(),
407                                        });
408                                    }
409                                    for (_, (id, name, args)) in tool_calls.iter() {
410                                        content_blocks.push(ContentBlock::ToolUse {
411                                        id: id.clone(),
412                                        name: name.clone(),
413                                        input: serde_json::from_str(args).unwrap_or_else(|e| {
414                                            tracing::warn!(
415                                                "Failed to parse tool arguments JSON for tool '{}': {}",
416                                                name, e
417                                            );
418                                            serde_json::Value::default()
419                                        }),
420                                    });
421                                    }
422                                    tool_calls.clear();
423                                    crate::telemetry::record_llm_usage(
424                                        usage.prompt_tokens,
425                                        usage.completion_tokens,
426                                        usage.total_tokens,
427                                        finish_reason.as_deref(),
428                                    );
429                                    let response = LlmResponse {
430                                        message: Message {
431                                            role: "assistant".to_string(),
432                                            content: std::mem::take(&mut content_blocks),
433                                            reasoning_content: if reasoning_content_accum.is_empty()
434                                            {
435                                                None
436                                            } else {
437                                                Some(std::mem::take(&mut reasoning_content_accum))
438                                            },
439                                        },
440                                        usage: usage.clone(),
441                                        stop_reason: std::mem::take(&mut finish_reason),
442                                    };
443                                    let _ = tx.send(StreamEvent::Done(response)).await;
444                                    continue;
445                                }
446
447                                if let Ok(event) = serde_json::from_str::<OpenAiStreamChunk>(data) {
448                                    if let Some(u) = event.usage {
449                                        usage.prompt_tokens = u.prompt_tokens;
450                                        usage.completion_tokens = u.completion_tokens;
451                                        usage.total_tokens = u.total_tokens;
452                                    }
453
454                                    if let Some(choice) = event.choices.into_iter().next() {
455                                        if let Some(reason) = choice.finish_reason {
456                                            finish_reason = Some(reason);
457                                        }
458
459                                        if let Some(delta) = choice.delta {
460                                            if let Some(ref rc) = delta.reasoning_content {
461                                                reasoning_content_accum.push_str(rc);
462                                            }
463
464                                            let text_delta =
465                                                delta.content.or(delta.reasoning_content);
466                                            if let Some(content) = text_delta {
467                                                text_content.push_str(&content);
468                                                let _ =
469                                                    tx.send(StreamEvent::TextDelta(content)).await;
470                                            }
471
472                                            if let Some(tcs) = delta.tool_calls {
473                                                for tc in tcs {
474                                                    let entry = tool_calls
475                                                        .entry(tc.index)
476                                                        .or_insert_with(|| {
477                                                            (
478                                                                String::new(),
479                                                                String::new(),
480                                                                String::new(),
481                                                            )
482                                                        });
483
484                                                    if let Some(id) = tc.id {
485                                                        entry.0 = id;
486                                                    }
487                                                    if let Some(func) = tc.function {
488                                                        if let Some(name) = func.name {
489                                                            entry.1 = name.clone();
490                                                            let _ = tx
491                                                                .send(StreamEvent::ToolUseStart {
492                                                                    id: entry.0.clone(),
493                                                                    name,
494                                                                })
495                                                                .await;
496                                                        }
497                                                        if let Some(args) = func.arguments {
498                                                            entry.2.push_str(&args);
499                                                            let _ = tx
500                                                                .send(
501                                                                    StreamEvent::ToolUseInputDelta(
502                                                                        args,
503                                                                    ),
504                                                                )
505                                                                .await;
506                                                        }
507                                                    }
508                                                }
509                                            }
510                                        }
511                                    }
512                                }
513                            }
514                        }
515                    }
516                }
517            });
518
519            Ok(rx)
520        }
521    }
522}
523
524// OpenAI API response types (private)
525#[derive(Debug, Deserialize)]
526pub(crate) struct OpenAiResponse {
527    pub(crate) choices: Vec<OpenAiChoice>,
528    pub(crate) usage: OpenAiUsage,
529}
530
531#[derive(Debug, Deserialize)]
532pub(crate) struct OpenAiChoice {
533    pub(crate) message: OpenAiMessage,
534    pub(crate) finish_reason: Option<String>,
535}
536
537#[derive(Debug, Deserialize)]
538pub(crate) struct OpenAiMessage {
539    pub(crate) reasoning_content: Option<String>,
540    pub(crate) content: Option<String>,
541    pub(crate) tool_calls: Option<Vec<OpenAiToolCall>>,
542}
543
544#[derive(Debug, Deserialize)]
545pub(crate) struct OpenAiToolCall {
546    pub(crate) id: String,
547    pub(crate) function: OpenAiFunction,
548}
549
550#[derive(Debug, Deserialize)]
551pub(crate) struct OpenAiFunction {
552    pub(crate) name: String,
553    pub(crate) arguments: String,
554}
555
556#[derive(Debug, Deserialize)]
557pub(crate) struct OpenAiUsage {
558    pub(crate) prompt_tokens: usize,
559    pub(crate) completion_tokens: usize,
560    pub(crate) total_tokens: usize,
561}
562
563// OpenAI streaming types
564#[derive(Debug, Deserialize)]
565pub(crate) struct OpenAiStreamChunk {
566    pub(crate) choices: Vec<OpenAiStreamChoice>,
567    pub(crate) usage: Option<OpenAiUsage>,
568}
569
570#[derive(Debug, Deserialize)]
571pub(crate) struct OpenAiStreamChoice {
572    pub(crate) delta: Option<OpenAiDelta>,
573    pub(crate) finish_reason: Option<String>,
574}
575
576#[derive(Debug, Deserialize)]
577pub(crate) struct OpenAiDelta {
578    pub(crate) reasoning_content: Option<String>,
579    pub(crate) content: Option<String>,
580    pub(crate) tool_calls: Option<Vec<OpenAiToolCallDelta>>,
581}
582
583#[derive(Debug, Deserialize)]
584pub(crate) struct OpenAiToolCallDelta {
585    pub(crate) index: usize,
586    pub(crate) id: Option<String>,
587    pub(crate) function: Option<OpenAiFunctionDelta>,
588}
589
590#[derive(Debug, Deserialize)]
591pub(crate) struct OpenAiFunctionDelta {
592    pub(crate) name: Option<String>,
593    pub(crate) arguments: Option<String>,
594}