Skip to main content

rustic_ai/providers/
openai.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use async_stream::try_stream;
5use async_trait::async_trait;
6use base64::{Engine as _, engine::general_purpose};
7use eventsource_stream::Eventsource;
8use futures::stream::StreamExt;
9use reqwest::{Client, Url};
10use serde::Deserialize;
11use serde_json::{Map, Value, json};
12use uuid::Uuid;
13
14use crate::json_schema::transform_openai_schema;
15use crate::messages::{
16    BinaryContent, ModelMessage, ModelRequestPart, ModelResponse, ModelResponsePart,
17    ProviderItemPart, TextPart, ToolCallPart, UserContent,
18};
19use crate::model::{
20    Model, ModelError, ModelRequestParameters, ModelSettings, ModelStream, OutputMode, StreamChunk,
21};
22use crate::providers::{Provider, ProviderError};
23use crate::usage::RequestUsage;
24
25struct OpenAIRequest {
26    body: Value,
27}
28
29fn map_reqwest_error(label: &str, error: reqwest::Error) -> ModelError {
30    if error.is_timeout() {
31        return ModelError::Timeout;
32    }
33    if error.is_connect() {
34        return ModelError::Transport(format!("{label} connect error: {error}"));
35    }
36    ModelError::Transport(format!("{label} request failed: {error}"))
37}
38
39fn truncate_error_body(body: &str) -> String {
40    const LIMIT: usize = 2000;
41    let trimmed = body.trim();
42    if trimmed.is_empty() {
43        return String::new();
44    }
45    if trimmed.chars().count() <= LIMIT {
46        return trimmed.to_string();
47    }
48    let truncated: String = trimmed.chars().take(LIMIT).collect();
49    format!("{truncated}...[truncated]")
50}
51
52fn join_path(base: &Url, path: &str) -> Result<Url, ModelError> {
53    let mut url = base.clone();
54    let base_path = url.path().trim_end_matches('/');
55    let path = path.trim_start_matches('/');
56    let new_path = if base_path.is_empty() || base_path == "/" {
57        format!("/{path}")
58    } else {
59        format!("{base_path}/{path}")
60    };
61    url.set_path(&new_path);
62    Ok(url)
63}
64
65fn normalize_tool_call_id(id: Option<String>) -> String {
66    match id {
67        Some(value) if !value.trim().is_empty() => value,
68        _ => format!("call_{}", Uuid::new_v4().simple()),
69    }
70}
71
72fn normalize_tool_call_id_str(id: &str) -> String {
73    if id.trim().is_empty() {
74        format!("call_{}", Uuid::new_v4().simple())
75    } else {
76        id.to_string()
77    }
78}
79
80fn tool_return_content(value: &Value) -> String {
81    match value {
82        Value::String(value) => value.clone(),
83        _ => serde_json::to_string(value).unwrap_or_else(|_| value.to_string()),
84    }
85}
86
87fn tool_call_arguments(value: &Value) -> String {
88    match value {
89        Value::String(value) => value.clone(),
90        _ => serde_json::to_string(value).unwrap_or_else(|_| value.to_string()),
91    }
92}
93
94fn is_text_like_media_type(media_type: &str) -> bool {
95    media_type.starts_with("text/")
96        || matches!(
97            media_type,
98            "application/json"
99                | "application/xml"
100                | "application/xhtml+xml"
101                | "application/javascript"
102                | "application/x-www-form-urlencoded"
103        )
104}
105
106fn audio_format_from_media_type(media_type: &str) -> Option<&'static str> {
107    match media_type {
108        "audio/wav" | "audio/x-wav" => Some("wav"),
109        "audio/mpeg" | "audio/mp3" => Some("mp3"),
110        "audio/ogg" | "audio/ogg;codecs=opus" => Some("ogg"),
111        "audio/flac" => Some("flac"),
112        "audio/aiff" => Some("aiff"),
113        "audio/aac" => Some("aac"),
114        _ => None,
115    }
116}
117
118fn parse_data_url_base64(url: &str) -> Option<(String, String)> {
119    let data_url = url.strip_prefix("data:")?;
120    let (meta, data) = data_url.split_once(',')?;
121    let (media_type, encoding) = meta.split_once(';')?;
122    if encoding != "base64" || media_type.trim().is_empty() {
123        return None;
124    }
125    Some((media_type.to_string(), data.to_string()))
126}
127
128fn normalize_stream_tool_call_id(id: Option<String>, index: Option<usize>) -> String {
129    if let Some(value) = id.filter(|value| !value.trim().is_empty()) {
130        value
131    } else if let Some(index) = index {
132        format!("call_{index}")
133    } else {
134        normalize_tool_call_id(None)
135    }
136}
137
138fn contains_audio(messages: &[ModelMessage]) -> bool {
139    for message in messages {
140        if let ModelMessage::Request(req) = message {
141            for part in &req.parts {
142                if let ModelRequestPart::UserPrompt(prompt) = part {
143                    for item in &prompt.content {
144                        match item {
145                            UserContent::Audio(_) => return true,
146                            UserContent::Binary(binary) => {
147                                if binary.media_type.starts_with("audio/") {
148                                    return true;
149                                }
150                            }
151                            _ => {}
152                        }
153                    }
154                }
155            }
156        }
157    }
158    false
159}
160
161fn is_responses_only_model(model: &str) -> bool {
162    let lowered = model.to_lowercase();
163    lowered.starts_with("gpt-5")
164        || lowered.starts_with("gpt-4.1")
165        || lowered.starts_with("o1")
166        || lowered.starts_with("o3")
167}
168
169fn prefers_responses(model: &str) -> bool {
170    let lowered = model.to_lowercase();
171    is_responses_only_model(model)
172        || lowered.starts_with("gpt-4o")
173        || lowered.starts_with("gpt-4.1")
174        || lowered.starts_with("o1")
175        || lowered.starts_with("o3")
176}
177
178#[derive(Clone, Debug)]
179pub(crate) struct OpenAIChatCapabilities {
180    pub(crate) supports_response_format: bool,
181    pub(crate) supports_parallel_tool_calls: bool,
182    pub(crate) reject_binary_images: bool,
183}
184
185impl Default for OpenAIChatCapabilities {
186    fn default() -> Self {
187        Self {
188            supports_response_format: true,
189            supports_parallel_tool_calls: true,
190            reject_binary_images: false,
191        }
192    }
193}
194
195#[derive(Clone, Debug)]
196pub struct OpenAIProvider {
197    api_key: String,
198    base_url: Url,
199}
200
201impl OpenAIProvider {
202    pub fn new(
203        api_key: impl Into<String>,
204        base_url: impl AsRef<str>,
205    ) -> Result<Self, ProviderError> {
206        let url = Url::parse(base_url.as_ref())
207            .map_err(|_| ProviderError::InvalidModel(base_url.as_ref().to_string()))?;
208        Ok(Self {
209            api_key: api_key.into(),
210            base_url: url,
211        })
212    }
213
214    pub fn from_env() -> Result<Self, ProviderError> {
215        let api_key = std::env::var("OPENAI_API_KEY")
216            .map_err(|_| ProviderError::MissingApiKey("openai".to_string()))?;
217        Self::new(api_key, "https://api.openai.com/v1")
218    }
219
220    pub fn with_base_url(mut self, base_url: impl AsRef<str>) -> Result<Self, ProviderError> {
221        self.base_url = Url::parse(base_url.as_ref())
222            .map_err(|_| ProviderError::InvalidModel(base_url.as_ref().to_string()))?;
223        Ok(self)
224    }
225}
226
227#[cfg(test)]
228mod tests {
229    use super::*;
230    use base64::engine::general_purpose::STANDARD;
231    use serde_json::{Value, json};
232    use std::path::PathBuf;
233
234    use crate::messages::{
235        AudioUrl, BinaryContent, DocumentUrl, ImageUrl, ModelMessage, ModelRequest,
236        ModelRequestPart, ModelResponse, ModelResponsePart, ProviderItemPart, TextPart,
237        ToolCallPart, ToolReturnPart, UserContent,
238    };
239
240    fn fixture_bytes(name: &str) -> Vec<u8> {
241        let path = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
242            .join("tests")
243            .join("fixtures")
244            .join(name);
245        std::fs::read(path).expect("fixture read")
246    }
247
248    #[test]
249    fn convert_user_content_handles_binary_media() {
250        let model = OpenAIChatModel::new(
251            "gpt-4o-mini",
252            "test-key".to_string(),
253            Url::parse("https://example.com/").expect("valid url"),
254            None,
255        );
256
257        let image_bytes = fixture_bytes("fixture.jpg");
258        let audio_bytes = fixture_bytes("fixture.m4a");
259        let pdf_bytes = fixture_bytes("fixture.pdf");
260
261        let content = vec![
262            UserContent::Binary(BinaryContent {
263                data: image_bytes.clone(),
264                media_type: "image/jpeg".to_string(),
265            }),
266            UserContent::Binary(BinaryContent {
267                data: audio_bytes.clone(),
268                media_type: "audio/aac".to_string(),
269            }),
270            UserContent::Binary(BinaryContent {
271                data: pdf_bytes.clone(),
272                media_type: "application/pdf".to_string(),
273            }),
274        ];
275
276        let value = model
277            .convert_user_content(&content)
278            .expect("convert user content");
279        let parts = value.as_array().expect("parts array");
280        assert_eq!(parts.len(), 3);
281
282        let image = &parts[0];
283        assert_eq!(
284            image.get("type"),
285            Some(&Value::String("image_url".to_string()))
286        );
287        let image_url = image
288            .get("image_url")
289            .and_then(|value| value.get("url"))
290            .and_then(|value| value.as_str())
291            .expect("image url");
292        let expected_image = format!("data:image/jpeg;base64,{}", STANDARD.encode(&image_bytes));
293        assert_eq!(image_url, expected_image);
294
295        let audio = &parts[1];
296        assert_eq!(
297            audio.get("type"),
298            Some(&Value::String("input_audio".to_string()))
299        );
300        let audio_input = audio.get("input_audio").expect("input_audio");
301        assert_eq!(
302            audio_input.get("format"),
303            Some(&Value::String("aac".to_string()))
304        );
305        let audio_data = audio_input
306            .get("data")
307            .and_then(|value| value.as_str())
308            .expect("audio data");
309        assert_eq!(audio_data, STANDARD.encode(&audio_bytes));
310
311        let pdf = &parts[2];
312        assert_eq!(pdf.get("type"), Some(&Value::String("text".to_string())));
313        let pdf_text = pdf
314            .get("text")
315            .and_then(|value| value.as_str())
316            .expect("pdf text");
317        let expected_text = format!("[binary content: {} bytes]", pdf_bytes.len());
318        assert_eq!(pdf_text, expected_text);
319    }
320
321    #[test]
322    fn make_messages_replays_tool_calls() {
323        let model = OpenAIChatModel::new(
324            "gpt-4o-mini",
325            "test-key".to_string(),
326            Url::parse("https://example.com/").expect("valid url"),
327            None,
328        );
329
330        let messages = vec![
331            ModelMessage::Response(ModelResponse {
332                parts: vec![ModelResponsePart::ToolCall(ToolCallPart {
333                    id: "call-1".to_string(),
334                    name: "get_data".to_string(),
335                    arguments: json!({"a": 1}),
336                })],
337                usage: None,
338                model_name: None,
339                finish_reason: None,
340            }),
341            ModelMessage::Request(ModelRequest {
342                parts: vec![ModelRequestPart::ToolReturn(ToolReturnPart {
343                    tool_name: "get_data".to_string(),
344                    tool_call_id: "call-1".to_string(),
345                    content: json!({"ok": true}),
346                })],
347                instructions: None,
348            }),
349        ];
350
351        let out = model.make_messages(&messages).expect("make messages");
352        assert_eq!(out.len(), 2);
353
354        let assistant = out[0].as_object().expect("assistant message");
355        assert_eq!(
356            assistant.get("role"),
357            Some(&Value::String("assistant".to_string()))
358        );
359        assert_eq!(assistant.get("content"), Some(&Value::Null));
360        let tool_calls = assistant
361            .get("tool_calls")
362            .and_then(|value| value.as_array())
363            .expect("tool_calls");
364        assert_eq!(tool_calls.len(), 1);
365        let call = &tool_calls[0];
366        assert_eq!(call.get("id"), Some(&Value::String("call-1".to_string())));
367        let function = call.get("function").expect("function");
368        assert_eq!(
369            function.get("name"),
370            Some(&Value::String("get_data".to_string()))
371        );
372        assert_eq!(
373            function.get("arguments"),
374            Some(&Value::String("{\"a\":1}".to_string()))
375        );
376
377        let tool = out[1].as_object().expect("tool message");
378        assert_eq!(tool.get("role"), Some(&Value::String("tool".to_string())));
379        assert_eq!(
380            tool.get("tool_call_id"),
381            Some(&Value::String("call-1".to_string()))
382        );
383        assert_eq!(
384            tool.get("content"),
385            Some(&Value::String("{\"ok\":true}".to_string()))
386        );
387    }
388
389    #[test]
390    fn responses_replays_tool_calls() {
391        let model = OpenAIResponsesModel::new(
392            "gpt-5-mini",
393            "test-key".to_string(),
394            Url::parse("https://example.com/").expect("valid url"),
395            None,
396        );
397
398        let messages = vec![
399            ModelMessage::Response(ModelResponse {
400                parts: vec![ModelResponsePart::ToolCall(ToolCallPart {
401                    id: "call-1".to_string(),
402                    name: "get_data".to_string(),
403                    arguments: json!({"a": 1}),
404                })],
405                usage: None,
406                model_name: None,
407                finish_reason: None,
408            }),
409            ModelMessage::Request(ModelRequest {
410                parts: vec![ModelRequestPart::ToolReturn(ToolReturnPart {
411                    tool_name: "get_data".to_string(),
412                    tool_call_id: "call-1".to_string(),
413                    content: json!({"ok": true}),
414                })],
415                instructions: None,
416            }),
417        ];
418
419        let out = model
420            .make_input_messages(&messages)
421            .expect("make input messages");
422        assert_eq!(out.len(), 2);
423
424        let call = out[0].as_object().expect("function call item");
425        assert_eq!(
426            call.get("type"),
427            Some(&Value::String("function_call".to_string()))
428        );
429        assert_eq!(
430            call.get("call_id"),
431            Some(&Value::String("call-1".to_string()))
432        );
433        assert_eq!(
434            call.get("name"),
435            Some(&Value::String("get_data".to_string()))
436        );
437        assert_eq!(
438            call.get("arguments"),
439            Some(&Value::String("{\"a\":1}".to_string()))
440        );
441
442        let output = out[1].as_object().expect("function call output");
443        assert_eq!(
444            output.get("type"),
445            Some(&Value::String("function_call_output".to_string()))
446        );
447        assert_eq!(
448            output.get("call_id"),
449            Some(&Value::String("call-1".to_string()))
450        );
451        assert_eq!(
452            output.get("output"),
453            Some(&Value::String("{\"ok\":true}".to_string()))
454        );
455    }
456
457    #[test]
458    fn responses_replays_provider_items() {
459        let model = OpenAIResponsesModel::new(
460            "gpt-5-mini",
461            "test-key".to_string(),
462            Url::parse("https://example.com/").expect("valid url"),
463            None,
464        );
465
466        let raw_item = json!({
467            "type": "reasoning",
468            "summary": "ok"
469        });
470
471        let messages = vec![ModelMessage::Response(ModelResponse {
472            parts: vec![
473                ModelResponsePart::ProviderItem(ProviderItemPart {
474                    provider: "openai_responses".to_string(),
475                    payload: raw_item.clone(),
476                }),
477                ModelResponsePart::Text(TextPart {
478                    content: "ignored".to_string(),
479                }),
480            ],
481            usage: None,
482            model_name: None,
483            finish_reason: None,
484        })];
485
486        let out = model
487            .make_input_messages(&messages)
488            .expect("make input messages");
489        assert_eq!(out.len(), 1);
490        assert_eq!(out[0], raw_item);
491    }
492
493    #[test]
494    fn unified_model_streaming_prefers_chat_when_available() {
495        let model = OpenAIUnifiedModel::new(
496            "gpt-4o-mini",
497            "test-key".to_string(),
498            Url::parse("https://example.com/").expect("valid url"),
499            None,
500        );
501
502        let mode = model.select_api(&[], true).expect("select api for stream");
503        assert!(matches!(mode, OpenAIApiMode::Chat));
504    }
505
506    #[test]
507    fn unified_model_streaming_supports_responses_only() {
508        let model = OpenAIUnifiedModel::new(
509            "gpt-5-mini",
510            "test-key".to_string(),
511            Url::parse("https://example.com/").expect("valid url"),
512            None,
513        );
514
515        let mode = model.select_api(&[], true).expect("select api for stream");
516        assert!(matches!(mode, OpenAIApiMode::Responses));
517    }
518
519    #[test]
520    fn make_messages_groups_consecutive_tool_calls() {
521        let model = OpenAIChatModel::new(
522            "gpt-4o-mini",
523            "test-key".to_string(),
524            Url::parse("https://example.com/").expect("valid url"),
525            None,
526        );
527
528        let messages = vec![
529            ModelMessage::Response(ModelResponse {
530                parts: vec![ModelResponsePart::ToolCall(ToolCallPart {
531                    id: "call-1".to_string(),
532                    name: "get_data".to_string(),
533                    arguments: json!({"a": 1}),
534                })],
535                usage: None,
536                model_name: None,
537                finish_reason: None,
538            }),
539            ModelMessage::Response(ModelResponse {
540                parts: vec![ModelResponsePart::ToolCall(ToolCallPart {
541                    id: "call-2".to_string(),
542                    name: "get_data".to_string(),
543                    arguments: json!({"b": 2}),
544                })],
545                usage: None,
546                model_name: None,
547                finish_reason: None,
548            }),
549        ];
550
551        let out = model.make_messages(&messages).expect("make messages");
552        assert_eq!(out.len(), 1);
553        let assistant = out[0].as_object().expect("assistant message");
554        let tool_calls = assistant
555            .get("tool_calls")
556            .and_then(|value| value.as_array())
557            .expect("tool_calls");
558        assert_eq!(tool_calls.len(), 2);
559    }
560
561    #[test]
562    fn responses_build_request_maps_max_tokens() {
563        let model = OpenAIResponsesModel::new(
564            "gpt-5-mini",
565            "test-key".to_string(),
566            Url::parse("https://example.com/").expect("valid url"),
567            None,
568        );
569
570        let messages = vec![ModelMessage::Request(ModelRequest::user_text_prompt("hi"))];
571        let params = ModelRequestParameters::default();
572        let mut settings = Map::new();
573        settings.insert("max_tokens".to_string(), Value::Number(42.into()));
574
575        let request = model
576            .build_request(&messages, Some(&settings), &params, false)
577            .expect("build request");
578        let body = request.body.as_object().expect("body object");
579        assert!(body.contains_key("max_output_tokens"));
580        assert!(!body.contains_key("max_tokens"));
581    }
582
583    #[test]
584    fn responses_stream_helpers_parse_tool_calls_and_usage() {
585        let item = json!({
586            "type": "function_call",
587            "name": "echo",
588            "call_id": "call-1",
589            "arguments": "{\"msg\":\"hi\"}"
590        });
591        let call = parse_responses_stream_tool_call(&item).expect("tool call");
592        assert_eq!(call.name, "echo");
593        assert_eq!(call.id, "call-1");
594        assert_eq!(call.arguments, json!({"msg": "hi"}));
595
596        let response = json!({
597            "usage": {
598                "input_tokens": 10,
599                "output_tokens": 5
600            }
601        });
602        let usage = parse_responses_stream_usage(&response).expect("usage");
603        assert_eq!(usage.input_tokens, 10);
604        assert_eq!(usage.output_tokens, 5);
605    }
606
607    #[test]
608    fn helper_functions_cover_ids_and_media_types() {
609        assert!(is_text_like_media_type("text/plain"));
610        assert!(is_text_like_media_type("application/json"));
611        assert!(!is_text_like_media_type("image/png"));
612
613        assert_eq!(audio_format_from_media_type("audio/mpeg"), Some("mp3"));
614        assert_eq!(audio_format_from_media_type("audio/aac"), Some("aac"));
615        assert_eq!(audio_format_from_media_type("audio/unknown"), None);
616
617        let parsed = parse_data_url_base64("data:audio/mpeg;base64,SGVsbG8=").expect("parse");
618        assert_eq!(parsed.0, "audio/mpeg");
619        assert_eq!(parsed.1, "SGVsbG8=");
620        assert!(parse_data_url_base64("https://example.com").is_none());
621
622        let id = normalize_tool_call_id(Some("".to_string()));
623        assert!(id.starts_with("call_"));
624        let id = normalize_tool_call_id_str("");
625        assert!(id.starts_with("call_"));
626
627        let id = normalize_stream_tool_call_id(None, Some(2));
628        assert_eq!(id, "call_2");
629
630        let id = normalize_stream_tool_call_id(Some("explicit".to_string()), Some(1));
631        assert_eq!(id, "explicit");
632    }
633
634    #[test]
635    fn helper_functions_cover_text_and_urls() {
636        let long_body = "a".repeat(2100);
637        let truncated = truncate_error_body(&format!("{long_body}\n"));
638        assert!(truncated.ends_with("...[truncated]"));
639
640        let base = Url::parse("https://example.com/v1/").expect("url");
641        let joined = join_path(&base, "chat/completions").expect("join");
642        assert_eq!(joined.as_str(), "https://example.com/v1/chat/completions");
643
644        assert_eq!(tool_return_content(&json!("ok")), "ok");
645        assert_eq!(tool_call_arguments(&json!({"a": 1})), "{\"a\":1}");
646
647        assert!(is_responses_only_model("gpt-5-mini"));
648        assert!(!is_responses_only_model("gpt-4o-mini"));
649        assert!(prefers_responses("gpt-4o-mini"));
650        assert!(!prefers_responses("gpt-3.5-turbo"));
651    }
652
653    #[test]
654    fn contains_audio_detects_audio_inputs() {
655        let messages = vec![ModelMessage::Request(ModelRequest {
656            parts: vec![ModelRequestPart::UserPrompt(
657                crate::messages::UserPromptPart {
658                    content: vec![UserContent::Audio(AudioUrl {
659                        url: "data:audio/mpeg;base64,SGVsbG8=".to_string(),
660                        media_type: None,
661                    })],
662                },
663            )],
664            instructions: None,
665        })];
666        assert!(contains_audio(&messages));
667    }
668
669    #[test]
670    fn convert_user_content_handles_text_and_urls() {
671        let model = OpenAIChatModel::new(
672            "gpt-4o-mini",
673            "test-key".to_string(),
674            Url::parse("https://example.com/").expect("valid url"),
675            None,
676        );
677
678        let content = vec![
679            UserContent::Text("hello".to_string()),
680            UserContent::Image(ImageUrl {
681                url: "https://example.com/image.png".to_string(),
682                media_type: None,
683            }),
684            UserContent::Audio(AudioUrl {
685                url: "data:audio/mpeg;base64,SGVsbG8=".to_string(),
686                media_type: None,
687            }),
688            UserContent::Document(DocumentUrl {
689                url: "data:text/plain;base64,SGVsbG8=".to_string(),
690                media_type: None,
691            }),
692            UserContent::Document(DocumentUrl {
693                url: "https://example.com/doc.pdf".to_string(),
694                media_type: None,
695            }),
696        ];
697
698        let parts = model
699            .convert_user_content(&content)
700            .expect("convert user content");
701        let parts = parts.as_array().expect("parts array");
702        assert_eq!(parts.len(), 5);
703        assert_eq!(
704            parts[0].get("type"),
705            Some(&Value::String("text".to_string()))
706        );
707        assert_eq!(
708            parts[1].get("type"),
709            Some(&Value::String("image_url".to_string()))
710        );
711        assert_eq!(
712            parts[2].get("type"),
713            Some(&Value::String("input_audio".to_string()))
714        );
715        assert_eq!(
716            parts[3].get("text"),
717            Some(&Value::String("Hello".to_string()))
718        );
719        assert_eq!(
720            parts[4].get("text"),
721            Some(&Value::String(
722                "[document: https://example.com/doc.pdf]".to_string()
723            ))
724        );
725    }
726
727    #[test]
728    fn convert_user_content_rejects_binary_images_when_disabled() {
729        let model = OpenAIChatModel::new_with_capabilities(
730            "gpt-4o-mini",
731            "test-key".to_string(),
732            Url::parse("https://example.com/").expect("valid url"),
733            None,
734            OpenAIChatCapabilities {
735                supports_response_format: true,
736                supports_parallel_tool_calls: true,
737                reject_binary_images: true,
738            },
739        );
740
741        let content = vec![UserContent::Binary(BinaryContent {
742            data: vec![1, 2, 3],
743            media_type: "image/png".to_string(),
744        })];
745
746        let err = model
747            .convert_user_content(&content)
748            .expect_err("should error");
749        match err {
750            ModelError::Unsupported(message) => {
751                assert!(message.contains("binary image inputs"));
752            }
753            other => panic!("unexpected error: {other:?}"),
754        }
755    }
756
757    #[test]
758    fn responses_helpers_cover_media_filename_and_content() {
759        assert_eq!(
760            OpenAIResponsesModel::filename_for_media_type("application/pdf"),
761            "file.pdf"
762        );
763        assert_eq!(
764            OpenAIResponsesModel::filename_for_media_type("text/plain"),
765            "file.txt"
766        );
767        assert_eq!(
768            OpenAIResponsesModel::filename_for_media_type("image/png"),
769            "file.bin"
770        );
771
772        let model = OpenAIResponsesModel::new(
773            "gpt-5-mini",
774            "test-key".to_string(),
775            Url::parse("https://example.com/").expect("valid url"),
776            None,
777        );
778
779        let content = vec![
780            UserContent::Binary(BinaryContent {
781                data: b"hello".to_vec(),
782                media_type: "text/plain".to_string(),
783            }),
784            UserContent::Document(DocumentUrl {
785                url: "data:application/pdf;base64,SGVsbG8=".to_string(),
786                media_type: None,
787            }),
788        ];
789
790        let parts = model
791            .convert_user_content(&content)
792            .expect("convert content");
793        let parts = parts.as_array().expect("parts array");
794        assert_eq!(parts.len(), 2);
795        assert_eq!(
796            parts[0].get("type"),
797            Some(&Value::String("input_text".to_string()))
798        );
799        assert_eq!(
800            parts[1].get("type"),
801            Some(&Value::String("input_file".to_string()))
802        );
803    }
804}
805
806impl Provider for OpenAIProvider {
807    fn name(&self) -> &str {
808        "openai"
809    }
810
811    fn model(&self, model: &str, settings: Option<ModelSettings>) -> Arc<dyn Model> {
812        Arc::new(OpenAIUnifiedModel::new(
813            model,
814            self.api_key.clone(),
815            self.base_url.clone(),
816            settings,
817        ))
818    }
819}
820
821#[derive(Clone, Debug)]
822pub struct OpenAIChatModel {
823    model: String,
824    api_key: String,
825    base_url: Url,
826    client: Client,
827    default_settings: Option<ModelSettings>,
828    capabilities: OpenAIChatCapabilities,
829}
830
831impl OpenAIChatModel {
832    pub fn new(
833        model: impl Into<String>,
834        api_key: String,
835        base_url: Url,
836        settings: Option<ModelSettings>,
837    ) -> Self {
838        Self::new_with_capabilities(
839            model,
840            api_key,
841            base_url,
842            settings,
843            OpenAIChatCapabilities::default(),
844        )
845    }
846
847    pub(crate) fn new_with_capabilities(
848        model: impl Into<String>,
849        api_key: String,
850        base_url: Url,
851        settings: Option<ModelSettings>,
852        capabilities: OpenAIChatCapabilities,
853    ) -> Self {
854        Self {
855            model: model.into(),
856            api_key,
857            base_url,
858            client: Client::new(),
859            default_settings: settings,
860            capabilities,
861        }
862    }
863
864    fn endpoint(&self) -> Result<Url, ModelError> {
865        join_path(&self.base_url, "chat/completions")
866    }
867
868    fn make_messages(&self, messages: &[ModelMessage]) -> Result<Vec<Value>, ModelError> {
869        let mut out = Vec::new();
870        for message in messages {
871            match message {
872                ModelMessage::Request(req) => {
873                    if let Some(instructions) = req
874                        .instructions
875                        .as_ref()
876                        .filter(|value| !value.trim().is_empty())
877                    {
878                        out.push(json!({"role": "system", "content": instructions}));
879                    }
880                    for part in &req.parts {
881                        match part {
882                            ModelRequestPart::SystemPrompt(prompt) => {
883                                out.push(json!({"role": "system", "content": prompt.content}))
884                            }
885                            ModelRequestPart::UserPrompt(prompt) => {
886                                let content = self.convert_user_content(&prompt.content)?;
887                                out.push(json!({"role": "user", "content": content}))
888                            }
889                            ModelRequestPart::ToolReturn(tool_return) => {
890                                let content = tool_return_content(&tool_return.content);
891                                out.push(json!({
892                                    "role": "tool",
893                                    "tool_call_id": normalize_tool_call_id_str(&tool_return.tool_call_id),
894                                    "content": content,
895                                }))
896                            }
897                            ModelRequestPart::RetryPrompt(retry) => {
898                                if retry.tool_name.is_some() {
899                                    out.push(json!({
900                                        "role": "tool",
901                                        "tool_call_id": normalize_tool_call_id(retry.tool_call_id.clone()),
902                                        "content": retry.content,
903                                    }));
904                                } else {
905                                    out.push(json!({
906                                        "role": "user",
907                                        "content": retry.content,
908                                    }));
909                                }
910                            }
911                        }
912                    }
913                }
914                ModelMessage::Response(res) => {
915                    let text = res.text();
916                    let tool_calls = res.tool_calls();
917
918                    if text.is_none() && tool_calls.is_empty() {
919                        continue;
920                    }
921
922                    let calls = tool_calls
923                        .into_iter()
924                        .map(|call| {
925                            let args = tool_call_arguments(&call.arguments);
926                            json!({
927                                "id": normalize_tool_call_id_str(&call.id),
928                                "type": "function",
929                                "function": {
930                                    "name": call.name,
931                                    "arguments": args,
932                                }
933                            })
934                        })
935                        .collect::<Vec<_>>();
936
937                    if text.is_none()
938                        && !calls.is_empty()
939                        && let Some(Value::Object(last)) = out.last_mut()
940                    {
941                        let is_assistant =
942                            last.get("role").and_then(|value| value.as_str()) == Some("assistant");
943                        let is_tool_calls = last.get("content").is_some_and(Value::is_null)
944                            && last.get("tool_calls").is_some();
945                        if is_assistant
946                            && is_tool_calls
947                            && let Some(existing) =
948                                last.get_mut("tool_calls").and_then(Value::as_array_mut)
949                        {
950                            existing.extend(calls);
951                            continue;
952                        }
953                    }
954
955                    let mut msg = Map::new();
956                    msg.insert("role".to_string(), Value::String("assistant".to_string()));
957
958                    if let Some(text) = text {
959                        msg.insert("content".to_string(), Value::String(text));
960                    } else if !calls.is_empty() {
961                        msg.insert("content".to_string(), Value::Null);
962                    }
963
964                    if !calls.is_empty() {
965                        msg.insert("tool_calls".to_string(), Value::Array(calls));
966                    }
967
968                    out.push(Value::Object(msg));
969                }
970            }
971        }
972        Ok(out)
973    }
974
975    fn convert_user_content(&self, content: &[UserContent]) -> Result<Value, ModelError> {
976        let mut parts = Vec::new();
977        for item in content {
978            match item {
979                UserContent::Text(text) => parts.push(json!({"type": "text", "text": text})),
980                UserContent::Image(image) => parts.push(json!({
981                    "type": "image_url",
982                    "image_url": {"url": image.url}
983                })),
984                UserContent::Binary(BinaryContent { data, media_type }) => {
985                    if media_type.starts_with("image/") {
986                        if self.capabilities.reject_binary_images {
987                            return Err(ModelError::Unsupported(
988                                "binary image inputs are not supported; provide an image URL"
989                                    .to_string(),
990                            ));
991                        }
992                        let encoded = general_purpose::STANDARD.encode(data);
993                        let data_url = format!("data:{};base64,{}", media_type, encoded);
994                        parts.push(json!({
995                            "type": "image_url",
996                            "image_url": {"url": data_url}
997                        }))
998                    } else if media_type.starts_with("audio/") {
999                        if let Some(format) = audio_format_from_media_type(media_type) {
1000                            let encoded = general_purpose::STANDARD.encode(data);
1001                            parts.push(json!({
1002                                "type": "input_audio",
1003                                "input_audio": {
1004                                    "data": encoded,
1005                                    "format": format
1006                                }
1007                            }))
1008                        } else {
1009                            parts.push(json!({
1010                                "type": "text",
1011                                "text": format!("[audio content: {} bytes]", data.len())
1012                            }))
1013                        }
1014                    } else if is_text_like_media_type(media_type) {
1015                        match std::str::from_utf8(data) {
1016                            Ok(text) => parts.push(json!({"type": "text", "text": text})),
1017                            Err(_) => parts.push(json!({
1018                                "type": "text",
1019                                "text": format!("[binary content: {} bytes]", data.len())
1020                            })),
1021                        }
1022                    } else {
1023                        parts.push(json!({
1024                            "type": "text",
1025                            "text": format!("[binary content: {} bytes]", data.len())
1026                        }))
1027                    }
1028                }
1029                UserContent::Audio(audio) => {
1030                    if let Some((media_type, data)) = parse_data_url_base64(&audio.url)
1031                        && let Some(format) = audio_format_from_media_type(&media_type)
1032                    {
1033                        parts.push(json!({
1034                            "type": "input_audio",
1035                            "input_audio": {
1036                                "data": data,
1037                                "format": format
1038                            }
1039                        }))
1040                    } else {
1041                        parts.push(json!({
1042                            "type": "text",
1043                            "text": format!("[audio: {}]", audio.url)
1044                        }))
1045                    }
1046                }
1047                UserContent::Video(video) => parts.push(json!({
1048                    "type": "text",
1049                    "text": format!("[video: {}]", video.url)
1050                })),
1051                UserContent::Document(doc) => {
1052                    if let Some((media_type, data)) = parse_data_url_base64(&doc.url)
1053                        && is_text_like_media_type(&media_type)
1054                    {
1055                        match general_purpose::STANDARD.decode(data.as_bytes()) {
1056                            Ok(bytes) => match String::from_utf8(bytes) {
1057                                Ok(text) => parts.push(json!({"type": "text", "text": text})),
1058                                Err(_) => parts.push(json!({
1059                                    "type": "text",
1060                                    "text": format!("[document: {}]", doc.url)
1061                                })),
1062                            },
1063                            Err(_) => parts.push(json!({
1064                                "type": "text",
1065                                "text": format!("[document: {}]", doc.url)
1066                            })),
1067                        }
1068                    } else {
1069                        parts.push(json!({
1070                            "type": "text",
1071                            "text": format!("[document: {}]", doc.url)
1072                        }))
1073                    }
1074                }
1075            }
1076        }
1077
1078        Ok(Value::Array(parts))
1079    }
1080
1081    fn build_body(
1082        &self,
1083        messages: &[ModelMessage],
1084        params: &ModelRequestParameters,
1085        stream: bool,
1086    ) -> Result<Value, ModelError> {
1087        let mut body = Map::new();
1088        body.insert("model".to_string(), Value::String(self.model.clone()));
1089        body.insert(
1090            "messages".to_string(),
1091            Value::Array(self.make_messages(messages)?),
1092        );
1093
1094        if !params.function_tools.is_empty() {
1095            let tools = params
1096                .function_tools
1097                .iter()
1098                .map(|tool| {
1099                    let (schema, _strict_ok) =
1100                        transform_openai_schema(&tool.parameters_json_schema, None);
1101                    json!({
1102                        "type": "function",
1103                        "function": {
1104                            "name": tool.name,
1105                            "description": tool.description,
1106                            "parameters": schema,
1107                        }
1108                    })
1109                })
1110                .collect();
1111            body.insert("tools".to_string(), Value::Array(tools));
1112            body.insert("tool_choice".to_string(), Value::String("auto".to_string()));
1113            if self.capabilities.supports_parallel_tool_calls
1114                && params.function_tools.iter().any(|tool| tool.sequential)
1115            {
1116                body.insert("parallel_tool_calls".to_string(), Value::Bool(false));
1117            }
1118        }
1119
1120        if params.output_mode == OutputMode::JsonSchema
1121            && let Some(schema) = params.output_schema.clone()
1122            && self.capabilities.supports_response_format
1123        {
1124            let strict = !params.allow_text_output;
1125            let (schema, _strict_ok) = transform_openai_schema(&schema, Some(strict));
1126            body.insert(
1127                "response_format".to_string(),
1128                json!({
1129                    "type": "json_schema",
1130                    "json_schema": {
1131                        "name": "output",
1132                        "schema": schema,
1133                        "strict": strict,
1134                    }
1135                }),
1136            );
1137        }
1138
1139        if stream {
1140            body.insert("stream".to_string(), Value::Bool(true));
1141            body.insert("stream_options".to_string(), json!({"include_usage": true}));
1142        }
1143
1144        if let Some(settings) = &self.default_settings {
1145            for (key, value) in settings {
1146                body.entry(key.clone()).or_insert(value.clone());
1147            }
1148        }
1149
1150        Ok(Value::Object(body))
1151    }
1152
1153    fn build_request(
1154        &self,
1155        messages: &[ModelMessage],
1156        settings: Option<&ModelSettings>,
1157        params: &ModelRequestParameters,
1158        stream: bool,
1159    ) -> Result<OpenAIRequest, ModelError> {
1160        let mut body = self.build_body(messages, params, stream)?;
1161        if let Some(settings) = settings
1162            && let Value::Object(map) = &mut body
1163        {
1164            for (key, value) in settings {
1165                map.insert(key.clone(), value.clone());
1166            }
1167        }
1168        Ok(OpenAIRequest { body })
1169    }
1170
1171    fn parse_tool_call(tool_call: &OpenAIToolCall) -> ToolCallPart {
1172        let args = tool_call
1173            .function
1174            .arguments
1175            .as_ref()
1176            .and_then(|arg| serde_json::from_str::<Value>(arg).ok())
1177            .unwrap_or_else(|| {
1178                tool_call
1179                    .function
1180                    .arguments
1181                    .clone()
1182                    .map(Value::String)
1183                    .unwrap_or_else(|| Value::Object(Map::new()))
1184            });
1185
1186        ToolCallPart {
1187            id: normalize_tool_call_id(tool_call.id.clone()),
1188            name: tool_call
1189                .function
1190                .name
1191                .clone()
1192                .unwrap_or_else(|| "tool".to_string()),
1193            arguments: args,
1194        }
1195    }
1196}
1197
1198#[async_trait]
1199impl Model for OpenAIChatModel {
1200    fn name(&self) -> &str {
1201        &self.model
1202    }
1203
1204    async fn request(
1205        &self,
1206        messages: &[ModelMessage],
1207        settings: Option<&ModelSettings>,
1208        params: &ModelRequestParameters,
1209    ) -> Result<ModelResponse, ModelError> {
1210        tracing::debug!(
1211            model = %self.model,
1212            tool_count = params.function_tools.len(),
1213            output_schema = params.output_schema.is_some(),
1214            "OpenAI chat request"
1215        );
1216        let request = self.build_request(messages, settings, params, false)?;
1217
1218        let response = self
1219            .client
1220            .post(self.endpoint()?)
1221            .bearer_auth(&self.api_key)
1222            .json(&request.body)
1223            .send()
1224            .await
1225            .map_err(|e| map_reqwest_error("OpenAI", e))?;
1226
1227        let status = response.status();
1228        if !status.is_success() {
1229            let body = response.text().await.unwrap_or_default();
1230            tracing::error!(
1231                status = status.as_u16(),
1232                model = %self.model,
1233                body = %truncate_error_body(&body),
1234                "OpenAI chat request failed"
1235            );
1236            return Err(ModelError::HttpStatus {
1237                status: status.as_u16(),
1238            });
1239        }
1240
1241        let body: OpenAIChatResponse = response.json().await.map_err(|e| {
1242            tracing::error!(error = %e, model = %self.model, "OpenAI response parse failed");
1243            ModelError::Provider(format!("OpenAI response parse failed: {e}"))
1244        })?;
1245
1246        let choice = body.choices.into_iter().next().ok_or_else(|| {
1247            tracing::error!(model = %self.model, "OpenAI response missing choices");
1248            ModelError::Provider("OpenAI response missing choices".to_string())
1249        })?;
1250
1251        let mut parts = Vec::new();
1252        if let Some(content) = choice.message.content {
1253            parts.push(ModelResponsePart::Text(TextPart { content }));
1254        }
1255
1256        if let Some(tool_calls) = choice.message.tool_calls {
1257            for call in tool_calls {
1258                parts.push(ModelResponsePart::ToolCall(Self::parse_tool_call(&call)));
1259            }
1260        } else if let Some(function_call) = choice.message.function_call {
1261            parts.push(ModelResponsePart::ToolCall(ToolCallPart {
1262                id: normalize_tool_call_id(None),
1263                name: function_call.name.unwrap_or_else(|| "tool".to_string()),
1264                arguments: function_call
1265                    .arguments
1266                    .as_ref()
1267                    .and_then(|arg| serde_json::from_str::<Value>(arg).ok())
1268                    .unwrap_or_else(|| {
1269                        function_call
1270                            .arguments
1271                            .clone()
1272                            .map(Value::String)
1273                            .unwrap_or_else(|| Value::Object(Map::new()))
1274                    }),
1275            }));
1276        }
1277
1278        let usage = body.usage.map(|usage| RequestUsage {
1279            input_tokens: usage.prompt_tokens.unwrap_or(0),
1280            output_tokens: usage.completion_tokens.unwrap_or(0),
1281            ..Default::default()
1282        });
1283
1284        Ok(ModelResponse {
1285            parts,
1286            usage,
1287            model_name: Some(self.model.clone()),
1288            finish_reason: choice.finish_reason,
1289        })
1290    }
1291
1292    async fn request_stream(
1293        &self,
1294        messages: &[ModelMessage],
1295        settings: Option<&ModelSettings>,
1296        params: &ModelRequestParameters,
1297    ) -> Result<ModelStream, ModelError> {
1298        tracing::debug!(
1299            model = %self.model,
1300            tool_count = params.function_tools.len(),
1301            output_schema = params.output_schema.is_some(),
1302            "OpenAI stream request"
1303        );
1304        let request = self.build_request(messages, settings, params, true)?;
1305
1306        let response = self
1307            .client
1308            .post(self.endpoint()?)
1309            .bearer_auth(&self.api_key)
1310            .json(&request.body)
1311            .send()
1312            .await
1313            .map_err(|e| map_reqwest_error("OpenAI stream", e))?;
1314
1315        let status = response.status();
1316        if !status.is_success() {
1317            let body = response.text().await.unwrap_or_default();
1318            tracing::error!(
1319                status = status.as_u16(),
1320                model = %self.model,
1321                body = %truncate_error_body(&body),
1322                "OpenAI stream request failed"
1323            );
1324            return Err(ModelError::HttpStatus {
1325                status: status.as_u16(),
1326            });
1327        }
1328
1329        let mut event_stream = response.bytes_stream().eventsource();
1330        let model_name = self.model.clone();
1331
1332        let s = try_stream! {
1333            let mut tool_accumulator: HashMap<String, ToolAccumulator> = HashMap::new();
1334            while let Some(event) = event_stream.next().await {
1335                let event = event.map_err(|e| {
1336                    tracing::error!(error = %e, model = %model_name, "OpenAI stream error");
1337                    ModelError::Provider(format!("OpenAI stream error: {e}"))
1338                })?;
1339                let data = event.data;
1340                if data.trim() == "[DONE]" {
1341                    if !tool_accumulator.is_empty() {
1342                        for (_id, acc) in tool_accumulator.drain() {
1343                            let args = serde_json::from_str::<Value>(&acc.arguments)
1344                                .unwrap_or_else(|_| Value::String(acc.arguments.clone()));
1345                            yield StreamChunk {
1346                                text_delta: None,
1347                                tool_call: Some(ToolCallPart {
1348                                    id: acc.id.clone(),
1349                                    name: acc.name.unwrap_or_else(|| "tool".to_string()),
1350                                    arguments: args,
1351                                }),
1352                                finish_reason: None,
1353                                usage: None,
1354                            };
1355                        }
1356                    }
1357                    break;
1358                }
1359
1360                let chunk: OpenAIChatStreamResponse = serde_json::from_str(&data)
1361                    .map_err(|e| {
1362                        tracing::error!(error = %e, model = %model_name, "OpenAI stream parse error");
1363                        ModelError::Provider(format!("OpenAI stream parse error: {e}"))
1364                    })?;
1365                if let Some(choice) = chunk.choices.into_iter().next() {
1366                    if let Some(content) = choice.delta.content {
1367                        yield StreamChunk {
1368                            text_delta: Some(content),
1369                            tool_call: None,
1370                            finish_reason: None,
1371                            usage: None,
1372                        };
1373                    }
1374
1375                    if let Some(tool_calls) = choice.delta.tool_calls {
1376                        for call in tool_calls {
1377                            let id = normalize_stream_tool_call_id(call.id.clone(), call.index);
1378                            let entry = tool_accumulator.entry(id.clone()).or_insert_with(|| ToolAccumulator {
1379                                id,
1380                                name: None,
1381                                arguments: String::new(),
1382                            });
1383                            if let Some(name) = call.function.name {
1384                                entry.name = Some(name);
1385                            }
1386                            if let Some(args) = call.function.arguments {
1387                                entry.arguments.push_str(&args);
1388                            }
1389                        }
1390                    }
1391
1392                    if let Some(reason) = choice.finish_reason.clone() {
1393                        if !tool_accumulator.is_empty() {
1394                            for (_id, acc) in tool_accumulator.drain() {
1395                                let args = serde_json::from_str::<Value>(&acc.arguments)
1396                                    .unwrap_or_else(|_| Value::String(acc.arguments.clone()));
1397                                yield StreamChunk {
1398                                    text_delta: None,
1399                                    tool_call: Some(ToolCallPart {
1400                                        id: acc.id.clone(),
1401                                        name: acc.name.unwrap_or_else(|| "tool".to_string()),
1402                                        arguments: args,
1403                                    }),
1404                                    finish_reason: Some(reason.clone()),
1405                                    usage: None,
1406                                };
1407                            }
1408                        }
1409                        yield StreamChunk {
1410                            text_delta: None,
1411                            tool_call: None,
1412                            finish_reason: Some(reason),
1413                            usage: chunk.usage.map(|usage| RequestUsage {
1414                                input_tokens: usage.prompt_tokens.unwrap_or(0),
1415                                output_tokens: usage.completion_tokens.unwrap_or(0),
1416                                ..Default::default()
1417                            }),
1418                        };
1419                    }
1420                }
1421            }
1422        };
1423
1424        Ok(Box::pin(s))
1425    }
1426}
1427
1428#[derive(Debug, Deserialize)]
1429struct OpenAIChatResponse {
1430    choices: Vec<OpenAIChoice>,
1431    usage: Option<OpenAIUsage>,
1432}
1433
1434#[derive(Debug, Deserialize)]
1435struct OpenAIChoice {
1436    message: OpenAIMessage,
1437    finish_reason: Option<String>,
1438}
1439
1440#[derive(Debug, Deserialize)]
1441struct OpenAIMessage {
1442    content: Option<String>,
1443    tool_calls: Option<Vec<OpenAIToolCall>>,
1444    function_call: Option<OpenAIFunctionCall>,
1445}
1446
1447#[derive(Debug, Deserialize)]
1448struct OpenAIToolCall {
1449    id: Option<String>,
1450    function: OpenAIToolFunction,
1451}
1452
1453#[derive(Debug, Deserialize)]
1454struct OpenAIToolFunction {
1455    name: Option<String>,
1456    arguments: Option<String>,
1457}
1458
1459#[derive(Debug, Deserialize)]
1460struct OpenAIFunctionCall {
1461    name: Option<String>,
1462    arguments: Option<String>,
1463}
1464
1465#[derive(Debug, Deserialize)]
1466struct OpenAIUsage {
1467    prompt_tokens: Option<u64>,
1468    completion_tokens: Option<u64>,
1469}
1470
1471#[derive(Debug, Deserialize)]
1472struct OpenAIChatStreamResponse {
1473    choices: Vec<OpenAIChatStreamChoice>,
1474    usage: Option<OpenAIUsage>,
1475}
1476
1477#[derive(Debug, Deserialize)]
1478struct OpenAIChatStreamChoice {
1479    delta: OpenAIChatStreamDelta,
1480    finish_reason: Option<String>,
1481}
1482
1483#[derive(Debug, Deserialize)]
1484struct OpenAIChatStreamDelta {
1485    content: Option<String>,
1486    tool_calls: Option<Vec<OpenAIStreamToolCall>>,
1487}
1488
1489#[derive(Debug, Deserialize)]
1490struct OpenAIStreamToolCall {
1491    id: Option<String>,
1492    index: Option<usize>,
1493    function: OpenAIStreamToolFunction,
1494}
1495
1496#[derive(Debug, Deserialize)]
1497struct OpenAIStreamToolFunction {
1498    name: Option<String>,
1499    arguments: Option<String>,
1500}
1501
1502#[derive(Debug)]
1503struct ToolAccumulator {
1504    id: String,
1505    name: Option<String>,
1506    arguments: String,
1507}
1508
1509#[derive(Debug, Deserialize)]
1510struct OpenAIResponsesStreamEvent {
1511    #[serde(rename = "type")]
1512    kind: String,
1513    response: Option<Value>,
1514    item: Option<Value>,
1515    delta: Option<String>,
1516}
1517
1518fn parse_responses_stream_usage(value: &Value) -> Option<RequestUsage> {
1519    let usage = value.get("usage")?;
1520    let input_tokens = usage
1521        .get("input_tokens")
1522        .and_then(|v| v.as_u64())
1523        .unwrap_or(0);
1524    let output_tokens = usage
1525        .get("output_tokens")
1526        .and_then(|v| v.as_u64())
1527        .unwrap_or(0);
1528    Some(RequestUsage {
1529        input_tokens,
1530        output_tokens,
1531        ..Default::default()
1532    })
1533}
1534
1535fn parse_responses_stream_tool_call(item: &Value) -> Option<ToolCallPart> {
1536    let item_type = item.get("type").and_then(|v| v.as_str())?;
1537    if item_type != "function_call" {
1538        return None;
1539    }
1540    let name = item
1541        .get("name")
1542        .and_then(|value| value.as_str())
1543        .unwrap_or("tool")
1544        .to_string();
1545    let call_id = item
1546        .get("call_id")
1547        .and_then(|value| value.as_str())
1548        .map(str::to_string)
1549        .or_else(|| {
1550            item.get("id")
1551                .and_then(|value| value.as_str())
1552                .map(str::to_string)
1553        });
1554    let arguments = item.get("arguments").cloned().unwrap_or(Value::Null);
1555    let args = match arguments {
1556        Value::String(value) => {
1557            serde_json::from_str::<Value>(&value).unwrap_or(Value::String(value))
1558        }
1559        other => other,
1560    };
1561    Some(ToolCallPart {
1562        id: normalize_tool_call_id(call_id),
1563        name,
1564        arguments: args,
1565    })
1566}
1567
1568#[derive(Clone, Debug)]
1569pub struct OpenAIUnifiedModel {
1570    model: String,
1571    chat: OpenAIChatModel,
1572    responses: OpenAIResponsesModel,
1573    responses_only: bool,
1574    prefer_responses: bool,
1575}
1576
1577impl OpenAIUnifiedModel {
1578    pub fn new(
1579        model: impl Into<String>,
1580        api_key: String,
1581        base_url: Url,
1582        settings: Option<ModelSettings>,
1583    ) -> Self {
1584        let model = model.into();
1585        let responses_only = is_responses_only_model(&model);
1586        let prefer_responses = prefers_responses(&model);
1587        Self {
1588            chat: OpenAIChatModel::new(
1589                model.clone(),
1590                api_key.clone(),
1591                base_url.clone(),
1592                settings.clone(),
1593            ),
1594            responses: OpenAIResponsesModel::new(model.clone(), api_key, base_url, settings),
1595            model,
1596            responses_only,
1597            prefer_responses,
1598        }
1599    }
1600
1601    fn select_api(
1602        &self,
1603        messages: &[ModelMessage],
1604        stream: bool,
1605    ) -> Result<OpenAIApiMode, ModelError> {
1606        if contains_audio(messages) {
1607            if self.responses_only {
1608                return Err(ModelError::Unsupported(
1609                    "OpenAI Responses API does not support audio input".to_string(),
1610                ));
1611            }
1612            return Ok(OpenAIApiMode::Chat);
1613        }
1614        if stream {
1615            if self.responses_only {
1616                return Ok(OpenAIApiMode::Responses);
1617            }
1618            return Ok(OpenAIApiMode::Chat);
1619        }
1620        if self.prefer_responses || self.responses_only {
1621            Ok(OpenAIApiMode::Responses)
1622        } else {
1623            Ok(OpenAIApiMode::Chat)
1624        }
1625    }
1626}
1627
1628#[derive(Clone, Copy, Debug)]
1629enum OpenAIApiMode {
1630    Chat,
1631    Responses,
1632}
1633
1634#[async_trait]
1635impl Model for OpenAIUnifiedModel {
1636    fn name(&self) -> &str {
1637        &self.model
1638    }
1639
1640    async fn request(
1641        &self,
1642        messages: &[ModelMessage],
1643        settings: Option<&ModelSettings>,
1644        params: &ModelRequestParameters,
1645    ) -> Result<ModelResponse, ModelError> {
1646        match self.select_api(messages, false)? {
1647            OpenAIApiMode::Chat => self.chat.request(messages, settings, params).await,
1648            OpenAIApiMode::Responses => self.responses.request(messages, settings, params).await,
1649        }
1650    }
1651
1652    async fn request_stream(
1653        &self,
1654        messages: &[ModelMessage],
1655        settings: Option<&ModelSettings>,
1656        params: &ModelRequestParameters,
1657    ) -> Result<ModelStream, ModelError> {
1658        match self.select_api(messages, true)? {
1659            OpenAIApiMode::Chat => self.chat.request_stream(messages, settings, params).await,
1660            OpenAIApiMode::Responses => {
1661                self.responses
1662                    .request_stream(messages, settings, params)
1663                    .await
1664            }
1665        }
1666    }
1667}
1668
1669#[derive(Clone, Debug)]
1670pub struct OpenAIResponsesModel {
1671    model: String,
1672    api_key: String,
1673    base_url: Url,
1674    client: Client,
1675    default_settings: Option<ModelSettings>,
1676}
1677
1678impl OpenAIResponsesModel {
1679    pub fn new(
1680        model: impl Into<String>,
1681        api_key: String,
1682        base_url: Url,
1683        settings: Option<ModelSettings>,
1684    ) -> Self {
1685        Self {
1686            model: model.into(),
1687            api_key,
1688            base_url,
1689            client: Client::new(),
1690            default_settings: settings,
1691        }
1692    }
1693
1694    fn endpoint(&self) -> Result<Url, ModelError> {
1695        join_path(&self.base_url, "responses")
1696    }
1697
1698    fn filename_for_media_type(media_type: &str) -> String {
1699        let ext = match media_type {
1700            "application/pdf" => "pdf",
1701            "text/plain" => "txt",
1702            "text/markdown" => "md",
1703            "application/json" => "json",
1704            _ => "bin",
1705        };
1706        format!("file.{ext}")
1707    }
1708
1709    fn make_input_messages(&self, messages: &[ModelMessage]) -> Result<Vec<Value>, ModelError> {
1710        let mut out = Vec::new();
1711        for message in messages {
1712            match message {
1713                ModelMessage::Request(req) => {
1714                    if let Some(instructions) = req
1715                        .instructions
1716                        .as_ref()
1717                        .filter(|value| !value.trim().is_empty())
1718                    {
1719                        out.push(json!({"role": "system", "content": instructions}));
1720                    }
1721                    for part in &req.parts {
1722                        match part {
1723                            ModelRequestPart::SystemPrompt(prompt) => {
1724                                out.push(json!({"role": "system", "content": prompt.content}))
1725                            }
1726                            ModelRequestPart::UserPrompt(prompt) => {
1727                                let content = self.convert_user_content(&prompt.content)?;
1728                                out.push(json!({"role": "user", "content": content}))
1729                            }
1730                            ModelRequestPart::ToolReturn(tool_return) => {
1731                                let content = tool_return_content(&tool_return.content);
1732                                out.push(json!({
1733                                    "type": "function_call_output",
1734                                    "call_id": normalize_tool_call_id_str(&tool_return.tool_call_id),
1735                                    "output": content,
1736                                }))
1737                            }
1738                            ModelRequestPart::RetryPrompt(retry) => {
1739                                if retry.tool_name.is_some() {
1740                                    out.push(json!({
1741                                        "type": "function_call_output",
1742                                        "call_id": normalize_tool_call_id(retry.tool_call_id.clone()),
1743                                        "output": retry.content,
1744                                    }));
1745                                } else {
1746                                    out.push(json!({
1747                                        "role": "user",
1748                                        "content": [ { "type": "input_text", "text": retry.content } ],
1749                                    }));
1750                                }
1751                            }
1752                        }
1753                    }
1754                }
1755                ModelMessage::Response(res) => {
1756                    let provider_items: Vec<Value> = res
1757                        .parts
1758                        .iter()
1759                        .filter_map(|part| match part {
1760                            ModelResponsePart::ProviderItem(item)
1761                                if item.provider == "openai_responses" =>
1762                            {
1763                                Some(item.payload.clone())
1764                            }
1765                            _ => None,
1766                        })
1767                        .collect();
1768                    if !provider_items.is_empty() {
1769                        out.extend(provider_items);
1770                        continue;
1771                    }
1772                    if let Some(text) = res.text() {
1773                        out.push(json!({"role": "assistant", "content": text}));
1774                    }
1775                    for call in res.tool_calls() {
1776                        let args = tool_call_arguments(&call.arguments);
1777                        out.push(json!({
1778                            "type": "function_call",
1779                            "call_id": normalize_tool_call_id_str(&call.id),
1780                            "name": call.name,
1781                            "arguments": args,
1782                        }));
1783                    }
1784                }
1785            }
1786        }
1787        Ok(out)
1788    }
1789
1790    fn convert_user_content(&self, content: &[UserContent]) -> Result<Value, ModelError> {
1791        let mut parts = Vec::new();
1792        for item in content {
1793            match item {
1794                UserContent::Text(text) => parts.push(json!({"type": "input_text", "text": text})),
1795                UserContent::Image(image) => parts.push(json!({
1796                    "type": "input_image",
1797                    "image_url": image.url
1798                })),
1799                UserContent::Binary(BinaryContent { data, media_type }) => {
1800                    if media_type.starts_with("image/") {
1801                        let encoded = general_purpose::STANDARD.encode(data);
1802                        let data_url = format!("data:{};base64,{}", media_type, encoded);
1803                        parts.push(json!({
1804                            "type": "input_image",
1805                            "image_url": data_url
1806                        }));
1807                    } else if media_type == "application/pdf" {
1808                        let encoded = general_purpose::STANDARD.encode(data);
1809                        let data_url = format!("data:{};base64,{}", media_type, encoded);
1810                        parts.push(json!({
1811                            "type": "input_file",
1812                            "file_data": data_url,
1813                            "filename": Self::filename_for_media_type(media_type),
1814                        }));
1815                    } else if is_text_like_media_type(media_type) {
1816                        match std::str::from_utf8(data) {
1817                            Ok(text) => parts.push(json!({"type": "input_text", "text": text})),
1818                            Err(_) => parts.push(json!({
1819                                "type": "input_text",
1820                                "text": format!("[binary content: {} bytes]", data.len())
1821                            })),
1822                        }
1823                    } else {
1824                        parts.push(json!({
1825                            "type": "input_text",
1826                            "text": format!("[binary content: {} bytes]", data.len())
1827                        }))
1828                    }
1829                }
1830                UserContent::Document(doc) => {
1831                    if let Some((media_type, data)) = parse_data_url_base64(&doc.url) {
1832                        let data_url = format!("data:{};base64,{}", media_type, data);
1833                        parts.push(json!({
1834                            "type": "input_file",
1835                            "file_data": data_url,
1836                            "filename": Self::filename_for_media_type(&media_type),
1837                        }));
1838                    } else {
1839                        parts.push(json!({
1840                            "type": "input_file",
1841                            "file_url": doc.url
1842                        }));
1843                    }
1844                }
1845                UserContent::Audio(audio) => parts.push(json!({
1846                    "type": "input_text",
1847                    "text": format!("[audio: {}]", audio.url)
1848                })),
1849                UserContent::Video(video) => parts.push(json!({
1850                    "type": "input_text",
1851                    "text": format!("[video: {}]", video.url)
1852                })),
1853            }
1854        }
1855        Ok(Value::Array(parts))
1856    }
1857
1858    fn build_body(
1859        &self,
1860        messages: &[ModelMessage],
1861        params: &ModelRequestParameters,
1862        stream: bool,
1863    ) -> Result<Value, ModelError> {
1864        let mut body = Map::new();
1865        body.insert("model".to_string(), Value::String(self.model.clone()));
1866        body.insert(
1867            "input".to_string(),
1868            Value::Array(self.make_input_messages(messages)?),
1869        );
1870
1871        if !params.function_tools.is_empty() {
1872            let tools = params
1873                .function_tools
1874                .iter()
1875                .map(|tool| {
1876                    let (schema, _strict_ok) =
1877                        transform_openai_schema(&tool.parameters_json_schema, None);
1878                    json!({
1879                        "type": "function",
1880                        "name": tool.name,
1881                        "description": tool.description,
1882                        "parameters": schema,
1883                    })
1884                })
1885                .collect();
1886            body.insert("tools".to_string(), Value::Array(tools));
1887            if params.function_tools.iter().any(|tool| tool.sequential) {
1888                body.insert("parallel_tool_calls".to_string(), Value::Bool(false));
1889            }
1890        }
1891
1892        if params.output_mode == OutputMode::JsonSchema
1893            && let Some(schema) = params.output_schema.clone()
1894        {
1895            let strict = !params.allow_text_output;
1896            let (schema, _strict_ok) = transform_openai_schema(&schema, Some(strict));
1897            body.insert(
1898                "text".to_string(),
1899                json!({
1900                    "format": {
1901                        "type": "json_schema",
1902                        "name": "output",
1903                        "schema": schema,
1904                        "strict": strict,
1905                    }
1906                }),
1907            );
1908        }
1909
1910        if stream {
1911            body.insert("stream".to_string(), Value::Bool(true));
1912        }
1913
1914        if let Some(settings) = &self.default_settings {
1915            for (key, value) in settings {
1916                if key == "max_tokens" {
1917                    body.insert("max_output_tokens".to_string(), value.clone());
1918                    continue;
1919                }
1920                body.insert(key.clone(), value.clone());
1921            }
1922        }
1923
1924        Ok(Value::Object(body))
1925    }
1926
1927    fn build_request(
1928        &self,
1929        messages: &[ModelMessage],
1930        settings: Option<&ModelSettings>,
1931        params: &ModelRequestParameters,
1932        stream: bool,
1933    ) -> Result<OpenAIRequest, ModelError> {
1934        let mut body = self.build_body(messages, params, stream)?;
1935        if let Some(settings) = settings
1936            && let Value::Object(map) = &mut body
1937        {
1938            for (key, value) in settings {
1939                if key == "max_tokens" {
1940                    map.insert("max_output_tokens".to_string(), value.clone());
1941                    continue;
1942                }
1943                map.insert(key.clone(), value.clone());
1944            }
1945        }
1946
1947        Ok(OpenAIRequest { body })
1948    }
1949}
1950
1951#[async_trait]
1952impl Model for OpenAIResponsesModel {
1953    fn name(&self) -> &str {
1954        &self.model
1955    }
1956
1957    async fn request(
1958        &self,
1959        messages: &[ModelMessage],
1960        settings: Option<&ModelSettings>,
1961        params: &ModelRequestParameters,
1962    ) -> Result<ModelResponse, ModelError> {
1963        tracing::debug!(
1964            model = %self.model,
1965            tool_count = params.function_tools.len(),
1966            output_schema = params.output_schema.is_some(),
1967            "OpenAI responses request"
1968        );
1969        let request = self.build_request(messages, settings, params, false)?;
1970
1971        let response = self
1972            .client
1973            .post(self.endpoint()?)
1974            .bearer_auth(&self.api_key)
1975            .json(&request.body)
1976            .send()
1977            .await
1978            .map_err(|e| map_reqwest_error("OpenAI Responses", e))?;
1979
1980        let status = response.status();
1981        if !status.is_success() {
1982            let body = response.text().await.unwrap_or_default();
1983            tracing::error!(
1984                status = status.as_u16(),
1985                model = %self.model,
1986                body = %truncate_error_body(&body),
1987                "OpenAI responses request failed"
1988            );
1989            return Err(ModelError::HttpStatus {
1990                status: status.as_u16(),
1991            });
1992        }
1993
1994        let body: OpenAIResponsesResponse = response.json().await.map_err(|e| {
1995            tracing::error!(
1996                error = %e,
1997                model = %self.model,
1998                "OpenAI responses parse failed"
1999            );
2000            ModelError::Provider(format!("OpenAI response parse failed: {e}"))
2001        })?;
2002
2003        let mut parts = Vec::new();
2004        for item in body.output {
2005            parts.push(ModelResponsePart::ProviderItem(ProviderItemPart {
2006                provider: "openai_responses".to_string(),
2007                payload: item.clone(),
2008            }));
2009
2010            if let Some(item_type) = item.get("type").and_then(|value| value.as_str()) {
2011                match item_type {
2012                    "message" => {
2013                        if let Some(content) =
2014                            item.get("content").and_then(|value| value.as_array())
2015                        {
2016                            for part in content {
2017                                if part.get("type").and_then(|value| value.as_str())
2018                                    == Some("output_text")
2019                                    && let Some(text) =
2020                                        part.get("text").and_then(|value| value.as_str())
2021                                {
2022                                    parts.push(ModelResponsePart::Text(TextPart {
2023                                        content: text.to_string(),
2024                                    }));
2025                                }
2026                            }
2027                        }
2028                    }
2029                    "function_call" => {
2030                        let name = item
2031                            .get("name")
2032                            .and_then(|value| value.as_str())
2033                            .unwrap_or("tool")
2034                            .to_string();
2035                        let call_id = item
2036                            .get("call_id")
2037                            .and_then(|value| value.as_str())
2038                            .map(str::to_string);
2039                        let arguments = item.get("arguments").cloned().unwrap_or(Value::Null);
2040                        let args = match arguments {
2041                            Value::String(value) => serde_json::from_str::<Value>(&value)
2042                                .unwrap_or(Value::String(value)),
2043                            other => other,
2044                        };
2045                        parts.push(ModelResponsePart::ToolCall(ToolCallPart {
2046                            id: normalize_tool_call_id(call_id),
2047                            name,
2048                            arguments: args,
2049                        }));
2050                    }
2051                    _ => {}
2052                }
2053            }
2054        }
2055
2056        let usage = body.usage.map(|usage| RequestUsage {
2057            input_tokens: usage.input_tokens.unwrap_or(0),
2058            output_tokens: usage.output_tokens.unwrap_or(0),
2059            ..Default::default()
2060        });
2061
2062        Ok(ModelResponse {
2063            parts,
2064            usage,
2065            model_name: body.model.or_else(|| Some(self.model.clone())),
2066            finish_reason: body.finish_reason,
2067        })
2068    }
2069
2070    async fn request_stream(
2071        &self,
2072        messages: &[ModelMessage],
2073        settings: Option<&ModelSettings>,
2074        params: &ModelRequestParameters,
2075    ) -> Result<ModelStream, ModelError> {
2076        tracing::debug!(
2077            model = %self.model,
2078            tool_count = params.function_tools.len(),
2079            output_schema = params.output_schema.is_some(),
2080            "OpenAI responses stream request"
2081        );
2082        let request = self.build_request(messages, settings, params, true)?;
2083
2084        let response = self
2085            .client
2086            .post(self.endpoint()?)
2087            .bearer_auth(&self.api_key)
2088            .json(&request.body)
2089            .send()
2090            .await
2091            .map_err(|e| map_reqwest_error("OpenAI Responses stream", e))?;
2092
2093        let status = response.status();
2094        if !status.is_success() {
2095            let body = response.text().await.unwrap_or_default();
2096            tracing::error!(
2097                status = status.as_u16(),
2098                model = %self.model,
2099                body = %truncate_error_body(&body),
2100                "OpenAI responses stream request failed"
2101            );
2102            return Err(ModelError::HttpStatus {
2103                status: status.as_u16(),
2104            });
2105        }
2106
2107        let mut event_stream = response.bytes_stream().eventsource();
2108        let model_name = self.model.clone();
2109
2110        let s = try_stream! {
2111            while let Some(event) = event_stream.next().await {
2112                let event = event.map_err(|e| {
2113                    tracing::error!(error = %e, model = %model_name, "OpenAI responses stream error");
2114                    ModelError::Provider(format!("OpenAI responses stream error: {e}"))
2115                })?;
2116                let data = event.data;
2117                if data.trim() == "[DONE]" {
2118                    break;
2119                }
2120                let event: OpenAIResponsesStreamEvent = serde_json::from_str(&data).map_err(|e| {
2121                    tracing::error!(error = %e, model = %model_name, "OpenAI responses stream parse error");
2122                    ModelError::Provider(format!("OpenAI responses stream parse error: {e}"))
2123                })?;
2124
2125                match event.kind.as_str() {
2126                    "response.output_text.delta" => {
2127                        if let Some(delta) = event.delta {
2128                            yield StreamChunk {
2129                                text_delta: Some(delta),
2130                                tool_call: None,
2131                                finish_reason: None,
2132                                usage: None,
2133                            };
2134                        }
2135                    }
2136                    "response.output_item.done" => {
2137                        if let Some(item) = event.item
2138                            && let Some(call) = parse_responses_stream_tool_call(&item)
2139                        {
2140                            yield StreamChunk {
2141                                text_delta: None,
2142                                tool_call: Some(call),
2143                                finish_reason: None,
2144                                usage: None,
2145                            };
2146                        }
2147                    }
2148                    "response.completed" | "response.done" => {
2149                        let usage = event
2150                            .response
2151                            .as_ref()
2152                            .and_then(parse_responses_stream_usage);
2153                        yield StreamChunk {
2154                            text_delta: None,
2155                            tool_call: None,
2156                            finish_reason: Some("stop".to_string()),
2157                            usage,
2158                        };
2159                    }
2160                    "response.failed" => {
2161                        let detail = event
2162                            .response
2163                            .map(|value| value.to_string())
2164                            .unwrap_or_else(|| "response.failed".to_string());
2165                        Err(ModelError::Provider(format!(
2166                            "OpenAI responses stream failed: {detail}"
2167                        )))?;
2168                    }
2169                    _ => {}
2170                }
2171            }
2172        };
2173
2174        Ok(Box::pin(s))
2175    }
2176}
2177
2178#[derive(Debug, Deserialize)]
2179struct OpenAIResponsesResponse {
2180    output: Vec<Value>,
2181    usage: Option<OpenAIResponsesUsage>,
2182    model: Option<String>,
2183    #[serde(rename = "finish_reason")]
2184    finish_reason: Option<String>,
2185}
2186
2187#[derive(Debug, Deserialize)]
2188struct OpenAIResponsesUsage {
2189    input_tokens: Option<u64>,
2190    output_tokens: Option<u64>,
2191}