Skip to main content

a3s_code_core/llm/
openai.rs

1//! OpenAI-compatible LLM client
2
3use super::http::{default_http_client, normalize_base_url, HttpClient};
4use super::types::*;
5use super::LlmClient;
6use crate::llm::types::{ToolResultContent, ToolResultContentField};
7use crate::retry::{AttemptOutcome, RetryConfig};
8use anyhow::{Context, Result};
9use async_trait::async_trait;
10use futures::StreamExt;
11use serde::Deserialize;
12use std::sync::Arc;
13use tokio::sync::mpsc;
14
15/// OpenAI client
16pub struct OpenAiClient {
17    pub(crate) api_key: SecretString,
18    pub(crate) model: String,
19    pub(crate) base_url: String,
20    pub(crate) temperature: Option<f32>,
21    pub(crate) max_tokens: Option<usize>,
22    pub(crate) http: Arc<dyn HttpClient>,
23    pub(crate) retry_config: RetryConfig,
24}
25
26impl OpenAiClient {
27    pub fn new(api_key: String, model: String) -> Self {
28        Self {
29            api_key: SecretString::new(api_key),
30            model,
31            base_url: "https://api.openai.com".to_string(),
32            temperature: None,
33            max_tokens: None,
34            http: default_http_client(),
35            retry_config: RetryConfig::default(),
36        }
37    }
38
39    pub fn with_base_url(mut self, base_url: String) -> Self {
40        self.base_url = normalize_base_url(&base_url);
41        self
42    }
43
44    pub fn with_temperature(mut self, temperature: f32) -> Self {
45        self.temperature = Some(temperature);
46        self
47    }
48
49    pub fn with_max_tokens(mut self, max_tokens: usize) -> Self {
50        self.max_tokens = Some(max_tokens);
51        self
52    }
53
54    pub fn with_retry_config(mut self, retry_config: RetryConfig) -> Self {
55        self.retry_config = retry_config;
56        self
57    }
58
59    pub fn with_http_client(mut self, http: Arc<dyn HttpClient>) -> Self {
60        self.http = http;
61        self
62    }
63
64    pub(crate) fn convert_messages(&self, messages: &[Message]) -> Vec<serde_json::Value> {
65        messages
66            .iter()
67            .map(|msg| {
68                let content: serde_json::Value = if msg.content.len() == 1 {
69                    match &msg.content[0] {
70                        ContentBlock::Text { text } => serde_json::json!(text),
71                        ContentBlock::ToolResult {
72                            tool_use_id,
73                            content,
74                            ..
75                        } => {
76                            let content_str = match content {
77                                ToolResultContentField::Text(s) => s.clone(),
78                                ToolResultContentField::Blocks(blocks) => blocks
79                                    .iter()
80                                    .filter_map(|b| {
81                                        if let ToolResultContent::Text { text } = b {
82                                            Some(text.clone())
83                                        } else {
84                                            None
85                                        }
86                                    })
87                                    .collect::<Vec<_>>()
88                                    .join("\n"),
89                            };
90                            return serde_json::json!({
91                                "role": "tool",
92                                "tool_call_id": tool_use_id,
93                                "content": content_str,
94                            });
95                        }
96                        _ => serde_json::json!(""),
97                    }
98                } else {
99                    serde_json::json!(msg
100                        .content
101                        .iter()
102                        .map(|block| {
103                            match block {
104                                ContentBlock::Text { text } => serde_json::json!({
105                                    "type": "text",
106                                    "text": text,
107                                }),
108                                ContentBlock::Image { source } => serde_json::json!({
109                                    "type": "image_url",
110                                    "image_url": {
111                                        "url": format!(
112                                            "data:{};base64,{}",
113                                            source.media_type, source.data
114                                        ),
115                                    }
116                                }),
117                                ContentBlock::ToolUse { id, name, input } => serde_json::json!({
118                                    "type": "function",
119                                    "id": id,
120                                    "function": {
121                                        "name": name,
122                                        "arguments": input.to_string(),
123                                    }
124                                }),
125                                _ => serde_json::json!({}),
126                            }
127                        })
128                        .collect::<Vec<_>>())
129                };
130
131                // Handle assistant messages — kimi-k2.5 requires reasoning_content
132                // on all assistant messages when thinking mode is enabled
133                if msg.role == "assistant" {
134                    let rc = msg.reasoning_content.as_deref().unwrap_or("");
135                    let tool_calls: Vec<_> = msg.tool_calls();
136                    if !tool_calls.is_empty() {
137                        return serde_json::json!({
138                            "role": "assistant",
139                            "content": msg.text(),
140                            "reasoning_content": rc,
141                            "tool_calls": tool_calls.iter().map(|tc| {
142                                serde_json::json!({
143                                    "id": tc.id,
144                                    "type": "function",
145                                    "function": {
146                                        "name": tc.name,
147                                        "arguments": tc.args.to_string(),
148                                    }
149                                })
150                            }).collect::<Vec<_>>(),
151                        });
152                    }
153                    return serde_json::json!({
154                        "role": "assistant",
155                        "content": content,
156                        "reasoning_content": rc,
157                    });
158                }
159
160                serde_json::json!({
161                    "role": msg.role,
162                    "content": content,
163                })
164            })
165            .collect()
166    }
167
168    pub(crate) fn convert_tools(&self, tools: &[ToolDefinition]) -> Vec<serde_json::Value> {
169        tools
170            .iter()
171            .map(|t| {
172                serde_json::json!({
173                    "type": "function",
174                    "function": {
175                        "name": t.name,
176                        "description": t.description,
177                        "parameters": t.parameters,
178                    }
179                })
180            })
181            .collect()
182    }
183}
184
185#[async_trait]
186impl LlmClient for OpenAiClient {
187    async fn complete(
188        &self,
189        messages: &[Message],
190        system: Option<&str>,
191        tools: &[ToolDefinition],
192    ) -> Result<LlmResponse> {
193        {
194            let mut openai_messages = Vec::new();
195
196            if let Some(sys) = system {
197                openai_messages.push(serde_json::json!({
198                    "role": "system",
199                    "content": sys,
200                }));
201            }
202
203            openai_messages.extend(self.convert_messages(messages));
204
205            let mut request = serde_json::json!({
206                "model": self.model,
207                "messages": openai_messages,
208            });
209
210            if let Some(temp) = self.temperature {
211                request["temperature"] = serde_json::json!(temp);
212            }
213            if let Some(max) = self.max_tokens {
214                request["max_tokens"] = serde_json::json!(max);
215            }
216
217            if !tools.is_empty() {
218                request["tools"] = serde_json::json!(self.convert_tools(tools));
219            }
220
221            let url = format!("{}/v1/chat/completions", self.base_url);
222            let auth_header = format!("Bearer {}", self.api_key.expose());
223            let headers = vec![("Authorization", auth_header.as_str())];
224
225            let response = crate::retry::with_retry(&self.retry_config, |_attempt| {
226                let http = &self.http;
227                let url = &url;
228                let headers = headers.clone();
229                let request = &request;
230                async move {
231                    match http.post(url, headers, request).await {
232                        Ok(resp) => {
233                            let status = reqwest::StatusCode::from_u16(resp.status)
234                                .unwrap_or(reqwest::StatusCode::INTERNAL_SERVER_ERROR);
235                            if status.is_success() {
236                                AttemptOutcome::Success(resp.body)
237                            } else if self.retry_config.is_retryable_status(status) {
238                                AttemptOutcome::Retryable {
239                                    status,
240                                    body: resp.body,
241                                    retry_after: None,
242                                }
243                            } else {
244                                AttemptOutcome::Fatal(anyhow::anyhow!(
245                                    "OpenAI API error at {} ({}): {}",
246                                    url,
247                                    status,
248                                    resp.body
249                                ))
250                            }
251                        }
252                        Err(e) => AttemptOutcome::Fatal(e),
253                    }
254                }
255            })
256            .await?;
257
258            let parsed: OpenAiResponse =
259                serde_json::from_str(&response).context("Failed to parse OpenAI response")?;
260
261            let choice = parsed.choices.into_iter().next().context("No choices")?;
262
263            let mut content = vec![];
264
265            let reasoning_content = choice.message.reasoning_content;
266
267            let text_content = choice.message.content;
268
269            if let Some(text) = text_content {
270                if !text.is_empty() {
271                    content.push(ContentBlock::Text { text });
272                }
273            }
274
275            if let Some(tool_calls) = choice.message.tool_calls {
276                for tc in tool_calls {
277                    content.push(ContentBlock::ToolUse {
278                        id: tc.id,
279                        name: tc.function.name.clone(),
280                        input: serde_json::from_str(&tc.function.arguments).unwrap_or_else(|e| {
281                            tracing::warn!(
282                                "Failed to parse tool arguments JSON for tool '{}': {}",
283                                tc.function.name,
284                                e
285                            );
286                            serde_json::Value::default()
287                        }),
288                    });
289                }
290            }
291
292            let llm_response = LlmResponse {
293                message: Message {
294                    role: "assistant".to_string(),
295                    content,
296                    reasoning_content,
297                },
298                usage: TokenUsage {
299                    prompt_tokens: parsed.usage.prompt_tokens,
300                    completion_tokens: parsed.usage.completion_tokens,
301                    total_tokens: parsed.usage.total_tokens,
302                    cache_read_tokens: parsed
303                        .usage
304                        .prompt_tokens_details
305                        .as_ref()
306                        .and_then(|d| d.cached_tokens),
307                    cache_write_tokens: None,
308                },
309                stop_reason: choice.finish_reason,
310            };
311
312            crate::telemetry::record_llm_usage(
313                llm_response.usage.prompt_tokens,
314                llm_response.usage.completion_tokens,
315                llm_response.usage.total_tokens,
316                llm_response.stop_reason.as_deref(),
317            );
318
319            Ok(llm_response)
320        }
321    }
322
323    async fn complete_streaming(
324        &self,
325        messages: &[Message],
326        system: Option<&str>,
327        tools: &[ToolDefinition],
328    ) -> Result<mpsc::Receiver<StreamEvent>> {
329        {
330            let mut openai_messages = Vec::new();
331
332            if let Some(sys) = system {
333                openai_messages.push(serde_json::json!({
334                    "role": "system",
335                    "content": sys,
336                }));
337            }
338
339            openai_messages.extend(self.convert_messages(messages));
340
341            let mut request = serde_json::json!({
342                "model": self.model,
343                "messages": openai_messages,
344                "stream": true,
345                "stream_options": { "include_usage": true },
346            });
347
348            if let Some(temp) = self.temperature {
349                request["temperature"] = serde_json::json!(temp);
350            }
351            if let Some(max) = self.max_tokens {
352                request["max_tokens"] = serde_json::json!(max);
353            }
354
355            if !tools.is_empty() {
356                request["tools"] = serde_json::json!(self.convert_tools(tools));
357            }
358
359            let url = format!("{}/v1/chat/completions", self.base_url);
360            let auth_header = format!("Bearer {}", self.api_key.expose());
361            let headers = vec![("Authorization", auth_header.as_str())];
362
363            let streaming_resp = crate::retry::with_retry(&self.retry_config, |_attempt| {
364                let http = &self.http;
365                let url = &url;
366                let headers = headers.clone();
367                let request = &request;
368                async move {
369                    match http.post_streaming(url, headers, request).await {
370                        Ok(resp) => {
371                            let status = reqwest::StatusCode::from_u16(resp.status)
372                                .unwrap_or(reqwest::StatusCode::INTERNAL_SERVER_ERROR);
373                            if status.is_success() {
374                                AttemptOutcome::Success(resp)
375                            } else {
376                                let retry_after = resp
377                                    .retry_after
378                                    .as_deref()
379                                    .and_then(|v| RetryConfig::parse_retry_after(Some(v)));
380                                if self.retry_config.is_retryable_status(status) {
381                                    AttemptOutcome::Retryable {
382                                        status,
383                                        body: resp.error_body,
384                                        retry_after,
385                                    }
386                                } else {
387                                    AttemptOutcome::Fatal(anyhow::anyhow!(
388                                        "OpenAI API error at {} ({}): {}",
389                                        url,
390                                        status,
391                                        resp.error_body
392                                    ))
393                                }
394                            }
395                        }
396                        Err(e) => AttemptOutcome::Fatal(anyhow::anyhow!(
397                            "Failed to send streaming request: {}",
398                            e
399                        )),
400                    }
401                }
402            })
403            .await?;
404
405            let (tx, rx) = mpsc::channel(100);
406
407            let mut stream = streaming_resp.byte_stream;
408            tokio::spawn(async move {
409                let mut buffer = String::new();
410                let mut content_blocks: Vec<ContentBlock> = Vec::new();
411                let mut text_content = String::new();
412                let mut reasoning_content_accum = String::new();
413                let mut tool_calls: std::collections::BTreeMap<usize, (String, String, String)> =
414                    std::collections::BTreeMap::new();
415                let mut usage = TokenUsage::default();
416                let mut finish_reason = None;
417
418                while let Some(chunk_result) = stream.next().await {
419                    let chunk = match chunk_result {
420                        Ok(c) => c,
421                        Err(e) => {
422                            tracing::error!("Stream error: {}", e);
423                            break;
424                        }
425                    };
426
427                    buffer.push_str(&String::from_utf8_lossy(&chunk));
428
429                    while let Some(event_end) = buffer.find("\n\n") {
430                        let event_data: String = buffer.drain(..event_end).collect();
431                        buffer.drain(..2);
432
433                        for line in event_data.lines() {
434                            if let Some(data) = line.strip_prefix("data: ") {
435                                if data == "[DONE]" {
436                                    if !text_content.is_empty() {
437                                        content_blocks.push(ContentBlock::Text {
438                                            text: text_content.clone(),
439                                        });
440                                    }
441                                    for (_, (id, name, args)) in tool_calls.iter() {
442                                        content_blocks.push(ContentBlock::ToolUse {
443                                        id: id.clone(),
444                                        name: name.clone(),
445                                        input: serde_json::from_str(args).unwrap_or_else(|e| {
446                                            tracing::warn!(
447                                                "Failed to parse tool arguments JSON for tool '{}': {}",
448                                                name, e
449                                            );
450                                            serde_json::Value::default()
451                                        }),
452                                    });
453                                    }
454                                    tool_calls.clear();
455                                    crate::telemetry::record_llm_usage(
456                                        usage.prompt_tokens,
457                                        usage.completion_tokens,
458                                        usage.total_tokens,
459                                        finish_reason.as_deref(),
460                                    );
461                                    let response = LlmResponse {
462                                        message: Message {
463                                            role: "assistant".to_string(),
464                                            content: std::mem::take(&mut content_blocks),
465                                            reasoning_content: if reasoning_content_accum.is_empty()
466                                            {
467                                                None
468                                            } else {
469                                                Some(std::mem::take(&mut reasoning_content_accum))
470                                            },
471                                        },
472                                        usage: usage.clone(),
473                                        stop_reason: std::mem::take(&mut finish_reason),
474                                    };
475                                    let _ = tx.send(StreamEvent::Done(response)).await;
476                                    continue;
477                                }
478
479                                if let Ok(event) = serde_json::from_str::<OpenAiStreamChunk>(data) {
480                                    if let Some(u) = event.usage {
481                                        usage.prompt_tokens = u.prompt_tokens;
482                                        usage.completion_tokens = u.completion_tokens;
483                                        usage.total_tokens = u.total_tokens;
484                                        usage.cache_read_tokens = u
485                                            .prompt_tokens_details
486                                            .as_ref()
487                                            .and_then(|d| d.cached_tokens);
488                                    }
489
490                                    if let Some(choice) = event.choices.into_iter().next() {
491                                        if let Some(reason) = choice.finish_reason {
492                                            finish_reason = Some(reason);
493                                        }
494
495                                        if let Some(delta) = choice.delta {
496                                            if let Some(ref rc) = delta.reasoning_content {
497                                                reasoning_content_accum.push_str(rc);
498                                            }
499
500                                            if let Some(content) = delta.content {
501                                                text_content.push_str(&content);
502                                                let _ =
503                                                    tx.send(StreamEvent::TextDelta(content)).await;
504                                            }
505
506                                            if let Some(tcs) = delta.tool_calls {
507                                                for tc in tcs {
508                                                    let entry = tool_calls
509                                                        .entry(tc.index)
510                                                        .or_insert_with(|| {
511                                                            (
512                                                                String::new(),
513                                                                String::new(),
514                                                                String::new(),
515                                                            )
516                                                        });
517
518                                                    if let Some(id) = tc.id {
519                                                        entry.0 = id;
520                                                    }
521                                                    if let Some(func) = tc.function {
522                                                        if let Some(name) = func.name {
523                                                            entry.1 = name.clone();
524                                                            let _ = tx
525                                                                .send(StreamEvent::ToolUseStart {
526                                                                    id: entry.0.clone(),
527                                                                    name,
528                                                                })
529                                                                .await;
530                                                        }
531                                                        if let Some(args) = func.arguments {
532                                                            entry.2.push_str(&args);
533                                                            let _ = tx
534                                                                .send(
535                                                                    StreamEvent::ToolUseInputDelta(
536                                                                        args,
537                                                                    ),
538                                                                )
539                                                                .await;
540                                                        }
541                                                    }
542                                                }
543                                            }
544                                        }
545                                    }
546                                }
547                            }
548                        }
549                    }
550                }
551            });
552
553            Ok(rx)
554        }
555    }
556}
557
558// OpenAI API response types (private)
559#[derive(Debug, Deserialize)]
560pub(crate) struct OpenAiResponse {
561    pub(crate) choices: Vec<OpenAiChoice>,
562    pub(crate) usage: OpenAiUsage,
563}
564
565#[derive(Debug, Deserialize)]
566pub(crate) struct OpenAiChoice {
567    pub(crate) message: OpenAiMessage,
568    pub(crate) finish_reason: Option<String>,
569}
570
571#[derive(Debug, Deserialize)]
572pub(crate) struct OpenAiMessage {
573    pub(crate) reasoning_content: Option<String>,
574    pub(crate) content: Option<String>,
575    pub(crate) tool_calls: Option<Vec<OpenAiToolCall>>,
576}
577
578#[derive(Debug, Deserialize)]
579pub(crate) struct OpenAiToolCall {
580    pub(crate) id: String,
581    pub(crate) function: OpenAiFunction,
582}
583
584#[derive(Debug, Deserialize)]
585pub(crate) struct OpenAiFunction {
586    pub(crate) name: String,
587    pub(crate) arguments: String,
588}
589
590#[derive(Debug, Deserialize)]
591pub(crate) struct OpenAiUsage {
592    pub(crate) prompt_tokens: usize,
593    pub(crate) completion_tokens: usize,
594    pub(crate) total_tokens: usize,
595    /// OpenAI returns cached token count in `prompt_tokens_details.cached_tokens`
596    #[serde(default)]
597    pub(crate) prompt_tokens_details: Option<OpenAiPromptTokensDetails>,
598}
599
600#[derive(Debug, Deserialize)]
601pub(crate) struct OpenAiPromptTokensDetails {
602    #[serde(default)]
603    pub(crate) cached_tokens: Option<usize>,
604}
605
606// OpenAI streaming types
607#[derive(Debug, Deserialize)]
608pub(crate) struct OpenAiStreamChunk {
609    pub(crate) choices: Vec<OpenAiStreamChoice>,
610    pub(crate) usage: Option<OpenAiUsage>,
611}
612
613#[derive(Debug, Deserialize)]
614pub(crate) struct OpenAiStreamChoice {
615    pub(crate) delta: Option<OpenAiDelta>,
616    pub(crate) finish_reason: Option<String>,
617}
618
619#[derive(Debug, Deserialize)]
620pub(crate) struct OpenAiDelta {
621    pub(crate) reasoning_content: Option<String>,
622    pub(crate) content: Option<String>,
623    pub(crate) tool_calls: Option<Vec<OpenAiToolCallDelta>>,
624}
625
626#[derive(Debug, Deserialize)]
627pub(crate) struct OpenAiToolCallDelta {
628    pub(crate) index: usize,
629    pub(crate) id: Option<String>,
630    pub(crate) function: Option<OpenAiFunctionDelta>,
631}
632
633#[derive(Debug, Deserialize)]
634pub(crate) struct OpenAiFunctionDelta {
635    pub(crate) name: Option<String>,
636    pub(crate) arguments: Option<String>,
637}