1use super::chat_api::{
2 self, AssistantAudioData, AssistantAudioDataInner, AssistantMessageContent,
3 AssistantMessageContentInner, ChatCompletionAudioParams, ChatCompletionMessageToolCall,
4 ChatCompletionMessageToolCallUnion, ChatCompletionNamedToolChoice,
5 ChatCompletionRequestAssistantMessage, ChatCompletionRequestMessage,
6 ChatCompletionRequestMessageContentPartAudio, ChatCompletionRequestMessageContentPartImage,
7 ChatCompletionRequestMessageContentPartText, ChatCompletionRequestSystemMessage,
8 ChatCompletionRequestToolMessage, ChatCompletionRequestToolMessageContentPart,
9 ChatCompletionRequestUserMessage, ChatCompletionRequestUserMessageContentPart,
10 ChatCompletionStreamOptions, ChatCompletionStreamOptionsInner,
11 ChatCompletionStreamResponseDelta, ChatCompletionTool, ChatCompletionToolChoiceOption,
12 ChatCompletionToolUnion, CompletionUsage, CreateChatCompletionRequest,
13 CreateChatCompletionResponse, CreateChatCompletionStreamResponse,
14 CreateModelResponseProperties, FunctionObject, JsonSchemaConfig, ModelIdsShared,
15 ModelResponseProperties, NamedToolFunction, ReasoningEffort, ReasoningEffortEnum,
16 ResponseFormat, ResponseFormatJsonObject, ResponseFormatJsonSchema,
17 ResponseFormatJsonSchemaSchema, ResponseFormatText, ResponseModalities, ResponseModalityEnum,
18 ToolCallFunction, ToolChoiceString, ToolMessageContent, VoiceIdsShared,
19};
20use crate::{
21 client_utils, source_part_utils, stream_utils, AssistantMessage, AudioFormat, AudioOptions,
22 ContentDelta, LanguageModel, LanguageModelError, LanguageModelInput, LanguageModelMetadata,
23 LanguageModelResult, LanguageModelStream, Message, ModelResponse, ModelUsage, Part, PartDelta,
24 PartialModelResponse, ResponseFormatJson, ResponseFormatOption, Tool, ToolCallPart,
25 ToolChoiceOption, ToolChoiceTool, ToolMessage, UserMessage,
26};
27use async_stream::try_stream;
28use futures::{future::BoxFuture, StreamExt};
29use reqwest::{
30 header::{self, HeaderMap, HeaderName, HeaderValue},
31 Client,
32};
33use serde_json::Value;
34use std::{collections::HashMap, sync::Arc};
35
36const PROVIDER: &str = "openai";
37const OPENAI_AUDIO_SAMPLE_RATE: u32 = 24_000;
38const OPENAI_AUDIO_CHANNELS: u32 = 1;
39
40pub struct OpenAIChatModel {
41 model_id: String,
42 api_key: String,
43 base_url: String,
44 client: Client,
45 metadata: Option<Arc<LanguageModelMetadata>>,
46 headers: HashMap<String, String>,
47}
48
49#[derive(Clone, Default)]
50pub struct OpenAIChatModelOptions {
51 pub base_url: Option<String>,
52 pub api_key: String,
53 pub headers: Option<HashMap<String, String>>,
54 pub client: Option<Client>,
55}
56
57impl OpenAIChatModel {
58 #[must_use]
59 pub fn new(model_id: impl Into<String>, options: OpenAIChatModelOptions) -> Self {
60 let OpenAIChatModelOptions {
61 base_url,
62 api_key,
63 headers,
64 client,
65 } = options;
66
67 let base_url = base_url
68 .unwrap_or_else(|| "https://api.openai.com/v1".to_string())
69 .trim_end_matches('/')
70 .to_string();
71 let client = client.unwrap_or_else(Client::new);
72 let headers = headers.unwrap_or_default();
73
74 Self {
75 model_id: model_id.into(),
76 api_key,
77 base_url,
78 client,
79 metadata: None,
80 headers,
81 }
82 }
83
84 #[must_use]
85 pub fn with_metadata(mut self, metadata: LanguageModelMetadata) -> Self {
86 self.metadata = Some(Arc::new(metadata));
87 self
88 }
89
90 fn request_headers(&self) -> LanguageModelResult<HeaderMap> {
91 let mut headers = HeaderMap::new();
92
93 let auth_header =
94 HeaderValue::from_str(&format!("Bearer {}", self.api_key)).map_err(|error| {
95 LanguageModelError::InvalidInput(format!(
96 "Invalid OpenAI API key header value: {error}"
97 ))
98 })?;
99 headers.insert(header::AUTHORIZATION, auth_header);
100
101 for (key, value) in &self.headers {
102 let header_name = HeaderName::from_bytes(key.as_bytes()).map_err(|error| {
103 LanguageModelError::InvalidInput(format!(
104 "Invalid OpenAI header name '{key}': {error}"
105 ))
106 })?;
107 let header_value = HeaderValue::from_str(value).map_err(|error| {
108 LanguageModelError::InvalidInput(format!(
109 "Invalid OpenAI header value for '{key}': {error}"
110 ))
111 })?;
112 headers.insert(header_name, header_value);
113 }
114
115 Ok(headers)
116 }
117}
118
119impl LanguageModel for OpenAIChatModel {
120 fn provider(&self) -> &'static str {
121 PROVIDER
122 }
123
124 fn model_id(&self) -> String {
125 self.model_id.clone()
126 }
127
128 fn metadata(&self) -> Option<&LanguageModelMetadata> {
129 self.metadata.as_deref()
130 }
131
132 fn generate(
133 &self,
134 input: LanguageModelInput,
135 ) -> BoxFuture<'_, LanguageModelResult<ModelResponse>> {
136 Box::pin(async move {
137 crate::opentelemetry::trace_generate(
138 self.provider(),
139 &self.model_id(),
140 input,
141 |input| async move {
142 let (request, payload) =
143 convert_to_openai_create_params(input, &self.model_id(), false)?;
144 let headers = self.request_headers()?;
145
146 let response: CreateChatCompletionResponse = client_utils::send_json(
147 &self.client,
148 &format!("{}/chat/completions", self.base_url),
149 &payload,
150 headers,
151 )
152 .await?;
153
154 let choice = response.choices.into_iter().next().ok_or_else(|| {
155 LanguageModelError::Invariant(
156 PROVIDER,
157 "No choices in response".to_string(),
158 )
159 })?;
160
161 let message = choice.message;
162
163 if let Some(refusal) = &message.refusal {
164 if !refusal.is_empty() {
165 return Err(LanguageModelError::Refusal(refusal.clone()));
166 }
167 }
168
169 let content = map_openai_message(message, request.audio)?;
170
171 let usage = response.usage.map(map_openai_usage).transpose()?;
172
173 let cost = if let (Some(usage), Some(pricing)) = (
174 usage.as_ref(),
175 self.metadata().and_then(|m| m.pricing.as_ref()),
176 ) {
177 Some(usage.calculate_cost(pricing))
178 } else {
179 None
180 };
181
182 Ok(ModelResponse {
183 content,
184 usage,
185 cost,
186 })
187 },
188 )
189 .await
190 })
191 }
192
193 fn stream(
194 &self,
195 input: LanguageModelInput,
196 ) -> BoxFuture<'_, LanguageModelResult<LanguageModelStream>> {
197 Box::pin(async move {
198 crate::opentelemetry::trace_stream(
199 self.provider(),
200 &self.model_id(),
201 input,
202 |input| async move {
203 let metadata = self.metadata.clone();
204 let (request, payload) =
205 convert_to_openai_create_params(input, &self.model_id(), true)?;
206 let CreateChatCompletionRequest { audio: audio_params, .. } = request;
207 let headers = self.request_headers()?;
208
209 let mut stream =
210 client_utils::send_sse_stream::<Value, CreateChatCompletionStreamResponse>(
211 &self.client,
212 &format!("{}/chat/completions", self.base_url),
213 &payload,
214 headers,
215 PROVIDER,
216 )
217 .await?;
218
219 let mut refusal = String::new();
220 let mut content_deltas: Vec<ContentDelta> = Vec::new();
221
222 let stream = try_stream! {
223 while let Some(chunk) = stream.next().await {
224 let chunk = chunk?;
225
226 if let Some(choice) = chunk.choices.unwrap_or_default().into_iter().next() {
227 let mut delta = choice.delta;
228
229 if let Some(delta_refusal) = delta.refusal.take() {
230 refusal.push_str(&delta_refusal);
231 }
232
233 let deltas = map_openai_delta(
234 delta,
235 &content_deltas,
236 audio_params.as_ref(),
237 )?;
238
239 for delta in deltas {
240 content_deltas.push(delta.clone());
241 yield PartialModelResponse {
242 delta: Some(delta),
243 ..Default::default()
244 };
245 }
246 }
247
248 if let Some(usage) = chunk.usage {
249 let usage = map_openai_usage(usage)?;
250 let cost = metadata
251 .as_ref()
252 .and_then(|m| m.pricing.as_ref())
253 .map(|pricing| usage.calculate_cost(pricing));
254
255 yield PartialModelResponse {
256 delta: None,
257 usage: Some(usage),
258 cost,
259 };
260 }
261 }
262
263 if !refusal.is_empty() {
264 Err(LanguageModelError::Refusal(refusal))?;
265 }
266 };
267
268 Ok(LanguageModelStream::from_stream(stream))
269 },
270 )
271 .await
272 })
273 }
274}
275
276fn convert_to_openai_create_params(
277 input: LanguageModelInput,
278 model_id: &str,
279 stream: bool,
280) -> LanguageModelResult<(CreateChatCompletionRequest, Value)> {
281 let messages = convert_to_openai_messages(input.messages, input.system_prompt)?;
282
283 let modalities = input
284 .modalities
285 .as_ref()
286 .map(|modalities| -> LanguageModelResult<ResponseModalities> {
287 if modalities.is_empty() {
288 Ok(ResponseModalities::Null)
289 } else {
290 let converted = modalities
291 .iter()
292 .map(convert_to_openai_modality)
293 .collect::<LanguageModelResult<Vec<_>>>()?;
294 Ok(ResponseModalities::Array(converted))
295 }
296 })
297 .transpose()?;
298
299 let create_model_response_properties = CreateModelResponseProperties {
300 model_response_properties: ModelResponseProperties {
301 prompt_cache_key: None,
302 safety_identifier: None,
303 service_tier: None,
304 temperature: input.temperature,
305 top_logprobs: None,
306 top_p: input.top_p,
307 ..Default::default()
308 },
309 top_logprobs: None,
310 };
311
312 let audio = input.audio.map(convert_to_openai_audio).transpose()?;
313
314 let reasoning_effort = input
315 .reasoning
316 .as_ref()
317 .and_then(|reasoning| reasoning.budget_tokens)
318 .map(convert_to_openai_reasoning_effort)
319 .transpose()?;
320
321 let request = CreateChatCompletionRequest {
322 create_model_response_properties,
323 audio,
324 frequency_penalty: input.frequency_penalty,
325 logit_bias: None,
326 logprobs: None,
327 max_completion_tokens: input
328 .max_tokens
329 .map(|value| {
330 i32::try_from(value).map_err(|_| {
331 LanguageModelError::InvalidInput(
332 "max_tokens exceeds supported range for OpenAI chat completions"
333 .to_string(),
334 )
335 })
336 })
337 .transpose()?,
338 messages,
339 modalities,
340 model: ModelIdsShared::String(model_id.to_string()),
341 n: None,
342 parallel_tool_calls: None,
343 prediction: None,
344 presence_penalty: input.presence_penalty,
345 reasoning_effort,
346 response_format: input.response_format.map(convert_to_openai_response_format),
347 seed: input.seed,
348 stop: None,
349 store: None,
350 stream: Some(stream),
351 stream_options: if stream {
352 Some(ChatCompletionStreamOptions::Options(
353 ChatCompletionStreamOptionsInner {
354 include_obfuscation: None,
355 include_usage: Some(true),
356 },
357 ))
358 } else {
359 None
360 },
361 tool_choice: input.tool_choice.map(convert_to_openai_tool_choice),
362 tools: input
363 .tools
364 .map(|tools| tools.into_iter().map(convert_to_openai_tool).collect()),
365 top_logprobs: None,
366 verbosity: None,
367 web_search_options: None,
368 };
369
370 let payload = merge_extra(&request, input.extra)?;
371
372 Ok((request, payload))
373}
374
375fn convert_to_openai_messages(
376 messages: Vec<Message>,
377 system_prompt: Option<String>,
378) -> LanguageModelResult<Vec<ChatCompletionRequestMessage>> {
379 let mut openai_messages = Vec::new();
380
381 if let Some(prompt) = system_prompt {
382 openai_messages.push(ChatCompletionRequestMessage::System(
383 ChatCompletionRequestSystemMessage {
384 content: chat_api::SystemMessageContent::Text(prompt),
385 name: None,
386 },
387 ));
388 }
389
390 for message in messages {
391 match message {
392 Message::User(user_message) => {
393 openai_messages.push(ChatCompletionRequestMessage::User(convert_user_message(
394 user_message,
395 )?));
396 }
397 Message::Assistant(assistant_message) => {
398 openai_messages.push(ChatCompletionRequestMessage::Assistant(
399 convert_assistant_message(assistant_message)?,
400 ));
401 }
402 Message::Tool(tool_message) => {
403 let tool_messages = convert_tool_message(tool_message)?;
404 openai_messages.extend(
405 tool_messages
406 .into_iter()
407 .map(ChatCompletionRequestMessage::Tool),
408 );
409 }
410 }
411 }
412
413 Ok(openai_messages)
414}
415
416fn convert_user_message(
417 user_message: UserMessage,
418) -> LanguageModelResult<ChatCompletionRequestUserMessage> {
419 let parts = source_part_utils::get_compatible_parts_without_source_parts(user_message.content);
420 let mut content_parts = Vec::new();
421
422 for part in parts {
423 match part {
424 Part::Text(text_part) => {
425 content_parts.push(ChatCompletionRequestUserMessageContentPart::Text(
426 ChatCompletionRequestMessageContentPartText {
427 text: text_part.text,
428 type_field: "text".to_string(),
429 },
430 ));
431 }
432 Part::Image(image_part) => {
433 content_parts.push(ChatCompletionRequestUserMessageContentPart::Image(
434 ChatCompletionRequestMessageContentPartImage {
435 image_url: chat_api::ImageUrl {
436 detail: None,
437 url: format!(
438 "data:{};base64,{}",
439 image_part.mime_type, image_part.data
440 ),
441 },
442 },
443 ));
444 }
445 Part::Audio(audio_part) => {
446 let format = match audio_part.format {
447 AudioFormat::Mp3 => chat_api::InputAudioFormat::Mp3,
448 AudioFormat::Wav => chat_api::InputAudioFormat::Wav,
449 _ => {
450 return Err(LanguageModelError::Unsupported(
451 PROVIDER,
452 format!(
453 "Cannot convert audio format '{:?}' to OpenAI input audio format",
454 audio_part.format
455 ),
456 ))
457 }
458 };
459 content_parts.push(ChatCompletionRequestUserMessageContentPart::Audio(
460 ChatCompletionRequestMessageContentPartAudio {
461 input_audio: chat_api::InputAudio {
462 data: audio_part.data,
463 format,
464 },
465 },
466 ));
467 }
468 unsupported => {
469 return Err(LanguageModelError::Unsupported(
470 PROVIDER,
471 format!("Cannot convert part to OpenAI user message for type {unsupported:?}"),
472 ));
473 }
474 }
475 }
476
477 if content_parts.is_empty() {
478 return Err(LanguageModelError::InvalidInput(
479 "User message content must not be empty".to_string(),
480 ));
481 }
482
483 Ok(ChatCompletionRequestUserMessage {
484 content: chat_api::UserMessageContent::Array(content_parts),
485 name: None,
486 })
487}
488
489fn convert_assistant_message(
490 assistant_message: AssistantMessage,
491) -> LanguageModelResult<ChatCompletionRequestAssistantMessage> {
492 let parts =
493 source_part_utils::get_compatible_parts_without_source_parts(assistant_message.content);
494
495 let mut content_parts: Vec<chat_api::ChatCompletionRequestAssistantMessageContentPart> =
496 Vec::new();
497 let mut tool_calls: Vec<ChatCompletionMessageToolCallUnion> = Vec::new();
498 let mut audio: Option<AssistantAudioData> = None;
499
500 for part in parts {
501 match part {
502 Part::Text(text_part) => {
503 content_parts.push(
504 chat_api::ChatCompletionRequestAssistantMessageContentPart::Text(
505 ChatCompletionRequestMessageContentPartText {
506 text: text_part.text,
507 type_field: "text".to_string(),
508 },
509 ),
510 );
511 }
512 Part::ToolCall(tool_call_part) => {
513 tool_calls.push(ChatCompletionMessageToolCallUnion::Function(
514 convert_to_openai_tool_call(tool_call_part)?,
515 ));
516 }
517 Part::Audio(audio_part) => {
518 let id = audio_part.id.ok_or_else(|| {
519 LanguageModelError::Unsupported(
520 PROVIDER,
521 "Cannot convert audio part to OpenAI assistant message without an ID"
522 .to_string(),
523 )
524 })?;
525 audio = Some(AssistantAudioData::Audio(AssistantAudioDataInner { id }));
526 }
527 unsupported => {
528 return Err(LanguageModelError::Unsupported(
529 PROVIDER,
530 format!(
531 "Cannot convert part to OpenAI assistant message for type {unsupported:?}"
532 ),
533 ));
534 }
535 }
536 }
537
538 let content = if content_parts.is_empty() {
539 None
540 } else {
541 Some(AssistantMessageContent::Content(
542 AssistantMessageContentInner::Array(content_parts),
543 ))
544 };
545
546 Ok(ChatCompletionRequestAssistantMessage {
547 audio,
548 content,
549 refusal: None,
550 tool_calls: if tool_calls.is_empty() {
551 None
552 } else {
553 Some(tool_calls)
554 },
555 })
556}
557
558fn convert_tool_message(
559 tool_message: ToolMessage,
560) -> LanguageModelResult<Vec<ChatCompletionRequestToolMessage>> {
561 let mut result = Vec::new();
562
563 for part in tool_message.content {
564 match part {
565 Part::ToolResult(tool_result_part) => {
566 let mut content_parts = Vec::new();
567 let converted_parts = source_part_utils::get_compatible_parts_without_source_parts(
568 tool_result_part.content,
569 );
570 for content_part in converted_parts {
571 match content_part {
572 Part::Text(text_part) => {
573 content_parts.push(ChatCompletionRequestToolMessageContentPart::Text(
574 ChatCompletionRequestMessageContentPartText {
575 text: text_part.text,
576 type_field: "text".to_string(),
577 },
578 ));
579 }
580 unsupported => {
581 return Err(LanguageModelError::Unsupported(
582 PROVIDER,
583 format!(
584 "Tool messages must contain only text parts, found \
585 {unsupported:?}"
586 ),
587 ));
588 }
589 }
590 }
591
592 result.push(ChatCompletionRequestToolMessage {
593 content: ToolMessageContent::Array(content_parts),
594 tool_call_id: tool_result_part.tool_call_id,
595 });
596 }
597 unsupported => {
598 return Err(LanguageModelError::InvalidInput(format!(
599 "Tool messages must contain only tool result parts, found {unsupported:?}"
600 )));
601 }
602 }
603 }
604
605 Ok(result)
606}
607
608fn convert_to_openai_tool(tool: Tool) -> ChatCompletionToolUnion {
609 let function = FunctionObject {
610 description: Some(tool.description),
611 name: tool.name,
612 parameters: Some(tool.parameters),
613 strict: Some(true),
614 };
615 ChatCompletionToolUnion::Function(ChatCompletionTool {
616 function,
617 type_field: "function".to_string(),
618 })
619}
620
621fn convert_to_openai_tool_call(
622 part: ToolCallPart,
623) -> LanguageModelResult<ChatCompletionMessageToolCall> {
624 let ToolCallPart {
625 tool_call_id,
626 tool_name,
627 args,
628 id,
629 } = part;
630
631 let arguments = serde_json::to_string(&args).map_err(|error| {
632 LanguageModelError::InvalidInput(format!(
633 "Failed to serialize tool call arguments: {error}"
634 ))
635 })?;
636
637 Ok(ChatCompletionMessageToolCall {
638 function: ToolCallFunction {
639 arguments,
640 name: tool_name,
641 },
642 id: id.unwrap_or(tool_call_id),
643 type_field: "function".to_string(),
644 })
645}
646
647fn convert_to_openai_tool_choice(tool_choice: ToolChoiceOption) -> ChatCompletionToolChoiceOption {
648 match tool_choice {
649 ToolChoiceOption::Auto => ChatCompletionToolChoiceOption::String(ToolChoiceString::Auto),
650 ToolChoiceOption::None => ChatCompletionToolChoiceOption::String(ToolChoiceString::None),
651 ToolChoiceOption::Required => {
652 ChatCompletionToolChoiceOption::String(ToolChoiceString::Required)
653 }
654 ToolChoiceOption::Tool(ToolChoiceTool { tool_name }) => {
655 ChatCompletionToolChoiceOption::NamedTool(ChatCompletionNamedToolChoice {
656 function: NamedToolFunction { name: tool_name },
657 type_field: "function".to_string(),
658 })
659 }
660 }
661}
662
663fn convert_to_openai_response_format(response_format: ResponseFormatOption) -> ResponseFormat {
664 match response_format {
665 ResponseFormatOption::Text => ResponseFormat::Text(ResponseFormatText {
666 type_field: "text".to_string(),
667 }),
668 ResponseFormatOption::Json(ResponseFormatJson {
669 name,
670 description,
671 schema,
672 }) => {
673 if let Some(schema) = schema {
674 ResponseFormat::JsonSchema(ResponseFormatJsonSchema {
675 json_schema: JsonSchemaConfig {
676 description,
677 name,
678 schema: Some(ResponseFormatJsonSchemaSchema::from(schema)),
679 strict: Some(true),
680 },
681 type_field: "json_schema".to_string(),
682 })
683 } else {
684 ResponseFormat::JsonObject(ResponseFormatJsonObject {
685 type_field: "json_object".to_string(),
686 })
687 }
688 }
689 }
690}
691
692fn convert_to_openai_modality(
693 modality: &crate::Modality,
694) -> LanguageModelResult<ResponseModalityEnum> {
695 Ok(match modality {
696 crate::Modality::Text => ResponseModalityEnum::Text,
697 crate::Modality::Audio => ResponseModalityEnum::Audio,
698 crate::Modality::Image => {
699 return Err(LanguageModelError::Unsupported(
700 PROVIDER,
701 format!("Cannot convert modality to OpenAI modality for modality {modality:?}"),
702 ))
703 }
704 })
705}
706
707fn convert_to_openai_audio(audio: AudioOptions) -> LanguageModelResult<ChatCompletionAudioParams> {
708 let voice = audio.voice.ok_or_else(|| {
709 LanguageModelError::InvalidInput("Audio voice is required for OpenAI audio".to_string())
710 })?;
711
712 let format = match audio.format {
713 Some(AudioFormat::Wav) => chat_api::AudioFormat::Wav,
714 Some(AudioFormat::Mp3) => chat_api::AudioFormat::Mp3,
715 Some(AudioFormat::Flac) => chat_api::AudioFormat::Flac,
716 Some(AudioFormat::Aac) => chat_api::AudioFormat::Aac,
717 Some(AudioFormat::Opus) => chat_api::AudioFormat::Opus,
718 Some(AudioFormat::Linear16) => chat_api::AudioFormat::Pcm16,
719 None => {
720 return Err(LanguageModelError::InvalidInput(
721 "Audio format is required for OpenAI audio".to_string(),
722 ))
723 }
724 Some(other) => {
725 return Err(LanguageModelError::Unsupported(
726 PROVIDER,
727 format!("Cannot convert audio format '{other:?}' to OpenAI audio format"),
728 ))
729 }
730 };
731
732 Ok(ChatCompletionAudioParams {
733 format,
734 voice: VoiceIdsShared::String(voice),
735 })
736}
737
738fn convert_to_openai_reasoning_effort(budget_tokens: u32) -> LanguageModelResult<ReasoningEffort> {
739 let effort = match budget_tokens {
740 crate::openai::types::OPENAI_REASONING_EFFORT_MINIMAL => ReasoningEffortEnum::Minimal,
741 crate::openai::types::OPENAI_REASONING_EFFORT_LOW => ReasoningEffortEnum::Low,
742 crate::openai::types::OPENAI_REASONING_EFFORT_MEDIUM => ReasoningEffortEnum::Medium,
743 crate::openai::types::OPENAI_REASONING_EFFORT_HIGH => ReasoningEffortEnum::High,
744 _ => {
745 return Err(LanguageModelError::Unsupported(
746 PROVIDER,
747 "Budget tokens property is not supported for OpenAI reasoning. You may use \
748 OPENAI_REASONING_EFFORT_* constants to map it to OpenAI reasoning effort levels."
749 .to_string(),
750 ))
751 }
752 };
753
754 Ok(ReasoningEffort::Enum(effort))
755}
756
757fn merge_extra(
758 request: &CreateChatCompletionRequest,
759 extra: Option<Value>,
760) -> LanguageModelResult<Value> {
761 let mut payload = serde_json::to_value(request).map_err(|error| {
762 LanguageModelError::InvalidInput(format!("Failed to serialize OpenAI request: {error}"))
763 })?;
764
765 if let Some(extra) = extra {
766 if let Value::Object(extra_map) = extra {
767 let map = payload.as_object_mut().ok_or_else(|| {
768 LanguageModelError::InvalidInput(
769 "Serialized OpenAI request is not an object".to_string(),
770 )
771 })?;
772 for (key, value) in extra_map {
773 map.insert(key, value);
774 }
775 } else if !extra.is_null() {
776 return Err(LanguageModelError::InvalidInput(
777 "OpenAI extra must be a JSON object".to_string(),
778 ));
779 }
780 }
781
782 Ok(payload)
783}
784
785fn map_openai_message(
786 message: chat_api::ChatCompletionResponseMessage,
787 audio_params: Option<ChatCompletionAudioParams>,
788) -> LanguageModelResult<Vec<Part>> {
789 let mut parts = Vec::new();
790
791 if let Some(content) = message.content {
792 if !content.is_empty() {
793 parts.push(Part::Text(crate::TextPart {
794 text: content,
795 citations: None,
796 }));
797 }
798 }
799
800 if let Some(chat_api::AudioResponseData::Audio(data)) = message.audio {
801 let audio_format = audio_params
802 .map(|params| map_openai_audio_format(¶ms.format))
803 .ok_or_else(|| {
804 LanguageModelError::Invariant(
805 PROVIDER,
806 "Audio returned from OpenAI API but no audio parameter was provided"
807 .to_string(),
808 )
809 })?;
810
811 let mut audio_part = crate::AudioPart {
812 data: data.data,
813 format: audio_format,
814 sample_rate: None,
815 channels: None,
816 transcript: Some(data.transcript),
817 id: Some(data.id),
818 };
819
820 if audio_part.format == AudioFormat::Linear16 {
821 audio_part.sample_rate = Some(OPENAI_AUDIO_SAMPLE_RATE);
822 audio_part.channels = Some(OPENAI_AUDIO_CHANNELS);
823 }
824
825 parts.push(Part::Audio(audio_part));
826 }
827
828 if let Some(tool_calls) = message.tool_calls {
829 for tool_call in tool_calls {
830 match tool_call {
831 ChatCompletionMessageToolCallUnion::Function(function_tool_call) => {
832 parts.push(Part::ToolCall(map_openai_function_tool_call(
833 function_tool_call,
834 )?));
835 }
836 ChatCompletionMessageToolCallUnion::Custom(custom_tool_call) => {
837 return Err(LanguageModelError::NotImplemented(
838 PROVIDER,
839 format!(
840 "Cannot map OpenAI tool call of type {} to ToolCallPart",
841 custom_tool_call.type_field
842 ),
843 ));
844 }
845 }
846 }
847 }
848
849 Ok(parts)
850}
851
852fn map_openai_audio_format(format: &chat_api::AudioFormat) -> AudioFormat {
853 match format {
854 chat_api::AudioFormat::Wav => AudioFormat::Wav,
855 chat_api::AudioFormat::Mp3 => AudioFormat::Mp3,
856 chat_api::AudioFormat::Flac => AudioFormat::Flac,
857 chat_api::AudioFormat::Opus => AudioFormat::Opus,
858 chat_api::AudioFormat::Pcm16 => AudioFormat::Linear16,
859 chat_api::AudioFormat::Aac => AudioFormat::Aac,
860 }
861}
862
863fn map_openai_function_tool_call(
864 tool_call: ChatCompletionMessageToolCall,
865) -> LanguageModelResult<ToolCallPart> {
866 if tool_call.type_field != "function" {
867 return Err(LanguageModelError::NotImplemented(
868 PROVIDER,
869 format!(
870 "Cannot map OpenAI tool call of type {} to ToolCallPart",
871 tool_call.type_field
872 ),
873 ));
874 }
875
876 let args: Value = serde_json::from_str(&tool_call.function.arguments).map_err(|error| {
877 LanguageModelError::Invariant(
878 PROVIDER,
879 format!("Failed to parse tool call arguments as JSON: {error}"),
880 )
881 })?;
882
883 Ok(ToolCallPart {
884 tool_call_id: tool_call.id,
885 tool_name: tool_call.function.name,
886 args,
887 id: None,
888 })
889}
890
891fn map_openai_delta(
892 delta: ChatCompletionStreamResponseDelta,
893 existing_content_deltas: &[ContentDelta],
894 audio_params: Option<&ChatCompletionAudioParams>,
895) -> LanguageModelResult<Vec<ContentDelta>> {
896 let mut content_deltas = Vec::new();
897
898 if let Some(content) = delta.content {
899 if !content.is_empty() {
900 let part = PartDelta::Text(crate::TextPartDelta {
901 text: content,
902 citation: None,
903 });
904 let combined = existing_content_deltas
905 .iter()
906 .chain(content_deltas.iter())
907 .collect::<Vec<_>>();
908 let index = stream_utils::guess_delta_index(&part, &combined, None);
909 content_deltas.push(ContentDelta { index, part });
910 }
911 }
912
913 if let Some(audio) = delta.audio {
914 let mut audio_part = crate::AudioPartDelta {
915 data: audio.data,
916 format: audio_params.map(|params| map_openai_audio_format(¶ms.format)),
917 sample_rate: None,
918 channels: None,
919 transcript: audio.transcript,
920 id: audio.id,
921 };
922
923 if audio_part.format == Some(AudioFormat::Linear16) {
924 audio_part.sample_rate = Some(OPENAI_AUDIO_SAMPLE_RATE);
925 audio_part.channels = Some(OPENAI_AUDIO_CHANNELS);
926 }
927
928 let part = PartDelta::Audio(audio_part);
929 let combined = existing_content_deltas
930 .iter()
931 .chain(content_deltas.iter())
932 .collect::<Vec<_>>();
933 let index = stream_utils::guess_delta_index(&part, &combined, None);
934 content_deltas.push(ContentDelta { index, part });
935 }
936
937 if let Some(tool_calls) = delta.tool_calls {
938 for tool_call in tool_calls {
939 let mut part = crate::ToolCallPartDelta {
940 tool_call_id: tool_call.id,
941 tool_name: None,
942 args: None,
943 id: None,
944 };
945
946 if let Some(function) = tool_call.function {
947 if part.tool_name.is_none() {
948 part.tool_name = function.name;
949 }
950 if part.args.is_none() {
951 part.args = function.arguments;
952 }
953 }
954
955 let part = PartDelta::ToolCall(part);
956
957 let combined = existing_content_deltas
958 .iter()
959 .chain(content_deltas.iter())
960 .collect::<Vec<_>>();
961 let index = stream_utils::guess_delta_index(
962 &part,
963 &combined,
964 Some(usize::try_from(tool_call.index).map_err(|_| {
965 LanguageModelError::Invariant(
966 PROVIDER,
967 "Received negative tool call index from OpenAI stream".to_string(),
968 )
969 })?),
970 );
971 content_deltas.push(ContentDelta { index, part });
972 }
973 }
974
975 Ok(content_deltas)
976}
977
978fn map_openai_usage(usage: CompletionUsage) -> LanguageModelResult<ModelUsage> {
979 let input_tokens = u32::try_from(usage.prompt_tokens).map_err(|_| {
980 LanguageModelError::Invariant(
981 PROVIDER,
982 "OpenAI prompt_tokens exceeded u32 range".to_string(),
983 )
984 })?;
985 let output_tokens = u32::try_from(usage.completion_tokens).map_err(|_| {
986 LanguageModelError::Invariant(
987 PROVIDER,
988 "OpenAI completion_tokens exceeded u32 range".to_string(),
989 )
990 })?;
991
992 let mut result = ModelUsage {
993 input_tokens,
994 output_tokens,
995 input_tokens_details: None,
996 output_tokens_details: None,
997 };
998
999 if let Some(details) = usage.prompt_tokens_details {
1000 result.input_tokens_details = Some(map_openai_prompt_tokens_details(details)?);
1001 }
1002
1003 if let Some(details) = &usage.completion_tokens_details {
1004 result.output_tokens_details = Some(map_openai_completion_tokens_details(details)?);
1005 }
1006
1007 Ok(result)
1008}
1009
1010fn map_openai_prompt_tokens_details(
1011 details: chat_api::PromptTokensDetails,
1012) -> LanguageModelResult<crate::ModelTokensDetails> {
1013 let mut result = crate::ModelTokensDetails::default();
1014
1015 if let Some(text_tokens) = details.text_tokens {
1016 result.text_tokens = Some(u32::try_from(text_tokens).map_err(|_| {
1017 LanguageModelError::Invariant(
1018 PROVIDER,
1019 "OpenAI text prompt tokens exceeded u32 range".to_string(),
1020 )
1021 })?);
1022 }
1023
1024 if let Some(audio_tokens) = details.audio_tokens {
1025 result.audio_tokens = Some(u32::try_from(audio_tokens).map_err(|_| {
1026 LanguageModelError::Invariant(
1027 PROVIDER,
1028 "OpenAI audio prompt tokens exceeded u32 range".to_string(),
1029 )
1030 })?);
1031 }
1032
1033 if let Some(image_tokens) = details.image_tokens {
1034 result.image_tokens = Some(u32::try_from(image_tokens).map_err(|_| {
1035 LanguageModelError::Invariant(
1036 PROVIDER,
1037 "OpenAI image prompt tokens exceeded u32 range".to_string(),
1038 )
1039 })?);
1040 }
1041
1042 if let Some(cached_details) = details.cached_tokens_details {
1043 if let Some(text_tokens) = cached_details.text_tokens {
1044 result.cached_text_tokens = Some(u32::try_from(text_tokens).map_err(|_| {
1045 LanguageModelError::Invariant(
1046 PROVIDER,
1047 "OpenAI cached text prompt tokens exceeded u32 range".to_string(),
1048 )
1049 })?);
1050 }
1051 if let Some(audio_tokens) = cached_details.audio_tokens {
1052 result.cached_audio_tokens = Some(u32::try_from(audio_tokens).map_err(|_| {
1053 LanguageModelError::Invariant(
1054 PROVIDER,
1055 "OpenAI cached audio prompt tokens exceeded u32 range".to_string(),
1056 )
1057 })?);
1058 }
1059 }
1060
1061 Ok(result)
1062}
1063
1064fn map_openai_completion_tokens_details(
1065 details: &chat_api::CompletionTokensDetails,
1066) -> LanguageModelResult<crate::ModelTokensDetails> {
1067 let mut result = crate::ModelTokensDetails::default();
1068
1069 if let Some(text_tokens) = details.text_tokens {
1070 result.text_tokens = Some(u32::try_from(text_tokens).map_err(|_| {
1071 LanguageModelError::Invariant(
1072 PROVIDER,
1073 "OpenAI text completion tokens exceeded u32 range".to_string(),
1074 )
1075 })?);
1076 }
1077
1078 if let Some(audio_tokens) = details.audio_tokens {
1079 result.audio_tokens = Some(u32::try_from(audio_tokens).map_err(|_| {
1080 LanguageModelError::Invariant(
1081 PROVIDER,
1082 "OpenAI audio completion tokens exceeded u32 range".to_string(),
1083 )
1084 })?);
1085 }
1086
1087 Ok(result)
1088}