llm_sdk/openai/
chat_model.rs

1use super::chat_api::{
2    self, AssistantAudioData, AssistantAudioDataInner, AssistantMessageContent,
3    AssistantMessageContentInner, ChatCompletionAudioParams, ChatCompletionMessageToolCall,
4    ChatCompletionMessageToolCallUnion, ChatCompletionNamedToolChoice,
5    ChatCompletionRequestAssistantMessage, ChatCompletionRequestMessage,
6    ChatCompletionRequestMessageContentPartAudio, ChatCompletionRequestMessageContentPartImage,
7    ChatCompletionRequestMessageContentPartText, ChatCompletionRequestSystemMessage,
8    ChatCompletionRequestToolMessage, ChatCompletionRequestToolMessageContentPart,
9    ChatCompletionRequestUserMessage, ChatCompletionRequestUserMessageContentPart,
10    ChatCompletionStreamOptions, ChatCompletionStreamOptionsInner,
11    ChatCompletionStreamResponseDelta, ChatCompletionTool, ChatCompletionToolChoiceOption,
12    ChatCompletionToolUnion, CompletionUsage, CreateChatCompletionRequest,
13    CreateChatCompletionResponse, CreateChatCompletionStreamResponse,
14    CreateModelResponseProperties, FunctionObject, JsonSchemaConfig, ModelIdsShared,
15    ModelResponseProperties, NamedToolFunction, ReasoningEffort, ReasoningEffortEnum,
16    ResponseFormat, ResponseFormatJsonObject, ResponseFormatJsonSchema,
17    ResponseFormatJsonSchemaSchema, ResponseFormatText, ResponseModalities, ResponseModalityEnum,
18    ToolCallFunction, ToolChoiceString, ToolMessageContent, VoiceIdsShared,
19};
20use crate::{
21    client_utils, source_part_utils, stream_utils, AssistantMessage, AudioFormat, AudioOptions,
22    ContentDelta, LanguageModel, LanguageModelError, LanguageModelInput, LanguageModelMetadata,
23    LanguageModelResult, LanguageModelStream, Message, ModelResponse, ModelUsage, Part, PartDelta,
24    PartialModelResponse, ResponseFormatJson, ResponseFormatOption, Tool, ToolCallPart,
25    ToolChoiceOption, ToolChoiceTool, ToolMessage, UserMessage,
26};
27use async_stream::try_stream;
28use futures::{future::BoxFuture, StreamExt};
29use reqwest::{
30    header::{self, HeaderMap, HeaderName, HeaderValue},
31    Client,
32};
33use serde_json::Value;
34use std::{collections::HashMap, sync::Arc};
35
36const PROVIDER: &str = "openai";
37const OPENAI_AUDIO_SAMPLE_RATE: u32 = 24_000;
38const OPENAI_AUDIO_CHANNELS: u32 = 1;
39
40pub struct OpenAIChatModel {
41    model_id: String,
42    api_key: String,
43    base_url: String,
44    client: Client,
45    metadata: Option<Arc<LanguageModelMetadata>>,
46    headers: HashMap<String, String>,
47}
48
49#[derive(Clone, Default)]
50pub struct OpenAIChatModelOptions {
51    pub base_url: Option<String>,
52    pub api_key: String,
53    pub headers: Option<HashMap<String, String>>,
54    pub client: Option<Client>,
55}
56
57impl OpenAIChatModel {
58    #[must_use]
59    pub fn new(model_id: impl Into<String>, options: OpenAIChatModelOptions) -> Self {
60        let OpenAIChatModelOptions {
61            base_url,
62            api_key,
63            headers,
64            client,
65        } = options;
66
67        let base_url = base_url
68            .unwrap_or_else(|| "https://api.openai.com/v1".to_string())
69            .trim_end_matches('/')
70            .to_string();
71        let client = client.unwrap_or_else(Client::new);
72        let headers = headers.unwrap_or_default();
73
74        Self {
75            model_id: model_id.into(),
76            api_key,
77            base_url,
78            client,
79            metadata: None,
80            headers,
81        }
82    }
83
84    #[must_use]
85    pub fn with_metadata(mut self, metadata: LanguageModelMetadata) -> Self {
86        self.metadata = Some(Arc::new(metadata));
87        self
88    }
89
90    fn request_headers(&self) -> LanguageModelResult<HeaderMap> {
91        let mut headers = HeaderMap::new();
92
93        let auth_header =
94            HeaderValue::from_str(&format!("Bearer {}", self.api_key)).map_err(|error| {
95                LanguageModelError::InvalidInput(format!(
96                    "Invalid OpenAI API key header value: {error}"
97                ))
98            })?;
99        headers.insert(header::AUTHORIZATION, auth_header);
100
101        for (key, value) in &self.headers {
102            let header_name = HeaderName::from_bytes(key.as_bytes()).map_err(|error| {
103                LanguageModelError::InvalidInput(format!(
104                    "Invalid OpenAI header name '{key}': {error}"
105                ))
106            })?;
107            let header_value = HeaderValue::from_str(value).map_err(|error| {
108                LanguageModelError::InvalidInput(format!(
109                    "Invalid OpenAI header value for '{key}': {error}"
110                ))
111            })?;
112            headers.insert(header_name, header_value);
113        }
114
115        Ok(headers)
116    }
117}
118
119impl LanguageModel for OpenAIChatModel {
120    fn provider(&self) -> &'static str {
121        PROVIDER
122    }
123
124    fn model_id(&self) -> String {
125        self.model_id.clone()
126    }
127
128    fn metadata(&self) -> Option<&LanguageModelMetadata> {
129        self.metadata.as_deref()
130    }
131
132    fn generate(
133        &self,
134        input: LanguageModelInput,
135    ) -> BoxFuture<'_, LanguageModelResult<ModelResponse>> {
136        Box::pin(async move {
137            crate::opentelemetry::trace_generate(
138                self.provider(),
139                &self.model_id(),
140                input,
141                |input| async move {
142                    let (request, payload) =
143                        convert_to_openai_create_params(input, &self.model_id(), false)?;
144                    let headers = self.request_headers()?;
145
146                    let response: CreateChatCompletionResponse = client_utils::send_json(
147                        &self.client,
148                        &format!("{}/chat/completions", self.base_url),
149                        &payload,
150                        headers,
151                    )
152                    .await?;
153
154                    let choice = response.choices.into_iter().next().ok_or_else(|| {
155                        LanguageModelError::Invariant(
156                            PROVIDER,
157                            "No choices in response".to_string(),
158                        )
159                    })?;
160
161                    let message = choice.message;
162
163                    if let Some(refusal) = &message.refusal {
164                        if !refusal.is_empty() {
165                            return Err(LanguageModelError::Refusal(refusal.clone()));
166                        }
167                    }
168
169                    let content = map_openai_message(message, request.audio)?;
170
171                    let usage = response.usage.map(map_openai_usage).transpose()?;
172
173                    let cost = if let (Some(usage), Some(pricing)) = (
174                        usage.as_ref(),
175                        self.metadata().and_then(|m| m.pricing.as_ref()),
176                    ) {
177                        Some(usage.calculate_cost(pricing))
178                    } else {
179                        None
180                    };
181
182                    Ok(ModelResponse {
183                        content,
184                        usage,
185                        cost,
186                    })
187                },
188            )
189            .await
190        })
191    }
192
193    fn stream(
194        &self,
195        input: LanguageModelInput,
196    ) -> BoxFuture<'_, LanguageModelResult<LanguageModelStream>> {
197        Box::pin(async move {
198            crate::opentelemetry::trace_stream(
199                self.provider(),
200                &self.model_id(),
201                input,
202                |input| async move {
203                    let metadata = self.metadata.clone();
204                    let (request, payload) =
205                        convert_to_openai_create_params(input, &self.model_id(), true)?;
206                    let CreateChatCompletionRequest { audio: audio_params, .. } = request;
207                    let headers = self.request_headers()?;
208
209                    let mut stream =
210                        client_utils::send_sse_stream::<Value, CreateChatCompletionStreamResponse>(
211                            &self.client,
212                            &format!("{}/chat/completions", self.base_url),
213                            &payload,
214                            headers,
215                            PROVIDER,
216                        )
217                        .await?;
218
219                    let mut refusal = String::new();
220                    let mut content_deltas: Vec<ContentDelta> = Vec::new();
221
222                    let stream = try_stream! {
223                        while let Some(chunk) = stream.next().await {
224                            let chunk = chunk?;
225
226                            if let Some(choice) = chunk.choices.unwrap_or_default().into_iter().next() {
227                                let mut delta = choice.delta;
228
229                                if let Some(delta_refusal) = delta.refusal.take() {
230                                    refusal.push_str(&delta_refusal);
231                                }
232
233                                let deltas = map_openai_delta(
234                                    delta,
235                                    &content_deltas,
236                                    audio_params.as_ref(),
237                                )?;
238
239                                for delta in deltas {
240                                    content_deltas.push(delta.clone());
241                                    yield PartialModelResponse {
242                                        delta: Some(delta),
243                                        ..Default::default()
244                                    };
245                                }
246                            }
247
248                            if let Some(usage) = chunk.usage {
249                                let usage = map_openai_usage(usage)?;
250                                let cost = metadata
251                                    .as_ref()
252                                    .and_then(|m| m.pricing.as_ref())
253                                    .map(|pricing| usage.calculate_cost(pricing));
254
255                                yield PartialModelResponse {
256                                    delta: None,
257                                    usage: Some(usage),
258                                    cost,
259                                };
260                            }
261                        }
262
263                        if !refusal.is_empty() {
264                            Err(LanguageModelError::Refusal(refusal))?;
265                        }
266                    };
267
268                    Ok(LanguageModelStream::from_stream(stream))
269                },
270            )
271            .await
272        })
273    }
274}
275
276fn convert_to_openai_create_params(
277    input: LanguageModelInput,
278    model_id: &str,
279    stream: bool,
280) -> LanguageModelResult<(CreateChatCompletionRequest, Value)> {
281    let messages = convert_to_openai_messages(input.messages, input.system_prompt)?;
282
283    let modalities = input
284        .modalities
285        .as_ref()
286        .map(|modalities| -> LanguageModelResult<ResponseModalities> {
287            if modalities.is_empty() {
288                Ok(ResponseModalities::Null)
289            } else {
290                let converted = modalities
291                    .iter()
292                    .map(convert_to_openai_modality)
293                    .collect::<LanguageModelResult<Vec<_>>>()?;
294                Ok(ResponseModalities::Array(converted))
295            }
296        })
297        .transpose()?;
298
299    let create_model_response_properties = CreateModelResponseProperties {
300        model_response_properties: ModelResponseProperties {
301            prompt_cache_key: None,
302            safety_identifier: None,
303            service_tier: None,
304            temperature: input.temperature,
305            top_logprobs: None,
306            top_p: input.top_p,
307            ..Default::default()
308        },
309        top_logprobs: None,
310    };
311
312    let audio = input.audio.map(convert_to_openai_audio).transpose()?;
313
314    let reasoning_effort = input
315        .reasoning
316        .as_ref()
317        .and_then(|reasoning| reasoning.budget_tokens)
318        .map(convert_to_openai_reasoning_effort)
319        .transpose()?;
320
321    let request = CreateChatCompletionRequest {
322        create_model_response_properties,
323        audio,
324        frequency_penalty: input.frequency_penalty,
325        logit_bias: None,
326        logprobs: None,
327        max_completion_tokens: input
328            .max_tokens
329            .map(|value| {
330                i32::try_from(value).map_err(|_| {
331                    LanguageModelError::InvalidInput(
332                        "max_tokens exceeds supported range for OpenAI chat completions"
333                            .to_string(),
334                    )
335                })
336            })
337            .transpose()?,
338        messages,
339        modalities,
340        model: ModelIdsShared::String(model_id.to_string()),
341        n: None,
342        parallel_tool_calls: None,
343        prediction: None,
344        presence_penalty: input.presence_penalty,
345        reasoning_effort,
346        response_format: input.response_format.map(convert_to_openai_response_format),
347        seed: input.seed,
348        stop: None,
349        store: None,
350        stream: Some(stream),
351        stream_options: if stream {
352            Some(ChatCompletionStreamOptions::Options(
353                ChatCompletionStreamOptionsInner {
354                    include_obfuscation: None,
355                    include_usage: Some(true),
356                },
357            ))
358        } else {
359            None
360        },
361        tool_choice: input.tool_choice.map(convert_to_openai_tool_choice),
362        tools: input
363            .tools
364            .map(|tools| tools.into_iter().map(convert_to_openai_tool).collect()),
365        top_logprobs: None,
366        verbosity: None,
367        web_search_options: None,
368    };
369
370    let payload = merge_extra(&request, input.extra)?;
371
372    Ok((request, payload))
373}
374
375fn convert_to_openai_messages(
376    messages: Vec<Message>,
377    system_prompt: Option<String>,
378) -> LanguageModelResult<Vec<ChatCompletionRequestMessage>> {
379    let mut openai_messages = Vec::new();
380
381    if let Some(prompt) = system_prompt {
382        openai_messages.push(ChatCompletionRequestMessage::System(
383            ChatCompletionRequestSystemMessage {
384                content: chat_api::SystemMessageContent::Text(prompt),
385                name: None,
386            },
387        ));
388    }
389
390    for message in messages {
391        match message {
392            Message::User(user_message) => {
393                openai_messages.push(ChatCompletionRequestMessage::User(convert_user_message(
394                    user_message,
395                )?));
396            }
397            Message::Assistant(assistant_message) => {
398                openai_messages.push(ChatCompletionRequestMessage::Assistant(
399                    convert_assistant_message(assistant_message)?,
400                ));
401            }
402            Message::Tool(tool_message) => {
403                let tool_messages = convert_tool_message(tool_message)?;
404                openai_messages.extend(
405                    tool_messages
406                        .into_iter()
407                        .map(ChatCompletionRequestMessage::Tool),
408                );
409            }
410        }
411    }
412
413    Ok(openai_messages)
414}
415
416fn convert_user_message(
417    user_message: UserMessage,
418) -> LanguageModelResult<ChatCompletionRequestUserMessage> {
419    let parts = source_part_utils::get_compatible_parts_without_source_parts(user_message.content);
420    let mut content_parts = Vec::new();
421
422    for part in parts {
423        match part {
424            Part::Text(text_part) => {
425                content_parts.push(ChatCompletionRequestUserMessageContentPart::Text(
426                    ChatCompletionRequestMessageContentPartText {
427                        text: text_part.text,
428                        type_field: "text".to_string(),
429                    },
430                ));
431            }
432            Part::Image(image_part) => {
433                content_parts.push(ChatCompletionRequestUserMessageContentPart::Image(
434                    ChatCompletionRequestMessageContentPartImage {
435                        image_url: chat_api::ImageUrl {
436                            detail: None,
437                            url: format!(
438                                "data:{};base64,{}",
439                                image_part.mime_type, image_part.data
440                            ),
441                        },
442                    },
443                ));
444            }
445            Part::Audio(audio_part) => {
446                let format = match audio_part.format {
447                    AudioFormat::Mp3 => chat_api::InputAudioFormat::Mp3,
448                    AudioFormat::Wav => chat_api::InputAudioFormat::Wav,
449                    _ => {
450                        return Err(LanguageModelError::Unsupported(
451                            PROVIDER,
452                            format!(
453                                "Cannot convert audio format '{:?}' to OpenAI input audio format",
454                                audio_part.format
455                            ),
456                        ))
457                    }
458                };
459                content_parts.push(ChatCompletionRequestUserMessageContentPart::Audio(
460                    ChatCompletionRequestMessageContentPartAudio {
461                        input_audio: chat_api::InputAudio {
462                            data: audio_part.data,
463                            format,
464                        },
465                    },
466                ));
467            }
468            unsupported => {
469                return Err(LanguageModelError::Unsupported(
470                    PROVIDER,
471                    format!("Cannot convert part to OpenAI user message for type {unsupported:?}"),
472                ));
473            }
474        }
475    }
476
477    if content_parts.is_empty() {
478        return Err(LanguageModelError::InvalidInput(
479            "User message content must not be empty".to_string(),
480        ));
481    }
482
483    Ok(ChatCompletionRequestUserMessage {
484        content: chat_api::UserMessageContent::Array(content_parts),
485        name: None,
486    })
487}
488
489fn convert_assistant_message(
490    assistant_message: AssistantMessage,
491) -> LanguageModelResult<ChatCompletionRequestAssistantMessage> {
492    let parts =
493        source_part_utils::get_compatible_parts_without_source_parts(assistant_message.content);
494
495    let mut content_parts: Vec<chat_api::ChatCompletionRequestAssistantMessageContentPart> =
496        Vec::new();
497    let mut tool_calls: Vec<ChatCompletionMessageToolCallUnion> = Vec::new();
498    let mut audio: Option<AssistantAudioData> = None;
499
500    for part in parts {
501        match part {
502            Part::Text(text_part) => {
503                content_parts.push(
504                    chat_api::ChatCompletionRequestAssistantMessageContentPart::Text(
505                        ChatCompletionRequestMessageContentPartText {
506                            text: text_part.text,
507                            type_field: "text".to_string(),
508                        },
509                    ),
510                );
511            }
512            Part::ToolCall(tool_call_part) => {
513                tool_calls.push(ChatCompletionMessageToolCallUnion::Function(
514                    convert_to_openai_tool_call(tool_call_part)?,
515                ));
516            }
517            Part::Audio(audio_part) => {
518                let id = audio_part.id.ok_or_else(|| {
519                    LanguageModelError::Unsupported(
520                        PROVIDER,
521                        "Cannot convert audio part to OpenAI assistant message without an ID"
522                            .to_string(),
523                    )
524                })?;
525                audio = Some(AssistantAudioData::Audio(AssistantAudioDataInner { id }));
526            }
527            unsupported => {
528                return Err(LanguageModelError::Unsupported(
529                    PROVIDER,
530                    format!(
531                        "Cannot convert part to OpenAI assistant message for type {unsupported:?}"
532                    ),
533                ));
534            }
535        }
536    }
537
538    let content = if content_parts.is_empty() {
539        None
540    } else {
541        Some(AssistantMessageContent::Content(
542            AssistantMessageContentInner::Array(content_parts),
543        ))
544    };
545
546    Ok(ChatCompletionRequestAssistantMessage {
547        audio,
548        content,
549        refusal: None,
550        tool_calls: if tool_calls.is_empty() {
551            None
552        } else {
553            Some(tool_calls)
554        },
555    })
556}
557
558fn convert_tool_message(
559    tool_message: ToolMessage,
560) -> LanguageModelResult<Vec<ChatCompletionRequestToolMessage>> {
561    let mut result = Vec::new();
562
563    for part in tool_message.content {
564        match part {
565            Part::ToolResult(tool_result_part) => {
566                let mut content_parts = Vec::new();
567                let converted_parts = source_part_utils::get_compatible_parts_without_source_parts(
568                    tool_result_part.content,
569                );
570                for content_part in converted_parts {
571                    match content_part {
572                        Part::Text(text_part) => {
573                            content_parts.push(ChatCompletionRequestToolMessageContentPart::Text(
574                                ChatCompletionRequestMessageContentPartText {
575                                    text: text_part.text,
576                                    type_field: "text".to_string(),
577                                },
578                            ));
579                        }
580                        unsupported => {
581                            return Err(LanguageModelError::Unsupported(
582                                PROVIDER,
583                                format!(
584                                    "Tool messages must contain only text parts, found \
585                                     {unsupported:?}"
586                                ),
587                            ));
588                        }
589                    }
590                }
591
592                result.push(ChatCompletionRequestToolMessage {
593                    content: ToolMessageContent::Array(content_parts),
594                    tool_call_id: tool_result_part.tool_call_id,
595                });
596            }
597            unsupported => {
598                return Err(LanguageModelError::InvalidInput(format!(
599                    "Tool messages must contain only tool result parts, found {unsupported:?}"
600                )));
601            }
602        }
603    }
604
605    Ok(result)
606}
607
608fn convert_to_openai_tool(tool: Tool) -> ChatCompletionToolUnion {
609    let function = FunctionObject {
610        description: Some(tool.description),
611        name: tool.name,
612        parameters: Some(tool.parameters),
613        strict: Some(true),
614    };
615    ChatCompletionToolUnion::Function(ChatCompletionTool {
616        function,
617        type_field: "function".to_string(),
618    })
619}
620
621fn convert_to_openai_tool_call(
622    part: ToolCallPart,
623) -> LanguageModelResult<ChatCompletionMessageToolCall> {
624    let ToolCallPart {
625        tool_call_id,
626        tool_name,
627        args,
628        id,
629    } = part;
630
631    let arguments = serde_json::to_string(&args).map_err(|error| {
632        LanguageModelError::InvalidInput(format!(
633            "Failed to serialize tool call arguments: {error}"
634        ))
635    })?;
636
637    Ok(ChatCompletionMessageToolCall {
638        function: ToolCallFunction {
639            arguments,
640            name: tool_name,
641        },
642        id: id.unwrap_or(tool_call_id),
643        type_field: "function".to_string(),
644    })
645}
646
647fn convert_to_openai_tool_choice(tool_choice: ToolChoiceOption) -> ChatCompletionToolChoiceOption {
648    match tool_choice {
649        ToolChoiceOption::Auto => ChatCompletionToolChoiceOption::String(ToolChoiceString::Auto),
650        ToolChoiceOption::None => ChatCompletionToolChoiceOption::String(ToolChoiceString::None),
651        ToolChoiceOption::Required => {
652            ChatCompletionToolChoiceOption::String(ToolChoiceString::Required)
653        }
654        ToolChoiceOption::Tool(ToolChoiceTool { tool_name }) => {
655            ChatCompletionToolChoiceOption::NamedTool(ChatCompletionNamedToolChoice {
656                function: NamedToolFunction { name: tool_name },
657                type_field: "function".to_string(),
658            })
659        }
660    }
661}
662
663fn convert_to_openai_response_format(response_format: ResponseFormatOption) -> ResponseFormat {
664    match response_format {
665        ResponseFormatOption::Text => ResponseFormat::Text(ResponseFormatText {
666            type_field: "text".to_string(),
667        }),
668        ResponseFormatOption::Json(ResponseFormatJson {
669            name,
670            description,
671            schema,
672        }) => {
673            if let Some(schema) = schema {
674                ResponseFormat::JsonSchema(ResponseFormatJsonSchema {
675                    json_schema: JsonSchemaConfig {
676                        description,
677                        name,
678                        schema: Some(ResponseFormatJsonSchemaSchema::from(schema)),
679                        strict: Some(true),
680                    },
681                    type_field: "json_schema".to_string(),
682                })
683            } else {
684                ResponseFormat::JsonObject(ResponseFormatJsonObject {
685                    type_field: "json_object".to_string(),
686                })
687            }
688        }
689    }
690}
691
692fn convert_to_openai_modality(
693    modality: &crate::Modality,
694) -> LanguageModelResult<ResponseModalityEnum> {
695    Ok(match modality {
696        crate::Modality::Text => ResponseModalityEnum::Text,
697        crate::Modality::Audio => ResponseModalityEnum::Audio,
698        crate::Modality::Image => {
699            return Err(LanguageModelError::Unsupported(
700                PROVIDER,
701                format!("Cannot convert modality to OpenAI modality for modality {modality:?}"),
702            ))
703        }
704    })
705}
706
707fn convert_to_openai_audio(audio: AudioOptions) -> LanguageModelResult<ChatCompletionAudioParams> {
708    let voice = audio.voice.ok_or_else(|| {
709        LanguageModelError::InvalidInput("Audio voice is required for OpenAI audio".to_string())
710    })?;
711
712    let format = match audio.format {
713        Some(AudioFormat::Wav) => chat_api::AudioFormat::Wav,
714        Some(AudioFormat::Mp3) => chat_api::AudioFormat::Mp3,
715        Some(AudioFormat::Flac) => chat_api::AudioFormat::Flac,
716        Some(AudioFormat::Aac) => chat_api::AudioFormat::Aac,
717        Some(AudioFormat::Opus) => chat_api::AudioFormat::Opus,
718        Some(AudioFormat::Linear16) => chat_api::AudioFormat::Pcm16,
719        None => {
720            return Err(LanguageModelError::InvalidInput(
721                "Audio format is required for OpenAI audio".to_string(),
722            ))
723        }
724        Some(other) => {
725            return Err(LanguageModelError::Unsupported(
726                PROVIDER,
727                format!("Cannot convert audio format '{other:?}' to OpenAI audio format"),
728            ))
729        }
730    };
731
732    Ok(ChatCompletionAudioParams {
733        format,
734        voice: VoiceIdsShared::String(voice),
735    })
736}
737
738fn convert_to_openai_reasoning_effort(budget_tokens: u32) -> LanguageModelResult<ReasoningEffort> {
739    let effort = match budget_tokens {
740        crate::openai::types::OPENAI_REASONING_EFFORT_MINIMAL => ReasoningEffortEnum::Minimal,
741        crate::openai::types::OPENAI_REASONING_EFFORT_LOW => ReasoningEffortEnum::Low,
742        crate::openai::types::OPENAI_REASONING_EFFORT_MEDIUM => ReasoningEffortEnum::Medium,
743        crate::openai::types::OPENAI_REASONING_EFFORT_HIGH => ReasoningEffortEnum::High,
744        _ => {
745            return Err(LanguageModelError::Unsupported(
746                PROVIDER,
747                "Budget tokens property is not supported for OpenAI reasoning. You may use \
748                 OPENAI_REASONING_EFFORT_* constants to map it to OpenAI reasoning effort levels."
749                    .to_string(),
750            ))
751        }
752    };
753
754    Ok(ReasoningEffort::Enum(effort))
755}
756
757fn merge_extra(
758    request: &CreateChatCompletionRequest,
759    extra: Option<Value>,
760) -> LanguageModelResult<Value> {
761    let mut payload = serde_json::to_value(request).map_err(|error| {
762        LanguageModelError::InvalidInput(format!("Failed to serialize OpenAI request: {error}"))
763    })?;
764
765    if let Some(extra) = extra {
766        if let Value::Object(extra_map) = extra {
767            let map = payload.as_object_mut().ok_or_else(|| {
768                LanguageModelError::InvalidInput(
769                    "Serialized OpenAI request is not an object".to_string(),
770                )
771            })?;
772            for (key, value) in extra_map {
773                map.insert(key, value);
774            }
775        } else if !extra.is_null() {
776            return Err(LanguageModelError::InvalidInput(
777                "OpenAI extra must be a JSON object".to_string(),
778            ));
779        }
780    }
781
782    Ok(payload)
783}
784
785fn map_openai_message(
786    message: chat_api::ChatCompletionResponseMessage,
787    audio_params: Option<ChatCompletionAudioParams>,
788) -> LanguageModelResult<Vec<Part>> {
789    let mut parts = Vec::new();
790
791    if let Some(content) = message.content {
792        if !content.is_empty() {
793            parts.push(Part::Text(crate::TextPart {
794                text: content,
795                citations: None,
796            }));
797        }
798    }
799
800    if let Some(chat_api::AudioResponseData::Audio(data)) = message.audio {
801        let audio_format = audio_params
802            .map(|params| map_openai_audio_format(&params.format))
803            .ok_or_else(|| {
804                LanguageModelError::Invariant(
805                    PROVIDER,
806                    "Audio returned from OpenAI API but no audio parameter was provided"
807                        .to_string(),
808                )
809            })?;
810
811        let mut audio_part = crate::AudioPart {
812            data: data.data,
813            format: audio_format,
814            sample_rate: None,
815            channels: None,
816            transcript: Some(data.transcript),
817            id: Some(data.id),
818        };
819
820        if audio_part.format == AudioFormat::Linear16 {
821            audio_part.sample_rate = Some(OPENAI_AUDIO_SAMPLE_RATE);
822            audio_part.channels = Some(OPENAI_AUDIO_CHANNELS);
823        }
824
825        parts.push(Part::Audio(audio_part));
826    }
827
828    if let Some(tool_calls) = message.tool_calls {
829        for tool_call in tool_calls {
830            match tool_call {
831                ChatCompletionMessageToolCallUnion::Function(function_tool_call) => {
832                    parts.push(Part::ToolCall(map_openai_function_tool_call(
833                        function_tool_call,
834                    )?));
835                }
836                ChatCompletionMessageToolCallUnion::Custom(custom_tool_call) => {
837                    return Err(LanguageModelError::NotImplemented(
838                        PROVIDER,
839                        format!(
840                            "Cannot map OpenAI tool call of type {} to ToolCallPart",
841                            custom_tool_call.type_field
842                        ),
843                    ));
844                }
845            }
846        }
847    }
848
849    Ok(parts)
850}
851
852fn map_openai_audio_format(format: &chat_api::AudioFormat) -> AudioFormat {
853    match format {
854        chat_api::AudioFormat::Wav => AudioFormat::Wav,
855        chat_api::AudioFormat::Mp3 => AudioFormat::Mp3,
856        chat_api::AudioFormat::Flac => AudioFormat::Flac,
857        chat_api::AudioFormat::Opus => AudioFormat::Opus,
858        chat_api::AudioFormat::Pcm16 => AudioFormat::Linear16,
859        chat_api::AudioFormat::Aac => AudioFormat::Aac,
860    }
861}
862
863fn map_openai_function_tool_call(
864    tool_call: ChatCompletionMessageToolCall,
865) -> LanguageModelResult<ToolCallPart> {
866    if tool_call.type_field != "function" {
867        return Err(LanguageModelError::NotImplemented(
868            PROVIDER,
869            format!(
870                "Cannot map OpenAI tool call of type {} to ToolCallPart",
871                tool_call.type_field
872            ),
873        ));
874    }
875
876    let args: Value = serde_json::from_str(&tool_call.function.arguments).map_err(|error| {
877        LanguageModelError::Invariant(
878            PROVIDER,
879            format!("Failed to parse tool call arguments as JSON: {error}"),
880        )
881    })?;
882
883    Ok(ToolCallPart {
884        tool_call_id: tool_call.id,
885        tool_name: tool_call.function.name,
886        args,
887        id: None,
888    })
889}
890
891fn map_openai_delta(
892    delta: ChatCompletionStreamResponseDelta,
893    existing_content_deltas: &[ContentDelta],
894    audio_params: Option<&ChatCompletionAudioParams>,
895) -> LanguageModelResult<Vec<ContentDelta>> {
896    let mut content_deltas = Vec::new();
897
898    if let Some(content) = delta.content {
899        if !content.is_empty() {
900            let part = PartDelta::Text(crate::TextPartDelta {
901                text: content,
902                citation: None,
903            });
904            let combined = existing_content_deltas
905                .iter()
906                .chain(content_deltas.iter())
907                .collect::<Vec<_>>();
908            let index = stream_utils::guess_delta_index(&part, &combined, None);
909            content_deltas.push(ContentDelta { index, part });
910        }
911    }
912
913    if let Some(audio) = delta.audio {
914        let mut audio_part = crate::AudioPartDelta {
915            data: audio.data,
916            format: audio_params.map(|params| map_openai_audio_format(&params.format)),
917            sample_rate: None,
918            channels: None,
919            transcript: audio.transcript,
920            id: audio.id,
921        };
922
923        if audio_part.format == Some(AudioFormat::Linear16) {
924            audio_part.sample_rate = Some(OPENAI_AUDIO_SAMPLE_RATE);
925            audio_part.channels = Some(OPENAI_AUDIO_CHANNELS);
926        }
927
928        let part = PartDelta::Audio(audio_part);
929        let combined = existing_content_deltas
930            .iter()
931            .chain(content_deltas.iter())
932            .collect::<Vec<_>>();
933        let index = stream_utils::guess_delta_index(&part, &combined, None);
934        content_deltas.push(ContentDelta { index, part });
935    }
936
937    if let Some(tool_calls) = delta.tool_calls {
938        for tool_call in tool_calls {
939            let mut part = crate::ToolCallPartDelta {
940                tool_call_id: tool_call.id,
941                tool_name: None,
942                args: None,
943                id: None,
944            };
945
946            if let Some(function) = tool_call.function {
947                if part.tool_name.is_none() {
948                    part.tool_name = function.name;
949                }
950                if part.args.is_none() {
951                    part.args = function.arguments;
952                }
953            }
954
955            let part = PartDelta::ToolCall(part);
956
957            let combined = existing_content_deltas
958                .iter()
959                .chain(content_deltas.iter())
960                .collect::<Vec<_>>();
961            let index = stream_utils::guess_delta_index(
962                &part,
963                &combined,
964                Some(usize::try_from(tool_call.index).map_err(|_| {
965                    LanguageModelError::Invariant(
966                        PROVIDER,
967                        "Received negative tool call index from OpenAI stream".to_string(),
968                    )
969                })?),
970            );
971            content_deltas.push(ContentDelta { index, part });
972        }
973    }
974
975    Ok(content_deltas)
976}
977
978fn map_openai_usage(usage: CompletionUsage) -> LanguageModelResult<ModelUsage> {
979    let input_tokens = u32::try_from(usage.prompt_tokens).map_err(|_| {
980        LanguageModelError::Invariant(
981            PROVIDER,
982            "OpenAI prompt_tokens exceeded u32 range".to_string(),
983        )
984    })?;
985    let output_tokens = u32::try_from(usage.completion_tokens).map_err(|_| {
986        LanguageModelError::Invariant(
987            PROVIDER,
988            "OpenAI completion_tokens exceeded u32 range".to_string(),
989        )
990    })?;
991
992    let mut result = ModelUsage {
993        input_tokens,
994        output_tokens,
995        input_tokens_details: None,
996        output_tokens_details: None,
997    };
998
999    if let Some(details) = usage.prompt_tokens_details {
1000        result.input_tokens_details = Some(map_openai_prompt_tokens_details(details)?);
1001    }
1002
1003    if let Some(details) = &usage.completion_tokens_details {
1004        result.output_tokens_details = Some(map_openai_completion_tokens_details(details)?);
1005    }
1006
1007    Ok(result)
1008}
1009
1010fn map_openai_prompt_tokens_details(
1011    details: chat_api::PromptTokensDetails,
1012) -> LanguageModelResult<crate::ModelTokensDetails> {
1013    let mut result = crate::ModelTokensDetails::default();
1014
1015    if let Some(text_tokens) = details.text_tokens {
1016        result.text_tokens = Some(u32::try_from(text_tokens).map_err(|_| {
1017            LanguageModelError::Invariant(
1018                PROVIDER,
1019                "OpenAI text prompt tokens exceeded u32 range".to_string(),
1020            )
1021        })?);
1022    }
1023
1024    if let Some(audio_tokens) = details.audio_tokens {
1025        result.audio_tokens = Some(u32::try_from(audio_tokens).map_err(|_| {
1026            LanguageModelError::Invariant(
1027                PROVIDER,
1028                "OpenAI audio prompt tokens exceeded u32 range".to_string(),
1029            )
1030        })?);
1031    }
1032
1033    if let Some(image_tokens) = details.image_tokens {
1034        result.image_tokens = Some(u32::try_from(image_tokens).map_err(|_| {
1035            LanguageModelError::Invariant(
1036                PROVIDER,
1037                "OpenAI image prompt tokens exceeded u32 range".to_string(),
1038            )
1039        })?);
1040    }
1041
1042    if let Some(cached_details) = details.cached_tokens_details {
1043        if let Some(text_tokens) = cached_details.text_tokens {
1044            result.cached_text_tokens = Some(u32::try_from(text_tokens).map_err(|_| {
1045                LanguageModelError::Invariant(
1046                    PROVIDER,
1047                    "OpenAI cached text prompt tokens exceeded u32 range".to_string(),
1048                )
1049            })?);
1050        }
1051        if let Some(audio_tokens) = cached_details.audio_tokens {
1052            result.cached_audio_tokens = Some(u32::try_from(audio_tokens).map_err(|_| {
1053                LanguageModelError::Invariant(
1054                    PROVIDER,
1055                    "OpenAI cached audio prompt tokens exceeded u32 range".to_string(),
1056                )
1057            })?);
1058        }
1059    }
1060
1061    Ok(result)
1062}
1063
1064fn map_openai_completion_tokens_details(
1065    details: &chat_api::CompletionTokensDetails,
1066) -> LanguageModelResult<crate::ModelTokensDetails> {
1067    let mut result = crate::ModelTokensDetails::default();
1068
1069    if let Some(text_tokens) = details.text_tokens {
1070        result.text_tokens = Some(u32::try_from(text_tokens).map_err(|_| {
1071            LanguageModelError::Invariant(
1072                PROVIDER,
1073                "OpenAI text completion tokens exceeded u32 range".to_string(),
1074            )
1075        })?);
1076    }
1077
1078    if let Some(audio_tokens) = details.audio_tokens {
1079        result.audio_tokens = Some(u32::try_from(audio_tokens).map_err(|_| {
1080            LanguageModelError::Invariant(
1081                PROVIDER,
1082                "OpenAI audio completion tokens exceeded u32 range".to_string(),
1083            )
1084        })?);
1085    }
1086
1087    Ok(result)
1088}