1use std::collections::HashMap;
2use std::sync::Arc;
3
4use async_stream::try_stream;
5use async_trait::async_trait;
6use base64::{Engine as _, engine::general_purpose};
7use eventsource_stream::Eventsource;
8use futures::stream::StreamExt;
9use reqwest::{Client, Url};
10use serde::Deserialize;
11use serde_json::{Map, Value, json};
12use uuid::Uuid;
13
14use crate::json_schema::transform_openai_schema;
15use crate::messages::{
16 BinaryContent, ModelMessage, ModelRequestPart, ModelResponse, ModelResponsePart,
17 ProviderItemPart, TextPart, ToolCallPart, UserContent,
18};
19use crate::model::{
20 Model, ModelError, ModelRequestParameters, ModelSettings, ModelStream, OutputMode, StreamChunk,
21};
22use crate::providers::{Provider, ProviderError};
23use crate::usage::RequestUsage;
24
25struct OpenAIRequest {
26 body: Value,
27}
28
29fn map_reqwest_error(label: &str, error: reqwest::Error) -> ModelError {
30 if error.is_timeout() {
31 return ModelError::Timeout;
32 }
33 if error.is_connect() {
34 return ModelError::Transport(format!("{label} connect error: {error}"));
35 }
36 ModelError::Transport(format!("{label} request failed: {error}"))
37}
38
39fn truncate_error_body(body: &str) -> String {
40 const LIMIT: usize = 2000;
41 let trimmed = body.trim();
42 if trimmed.is_empty() {
43 return String::new();
44 }
45 if trimmed.chars().count() <= LIMIT {
46 return trimmed.to_string();
47 }
48 let truncated: String = trimmed.chars().take(LIMIT).collect();
49 format!("{truncated}...[truncated]")
50}
51
52fn join_path(base: &Url, path: &str) -> Result<Url, ModelError> {
53 let mut url = base.clone();
54 let base_path = url.path().trim_end_matches('/');
55 let path = path.trim_start_matches('/');
56 let new_path = if base_path.is_empty() || base_path == "/" {
57 format!("/{path}")
58 } else {
59 format!("{base_path}/{path}")
60 };
61 url.set_path(&new_path);
62 Ok(url)
63}
64
65fn normalize_tool_call_id(id: Option<String>) -> String {
66 match id {
67 Some(value) if !value.trim().is_empty() => value,
68 _ => format!("call_{}", Uuid::new_v4().simple()),
69 }
70}
71
72fn normalize_tool_call_id_str(id: &str) -> String {
73 if id.trim().is_empty() {
74 format!("call_{}", Uuid::new_v4().simple())
75 } else {
76 id.to_string()
77 }
78}
79
80fn tool_return_content(value: &Value) -> String {
81 match value {
82 Value::String(value) => value.clone(),
83 _ => serde_json::to_string(value).unwrap_or_else(|_| value.to_string()),
84 }
85}
86
87fn tool_call_arguments(value: &Value) -> String {
88 match value {
89 Value::String(value) => value.clone(),
90 _ => serde_json::to_string(value).unwrap_or_else(|_| value.to_string()),
91 }
92}
93
94fn is_text_like_media_type(media_type: &str) -> bool {
95 media_type.starts_with("text/")
96 || matches!(
97 media_type,
98 "application/json"
99 | "application/xml"
100 | "application/xhtml+xml"
101 | "application/javascript"
102 | "application/x-www-form-urlencoded"
103 )
104}
105
106fn audio_format_from_media_type(media_type: &str) -> Option<&'static str> {
107 match media_type {
108 "audio/wav" | "audio/x-wav" => Some("wav"),
109 "audio/mpeg" | "audio/mp3" => Some("mp3"),
110 "audio/ogg" | "audio/ogg;codecs=opus" => Some("ogg"),
111 "audio/flac" => Some("flac"),
112 "audio/aiff" => Some("aiff"),
113 "audio/aac" => Some("aac"),
114 _ => None,
115 }
116}
117
118fn parse_data_url_base64(url: &str) -> Option<(String, String)> {
119 let data_url = url.strip_prefix("data:")?;
120 let (meta, data) = data_url.split_once(',')?;
121 let (media_type, encoding) = meta.split_once(';')?;
122 if encoding != "base64" || media_type.trim().is_empty() {
123 return None;
124 }
125 Some((media_type.to_string(), data.to_string()))
126}
127
128fn normalize_stream_tool_call_id(id: Option<String>, index: Option<usize>) -> String {
129 if let Some(value) = id.filter(|value| !value.trim().is_empty()) {
130 value
131 } else if let Some(index) = index {
132 format!("call_{index}")
133 } else {
134 normalize_tool_call_id(None)
135 }
136}
137
138fn contains_audio(messages: &[ModelMessage]) -> bool {
139 for message in messages {
140 if let ModelMessage::Request(req) = message {
141 for part in &req.parts {
142 if let ModelRequestPart::UserPrompt(prompt) = part {
143 for item in &prompt.content {
144 match item {
145 UserContent::Audio(_) => return true,
146 UserContent::Binary(binary) => {
147 if binary.media_type.starts_with("audio/") {
148 return true;
149 }
150 }
151 _ => {}
152 }
153 }
154 }
155 }
156 }
157 }
158 false
159}
160
161fn is_responses_only_model(model: &str) -> bool {
162 let lowered = model.to_lowercase();
163 lowered.starts_with("gpt-5")
164 || lowered.starts_with("gpt-4.1")
165 || lowered.starts_with("o1")
166 || lowered.starts_with("o3")
167}
168
169fn prefers_responses(model: &str) -> bool {
170 let lowered = model.to_lowercase();
171 is_responses_only_model(model)
172 || lowered.starts_with("gpt-4o")
173 || lowered.starts_with("gpt-4.1")
174 || lowered.starts_with("o1")
175 || lowered.starts_with("o3")
176}
177
178#[derive(Clone, Debug)]
179pub(crate) struct OpenAIChatCapabilities {
180 pub(crate) supports_response_format: bool,
181 pub(crate) supports_parallel_tool_calls: bool,
182 pub(crate) reject_binary_images: bool,
183}
184
185impl Default for OpenAIChatCapabilities {
186 fn default() -> Self {
187 Self {
188 supports_response_format: true,
189 supports_parallel_tool_calls: true,
190 reject_binary_images: false,
191 }
192 }
193}
194
195#[derive(Clone, Debug)]
196pub struct OpenAIProvider {
197 api_key: String,
198 base_url: Url,
199}
200
201impl OpenAIProvider {
202 pub fn new(
203 api_key: impl Into<String>,
204 base_url: impl AsRef<str>,
205 ) -> Result<Self, ProviderError> {
206 let url = Url::parse(base_url.as_ref())
207 .map_err(|_| ProviderError::InvalidModel(base_url.as_ref().to_string()))?;
208 Ok(Self {
209 api_key: api_key.into(),
210 base_url: url,
211 })
212 }
213
214 pub fn from_env() -> Result<Self, ProviderError> {
215 let api_key = std::env::var("OPENAI_API_KEY")
216 .map_err(|_| ProviderError::MissingApiKey("openai".to_string()))?;
217 Self::new(api_key, "https://api.openai.com/v1")
218 }
219
220 pub fn with_base_url(mut self, base_url: impl AsRef<str>) -> Result<Self, ProviderError> {
221 self.base_url = Url::parse(base_url.as_ref())
222 .map_err(|_| ProviderError::InvalidModel(base_url.as_ref().to_string()))?;
223 Ok(self)
224 }
225}
226
227#[cfg(test)]
228mod tests {
229 use super::*;
230 use base64::engine::general_purpose::STANDARD;
231 use serde_json::{Value, json};
232 use std::path::PathBuf;
233
234 use crate::messages::{
235 AudioUrl, BinaryContent, DocumentUrl, ImageUrl, ModelMessage, ModelRequest,
236 ModelRequestPart, ModelResponse, ModelResponsePart, ProviderItemPart, TextPart,
237 ToolCallPart, ToolReturnPart, UserContent,
238 };
239
240 fn fixture_bytes(name: &str) -> Vec<u8> {
241 let path = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
242 .join("tests")
243 .join("fixtures")
244 .join(name);
245 std::fs::read(path).expect("fixture read")
246 }
247
248 #[test]
249 fn convert_user_content_handles_binary_media() {
250 let model = OpenAIChatModel::new(
251 "gpt-4o-mini",
252 "test-key".to_string(),
253 Url::parse("https://example.com/").expect("valid url"),
254 None,
255 );
256
257 let image_bytes = fixture_bytes("fixture.jpg");
258 let audio_bytes = fixture_bytes("fixture.m4a");
259 let pdf_bytes = fixture_bytes("fixture.pdf");
260
261 let content = vec![
262 UserContent::Binary(BinaryContent {
263 data: image_bytes.clone(),
264 media_type: "image/jpeg".to_string(),
265 }),
266 UserContent::Binary(BinaryContent {
267 data: audio_bytes.clone(),
268 media_type: "audio/aac".to_string(),
269 }),
270 UserContent::Binary(BinaryContent {
271 data: pdf_bytes.clone(),
272 media_type: "application/pdf".to_string(),
273 }),
274 ];
275
276 let value = model
277 .convert_user_content(&content)
278 .expect("convert user content");
279 let parts = value.as_array().expect("parts array");
280 assert_eq!(parts.len(), 3);
281
282 let image = &parts[0];
283 assert_eq!(
284 image.get("type"),
285 Some(&Value::String("image_url".to_string()))
286 );
287 let image_url = image
288 .get("image_url")
289 .and_then(|value| value.get("url"))
290 .and_then(|value| value.as_str())
291 .expect("image url");
292 let expected_image = format!("data:image/jpeg;base64,{}", STANDARD.encode(&image_bytes));
293 assert_eq!(image_url, expected_image);
294
295 let audio = &parts[1];
296 assert_eq!(
297 audio.get("type"),
298 Some(&Value::String("input_audio".to_string()))
299 );
300 let audio_input = audio.get("input_audio").expect("input_audio");
301 assert_eq!(
302 audio_input.get("format"),
303 Some(&Value::String("aac".to_string()))
304 );
305 let audio_data = audio_input
306 .get("data")
307 .and_then(|value| value.as_str())
308 .expect("audio data");
309 assert_eq!(audio_data, STANDARD.encode(&audio_bytes));
310
311 let pdf = &parts[2];
312 assert_eq!(pdf.get("type"), Some(&Value::String("text".to_string())));
313 let pdf_text = pdf
314 .get("text")
315 .and_then(|value| value.as_str())
316 .expect("pdf text");
317 let expected_text = format!("[binary content: {} bytes]", pdf_bytes.len());
318 assert_eq!(pdf_text, expected_text);
319 }
320
321 #[test]
322 fn make_messages_replays_tool_calls() {
323 let model = OpenAIChatModel::new(
324 "gpt-4o-mini",
325 "test-key".to_string(),
326 Url::parse("https://example.com/").expect("valid url"),
327 None,
328 );
329
330 let messages = vec![
331 ModelMessage::Response(ModelResponse {
332 parts: vec![ModelResponsePart::ToolCall(ToolCallPart {
333 id: "call-1".to_string(),
334 name: "get_data".to_string(),
335 arguments: json!({"a": 1}),
336 })],
337 usage: None,
338 model_name: None,
339 finish_reason: None,
340 }),
341 ModelMessage::Request(ModelRequest {
342 parts: vec![ModelRequestPart::ToolReturn(ToolReturnPart {
343 tool_name: "get_data".to_string(),
344 tool_call_id: "call-1".to_string(),
345 content: json!({"ok": true}),
346 })],
347 instructions: None,
348 }),
349 ];
350
351 let out = model.make_messages(&messages).expect("make messages");
352 assert_eq!(out.len(), 2);
353
354 let assistant = out[0].as_object().expect("assistant message");
355 assert_eq!(
356 assistant.get("role"),
357 Some(&Value::String("assistant".to_string()))
358 );
359 assert_eq!(assistant.get("content"), Some(&Value::Null));
360 let tool_calls = assistant
361 .get("tool_calls")
362 .and_then(|value| value.as_array())
363 .expect("tool_calls");
364 assert_eq!(tool_calls.len(), 1);
365 let call = &tool_calls[0];
366 assert_eq!(call.get("id"), Some(&Value::String("call-1".to_string())));
367 let function = call.get("function").expect("function");
368 assert_eq!(
369 function.get("name"),
370 Some(&Value::String("get_data".to_string()))
371 );
372 assert_eq!(
373 function.get("arguments"),
374 Some(&Value::String("{\"a\":1}".to_string()))
375 );
376
377 let tool = out[1].as_object().expect("tool message");
378 assert_eq!(tool.get("role"), Some(&Value::String("tool".to_string())));
379 assert_eq!(
380 tool.get("tool_call_id"),
381 Some(&Value::String("call-1".to_string()))
382 );
383 assert_eq!(
384 tool.get("content"),
385 Some(&Value::String("{\"ok\":true}".to_string()))
386 );
387 }
388
389 #[test]
390 fn responses_replays_tool_calls() {
391 let model = OpenAIResponsesModel::new(
392 "gpt-5-mini",
393 "test-key".to_string(),
394 Url::parse("https://example.com/").expect("valid url"),
395 None,
396 );
397
398 let messages = vec![
399 ModelMessage::Response(ModelResponse {
400 parts: vec![ModelResponsePart::ToolCall(ToolCallPart {
401 id: "call-1".to_string(),
402 name: "get_data".to_string(),
403 arguments: json!({"a": 1}),
404 })],
405 usage: None,
406 model_name: None,
407 finish_reason: None,
408 }),
409 ModelMessage::Request(ModelRequest {
410 parts: vec![ModelRequestPart::ToolReturn(ToolReturnPart {
411 tool_name: "get_data".to_string(),
412 tool_call_id: "call-1".to_string(),
413 content: json!({"ok": true}),
414 })],
415 instructions: None,
416 }),
417 ];
418
419 let out = model
420 .make_input_messages(&messages)
421 .expect("make input messages");
422 assert_eq!(out.len(), 2);
423
424 let call = out[0].as_object().expect("function call item");
425 assert_eq!(
426 call.get("type"),
427 Some(&Value::String("function_call".to_string()))
428 );
429 assert_eq!(
430 call.get("call_id"),
431 Some(&Value::String("call-1".to_string()))
432 );
433 assert_eq!(
434 call.get("name"),
435 Some(&Value::String("get_data".to_string()))
436 );
437 assert_eq!(
438 call.get("arguments"),
439 Some(&Value::String("{\"a\":1}".to_string()))
440 );
441
442 let output = out[1].as_object().expect("function call output");
443 assert_eq!(
444 output.get("type"),
445 Some(&Value::String("function_call_output".to_string()))
446 );
447 assert_eq!(
448 output.get("call_id"),
449 Some(&Value::String("call-1".to_string()))
450 );
451 assert_eq!(
452 output.get("output"),
453 Some(&Value::String("{\"ok\":true}".to_string()))
454 );
455 }
456
457 #[test]
458 fn responses_replays_provider_items() {
459 let model = OpenAIResponsesModel::new(
460 "gpt-5-mini",
461 "test-key".to_string(),
462 Url::parse("https://example.com/").expect("valid url"),
463 None,
464 );
465
466 let raw_item = json!({
467 "type": "reasoning",
468 "summary": "ok"
469 });
470
471 let messages = vec![ModelMessage::Response(ModelResponse {
472 parts: vec![
473 ModelResponsePart::ProviderItem(ProviderItemPart {
474 provider: "openai_responses".to_string(),
475 payload: raw_item.clone(),
476 }),
477 ModelResponsePart::Text(TextPart {
478 content: "ignored".to_string(),
479 }),
480 ],
481 usage: None,
482 model_name: None,
483 finish_reason: None,
484 })];
485
486 let out = model
487 .make_input_messages(&messages)
488 .expect("make input messages");
489 assert_eq!(out.len(), 1);
490 assert_eq!(out[0], raw_item);
491 }
492
493 #[test]
494 fn unified_model_streaming_prefers_chat_when_available() {
495 let model = OpenAIUnifiedModel::new(
496 "gpt-4o-mini",
497 "test-key".to_string(),
498 Url::parse("https://example.com/").expect("valid url"),
499 None,
500 );
501
502 let mode = model.select_api(&[], true).expect("select api for stream");
503 assert!(matches!(mode, OpenAIApiMode::Chat));
504 }
505
506 #[test]
507 fn unified_model_streaming_supports_responses_only() {
508 let model = OpenAIUnifiedModel::new(
509 "gpt-5-mini",
510 "test-key".to_string(),
511 Url::parse("https://example.com/").expect("valid url"),
512 None,
513 );
514
515 let mode = model.select_api(&[], true).expect("select api for stream");
516 assert!(matches!(mode, OpenAIApiMode::Responses));
517 }
518
519 #[test]
520 fn make_messages_groups_consecutive_tool_calls() {
521 let model = OpenAIChatModel::new(
522 "gpt-4o-mini",
523 "test-key".to_string(),
524 Url::parse("https://example.com/").expect("valid url"),
525 None,
526 );
527
528 let messages = vec![
529 ModelMessage::Response(ModelResponse {
530 parts: vec![ModelResponsePart::ToolCall(ToolCallPart {
531 id: "call-1".to_string(),
532 name: "get_data".to_string(),
533 arguments: json!({"a": 1}),
534 })],
535 usage: None,
536 model_name: None,
537 finish_reason: None,
538 }),
539 ModelMessage::Response(ModelResponse {
540 parts: vec![ModelResponsePart::ToolCall(ToolCallPart {
541 id: "call-2".to_string(),
542 name: "get_data".to_string(),
543 arguments: json!({"b": 2}),
544 })],
545 usage: None,
546 model_name: None,
547 finish_reason: None,
548 }),
549 ];
550
551 let out = model.make_messages(&messages).expect("make messages");
552 assert_eq!(out.len(), 1);
553 let assistant = out[0].as_object().expect("assistant message");
554 let tool_calls = assistant
555 .get("tool_calls")
556 .and_then(|value| value.as_array())
557 .expect("tool_calls");
558 assert_eq!(tool_calls.len(), 2);
559 }
560
561 #[test]
562 fn responses_build_request_maps_max_tokens() {
563 let model = OpenAIResponsesModel::new(
564 "gpt-5-mini",
565 "test-key".to_string(),
566 Url::parse("https://example.com/").expect("valid url"),
567 None,
568 );
569
570 let messages = vec![ModelMessage::Request(ModelRequest::user_text_prompt("hi"))];
571 let params = ModelRequestParameters::default();
572 let mut settings = Map::new();
573 settings.insert("max_tokens".to_string(), Value::Number(42.into()));
574
575 let request = model
576 .build_request(&messages, Some(&settings), ¶ms, false)
577 .expect("build request");
578 let body = request.body.as_object().expect("body object");
579 assert!(body.contains_key("max_output_tokens"));
580 assert!(!body.contains_key("max_tokens"));
581 }
582
583 #[test]
584 fn responses_stream_helpers_parse_tool_calls_and_usage() {
585 let item = json!({
586 "type": "function_call",
587 "name": "echo",
588 "call_id": "call-1",
589 "arguments": "{\"msg\":\"hi\"}"
590 });
591 let call = parse_responses_stream_tool_call(&item).expect("tool call");
592 assert_eq!(call.name, "echo");
593 assert_eq!(call.id, "call-1");
594 assert_eq!(call.arguments, json!({"msg": "hi"}));
595
596 let response = json!({
597 "usage": {
598 "input_tokens": 10,
599 "output_tokens": 5
600 }
601 });
602 let usage = parse_responses_stream_usage(&response).expect("usage");
603 assert_eq!(usage.input_tokens, 10);
604 assert_eq!(usage.output_tokens, 5);
605 }
606
607 #[test]
608 fn helper_functions_cover_ids_and_media_types() {
609 assert!(is_text_like_media_type("text/plain"));
610 assert!(is_text_like_media_type("application/json"));
611 assert!(!is_text_like_media_type("image/png"));
612
613 assert_eq!(audio_format_from_media_type("audio/mpeg"), Some("mp3"));
614 assert_eq!(audio_format_from_media_type("audio/aac"), Some("aac"));
615 assert_eq!(audio_format_from_media_type("audio/unknown"), None);
616
617 let parsed = parse_data_url_base64("data:audio/mpeg;base64,SGVsbG8=").expect("parse");
618 assert_eq!(parsed.0, "audio/mpeg");
619 assert_eq!(parsed.1, "SGVsbG8=");
620 assert!(parse_data_url_base64("https://example.com").is_none());
621
622 let id = normalize_tool_call_id(Some("".to_string()));
623 assert!(id.starts_with("call_"));
624 let id = normalize_tool_call_id_str("");
625 assert!(id.starts_with("call_"));
626
627 let id = normalize_stream_tool_call_id(None, Some(2));
628 assert_eq!(id, "call_2");
629
630 let id = normalize_stream_tool_call_id(Some("explicit".to_string()), Some(1));
631 assert_eq!(id, "explicit");
632 }
633
634 #[test]
635 fn helper_functions_cover_text_and_urls() {
636 let long_body = "a".repeat(2100);
637 let truncated = truncate_error_body(&format!("{long_body}\n"));
638 assert!(truncated.ends_with("...[truncated]"));
639
640 let base = Url::parse("https://example.com/v1/").expect("url");
641 let joined = join_path(&base, "chat/completions").expect("join");
642 assert_eq!(joined.as_str(), "https://example.com/v1/chat/completions");
643
644 assert_eq!(tool_return_content(&json!("ok")), "ok");
645 assert_eq!(tool_call_arguments(&json!({"a": 1})), "{\"a\":1}");
646
647 assert!(is_responses_only_model("gpt-5-mini"));
648 assert!(!is_responses_only_model("gpt-4o-mini"));
649 assert!(prefers_responses("gpt-4o-mini"));
650 assert!(!prefers_responses("gpt-3.5-turbo"));
651 }
652
653 #[test]
654 fn contains_audio_detects_audio_inputs() {
655 let messages = vec![ModelMessage::Request(ModelRequest {
656 parts: vec![ModelRequestPart::UserPrompt(
657 crate::messages::UserPromptPart {
658 content: vec![UserContent::Audio(AudioUrl {
659 url: "data:audio/mpeg;base64,SGVsbG8=".to_string(),
660 media_type: None,
661 })],
662 },
663 )],
664 instructions: None,
665 })];
666 assert!(contains_audio(&messages));
667 }
668
669 #[test]
670 fn convert_user_content_handles_text_and_urls() {
671 let model = OpenAIChatModel::new(
672 "gpt-4o-mini",
673 "test-key".to_string(),
674 Url::parse("https://example.com/").expect("valid url"),
675 None,
676 );
677
678 let content = vec![
679 UserContent::Text("hello".to_string()),
680 UserContent::Image(ImageUrl {
681 url: "https://example.com/image.png".to_string(),
682 media_type: None,
683 }),
684 UserContent::Audio(AudioUrl {
685 url: "data:audio/mpeg;base64,SGVsbG8=".to_string(),
686 media_type: None,
687 }),
688 UserContent::Document(DocumentUrl {
689 url: "data:text/plain;base64,SGVsbG8=".to_string(),
690 media_type: None,
691 }),
692 UserContent::Document(DocumentUrl {
693 url: "https://example.com/doc.pdf".to_string(),
694 media_type: None,
695 }),
696 ];
697
698 let parts = model
699 .convert_user_content(&content)
700 .expect("convert user content");
701 let parts = parts.as_array().expect("parts array");
702 assert_eq!(parts.len(), 5);
703 assert_eq!(
704 parts[0].get("type"),
705 Some(&Value::String("text".to_string()))
706 );
707 assert_eq!(
708 parts[1].get("type"),
709 Some(&Value::String("image_url".to_string()))
710 );
711 assert_eq!(
712 parts[2].get("type"),
713 Some(&Value::String("input_audio".to_string()))
714 );
715 assert_eq!(
716 parts[3].get("text"),
717 Some(&Value::String("Hello".to_string()))
718 );
719 assert_eq!(
720 parts[4].get("text"),
721 Some(&Value::String(
722 "[document: https://example.com/doc.pdf]".to_string()
723 ))
724 );
725 }
726
727 #[test]
728 fn convert_user_content_rejects_binary_images_when_disabled() {
729 let model = OpenAIChatModel::new_with_capabilities(
730 "gpt-4o-mini",
731 "test-key".to_string(),
732 Url::parse("https://example.com/").expect("valid url"),
733 None,
734 OpenAIChatCapabilities {
735 supports_response_format: true,
736 supports_parallel_tool_calls: true,
737 reject_binary_images: true,
738 },
739 );
740
741 let content = vec![UserContent::Binary(BinaryContent {
742 data: vec![1, 2, 3],
743 media_type: "image/png".to_string(),
744 })];
745
746 let err = model
747 .convert_user_content(&content)
748 .expect_err("should error");
749 match err {
750 ModelError::Unsupported(message) => {
751 assert!(message.contains("binary image inputs"));
752 }
753 other => panic!("unexpected error: {other:?}"),
754 }
755 }
756
757 #[test]
758 fn responses_helpers_cover_media_filename_and_content() {
759 assert_eq!(
760 OpenAIResponsesModel::filename_for_media_type("application/pdf"),
761 "file.pdf"
762 );
763 assert_eq!(
764 OpenAIResponsesModel::filename_for_media_type("text/plain"),
765 "file.txt"
766 );
767 assert_eq!(
768 OpenAIResponsesModel::filename_for_media_type("image/png"),
769 "file.bin"
770 );
771
772 let model = OpenAIResponsesModel::new(
773 "gpt-5-mini",
774 "test-key".to_string(),
775 Url::parse("https://example.com/").expect("valid url"),
776 None,
777 );
778
779 let content = vec![
780 UserContent::Binary(BinaryContent {
781 data: b"hello".to_vec(),
782 media_type: "text/plain".to_string(),
783 }),
784 UserContent::Document(DocumentUrl {
785 url: "data:application/pdf;base64,SGVsbG8=".to_string(),
786 media_type: None,
787 }),
788 ];
789
790 let parts = model
791 .convert_user_content(&content)
792 .expect("convert content");
793 let parts = parts.as_array().expect("parts array");
794 assert_eq!(parts.len(), 2);
795 assert_eq!(
796 parts[0].get("type"),
797 Some(&Value::String("input_text".to_string()))
798 );
799 assert_eq!(
800 parts[1].get("type"),
801 Some(&Value::String("input_file".to_string()))
802 );
803 }
804}
805
806impl Provider for OpenAIProvider {
807 fn name(&self) -> &str {
808 "openai"
809 }
810
811 fn model(&self, model: &str, settings: Option<ModelSettings>) -> Arc<dyn Model> {
812 Arc::new(OpenAIUnifiedModel::new(
813 model,
814 self.api_key.clone(),
815 self.base_url.clone(),
816 settings,
817 ))
818 }
819}
820
821#[derive(Clone, Debug)]
822pub struct OpenAIChatModel {
823 model: String,
824 api_key: String,
825 base_url: Url,
826 client: Client,
827 default_settings: Option<ModelSettings>,
828 capabilities: OpenAIChatCapabilities,
829}
830
831impl OpenAIChatModel {
832 pub fn new(
833 model: impl Into<String>,
834 api_key: String,
835 base_url: Url,
836 settings: Option<ModelSettings>,
837 ) -> Self {
838 Self::new_with_capabilities(
839 model,
840 api_key,
841 base_url,
842 settings,
843 OpenAIChatCapabilities::default(),
844 )
845 }
846
847 pub(crate) fn new_with_capabilities(
848 model: impl Into<String>,
849 api_key: String,
850 base_url: Url,
851 settings: Option<ModelSettings>,
852 capabilities: OpenAIChatCapabilities,
853 ) -> Self {
854 Self {
855 model: model.into(),
856 api_key,
857 base_url,
858 client: Client::new(),
859 default_settings: settings,
860 capabilities,
861 }
862 }
863
864 fn endpoint(&self) -> Result<Url, ModelError> {
865 join_path(&self.base_url, "chat/completions")
866 }
867
868 fn make_messages(&self, messages: &[ModelMessage]) -> Result<Vec<Value>, ModelError> {
869 let mut out = Vec::new();
870 for message in messages {
871 match message {
872 ModelMessage::Request(req) => {
873 if let Some(instructions) = req
874 .instructions
875 .as_ref()
876 .filter(|value| !value.trim().is_empty())
877 {
878 out.push(json!({"role": "system", "content": instructions}));
879 }
880 for part in &req.parts {
881 match part {
882 ModelRequestPart::SystemPrompt(prompt) => {
883 out.push(json!({"role": "system", "content": prompt.content}))
884 }
885 ModelRequestPart::UserPrompt(prompt) => {
886 let content = self.convert_user_content(&prompt.content)?;
887 out.push(json!({"role": "user", "content": content}))
888 }
889 ModelRequestPart::ToolReturn(tool_return) => {
890 let content = tool_return_content(&tool_return.content);
891 out.push(json!({
892 "role": "tool",
893 "tool_call_id": normalize_tool_call_id_str(&tool_return.tool_call_id),
894 "content": content,
895 }))
896 }
897 ModelRequestPart::RetryPrompt(retry) => {
898 if retry.tool_name.is_some() {
899 out.push(json!({
900 "role": "tool",
901 "tool_call_id": normalize_tool_call_id(retry.tool_call_id.clone()),
902 "content": retry.content,
903 }));
904 } else {
905 out.push(json!({
906 "role": "user",
907 "content": retry.content,
908 }));
909 }
910 }
911 }
912 }
913 }
914 ModelMessage::Response(res) => {
915 let text = res.text();
916 let tool_calls = res.tool_calls();
917
918 if text.is_none() && tool_calls.is_empty() {
919 continue;
920 }
921
922 let calls = tool_calls
923 .into_iter()
924 .map(|call| {
925 let args = tool_call_arguments(&call.arguments);
926 json!({
927 "id": normalize_tool_call_id_str(&call.id),
928 "type": "function",
929 "function": {
930 "name": call.name,
931 "arguments": args,
932 }
933 })
934 })
935 .collect::<Vec<_>>();
936
937 if text.is_none()
938 && !calls.is_empty()
939 && let Some(Value::Object(last)) = out.last_mut()
940 {
941 let is_assistant =
942 last.get("role").and_then(|value| value.as_str()) == Some("assistant");
943 let is_tool_calls = last.get("content").is_some_and(Value::is_null)
944 && last.get("tool_calls").is_some();
945 if is_assistant
946 && is_tool_calls
947 && let Some(existing) =
948 last.get_mut("tool_calls").and_then(Value::as_array_mut)
949 {
950 existing.extend(calls);
951 continue;
952 }
953 }
954
955 let mut msg = Map::new();
956 msg.insert("role".to_string(), Value::String("assistant".to_string()));
957
958 if let Some(text) = text {
959 msg.insert("content".to_string(), Value::String(text));
960 } else if !calls.is_empty() {
961 msg.insert("content".to_string(), Value::Null);
962 }
963
964 if !calls.is_empty() {
965 msg.insert("tool_calls".to_string(), Value::Array(calls));
966 }
967
968 out.push(Value::Object(msg));
969 }
970 }
971 }
972 Ok(out)
973 }
974
975 fn convert_user_content(&self, content: &[UserContent]) -> Result<Value, ModelError> {
976 let mut parts = Vec::new();
977 for item in content {
978 match item {
979 UserContent::Text(text) => parts.push(json!({"type": "text", "text": text})),
980 UserContent::Image(image) => parts.push(json!({
981 "type": "image_url",
982 "image_url": {"url": image.url}
983 })),
984 UserContent::Binary(BinaryContent { data, media_type }) => {
985 if media_type.starts_with("image/") {
986 if self.capabilities.reject_binary_images {
987 return Err(ModelError::Unsupported(
988 "binary image inputs are not supported; provide an image URL"
989 .to_string(),
990 ));
991 }
992 let encoded = general_purpose::STANDARD.encode(data);
993 let data_url = format!("data:{};base64,{}", media_type, encoded);
994 parts.push(json!({
995 "type": "image_url",
996 "image_url": {"url": data_url}
997 }))
998 } else if media_type.starts_with("audio/") {
999 if let Some(format) = audio_format_from_media_type(media_type) {
1000 let encoded = general_purpose::STANDARD.encode(data);
1001 parts.push(json!({
1002 "type": "input_audio",
1003 "input_audio": {
1004 "data": encoded,
1005 "format": format
1006 }
1007 }))
1008 } else {
1009 parts.push(json!({
1010 "type": "text",
1011 "text": format!("[audio content: {} bytes]", data.len())
1012 }))
1013 }
1014 } else if is_text_like_media_type(media_type) {
1015 match std::str::from_utf8(data) {
1016 Ok(text) => parts.push(json!({"type": "text", "text": text})),
1017 Err(_) => parts.push(json!({
1018 "type": "text",
1019 "text": format!("[binary content: {} bytes]", data.len())
1020 })),
1021 }
1022 } else {
1023 parts.push(json!({
1024 "type": "text",
1025 "text": format!("[binary content: {} bytes]", data.len())
1026 }))
1027 }
1028 }
1029 UserContent::Audio(audio) => {
1030 if let Some((media_type, data)) = parse_data_url_base64(&audio.url)
1031 && let Some(format) = audio_format_from_media_type(&media_type)
1032 {
1033 parts.push(json!({
1034 "type": "input_audio",
1035 "input_audio": {
1036 "data": data,
1037 "format": format
1038 }
1039 }))
1040 } else {
1041 parts.push(json!({
1042 "type": "text",
1043 "text": format!("[audio: {}]", audio.url)
1044 }))
1045 }
1046 }
1047 UserContent::Video(video) => parts.push(json!({
1048 "type": "text",
1049 "text": format!("[video: {}]", video.url)
1050 })),
1051 UserContent::Document(doc) => {
1052 if let Some((media_type, data)) = parse_data_url_base64(&doc.url)
1053 && is_text_like_media_type(&media_type)
1054 {
1055 match general_purpose::STANDARD.decode(data.as_bytes()) {
1056 Ok(bytes) => match String::from_utf8(bytes) {
1057 Ok(text) => parts.push(json!({"type": "text", "text": text})),
1058 Err(_) => parts.push(json!({
1059 "type": "text",
1060 "text": format!("[document: {}]", doc.url)
1061 })),
1062 },
1063 Err(_) => parts.push(json!({
1064 "type": "text",
1065 "text": format!("[document: {}]", doc.url)
1066 })),
1067 }
1068 } else {
1069 parts.push(json!({
1070 "type": "text",
1071 "text": format!("[document: {}]", doc.url)
1072 }))
1073 }
1074 }
1075 }
1076 }
1077
1078 Ok(Value::Array(parts))
1079 }
1080
1081 fn build_body(
1082 &self,
1083 messages: &[ModelMessage],
1084 params: &ModelRequestParameters,
1085 stream: bool,
1086 ) -> Result<Value, ModelError> {
1087 let mut body = Map::new();
1088 body.insert("model".to_string(), Value::String(self.model.clone()));
1089 body.insert(
1090 "messages".to_string(),
1091 Value::Array(self.make_messages(messages)?),
1092 );
1093
1094 if !params.function_tools.is_empty() {
1095 let tools = params
1096 .function_tools
1097 .iter()
1098 .map(|tool| {
1099 let (schema, _strict_ok) =
1100 transform_openai_schema(&tool.parameters_json_schema, None);
1101 json!({
1102 "type": "function",
1103 "function": {
1104 "name": tool.name,
1105 "description": tool.description,
1106 "parameters": schema,
1107 }
1108 })
1109 })
1110 .collect();
1111 body.insert("tools".to_string(), Value::Array(tools));
1112 body.insert("tool_choice".to_string(), Value::String("auto".to_string()));
1113 if self.capabilities.supports_parallel_tool_calls
1114 && params.function_tools.iter().any(|tool| tool.sequential)
1115 {
1116 body.insert("parallel_tool_calls".to_string(), Value::Bool(false));
1117 }
1118 }
1119
1120 if params.output_mode == OutputMode::JsonSchema
1121 && let Some(schema) = params.output_schema.clone()
1122 && self.capabilities.supports_response_format
1123 {
1124 let strict = !params.allow_text_output;
1125 let (schema, _strict_ok) = transform_openai_schema(&schema, Some(strict));
1126 body.insert(
1127 "response_format".to_string(),
1128 json!({
1129 "type": "json_schema",
1130 "json_schema": {
1131 "name": "output",
1132 "schema": schema,
1133 "strict": strict,
1134 }
1135 }),
1136 );
1137 }
1138
1139 if stream {
1140 body.insert("stream".to_string(), Value::Bool(true));
1141 body.insert("stream_options".to_string(), json!({"include_usage": true}));
1142 }
1143
1144 if let Some(settings) = &self.default_settings {
1145 for (key, value) in settings {
1146 body.entry(key.clone()).or_insert(value.clone());
1147 }
1148 }
1149
1150 Ok(Value::Object(body))
1151 }
1152
1153 fn build_request(
1154 &self,
1155 messages: &[ModelMessage],
1156 settings: Option<&ModelSettings>,
1157 params: &ModelRequestParameters,
1158 stream: bool,
1159 ) -> Result<OpenAIRequest, ModelError> {
1160 let mut body = self.build_body(messages, params, stream)?;
1161 if let Some(settings) = settings
1162 && let Value::Object(map) = &mut body
1163 {
1164 for (key, value) in settings {
1165 map.insert(key.clone(), value.clone());
1166 }
1167 }
1168 Ok(OpenAIRequest { body })
1169 }
1170
1171 fn parse_tool_call(tool_call: &OpenAIToolCall) -> ToolCallPart {
1172 let args = tool_call
1173 .function
1174 .arguments
1175 .as_ref()
1176 .and_then(|arg| serde_json::from_str::<Value>(arg).ok())
1177 .unwrap_or_else(|| {
1178 tool_call
1179 .function
1180 .arguments
1181 .clone()
1182 .map(Value::String)
1183 .unwrap_or_else(|| Value::Object(Map::new()))
1184 });
1185
1186 ToolCallPart {
1187 id: normalize_tool_call_id(tool_call.id.clone()),
1188 name: tool_call
1189 .function
1190 .name
1191 .clone()
1192 .unwrap_or_else(|| "tool".to_string()),
1193 arguments: args,
1194 }
1195 }
1196}
1197
1198#[async_trait]
1199impl Model for OpenAIChatModel {
1200 fn name(&self) -> &str {
1201 &self.model
1202 }
1203
1204 async fn request(
1205 &self,
1206 messages: &[ModelMessage],
1207 settings: Option<&ModelSettings>,
1208 params: &ModelRequestParameters,
1209 ) -> Result<ModelResponse, ModelError> {
1210 tracing::debug!(
1211 model = %self.model,
1212 tool_count = params.function_tools.len(),
1213 output_schema = params.output_schema.is_some(),
1214 "OpenAI chat request"
1215 );
1216 let request = self.build_request(messages, settings, params, false)?;
1217
1218 let response = self
1219 .client
1220 .post(self.endpoint()?)
1221 .bearer_auth(&self.api_key)
1222 .json(&request.body)
1223 .send()
1224 .await
1225 .map_err(|e| map_reqwest_error("OpenAI", e))?;
1226
1227 let status = response.status();
1228 if !status.is_success() {
1229 let body = response.text().await.unwrap_or_default();
1230 tracing::error!(
1231 status = status.as_u16(),
1232 model = %self.model,
1233 body = %truncate_error_body(&body),
1234 "OpenAI chat request failed"
1235 );
1236 return Err(ModelError::HttpStatus {
1237 status: status.as_u16(),
1238 });
1239 }
1240
1241 let body: OpenAIChatResponse = response.json().await.map_err(|e| {
1242 tracing::error!(error = %e, model = %self.model, "OpenAI response parse failed");
1243 ModelError::Provider(format!("OpenAI response parse failed: {e}"))
1244 })?;
1245
1246 let choice = body.choices.into_iter().next().ok_or_else(|| {
1247 tracing::error!(model = %self.model, "OpenAI response missing choices");
1248 ModelError::Provider("OpenAI response missing choices".to_string())
1249 })?;
1250
1251 let mut parts = Vec::new();
1252 if let Some(content) = choice.message.content {
1253 parts.push(ModelResponsePart::Text(TextPart { content }));
1254 }
1255
1256 if let Some(tool_calls) = choice.message.tool_calls {
1257 for call in tool_calls {
1258 parts.push(ModelResponsePart::ToolCall(Self::parse_tool_call(&call)));
1259 }
1260 } else if let Some(function_call) = choice.message.function_call {
1261 parts.push(ModelResponsePart::ToolCall(ToolCallPart {
1262 id: normalize_tool_call_id(None),
1263 name: function_call.name.unwrap_or_else(|| "tool".to_string()),
1264 arguments: function_call
1265 .arguments
1266 .as_ref()
1267 .and_then(|arg| serde_json::from_str::<Value>(arg).ok())
1268 .unwrap_or_else(|| {
1269 function_call
1270 .arguments
1271 .clone()
1272 .map(Value::String)
1273 .unwrap_or_else(|| Value::Object(Map::new()))
1274 }),
1275 }));
1276 }
1277
1278 let usage = body.usage.map(|usage| RequestUsage {
1279 input_tokens: usage.prompt_tokens.unwrap_or(0),
1280 output_tokens: usage.completion_tokens.unwrap_or(0),
1281 ..Default::default()
1282 });
1283
1284 Ok(ModelResponse {
1285 parts,
1286 usage,
1287 model_name: Some(self.model.clone()),
1288 finish_reason: choice.finish_reason,
1289 })
1290 }
1291
1292 async fn request_stream(
1293 &self,
1294 messages: &[ModelMessage],
1295 settings: Option<&ModelSettings>,
1296 params: &ModelRequestParameters,
1297 ) -> Result<ModelStream, ModelError> {
1298 tracing::debug!(
1299 model = %self.model,
1300 tool_count = params.function_tools.len(),
1301 output_schema = params.output_schema.is_some(),
1302 "OpenAI stream request"
1303 );
1304 let request = self.build_request(messages, settings, params, true)?;
1305
1306 let response = self
1307 .client
1308 .post(self.endpoint()?)
1309 .bearer_auth(&self.api_key)
1310 .json(&request.body)
1311 .send()
1312 .await
1313 .map_err(|e| map_reqwest_error("OpenAI stream", e))?;
1314
1315 let status = response.status();
1316 if !status.is_success() {
1317 let body = response.text().await.unwrap_or_default();
1318 tracing::error!(
1319 status = status.as_u16(),
1320 model = %self.model,
1321 body = %truncate_error_body(&body),
1322 "OpenAI stream request failed"
1323 );
1324 return Err(ModelError::HttpStatus {
1325 status: status.as_u16(),
1326 });
1327 }
1328
1329 let mut event_stream = response.bytes_stream().eventsource();
1330 let model_name = self.model.clone();
1331
1332 let s = try_stream! {
1333 let mut tool_accumulator: HashMap<String, ToolAccumulator> = HashMap::new();
1334 while let Some(event) = event_stream.next().await {
1335 let event = event.map_err(|e| {
1336 tracing::error!(error = %e, model = %model_name, "OpenAI stream error");
1337 ModelError::Provider(format!("OpenAI stream error: {e}"))
1338 })?;
1339 let data = event.data;
1340 if data.trim() == "[DONE]" {
1341 if !tool_accumulator.is_empty() {
1342 for (_id, acc) in tool_accumulator.drain() {
1343 let args = serde_json::from_str::<Value>(&acc.arguments)
1344 .unwrap_or_else(|_| Value::String(acc.arguments.clone()));
1345 yield StreamChunk {
1346 text_delta: None,
1347 tool_call: Some(ToolCallPart {
1348 id: acc.id.clone(),
1349 name: acc.name.unwrap_or_else(|| "tool".to_string()),
1350 arguments: args,
1351 }),
1352 finish_reason: None,
1353 usage: None,
1354 };
1355 }
1356 }
1357 break;
1358 }
1359
1360 let chunk: OpenAIChatStreamResponse = serde_json::from_str(&data)
1361 .map_err(|e| {
1362 tracing::error!(error = %e, model = %model_name, "OpenAI stream parse error");
1363 ModelError::Provider(format!("OpenAI stream parse error: {e}"))
1364 })?;
1365 if let Some(choice) = chunk.choices.into_iter().next() {
1366 if let Some(content) = choice.delta.content {
1367 yield StreamChunk {
1368 text_delta: Some(content),
1369 tool_call: None,
1370 finish_reason: None,
1371 usage: None,
1372 };
1373 }
1374
1375 if let Some(tool_calls) = choice.delta.tool_calls {
1376 for call in tool_calls {
1377 let id = normalize_stream_tool_call_id(call.id.clone(), call.index);
1378 let entry = tool_accumulator.entry(id.clone()).or_insert_with(|| ToolAccumulator {
1379 id,
1380 name: None,
1381 arguments: String::new(),
1382 });
1383 if let Some(name) = call.function.name {
1384 entry.name = Some(name);
1385 }
1386 if let Some(args) = call.function.arguments {
1387 entry.arguments.push_str(&args);
1388 }
1389 }
1390 }
1391
1392 if let Some(reason) = choice.finish_reason.clone() {
1393 if !tool_accumulator.is_empty() {
1394 for (_id, acc) in tool_accumulator.drain() {
1395 let args = serde_json::from_str::<Value>(&acc.arguments)
1396 .unwrap_or_else(|_| Value::String(acc.arguments.clone()));
1397 yield StreamChunk {
1398 text_delta: None,
1399 tool_call: Some(ToolCallPart {
1400 id: acc.id.clone(),
1401 name: acc.name.unwrap_or_else(|| "tool".to_string()),
1402 arguments: args,
1403 }),
1404 finish_reason: Some(reason.clone()),
1405 usage: None,
1406 };
1407 }
1408 }
1409 yield StreamChunk {
1410 text_delta: None,
1411 tool_call: None,
1412 finish_reason: Some(reason),
1413 usage: chunk.usage.map(|usage| RequestUsage {
1414 input_tokens: usage.prompt_tokens.unwrap_or(0),
1415 output_tokens: usage.completion_tokens.unwrap_or(0),
1416 ..Default::default()
1417 }),
1418 };
1419 }
1420 }
1421 }
1422 };
1423
1424 Ok(Box::pin(s))
1425 }
1426}
1427
1428#[derive(Debug, Deserialize)]
1429struct OpenAIChatResponse {
1430 choices: Vec<OpenAIChoice>,
1431 usage: Option<OpenAIUsage>,
1432}
1433
1434#[derive(Debug, Deserialize)]
1435struct OpenAIChoice {
1436 message: OpenAIMessage,
1437 finish_reason: Option<String>,
1438}
1439
1440#[derive(Debug, Deserialize)]
1441struct OpenAIMessage {
1442 content: Option<String>,
1443 tool_calls: Option<Vec<OpenAIToolCall>>,
1444 function_call: Option<OpenAIFunctionCall>,
1445}
1446
1447#[derive(Debug, Deserialize)]
1448struct OpenAIToolCall {
1449 id: Option<String>,
1450 function: OpenAIToolFunction,
1451}
1452
1453#[derive(Debug, Deserialize)]
1454struct OpenAIToolFunction {
1455 name: Option<String>,
1456 arguments: Option<String>,
1457}
1458
1459#[derive(Debug, Deserialize)]
1460struct OpenAIFunctionCall {
1461 name: Option<String>,
1462 arguments: Option<String>,
1463}
1464
1465#[derive(Debug, Deserialize)]
1466struct OpenAIUsage {
1467 prompt_tokens: Option<u64>,
1468 completion_tokens: Option<u64>,
1469}
1470
1471#[derive(Debug, Deserialize)]
1472struct OpenAIChatStreamResponse {
1473 choices: Vec<OpenAIChatStreamChoice>,
1474 usage: Option<OpenAIUsage>,
1475}
1476
1477#[derive(Debug, Deserialize)]
1478struct OpenAIChatStreamChoice {
1479 delta: OpenAIChatStreamDelta,
1480 finish_reason: Option<String>,
1481}
1482
1483#[derive(Debug, Deserialize)]
1484struct OpenAIChatStreamDelta {
1485 content: Option<String>,
1486 tool_calls: Option<Vec<OpenAIStreamToolCall>>,
1487}
1488
1489#[derive(Debug, Deserialize)]
1490struct OpenAIStreamToolCall {
1491 id: Option<String>,
1492 index: Option<usize>,
1493 function: OpenAIStreamToolFunction,
1494}
1495
1496#[derive(Debug, Deserialize)]
1497struct OpenAIStreamToolFunction {
1498 name: Option<String>,
1499 arguments: Option<String>,
1500}
1501
1502#[derive(Debug)]
1503struct ToolAccumulator {
1504 id: String,
1505 name: Option<String>,
1506 arguments: String,
1507}
1508
1509#[derive(Debug, Deserialize)]
1510struct OpenAIResponsesStreamEvent {
1511 #[serde(rename = "type")]
1512 kind: String,
1513 response: Option<Value>,
1514 item: Option<Value>,
1515 delta: Option<String>,
1516}
1517
1518fn parse_responses_stream_usage(value: &Value) -> Option<RequestUsage> {
1519 let usage = value.get("usage")?;
1520 let input_tokens = usage
1521 .get("input_tokens")
1522 .and_then(|v| v.as_u64())
1523 .unwrap_or(0);
1524 let output_tokens = usage
1525 .get("output_tokens")
1526 .and_then(|v| v.as_u64())
1527 .unwrap_or(0);
1528 Some(RequestUsage {
1529 input_tokens,
1530 output_tokens,
1531 ..Default::default()
1532 })
1533}
1534
1535fn parse_responses_stream_tool_call(item: &Value) -> Option<ToolCallPart> {
1536 let item_type = item.get("type").and_then(|v| v.as_str())?;
1537 if item_type != "function_call" {
1538 return None;
1539 }
1540 let name = item
1541 .get("name")
1542 .and_then(|value| value.as_str())
1543 .unwrap_or("tool")
1544 .to_string();
1545 let call_id = item
1546 .get("call_id")
1547 .and_then(|value| value.as_str())
1548 .map(str::to_string)
1549 .or_else(|| {
1550 item.get("id")
1551 .and_then(|value| value.as_str())
1552 .map(str::to_string)
1553 });
1554 let arguments = item.get("arguments").cloned().unwrap_or(Value::Null);
1555 let args = match arguments {
1556 Value::String(value) => {
1557 serde_json::from_str::<Value>(&value).unwrap_or(Value::String(value))
1558 }
1559 other => other,
1560 };
1561 Some(ToolCallPart {
1562 id: normalize_tool_call_id(call_id),
1563 name,
1564 arguments: args,
1565 })
1566}
1567
1568#[derive(Clone, Debug)]
1569pub struct OpenAIUnifiedModel {
1570 model: String,
1571 chat: OpenAIChatModel,
1572 responses: OpenAIResponsesModel,
1573 responses_only: bool,
1574 prefer_responses: bool,
1575}
1576
1577impl OpenAIUnifiedModel {
1578 pub fn new(
1579 model: impl Into<String>,
1580 api_key: String,
1581 base_url: Url,
1582 settings: Option<ModelSettings>,
1583 ) -> Self {
1584 let model = model.into();
1585 let responses_only = is_responses_only_model(&model);
1586 let prefer_responses = prefers_responses(&model);
1587 Self {
1588 chat: OpenAIChatModel::new(
1589 model.clone(),
1590 api_key.clone(),
1591 base_url.clone(),
1592 settings.clone(),
1593 ),
1594 responses: OpenAIResponsesModel::new(model.clone(), api_key, base_url, settings),
1595 model,
1596 responses_only,
1597 prefer_responses,
1598 }
1599 }
1600
1601 fn select_api(
1602 &self,
1603 messages: &[ModelMessage],
1604 stream: bool,
1605 ) -> Result<OpenAIApiMode, ModelError> {
1606 if contains_audio(messages) {
1607 if self.responses_only {
1608 return Err(ModelError::Unsupported(
1609 "OpenAI Responses API does not support audio input".to_string(),
1610 ));
1611 }
1612 return Ok(OpenAIApiMode::Chat);
1613 }
1614 if stream {
1615 if self.responses_only {
1616 return Ok(OpenAIApiMode::Responses);
1617 }
1618 return Ok(OpenAIApiMode::Chat);
1619 }
1620 if self.prefer_responses || self.responses_only {
1621 Ok(OpenAIApiMode::Responses)
1622 } else {
1623 Ok(OpenAIApiMode::Chat)
1624 }
1625 }
1626}
1627
1628#[derive(Clone, Copy, Debug)]
1629enum OpenAIApiMode {
1630 Chat,
1631 Responses,
1632}
1633
1634#[async_trait]
1635impl Model for OpenAIUnifiedModel {
1636 fn name(&self) -> &str {
1637 &self.model
1638 }
1639
1640 async fn request(
1641 &self,
1642 messages: &[ModelMessage],
1643 settings: Option<&ModelSettings>,
1644 params: &ModelRequestParameters,
1645 ) -> Result<ModelResponse, ModelError> {
1646 match self.select_api(messages, false)? {
1647 OpenAIApiMode::Chat => self.chat.request(messages, settings, params).await,
1648 OpenAIApiMode::Responses => self.responses.request(messages, settings, params).await,
1649 }
1650 }
1651
1652 async fn request_stream(
1653 &self,
1654 messages: &[ModelMessage],
1655 settings: Option<&ModelSettings>,
1656 params: &ModelRequestParameters,
1657 ) -> Result<ModelStream, ModelError> {
1658 match self.select_api(messages, true)? {
1659 OpenAIApiMode::Chat => self.chat.request_stream(messages, settings, params).await,
1660 OpenAIApiMode::Responses => {
1661 self.responses
1662 .request_stream(messages, settings, params)
1663 .await
1664 }
1665 }
1666 }
1667}
1668
1669#[derive(Clone, Debug)]
1670pub struct OpenAIResponsesModel {
1671 model: String,
1672 api_key: String,
1673 base_url: Url,
1674 client: Client,
1675 default_settings: Option<ModelSettings>,
1676}
1677
1678impl OpenAIResponsesModel {
1679 pub fn new(
1680 model: impl Into<String>,
1681 api_key: String,
1682 base_url: Url,
1683 settings: Option<ModelSettings>,
1684 ) -> Self {
1685 Self {
1686 model: model.into(),
1687 api_key,
1688 base_url,
1689 client: Client::new(),
1690 default_settings: settings,
1691 }
1692 }
1693
1694 fn endpoint(&self) -> Result<Url, ModelError> {
1695 join_path(&self.base_url, "responses")
1696 }
1697
1698 fn filename_for_media_type(media_type: &str) -> String {
1699 let ext = match media_type {
1700 "application/pdf" => "pdf",
1701 "text/plain" => "txt",
1702 "text/markdown" => "md",
1703 "application/json" => "json",
1704 _ => "bin",
1705 };
1706 format!("file.{ext}")
1707 }
1708
1709 fn make_input_messages(&self, messages: &[ModelMessage]) -> Result<Vec<Value>, ModelError> {
1710 let mut out = Vec::new();
1711 for message in messages {
1712 match message {
1713 ModelMessage::Request(req) => {
1714 if let Some(instructions) = req
1715 .instructions
1716 .as_ref()
1717 .filter(|value| !value.trim().is_empty())
1718 {
1719 out.push(json!({"role": "system", "content": instructions}));
1720 }
1721 for part in &req.parts {
1722 match part {
1723 ModelRequestPart::SystemPrompt(prompt) => {
1724 out.push(json!({"role": "system", "content": prompt.content}))
1725 }
1726 ModelRequestPart::UserPrompt(prompt) => {
1727 let content = self.convert_user_content(&prompt.content)?;
1728 out.push(json!({"role": "user", "content": content}))
1729 }
1730 ModelRequestPart::ToolReturn(tool_return) => {
1731 let content = tool_return_content(&tool_return.content);
1732 out.push(json!({
1733 "type": "function_call_output",
1734 "call_id": normalize_tool_call_id_str(&tool_return.tool_call_id),
1735 "output": content,
1736 }))
1737 }
1738 ModelRequestPart::RetryPrompt(retry) => {
1739 if retry.tool_name.is_some() {
1740 out.push(json!({
1741 "type": "function_call_output",
1742 "call_id": normalize_tool_call_id(retry.tool_call_id.clone()),
1743 "output": retry.content,
1744 }));
1745 } else {
1746 out.push(json!({
1747 "role": "user",
1748 "content": [ { "type": "input_text", "text": retry.content } ],
1749 }));
1750 }
1751 }
1752 }
1753 }
1754 }
1755 ModelMessage::Response(res) => {
1756 let provider_items: Vec<Value> = res
1757 .parts
1758 .iter()
1759 .filter_map(|part| match part {
1760 ModelResponsePart::ProviderItem(item)
1761 if item.provider == "openai_responses" =>
1762 {
1763 Some(item.payload.clone())
1764 }
1765 _ => None,
1766 })
1767 .collect();
1768 if !provider_items.is_empty() {
1769 out.extend(provider_items);
1770 continue;
1771 }
1772 if let Some(text) = res.text() {
1773 out.push(json!({"role": "assistant", "content": text}));
1774 }
1775 for call in res.tool_calls() {
1776 let args = tool_call_arguments(&call.arguments);
1777 out.push(json!({
1778 "type": "function_call",
1779 "call_id": normalize_tool_call_id_str(&call.id),
1780 "name": call.name,
1781 "arguments": args,
1782 }));
1783 }
1784 }
1785 }
1786 }
1787 Ok(out)
1788 }
1789
1790 fn convert_user_content(&self, content: &[UserContent]) -> Result<Value, ModelError> {
1791 let mut parts = Vec::new();
1792 for item in content {
1793 match item {
1794 UserContent::Text(text) => parts.push(json!({"type": "input_text", "text": text})),
1795 UserContent::Image(image) => parts.push(json!({
1796 "type": "input_image",
1797 "image_url": image.url
1798 })),
1799 UserContent::Binary(BinaryContent { data, media_type }) => {
1800 if media_type.starts_with("image/") {
1801 let encoded = general_purpose::STANDARD.encode(data);
1802 let data_url = format!("data:{};base64,{}", media_type, encoded);
1803 parts.push(json!({
1804 "type": "input_image",
1805 "image_url": data_url
1806 }));
1807 } else if media_type == "application/pdf" {
1808 let encoded = general_purpose::STANDARD.encode(data);
1809 let data_url = format!("data:{};base64,{}", media_type, encoded);
1810 parts.push(json!({
1811 "type": "input_file",
1812 "file_data": data_url,
1813 "filename": Self::filename_for_media_type(media_type),
1814 }));
1815 } else if is_text_like_media_type(media_type) {
1816 match std::str::from_utf8(data) {
1817 Ok(text) => parts.push(json!({"type": "input_text", "text": text})),
1818 Err(_) => parts.push(json!({
1819 "type": "input_text",
1820 "text": format!("[binary content: {} bytes]", data.len())
1821 })),
1822 }
1823 } else {
1824 parts.push(json!({
1825 "type": "input_text",
1826 "text": format!("[binary content: {} bytes]", data.len())
1827 }))
1828 }
1829 }
1830 UserContent::Document(doc) => {
1831 if let Some((media_type, data)) = parse_data_url_base64(&doc.url) {
1832 let data_url = format!("data:{};base64,{}", media_type, data);
1833 parts.push(json!({
1834 "type": "input_file",
1835 "file_data": data_url,
1836 "filename": Self::filename_for_media_type(&media_type),
1837 }));
1838 } else {
1839 parts.push(json!({
1840 "type": "input_file",
1841 "file_url": doc.url
1842 }));
1843 }
1844 }
1845 UserContent::Audio(audio) => parts.push(json!({
1846 "type": "input_text",
1847 "text": format!("[audio: {}]", audio.url)
1848 })),
1849 UserContent::Video(video) => parts.push(json!({
1850 "type": "input_text",
1851 "text": format!("[video: {}]", video.url)
1852 })),
1853 }
1854 }
1855 Ok(Value::Array(parts))
1856 }
1857
1858 fn build_body(
1859 &self,
1860 messages: &[ModelMessage],
1861 params: &ModelRequestParameters,
1862 stream: bool,
1863 ) -> Result<Value, ModelError> {
1864 let mut body = Map::new();
1865 body.insert("model".to_string(), Value::String(self.model.clone()));
1866 body.insert(
1867 "input".to_string(),
1868 Value::Array(self.make_input_messages(messages)?),
1869 );
1870
1871 if !params.function_tools.is_empty() {
1872 let tools = params
1873 .function_tools
1874 .iter()
1875 .map(|tool| {
1876 let (schema, _strict_ok) =
1877 transform_openai_schema(&tool.parameters_json_schema, None);
1878 json!({
1879 "type": "function",
1880 "name": tool.name,
1881 "description": tool.description,
1882 "parameters": schema,
1883 })
1884 })
1885 .collect();
1886 body.insert("tools".to_string(), Value::Array(tools));
1887 if params.function_tools.iter().any(|tool| tool.sequential) {
1888 body.insert("parallel_tool_calls".to_string(), Value::Bool(false));
1889 }
1890 }
1891
1892 if params.output_mode == OutputMode::JsonSchema
1893 && let Some(schema) = params.output_schema.clone()
1894 {
1895 let strict = !params.allow_text_output;
1896 let (schema, _strict_ok) = transform_openai_schema(&schema, Some(strict));
1897 body.insert(
1898 "text".to_string(),
1899 json!({
1900 "format": {
1901 "type": "json_schema",
1902 "name": "output",
1903 "schema": schema,
1904 "strict": strict,
1905 }
1906 }),
1907 );
1908 }
1909
1910 if stream {
1911 body.insert("stream".to_string(), Value::Bool(true));
1912 }
1913
1914 if let Some(settings) = &self.default_settings {
1915 for (key, value) in settings {
1916 if key == "max_tokens" {
1917 body.insert("max_output_tokens".to_string(), value.clone());
1918 continue;
1919 }
1920 body.insert(key.clone(), value.clone());
1921 }
1922 }
1923
1924 Ok(Value::Object(body))
1925 }
1926
1927 fn build_request(
1928 &self,
1929 messages: &[ModelMessage],
1930 settings: Option<&ModelSettings>,
1931 params: &ModelRequestParameters,
1932 stream: bool,
1933 ) -> Result<OpenAIRequest, ModelError> {
1934 let mut body = self.build_body(messages, params, stream)?;
1935 if let Some(settings) = settings
1936 && let Value::Object(map) = &mut body
1937 {
1938 for (key, value) in settings {
1939 if key == "max_tokens" {
1940 map.insert("max_output_tokens".to_string(), value.clone());
1941 continue;
1942 }
1943 map.insert(key.clone(), value.clone());
1944 }
1945 }
1946
1947 Ok(OpenAIRequest { body })
1948 }
1949}
1950
1951#[async_trait]
1952impl Model for OpenAIResponsesModel {
1953 fn name(&self) -> &str {
1954 &self.model
1955 }
1956
1957 async fn request(
1958 &self,
1959 messages: &[ModelMessage],
1960 settings: Option<&ModelSettings>,
1961 params: &ModelRequestParameters,
1962 ) -> Result<ModelResponse, ModelError> {
1963 tracing::debug!(
1964 model = %self.model,
1965 tool_count = params.function_tools.len(),
1966 output_schema = params.output_schema.is_some(),
1967 "OpenAI responses request"
1968 );
1969 let request = self.build_request(messages, settings, params, false)?;
1970
1971 let response = self
1972 .client
1973 .post(self.endpoint()?)
1974 .bearer_auth(&self.api_key)
1975 .json(&request.body)
1976 .send()
1977 .await
1978 .map_err(|e| map_reqwest_error("OpenAI Responses", e))?;
1979
1980 let status = response.status();
1981 if !status.is_success() {
1982 let body = response.text().await.unwrap_or_default();
1983 tracing::error!(
1984 status = status.as_u16(),
1985 model = %self.model,
1986 body = %truncate_error_body(&body),
1987 "OpenAI responses request failed"
1988 );
1989 return Err(ModelError::HttpStatus {
1990 status: status.as_u16(),
1991 });
1992 }
1993
1994 let body: OpenAIResponsesResponse = response.json().await.map_err(|e| {
1995 tracing::error!(
1996 error = %e,
1997 model = %self.model,
1998 "OpenAI responses parse failed"
1999 );
2000 ModelError::Provider(format!("OpenAI response parse failed: {e}"))
2001 })?;
2002
2003 let mut parts = Vec::new();
2004 for item in body.output {
2005 parts.push(ModelResponsePart::ProviderItem(ProviderItemPart {
2006 provider: "openai_responses".to_string(),
2007 payload: item.clone(),
2008 }));
2009
2010 if let Some(item_type) = item.get("type").and_then(|value| value.as_str()) {
2011 match item_type {
2012 "message" => {
2013 if let Some(content) =
2014 item.get("content").and_then(|value| value.as_array())
2015 {
2016 for part in content {
2017 if part.get("type").and_then(|value| value.as_str())
2018 == Some("output_text")
2019 && let Some(text) =
2020 part.get("text").and_then(|value| value.as_str())
2021 {
2022 parts.push(ModelResponsePart::Text(TextPart {
2023 content: text.to_string(),
2024 }));
2025 }
2026 }
2027 }
2028 }
2029 "function_call" => {
2030 let name = item
2031 .get("name")
2032 .and_then(|value| value.as_str())
2033 .unwrap_or("tool")
2034 .to_string();
2035 let call_id = item
2036 .get("call_id")
2037 .and_then(|value| value.as_str())
2038 .map(str::to_string);
2039 let arguments = item.get("arguments").cloned().unwrap_or(Value::Null);
2040 let args = match arguments {
2041 Value::String(value) => serde_json::from_str::<Value>(&value)
2042 .unwrap_or(Value::String(value)),
2043 other => other,
2044 };
2045 parts.push(ModelResponsePart::ToolCall(ToolCallPart {
2046 id: normalize_tool_call_id(call_id),
2047 name,
2048 arguments: args,
2049 }));
2050 }
2051 _ => {}
2052 }
2053 }
2054 }
2055
2056 let usage = body.usage.map(|usage| RequestUsage {
2057 input_tokens: usage.input_tokens.unwrap_or(0),
2058 output_tokens: usage.output_tokens.unwrap_or(0),
2059 ..Default::default()
2060 });
2061
2062 Ok(ModelResponse {
2063 parts,
2064 usage,
2065 model_name: body.model.or_else(|| Some(self.model.clone())),
2066 finish_reason: body.finish_reason,
2067 })
2068 }
2069
2070 async fn request_stream(
2071 &self,
2072 messages: &[ModelMessage],
2073 settings: Option<&ModelSettings>,
2074 params: &ModelRequestParameters,
2075 ) -> Result<ModelStream, ModelError> {
2076 tracing::debug!(
2077 model = %self.model,
2078 tool_count = params.function_tools.len(),
2079 output_schema = params.output_schema.is_some(),
2080 "OpenAI responses stream request"
2081 );
2082 let request = self.build_request(messages, settings, params, true)?;
2083
2084 let response = self
2085 .client
2086 .post(self.endpoint()?)
2087 .bearer_auth(&self.api_key)
2088 .json(&request.body)
2089 .send()
2090 .await
2091 .map_err(|e| map_reqwest_error("OpenAI Responses stream", e))?;
2092
2093 let status = response.status();
2094 if !status.is_success() {
2095 let body = response.text().await.unwrap_or_default();
2096 tracing::error!(
2097 status = status.as_u16(),
2098 model = %self.model,
2099 body = %truncate_error_body(&body),
2100 "OpenAI responses stream request failed"
2101 );
2102 return Err(ModelError::HttpStatus {
2103 status: status.as_u16(),
2104 });
2105 }
2106
2107 let mut event_stream = response.bytes_stream().eventsource();
2108 let model_name = self.model.clone();
2109
2110 let s = try_stream! {
2111 while let Some(event) = event_stream.next().await {
2112 let event = event.map_err(|e| {
2113 tracing::error!(error = %e, model = %model_name, "OpenAI responses stream error");
2114 ModelError::Provider(format!("OpenAI responses stream error: {e}"))
2115 })?;
2116 let data = event.data;
2117 if data.trim() == "[DONE]" {
2118 break;
2119 }
2120 let event: OpenAIResponsesStreamEvent = serde_json::from_str(&data).map_err(|e| {
2121 tracing::error!(error = %e, model = %model_name, "OpenAI responses stream parse error");
2122 ModelError::Provider(format!("OpenAI responses stream parse error: {e}"))
2123 })?;
2124
2125 match event.kind.as_str() {
2126 "response.output_text.delta" => {
2127 if let Some(delta) = event.delta {
2128 yield StreamChunk {
2129 text_delta: Some(delta),
2130 tool_call: None,
2131 finish_reason: None,
2132 usage: None,
2133 };
2134 }
2135 }
2136 "response.output_item.done" => {
2137 if let Some(item) = event.item
2138 && let Some(call) = parse_responses_stream_tool_call(&item)
2139 {
2140 yield StreamChunk {
2141 text_delta: None,
2142 tool_call: Some(call),
2143 finish_reason: None,
2144 usage: None,
2145 };
2146 }
2147 }
2148 "response.completed" | "response.done" => {
2149 let usage = event
2150 .response
2151 .as_ref()
2152 .and_then(parse_responses_stream_usage);
2153 yield StreamChunk {
2154 text_delta: None,
2155 tool_call: None,
2156 finish_reason: Some("stop".to_string()),
2157 usage,
2158 };
2159 }
2160 "response.failed" => {
2161 let detail = event
2162 .response
2163 .map(|value| value.to_string())
2164 .unwrap_or_else(|| "response.failed".to_string());
2165 Err(ModelError::Provider(format!(
2166 "OpenAI responses stream failed: {detail}"
2167 )))?;
2168 }
2169 _ => {}
2170 }
2171 }
2172 };
2173
2174 Ok(Box::pin(s))
2175 }
2176}
2177
2178#[derive(Debug, Deserialize)]
2179struct OpenAIResponsesResponse {
2180 output: Vec<Value>,
2181 usage: Option<OpenAIResponsesUsage>,
2182 model: Option<String>,
2183 #[serde(rename = "finish_reason")]
2184 finish_reason: Option<String>,
2185}
2186
2187#[derive(Debug, Deserialize)]
2188struct OpenAIResponsesUsage {
2189 input_tokens: Option<u64>,
2190 output_tokens: Option<u64>,
2191}