Skip to main content

a3s_code_core/llm/
openai.rs

1//! OpenAI-compatible LLM client
2
3use super::http::{default_http_client, normalize_base_url, HttpClient};
4use super::types::*;
5use super::LlmClient;
6use crate::retry::{AttemptOutcome, RetryConfig};
7use anyhow::{Context, Result};
8use async_trait::async_trait;
9use futures::StreamExt;
10use serde::Deserialize;
11use std::sync::Arc;
12use tokio::sync::mpsc;
13
14/// OpenAI client
15pub struct OpenAiClient {
16    pub(crate) api_key: SecretString,
17    pub(crate) model: String,
18    pub(crate) base_url: String,
19    pub(crate) http: Arc<dyn HttpClient>,
20    pub(crate) retry_config: RetryConfig,
21}
22
23impl OpenAiClient {
24    pub fn new(api_key: String, model: String) -> Self {
25        Self {
26            api_key: SecretString::new(api_key),
27            model,
28            base_url: "https://api.openai.com".to_string(),
29            http: default_http_client(),
30            retry_config: RetryConfig::default(),
31        }
32    }
33
34    pub fn with_base_url(mut self, base_url: String) -> Self {
35        self.base_url = normalize_base_url(&base_url);
36        self
37    }
38
39    pub fn with_retry_config(mut self, retry_config: RetryConfig) -> Self {
40        self.retry_config = retry_config;
41        self
42    }
43
44    pub fn with_http_client(mut self, http: Arc<dyn HttpClient>) -> Self {
45        self.http = http;
46        self
47    }
48
49    pub(crate) fn convert_messages(&self, messages: &[Message]) -> Vec<serde_json::Value> {
50        messages
51            .iter()
52            .map(|msg| {
53                let content: serde_json::Value = if msg.content.len() == 1 {
54                    match &msg.content[0] {
55                        ContentBlock::Text { text } => serde_json::json!(text),
56                        ContentBlock::ToolResult {
57                            tool_use_id,
58                            content,
59                            ..
60                        } => {
61                            return serde_json::json!({
62                                "role": "tool",
63                                "tool_call_id": tool_use_id,
64                                "content": content,
65                            });
66                        }
67                        _ => serde_json::json!(""),
68                    }
69                } else {
70                    serde_json::json!(msg
71                        .content
72                        .iter()
73                        .map(|block| {
74                            match block {
75                                ContentBlock::Text { text } => serde_json::json!({
76                                    "type": "text",
77                                    "text": text,
78                                }),
79                                ContentBlock::ToolUse { id, name, input } => serde_json::json!({
80                                    "type": "function",
81                                    "id": id,
82                                    "function": {
83                                        "name": name,
84                                        "arguments": input.to_string(),
85                                    }
86                                }),
87                                _ => serde_json::json!({}),
88                            }
89                        })
90                        .collect::<Vec<_>>())
91                };
92
93                // Handle assistant messages — kimi-k2.5 requires reasoning_content
94                // on all assistant messages when thinking mode is enabled
95                if msg.role == "assistant" {
96                    let rc = msg.reasoning_content.as_deref().unwrap_or("");
97                    let tool_calls: Vec<_> = msg.tool_calls();
98                    if !tool_calls.is_empty() {
99                        return serde_json::json!({
100                            "role": "assistant",
101                            "content": msg.text(),
102                            "reasoning_content": rc,
103                            "tool_calls": tool_calls.iter().map(|tc| {
104                                serde_json::json!({
105                                    "id": tc.id,
106                                    "type": "function",
107                                    "function": {
108                                        "name": tc.name,
109                                        "arguments": tc.args.to_string(),
110                                    }
111                                })
112                            }).collect::<Vec<_>>(),
113                        });
114                    }
115                    return serde_json::json!({
116                        "role": "assistant",
117                        "content": content,
118                        "reasoning_content": rc,
119                    });
120                }
121
122                serde_json::json!({
123                    "role": msg.role,
124                    "content": content,
125                })
126            })
127            .collect()
128    }
129
130    pub(crate) fn convert_tools(&self, tools: &[ToolDefinition]) -> Vec<serde_json::Value> {
131        tools
132            .iter()
133            .map(|t| {
134                serde_json::json!({
135                    "type": "function",
136                    "function": {
137                        "name": t.name,
138                        "description": t.description,
139                        "parameters": t.parameters,
140                    }
141                })
142            })
143            .collect()
144    }
145}
146
147#[async_trait]
148impl LlmClient for OpenAiClient {
149    async fn complete(
150        &self,
151        messages: &[Message],
152        system: Option<&str>,
153        tools: &[ToolDefinition],
154    ) -> Result<LlmResponse> {
155        {
156            let mut openai_messages = Vec::new();
157
158            if let Some(sys) = system {
159                openai_messages.push(serde_json::json!({
160                    "role": "system",
161                    "content": sys,
162                }));
163            }
164
165            openai_messages.extend(self.convert_messages(messages));
166
167            let mut request = serde_json::json!({
168                "model": self.model,
169                "messages": openai_messages,
170            });
171
172            if !tools.is_empty() {
173                request["tools"] = serde_json::json!(self.convert_tools(tools));
174            }
175
176            let url = format!("{}/v1/chat/completions", self.base_url);
177            let auth_header = format!("Bearer {}", self.api_key.expose());
178            let headers = vec![("Authorization", auth_header.as_str())];
179
180            let response = crate::retry::with_retry(&self.retry_config, |_attempt| {
181                let http = &self.http;
182                let url = &url;
183                let headers = headers.clone();
184                let request = &request;
185                async move {
186                    match http.post(url, headers, request).await {
187                        Ok(resp) => {
188                            let status = reqwest::StatusCode::from_u16(resp.status)
189                                .unwrap_or(reqwest::StatusCode::INTERNAL_SERVER_ERROR);
190                            if status.is_success() {
191                                AttemptOutcome::Success(resp.body)
192                            } else if self.retry_config.is_retryable_status(status) {
193                                AttemptOutcome::Retryable {
194                                    status,
195                                    body: resp.body,
196                                    retry_after: None,
197                                }
198                            } else {
199                                AttemptOutcome::Fatal(anyhow::anyhow!(
200                                    "OpenAI API error at {} ({}): {}",
201                                    url,
202                                    status,
203                                    resp.body
204                                ))
205                            }
206                        }
207                        Err(e) => AttemptOutcome::Fatal(e),
208                    }
209                }
210            })
211            .await?;
212
213            let parsed: OpenAiResponse =
214                serde_json::from_str(&response).context("Failed to parse OpenAI response")?;
215
216            let choice = parsed.choices.into_iter().next().context("No choices")?;
217
218            let mut content = vec![];
219
220            let reasoning_content = choice.message.reasoning_content.clone();
221
222            let text_content = choice.message.content.or(choice.message.reasoning_content);
223
224            if let Some(text) = text_content {
225                if !text.is_empty() {
226                    content.push(ContentBlock::Text { text });
227                }
228            }
229
230            if let Some(tool_calls) = choice.message.tool_calls {
231                for tc in tool_calls {
232                    content.push(ContentBlock::ToolUse {
233                        id: tc.id,
234                        name: tc.function.name.clone(),
235                        input: serde_json::from_str(&tc.function.arguments).unwrap_or_else(|e| {
236                            tracing::warn!(
237                                "Failed to parse tool arguments JSON for tool '{}': {}",
238                                tc.function.name,
239                                e
240                            );
241                            serde_json::Value::default()
242                        }),
243                    });
244                }
245            }
246
247            let llm_response = LlmResponse {
248                message: Message {
249                    role: "assistant".to_string(),
250                    content,
251                    reasoning_content,
252                },
253                usage: TokenUsage {
254                    prompt_tokens: parsed.usage.prompt_tokens,
255                    completion_tokens: parsed.usage.completion_tokens,
256                    total_tokens: parsed.usage.total_tokens,
257                    cache_read_tokens: None,
258                    cache_write_tokens: None,
259                },
260                stop_reason: choice.finish_reason,
261            };
262
263            crate::telemetry::record_llm_usage(
264                llm_response.usage.prompt_tokens,
265                llm_response.usage.completion_tokens,
266                llm_response.usage.total_tokens,
267                llm_response.stop_reason.as_deref(),
268            );
269
270            Ok(llm_response)
271        }
272    }
273
274    async fn complete_streaming(
275        &self,
276        messages: &[Message],
277        system: Option<&str>,
278        tools: &[ToolDefinition],
279    ) -> Result<mpsc::Receiver<StreamEvent>> {
280        {
281            let mut openai_messages = Vec::new();
282
283            if let Some(sys) = system {
284                openai_messages.push(serde_json::json!({
285                    "role": "system",
286                    "content": sys,
287                }));
288            }
289
290            openai_messages.extend(self.convert_messages(messages));
291
292            let mut request = serde_json::json!({
293                "model": self.model,
294                "messages": openai_messages,
295                "stream": true,
296                "stream_options": { "include_usage": true },
297            });
298
299            if !tools.is_empty() {
300                request["tools"] = serde_json::json!(self.convert_tools(tools));
301            }
302
303            let url = format!("{}/v1/chat/completions", self.base_url);
304            let auth_header = format!("Bearer {}", self.api_key.expose());
305            let headers = vec![("Authorization", auth_header.as_str())];
306
307            let streaming_resp = crate::retry::with_retry(&self.retry_config, |_attempt| {
308                let http = &self.http;
309                let url = &url;
310                let headers = headers.clone();
311                let request = &request;
312                async move {
313                    match http.post_streaming(url, headers, request).await {
314                        Ok(resp) => {
315                            let status = reqwest::StatusCode::from_u16(resp.status)
316                                .unwrap_or(reqwest::StatusCode::INTERNAL_SERVER_ERROR);
317                            if status.is_success() {
318                                AttemptOutcome::Success(resp)
319                            } else {
320                                let retry_after = resp
321                                    .retry_after
322                                    .as_deref()
323                                    .and_then(|v| RetryConfig::parse_retry_after(Some(v)));
324                                if self.retry_config.is_retryable_status(status) {
325                                    AttemptOutcome::Retryable {
326                                        status,
327                                        body: resp.error_body,
328                                        retry_after,
329                                    }
330                                } else {
331                                    AttemptOutcome::Fatal(anyhow::anyhow!(
332                                        "OpenAI API error at {} ({}): {}",
333                                        url,
334                                        status,
335                                        resp.error_body
336                                    ))
337                                }
338                            }
339                        }
340                        Err(e) => AttemptOutcome::Fatal(anyhow::anyhow!(
341                            "Failed to send streaming request: {}",
342                            e
343                        )),
344                    }
345                }
346            })
347            .await?;
348
349            let (tx, rx) = mpsc::channel(100);
350
351            let mut stream = streaming_resp.byte_stream;
352            tokio::spawn(async move {
353                let mut buffer = String::new();
354                let mut content_blocks: Vec<ContentBlock> = Vec::new();
355                let mut text_content = String::new();
356                let mut reasoning_content_accum = String::new();
357                let mut tool_calls: std::collections::BTreeMap<usize, (String, String, String)> =
358                    std::collections::BTreeMap::new();
359                let mut usage = TokenUsage::default();
360                let mut finish_reason = None;
361
362                while let Some(chunk_result) = stream.next().await {
363                    let chunk = match chunk_result {
364                        Ok(c) => c,
365                        Err(e) => {
366                            tracing::error!("Stream error: {}", e);
367                            break;
368                        }
369                    };
370
371                    buffer.push_str(&String::from_utf8_lossy(&chunk));
372
373                    while let Some(event_end) = buffer.find("\n\n") {
374                        let event_data: String = buffer.drain(..event_end).collect();
375                        buffer.drain(..2);
376
377                        for line in event_data.lines() {
378                            if let Some(data) = line.strip_prefix("data: ") {
379                                if data == "[DONE]" {
380                                    if !text_content.is_empty() {
381                                        content_blocks.push(ContentBlock::Text {
382                                            text: text_content.clone(),
383                                        });
384                                    }
385                                    for (_, (id, name, args)) in tool_calls.iter() {
386                                        content_blocks.push(ContentBlock::ToolUse {
387                                        id: id.clone(),
388                                        name: name.clone(),
389                                        input: serde_json::from_str(args).unwrap_or_else(|e| {
390                                            tracing::warn!(
391                                                "Failed to parse tool arguments JSON for tool '{}': {}",
392                                                name, e
393                                            );
394                                            serde_json::Value::default()
395                                        }),
396                                    });
397                                    }
398                                    tool_calls.clear();
399                                    crate::telemetry::record_llm_usage(
400                                        usage.prompt_tokens,
401                                        usage.completion_tokens,
402                                        usage.total_tokens,
403                                        finish_reason.as_deref(),
404                                    );
405                                    let response = LlmResponse {
406                                        message: Message {
407                                            role: "assistant".to_string(),
408                                            content: std::mem::take(&mut content_blocks),
409                                            reasoning_content: if reasoning_content_accum.is_empty()
410                                            {
411                                                None
412                                            } else {
413                                                Some(std::mem::take(&mut reasoning_content_accum))
414                                            },
415                                        },
416                                        usage: usage.clone(),
417                                        stop_reason: std::mem::take(&mut finish_reason),
418                                    };
419                                    let _ = tx.send(StreamEvent::Done(response)).await;
420                                    continue;
421                                }
422
423                                if let Ok(event) = serde_json::from_str::<OpenAiStreamChunk>(data) {
424                                    if let Some(u) = event.usage {
425                                        usage.prompt_tokens = u.prompt_tokens;
426                                        usage.completion_tokens = u.completion_tokens;
427                                        usage.total_tokens = u.total_tokens;
428                                    }
429
430                                    if let Some(choice) = event.choices.into_iter().next() {
431                                        if let Some(reason) = choice.finish_reason {
432                                            finish_reason = Some(reason);
433                                        }
434
435                                        if let Some(delta) = choice.delta {
436                                            if let Some(ref rc) = delta.reasoning_content {
437                                                reasoning_content_accum.push_str(rc);
438                                            }
439
440                                            let text_delta =
441                                                delta.content.or(delta.reasoning_content);
442                                            if let Some(content) = text_delta {
443                                                text_content.push_str(&content);
444                                                let _ =
445                                                    tx.send(StreamEvent::TextDelta(content)).await;
446                                            }
447
448                                            if let Some(tcs) = delta.tool_calls {
449                                                for tc in tcs {
450                                                    let entry = tool_calls
451                                                        .entry(tc.index)
452                                                        .or_insert_with(|| {
453                                                            (
454                                                                String::new(),
455                                                                String::new(),
456                                                                String::new(),
457                                                            )
458                                                        });
459
460                                                    if let Some(id) = tc.id {
461                                                        entry.0 = id;
462                                                    }
463                                                    if let Some(func) = tc.function {
464                                                        if let Some(name) = func.name {
465                                                            entry.1 = name.clone();
466                                                            let _ = tx
467                                                                .send(StreamEvent::ToolUseStart {
468                                                                    id: entry.0.clone(),
469                                                                    name,
470                                                                })
471                                                                .await;
472                                                        }
473                                                        if let Some(args) = func.arguments {
474                                                            entry.2.push_str(&args);
475                                                            let _ = tx
476                                                                .send(
477                                                                    StreamEvent::ToolUseInputDelta(
478                                                                        args,
479                                                                    ),
480                                                                )
481                                                                .await;
482                                                        }
483                                                    }
484                                                }
485                                            }
486                                        }
487                                    }
488                                }
489                            }
490                        }
491                    }
492                }
493            });
494
495            Ok(rx)
496        }
497    }
498}
499
500// OpenAI API response types (private)
501#[derive(Debug, Deserialize)]
502pub(crate) struct OpenAiResponse {
503    pub(crate) choices: Vec<OpenAiChoice>,
504    pub(crate) usage: OpenAiUsage,
505}
506
507#[derive(Debug, Deserialize)]
508pub(crate) struct OpenAiChoice {
509    pub(crate) message: OpenAiMessage,
510    pub(crate) finish_reason: Option<String>,
511}
512
513#[derive(Debug, Deserialize)]
514pub(crate) struct OpenAiMessage {
515    pub(crate) reasoning_content: Option<String>,
516    pub(crate) content: Option<String>,
517    pub(crate) tool_calls: Option<Vec<OpenAiToolCall>>,
518}
519
520#[derive(Debug, Deserialize)]
521pub(crate) struct OpenAiToolCall {
522    pub(crate) id: String,
523    pub(crate) function: OpenAiFunction,
524}
525
526#[derive(Debug, Deserialize)]
527pub(crate) struct OpenAiFunction {
528    pub(crate) name: String,
529    pub(crate) arguments: String,
530}
531
532#[derive(Debug, Deserialize)]
533pub(crate) struct OpenAiUsage {
534    pub(crate) prompt_tokens: usize,
535    pub(crate) completion_tokens: usize,
536    pub(crate) total_tokens: usize,
537}
538
539// OpenAI streaming types
540#[derive(Debug, Deserialize)]
541pub(crate) struct OpenAiStreamChunk {
542    pub(crate) choices: Vec<OpenAiStreamChoice>,
543    pub(crate) usage: Option<OpenAiUsage>,
544}
545
546#[derive(Debug, Deserialize)]
547pub(crate) struct OpenAiStreamChoice {
548    pub(crate) delta: Option<OpenAiDelta>,
549    pub(crate) finish_reason: Option<String>,
550}
551
552#[derive(Debug, Deserialize)]
553pub(crate) struct OpenAiDelta {
554    pub(crate) reasoning_content: Option<String>,
555    pub(crate) content: Option<String>,
556    pub(crate) tool_calls: Option<Vec<OpenAiToolCallDelta>>,
557}
558
559#[derive(Debug, Deserialize)]
560pub(crate) struct OpenAiToolCallDelta {
561    pub(crate) index: usize,
562    pub(crate) id: Option<String>,
563    pub(crate) function: Option<OpenAiFunctionDelta>,
564}
565
566#[derive(Debug, Deserialize)]
567pub(crate) struct OpenAiFunctionDelta {
568    pub(crate) name: Option<String>,
569    pub(crate) arguments: Option<String>,
570}