Skip to main content

codineer_api/providers/
openai_compat.rs

1#[path = "openai_compat_sse.rs"]
2mod openai_compat_sse;
3#[path = "openai_compat_stream.rs"]
4mod openai_compat_stream;
5
6use std::collections::VecDeque;
7
8use serde::{Deserialize, Deserializer};
9use serde_json::{json, Value};
10
11use crate::error::ApiError;
12use crate::providers::{parse_custom_provider_prefix, RetryPolicy};
13use crate::types::{
14    InputContentBlock, InputMessage, MessageRequest, MessageResponse, OutputContentBlock,
15    ToolChoice, ToolDefinition, ToolResultContentBlock, Usage,
16};
17
18use openai_compat_sse::{first_non_empty_field, OpenAiSseParser};
19use openai_compat_stream::StreamState;
20
21pub use openai_compat_stream::MessageStream;
22
23pub const DEFAULT_XAI_BASE_URL: &str = "https://api.x.ai/v1";
24pub const DEFAULT_OPENAI_BASE_URL: &str = "https://api.openai.com/v1";
25const REQUEST_ID_HEADER: &str = "request-id";
26const ALT_REQUEST_ID_HEADER: &str = "x-request-id";
27
28#[derive(Debug, Clone, Copy, PartialEq, Eq)]
29pub struct OpenAiCompatConfig {
30    pub provider_name: &'static str,
31    pub api_key_env: &'static str,
32    pub base_url_env: &'static str,
33    pub default_base_url: &'static str,
34}
35
36const XAI_ENV_VARS: &[&str] = &["XAI_API_KEY"];
37const OPENAI_ENV_VARS: &[&str] = &["OPENAI_API_KEY"];
38
39impl OpenAiCompatConfig {
40    #[must_use]
41    pub const fn xai() -> Self {
42        Self {
43            provider_name: "xAI",
44            api_key_env: "XAI_API_KEY",
45            base_url_env: "XAI_BASE_URL",
46            default_base_url: DEFAULT_XAI_BASE_URL,
47        }
48    }
49
50    #[must_use]
51    pub const fn openai() -> Self {
52        Self {
53            provider_name: "OpenAI",
54            api_key_env: "OPENAI_API_KEY",
55            base_url_env: "OPENAI_BASE_URL",
56            default_base_url: DEFAULT_OPENAI_BASE_URL,
57        }
58    }
59    #[must_use]
60    pub fn credential_env_vars(self) -> &'static [&'static str] {
61        match self.api_key_env {
62            "XAI_API_KEY" => XAI_ENV_VARS,
63            "OPENAI_API_KEY" => OPENAI_ENV_VARS,
64            _ => &[],
65        }
66    }
67}
68
69#[derive(Clone)]
70pub struct OpenAiCompatClient {
71    http: reqwest::Client,
72    api_key: String,
73    base_url: String,
74    endpoint_query: Option<String>,
75    retry: RetryPolicy,
76}
77
78impl std::fmt::Debug for OpenAiCompatClient {
79    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
80        f.debug_struct("OpenAiCompatClient")
81            .field("base_url", &self.base_url)
82            .field("endpoint_query", &self.endpoint_query)
83            .field("api_key", &"***")
84            .finish()
85    }
86}
87
88impl OpenAiCompatClient {
89    #[must_use]
90    pub fn new(api_key: impl Into<String>, config: OpenAiCompatConfig) -> Self {
91        Self {
92            http: crate::default_http_client(),
93            api_key: api_key.into(),
94            base_url: read_base_url(config),
95            endpoint_query: None,
96            retry: RetryPolicy::default(),
97        }
98    }
99
100    #[must_use]
101    pub fn new_custom(base_url: impl Into<String>, api_key: impl Into<String>) -> Self {
102        Self {
103            http: crate::default_http_client(),
104            api_key: api_key.into(),
105            base_url: base_url.into(),
106            endpoint_query: None,
107            retry: RetryPolicy::default(),
108        }
109    }
110
111    #[must_use]
112    pub fn with_endpoint_query(mut self, endpoint_query: Option<String>) -> Self {
113        self.endpoint_query = endpoint_query
114            .map(|s| s.trim().to_string())
115            .filter(|s| !s.is_empty());
116        self
117    }
118
119    pub fn from_env(config: OpenAiCompatConfig) -> Result<Self, ApiError> {
120        let Some(api_key) = read_env_non_empty(config.api_key_env)? else {
121            return Err(ApiError::missing_credentials(
122                config.provider_name,
123                config.credential_env_vars(),
124            ));
125        };
126        Ok(Self::new(api_key, config))
127    }
128
129    #[must_use]
130    pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
131        self.base_url = base_url.into();
132        self
133    }
134
135    #[must_use]
136    pub fn with_retry_policy(mut self, retry: RetryPolicy) -> Self {
137        self.retry = retry;
138        self
139    }
140
141    pub async fn send_message(
142        &self,
143        request: &MessageRequest,
144    ) -> Result<MessageResponse, ApiError> {
145        let request = MessageRequest {
146            stream: false,
147            ..request.clone()
148        };
149        let response = self.send_with_retry(&request).await?;
150        let request_id = request_id_from_headers(response.headers());
151        let payload = response.json::<ChatCompletionResponse>().await?;
152        let mut normalized = normalize_response(&request.model, payload)?;
153        if normalized.request_id.is_none() {
154            normalized.request_id = request_id;
155        }
156        Ok(normalized)
157    }
158
159    pub async fn stream_message(
160        &self,
161        request: &MessageRequest,
162    ) -> Result<MessageStream, ApiError> {
163        let response = self
164            .send_with_retry(&request.clone().with_streaming())
165            .await?;
166        Ok(MessageStream {
167            request_id: request_id_from_headers(response.headers()),
168            response,
169            parser: OpenAiSseParser::new(),
170            pending: VecDeque::new(),
171            done: false,
172            state: StreamState::new(request.model.clone()),
173        })
174    }
175
176    async fn send_with_retry(
177        &self,
178        request: &MessageRequest,
179    ) -> Result<reqwest::Response, ApiError> {
180        let mut attempts = 0;
181
182        let last_error = loop {
183            attempts += 1;
184            let retryable_error = match self.send_raw_request(request).await {
185                Ok(response) => match expect_success(response).await {
186                    Ok(response) => return Ok(response),
187                    Err(error)
188                        if error.is_retryable() && attempts <= self.retry.max_retries + 1 =>
189                    {
190                        error
191                    }
192                    Err(error) => return Err(error),
193                },
194                Err(error) if error.is_retryable() && attempts <= self.retry.max_retries + 1 => {
195                    error
196                }
197                Err(error) => return Err(error),
198            };
199
200            if attempts > self.retry.max_retries {
201                break retryable_error;
202            }
203
204            tokio::time::sleep(self.backoff_for_attempt(attempts)?).await;
205        };
206
207        Err(ApiError::RetriesExhausted {
208            attempts,
209            last_error: Box::new(last_error),
210        })
211    }
212
213    async fn send_raw_request(
214        &self,
215        request: &MessageRequest,
216    ) -> Result<reqwest::Response, ApiError> {
217        let request_url = chat_completions_endpoint(&self.base_url, self.endpoint_query.as_deref());
218        let mut req = self
219            .http
220            .post(&request_url)
221            .header("content-type", "application/json");
222        if !self.api_key.is_empty() {
223            req = req.bearer_auth(&self.api_key);
224        }
225        req.json(&build_chat_completion_request(request))
226            .send()
227            .await
228            .map_err(ApiError::from)
229    }
230
231    fn backoff_for_attempt(&self, attempt: u32) -> Result<std::time::Duration, ApiError> {
232        let Some(multiplier) = 1_u32.checked_shl(attempt.saturating_sub(1)) else {
233            return Err(ApiError::BackoffOverflow {
234                attempt,
235                base_delay: self.retry.initial_backoff,
236            });
237        };
238        Ok(self
239            .retry
240            .initial_backoff
241            .checked_mul(multiplier)
242            .map_or(self.retry.max_backoff, |delay| {
243                delay.min(self.retry.max_backoff)
244            }))
245    }
246}
247
248// ---------------------------------------------------------------------------
249// Non-streaming DTOs
250// ---------------------------------------------------------------------------
251
252#[derive(Debug, Deserialize)]
253struct ChatCompletionResponse {
254    id: String,
255    model: String,
256    choices: Vec<ChatChoice>,
257    #[serde(default)]
258    usage: Option<OpenAiUsage>,
259}
260
261#[derive(Debug, Deserialize)]
262struct ChatChoice {
263    message: ChatMessage,
264    #[serde(default)]
265    finish_reason: Option<String>,
266}
267
268#[derive(Debug, Deserialize)]
269struct ChatMessage {
270    role: String,
271    #[serde(default, deserialize_with = "deserialize_openai_text_content")]
272    content: Option<String>,
273    #[serde(default)]
274    reasoning_content: Option<String>,
275    #[serde(default)]
276    reasoning: Option<String>,
277    #[serde(default)]
278    thought: Option<String>,
279    #[serde(default)]
280    thinking: Option<String>,
281    #[serde(default)]
282    tool_calls: Vec<ResponseToolCall>,
283}
284
285impl ChatMessage {
286    fn assistant_visible_text(&self) -> Option<String> {
287        first_non_empty_field(&[
288            &self.content,
289            &self.reasoning_content,
290            &self.reasoning,
291            &self.thought,
292            &self.thinking,
293        ])
294    }
295}
296
297#[derive(Debug, Deserialize)]
298struct ResponseToolCall {
299    id: String,
300    function: ResponseToolFunction,
301}
302
303#[derive(Debug, Deserialize)]
304struct ResponseToolFunction {
305    name: String,
306    arguments: String,
307}
308
309#[derive(Debug, Deserialize)]
310pub(super) struct OpenAiUsage {
311    #[serde(default)]
312    pub prompt_tokens: u32,
313    #[serde(default)]
314    pub completion_tokens: u32,
315}
316
317#[derive(Debug, Deserialize)]
318struct ErrorEnvelope {
319    error: ErrorBody,
320}
321
322#[derive(Debug, Deserialize)]
323struct ErrorBody {
324    #[serde(rename = "type")]
325    error_type: Option<String>,
326    message: Option<String>,
327}
328
329// ---------------------------------------------------------------------------
330// Request / response mapping
331// ---------------------------------------------------------------------------
332
333fn upstream_openai_model(model: &str) -> String {
334    parse_custom_provider_prefix(model)
335        .map(|(_, rest)| rest.to_string())
336        .unwrap_or_else(|| model.to_string())
337}
338
339fn build_chat_completion_request(request: &MessageRequest) -> Value {
340    let mut messages = Vec::new();
341    if let Some(system) = request.system.as_ref().filter(|value| !value.is_empty()) {
342        messages.push(json!({
343            "role": "system",
344            "content": system,
345        }));
346    }
347    for message in &request.messages {
348        messages.extend(translate_message(message));
349    }
350
351    let upstream_model = upstream_openai_model(&request.model);
352    const MAX_TOKENS_OPENAI_COMPAT_CAP: u32 = 32_768;
353    let max_tokens = request.max_tokens.clamp(1, MAX_TOKENS_OPENAI_COMPAT_CAP);
354    let mut payload = json!({
355        "model": upstream_model,
356        "max_tokens": max_tokens,
357        "messages": messages,
358        "stream": request.stream,
359    });
360
361    if let Some(tools) = &request.tools {
362        payload["tools"] =
363            Value::Array(tools.iter().map(openai_tool_definition).collect::<Vec<_>>());
364    }
365    if let Some(tool_choice) = &request.tool_choice {
366        payload["tool_choice"] = openai_tool_choice(tool_choice);
367    }
368
369    payload
370}
371
372fn translate_message(message: &InputMessage) -> Vec<Value> {
373    match message.role.as_str() {
374        "assistant" => {
375            let mut text = String::new();
376            let mut tool_calls = Vec::new();
377            for block in &message.content {
378                match block {
379                    InputContentBlock::Text { text: value } => text.push_str(value),
380                    InputContentBlock::ToolUse { id, name, input } => tool_calls.push(json!({
381                        "id": id,
382                        "type": "function",
383                        "function": {
384                            "name": name,
385                            "arguments": serde_json::to_string(input).unwrap_or_default(),
386                        }
387                    })),
388                    InputContentBlock::ToolResult { .. } => {}
389                }
390            }
391            if text.is_empty() && tool_calls.is_empty() {
392                Vec::new()
393            } else {
394                let mut msg = json!({
395                    "role": "assistant",
396                    "content": (!text.is_empty()).then_some(text),
397                });
398                // Only include tool_calls when non-empty; some providers
399                // (e.g. DashScope) reject an empty array.
400                if !tool_calls.is_empty() {
401                    msg["tool_calls"] = json!(tool_calls);
402                }
403                vec![msg]
404            }
405        }
406        _ => message
407            .content
408            .iter()
409            .filter_map(|block| match block {
410                InputContentBlock::Text { text } => Some(json!({
411                    "role": "user",
412                    "content": text,
413                })),
414                InputContentBlock::ToolResult {
415                    tool_use_id,
416                    content,
417                    is_error,
418                } => Some(json!({
419                    "role": "tool",
420                    "tool_call_id": tool_use_id,
421                    "content": flatten_tool_result_content(content),
422                    "is_error": is_error,
423                })),
424                InputContentBlock::ToolUse { .. } => None,
425            })
426            .collect(),
427    }
428}
429
430fn flatten_tool_result_content(content: &[ToolResultContentBlock]) -> String {
431    content
432        .iter()
433        .map(|block| match block {
434            ToolResultContentBlock::Text { text } => text.clone(),
435            ToolResultContentBlock::Json { value } => value.to_string(),
436        })
437        .collect::<Vec<_>>()
438        .join("\n")
439}
440
441fn openai_tool_definition(tool: &ToolDefinition) -> Value {
442    json!({
443        "type": "function",
444        "function": {
445            "name": tool.name,
446            "description": tool.description,
447            "parameters": tool.input_schema,
448        }
449    })
450}
451
452fn openai_tool_choice(tool_choice: &ToolChoice) -> Value {
453    match tool_choice {
454        ToolChoice::Auto => Value::String("auto".to_string()),
455        ToolChoice::Any => Value::String("required".to_string()),
456        ToolChoice::Tool { name } => json!({
457            "type": "function",
458            "function": { "name": name },
459        }),
460    }
461}
462
463fn normalize_response(
464    model: &str,
465    response: ChatCompletionResponse,
466) -> Result<MessageResponse, ApiError> {
467    let choice = response
468        .choices
469        .into_iter()
470        .next()
471        .ok_or(ApiError::InvalidSseFrame(
472            "chat completion response missing choices",
473        ))?;
474    let mut content = Vec::new();
475    if let Some(text) = choice.message.assistant_visible_text() {
476        content.push(OutputContentBlock::Text { text });
477    }
478    for tool_call in choice.message.tool_calls {
479        content.push(OutputContentBlock::ToolUse {
480            id: tool_call.id,
481            name: tool_call.function.name,
482            input: parse_tool_arguments(&tool_call.function.arguments),
483        });
484    }
485
486    Ok(MessageResponse {
487        id: response.id,
488        kind: "message".to_string(),
489        role: choice.message.role,
490        content,
491        model: response.model.if_empty_then(model.to_string()),
492        stop_reason: choice
493            .finish_reason
494            .map(|value| normalize_finish_reason(&value)),
495        stop_sequence: None,
496        usage: Usage {
497            input_tokens: response
498                .usage
499                .as_ref()
500                .map_or(0, |usage| usage.prompt_tokens),
501            cache_creation_input_tokens: 0,
502            cache_read_input_tokens: 0,
503            output_tokens: response
504                .usage
505                .as_ref()
506                .map_or(0, |usage| usage.completion_tokens),
507        },
508        request_id: None,
509    })
510}
511
512fn parse_tool_arguments(arguments: &str) -> Value {
513    serde_json::from_str(arguments).unwrap_or_else(|_| json!({ "raw": arguments }))
514}
515
516// ---------------------------------------------------------------------------
517// Deserialization helpers
518// ---------------------------------------------------------------------------
519
520/// OpenAI-compatible APIs usually send a string; some use `[{type,text}]`-style array parts.
521fn deserialize_openai_text_content<'de, D>(deserializer: D) -> Result<Option<String>, D::Error>
522where
523    D: Deserializer<'de>,
524{
525    #[derive(Deserialize)]
526    #[serde(untagged)]
527    enum Raw {
528        Str(String),
529        Arr(Vec<Value>),
530    }
531    match Option::<Raw>::deserialize(deserializer)? {
532        None => Ok(None),
533        Some(Raw::Str(s)) if s.is_empty() => Ok(None),
534        Some(Raw::Str(s)) => Ok(Some(s)),
535        Some(Raw::Arr(parts)) => {
536            let mut joined = String::new();
537            for part in parts {
538                match part {
539                    Value::Object(map) => {
540                        if let Some(text) = map.get("text").and_then(Value::as_str) {
541                            joined.push_str(text);
542                        } else if let Some(text) = map.get("content").and_then(Value::as_str) {
543                            joined.push_str(text);
544                        }
545                    }
546                    Value::String(s) => joined.push_str(&s),
547                    _ => {}
548                }
549            }
550            Ok((!joined.is_empty()).then_some(joined))
551        }
552    }
553}
554
555// ---------------------------------------------------------------------------
556// Env / URL / HTTP helpers
557// ---------------------------------------------------------------------------
558
559fn read_env_non_empty(key: &str) -> Result<Option<String>, ApiError> {
560    match std::env::var(key) {
561        Ok(value) if !value.is_empty() => Ok(Some(value)),
562        Ok(_) | Err(std::env::VarError::NotPresent) => Ok(None),
563        Err(error) => Err(ApiError::from(error)),
564    }
565}
566
567#[must_use]
568pub fn has_api_key(key: &str) -> bool {
569    read_env_non_empty(key)
570        .ok()
571        .and_then(std::convert::identity)
572        .is_some()
573}
574
575#[must_use]
576pub fn read_base_url(config: OpenAiCompatConfig) -> String {
577    std::env::var(config.base_url_env).unwrap_or_else(|_| config.default_base_url.to_string())
578}
579
580fn chat_completions_endpoint(base_url: &str, extra_query: Option<&str>) -> String {
581    let trimmed = base_url.trim();
582    let (path_part, base_query) = match trimmed.split_once('?') {
583        Some((p, q)) => (p.trim_end_matches('/'), Some(q)),
584        None => (trimmed.trim_end_matches('/'), None),
585    };
586    let path = if path_part.ends_with("/chat/completions") {
587        path_part.to_string()
588    } else {
589        format!("{path_part}/chat/completions")
590    };
591    merge_url_query(&path, base_query, extra_query)
592}
593
594fn merge_url_query(path: &str, base_query: Option<&str>, extra_query: Option<&str>) -> String {
595    let mut segments: Vec<&str> = Vec::new();
596    if let Some(q) = base_query.map(str::trim).filter(|q| !q.is_empty()) {
597        segments.push(q);
598    }
599    if let Some(q) = extra_query.map(str::trim).filter(|q| !q.is_empty()) {
600        segments.push(q);
601    }
602    if segments.is_empty() {
603        path.to_string()
604    } else {
605        format!("{path}?{}", segments.join("&"))
606    }
607}
608
609fn request_id_from_headers(headers: &reqwest::header::HeaderMap) -> Option<String> {
610    headers
611        .get(REQUEST_ID_HEADER)
612        .or_else(|| headers.get(ALT_REQUEST_ID_HEADER))
613        .and_then(|value| value.to_str().ok())
614        .map(ToOwned::to_owned)
615}
616
617async fn expect_success(response: reqwest::Response) -> Result<reqwest::Response, ApiError> {
618    let status = response.status();
619    if status.is_success() {
620        return Ok(response);
621    }
622
623    let body = response.text().await.unwrap_or_default();
624    let parsed_error = serde_json::from_str::<ErrorEnvelope>(&body).ok();
625    let retryable = is_retryable_status(status);
626
627    Err(ApiError::Api {
628        status,
629        error_type: parsed_error
630            .as_ref()
631            .and_then(|error| error.error.error_type.clone()),
632        message: parsed_error
633            .as_ref()
634            .and_then(|error| error.error.message.clone()),
635        body,
636        retryable,
637    })
638}
639
640const fn is_retryable_status(status: reqwest::StatusCode) -> bool {
641    matches!(status.as_u16(), 408 | 409 | 429 | 500 | 502 | 503 | 504)
642}
643
644fn normalize_finish_reason(value: &str) -> String {
645    match value {
646        "stop" => "end_turn",
647        "tool_calls" => "tool_use",
648        other => other,
649    }
650    .to_string()
651}
652
653trait StringExt {
654    fn if_empty_then(self, fallback: String) -> String;
655}
656
657impl StringExt for String {
658    fn if_empty_then(self, fallback: String) -> String {
659        if self.is_empty() {
660            fallback
661        } else {
662            self
663        }
664    }
665}
666
667#[cfg(test)]
668mod openai_compat_inner_tests {
669    use super::*;
670    use crate::types::OutputContentBlock;
671
672    #[test]
673    fn chat_completions_url_appends_api_version() {
674        assert_eq!(
675            chat_completions_endpoint(
676                "https://my.openai.azure.com/openai/deployments/gpt4",
677                Some("api-version=2024-02-15-preview"),
678            ),
679            "https://my.openai.azure.com/openai/deployments/gpt4/chat/completions?api-version=2024-02-15-preview"
680        );
681    }
682
683    #[test]
684    fn chat_completions_url_merges_base_query_and_api_version() {
685        assert_eq!(
686            chat_completions_endpoint(
687                "https://x/v1/chat/completions?existing=1",
688                Some("api-version=2024-02-15-preview"),
689            ),
690            "https://x/v1/chat/completions?existing=1&api-version=2024-02-15-preview"
691        );
692    }
693
694    #[test]
695    fn non_streaming_message_parses_content_array() {
696        let json = r#"{
697            "id":"1",
698            "model":"qwen",
699            "choices":[{
700                "message":{"role":"assistant","content":[{"type":"text","text":"hello"}]},
701                "finish_reason":"stop"
702            }],
703            "usage":{"prompt_tokens":1,"completion_tokens":1}
704        }"#;
705        let resp: ChatCompletionResponse = serde_json::from_str(json).unwrap();
706        let msg = normalize_response("qwen", resp).expect("normalize");
707        assert_eq!(
708            msg.content,
709            vec![OutputContentBlock::Text {
710                text: "hello".to_string()
711            }]
712        );
713    }
714
715    #[test]
716    fn non_streaming_reasoning_only_message() {
717        let json = r#"{
718            "id":"1",
719            "model":"qwen",
720            "choices":[{
721                "message":{"role":"assistant","content":null,"reasoning_content":"think"},
722                "finish_reason":"stop"
723            }],
724            "usage":{"prompt_tokens":1,"completion_tokens":1}
725        }"#;
726        let resp: ChatCompletionResponse = serde_json::from_str(json).unwrap();
727        let msg = normalize_response("qwen", resp).expect("normalize");
728        assert_eq!(
729            msg.content,
730            vec![OutputContentBlock::Text {
731                text: "think".to_string()
732            }]
733        );
734    }
735}
736
737#[cfg(test)]
738#[path = "openai_compat_tests.rs"]
739mod tests;