Skip to main content

agents/llm/provider/
openai.rs

1mod json_normalizer;
2
3use async_trait::async_trait;
4use derive_builder::Builder;
5use futures_util::StreamExt;
6use reqwest::Client;
7#[cfg(not(target_arch = "wasm32"))]
8use reqwest_eventsource::{Event, RequestBuilderExt};
9use serde::{Deserialize, Serialize};
10use std::sync::Arc;
11use tokio::sync::RwLock;
12use tokio::sync::mpsc;
13
14use crate::llm::capability::Capability;
15use crate::llm::completion::{
16    FinishReason, ModelSelector, ProviderType, RawCompletionEvent, RawCompletionEventStream,
17    RawCompletionRequest, RawCompletionResponse, RawInputContent, RawInputItem, RawOutputContent,
18    RawOutputItem, Role, ToolChoice as RawToolChoice, Usage as CompletionUsage,
19};
20use crate::llm::error::{Error, LlmResult, OpenAIConfigError};
21use crate::llm::model::Model;
22use crate::llm::provider::LlmProvider;
23use crate::llm::response::RawResponseFormat;
24use crate::llm::tools::{RawToolCall, RawToolDefinition};
25use crate::llm::transcription::{
26    AudioSource, AudioTranscriptionRequest, AudioTranscriptionResponse, TranscriptionLanguage,
27    TranscriptionPrompt,
28};
29use json_normalizer::normalize_openai_schema;
30use serde_json::{Value, json};
31
32#[derive(Debug, Clone)]
33pub struct OpenAIConfig {
34    pub api_key: String,
35    pub base_url: String,
36    pub organization: Option<String>,
37    pub default_model: String,
38}
39
40impl OpenAIConfig {
41    pub fn new(
42        api_key: impl Into<String>,
43        default_model: impl Into<String>,
44    ) -> Result<Self, OpenAIConfigError> {
45        let api_key = api_key.into();
46        if api_key.is_empty() {
47            return Err(OpenAIConfigError::MissingApiKey);
48        }
49        Ok(Self {
50            api_key,
51            base_url: "https://api.openai.com".to_string(),
52            organization: None,
53            default_model: default_model.into(),
54        })
55    }
56
57    pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
58        self.base_url = base_url.into();
59        self
60    }
61
62    pub fn with_organization(mut self, org: impl Into<String>) -> Self {
63        self.organization = Some(org.into());
64        self
65    }
66}
67
68pub struct OpenAI {
69    client: Client,
70    config: OpenAIConfig,
71    cached_models: Arc<RwLock<Option<Vec<Model>>>>,
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct ChatMessage {
76    pub role: String,
77    pub content: Option<String>,
78    pub name: Option<String>,
79    #[serde(skip_serializing_if = "Option::is_none")]
80    pub tool_call_id: Option<String>,
81    pub tool_calls: Option<Vec<ChatToolCall>>,
82}
83
84#[derive(Debug, Clone, Serialize, Deserialize)]
85#[serde(rename_all = "camelCase")]
86pub struct ChatToolCall {
87    pub id: String,
88    pub r#type: String,
89    pub function: ChatToolCallFunction,
90}
91
92#[derive(Debug, Clone, Serialize, Deserialize)]
93#[serde(rename_all = "camelCase")]
94pub struct ChatToolCallFunction {
95    pub name: String,
96    pub arguments: String,
97}
98
99#[derive(Debug, Clone, Builder, Serialize, Deserialize)]
100#[serde(rename_all = "camelCase")]
101pub struct ChatRequest {
102    pub model: String,
103    pub messages: Vec<ChatMessage>,
104    pub temperature: Option<f32>,
105    pub top_p: Option<f32>,
106    pub max_tokens: Option<u32>,
107    pub stream: Option<bool>,
108    pub tools: Option<Vec<ToolDefinition>>,
109    pub tool_choice: Option<ToolChoice>,
110    pub response_format: Option<ResponseFormat>,
111}
112
113#[derive(Debug, Clone, Serialize, Deserialize)]
114#[serde(rename_all = "camelCase")]
115pub struct ToolDefinition {
116    pub r#type: String,
117    pub function: ToolFunction,
118}
119
120#[derive(Debug, Clone, Serialize, Deserialize)]
121#[serde(rename_all = "camelCase")]
122pub struct ToolFunction {
123    pub name: String,
124    pub description: Option<String>,
125    pub parameters: serde_json::Value,
126}
127
128#[derive(Debug, Clone, Serialize, Deserialize)]
129#[serde(rename_all = "camelCase")]
130pub struct ToolChoice {
131    pub r#type: String,
132    pub function: Option<ToolChoiceFunction>,
133}
134
135#[derive(Debug, Clone, Serialize, Deserialize)]
136#[serde(rename_all = "camelCase")]
137pub struct ToolChoiceFunction {
138    pub name: String,
139}
140
141#[derive(Debug, Clone, Serialize, Deserialize)]
142#[serde(rename_all = "camelCase")]
143pub struct ResponseFormat {
144    pub r#type: String,
145    pub json_schema: Option<JsonSchema>,
146}
147
148#[derive(Debug, Clone, Serialize, Deserialize)]
149#[serde(rename_all = "camelCase")]
150pub struct JsonSchema {
151    pub name: String,
152    pub strict: Option<bool>,
153    pub schema: serde_json::Value,
154}
155
156#[derive(Debug, Clone, Serialize, Deserialize)]
157pub struct ChatResponse {
158    pub id: String,
159    pub object: String,
160    pub created: u64,
161    pub model: String,
162    pub choices: Vec<Choice>,
163    pub usage: Usage,
164}
165
166#[derive(Debug, Clone, Serialize, Deserialize)]
167pub struct ChatStreamChunk {
168    pub id: String,
169    pub object: String,
170    pub created: u64,
171    pub model: String,
172    pub choices: Vec<StreamChoice>,
173}
174
175#[derive(Debug, Clone, Serialize, Deserialize)]
176#[serde(rename_all = "camelCase")]
177pub struct StreamChoice {
178    pub index: u32,
179    pub delta: StreamDelta,
180    pub finish_reason: Option<String>,
181}
182
183#[derive(Debug, Clone, Default, Serialize, Deserialize)]
184#[serde(rename_all = "camelCase")]
185pub struct StreamDelta {
186    pub role: Option<String>,
187    pub content: Option<String>,
188    pub tool_calls: Option<Vec<ChatToolCall>>,
189}
190
191#[derive(Debug, Clone, Serialize, Deserialize)]
192#[serde(rename_all = "camelCase")]
193pub struct Choice {
194    pub index: u32,
195    pub message: ChatMessage,
196    pub finish_reason: Option<String>,
197}
198
199#[derive(Debug, Clone, Serialize, Deserialize)]
200#[serde(rename_all = "camelCase")]
201pub struct Usage {
202    pub prompt_tokens: u32,
203    pub completion_tokens: u32,
204    pub total_tokens: u32,
205}
206
207#[derive(Debug, Clone, Builder, Serialize, Deserialize)]
208#[serde(rename_all = "snake_case")]
209pub struct ResponsesRequest {
210    pub model: String,
211    pub input: Vec<ResponseInputItem>,
212    #[serde(skip_serializing_if = "Option::is_none")]
213    pub temperature: Option<f32>,
214    #[serde(skip_serializing_if = "Option::is_none")]
215    pub top_p: Option<f32>,
216    #[serde(skip_serializing_if = "Option::is_none")]
217    pub max_output_tokens: Option<u32>,
218    #[serde(skip_serializing_if = "Option::is_none")]
219    pub stream: Option<bool>,
220    #[serde(skip_serializing_if = "Option::is_none")]
221    pub tools: Option<Vec<ResponseToolDefinition>>,
222    #[serde(skip_serializing_if = "Option::is_none")]
223    pub tool_choice: Option<Value>,
224    #[serde(skip_serializing_if = "Option::is_none")]
225    pub text: Option<ResponseTextConfig>,
226}
227
228#[derive(Debug, Clone, Serialize, Deserialize)]
229#[serde(rename_all = "snake_case", tag = "type")]
230pub enum ResponseInputItem {
231    Message {
232        role: String,
233        content: Vec<ResponseContent>,
234    },
235    FunctionCall {
236        call_id: String,
237        name: String,
238        arguments: String,
239    },
240    FunctionCallOutput {
241        call_id: String,
242        output: String,
243    },
244}
245
246#[derive(Debug, Clone, Serialize, Deserialize)]
247#[serde(rename_all = "snake_case", tag = "type")]
248pub enum ResponseContent {
249    InputText { text: String },
250    OutputText { text: String },
251    InputImage { image_url: String },
252}
253
254#[derive(Debug, Clone, Serialize, Deserialize)]
255#[serde(rename_all = "snake_case")]
256pub struct ResponseToolDefinition {
257    pub r#type: String,
258    pub name: String,
259    #[serde(skip_serializing_if = "Option::is_none")]
260    pub description: Option<String>,
261    pub parameters: serde_json::Value,
262    pub strict: bool,
263}
264
265#[derive(Debug, Clone, Serialize, Deserialize)]
266#[serde(rename_all = "snake_case")]
267pub struct ResponseTextConfig {
268    pub format: ResponseTextFormat,
269}
270
271#[derive(Debug, Clone, Serialize, Deserialize)]
272#[serde(rename_all = "snake_case", tag = "type")]
273pub enum ResponseTextFormat {
274    Text,
275    JsonSchema {
276        name: String,
277        schema: serde_json::Value,
278        #[serde(skip_serializing_if = "Option::is_none")]
279        description: Option<String>,
280        #[serde(skip_serializing_if = "Option::is_none")]
281        strict: Option<bool>,
282    },
283    JsonObject,
284}
285
286#[derive(Debug, Clone, Builder, Serialize, Deserialize)]
287#[serde(rename_all = "camelCase")]
288pub struct EvalCreateRequest {
289    pub model: String,
290    pub dataset_id: String,
291    pub subject: Option<String>,
292    pub metrics: Option<Vec<EvalMetric>>,
293}
294
295#[derive(Debug, Clone, Serialize, Deserialize)]
296#[serde(rename_all = "camelCase")]
297pub struct EvalMetric {
298    pub r#type: String,
299}
300
301#[derive(Debug, Clone, Serialize, Deserialize)]
302#[serde(rename_all = "camelCase")]
303pub struct Eval {
304    pub id: String,
305    pub object: String,
306    pub created: u64,
307    pub status: String,
308    pub model: String,
309    pub dataset_id: String,
310    pub metrics: Option<serde_json::Value>,
311}
312
313#[derive(Debug, Clone, Serialize, Deserialize)]
314pub struct EvalListResponse {
315    pub data: Vec<Eval>,
316    pub first_id: Option<String>,
317    pub last_id: Option<String>,
318    pub has_more: bool,
319}
320
321#[derive(Debug, Clone, Deserialize)]
322struct ResponseOutputTextDeltaEvent {
323    delta: String,
324}
325
326#[derive(Debug, Clone, Deserialize)]
327struct ResponseOutputItemEvent {
328    item: ResponseOutputItem,
329}
330
331#[derive(Debug, Clone, Deserialize)]
332struct ResponseOutputItem {
333    id: String,
334    #[serde(default)]
335    call_id: Option<String>,
336    #[serde(rename = "type")]
337    item_type: String,
338    #[serde(default)]
339    name: Option<String>,
340    #[serde(default)]
341    arguments: Option<String>,
342}
343
344#[derive(Debug, Clone, Deserialize)]
345struct ResponseFunctionCallArgumentsDeltaEvent {
346    item_id: String,
347    delta: String,
348}
349
350#[derive(Debug, Clone, Deserialize)]
351struct ResponseCompletedEvent {
352    response: Value,
353}
354
355impl OpenAI {
356    pub fn new(config: OpenAIConfig) -> Self {
357        let client = Client::builder()
358            .build()
359            .expect("failed to build reqwest client");
360        Self {
361            client,
362            config,
363            cached_models: Arc::new(RwLock::new(None)),
364        }
365    }
366
367    pub fn auth_header(&self) -> String {
368        format!("Bearer {}", self.config.api_key)
369    }
370
371    pub async fn chat(&self, request: &ChatRequest) -> LlmResult<ChatResponse> {
372        let url = format!("{}/v1/chat/completions", self.config.base_url);
373        let auth = self.auth_header();
374
375        let mut req_builder = self
376            .client
377            .post(&url)
378            .header("Authorization", auth)
379            .header("Content-Type", "application/json");
380
381        if let Some(ref org) = self.config.organization {
382            req_builder = req_builder.header("OpenAI-Organization", org);
383        }
384
385        let response = req_builder.json(request).send().await?;
386
387        if !response.status().is_success() {
388            let status = response.status();
389            let body = response.text().await.unwrap_or_default();
390            return Err(Error::Provider {
391                provider: "openai".to_string(),
392                status: status.as_u16(),
393                message: body,
394            });
395        }
396
397        let body = response.text().await?;
398        let parsed: ChatResponse =
399            serde_json::from_str(&body).map_err(|e| Error::parse(body, e))?;
400        Ok(parsed)
401    }
402
403    pub async fn responses(&self, request: &ResponsesRequest) -> LlmResult<Value> {
404        let url = format!("{}/v1/responses", self.config.base_url);
405        let auth = self.auth_header();
406
407        let response = self
408            .client
409            .post(&url)
410            .header("Authorization", auth)
411            .header("Content-Type", "application/json")
412            .json(request)
413            .send()
414            .await?;
415
416        if !response.status().is_success() {
417            let status = response.status();
418            let body = response.text().await.unwrap_or_default();
419            return Err(Error::Provider {
420                provider: "openai".to_string(),
421                status: status.as_u16(),
422                message: body,
423            });
424        }
425
426        let body = response.text().await?;
427        let parsed: Value = serde_json::from_str(&body).map_err(|e| Error::parse(body, e))?;
428        Ok(parsed)
429    }
430
431    pub async fn create_eval(&self, request: &EvalCreateRequest) -> LlmResult<Eval> {
432        let url = format!("{}/v1/evals", self.config.base_url);
433        let auth = self.auth_header();
434
435        let response = self
436            .client
437            .post(&url)
438            .header("Authorization", auth)
439            .header("Content-Type", "application/json")
440            .json(request)
441            .send()
442            .await?;
443
444        if !response.status().is_success() {
445            let status = response.status();
446            let body = response.text().await.unwrap_or_default();
447            return Err(Error::Provider {
448                provider: "openai".to_string(),
449                status: status.as_u16(),
450                message: body,
451            });
452        }
453
454        let body = response.text().await?;
455        let parsed: Eval = serde_json::from_str(&body).map_err(|e| Error::parse(body, e))?;
456        Ok(parsed)
457    }
458
459    pub async fn list_evals(&self) -> LlmResult<EvalListResponse> {
460        let url = format!("{}/v1/evals", self.config.base_url);
461        let auth = self.auth_header();
462
463        let response = self
464            .client
465            .get(&url)
466            .header("Authorization", auth)
467            .send()
468            .await?;
469
470        if !response.status().is_success() {
471            let status = response.status();
472            let body = response.text().await.unwrap_or_default();
473            return Err(Error::Provider {
474                provider: "openai".to_string(),
475                status: status.as_u16(),
476                message: body,
477            });
478        }
479
480        let body = response.text().await?;
481        let parsed: EvalListResponse =
482            serde_json::from_str(&body).map_err(|e| Error::parse(body, e))?;
483        Ok(parsed)
484    }
485}
486
487#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
488#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
489impl LlmProvider for OpenAI {
490    fn provider_type(&self) -> ProviderType {
491        ProviderType::OpenAI
492    }
493
494    fn provider_name(&self) -> &'static str {
495        "openai"
496    }
497
498    fn capabilities(&self) -> &[Capability] {
499        &[
500            Capability::ChatCompletion,
501            Capability::AudioTranscription,
502            Capability::Evals,
503        ]
504    }
505
506    async fn available_models(&self) -> LlmResult<Vec<Model>> {
507        let mut cache = self.cached_models.write().await;
508        if let Some(ref models) = *cache {
509            return Ok(models.clone());
510        }
511
512        let models = vec![
513            Model::new("gpt-4o"),
514            Model::new("gpt-4o-mini"),
515            Model::new("gpt-4-turbo"),
516            Model::new("gpt-4"),
517            Model::new("gpt-3.5-turbo"),
518        ];
519
520        *cache = Some(models.clone());
521        Ok(models)
522    }
523
524    async fn chat_raw(&self, req: RawCompletionRequest) -> LlmResult<RawCompletionResponse> {
525        let responses_req = build_responses_request(&self.config.default_model, req)?;
526        let response = self.responses(&responses_req).await?;
527        parse_responses_response(response)
528    }
529
530    #[cfg(not(target_arch = "wasm32"))]
531    async fn chat_raw_stream(
532        &self,
533        req: RawCompletionRequest,
534    ) -> LlmResult<RawCompletionEventStream> {
535        let responses_req = build_responses_request(&self.config.default_model, req)?;
536        let url = format!("{}/v1/responses", self.config.base_url);
537        let auth = self.auth_header();
538        let mut req_builder = self
539            .client
540            .post(&url)
541            .header("Authorization", auth)
542            .header("Content-Type", "application/json");
543
544        if let Some(ref org) = self.config.organization {
545            req_builder = req_builder.header("OpenAI-Organization", org);
546        }
547
548        let event_source = req_builder
549            .json(&responses_req)
550            .eventsource()
551            .map_err(|error| Error::from_eventsource_builder("openai", error))?;
552
553        let (sender, receiver) = mpsc::channel(32);
554
555        tokio::spawn(async move {
556            let mut event_source = event_source;
557            let mut function_calls: std::collections::HashMap<String, (String, String, String)> =
558                std::collections::HashMap::new();
559
560            while let Some(event) = event_source.next().await {
561                match event {
562                    Ok(Event::Open) => {}
563                    Ok(Event::Message(message)) => match message.event.as_str() {
564                        "response.output_text.delta" => {
565                            let parsed: ResponseOutputTextDeltaEvent =
566                                match serde_json::from_str(&message.data) {
567                                    Ok(parsed) => parsed,
568                                    Err(error) => {
569                                        let _ = sender
570                                            .send(Err(Error::parse(message.data, error)))
571                                            .await;
572                                        let _ = event_source.close();
573                                        return;
574                                    }
575                                };
576                            if sender
577                                .send(Ok(RawCompletionEvent::TextDelta { text: parsed.delta }))
578                                .await
579                                .is_err()
580                            {
581                                let _ = event_source.close();
582                                return;
583                            }
584                        }
585                        "response.output_item.added" | "response.output_item.done" => {
586                            let parsed: ResponseOutputItemEvent =
587                                match serde_json::from_str(&message.data) {
588                                    Ok(parsed) => parsed,
589                                    Err(error) => {
590                                        let _ = sender
591                                            .send(Err(Error::parse(message.data, error)))
592                                            .await;
593                                        let _ = event_source.close();
594                                        return;
595                                    }
596                                };
597                            if parsed.item.item_type == "function_call" {
598                                let item_id = parsed.item.id;
599                                let call_id = parsed
600                                    .item
601                                    .call_id
602                                    .clone()
603                                    .unwrap_or_else(|| item_id.clone());
604                                let name = parsed.item.name.unwrap_or_default();
605                                let arguments = parsed.item.arguments.unwrap_or_default();
606                                function_calls.insert(
607                                    item_id.clone(),
608                                    (call_id.clone(), name.clone(), arguments.clone()),
609                                );
610                                if message.event == "response.output_item.done" {
611                                    match parse_function_call(&call_id, &name, &arguments) {
612                                        Ok(call) => {
613                                            if sender
614                                                .send(Ok(RawCompletionEvent::ToolCall { call }))
615                                                .await
616                                                .is_err()
617                                            {
618                                                let _ = event_source.close();
619                                                return;
620                                            }
621                                        }
622                                        Err(error) => {
623                                            let _ = sender.send(Err(error)).await;
624                                            let _ = event_source.close();
625                                            return;
626                                        }
627                                    }
628                                }
629                            }
630                        }
631                        "response.function_call_arguments.delta" => {
632                            let parsed: ResponseFunctionCallArgumentsDeltaEvent =
633                                match serde_json::from_str(&message.data) {
634                                    Ok(parsed) => parsed,
635                                    Err(error) => {
636                                        let _ = sender
637                                            .send(Err(Error::parse(message.data, error)))
638                                            .await;
639                                        let _ = event_source.close();
640                                        return;
641                                    }
642                                };
643                            if let Some((_, _, arguments)) = function_calls.get_mut(&parsed.item_id)
644                            {
645                                arguments.push_str(&parsed.delta);
646                            }
647                        }
648                        "response.completed" => {
649                            let parsed: ResponseCompletedEvent =
650                                match serde_json::from_str(&message.data) {
651                                    Ok(parsed) => parsed,
652                                    Err(error) => {
653                                        let _ = sender
654                                            .send(Err(Error::parse(message.data, error)))
655                                            .await;
656                                        let _ = event_source.close();
657                                        return;
658                                    }
659                                };
660                            let final_response = match parse_responses_response(parsed.response) {
661                                Ok(response) => response,
662                                Err(error) => {
663                                    let _ = sender.send(Err(error)).await;
664                                    let _ = event_source.close();
665                                    return;
666                                }
667                            };
668                            let _ = sender
669                                .send(Ok(RawCompletionEvent::Done(final_response)))
670                                .await;
671                            break;
672                        }
673                        "response.failed" => {
674                            let _ = sender
675                                .send(Err(Error::InvalidResponse {
676                                    reason: format!("OpenAI stream failed: {}", message.data),
677                                }))
678                                .await;
679                            let _ = event_source.close();
680                            return;
681                        }
682                        _ => {}
683                    },
684                    Err(error) => {
685                        let _ = sender
686                            .send(Err(Error::from_eventsource("openai", error)))
687                            .await;
688                        let _ = event_source.close();
689                        return;
690                    }
691                }
692            }
693
694            let _ = event_source.close();
695        });
696
697        Ok(RawCompletionEventStream::new(receiver))
698    }
699
700    async fn transcribe(
701        &self,
702        req: AudioTranscriptionRequest,
703    ) -> LlmResult<AudioTranscriptionResponse> {
704        let url = format!("{}/v1/audio/transcriptions", self.config.base_url);
705
706        let model = match &req.model {
707            ModelSelector::Any | ModelSelector::Provider(_) => "whisper-1".to_string(),
708            ModelSelector::Specific { model, .. } => model.clone(),
709        };
710
711        let (audio_data, file_name, mime_type) = match &req.audio {
712            AudioSource::Data(data) => (
713                data.clone(),
714                "audio.wav".to_string(),
715                "audio/wav".to_string(),
716            ),
717            AudioSource::Url(_) => {
718                return Err(Error::InvalidRequest {
719                    reason: "URL audio not supported yet".to_string(),
720                });
721            }
722            AudioSource::Path(path) => (
723                std::fs::read(path).map_err(|e| Error::InvalidRequest {
724                    reason: e.to_string(),
725                })?,
726                path.file_name()
727                    .and_then(|name| name.to_str())
728                    .unwrap_or("audio")
729                    .to_string(),
730                match path.extension().and_then(|ext| ext.to_str()) {
731                    Some("ogg") => "audio/ogg",
732                    Some("mp3") => "audio/mpeg",
733                    Some("m4a") => "audio/mp4",
734                    Some("wav") => "audio/wav",
735                    Some("webm") => "audio/webm",
736                    Some("flac") => "audio/flac",
737                    _ => "application/octet-stream",
738                }
739                .to_string(),
740            ),
741        };
742
743        let part = reqwest::multipart::Part::bytes(audio_data)
744            .file_name(file_name)
745            .mime_str(&mime_type)
746            .map_err(|e| Error::InvalidRequest {
747                reason: e.to_string(),
748            })?;
749
750        let mut form = reqwest::multipart::Form::new()
751            .text("model", model.clone())
752            .part("file", part);
753
754        if let TranscriptionLanguage::Explicit { language } = req.language {
755            form = form.text("language", language);
756        }
757
758        if let TranscriptionPrompt::Hint { text } = req.prompt {
759            form = form.text("prompt", text);
760        }
761
762        if let Some(response_format) = req.response_format.as_openai_str() {
763            form = form.text("response_format", response_format.to_string());
764        }
765
766        let response = self
767            .client
768            .post(&url)
769            .header("Authorization", self.auth_header())
770            .multipart(form)
771            .send()
772            .await?;
773
774        if !response.status().is_success() {
775            let status = response.status();
776            let body = response.text().await.unwrap_or_default();
777            return Err(Error::Provider {
778                provider: "openai".to_string(),
779                status: status.as_u16(),
780                message: body,
781            });
782        }
783
784        let body = response.text().await?;
785        #[derive(Deserialize)]
786        struct TranscriptionResponse {
787            text: String,
788        }
789        let parsed: TranscriptionResponse =
790            serde_json::from_str(&body).map_err(|e| Error::parse(body, e))?;
791
792        Ok(AudioTranscriptionResponse {
793            provider: ProviderType::OpenAI,
794            model,
795            text: parsed.text,
796        })
797    }
798}
799
800fn build_responses_request(
801    default_model: &str,
802    req: RawCompletionRequest,
803) -> LlmResult<ResponsesRequest> {
804    let model = match req.model {
805        ModelSelector::Any => default_model.to_string(),
806        ModelSelector::Provider(_) => default_model.to_string(),
807        ModelSelector::Specific { model, .. } => model,
808    };
809
810    let input = req
811        .input
812        .into_iter()
813        .map(|item| -> LlmResult<ResponseInputItem> {
814            Ok(match item {
815                RawInputItem::Message { role, content } => ResponseInputItem::Message {
816                    role: match role {
817                        Role::System => "system".to_string(),
818                        Role::User => "user".to_string(),
819                        Role::Assistant => "assistant".to_string(),
820                    },
821                    content: content
822                        .into_iter()
823                        .map(|content| match content {
824                            RawInputContent::Text { text } => match role {
825                                Role::Assistant => ResponseContent::OutputText { text },
826                                Role::System | Role::User => ResponseContent::InputText { text },
827                            },
828                            RawInputContent::ImageUrl { url } => {
829                                ResponseContent::InputImage { image_url: url }
830                            }
831                        })
832                        .collect(),
833                },
834                RawInputItem::ToolCall { call } => ResponseInputItem::FunctionCall {
835                    call_id: call.id,
836                    name: call.name,
837                    arguments: serde_json::to_string(&call.arguments)
838                        .map_err(|error| Error::parse("tool call arguments", error))?,
839                },
840                RawInputItem::ToolResult {
841                    tool_use_id,
842                    content,
843                } => ResponseInputItem::FunctionCallOutput {
844                    call_id: tool_use_id,
845                    output: content,
846                },
847            })
848        })
849        .collect::<LlmResult<Vec<_>>>()?;
850
851    Ok(ResponsesRequest {
852        model,
853        input,
854        temperature: req.temperature.as_option(),
855        top_p: req.top_p.as_option(),
856        max_output_tokens: req.token_limit.as_option(),
857        stream: Some(req.response_mode.is_streaming()),
858        tools: req.tools.map(map_response_tools),
859        tool_choice: map_responses_tool_choice(req.tool_choice),
860        text: req.response_format.map(map_response_text_config),
861    })
862}
863
864fn map_response_tools(tools: Vec<RawToolDefinition>) -> Vec<ResponseToolDefinition> {
865    tools
866        .into_iter()
867        .map(|tool| ResponseToolDefinition {
868            r#type: tool.kind,
869            name: tool.function.name,
870            description: tool.function.description,
871            parameters: normalize_openai_schema(tool.function.parameters),
872            strict: true,
873        })
874        .collect()
875}
876
877fn map_responses_tool_choice(choice: RawToolChoice) -> Option<Value> {
878    match choice {
879        RawToolChoice::ProviderDefault => None,
880        RawToolChoice::Auto => Some(json!("auto")),
881        RawToolChoice::Required => Some(json!("required")),
882        RawToolChoice::Specific { name } => Some(json!({
883            "type": "function",
884            "name": name,
885        })),
886        RawToolChoice::None => Some(json!("none")),
887    }
888}
889
890fn map_response_text_config(format: RawResponseFormat) -> ResponseTextConfig {
891    ResponseTextConfig {
892        format: match format.json_schema {
893            Some(schema) => ResponseTextFormat::JsonSchema {
894                name: schema.name,
895                schema: normalize_openai_schema(schema.schema),
896                description: None,
897                strict: schema.strict,
898            },
899            None if format.r#type == "json_object" => ResponseTextFormat::JsonObject,
900            None => ResponseTextFormat::Text,
901        },
902    }
903}
904
905fn parse_responses_response(value: Value) -> LlmResult<RawCompletionResponse> {
906    let model = value
907        .get("model")
908        .and_then(Value::as_str)
909        .ok_or(Error::InvalidResponse {
910            reason: "OpenAI responses payload missing model".to_string(),
911        })?
912        .to_string();
913
914    let output_values =
915        value
916            .get("output")
917            .and_then(Value::as_array)
918            .ok_or(Error::InvalidResponse {
919                reason: "OpenAI responses payload missing output".to_string(),
920            })?;
921
922    let mut output = Vec::new();
923    let mut saw_tool_call = false;
924
925    for item in output_values {
926        match item.get("type").and_then(Value::as_str) {
927            Some("message") => {
928                let mut content = Vec::new();
929                if let Some(parts) = item.get("content").and_then(Value::as_array) {
930                    for part in parts {
931                        match part.get("type").and_then(Value::as_str) {
932                            Some("output_text") => {
933                                if let Some(text) = part.get("text").and_then(Value::as_str) {
934                                    content.push(RawOutputContent::Text {
935                                        text: text.to_string(),
936                                    });
937                                }
938                            }
939                            Some("output_json") => {
940                                if let Some(json) = part.get("json") {
941                                    content.push(RawOutputContent::Json {
942                                        value: json.clone(),
943                                    });
944                                }
945                            }
946                            _ => {}
947                        }
948                    }
949                }
950                if !content.is_empty() {
951                    output.push(RawOutputItem::Message {
952                        role: Role::Assistant,
953                        content,
954                    });
955                }
956            }
957            Some("function_call") => {
958                let call_id = item
959                    .get("call_id")
960                    .and_then(Value::as_str)
961                    .or_else(|| item.get("id").and_then(Value::as_str))
962                    .ok_or(Error::InvalidResponse {
963                        reason: "OpenAI function_call missing call id".to_string(),
964                    })?;
965                let name =
966                    item.get("name")
967                        .and_then(Value::as_str)
968                        .ok_or(Error::InvalidResponse {
969                            reason: "OpenAI function_call missing name".to_string(),
970                        })?;
971                let arguments = item.get("arguments").and_then(Value::as_str).ok_or(
972                    Error::InvalidResponse {
973                        reason: "OpenAI function_call missing arguments".to_string(),
974                    },
975                )?;
976                output.push(RawOutputItem::ToolCall {
977                    call: parse_function_call(call_id, name, arguments)?,
978                });
979                saw_tool_call = true;
980            }
981            Some("reasoning") => {
982                let summary = item
983                    .get("summary")
984                    .and_then(Value::as_array)
985                    .into_iter()
986                    .flatten()
987                    .filter_map(|part| part.get("text").and_then(Value::as_str))
988                    .collect::<Vec<_>>()
989                    .join("\n");
990                if !summary.is_empty() {
991                    output.push(RawOutputItem::Reasoning { text: summary });
992                }
993            }
994            _ => {}
995        }
996    }
997
998    let usage = value.get("usage").cloned().unwrap_or_else(|| json!({}));
999    let prompt_tokens = usage
1000        .get("input_tokens")
1001        .and_then(Value::as_u64)
1002        .unwrap_or(0) as u32;
1003    let completion_tokens = usage
1004        .get("output_tokens")
1005        .and_then(Value::as_u64)
1006        .unwrap_or(0) as u32;
1007    let total_tokens = usage
1008        .get("total_tokens")
1009        .and_then(Value::as_u64)
1010        .unwrap_or((prompt_tokens + completion_tokens) as u64) as u32;
1011
1012    Ok(RawCompletionResponse {
1013        provider: ProviderType::OpenAI,
1014        model,
1015        output,
1016        usage: CompletionUsage {
1017            prompt_tokens,
1018            completion_tokens,
1019            total_tokens,
1020        },
1021        finish_reason: if saw_tool_call {
1022            FinishReason::ToolCalls
1023        } else {
1024            FinishReason::Stop
1025        },
1026    })
1027}
1028
1029fn parse_function_call(call_id: &str, name: &str, arguments: &str) -> LlmResult<RawToolCall> {
1030    Ok(RawToolCall {
1031        id: call_id.to_string(),
1032        name: name.to_string(),
1033        arguments: serde_json::from_str(arguments)
1034            .map_err(|e| Error::parse("tool arguments", e))?,
1035    })
1036}