Skip to main content

codineer_api/providers/
openai_compat.rs

1#[path = "openai_compat_sse.rs"]
2mod openai_compat_sse;
3#[path = "openai_compat_stream.rs"]
4mod openai_compat_stream;
5
6use std::collections::VecDeque;
7
8use serde::{Deserialize, Deserializer};
9use serde_json::{json, Value};
10
11use crate::error::ApiError;
12use crate::providers::{parse_custom_provider_prefix, RetryPolicy};
13use crate::types::{
14    InputContentBlock, InputMessage, MessageRequest, MessageResponse, OutputContentBlock,
15    ToolChoice, ToolDefinition, ToolResultContentBlock, Usage,
16};
17
18use openai_compat_sse::{first_non_empty_field, OpenAiSseParser};
19use openai_compat_stream::StreamState;
20
21pub use openai_compat_stream::MessageStream;
22
23pub const DEFAULT_XAI_BASE_URL: &str = "https://api.x.ai/v1";
24pub const DEFAULT_OPENAI_BASE_URL: &str = "https://api.openai.com/v1";
25const REQUEST_ID_HEADER: &str = "request-id";
26const ALT_REQUEST_ID_HEADER: &str = "x-request-id";
27
28#[derive(Debug, Clone, Copy, PartialEq, Eq)]
29pub struct OpenAiCompatConfig {
30    pub provider_name: &'static str,
31    pub api_key_env: &'static str,
32    pub base_url_env: &'static str,
33    pub default_base_url: &'static str,
34}
35
36const XAI_ENV_VARS: &[&str] = &["XAI_API_KEY"];
37const OPENAI_ENV_VARS: &[&str] = &["OPENAI_API_KEY"];
38
39impl OpenAiCompatConfig {
40    #[must_use]
41    pub const fn xai() -> Self {
42        Self {
43            provider_name: "xAI",
44            api_key_env: "XAI_API_KEY",
45            base_url_env: "XAI_BASE_URL",
46            default_base_url: DEFAULT_XAI_BASE_URL,
47        }
48    }
49
50    #[must_use]
51    pub const fn openai() -> Self {
52        Self {
53            provider_name: "OpenAI",
54            api_key_env: "OPENAI_API_KEY",
55            base_url_env: "OPENAI_BASE_URL",
56            default_base_url: DEFAULT_OPENAI_BASE_URL,
57        }
58    }
59    #[must_use]
60    pub fn credential_env_vars(self) -> &'static [&'static str] {
61        match self.api_key_env {
62            "XAI_API_KEY" => XAI_ENV_VARS,
63            "OPENAI_API_KEY" => OPENAI_ENV_VARS,
64            _ => &[],
65        }
66    }
67}
68
69#[derive(Clone)]
70pub struct OpenAiCompatClient {
71    http: reqwest::Client,
72    api_key: String,
73    base_url: String,
74    endpoint_query: Option<String>,
75    retry: RetryPolicy,
76    custom_headers: std::collections::BTreeMap<String, String>,
77}
78
79impl std::fmt::Debug for OpenAiCompatClient {
80    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
81        f.debug_struct("OpenAiCompatClient")
82            .field("base_url", &self.base_url)
83            .field("endpoint_query", &self.endpoint_query)
84            .field("api_key", &"***")
85            .finish()
86    }
87}
88
89impl OpenAiCompatClient {
90    #[must_use]
91    pub fn new(api_key: impl Into<String>, config: OpenAiCompatConfig) -> Self {
92        Self {
93            http: crate::default_http_client(),
94            api_key: api_key.into(),
95            base_url: read_base_url(config),
96            endpoint_query: None,
97            retry: RetryPolicy::default(),
98            custom_headers: std::collections::BTreeMap::new(),
99        }
100    }
101
102    #[must_use]
103    pub fn new_custom(base_url: impl Into<String>, api_key: impl Into<String>) -> Self {
104        Self {
105            http: crate::default_http_client(),
106            api_key: api_key.into(),
107            base_url: base_url.into(),
108            endpoint_query: None,
109            retry: RetryPolicy::default(),
110            custom_headers: std::collections::BTreeMap::new(),
111        }
112    }
113
114    #[must_use]
115    pub fn with_endpoint_query(mut self, endpoint_query: Option<String>) -> Self {
116        self.endpoint_query = endpoint_query
117            .map(|s| s.trim().to_string())
118            .filter(|s| !s.is_empty());
119        self
120    }
121
122    #[must_use]
123    pub fn with_custom_headers(
124        mut self,
125        headers: std::collections::BTreeMap<String, String>,
126    ) -> Self {
127        self.custom_headers = headers;
128        self
129    }
130
131    pub fn from_env(config: OpenAiCompatConfig) -> Result<Self, ApiError> {
132        let Some(api_key) = read_env_non_empty(config.api_key_env)? else {
133            return Err(ApiError::missing_credentials(
134                config.provider_name,
135                config.credential_env_vars(),
136            ));
137        };
138        Ok(Self::new(api_key, config))
139    }
140
141    #[must_use]
142    pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
143        self.base_url = base_url.into();
144        self
145    }
146
147    #[must_use]
148    pub fn with_retry_policy(mut self, retry: RetryPolicy) -> Self {
149        self.retry = retry;
150        self
151    }
152
153    pub async fn send_message(
154        &self,
155        request: &MessageRequest,
156    ) -> Result<MessageResponse, ApiError> {
157        let request = MessageRequest {
158            stream: false,
159            ..request.clone()
160        };
161        let response = self.send_with_retry(&request).await?;
162        let request_id = request_id_from_headers(response.headers());
163        let payload = response.json::<ChatCompletionResponse>().await?;
164        let mut normalized = normalize_response(&request.model, payload)?;
165        if normalized.request_id.is_none() {
166            normalized.request_id = request_id;
167        }
168        Ok(normalized)
169    }
170
171    pub async fn stream_message(
172        &self,
173        request: &MessageRequest,
174    ) -> Result<MessageStream, ApiError> {
175        let response = self
176            .send_with_retry(&request.clone().with_streaming())
177            .await?;
178        Ok(MessageStream {
179            request_id: request_id_from_headers(response.headers()),
180            response,
181            parser: OpenAiSseParser::new(),
182            pending: VecDeque::new(),
183            done: false,
184            state: StreamState::new(request.model.clone()),
185        })
186    }
187
188    async fn send_with_retry(
189        &self,
190        request: &MessageRequest,
191    ) -> Result<reqwest::Response, ApiError> {
192        let mut attempts = 0;
193
194        let last_error = loop {
195            attempts += 1;
196            let retryable_error = match self.send_raw_request(request).await {
197                Ok(response) => match expect_success(response).await {
198                    Ok(response) => return Ok(response),
199                    Err(error)
200                        if error.is_retryable() && attempts <= self.retry.max_retries + 1 =>
201                    {
202                        error
203                    }
204                    Err(error) => return Err(error),
205                },
206                Err(error) if error.is_retryable() && attempts <= self.retry.max_retries + 1 => {
207                    error
208                }
209                Err(error) => return Err(error),
210            };
211
212            if attempts > self.retry.max_retries {
213                break retryable_error;
214            }
215
216            tokio::time::sleep(self.backoff_for_attempt(attempts)?).await;
217        };
218
219        Err(ApiError::RetriesExhausted {
220            attempts,
221            last_error: Box::new(last_error),
222        })
223    }
224
225    async fn send_raw_request(
226        &self,
227        request: &MessageRequest,
228    ) -> Result<reqwest::Response, ApiError> {
229        let request_url = chat_completions_endpoint(&self.base_url, self.endpoint_query.as_deref());
230        let mut req = self
231            .http
232            .post(&request_url)
233            .header("content-type", "application/json");
234        if !self.api_key.is_empty() {
235            req = req.bearer_auth(&self.api_key);
236        }
237        for (name, value) in &self.custom_headers {
238            req = req.header(name.as_str(), value.as_str());
239        }
240        req.json(&build_chat_completion_request(request))
241            .send()
242            .await
243            .map_err(ApiError::from)
244    }
245
246    fn backoff_for_attempt(&self, attempt: u32) -> Result<std::time::Duration, ApiError> {
247        let Some(multiplier) = 1_u32.checked_shl(attempt.saturating_sub(1)) else {
248            return Err(ApiError::BackoffOverflow {
249                attempt,
250                base_delay: self.retry.initial_backoff,
251            });
252        };
253        Ok(self
254            .retry
255            .initial_backoff
256            .checked_mul(multiplier)
257            .map_or(self.retry.max_backoff, |delay| {
258                delay.min(self.retry.max_backoff)
259            }))
260    }
261}
262
263// ---------------------------------------------------------------------------
264// Non-streaming DTOs
265// ---------------------------------------------------------------------------
266
267#[derive(Debug, Deserialize)]
268struct ChatCompletionResponse {
269    id: String,
270    model: String,
271    choices: Vec<ChatChoice>,
272    #[serde(default)]
273    usage: Option<OpenAiUsage>,
274}
275
276#[derive(Debug, Deserialize)]
277struct ChatChoice {
278    message: ChatMessage,
279    #[serde(default)]
280    finish_reason: Option<String>,
281}
282
283#[derive(Debug, Deserialize)]
284struct ChatMessage {
285    role: String,
286    #[serde(default, deserialize_with = "deserialize_openai_text_content")]
287    content: Option<String>,
288    #[serde(default)]
289    reasoning_content: Option<String>,
290    #[serde(default)]
291    reasoning: Option<String>,
292    #[serde(default)]
293    thought: Option<String>,
294    #[serde(default)]
295    thinking: Option<String>,
296    #[serde(default)]
297    tool_calls: Vec<ResponseToolCall>,
298}
299
300impl ChatMessage {
301    fn assistant_visible_text(&self) -> Option<String> {
302        first_non_empty_field(&[
303            &self.content,
304            &self.reasoning_content,
305            &self.reasoning,
306            &self.thought,
307            &self.thinking,
308        ])
309    }
310}
311
312#[derive(Debug, Deserialize)]
313struct ResponseToolCall {
314    id: String,
315    function: ResponseToolFunction,
316}
317
318#[derive(Debug, Deserialize)]
319struct ResponseToolFunction {
320    name: String,
321    arguments: String,
322}
323
324#[derive(Debug, Deserialize)]
325pub(super) struct OpenAiUsage {
326    #[serde(default)]
327    pub prompt_tokens: u32,
328    #[serde(default)]
329    pub completion_tokens: u32,
330}
331
332#[derive(Debug, Deserialize)]
333struct ErrorEnvelope {
334    error: ErrorBody,
335}
336
337#[derive(Debug, Deserialize)]
338struct ErrorBody {
339    #[serde(rename = "type")]
340    error_type: Option<String>,
341    message: Option<String>,
342}
343
344// ---------------------------------------------------------------------------
345// Request / response mapping
346// ---------------------------------------------------------------------------
347
348fn upstream_openai_model(model: &str) -> String {
349    parse_custom_provider_prefix(model)
350        .map(|(_, rest)| rest.to_string())
351        .unwrap_or_else(|| model.to_string())
352}
353
354fn build_chat_completion_request(request: &MessageRequest) -> Value {
355    let mut messages = Vec::new();
356    if let Some(system) = request.system.as_ref().filter(|value| !value.is_empty()) {
357        messages.push(json!({
358            "role": "system",
359            "content": system,
360        }));
361    }
362    for message in &request.messages {
363        messages.extend(translate_message(message));
364    }
365
366    let upstream_model = upstream_openai_model(&request.model);
367    const MAX_TOKENS_OPENAI_COMPAT_CAP: u32 = 32_768;
368    let max_tokens = request.max_tokens.clamp(1, MAX_TOKENS_OPENAI_COMPAT_CAP);
369    let mut payload = json!({
370        "model": upstream_model,
371        "max_tokens": max_tokens,
372        "messages": messages,
373        "stream": request.stream,
374    });
375
376    if let Some(tools) = &request.tools {
377        payload["tools"] =
378            Value::Array(tools.iter().map(openai_tool_definition).collect::<Vec<_>>());
379    }
380    if let Some(tool_choice) = &request.tool_choice {
381        payload["tool_choice"] = openai_tool_choice(tool_choice);
382    }
383
384    payload
385}
386
387fn translate_message(message: &InputMessage) -> Vec<Value> {
388    match message.role.as_str() {
389        "assistant" => {
390            let mut text = String::new();
391            let mut tool_calls = Vec::new();
392            for block in &message.content {
393                match block {
394                    InputContentBlock::Text { text: value } => text.push_str(value),
395                    InputContentBlock::ToolUse { id, name, input } => tool_calls.push(json!({
396                        "id": id,
397                        "type": "function",
398                        "function": {
399                            "name": name,
400                            "arguments": serde_json::to_string(input).unwrap_or_default(),
401                        }
402                    })),
403                    InputContentBlock::Image { .. } | InputContentBlock::ToolResult { .. } => {}
404                }
405            }
406            if text.is_empty() && tool_calls.is_empty() {
407                Vec::new()
408            } else {
409                let mut msg = json!({
410                    "role": "assistant",
411                    "content": (!text.is_empty()).then_some(text),
412                });
413                // Only include tool_calls when non-empty; some providers
414                // (e.g. DashScope) reject an empty array.
415                if !tool_calls.is_empty() {
416                    msg["tool_calls"] = json!(tool_calls);
417                }
418                vec![msg]
419            }
420        }
421        _ => {
422            let has_image = message
423                .content
424                .iter()
425                .any(|b| matches!(b, InputContentBlock::Image { .. }));
426            let mut result = Vec::new();
427            let mut user_parts: Vec<Value> = Vec::new();
428
429            for block in &message.content {
430                match block {
431                    InputContentBlock::Text { text } => {
432                        if has_image {
433                            user_parts.push(json!({ "type": "text", "text": text }));
434                        } else {
435                            result.push(json!({ "role": "user", "content": text }));
436                        }
437                    }
438                    InputContentBlock::Image { source } => {
439                        let data_url = format!("data:{};base64,{}", source.media_type, source.data);
440                        user_parts.push(json!({
441                            "type": "image_url",
442                            "image_url": { "url": data_url }
443                        }));
444                    }
445                    InputContentBlock::ToolResult {
446                        tool_use_id,
447                        content,
448                        is_error,
449                    } => {
450                        flush_user_parts(&mut user_parts, &mut result);
451                        result.push(json!({
452                            "role": "tool",
453                            "tool_call_id": tool_use_id,
454                            "content": flatten_tool_result_content(content),
455                            "is_error": is_error,
456                        }));
457                    }
458                    InputContentBlock::ToolUse { .. } => {}
459                }
460            }
461            flush_user_parts(&mut user_parts, &mut result);
462            result
463        }
464    }
465}
466
467fn flush_user_parts(parts: &mut Vec<Value>, result: &mut Vec<Value>) {
468    if parts.is_empty() {
469        return;
470    }
471    let content = Value::Array(std::mem::take(parts));
472    result.push(json!({ "role": "user", "content": content }));
473}
474
475fn flatten_tool_result_content(content: &[ToolResultContentBlock]) -> String {
476    content
477        .iter()
478        .filter_map(|block| match block {
479            ToolResultContentBlock::Text { text } => Some(text.clone()),
480            ToolResultContentBlock::Json { value } => Some(value.to_string()),
481            ToolResultContentBlock::Image { .. } => None,
482        })
483        .collect::<Vec<_>>()
484        .join("\n")
485}
486
487fn openai_tool_definition(tool: &ToolDefinition) -> Value {
488    json!({
489        "type": "function",
490        "function": {
491            "name": tool.name,
492            "description": tool.description,
493            "parameters": tool.input_schema,
494        }
495    })
496}
497
498fn openai_tool_choice(tool_choice: &ToolChoice) -> Value {
499    match tool_choice {
500        ToolChoice::Auto => Value::String("auto".to_string()),
501        ToolChoice::Any => Value::String("required".to_string()),
502        ToolChoice::Tool { name } => json!({
503            "type": "function",
504            "function": { "name": name },
505        }),
506    }
507}
508
509fn normalize_response(
510    model: &str,
511    response: ChatCompletionResponse,
512) -> Result<MessageResponse, ApiError> {
513    let choice = response
514        .choices
515        .into_iter()
516        .next()
517        .ok_or(ApiError::InvalidSseFrame(
518            "chat completion response missing choices",
519        ))?;
520    let mut content = Vec::new();
521    if let Some(text) = choice.message.assistant_visible_text() {
522        content.push(OutputContentBlock::Text { text });
523    }
524    for tool_call in choice.message.tool_calls {
525        content.push(OutputContentBlock::ToolUse {
526            id: tool_call.id,
527            name: tool_call.function.name,
528            input: parse_tool_arguments(&tool_call.function.arguments),
529        });
530    }
531
532    Ok(MessageResponse {
533        id: response.id,
534        kind: "message".to_string(),
535        role: choice.message.role,
536        content,
537        model: response.model.if_empty_then(model.to_string()),
538        stop_reason: choice
539            .finish_reason
540            .map(|value| normalize_finish_reason(&value)),
541        stop_sequence: None,
542        usage: Usage {
543            input_tokens: response
544                .usage
545                .as_ref()
546                .map_or(0, |usage| usage.prompt_tokens),
547            cache_creation_input_tokens: 0,
548            cache_read_input_tokens: 0,
549            output_tokens: response
550                .usage
551                .as_ref()
552                .map_or(0, |usage| usage.completion_tokens),
553        },
554        request_id: None,
555    })
556}
557
558fn parse_tool_arguments(arguments: &str) -> Value {
559    serde_json::from_str(arguments).unwrap_or_else(|_| json!({ "raw": arguments }))
560}
561
562// ---------------------------------------------------------------------------
563// Deserialization helpers
564// ---------------------------------------------------------------------------
565
566/// OpenAI-compatible APIs usually send a string; some use `[{type,text}]`-style array parts.
567fn deserialize_openai_text_content<'de, D>(deserializer: D) -> Result<Option<String>, D::Error>
568where
569    D: Deserializer<'de>,
570{
571    #[derive(Deserialize)]
572    #[serde(untagged)]
573    enum Raw {
574        Str(String),
575        Arr(Vec<Value>),
576    }
577    match Option::<Raw>::deserialize(deserializer)? {
578        None => Ok(None),
579        Some(Raw::Str(s)) if s.is_empty() => Ok(None),
580        Some(Raw::Str(s)) => Ok(Some(s)),
581        Some(Raw::Arr(parts)) => {
582            let mut joined = String::new();
583            for part in parts {
584                match part {
585                    Value::Object(map) => {
586                        if let Some(text) = map.get("text").and_then(Value::as_str) {
587                            joined.push_str(text);
588                        } else if let Some(text) = map.get("content").and_then(Value::as_str) {
589                            joined.push_str(text);
590                        }
591                    }
592                    Value::String(s) => joined.push_str(&s),
593                    _ => {}
594                }
595            }
596            Ok((!joined.is_empty()).then_some(joined))
597        }
598    }
599}
600
601// ---------------------------------------------------------------------------
602// Env / URL / HTTP helpers
603// ---------------------------------------------------------------------------
604
605fn read_env_non_empty(key: &str) -> Result<Option<String>, ApiError> {
606    match std::env::var(key) {
607        Ok(value) if !value.is_empty() => Ok(Some(value)),
608        Ok(_) | Err(std::env::VarError::NotPresent) => Ok(None),
609        Err(error) => Err(ApiError::from(error)),
610    }
611}
612
613#[must_use]
614pub fn has_api_key(key: &str) -> bool {
615    read_env_non_empty(key)
616        .ok()
617        .and_then(std::convert::identity)
618        .is_some()
619}
620
621#[must_use]
622pub fn read_base_url(config: OpenAiCompatConfig) -> String {
623    std::env::var(config.base_url_env).unwrap_or_else(|_| config.default_base_url.to_string())
624}
625
626fn chat_completions_endpoint(base_url: &str, extra_query: Option<&str>) -> String {
627    let trimmed = base_url.trim();
628    let (path_part, base_query) = match trimmed.split_once('?') {
629        Some((p, q)) => (p.trim_end_matches('/'), Some(q)),
630        None => (trimmed.trim_end_matches('/'), None),
631    };
632    let path = if path_part.ends_with("/chat/completions") {
633        path_part.to_string()
634    } else {
635        format!("{path_part}/chat/completions")
636    };
637    merge_url_query(&path, base_query, extra_query)
638}
639
640fn merge_url_query(path: &str, base_query: Option<&str>, extra_query: Option<&str>) -> String {
641    let mut segments: Vec<&str> = Vec::new();
642    if let Some(q) = base_query.map(str::trim).filter(|q| !q.is_empty()) {
643        segments.push(q);
644    }
645    if let Some(q) = extra_query.map(str::trim).filter(|q| !q.is_empty()) {
646        segments.push(q);
647    }
648    if segments.is_empty() {
649        path.to_string()
650    } else {
651        format!("{path}?{}", segments.join("&"))
652    }
653}
654
655fn request_id_from_headers(headers: &reqwest::header::HeaderMap) -> Option<String> {
656    headers
657        .get(REQUEST_ID_HEADER)
658        .or_else(|| headers.get(ALT_REQUEST_ID_HEADER))
659        .and_then(|value| value.to_str().ok())
660        .map(ToOwned::to_owned)
661}
662
663async fn expect_success(response: reqwest::Response) -> Result<reqwest::Response, ApiError> {
664    let status = response.status();
665    if status.is_success() {
666        return Ok(response);
667    }
668
669    let url = response.url().to_string();
670    let body = response.text().await.unwrap_or_default();
671    let parsed_error = serde_json::from_str::<ErrorEnvelope>(&body).ok();
672    let retryable = is_retryable_status(status);
673
674    Err(ApiError::Api {
675        status,
676        error_type: parsed_error
677            .as_ref()
678            .and_then(|error| error.error.error_type.clone()),
679        message: parsed_error
680            .as_ref()
681            .and_then(|error| error.error.message.clone()),
682        body,
683        url: Some(url),
684        retryable,
685    })
686}
687
688const fn is_retryable_status(status: reqwest::StatusCode) -> bool {
689    matches!(status.as_u16(), 408 | 409 | 429 | 500 | 502 | 503 | 504)
690}
691
692fn normalize_finish_reason(value: &str) -> String {
693    match value {
694        "stop" => "end_turn",
695        "tool_calls" => "tool_use",
696        other => other,
697    }
698    .to_string()
699}
700
701trait StringExt {
702    fn if_empty_then(self, fallback: String) -> String;
703}
704
705impl StringExt for String {
706    fn if_empty_then(self, fallback: String) -> String {
707        if self.is_empty() {
708            fallback
709        } else {
710            self
711        }
712    }
713}
714
715#[cfg(test)]
716mod openai_compat_inner_tests {
717    use super::*;
718    use crate::types::OutputContentBlock;
719
720    #[test]
721    fn chat_completions_url_appends_api_version() {
722        assert_eq!(
723            chat_completions_endpoint(
724                "https://my.openai.azure.com/openai/deployments/gpt4",
725                Some("api-version=2024-02-15-preview"),
726            ),
727            "https://my.openai.azure.com/openai/deployments/gpt4/chat/completions?api-version=2024-02-15-preview"
728        );
729    }
730
731    #[test]
732    fn chat_completions_url_merges_base_query_and_api_version() {
733        assert_eq!(
734            chat_completions_endpoint(
735                "https://x/v1/chat/completions?existing=1",
736                Some("api-version=2024-02-15-preview"),
737            ),
738            "https://x/v1/chat/completions?existing=1&api-version=2024-02-15-preview"
739        );
740    }
741
742    #[test]
743    fn non_streaming_message_parses_content_array() {
744        let json = r#"{
745            "id":"1",
746            "model":"qwen",
747            "choices":[{
748                "message":{"role":"assistant","content":[{"type":"text","text":"hello"}]},
749                "finish_reason":"stop"
750            }],
751            "usage":{"prompt_tokens":1,"completion_tokens":1}
752        }"#;
753        let resp: ChatCompletionResponse = serde_json::from_str(json).unwrap();
754        let msg = normalize_response("qwen", resp).expect("normalize");
755        assert_eq!(
756            msg.content,
757            vec![OutputContentBlock::Text {
758                text: "hello".to_string()
759            }]
760        );
761    }
762
763    #[test]
764    fn non_streaming_reasoning_only_message() {
765        let json = r#"{
766            "id":"1",
767            "model":"qwen",
768            "choices":[{
769                "message":{"role":"assistant","content":null,"reasoning_content":"think"},
770                "finish_reason":"stop"
771            }],
772            "usage":{"prompt_tokens":1,"completion_tokens":1}
773        }"#;
774        let resp: ChatCompletionResponse = serde_json::from_str(json).unwrap();
775        let msg = normalize_response("qwen", resp).expect("normalize");
776        assert_eq!(
777            msg.content,
778            vec![OutputContentBlock::Text {
779                text: "think".to_string()
780            }]
781        );
782    }
783
784    #[test]
785    fn translate_user_message_with_image_produces_content_array() {
786        use crate::types::ImageSource;
787        let msg = InputMessage {
788            role: "user".to_string(),
789            content: vec![
790                InputContentBlock::Text {
791                    text: "describe this".to_string(),
792                },
793                InputContentBlock::Image {
794                    source: ImageSource {
795                        source_type: "base64".to_string(),
796                        media_type: "image/png".to_string(),
797                        data: "abc123".to_string(),
798                    },
799                },
800            ],
801        };
802        let result = translate_message(&msg);
803        assert_eq!(result.len(), 1);
804        let content = &result[0]["content"];
805        assert!(content.is_array(), "content should be an array");
806        let arr = content.as_array().unwrap();
807        assert_eq!(arr.len(), 2);
808        assert_eq!(arr[0]["type"], "text");
809        assert_eq!(arr[0]["text"], "describe this");
810        assert_eq!(arr[1]["type"], "image_url");
811        assert_eq!(arr[1]["image_url"]["url"], "data:image/png;base64,abc123");
812    }
813
814    #[test]
815    fn translate_text_only_user_message_stays_string() {
816        let msg = InputMessage::user_text("hello");
817        let result = translate_message(&msg);
818        assert_eq!(result.len(), 1);
819        assert_eq!(result[0]["content"], "hello");
820    }
821}
822
823#[cfg(test)]
824#[path = "openai_compat_tests.rs"]
825mod tests;