1mod json_normalizer;
2
3use async_trait::async_trait;
4use derive_builder::Builder;
5use futures_util::StreamExt;
6use reqwest::Client;
7#[cfg(not(target_arch = "wasm32"))]
8use reqwest_eventsource::{Event, RequestBuilderExt};
9use serde::{Deserialize, Serialize};
10use std::sync::Arc;
11use tokio::sync::RwLock;
12use tokio::sync::mpsc;
13
14use crate::llm::capability::Capability;
15use crate::llm::completion::{
16 FinishReason, ModelSelector, ProviderType, RawCompletionEvent, RawCompletionEventStream,
17 RawCompletionRequest, RawCompletionResponse, RawInputContent, RawInputItem, RawOutputContent,
18 RawOutputItem, Role, ToolChoice as RawToolChoice, Usage as CompletionUsage,
19};
20use crate::llm::error::{Error, LlmResult, OpenAIConfigError};
21use crate::llm::model::Model;
22use crate::llm::provider::LlmProvider;
23use crate::llm::response::RawResponseFormat;
24use crate::llm::tools::{RawToolCall, RawToolDefinition};
25use crate::llm::transcription::{
26 AudioSource, AudioTranscriptionRequest, AudioTranscriptionResponse, TranscriptionLanguage,
27 TranscriptionPrompt,
28};
29use json_normalizer::normalize_openai_schema;
30use serde_json::{Value, json};
31
32#[derive(Debug, Clone)]
33pub struct OpenAIConfig {
34 pub api_key: String,
35 pub base_url: String,
36 pub organization: Option<String>,
37 pub default_model: String,
38}
39
40impl OpenAIConfig {
41 pub fn new(
42 api_key: impl Into<String>,
43 default_model: impl Into<String>,
44 ) -> Result<Self, OpenAIConfigError> {
45 let api_key = api_key.into();
46 if api_key.is_empty() {
47 return Err(OpenAIConfigError::MissingApiKey);
48 }
49 Ok(Self {
50 api_key,
51 base_url: "https://api.openai.com".to_string(),
52 organization: None,
53 default_model: default_model.into(),
54 })
55 }
56
57 pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
58 self.base_url = base_url.into();
59 self
60 }
61
62 pub fn with_organization(mut self, org: impl Into<String>) -> Self {
63 self.organization = Some(org.into());
64 self
65 }
66}
67
68pub struct OpenAI {
69 client: Client,
70 config: OpenAIConfig,
71 cached_models: Arc<RwLock<Option<Vec<Model>>>>,
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct ChatMessage {
76 pub role: String,
77 pub content: Option<String>,
78 pub name: Option<String>,
79 #[serde(skip_serializing_if = "Option::is_none")]
80 pub tool_call_id: Option<String>,
81 pub tool_calls: Option<Vec<ChatToolCall>>,
82}
83
84#[derive(Debug, Clone, Serialize, Deserialize)]
85#[serde(rename_all = "camelCase")]
86pub struct ChatToolCall {
87 pub id: String,
88 pub r#type: String,
89 pub function: ChatToolCallFunction,
90}
91
92#[derive(Debug, Clone, Serialize, Deserialize)]
93#[serde(rename_all = "camelCase")]
94pub struct ChatToolCallFunction {
95 pub name: String,
96 pub arguments: String,
97}
98
99#[derive(Debug, Clone, Builder, Serialize, Deserialize)]
100#[serde(rename_all = "camelCase")]
101pub struct ChatRequest {
102 pub model: String,
103 pub messages: Vec<ChatMessage>,
104 pub temperature: Option<f32>,
105 pub top_p: Option<f32>,
106 pub max_tokens: Option<u32>,
107 pub stream: Option<bool>,
108 pub tools: Option<Vec<ToolDefinition>>,
109 pub tool_choice: Option<ToolChoice>,
110 pub response_format: Option<ResponseFormat>,
111}
112
113#[derive(Debug, Clone, Serialize, Deserialize)]
114#[serde(rename_all = "camelCase")]
115pub struct ToolDefinition {
116 pub r#type: String,
117 pub function: ToolFunction,
118}
119
120#[derive(Debug, Clone, Serialize, Deserialize)]
121#[serde(rename_all = "camelCase")]
122pub struct ToolFunction {
123 pub name: String,
124 pub description: Option<String>,
125 pub parameters: serde_json::Value,
126}
127
128#[derive(Debug, Clone, Serialize, Deserialize)]
129#[serde(rename_all = "camelCase")]
130pub struct ToolChoice {
131 pub r#type: String,
132 pub function: Option<ToolChoiceFunction>,
133}
134
135#[derive(Debug, Clone, Serialize, Deserialize)]
136#[serde(rename_all = "camelCase")]
137pub struct ToolChoiceFunction {
138 pub name: String,
139}
140
141#[derive(Debug, Clone, Serialize, Deserialize)]
142#[serde(rename_all = "camelCase")]
143pub struct ResponseFormat {
144 pub r#type: String,
145 pub json_schema: Option<JsonSchema>,
146}
147
148#[derive(Debug, Clone, Serialize, Deserialize)]
149#[serde(rename_all = "camelCase")]
150pub struct JsonSchema {
151 pub name: String,
152 pub strict: Option<bool>,
153 pub schema: serde_json::Value,
154}
155
156#[derive(Debug, Clone, Serialize, Deserialize)]
157pub struct ChatResponse {
158 pub id: String,
159 pub object: String,
160 pub created: u64,
161 pub model: String,
162 pub choices: Vec<Choice>,
163 pub usage: Usage,
164}
165
166#[derive(Debug, Clone, Serialize, Deserialize)]
167pub struct ChatStreamChunk {
168 pub id: String,
169 pub object: String,
170 pub created: u64,
171 pub model: String,
172 pub choices: Vec<StreamChoice>,
173}
174
175#[derive(Debug, Clone, Serialize, Deserialize)]
176#[serde(rename_all = "camelCase")]
177pub struct StreamChoice {
178 pub index: u32,
179 pub delta: StreamDelta,
180 pub finish_reason: Option<String>,
181}
182
183#[derive(Debug, Clone, Default, Serialize, Deserialize)]
184#[serde(rename_all = "camelCase")]
185pub struct StreamDelta {
186 pub role: Option<String>,
187 pub content: Option<String>,
188 pub tool_calls: Option<Vec<ChatToolCall>>,
189}
190
191#[derive(Debug, Clone, Serialize, Deserialize)]
192#[serde(rename_all = "camelCase")]
193pub struct Choice {
194 pub index: u32,
195 pub message: ChatMessage,
196 pub finish_reason: Option<String>,
197}
198
199#[derive(Debug, Clone, Serialize, Deserialize)]
200#[serde(rename_all = "camelCase")]
201pub struct Usage {
202 pub prompt_tokens: u32,
203 pub completion_tokens: u32,
204 pub total_tokens: u32,
205}
206
207#[derive(Debug, Clone, Builder, Serialize, Deserialize)]
208#[serde(rename_all = "snake_case")]
209pub struct ResponsesRequest {
210 pub model: String,
211 pub input: Vec<ResponseInputItem>,
212 #[serde(skip_serializing_if = "Option::is_none")]
213 pub temperature: Option<f32>,
214 #[serde(skip_serializing_if = "Option::is_none")]
215 pub top_p: Option<f32>,
216 #[serde(skip_serializing_if = "Option::is_none")]
217 pub max_output_tokens: Option<u32>,
218 #[serde(skip_serializing_if = "Option::is_none")]
219 pub stream: Option<bool>,
220 #[serde(skip_serializing_if = "Option::is_none")]
221 pub tools: Option<Vec<ResponseToolDefinition>>,
222 #[serde(skip_serializing_if = "Option::is_none")]
223 pub tool_choice: Option<Value>,
224 #[serde(skip_serializing_if = "Option::is_none")]
225 pub text: Option<ResponseTextConfig>,
226}
227
228#[derive(Debug, Clone, Serialize, Deserialize)]
229#[serde(rename_all = "snake_case", tag = "type")]
230pub enum ResponseInputItem {
231 Message {
232 role: String,
233 content: Vec<ResponseContent>,
234 },
235 FunctionCall {
236 call_id: String,
237 name: String,
238 arguments: String,
239 },
240 FunctionCallOutput {
241 call_id: String,
242 output: String,
243 },
244}
245
246#[derive(Debug, Clone, Serialize, Deserialize)]
247#[serde(rename_all = "snake_case", tag = "type")]
248pub enum ResponseContent {
249 InputText { text: String },
250 OutputText { text: String },
251 InputImage { image_url: String },
252}
253
254#[derive(Debug, Clone, Serialize, Deserialize)]
255#[serde(rename_all = "snake_case")]
256pub struct ResponseToolDefinition {
257 pub r#type: String,
258 pub name: String,
259 #[serde(skip_serializing_if = "Option::is_none")]
260 pub description: Option<String>,
261 pub parameters: serde_json::Value,
262 pub strict: bool,
263}
264
265#[derive(Debug, Clone, Serialize, Deserialize)]
266#[serde(rename_all = "snake_case")]
267pub struct ResponseTextConfig {
268 pub format: ResponseTextFormat,
269}
270
271#[derive(Debug, Clone, Serialize, Deserialize)]
272#[serde(rename_all = "snake_case", tag = "type")]
273pub enum ResponseTextFormat {
274 Text,
275 JsonSchema {
276 name: String,
277 schema: serde_json::Value,
278 #[serde(skip_serializing_if = "Option::is_none")]
279 description: Option<String>,
280 #[serde(skip_serializing_if = "Option::is_none")]
281 strict: Option<bool>,
282 },
283 JsonObject,
284}
285
286#[derive(Debug, Clone, Builder, Serialize, Deserialize)]
287#[serde(rename_all = "camelCase")]
288pub struct EvalCreateRequest {
289 pub model: String,
290 pub dataset_id: String,
291 pub subject: Option<String>,
292 pub metrics: Option<Vec<EvalMetric>>,
293}
294
295#[derive(Debug, Clone, Serialize, Deserialize)]
296#[serde(rename_all = "camelCase")]
297pub struct EvalMetric {
298 pub r#type: String,
299}
300
301#[derive(Debug, Clone, Serialize, Deserialize)]
302#[serde(rename_all = "camelCase")]
303pub struct Eval {
304 pub id: String,
305 pub object: String,
306 pub created: u64,
307 pub status: String,
308 pub model: String,
309 pub dataset_id: String,
310 pub metrics: Option<serde_json::Value>,
311}
312
313#[derive(Debug, Clone, Serialize, Deserialize)]
314pub struct EvalListResponse {
315 pub data: Vec<Eval>,
316 pub first_id: Option<String>,
317 pub last_id: Option<String>,
318 pub has_more: bool,
319}
320
321#[derive(Debug, Clone, Deserialize)]
322struct ResponseOutputTextDeltaEvent {
323 delta: String,
324}
325
326#[derive(Debug, Clone, Deserialize)]
327struct ResponseOutputItemEvent {
328 item: ResponseOutputItem,
329}
330
331#[derive(Debug, Clone, Deserialize)]
332struct ResponseOutputItem {
333 id: String,
334 #[serde(default)]
335 call_id: Option<String>,
336 #[serde(rename = "type")]
337 item_type: String,
338 #[serde(default)]
339 name: Option<String>,
340 #[serde(default)]
341 arguments: Option<String>,
342}
343
344#[derive(Debug, Clone, Deserialize)]
345struct ResponseFunctionCallArgumentsDeltaEvent {
346 item_id: String,
347 delta: String,
348}
349
350#[derive(Debug, Clone, Deserialize)]
351struct ResponseCompletedEvent {
352 response: Value,
353}
354
355impl OpenAI {
356 pub fn new(config: OpenAIConfig) -> Self {
357 let client = Client::builder()
358 .build()
359 .expect("failed to build reqwest client");
360 Self {
361 client,
362 config,
363 cached_models: Arc::new(RwLock::new(None)),
364 }
365 }
366
367 pub fn auth_header(&self) -> String {
368 format!("Bearer {}", self.config.api_key)
369 }
370
371 pub async fn chat(&self, request: &ChatRequest) -> LlmResult<ChatResponse> {
372 let url = format!("{}/v1/chat/completions", self.config.base_url);
373 let auth = self.auth_header();
374
375 let mut req_builder = self
376 .client
377 .post(&url)
378 .header("Authorization", auth)
379 .header("Content-Type", "application/json");
380
381 if let Some(ref org) = self.config.organization {
382 req_builder = req_builder.header("OpenAI-Organization", org);
383 }
384
385 let response = req_builder.json(request).send().await?;
386
387 if !response.status().is_success() {
388 let status = response.status();
389 let body = response.text().await.unwrap_or_default();
390 return Err(Error::Provider {
391 provider: "openai".to_string(),
392 status: status.as_u16(),
393 message: body,
394 });
395 }
396
397 let body = response.text().await?;
398 let parsed: ChatResponse =
399 serde_json::from_str(&body).map_err(|e| Error::parse(body, e))?;
400 Ok(parsed)
401 }
402
403 pub async fn responses(&self, request: &ResponsesRequest) -> LlmResult<Value> {
404 let url = format!("{}/v1/responses", self.config.base_url);
405 let auth = self.auth_header();
406
407 let response = self
408 .client
409 .post(&url)
410 .header("Authorization", auth)
411 .header("Content-Type", "application/json")
412 .json(request)
413 .send()
414 .await?;
415
416 if !response.status().is_success() {
417 let status = response.status();
418 let body = response.text().await.unwrap_or_default();
419 return Err(Error::Provider {
420 provider: "openai".to_string(),
421 status: status.as_u16(),
422 message: body,
423 });
424 }
425
426 let body = response.text().await?;
427 let parsed: Value = serde_json::from_str(&body).map_err(|e| Error::parse(body, e))?;
428 Ok(parsed)
429 }
430
431 pub async fn create_eval(&self, request: &EvalCreateRequest) -> LlmResult<Eval> {
432 let url = format!("{}/v1/evals", self.config.base_url);
433 let auth = self.auth_header();
434
435 let response = self
436 .client
437 .post(&url)
438 .header("Authorization", auth)
439 .header("Content-Type", "application/json")
440 .json(request)
441 .send()
442 .await?;
443
444 if !response.status().is_success() {
445 let status = response.status();
446 let body = response.text().await.unwrap_or_default();
447 return Err(Error::Provider {
448 provider: "openai".to_string(),
449 status: status.as_u16(),
450 message: body,
451 });
452 }
453
454 let body = response.text().await?;
455 let parsed: Eval = serde_json::from_str(&body).map_err(|e| Error::parse(body, e))?;
456 Ok(parsed)
457 }
458
459 pub async fn list_evals(&self) -> LlmResult<EvalListResponse> {
460 let url = format!("{}/v1/evals", self.config.base_url);
461 let auth = self.auth_header();
462
463 let response = self
464 .client
465 .get(&url)
466 .header("Authorization", auth)
467 .send()
468 .await?;
469
470 if !response.status().is_success() {
471 let status = response.status();
472 let body = response.text().await.unwrap_or_default();
473 return Err(Error::Provider {
474 provider: "openai".to_string(),
475 status: status.as_u16(),
476 message: body,
477 });
478 }
479
480 let body = response.text().await?;
481 let parsed: EvalListResponse =
482 serde_json::from_str(&body).map_err(|e| Error::parse(body, e))?;
483 Ok(parsed)
484 }
485}
486
487#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
488#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
489impl LlmProvider for OpenAI {
490 fn provider_type(&self) -> ProviderType {
491 ProviderType::OpenAI
492 }
493
494 fn provider_name(&self) -> &'static str {
495 "openai"
496 }
497
498 fn capabilities(&self) -> &[Capability] {
499 &[
500 Capability::ChatCompletion,
501 Capability::AudioTranscription,
502 Capability::Evals,
503 ]
504 }
505
506 async fn available_models(&self) -> LlmResult<Vec<Model>> {
507 let mut cache = self.cached_models.write().await;
508 if let Some(ref models) = *cache {
509 return Ok(models.clone());
510 }
511
512 let models = vec![
513 Model::new("gpt-4o"),
514 Model::new("gpt-4o-mini"),
515 Model::new("gpt-4-turbo"),
516 Model::new("gpt-4"),
517 Model::new("gpt-3.5-turbo"),
518 ];
519
520 *cache = Some(models.clone());
521 Ok(models)
522 }
523
524 async fn chat_raw(&self, req: RawCompletionRequest) -> LlmResult<RawCompletionResponse> {
525 let responses_req = build_responses_request(&self.config.default_model, req)?;
526 let response = self.responses(&responses_req).await?;
527 parse_responses_response(response)
528 }
529
530 #[cfg(not(target_arch = "wasm32"))]
531 async fn chat_raw_stream(
532 &self,
533 req: RawCompletionRequest,
534 ) -> LlmResult<RawCompletionEventStream> {
535 let responses_req = build_responses_request(&self.config.default_model, req)?;
536 let url = format!("{}/v1/responses", self.config.base_url);
537 let auth = self.auth_header();
538 let mut req_builder = self
539 .client
540 .post(&url)
541 .header("Authorization", auth)
542 .header("Content-Type", "application/json");
543
544 if let Some(ref org) = self.config.organization {
545 req_builder = req_builder.header("OpenAI-Organization", org);
546 }
547
548 let event_source = req_builder
549 .json(&responses_req)
550 .eventsource()
551 .map_err(|error| Error::from_eventsource_builder("openai", error))?;
552
553 let (sender, receiver) = mpsc::channel(32);
554
555 tokio::spawn(async move {
556 let mut event_source = event_source;
557 let mut function_calls: std::collections::HashMap<String, (String, String, String)> =
558 std::collections::HashMap::new();
559
560 while let Some(event) = event_source.next().await {
561 match event {
562 Ok(Event::Open) => {}
563 Ok(Event::Message(message)) => match message.event.as_str() {
564 "response.output_text.delta" => {
565 let parsed: ResponseOutputTextDeltaEvent =
566 match serde_json::from_str(&message.data) {
567 Ok(parsed) => parsed,
568 Err(error) => {
569 let _ = sender
570 .send(Err(Error::parse(message.data, error)))
571 .await;
572 let _ = event_source.close();
573 return;
574 }
575 };
576 if sender
577 .send(Ok(RawCompletionEvent::TextDelta { text: parsed.delta }))
578 .await
579 .is_err()
580 {
581 let _ = event_source.close();
582 return;
583 }
584 }
585 "response.output_item.added" | "response.output_item.done" => {
586 let parsed: ResponseOutputItemEvent =
587 match serde_json::from_str(&message.data) {
588 Ok(parsed) => parsed,
589 Err(error) => {
590 let _ = sender
591 .send(Err(Error::parse(message.data, error)))
592 .await;
593 let _ = event_source.close();
594 return;
595 }
596 };
597 if parsed.item.item_type == "function_call" {
598 let item_id = parsed.item.id;
599 let call_id = parsed
600 .item
601 .call_id
602 .clone()
603 .unwrap_or_else(|| item_id.clone());
604 let name = parsed.item.name.unwrap_or_default();
605 let arguments = parsed.item.arguments.unwrap_or_default();
606 function_calls.insert(
607 item_id.clone(),
608 (call_id.clone(), name.clone(), arguments.clone()),
609 );
610 if message.event == "response.output_item.done" {
611 match parse_function_call(&call_id, &name, &arguments) {
612 Ok(call) => {
613 if sender
614 .send(Ok(RawCompletionEvent::ToolCall { call }))
615 .await
616 .is_err()
617 {
618 let _ = event_source.close();
619 return;
620 }
621 }
622 Err(error) => {
623 let _ = sender.send(Err(error)).await;
624 let _ = event_source.close();
625 return;
626 }
627 }
628 }
629 }
630 }
631 "response.function_call_arguments.delta" => {
632 let parsed: ResponseFunctionCallArgumentsDeltaEvent =
633 match serde_json::from_str(&message.data) {
634 Ok(parsed) => parsed,
635 Err(error) => {
636 let _ = sender
637 .send(Err(Error::parse(message.data, error)))
638 .await;
639 let _ = event_source.close();
640 return;
641 }
642 };
643 if let Some((_, _, arguments)) = function_calls.get_mut(&parsed.item_id)
644 {
645 arguments.push_str(&parsed.delta);
646 }
647 }
648 "response.completed" => {
649 let parsed: ResponseCompletedEvent =
650 match serde_json::from_str(&message.data) {
651 Ok(parsed) => parsed,
652 Err(error) => {
653 let _ = sender
654 .send(Err(Error::parse(message.data, error)))
655 .await;
656 let _ = event_source.close();
657 return;
658 }
659 };
660 let final_response = match parse_responses_response(parsed.response) {
661 Ok(response) => response,
662 Err(error) => {
663 let _ = sender.send(Err(error)).await;
664 let _ = event_source.close();
665 return;
666 }
667 };
668 let _ = sender
669 .send(Ok(RawCompletionEvent::Done(final_response)))
670 .await;
671 break;
672 }
673 "response.failed" => {
674 let _ = sender
675 .send(Err(Error::InvalidResponse {
676 reason: format!("OpenAI stream failed: {}", message.data),
677 }))
678 .await;
679 let _ = event_source.close();
680 return;
681 }
682 _ => {}
683 },
684 Err(error) => {
685 let _ = sender
686 .send(Err(Error::from_eventsource("openai", error)))
687 .await;
688 let _ = event_source.close();
689 return;
690 }
691 }
692 }
693
694 let _ = event_source.close();
695 });
696
697 Ok(RawCompletionEventStream::new(receiver))
698 }
699
700 async fn transcribe(
701 &self,
702 req: AudioTranscriptionRequest,
703 ) -> LlmResult<AudioTranscriptionResponse> {
704 let url = format!("{}/v1/audio/transcriptions", self.config.base_url);
705
706 let model = match &req.model {
707 ModelSelector::Any | ModelSelector::Provider(_) => "whisper-1".to_string(),
708 ModelSelector::Specific { model, .. } => model.clone(),
709 };
710
711 let (audio_data, file_name, mime_type) = match &req.audio {
712 AudioSource::Data(data) => (
713 data.clone(),
714 "audio.wav".to_string(),
715 "audio/wav".to_string(),
716 ),
717 AudioSource::Url(_) => {
718 return Err(Error::InvalidRequest {
719 reason: "URL audio not supported yet".to_string(),
720 });
721 }
722 AudioSource::Path(path) => (
723 std::fs::read(path).map_err(|e| Error::InvalidRequest {
724 reason: e.to_string(),
725 })?,
726 path.file_name()
727 .and_then(|name| name.to_str())
728 .unwrap_or("audio")
729 .to_string(),
730 match path.extension().and_then(|ext| ext.to_str()) {
731 Some("ogg") => "audio/ogg",
732 Some("mp3") => "audio/mpeg",
733 Some("m4a") => "audio/mp4",
734 Some("wav") => "audio/wav",
735 Some("webm") => "audio/webm",
736 Some("flac") => "audio/flac",
737 _ => "application/octet-stream",
738 }
739 .to_string(),
740 ),
741 };
742
743 let part = reqwest::multipart::Part::bytes(audio_data)
744 .file_name(file_name)
745 .mime_str(&mime_type)
746 .map_err(|e| Error::InvalidRequest {
747 reason: e.to_string(),
748 })?;
749
750 let mut form = reqwest::multipart::Form::new()
751 .text("model", model.clone())
752 .part("file", part);
753
754 if let TranscriptionLanguage::Explicit { language } = req.language {
755 form = form.text("language", language);
756 }
757
758 if let TranscriptionPrompt::Hint { text } = req.prompt {
759 form = form.text("prompt", text);
760 }
761
762 if let Some(response_format) = req.response_format.as_openai_str() {
763 form = form.text("response_format", response_format.to_string());
764 }
765
766 let response = self
767 .client
768 .post(&url)
769 .header("Authorization", self.auth_header())
770 .multipart(form)
771 .send()
772 .await?;
773
774 if !response.status().is_success() {
775 let status = response.status();
776 let body = response.text().await.unwrap_or_default();
777 return Err(Error::Provider {
778 provider: "openai".to_string(),
779 status: status.as_u16(),
780 message: body,
781 });
782 }
783
784 let body = response.text().await?;
785 #[derive(Deserialize)]
786 struct TranscriptionResponse {
787 text: String,
788 }
789 let parsed: TranscriptionResponse =
790 serde_json::from_str(&body).map_err(|e| Error::parse(body, e))?;
791
792 Ok(AudioTranscriptionResponse {
793 provider: ProviderType::OpenAI,
794 model,
795 text: parsed.text,
796 })
797 }
798}
799
800fn build_responses_request(
801 default_model: &str,
802 req: RawCompletionRequest,
803) -> LlmResult<ResponsesRequest> {
804 let model = match req.model {
805 ModelSelector::Any => default_model.to_string(),
806 ModelSelector::Provider(_) => default_model.to_string(),
807 ModelSelector::Specific { model, .. } => model,
808 };
809
810 let input = req
811 .input
812 .into_iter()
813 .map(|item| -> LlmResult<ResponseInputItem> {
814 Ok(match item {
815 RawInputItem::Message { role, content } => ResponseInputItem::Message {
816 role: match role {
817 Role::System => "system".to_string(),
818 Role::User => "user".to_string(),
819 Role::Assistant => "assistant".to_string(),
820 },
821 content: content
822 .into_iter()
823 .map(|content| match content {
824 RawInputContent::Text { text } => match role {
825 Role::Assistant => ResponseContent::OutputText { text },
826 Role::System | Role::User => ResponseContent::InputText { text },
827 },
828 RawInputContent::ImageUrl { url } => {
829 ResponseContent::InputImage { image_url: url }
830 }
831 })
832 .collect(),
833 },
834 RawInputItem::ToolCall { call } => ResponseInputItem::FunctionCall {
835 call_id: call.id,
836 name: call.name,
837 arguments: serde_json::to_string(&call.arguments)
838 .map_err(|error| Error::parse("tool call arguments", error))?,
839 },
840 RawInputItem::ToolResult {
841 tool_use_id,
842 content,
843 } => ResponseInputItem::FunctionCallOutput {
844 call_id: tool_use_id,
845 output: content,
846 },
847 })
848 })
849 .collect::<LlmResult<Vec<_>>>()?;
850
851 Ok(ResponsesRequest {
852 model,
853 input,
854 temperature: req.temperature.as_option(),
855 top_p: req.top_p.as_option(),
856 max_output_tokens: req.token_limit.as_option(),
857 stream: Some(req.response_mode.is_streaming()),
858 tools: req.tools.map(map_response_tools),
859 tool_choice: map_responses_tool_choice(req.tool_choice),
860 text: req.response_format.map(map_response_text_config),
861 })
862}
863
864fn map_response_tools(tools: Vec<RawToolDefinition>) -> Vec<ResponseToolDefinition> {
865 tools
866 .into_iter()
867 .map(|tool| ResponseToolDefinition {
868 r#type: tool.kind,
869 name: tool.function.name,
870 description: tool.function.description,
871 parameters: normalize_openai_schema(tool.function.parameters),
872 strict: true,
873 })
874 .collect()
875}
876
877fn map_responses_tool_choice(choice: RawToolChoice) -> Option<Value> {
878 match choice {
879 RawToolChoice::ProviderDefault => None,
880 RawToolChoice::Auto => Some(json!("auto")),
881 RawToolChoice::Required => Some(json!("required")),
882 RawToolChoice::Specific { name } => Some(json!({
883 "type": "function",
884 "name": name,
885 })),
886 RawToolChoice::None => Some(json!("none")),
887 }
888}
889
890fn map_response_text_config(format: RawResponseFormat) -> ResponseTextConfig {
891 ResponseTextConfig {
892 format: match format.json_schema {
893 Some(schema) => ResponseTextFormat::JsonSchema {
894 name: schema.name,
895 schema: normalize_openai_schema(schema.schema),
896 description: None,
897 strict: schema.strict,
898 },
899 None if format.r#type == "json_object" => ResponseTextFormat::JsonObject,
900 None => ResponseTextFormat::Text,
901 },
902 }
903}
904
905fn parse_responses_response(value: Value) -> LlmResult<RawCompletionResponse> {
906 let model = value
907 .get("model")
908 .and_then(Value::as_str)
909 .ok_or(Error::InvalidResponse {
910 reason: "OpenAI responses payload missing model".to_string(),
911 })?
912 .to_string();
913
914 let output_values =
915 value
916 .get("output")
917 .and_then(Value::as_array)
918 .ok_or(Error::InvalidResponse {
919 reason: "OpenAI responses payload missing output".to_string(),
920 })?;
921
922 let mut output = Vec::new();
923 let mut saw_tool_call = false;
924
925 for item in output_values {
926 match item.get("type").and_then(Value::as_str) {
927 Some("message") => {
928 let mut content = Vec::new();
929 if let Some(parts) = item.get("content").and_then(Value::as_array) {
930 for part in parts {
931 match part.get("type").and_then(Value::as_str) {
932 Some("output_text") => {
933 if let Some(text) = part.get("text").and_then(Value::as_str) {
934 content.push(RawOutputContent::Text {
935 text: text.to_string(),
936 });
937 }
938 }
939 Some("output_json") => {
940 if let Some(json) = part.get("json") {
941 content.push(RawOutputContent::Json {
942 value: json.clone(),
943 });
944 }
945 }
946 _ => {}
947 }
948 }
949 }
950 if !content.is_empty() {
951 output.push(RawOutputItem::Message {
952 role: Role::Assistant,
953 content,
954 });
955 }
956 }
957 Some("function_call") => {
958 let call_id = item
959 .get("call_id")
960 .and_then(Value::as_str)
961 .or_else(|| item.get("id").and_then(Value::as_str))
962 .ok_or(Error::InvalidResponse {
963 reason: "OpenAI function_call missing call id".to_string(),
964 })?;
965 let name =
966 item.get("name")
967 .and_then(Value::as_str)
968 .ok_or(Error::InvalidResponse {
969 reason: "OpenAI function_call missing name".to_string(),
970 })?;
971 let arguments = item.get("arguments").and_then(Value::as_str).ok_or(
972 Error::InvalidResponse {
973 reason: "OpenAI function_call missing arguments".to_string(),
974 },
975 )?;
976 output.push(RawOutputItem::ToolCall {
977 call: parse_function_call(call_id, name, arguments)?,
978 });
979 saw_tool_call = true;
980 }
981 Some("reasoning") => {
982 let summary = item
983 .get("summary")
984 .and_then(Value::as_array)
985 .into_iter()
986 .flatten()
987 .filter_map(|part| part.get("text").and_then(Value::as_str))
988 .collect::<Vec<_>>()
989 .join("\n");
990 if !summary.is_empty() {
991 output.push(RawOutputItem::Reasoning { text: summary });
992 }
993 }
994 _ => {}
995 }
996 }
997
998 let usage = value.get("usage").cloned().unwrap_or_else(|| json!({}));
999 let prompt_tokens = usage
1000 .get("input_tokens")
1001 .and_then(Value::as_u64)
1002 .unwrap_or(0) as u32;
1003 let completion_tokens = usage
1004 .get("output_tokens")
1005 .and_then(Value::as_u64)
1006 .unwrap_or(0) as u32;
1007 let total_tokens = usage
1008 .get("total_tokens")
1009 .and_then(Value::as_u64)
1010 .unwrap_or((prompt_tokens + completion_tokens) as u64) as u32;
1011
1012 Ok(RawCompletionResponse {
1013 provider: ProviderType::OpenAI,
1014 model,
1015 output,
1016 usage: CompletionUsage {
1017 prompt_tokens,
1018 completion_tokens,
1019 total_tokens,
1020 },
1021 finish_reason: if saw_tool_call {
1022 FinishReason::ToolCalls
1023 } else {
1024 FinishReason::Stop
1025 },
1026 })
1027}
1028
1029fn parse_function_call(call_id: &str, name: &str, arguments: &str) -> LlmResult<RawToolCall> {
1030 Ok(RawToolCall {
1031 id: call_id.to_string(),
1032 name: name.to_string(),
1033 arguments: serde_json::from_str(arguments)
1034 .map_err(|e| Error::parse("tool arguments", e))?,
1035 })
1036}