Skip to main content

a3s_code_core/llm/
openai.rs

1//! OpenAI-compatible LLM client
2
3use super::http::{default_http_client, normalize_base_url, HttpClient};
4use super::types::*;
5use super::LlmClient;
6use crate::llm::types::{ToolResultContent, ToolResultContentField};
7use crate::retry::{AttemptOutcome, RetryConfig};
8use anyhow::{Context, Result};
9use async_trait::async_trait;
10use futures::StreamExt;
11use serde::Deserialize;
12use std::sync::Arc;
13use std::time::Instant;
14use tokio::sync::mpsc;
15
16/// OpenAI client
17pub struct OpenAiClient {
18    pub(crate) provider_name: String,
19    pub(crate) api_key: SecretString,
20    pub(crate) model: String,
21    pub(crate) base_url: String,
22    pub(crate) chat_completions_path: String,
23    pub(crate) temperature: Option<f32>,
24    pub(crate) max_tokens: Option<usize>,
25    pub(crate) http: Arc<dyn HttpClient>,
26    pub(crate) retry_config: RetryConfig,
27}
28
29impl OpenAiClient {
30    pub(crate) fn parse_tool_arguments(tool_name: &str, arguments: &str) -> serde_json::Value {
31        if arguments.trim().is_empty() {
32            return serde_json::Value::Object(Default::default());
33        }
34
35        serde_json::from_str(arguments).unwrap_or_else(|e| {
36            tracing::warn!(
37                "Failed to parse tool arguments JSON for tool '{}': {}",
38                tool_name,
39                e
40            );
41            serde_json::json!({
42                "__parse_error": format!(
43                    "Malformed tool arguments: {}. Raw input: {}",
44                    e, arguments
45                )
46            })
47        })
48    }
49
50    pub fn new(api_key: String, model: String) -> Self {
51        Self {
52            provider_name: "openai".to_string(),
53            api_key: SecretString::new(api_key),
54            model,
55            base_url: "https://api.openai.com".to_string(),
56            chat_completions_path: "/v1/chat/completions".to_string(),
57            temperature: None,
58            max_tokens: None,
59            http: default_http_client(),
60            retry_config: RetryConfig::default(),
61        }
62    }
63
64    pub fn with_base_url(mut self, base_url: String) -> Self {
65        self.base_url = normalize_base_url(&base_url);
66        self
67    }
68
69    pub fn with_provider_name(mut self, provider_name: impl Into<String>) -> Self {
70        self.provider_name = provider_name.into();
71        self
72    }
73
74    pub fn with_chat_completions_path(mut self, path: impl Into<String>) -> Self {
75        let path = path.into();
76        self.chat_completions_path = if path.starts_with('/') {
77            path
78        } else {
79            format!("/{}", path)
80        };
81        self
82    }
83
84    pub fn with_temperature(mut self, temperature: f32) -> Self {
85        self.temperature = Some(temperature);
86        self
87    }
88
89    pub fn with_max_tokens(mut self, max_tokens: usize) -> Self {
90        self.max_tokens = Some(max_tokens);
91        self
92    }
93
94    pub fn with_retry_config(mut self, retry_config: RetryConfig) -> Self {
95        self.retry_config = retry_config;
96        self
97    }
98
99    pub fn with_http_client(mut self, http: Arc<dyn HttpClient>) -> Self {
100        self.http = http;
101        self
102    }
103
104    pub(crate) fn convert_messages(&self, messages: &[Message]) -> Vec<serde_json::Value> {
105        messages
106            .iter()
107            .map(|msg| {
108                let content: serde_json::Value = if msg.content.len() == 1 {
109                    match &msg.content[0] {
110                        ContentBlock::Text { text } => serde_json::json!(text),
111                        ContentBlock::ToolResult {
112                            tool_use_id,
113                            content,
114                            ..
115                        } => {
116                            let content_str = match content {
117                                ToolResultContentField::Text(s) => s.clone(),
118                                ToolResultContentField::Blocks(blocks) => blocks
119                                    .iter()
120                                    .filter_map(|b| {
121                                        if let ToolResultContent::Text { text } = b {
122                                            Some(text.clone())
123                                        } else {
124                                            None
125                                        }
126                                    })
127                                    .collect::<Vec<_>>()
128                                    .join("\n"),
129                            };
130                            return serde_json::json!({
131                                "role": "tool",
132                                "tool_call_id": tool_use_id,
133                                "content": content_str,
134                            });
135                        }
136                        _ => serde_json::json!(""),
137                    }
138                } else {
139                    serde_json::json!(msg
140                        .content
141                        .iter()
142                        .map(|block| {
143                            match block {
144                                ContentBlock::Text { text } => serde_json::json!({
145                                    "type": "text",
146                                    "text": text,
147                                }),
148                                ContentBlock::Image { source } => serde_json::json!({
149                                    "type": "image_url",
150                                    "image_url": {
151                                        "url": format!(
152                                            "data:{};base64,{}",
153                                            source.media_type, source.data
154                                        ),
155                                    }
156                                }),
157                                ContentBlock::ToolUse { id, name, input } => serde_json::json!({
158                                    "type": "function",
159                                    "id": id,
160                                    "function": {
161                                        "name": name,
162                                        "arguments": input.to_string(),
163                                    }
164                                }),
165                                _ => serde_json::json!({}),
166                            }
167                        })
168                        .collect::<Vec<_>>())
169                };
170
171                // Handle assistant messages — kimi-k2.5 requires reasoning_content
172                // on all assistant messages when thinking mode is enabled
173                if msg.role == "assistant" {
174                    let rc = msg.reasoning_content.as_deref().unwrap_or("");
175                    let tool_calls: Vec<_> = msg.tool_calls();
176                    if !tool_calls.is_empty() {
177                        return serde_json::json!({
178                            "role": "assistant",
179                            "content": msg.text(),
180                            "reasoning_content": rc,
181                            "tool_calls": tool_calls.iter().map(|tc| {
182                                serde_json::json!({
183                                    "id": tc.id,
184                                    "type": "function",
185                                    "function": {
186                                        "name": tc.name,
187                                        "arguments": tc.args.to_string(),
188                                    }
189                                })
190                            }).collect::<Vec<_>>(),
191                        });
192                    }
193                    return serde_json::json!({
194                        "role": "assistant",
195                        "content": content,
196                        "reasoning_content": rc,
197                    });
198                }
199
200                serde_json::json!({
201                    "role": msg.role,
202                    "content": content,
203                })
204            })
205            .collect()
206    }
207
208    pub(crate) fn convert_tools(&self, tools: &[ToolDefinition]) -> Vec<serde_json::Value> {
209        tools
210            .iter()
211            .map(|t| {
212                serde_json::json!({
213                    "type": "function",
214                    "function": {
215                        "name": t.name,
216                        "description": t.description,
217                        "parameters": t.parameters,
218                    }
219                })
220            })
221            .collect()
222    }
223}
224
225#[async_trait]
226impl LlmClient for OpenAiClient {
227    async fn complete(
228        &self,
229        messages: &[Message],
230        system: Option<&str>,
231        tools: &[ToolDefinition],
232    ) -> Result<LlmResponse> {
233        {
234            let request_started_at = Instant::now();
235            let mut openai_messages = Vec::new();
236
237            if let Some(sys) = system {
238                openai_messages.push(serde_json::json!({
239                    "role": "system",
240                    "content": sys,
241                }));
242            }
243
244            openai_messages.extend(self.convert_messages(messages));
245
246            let mut request = serde_json::json!({
247                "model": self.model,
248                "messages": openai_messages,
249            });
250
251            if let Some(temp) = self.temperature {
252                request["temperature"] = serde_json::json!(temp);
253            }
254            if let Some(max) = self.max_tokens {
255                request["max_tokens"] = serde_json::json!(max);
256            }
257
258            if !tools.is_empty() {
259                request["tools"] = serde_json::json!(self.convert_tools(tools));
260            }
261
262            let url = format!("{}{}", self.base_url, self.chat_completions_path);
263            let auth_header = format!("Bearer {}", self.api_key.expose());
264            let headers = vec![("Authorization", auth_header.as_str())];
265
266            let response = crate::retry::with_retry(&self.retry_config, |_attempt| {
267                let http = &self.http;
268                let url = &url;
269                let headers = headers.clone();
270                let request = &request;
271                async move {
272                    match http.post(url, headers, request).await {
273                        Ok(resp) => {
274                            let status = reqwest::StatusCode::from_u16(resp.status)
275                                .unwrap_or(reqwest::StatusCode::INTERNAL_SERVER_ERROR);
276                            if status.is_success() {
277                                AttemptOutcome::Success(resp.body)
278                            } else if self.retry_config.is_retryable_status(status) {
279                                AttemptOutcome::Retryable {
280                                    status,
281                                    body: resp.body,
282                                    retry_after: None,
283                                }
284                            } else {
285                                AttemptOutcome::Fatal(anyhow::anyhow!(
286                                    "OpenAI API error at {} ({}): {}",
287                                    url,
288                                    status,
289                                    resp.body
290                                ))
291                            }
292                        }
293                        Err(e) => {
294                            eprintln!("[DEBUG] HTTP error: {:?}", e);
295                            AttemptOutcome::Fatal(e)
296                        }
297                    }
298                }
299            })
300            .await?;
301
302            let parsed: OpenAiResponse =
303                serde_json::from_str(&response).context("Failed to parse OpenAI response")?;
304
305            let choice = parsed.choices.into_iter().next().context("No choices")?;
306
307            let mut content = vec![];
308
309            let reasoning_content = choice.message.reasoning_content;
310
311            let text_content = choice.message.content;
312
313            if let Some(text) = text_content {
314                if !text.is_empty() {
315                    content.push(ContentBlock::Text { text });
316                }
317            }
318
319            if let Some(tool_calls) = choice.message.tool_calls {
320                for tc in tool_calls {
321                    content.push(ContentBlock::ToolUse {
322                        id: tc.id,
323                        name: tc.function.name.clone(),
324                        input: Self::parse_tool_arguments(
325                            &tc.function.name,
326                            &tc.function.arguments,
327                        ),
328                    });
329                }
330            }
331
332            let llm_response = LlmResponse {
333                message: Message {
334                    role: "assistant".to_string(),
335                    content,
336                    reasoning_content,
337                },
338                usage: TokenUsage {
339                    prompt_tokens: parsed.usage.prompt_tokens,
340                    completion_tokens: parsed.usage.completion_tokens,
341                    total_tokens: parsed.usage.total_tokens,
342                    cache_read_tokens: parsed
343                        .usage
344                        .prompt_tokens_details
345                        .as_ref()
346                        .and_then(|d| d.cached_tokens),
347                    cache_write_tokens: None,
348                },
349                stop_reason: choice.finish_reason,
350                meta: Some(LlmResponseMeta {
351                    provider: Some(self.provider_name.clone()),
352                    request_model: Some(self.model.clone()),
353                    request_url: Some(url.clone()),
354                    response_id: parsed.id,
355                    response_model: parsed.model,
356                    response_object: parsed.object,
357                    first_token_ms: None,
358                    duration_ms: Some(request_started_at.elapsed().as_millis() as u64),
359                }),
360            };
361
362            crate::telemetry::record_llm_usage(
363                llm_response.usage.prompt_tokens,
364                llm_response.usage.completion_tokens,
365                llm_response.usage.total_tokens,
366                llm_response.stop_reason.as_deref(),
367            );
368
369            Ok(llm_response)
370        }
371    }
372
373    async fn complete_streaming(
374        &self,
375        messages: &[Message],
376        system: Option<&str>,
377        tools: &[ToolDefinition],
378    ) -> Result<mpsc::Receiver<StreamEvent>> {
379        {
380            let request_started_at = Instant::now();
381            let mut openai_messages = Vec::new();
382
383            if let Some(sys) = system {
384                openai_messages.push(serde_json::json!({
385                    "role": "system",
386                    "content": sys,
387                }));
388            }
389
390            openai_messages.extend(self.convert_messages(messages));
391
392            let mut request = serde_json::json!({
393                "model": self.model,
394                "messages": openai_messages,
395                "stream": true,
396                "stream_options": { "include_usage": true },
397            });
398
399            if let Some(temp) = self.temperature {
400                request["temperature"] = serde_json::json!(temp);
401            }
402            if let Some(max) = self.max_tokens {
403                request["max_tokens"] = serde_json::json!(max);
404            }
405
406            if !tools.is_empty() {
407                request["tools"] = serde_json::json!(self.convert_tools(tools));
408            }
409
410            let url = format!("{}{}", self.base_url, self.chat_completions_path);
411            let auth_header = format!("Bearer {}", self.api_key.expose());
412            let headers = vec![("Authorization", auth_header.as_str())];
413
414            let streaming_resp = crate::retry::with_retry(&self.retry_config, |_attempt| {
415                let http = &self.http;
416                let url = &url;
417                let headers = headers.clone();
418                let request = &request;
419                async move {
420                    match http.post_streaming(url, headers, request).await {
421                        Ok(resp) => {
422                            let status = reqwest::StatusCode::from_u16(resp.status)
423                                .unwrap_or(reqwest::StatusCode::INTERNAL_SERVER_ERROR);
424                            if status.is_success() {
425                                AttemptOutcome::Success(resp)
426                            } else {
427                                let retry_after = resp
428                                    .retry_after
429                                    .as_deref()
430                                    .and_then(|v| RetryConfig::parse_retry_after(Some(v)));
431                                if self.retry_config.is_retryable_status(status) {
432                                    AttemptOutcome::Retryable {
433                                        status,
434                                        body: resp.error_body,
435                                        retry_after,
436                                    }
437                                } else {
438                                    AttemptOutcome::Fatal(anyhow::anyhow!(
439                                        "OpenAI API error at {} ({}): {}",
440                                        url,
441                                        status,
442                                        resp.error_body
443                                    ))
444                                }
445                            }
446                        }
447                        Err(e) => AttemptOutcome::Fatal(anyhow::anyhow!(
448                            "Failed to send streaming request: {}",
449                            e
450                        )),
451                    }
452                }
453            })
454            .await?;
455
456            let (tx, rx) = mpsc::channel(100);
457
458            let mut stream = streaming_resp.byte_stream;
459            let provider_name = self.provider_name.clone();
460            let request_model = self.model.clone();
461            let request_url = url.clone();
462            tokio::spawn(async move {
463                let mut buffer = String::new();
464                let mut content_blocks: Vec<ContentBlock> = Vec::new();
465                let mut text_content = String::new();
466                let mut reasoning_content_accum = String::new();
467                let mut tool_calls: std::collections::BTreeMap<usize, (String, String, String)> =
468                    std::collections::BTreeMap::new();
469                let mut usage = TokenUsage::default();
470                let mut finish_reason = None;
471                let mut response_id = None;
472                let mut response_model = None;
473                let mut response_object = None;
474                let mut first_token_ms = None;
475
476                while let Some(chunk_result) = stream.next().await {
477                    let chunk = match chunk_result {
478                        Ok(c) => c,
479                        Err(e) => {
480                            tracing::error!("Stream error: {}", e);
481                            break;
482                        }
483                    };
484
485                    buffer.push_str(&String::from_utf8_lossy(&chunk));
486
487                    while let Some(event_end) = buffer.find("\n\n") {
488                        let event_data: String = buffer.drain(..event_end).collect();
489                        buffer.drain(..2);
490
491                        for line in event_data.lines() {
492                            if let Some(data) = line.strip_prefix("data: ") {
493                                if data == "[DONE]" {
494                                    if !text_content.is_empty() {
495                                        content_blocks.push(ContentBlock::Text {
496                                            text: text_content.clone(),
497                                        });
498                                    }
499                                    for (_, (id, name, args)) in tool_calls.iter() {
500                                        content_blocks.push(ContentBlock::ToolUse {
501                                            id: id.clone(),
502                                            name: name.clone(),
503                                            input: Self::parse_tool_arguments(name, args),
504                                        });
505                                    }
506                                    tool_calls.clear();
507                                    crate::telemetry::record_llm_usage(
508                                        usage.prompt_tokens,
509                                        usage.completion_tokens,
510                                        usage.total_tokens,
511                                        finish_reason.as_deref(),
512                                    );
513                                    let response = LlmResponse {
514                                        message: Message {
515                                            role: "assistant".to_string(),
516                                            content: std::mem::take(&mut content_blocks),
517                                            reasoning_content: if reasoning_content_accum.is_empty()
518                                            {
519                                                None
520                                            } else {
521                                                Some(std::mem::take(&mut reasoning_content_accum))
522                                            },
523                                        },
524                                        usage: usage.clone(),
525                                        stop_reason: std::mem::take(&mut finish_reason),
526                                        meta: Some(LlmResponseMeta {
527                                            provider: Some(provider_name.clone()),
528                                            request_model: Some(request_model.clone()),
529                                            request_url: Some(request_url.clone()),
530                                            response_id: response_id.clone(),
531                                            response_model: response_model.clone(),
532                                            response_object: response_object.clone(),
533                                            first_token_ms,
534                                            duration_ms: Some(
535                                                request_started_at.elapsed().as_millis() as u64,
536                                            ),
537                                        }),
538                                    };
539                                    let _ = tx.send(StreamEvent::Done(response)).await;
540                                    continue;
541                                }
542
543                                if let Ok(event) = serde_json::from_str::<OpenAiStreamChunk>(data) {
544                                    if response_id.is_none() {
545                                        response_id = event.id.clone();
546                                    }
547                                    if response_model.is_none() {
548                                        response_model = event.model.clone();
549                                    }
550                                    if response_object.is_none() {
551                                        response_object = event.object.clone();
552                                    }
553                                    if let Some(u) = event.usage {
554                                        usage.prompt_tokens = u.prompt_tokens;
555                                        usage.completion_tokens = u.completion_tokens;
556                                        usage.total_tokens = u.total_tokens;
557                                        usage.cache_read_tokens = u
558                                            .prompt_tokens_details
559                                            .as_ref()
560                                            .and_then(|d| d.cached_tokens);
561                                    }
562
563                                    if let Some(choice) = event.choices.into_iter().next() {
564                                        if let Some(reason) = choice.finish_reason {
565                                            finish_reason = Some(reason);
566                                        }
567
568                                        if let Some(delta) = choice.delta {
569                                            if let Some(ref rc) = delta.reasoning_content {
570                                                reasoning_content_accum.push_str(rc);
571                                            }
572
573                                            if let Some(content) = delta.content {
574                                                if first_token_ms.is_none() {
575                                                    first_token_ms = Some(
576                                                        request_started_at.elapsed().as_millis()
577                                                            as u64,
578                                                    );
579                                                }
580                                                text_content.push_str(&content);
581                                                let _ =
582                                                    tx.send(StreamEvent::TextDelta(content)).await;
583                                            }
584
585                                            if let Some(tcs) = delta.tool_calls {
586                                                for tc in tcs {
587                                                    let entry = tool_calls
588                                                        .entry(tc.index)
589                                                        .or_insert_with(|| {
590                                                            (
591                                                                String::new(),
592                                                                String::new(),
593                                                                String::new(),
594                                                            )
595                                                        });
596
597                                                    if let Some(id) = tc.id {
598                                                        entry.0 = id;
599                                                    }
600                                                    if let Some(func) = tc.function {
601                                                        if let Some(name) = func.name {
602                                                            if first_token_ms.is_none() {
603                                                                first_token_ms = Some(
604                                                                    request_started_at
605                                                                        .elapsed()
606                                                                        .as_millis()
607                                                                        as u64,
608                                                                );
609                                                            }
610                                                            entry.1 = name.clone();
611                                                            let _ = tx
612                                                                .send(StreamEvent::ToolUseStart {
613                                                                    id: entry.0.clone(),
614                                                                    name,
615                                                                })
616                                                                .await;
617                                                        }
618                                                        if let Some(args) = func.arguments {
619                                                            entry.2.push_str(&args);
620                                                            let _ = tx
621                                                                .send(
622                                                                    StreamEvent::ToolUseInputDelta(
623                                                                        args,
624                                                                    ),
625                                                                )
626                                                                .await;
627                                                        }
628                                                    }
629                                                }
630                                            }
631                                        }
632                                    }
633                                }
634                            }
635                        }
636                    }
637                }
638            });
639
640            Ok(rx)
641        }
642    }
643}
644
645// OpenAI API response types (private)
646#[derive(Debug, Deserialize)]
647pub(crate) struct OpenAiResponse {
648    #[serde(default)]
649    pub(crate) id: Option<String>,
650    #[serde(default)]
651    pub(crate) object: Option<String>,
652    #[serde(default)]
653    pub(crate) model: Option<String>,
654    pub(crate) choices: Vec<OpenAiChoice>,
655    pub(crate) usage: OpenAiUsage,
656}
657
658#[derive(Debug, Deserialize)]
659pub(crate) struct OpenAiChoice {
660    pub(crate) message: OpenAiMessage,
661    pub(crate) finish_reason: Option<String>,
662}
663
664#[derive(Debug, Deserialize)]
665pub(crate) struct OpenAiMessage {
666    pub(crate) reasoning_content: Option<String>,
667    pub(crate) content: Option<String>,
668    pub(crate) tool_calls: Option<Vec<OpenAiToolCall>>,
669}
670
671#[derive(Debug, Deserialize)]
672pub(crate) struct OpenAiToolCall {
673    pub(crate) id: String,
674    pub(crate) function: OpenAiFunction,
675}
676
677#[derive(Debug, Deserialize)]
678pub(crate) struct OpenAiFunction {
679    pub(crate) name: String,
680    pub(crate) arguments: String,
681}
682
683#[derive(Debug, Deserialize)]
684pub(crate) struct OpenAiUsage {
685    pub(crate) prompt_tokens: usize,
686    pub(crate) completion_tokens: usize,
687    pub(crate) total_tokens: usize,
688    /// OpenAI returns cached token count in `prompt_tokens_details.cached_tokens`
689    #[serde(default)]
690    pub(crate) prompt_tokens_details: Option<OpenAiPromptTokensDetails>,
691}
692
693#[derive(Debug, Deserialize)]
694pub(crate) struct OpenAiPromptTokensDetails {
695    #[serde(default)]
696    pub(crate) cached_tokens: Option<usize>,
697}
698
699// OpenAI streaming types
700#[derive(Debug, Deserialize)]
701pub(crate) struct OpenAiStreamChunk {
702    #[serde(default)]
703    pub(crate) id: Option<String>,
704    #[serde(default)]
705    pub(crate) object: Option<String>,
706    #[serde(default)]
707    pub(crate) model: Option<String>,
708    pub(crate) choices: Vec<OpenAiStreamChoice>,
709    pub(crate) usage: Option<OpenAiUsage>,
710}
711
712#[derive(Debug, Deserialize)]
713pub(crate) struct OpenAiStreamChoice {
714    pub(crate) delta: Option<OpenAiDelta>,
715    pub(crate) finish_reason: Option<String>,
716}
717
718#[derive(Debug, Deserialize)]
719pub(crate) struct OpenAiDelta {
720    pub(crate) reasoning_content: Option<String>,
721    pub(crate) content: Option<String>,
722    pub(crate) tool_calls: Option<Vec<OpenAiToolCallDelta>>,
723}
724
725#[derive(Debug, Deserialize)]
726pub(crate) struct OpenAiToolCallDelta {
727    pub(crate) index: usize,
728    pub(crate) id: Option<String>,
729    pub(crate) function: Option<OpenAiFunctionDelta>,
730}
731
732#[derive(Debug, Deserialize)]
733pub(crate) struct OpenAiFunctionDelta {
734    pub(crate) name: Option<String>,
735    pub(crate) arguments: Option<String>,
736}