Skip to main content

a3s_code_core/llm/
openai.rs

1//! OpenAI-compatible LLM client
2
3use super::http::{default_http_client, normalize_base_url, HttpClient};
4use super::types::*;
5use super::LlmClient;
6use crate::llm::types::{ToolResultContent, ToolResultContentField};
7use crate::retry::{AttemptOutcome, RetryConfig};
8use anyhow::{Context, Result};
9use async_trait::async_trait;
10use futures::StreamExt;
11use serde::Deserialize;
12use std::sync::Arc;
13use tokio::sync::mpsc;
14
15/// OpenAI client
16pub struct OpenAiClient {
17    pub(crate) api_key: SecretString,
18    pub(crate) model: String,
19    pub(crate) base_url: String,
20    pub(crate) temperature: Option<f32>,
21    pub(crate) max_tokens: Option<usize>,
22    pub(crate) http: Arc<dyn HttpClient>,
23    pub(crate) retry_config: RetryConfig,
24}
25
26impl OpenAiClient {
27    pub fn new(api_key: String, model: String) -> Self {
28        Self {
29            api_key: SecretString::new(api_key),
30            model,
31            base_url: "https://api.openai.com".to_string(),
32            temperature: None,
33            max_tokens: None,
34            http: default_http_client(),
35            retry_config: RetryConfig::default(),
36        }
37    }
38
39    pub fn with_base_url(mut self, base_url: String) -> Self {
40        self.base_url = normalize_base_url(&base_url);
41        self
42    }
43
44    pub fn with_temperature(mut self, temperature: f32) -> Self {
45        self.temperature = Some(temperature);
46        self
47    }
48
49    pub fn with_max_tokens(mut self, max_tokens: usize) -> Self {
50        self.max_tokens = Some(max_tokens);
51        self
52    }
53
54    pub fn with_retry_config(mut self, retry_config: RetryConfig) -> Self {
55        self.retry_config = retry_config;
56        self
57    }
58
59    pub fn with_http_client(mut self, http: Arc<dyn HttpClient>) -> Self {
60        self.http = http;
61        self
62    }
63
64    pub(crate) fn convert_messages(&self, messages: &[Message]) -> Vec<serde_json::Value> {
65        messages
66            .iter()
67            .map(|msg| {
68                let content: serde_json::Value = if msg.content.len() == 1 {
69                    match &msg.content[0] {
70                        ContentBlock::Text { text } => serde_json::json!(text),
71                        ContentBlock::ToolResult {
72                            tool_use_id,
73                            content,
74                            ..
75                        } => {
76                            let content_str = match content {
77                                ToolResultContentField::Text(s) => s.clone(),
78                                ToolResultContentField::Blocks(blocks) => blocks
79                                    .iter()
80                                    .filter_map(|b| {
81                                        if let ToolResultContent::Text { text } = b {
82                                            Some(text.clone())
83                                        } else {
84                                            None
85                                        }
86                                    })
87                                    .collect::<Vec<_>>()
88                                    .join("\n"),
89                            };
90                            return serde_json::json!({
91                                "role": "tool",
92                                "tool_call_id": tool_use_id,
93                                "content": content_str,
94                            });
95                        }
96                        _ => serde_json::json!(""),
97                    }
98                } else {
99                    serde_json::json!(msg
100                        .content
101                        .iter()
102                        .map(|block| {
103                            match block {
104                                ContentBlock::Text { text } => serde_json::json!({
105                                    "type": "text",
106                                    "text": text,
107                                }),
108                                ContentBlock::Image { source } => serde_json::json!({
109                                    "type": "image_url",
110                                    "image_url": {
111                                        "url": format!(
112                                            "data:{};base64,{}",
113                                            source.media_type, source.data
114                                        ),
115                                    }
116                                }),
117                                ContentBlock::ToolUse { id, name, input } => serde_json::json!({
118                                    "type": "function",
119                                    "id": id,
120                                    "function": {
121                                        "name": name,
122                                        "arguments": input.to_string(),
123                                    }
124                                }),
125                                _ => serde_json::json!({}),
126                            }
127                        })
128                        .collect::<Vec<_>>())
129                };
130
131                // Handle assistant messages — kimi-k2.5 requires reasoning_content
132                // on all assistant messages when thinking mode is enabled
133                if msg.role == "assistant" {
134                    let rc = msg.reasoning_content.as_deref().unwrap_or("");
135                    let tool_calls: Vec<_> = msg.tool_calls();
136                    if !tool_calls.is_empty() {
137                        return serde_json::json!({
138                            "role": "assistant",
139                            "content": msg.text(),
140                            "reasoning_content": rc,
141                            "tool_calls": tool_calls.iter().map(|tc| {
142                                serde_json::json!({
143                                    "id": tc.id,
144                                    "type": "function",
145                                    "function": {
146                                        "name": tc.name,
147                                        "arguments": tc.args.to_string(),
148                                    }
149                                })
150                            }).collect::<Vec<_>>(),
151                        });
152                    }
153                    return serde_json::json!({
154                        "role": "assistant",
155                        "content": content,
156                        "reasoning_content": rc,
157                    });
158                }
159
160                serde_json::json!({
161                    "role": msg.role,
162                    "content": content,
163                })
164            })
165            .collect()
166    }
167
168    pub(crate) fn convert_tools(&self, tools: &[ToolDefinition]) -> Vec<serde_json::Value> {
169        tools
170            .iter()
171            .map(|t| {
172                serde_json::json!({
173                    "type": "function",
174                    "function": {
175                        "name": t.name,
176                        "description": t.description,
177                        "parameters": t.parameters,
178                    }
179                })
180            })
181            .collect()
182    }
183}
184
185#[async_trait]
186impl LlmClient for OpenAiClient {
187    async fn complete(
188        &self,
189        messages: &[Message],
190        system: Option<&str>,
191        tools: &[ToolDefinition],
192    ) -> Result<LlmResponse> {
193        {
194            let mut openai_messages = Vec::new();
195
196            if let Some(sys) = system {
197                openai_messages.push(serde_json::json!({
198                    "role": "system",
199                    "content": sys,
200                }));
201            }
202
203            openai_messages.extend(self.convert_messages(messages));
204
205            let mut request = serde_json::json!({
206                "model": self.model,
207                "messages": openai_messages,
208            });
209
210            if let Some(temp) = self.temperature {
211                request["temperature"] = serde_json::json!(temp);
212            }
213            if let Some(max) = self.max_tokens {
214                request["max_tokens"] = serde_json::json!(max);
215            }
216
217            if !tools.is_empty() {
218                request["tools"] = serde_json::json!(self.convert_tools(tools));
219            }
220
221            let url = format!("{}/v1/chat/completions", self.base_url);
222            let auth_header = format!("Bearer {}", self.api_key.expose());
223            let headers = vec![("Authorization", auth_header.as_str())];
224
225            let response = crate::retry::with_retry(&self.retry_config, |_attempt| {
226                let http = &self.http;
227                let url = &url;
228                let headers = headers.clone();
229                let request = &request;
230                async move {
231                    match http.post(url, headers, request).await {
232                        Ok(resp) => {
233                            let status = reqwest::StatusCode::from_u16(resp.status)
234                                .unwrap_or(reqwest::StatusCode::INTERNAL_SERVER_ERROR);
235                            if status.is_success() {
236                                AttemptOutcome::Success(resp.body)
237                            } else if self.retry_config.is_retryable_status(status) {
238                                AttemptOutcome::Retryable {
239                                    status,
240                                    body: resp.body,
241                                    retry_after: None,
242                                }
243                            } else {
244                                AttemptOutcome::Fatal(anyhow::anyhow!(
245                                    "OpenAI API error at {} ({}): {}",
246                                    url,
247                                    status,
248                                    resp.body
249                                ))
250                            }
251                        }
252                        Err(e) => {
253                            eprintln!("[DEBUG] HTTP error: {:?}", e);
254                            AttemptOutcome::Fatal(e)
255                        }
256                    }
257                }
258            })
259            .await?;
260
261            let parsed: OpenAiResponse =
262                serde_json::from_str(&response).context("Failed to parse OpenAI response")?;
263
264            let choice = parsed.choices.into_iter().next().context("No choices")?;
265
266            let mut content = vec![];
267
268            let reasoning_content = choice.message.reasoning_content;
269
270            let text_content = choice.message.content;
271
272            if let Some(text) = text_content {
273                if !text.is_empty() {
274                    content.push(ContentBlock::Text { text });
275                }
276            }
277
278            if let Some(tool_calls) = choice.message.tool_calls {
279                for tc in tool_calls {
280                    content.push(ContentBlock::ToolUse {
281                        id: tc.id,
282                        name: tc.function.name.clone(),
283                        input: serde_json::from_str(&tc.function.arguments).unwrap_or_else(|e| {
284                            tracing::warn!(
285                                "Failed to parse tool arguments JSON for tool '{}': {}",
286                                tc.function.name,
287                                e
288                            );
289                            serde_json::Value::default()
290                        }),
291                    });
292                }
293            }
294
295            let llm_response = LlmResponse {
296                message: Message {
297                    role: "assistant".to_string(),
298                    content,
299                    reasoning_content,
300                },
301                usage: TokenUsage {
302                    prompt_tokens: parsed.usage.prompt_tokens,
303                    completion_tokens: parsed.usage.completion_tokens,
304                    total_tokens: parsed.usage.total_tokens,
305                    cache_read_tokens: parsed
306                        .usage
307                        .prompt_tokens_details
308                        .as_ref()
309                        .and_then(|d| d.cached_tokens),
310                    cache_write_tokens: None,
311                },
312                stop_reason: choice.finish_reason,
313            };
314
315            crate::telemetry::record_llm_usage(
316                llm_response.usage.prompt_tokens,
317                llm_response.usage.completion_tokens,
318                llm_response.usage.total_tokens,
319                llm_response.stop_reason.as_deref(),
320            );
321
322            Ok(llm_response)
323        }
324    }
325
326    async fn complete_streaming(
327        &self,
328        messages: &[Message],
329        system: Option<&str>,
330        tools: &[ToolDefinition],
331    ) -> Result<mpsc::Receiver<StreamEvent>> {
332        {
333            let mut openai_messages = Vec::new();
334
335            if let Some(sys) = system {
336                openai_messages.push(serde_json::json!({
337                    "role": "system",
338                    "content": sys,
339                }));
340            }
341
342            openai_messages.extend(self.convert_messages(messages));
343
344            let mut request = serde_json::json!({
345                "model": self.model,
346                "messages": openai_messages,
347                "stream": true,
348                "stream_options": { "include_usage": true },
349            });
350
351            if let Some(temp) = self.temperature {
352                request["temperature"] = serde_json::json!(temp);
353            }
354            if let Some(max) = self.max_tokens {
355                request["max_tokens"] = serde_json::json!(max);
356            }
357
358            if !tools.is_empty() {
359                request["tools"] = serde_json::json!(self.convert_tools(tools));
360            }
361
362            let url = format!("{}/v1/chat/completions", self.base_url);
363            let auth_header = format!("Bearer {}", self.api_key.expose());
364            let headers = vec![("Authorization", auth_header.as_str())];
365
366            let streaming_resp = crate::retry::with_retry(&self.retry_config, |_attempt| {
367                let http = &self.http;
368                let url = &url;
369                let headers = headers.clone();
370                let request = &request;
371                async move {
372                    match http.post_streaming(url, headers, request).await {
373                        Ok(resp) => {
374                            let status = reqwest::StatusCode::from_u16(resp.status)
375                                .unwrap_or(reqwest::StatusCode::INTERNAL_SERVER_ERROR);
376                            if status.is_success() {
377                                AttemptOutcome::Success(resp)
378                            } else {
379                                let retry_after = resp
380                                    .retry_after
381                                    .as_deref()
382                                    .and_then(|v| RetryConfig::parse_retry_after(Some(v)));
383                                if self.retry_config.is_retryable_status(status) {
384                                    AttemptOutcome::Retryable {
385                                        status,
386                                        body: resp.error_body,
387                                        retry_after,
388                                    }
389                                } else {
390                                    AttemptOutcome::Fatal(anyhow::anyhow!(
391                                        "OpenAI API error at {} ({}): {}",
392                                        url,
393                                        status,
394                                        resp.error_body
395                                    ))
396                                }
397                            }
398                        }
399                        Err(e) => AttemptOutcome::Fatal(anyhow::anyhow!(
400                            "Failed to send streaming request: {}",
401                            e
402                        )),
403                    }
404                }
405            })
406            .await?;
407
408            let (tx, rx) = mpsc::channel(100);
409
410            let mut stream = streaming_resp.byte_stream;
411            tokio::spawn(async move {
412                let mut buffer = String::new();
413                let mut content_blocks: Vec<ContentBlock> = Vec::new();
414                let mut text_content = String::new();
415                let mut reasoning_content_accum = String::new();
416                let mut tool_calls: std::collections::BTreeMap<usize, (String, String, String)> =
417                    std::collections::BTreeMap::new();
418                let mut usage = TokenUsage::default();
419                let mut finish_reason = None;
420
421                while let Some(chunk_result) = stream.next().await {
422                    let chunk = match chunk_result {
423                        Ok(c) => c,
424                        Err(e) => {
425                            tracing::error!("Stream error: {}", e);
426                            break;
427                        }
428                    };
429
430                    buffer.push_str(&String::from_utf8_lossy(&chunk));
431
432                    while let Some(event_end) = buffer.find("\n\n") {
433                        let event_data: String = buffer.drain(..event_end).collect();
434                        buffer.drain(..2);
435
436                        for line in event_data.lines() {
437                            if let Some(data) = line.strip_prefix("data: ") {
438                                if data == "[DONE]" {
439                                    if !text_content.is_empty() {
440                                        content_blocks.push(ContentBlock::Text {
441                                            text: text_content.clone(),
442                                        });
443                                    }
444                                    for (_, (id, name, args)) in tool_calls.iter() {
445                                        content_blocks.push(ContentBlock::ToolUse {
446                                        id: id.clone(),
447                                        name: name.clone(),
448                                        input: serde_json::from_str(args).unwrap_or_else(|e| {
449                                            tracing::warn!(
450                                                "Failed to parse tool arguments JSON for tool '{}': {}",
451                                                name, e
452                                            );
453                                            serde_json::Value::default()
454                                        }),
455                                    });
456                                    }
457                                    tool_calls.clear();
458                                    crate::telemetry::record_llm_usage(
459                                        usage.prompt_tokens,
460                                        usage.completion_tokens,
461                                        usage.total_tokens,
462                                        finish_reason.as_deref(),
463                                    );
464                                    let response = LlmResponse {
465                                        message: Message {
466                                            role: "assistant".to_string(),
467                                            content: std::mem::take(&mut content_blocks),
468                                            reasoning_content: if reasoning_content_accum.is_empty()
469                                            {
470                                                None
471                                            } else {
472                                                Some(std::mem::take(&mut reasoning_content_accum))
473                                            },
474                                        },
475                                        usage: usage.clone(),
476                                        stop_reason: std::mem::take(&mut finish_reason),
477                                    };
478                                    let _ = tx.send(StreamEvent::Done(response)).await;
479                                    continue;
480                                }
481
482                                if let Ok(event) = serde_json::from_str::<OpenAiStreamChunk>(data) {
483                                    if let Some(u) = event.usage {
484                                        usage.prompt_tokens = u.prompt_tokens;
485                                        usage.completion_tokens = u.completion_tokens;
486                                        usage.total_tokens = u.total_tokens;
487                                        usage.cache_read_tokens = u
488                                            .prompt_tokens_details
489                                            .as_ref()
490                                            .and_then(|d| d.cached_tokens);
491                                    }
492
493                                    if let Some(choice) = event.choices.into_iter().next() {
494                                        if let Some(reason) = choice.finish_reason {
495                                            finish_reason = Some(reason);
496                                        }
497
498                                        if let Some(delta) = choice.delta {
499                                            if let Some(ref rc) = delta.reasoning_content {
500                                                reasoning_content_accum.push_str(rc);
501                                            }
502
503                                            if let Some(content) = delta.content {
504                                                text_content.push_str(&content);
505                                                let _ =
506                                                    tx.send(StreamEvent::TextDelta(content)).await;
507                                            }
508
509                                            if let Some(tcs) = delta.tool_calls {
510                                                for tc in tcs {
511                                                    let entry = tool_calls
512                                                        .entry(tc.index)
513                                                        .or_insert_with(|| {
514                                                            (
515                                                                String::new(),
516                                                                String::new(),
517                                                                String::new(),
518                                                            )
519                                                        });
520
521                                                    if let Some(id) = tc.id {
522                                                        entry.0 = id;
523                                                    }
524                                                    if let Some(func) = tc.function {
525                                                        if let Some(name) = func.name {
526                                                            entry.1 = name.clone();
527                                                            let _ = tx
528                                                                .send(StreamEvent::ToolUseStart {
529                                                                    id: entry.0.clone(),
530                                                                    name,
531                                                                })
532                                                                .await;
533                                                        }
534                                                        if let Some(args) = func.arguments {
535                                                            entry.2.push_str(&args);
536                                                            let _ = tx
537                                                                .send(
538                                                                    StreamEvent::ToolUseInputDelta(
539                                                                        args,
540                                                                    ),
541                                                                )
542                                                                .await;
543                                                        }
544                                                    }
545                                                }
546                                            }
547                                        }
548                                    }
549                                }
550                            }
551                        }
552                    }
553                }
554            });
555
556            Ok(rx)
557        }
558    }
559}
560
561// OpenAI API response types (private)
562#[derive(Debug, Deserialize)]
563pub(crate) struct OpenAiResponse {
564    pub(crate) choices: Vec<OpenAiChoice>,
565    pub(crate) usage: OpenAiUsage,
566}
567
568#[derive(Debug, Deserialize)]
569pub(crate) struct OpenAiChoice {
570    pub(crate) message: OpenAiMessage,
571    pub(crate) finish_reason: Option<String>,
572}
573
574#[derive(Debug, Deserialize)]
575pub(crate) struct OpenAiMessage {
576    pub(crate) reasoning_content: Option<String>,
577    pub(crate) content: Option<String>,
578    pub(crate) tool_calls: Option<Vec<OpenAiToolCall>>,
579}
580
581#[derive(Debug, Deserialize)]
582pub(crate) struct OpenAiToolCall {
583    pub(crate) id: String,
584    pub(crate) function: OpenAiFunction,
585}
586
587#[derive(Debug, Deserialize)]
588pub(crate) struct OpenAiFunction {
589    pub(crate) name: String,
590    pub(crate) arguments: String,
591}
592
593#[derive(Debug, Deserialize)]
594pub(crate) struct OpenAiUsage {
595    pub(crate) prompt_tokens: usize,
596    pub(crate) completion_tokens: usize,
597    pub(crate) total_tokens: usize,
598    /// OpenAI returns cached token count in `prompt_tokens_details.cached_tokens`
599    #[serde(default)]
600    pub(crate) prompt_tokens_details: Option<OpenAiPromptTokensDetails>,
601}
602
603#[derive(Debug, Deserialize)]
604pub(crate) struct OpenAiPromptTokensDetails {
605    #[serde(default)]
606    pub(crate) cached_tokens: Option<usize>,
607}
608
609// OpenAI streaming types
610#[derive(Debug, Deserialize)]
611pub(crate) struct OpenAiStreamChunk {
612    pub(crate) choices: Vec<OpenAiStreamChoice>,
613    pub(crate) usage: Option<OpenAiUsage>,
614}
615
616#[derive(Debug, Deserialize)]
617pub(crate) struct OpenAiStreamChoice {
618    pub(crate) delta: Option<OpenAiDelta>,
619    pub(crate) finish_reason: Option<String>,
620}
621
622#[derive(Debug, Deserialize)]
623pub(crate) struct OpenAiDelta {
624    pub(crate) reasoning_content: Option<String>,
625    pub(crate) content: Option<String>,
626    pub(crate) tool_calls: Option<Vec<OpenAiToolCallDelta>>,
627}
628
629#[derive(Debug, Deserialize)]
630pub(crate) struct OpenAiToolCallDelta {
631    pub(crate) index: usize,
632    pub(crate) id: Option<String>,
633    pub(crate) function: Option<OpenAiFunctionDelta>,
634}
635
636#[derive(Debug, Deserialize)]
637pub(crate) struct OpenAiFunctionDelta {
638    pub(crate) name: Option<String>,
639    pub(crate) arguments: Option<String>,
640}