1use ai_sdk_core::{CallOptions, Content, FinishReason, LanguageModel, Usage};
2use ai_sdk_provider::language_model::{
3 self, AssistantContentPart, Message, ResponseInfo, SourcePart, SourceType, TextPart,
4 ToolCallPart, UserContentPart,
5};
6use ai_sdk_provider::{GenerateResponse, Result, StreamPart, StreamResponse};
7use ai_sdk_provider_utils::merge_headers_reqwest;
8use async_stream::stream;
9use async_trait::async_trait;
10use futures::stream::StreamExt;
11use log::warn;
12use reqwest::Client;
13use std::collections::HashMap;
14
15use crate::error::OpenAIError;
16
17use crate::model_detection::{
18 is_reasoning_model, is_search_preview_model, supports_flex_processing,
19};
20use crate::openai_config::{OpenAIConfig, OpenAIUrlOptions};
21
22use super::{generate_source_id, options::OpenAIChatOptions};
23
24pub struct OpenAIChatModel {
48 model_id: String,
49 config: OpenAIConfig,
50 client: Client,
51}
52
53impl OpenAIChatModel {
54 pub fn new(model_id: impl Into<String>, config: impl Into<OpenAIConfig>) -> Self {
83 Self {
84 model_id: model_id.into(),
85 config: config.into(),
86 client: Client::new(),
87 }
88 }
89
90 fn convert_prompt_to_messages(&self, prompt: &[Message]) -> Vec<crate::api_types::ChatMessage> {
91 let mut openai_messages = Vec::new();
93
94 for msg in prompt {
95 match msg {
96 Message::System { content } => {
97 let role = if is_reasoning_model(&self.model_id) {
99 "developer"
100 } else {
101 "system"
102 };
103 openai_messages.push(crate::api_types::ChatMessage {
104 role: role.into(),
105 content: Some(crate::api_types::ChatMessageContent::Text(content.clone())),
106 tool_calls: None,
107 tool_call_id: None,
108 annotations: None,
109 });
110 }
111 Message::User { content } => {
112 let has_files = content
114 .iter()
115 .any(|part| matches!(part, UserContentPart::File { .. }));
116
117 if has_files {
118 let mut openai_content = Vec::new();
120
121 for part in content {
122 match part {
123 UserContentPart::Text { text } => {
124 openai_content.push(
125 crate::multimodal::OpenAIContentPart::Text {
126 text: text.clone(),
127 },
128 );
129 }
130 UserContentPart::File { data, media_type } => {
131 if media_type.starts_with("image/") {
133 match crate::multimodal::convert_image_part(
134 data, media_type,
135 ) {
136 Ok(part) => openai_content.push(part),
137 Err(e) => {
138 eprintln!(
140 "Warning: Failed to convert image: {}",
141 e
142 );
143 }
144 }
145 } else if media_type.starts_with("audio/") {
146 match crate::multimodal::convert_audio_part(
147 data, media_type,
148 ) {
149 Ok(part) => openai_content.push(part),
150 Err(e) => {
151 eprintln!(
152 "Warning: Failed to convert audio: {}",
153 e
154 );
155 }
156 }
157 } else {
158 eprintln!(
159 "Warning: Unsupported media type: {}",
160 media_type
161 );
162 }
163 }
164 }
165 }
166
167 openai_messages.push(crate::api_types::ChatMessage {
168 role: "user".into(),
169 content: Some(crate::api_types::ChatMessageContent::Parts(
170 openai_content,
171 )),
172 tool_calls: None,
173 tool_call_id: None,
174 annotations: None,
175 });
176 } else {
177 let text_content = content
179 .iter()
180 .filter_map(|part| match part {
181 UserContentPart::Text { text } => Some(text.clone()),
182 UserContentPart::File { .. } => None,
183 })
184 .collect::<Vec<_>>()
185 .join("\n");
186
187 openai_messages.push(crate::api_types::ChatMessage {
188 role: "user".into(),
189 content: Some(crate::api_types::ChatMessageContent::Text(text_content)),
190 tool_calls: None,
191 tool_call_id: None,
192 annotations: None,
193 });
194 }
195 }
196 Message::Assistant { content } => {
197 let mut text_content = String::new();
198 let mut tool_calls = Vec::new();
199
200 for part in content {
201 match part {
202 AssistantContentPart::Text(text_part) => {
203 text_content.push_str(&text_part.text);
204 }
205 AssistantContentPart::ToolCall(tool_call) => {
206 tool_calls.push(crate::api_types::OpenAIToolCall {
207 id: tool_call.tool_call_id.clone(),
208 r#type: "function".to_string(),
209 function: crate::api_types::OpenAIFunctionCall {
210 name: tool_call.tool_name.clone(),
211 arguments: tool_call.input.clone(),
212 },
213 });
214 }
215 _ => {}
217 }
218 }
219
220 openai_messages.push(crate::api_types::ChatMessage {
221 role: "assistant".into(),
222 content: if text_content.is_empty() {
223 None
224 } else {
225 Some(crate::api_types::ChatMessageContent::Text(text_content))
226 },
227 tool_calls: if tool_calls.is_empty() {
228 None
229 } else {
230 Some(tool_calls)
231 },
232 tool_call_id: None,
233 annotations: None,
234 });
235 }
236 Message::Tool { content } => {
237 for tool_result in content {
239 openai_messages.push(crate::api_types::ChatMessage {
240 role: "tool".into(),
241 content: Some(crate::api_types::ChatMessageContent::Text(
242 serde_json::to_string(&tool_result.output).unwrap_or_default(),
243 )),
244 tool_calls: None,
245 tool_call_id: Some(tool_result.tool_call_id.clone()),
246 annotations: None,
247 });
248 }
249 }
250 }
251 }
252
253 openai_messages
254 }
255
256 fn convert_tools(&self, tools: &[language_model::Tool]) -> Vec<crate::api_types::OpenAITool> {
257 tools
258 .iter()
259 .filter_map(|tool| match tool {
260 language_model::Tool::Function(function_tool) => {
261 Some(crate::api_types::OpenAITool {
262 r#type: "function".to_string(),
263 function: crate::api_types::OpenAIFunction {
264 name: function_tool.name.clone(),
265 description: function_tool.description.clone(),
266 parameters: function_tool.input_schema.clone(),
267 },
268 })
269 }
270 language_model::Tool::ProviderDefined(_) => None,
272 })
273 .collect()
274 }
275
276 fn convert_tool_choice(
277 &self,
278 tool_choice: &language_model::ToolChoice,
279 ) -> crate::api_types::OpenAIToolChoice {
280 match tool_choice {
281 language_model::ToolChoice::Auto => {
282 crate::api_types::OpenAIToolChoice::String("auto".to_string())
283 }
284 language_model::ToolChoice::None => {
285 crate::api_types::OpenAIToolChoice::String("none".to_string())
286 }
287 language_model::ToolChoice::Required => {
288 crate::api_types::OpenAIToolChoice::String("required".to_string())
289 }
290 language_model::ToolChoice::Tool { tool_name } => {
291 crate::api_types::OpenAIToolChoice::Specific {
292 r#type: "function".to_string(),
293 function: crate::api_types::OpenAIFunctionName {
294 name: tool_name.clone(),
295 },
296 }
297 }
298 }
299 }
300
301 fn convert_response_format(
302 &self,
303 response_format: &language_model::ResponseFormat,
304 ) -> crate::api_types::OpenAIResponseFormat {
305 match response_format {
306 language_model::ResponseFormat::Text => crate::api_types::OpenAIResponseFormat::Text,
307 language_model::ResponseFormat::Json {
308 schema,
309 name,
310 description,
311 } => {
312 if let Some(schema) = schema {
313 crate::api_types::OpenAIResponseFormat::JsonSchema {
315 json_schema: crate::api_types::OpenAIJsonSchema {
316 name: name.clone().unwrap_or_else(|| "response".to_string()),
317 description: description.clone(),
318 schema: schema.clone(),
319 strict: Some(true), },
321 }
322 } else {
323 crate::api_types::OpenAIResponseFormat::JsonObject
325 }
326 }
327 }
328 }
329
330 fn map_finish_reason(&self, reason: Option<&str>) -> FinishReason {
331 match reason {
332 Some("stop") => FinishReason::Stop,
333 Some("length") => FinishReason::Length,
334 Some("content_filter") => FinishReason::ContentFilter,
335 Some("tool_calls") => FinishReason::ToolCalls,
336 _ => FinishReason::Unknown,
337 }
338 }
339}
340
341#[async_trait]
342impl LanguageModel for OpenAIChatModel {
343 fn provider(&self) -> &str {
344 "openai"
345 }
346
347 fn model_id(&self) -> &str {
348 &self.model_id
349 }
350
351 async fn supported_urls(&self) -> HashMap<String, Vec<String>> {
352 let mut urls = HashMap::new();
353 urls.insert("image/*".into(), vec![r"^https?://.*$".into()]);
354 urls
355 }
356
357 async fn do_generate(&self, options: CallOptions) -> Result<GenerateResponse> {
358 let openai_opts = OpenAIChatOptions::from_provider_options(&options.provider_options);
360
361 let temperature = if is_search_preview_model(&self.model_id) {
363 if options.temperature.is_some() {
364 warn!(
365 "Temperature is not supported for search preview model '{}'. Removing temperature setting.",
366 self.model_id
367 );
368 }
369 None
370 } else {
371 options.temperature
372 };
373
374 let (max_tokens, max_completion_tokens) = if is_reasoning_model(&self.model_id) {
376 let mct = openai_opts
378 .max_completion_tokens
379 .or(options.max_output_tokens);
380 (None, mct)
381 } else {
382 (options.max_output_tokens, openai_opts.max_completion_tokens)
383 };
384
385 let service_tier = match &openai_opts.service_tier {
387 Some(tier) if tier == "flex" && !supports_flex_processing(&self.model_id) => {
388 warn!(
389 "Flex processing is not supported for model '{}'. Supported models: o3, o4-mini, gpt-5. Service tier will be ignored.",
390 self.model_id
391 );
392 None
393 }
394 tier => tier.clone(),
395 };
396
397 let request = crate::api_types::ChatCompletionRequest {
398 model: self.model_id.clone(),
399 messages: self.convert_prompt_to_messages(&options.prompt),
400 temperature,
401 max_tokens,
402 stream: Some(false),
403 tools: options.tools.as_ref().map(|t| self.convert_tools(t)),
404 tool_choice: options
405 .tool_choice
406 .as_ref()
407 .map(|tc| self.convert_tool_choice(tc)),
408 response_format: options
409 .response_format
410 .as_ref()
411 .map(|rf| self.convert_response_format(rf)),
412 stream_options: None, logit_bias: openai_opts.logit_bias,
416 logprobs: openai_opts.logprobs,
417 top_logprobs: None, user: openai_opts.user,
419 parallel_tool_calls: openai_opts.parallel_tool_calls,
420 reasoning_effort: openai_opts.reasoning_effort,
421 max_completion_tokens,
422 store: openai_opts.store,
423 metadata: openai_opts.metadata,
424 prediction: openai_opts.prediction,
425 service_tier,
426 verbosity: openai_opts.verbosity,
427 prompt_cache_key: openai_opts.prompt_cache_key,
428 safety_identifier: openai_opts.safety_identifier,
429 };
430
431 let url = (self.config.url)(OpenAIUrlOptions {
433 model_id: self.model_id.clone(),
434 path: "/chat/completions".to_string(),
435 });
436
437 let response = self
439 .client
440 .post(url)
441 .header("Content-Type", "application/json")
442 .headers(merge_headers_reqwest(
443 (self.config.headers)(),
444 options.headers.as_ref(),
445 ))
446 .json(&request)
447 .send()
448 .await?;
449
450 if !response.status().is_success() {
451 let status_code = response.status().as_u16();
452 let message = response.text().await.unwrap();
453 return Err(crate::error::OpenAIError::ApiError {
454 message,
455 status_code: Some(status_code),
456 }
457 .into());
458 }
459
460 let headers: HashMap<String, String> = response
462 .headers()
463 .iter()
464 .map(|(k, v)| (k.as_str().to_string(), v.to_str().unwrap_or("").to_string()))
465 .collect();
466
467 let api_response: crate::api_types::ChatCompletionResponse = response.json().await?;
468
469 let choice = &api_response.choices[0];
470
471 let mut content = Vec::new();
473
474 if let Some(message_content) = &choice.message.content {
476 let text = match message_content {
478 crate::api_types::ChatMessageContent::Text(s) => s.clone(),
479 crate::api_types::ChatMessageContent::Parts(parts) => {
480 parts
482 .iter()
483 .filter_map(|part| match part {
484 crate::multimodal::OpenAIContentPart::Text { text } => {
485 Some(text.clone())
486 }
487 _ => None,
488 })
489 .collect::<Vec<_>>()
490 .join("\n")
491 }
492 };
493
494 if !text.is_empty() {
495 content.push(Content::Text(TextPart {
496 text,
497 provider_metadata: None,
498 }));
499 }
500 }
501
502 if let Some(tool_calls) = &choice.message.tool_calls {
504 for tool_call in tool_calls {
505 content.push(Content::ToolCall(ToolCallPart {
506 tool_call_id: tool_call.id.clone(),
507 tool_name: tool_call.function.name.clone(),
508 input: tool_call.function.arguments.clone(),
509 provider_executed: None,
510 dynamic: None,
511 provider_metadata: None,
512 }));
513 }
514 }
515
516 if let Some(annotations) = &choice.message.annotations {
518 for annotation in annotations {
519 content.push(Content::Source(SourcePart {
520 id: generate_source_id(),
521 source_type: SourceType::Url,
522 url: Some(annotation.url.clone()),
523 title: Some(annotation.title.clone()),
524 provider_metadata: None,
525 }));
526 }
527 }
528
529 let usage = api_response
530 .usage
531 .as_ref()
532 .map(|u| Usage {
533 input_tokens: Some(u.prompt_tokens),
534 output_tokens: u.completion_tokens,
535 total_tokens: Some(u.total_tokens),
536 reasoning_tokens: u
537 .completion_tokens_details
538 .as_ref()
539 .and_then(|d| d.reasoning_tokens),
540 cached_input_tokens: u
541 .prompt_tokens_details
542 .as_ref()
543 .and_then(|d| d.cached_tokens),
544 })
545 .unwrap_or_default();
546
547 let finish_reason = if choice.message.tool_calls.is_some() {
549 FinishReason::ToolCalls
550 } else {
551 self.map_finish_reason(choice.finish_reason.as_deref())
552 };
553
554 let provider_metadata = {
556 let mut openai_metadata: HashMap<String, ai_sdk_provider::json_value::JsonValue> =
557 HashMap::new();
558
559 if let Some(logprobs) = &choice.logprobs {
561 if let Some(content_logprobs) = logprobs.get("content") {
562 let json_value: ai_sdk_provider::json_value::JsonValue =
563 serde_json::from_value(content_logprobs.clone())
564 .unwrap_or(ai_sdk_provider::json_value::JsonValue::Null);
565 openai_metadata.insert("logprobs".to_string(), json_value);
566 }
567 }
568
569 if let Some(usage_info) = &api_response.usage {
571 if let Some(details) = &usage_info.completion_tokens_details {
572 if let Some(accepted) = details.accepted_prediction_tokens {
573 openai_metadata.insert(
574 "acceptedPredictionTokens".to_string(),
575 ai_sdk_provider::json_value::JsonValue::Number(accepted.into()),
576 );
577 }
578 if let Some(rejected) = details.rejected_prediction_tokens {
579 openai_metadata.insert(
580 "rejectedPredictionTokens".to_string(),
581 ai_sdk_provider::json_value::JsonValue::Number(rejected.into()),
582 );
583 }
584 }
585 }
586
587 if openai_metadata.is_empty() {
588 None
589 } else {
590 let mut metadata = HashMap::new();
591 metadata.insert("openai".to_string(), openai_metadata);
592 Some(metadata)
593 }
594 };
595
596 let response_info = Some(ResponseInfo {
598 headers: Some(headers),
599 body: Some(serde_json::to_value(&api_response).unwrap_or(serde_json::json!({}))),
600 id: Some(api_response.id.clone()),
601 timestamp: Some({
602 let secs = api_response.created as i64;
604 let hours = secs / 3600;
605 let minutes = (secs % 3600) / 60;
606 let seconds = secs % 60;
607 format!("1970-01-01T{:02}:{:02}:{:02}Z", hours, minutes, seconds)
608 }),
609 model_id: Some(api_response.model.clone()),
610 });
611
612 Ok(GenerateResponse {
613 content,
614 finish_reason,
615 usage,
616 provider_metadata,
617 request: None,
618 response: response_info,
619 warnings: vec![],
620 })
621 }
622
623 async fn do_stream(&self, options: CallOptions) -> Result<StreamResponse> {
624 let openai_opts = OpenAIChatOptions::from_provider_options(&options.provider_options);
626
627 let temperature = if is_search_preview_model(&self.model_id) {
629 if options.temperature.is_some() {
630 warn!(
631 "Temperature is not supported for search preview model '{}'. Removing temperature setting.",
632 self.model_id
633 );
634 }
635 None
636 } else {
637 options.temperature
638 };
639
640 let (max_tokens, max_completion_tokens) = if is_reasoning_model(&self.model_id) {
642 let mct = openai_opts
644 .max_completion_tokens
645 .or(options.max_output_tokens);
646 (None, mct)
647 } else {
648 (options.max_output_tokens, openai_opts.max_completion_tokens)
649 };
650
651 let service_tier = match &openai_opts.service_tier {
653 Some(tier) if tier == "flex" && !supports_flex_processing(&self.model_id) => {
654 warn!(
655 "Flex processing is not supported for model '{}'. Supported models: o3, o4-mini, gpt-5. Service tier will be ignored.",
656 self.model_id
657 );
658 None
659 }
660 tier => tier.clone(),
661 };
662
663 let request = crate::api_types::ChatCompletionRequest {
664 model: self.model_id.clone(),
665 messages: self.convert_prompt_to_messages(&options.prompt),
666 temperature,
667 max_tokens,
668 stream: Some(true),
669 tools: options.tools.as_ref().map(|t| self.convert_tools(t)),
670 tool_choice: options
671 .tool_choice
672 .as_ref()
673 .map(|tc| self.convert_tool_choice(tc)),
674 response_format: options
675 .response_format
676 .as_ref()
677 .map(|rf| self.convert_response_format(rf)),
678 stream_options: Some(crate::api_types::StreamOptions {
679 include_usage: true,
680 }),
681
682 logit_bias: openai_opts.logit_bias,
684 logprobs: openai_opts.logprobs,
685 top_logprobs: None,
686 user: openai_opts.user,
687 parallel_tool_calls: openai_opts.parallel_tool_calls,
688 reasoning_effort: openai_opts.reasoning_effort,
689 max_completion_tokens,
690 store: openai_opts.store,
691 metadata: openai_opts.metadata,
692 prediction: openai_opts.prediction,
693 service_tier,
694 verbosity: openai_opts.verbosity,
695 prompt_cache_key: openai_opts.prompt_cache_key,
696 safety_identifier: openai_opts.safety_identifier,
697 };
698
699 let url = (self.config.url)(OpenAIUrlOptions {
701 model_id: self.model_id.clone(),
702 path: "/chat/completions".to_string(),
703 });
704
705 let response = self
707 .client
708 .post(url)
709 .header("Content-Type", "application/json")
710 .headers(merge_headers_reqwest(
711 (self.config.headers)(),
712 options.headers.as_ref(),
713 ))
714 .json(&request)
715 .send()
716 .await?;
717
718 let status = response.status();
719 if !status.is_success() {
720 return Err(crate::error::OpenAIError::ApiError {
721 message: format!("API returned status {}", status),
722 status_code: Some(status.as_u16()),
723 }
724 .into());
725 }
726
727 let stream_impl = stream! {
728 let mut byte_stream = response.bytes_stream();
729 let mut buffer = String::new();
730 let mut tool_calls: Vec<crate::api_types::OpenAIToolCall> = Vec::new();
731 let mut accumulated_usage: Option<Usage> = None;
732 let mut last_finish_reason: Option<FinishReason> = None;
733
734 while let Some(chunk_result) = byte_stream.next().await {
735 match chunk_result {
736 Ok(bytes) => {
737 buffer.push_str(&String::from_utf8_lossy(&bytes));
738
739 while let Some(line_end) = buffer.find('\n') {
741 let line = buffer[..line_end].trim().to_string();
742 buffer.drain(..line_end + 1);
743
744 if let Some(data) = line.strip_prefix("data: ") {
745 if data == "[DONE]" {
746 break;
747 }
748
749 if let Ok(chunk) = serde_json::from_str::<crate::api_types::ChatCompletionChunk>(data) {
751 if let Some(usage_info) = &chunk.usage {
753 accumulated_usage = Some(Usage {
754 input_tokens: Some(usage_info.prompt_tokens),
755 output_tokens: usage_info.completion_tokens,
756 total_tokens: Some(usage_info.total_tokens),
757 reasoning_tokens: usage_info
758 .completion_tokens_details
759 .as_ref()
760 .and_then(|d| d.reasoning_tokens),
761 cached_input_tokens: usage_info
762 .prompt_tokens_details
763 .as_ref()
764 .and_then(|d| d.cached_tokens),
765 });
766 }
767
768 if let Some(choice) = chunk.choices.first() {
770 if let Some(content) = &choice.delta.content {
772 yield Ok(StreamPart::TextDelta {
773 id: "0".into(),
774 delta: content.clone(),
775 provider_metadata: None,
776 });
777 }
778
779 if let Some(tool_call_deltas) = &choice.delta.tool_calls {
781 for tool_call_delta in tool_call_deltas {
782 let index = tool_call_delta.index as usize;
783
784 if tool_calls.len() <= index {
786 let tool_id = tool_call_delta.id.clone().unwrap_or_default();
787 let tool_name = tool_call_delta.function.name.clone().unwrap_or_default();
788
789 tool_calls.push(crate::api_types::OpenAIToolCall {
790 id: tool_id.clone(),
791 r#type: "function".to_string(),
792 function: crate::api_types::OpenAIFunctionCall {
793 name: tool_name.clone(),
794 arguments: String::new(),
795 },
796 });
797
798 yield Ok(StreamPart::ToolInputStart {
800 id: tool_id,
801 tool_name,
802 provider_metadata: None,
803 provider_executed: None,
804 dynamic: None,
805 title: None,
806 });
807 }
808
809 if let Some(args_delta) = &tool_call_delta.function.arguments {
811 tool_calls[index].function.arguments.push_str(args_delta);
812
813 yield Ok(StreamPart::ToolInputDelta {
815 id: tool_calls[index].id.clone(),
816 delta: args_delta.clone(),
817 provider_metadata: None,
818 });
819 }
820 }
821 }
822
823 if let Some(annotations) = &choice.delta.annotations {
825 for annotation in annotations {
826 yield Ok(StreamPart::Source(SourcePart {
827 id: generate_source_id(),
828 source_type: SourceType::Url,
829 url: Some(annotation.url.clone()),
830 title: Some(annotation.title.clone()),
831 provider_metadata: None,
832 }));
833 }
834 }
835
836 if let Some(finish_reason) = &choice.finish_reason {
838 if !finish_reason.is_empty() && finish_reason != "null" {
839 for tool_call in &tool_calls {
841 yield Ok(StreamPart::ToolInputEnd {
842 id: tool_call.id.clone(),
843 provider_metadata: None,
844 });
845
846 yield Ok(StreamPart::ToolCall(ToolCallPart {
847 tool_call_id: tool_call.id.clone(),
848 tool_name: tool_call.function.name.clone(),
849 input: tool_call.function.arguments.clone(),
850 provider_executed: None,
851 dynamic: None,
852 provider_metadata: None,
853 }));
854 }
855
856 let mapped_reason = match finish_reason.as_str() {
857 "stop" => FinishReason::Stop,
858 "length" => FinishReason::Length,
859 "content_filter" => FinishReason::ContentFilter,
860 "tool_calls" => FinishReason::ToolCalls,
861 _ => FinishReason::Unknown,
862 };
863
864 last_finish_reason = Some(mapped_reason);
867 }
868 }
869 }
870 }
871 }
872 }
873 }
874 Err(e) => {
875 yield Err(OpenAIError::NetworkError(e).into());
876 break;
877 }
878 }
879 }
880
881 if let Some(finish_reason) = last_finish_reason {
884 let usage_to_send = accumulated_usage.unwrap_or_default();
885 yield Ok(StreamPart::Finish {
886 usage: usage_to_send,
887 finish_reason,
888 provider_metadata: None,
889 });
890 }
891 };
892
893 Ok(StreamResponse {
894 stream: Box::pin(stream_impl),
895 request: None,
896 response: None,
897 })
898 }
899}