Skip to main content

rig_core/providers/openai/
client.rs

1use crate::{
2    client::{
3        self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
4        ProviderClient,
5    },
6    extractor::ExtractorBuilder,
7    http_client::{self, HttpClientExt},
8    prelude::CompletionClient,
9    wasm_compat::{WasmCompatSend, WasmCompatSync},
10};
11use schemars::JsonSchema;
12use serde::{Deserialize, Serialize};
13use std::fmt::Debug;
14
15// ================================================================
16// Main OpenAI Client
17// ================================================================
18const OPENAI_API_BASE_URL: &str = "https://api.openai.com/v1";
19
20// ================================================================
21// OpenAI Responses API Extension
22// ================================================================
23#[derive(Debug, Default, Clone, Copy)]
24pub struct OpenAIResponsesExt;
25
26#[derive(Debug, Default, Clone, Copy)]
27pub struct OpenAIResponsesExtBuilder;
28
29// ================================================================
30// OpenAI Completions API Extension
31// ================================================================
32#[derive(Debug, Default, Clone, Copy)]
33pub struct OpenAICompletionsExt;
34
35#[derive(Debug, Default, Clone, Copy)]
36pub struct OpenAICompletionsExtBuilder;
37
38type OpenAIApiKey = BearerAuth;
39
40// Responses API client (default)
41pub type Client<H = reqwest::Client> = client::Client<OpenAIResponsesExt, H>;
42pub type ClientBuilder<H = crate::markers::Missing> =
43    client::ClientBuilder<OpenAIResponsesExtBuilder, OpenAIApiKey, H>;
44
45// Completions API client
46pub type CompletionsClient<H = reqwest::Client> = client::Client<OpenAICompletionsExt, H>;
47pub type CompletionsClientBuilder<H = crate::markers::Missing> =
48    client::ClientBuilder<OpenAICompletionsExtBuilder, OpenAIApiKey, H>;
49
50impl Provider for OpenAIResponsesExt {
51    type Builder = OpenAIResponsesExtBuilder;
52    const VERIFY_PATH: &'static str = "/models";
53}
54
55impl Provider for OpenAICompletionsExt {
56    type Builder = OpenAICompletionsExtBuilder;
57    const VERIFY_PATH: &'static str = "/models";
58}
59
60impl<H> Capabilities<H> for OpenAIResponsesExt {
61    type Completion = Capable<super::responses_api::ResponsesCompletionModel<H>>;
62    type Embeddings = Capable<super::EmbeddingModel<H>>;
63    type Transcription = Capable<super::TranscriptionModel<H>>;
64    type ModelListing = Capable<super::OpenAIModelLister<H>>;
65    #[cfg(feature = "image")]
66    type ImageGeneration = Capable<super::ImageGenerationModel<H>>;
67    #[cfg(feature = "audio")]
68    type AudioGeneration = Capable<super::audio_generation::AudioGenerationModel<H>>;
69    type Rerank = Nothing;
70}
71
72impl<H> Capabilities<H> for OpenAICompletionsExt {
73    type Completion = Capable<super::completion::CompletionModel<H>>;
74    type Embeddings = Capable<super::GenericEmbeddingModel<OpenAICompletionsExt, H>>;
75    type Transcription = Capable<super::TranscriptionModel<H>>;
76    type ModelListing = Capable<super::OpenAIModelLister<H>>;
77    #[cfg(feature = "image")]
78    type ImageGeneration = Capable<super::ImageGenerationModel<H>>;
79    #[cfg(feature = "audio")]
80    type AudioGeneration = Capable<super::audio_generation::AudioGenerationModel<H>>;
81    type Rerank = Nothing;
82}
83
84impl DebugExt for OpenAIResponsesExt {}
85
86impl DebugExt for OpenAICompletionsExt {}
87
88impl ProviderBuilder for OpenAIResponsesExtBuilder {
89    type Extension<H>
90        = OpenAIResponsesExt
91    where
92        H: HttpClientExt;
93    type ApiKey = OpenAIApiKey;
94
95    const BASE_URL: &'static str = OPENAI_API_BASE_URL;
96
97    fn build<H>(
98        _builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
99    ) -> http_client::Result<Self::Extension<H>>
100    where
101        H: HttpClientExt,
102    {
103        Ok(OpenAIResponsesExt)
104    }
105}
106
107impl ProviderBuilder for OpenAICompletionsExtBuilder {
108    type Extension<H>
109        = OpenAICompletionsExt
110    where
111        H: HttpClientExt;
112    type ApiKey = OpenAIApiKey;
113
114    const BASE_URL: &'static str = OPENAI_API_BASE_URL;
115
116    fn build<H>(
117        _builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
118    ) -> http_client::Result<Self::Extension<H>>
119    where
120        H: HttpClientExt,
121    {
122        Ok(OpenAICompletionsExt)
123    }
124}
125
126impl<H> Client<H>
127where
128    H: HttpClientExt
129        + Clone
130        + std::fmt::Debug
131        + Default
132        + WasmCompatSend
133        + WasmCompatSync
134        + 'static,
135{
136    /// Create an extractor builder with the given completion model.
137    /// Uses the OpenAI Responses API (default behavior).
138    pub fn extractor<U>(
139        &self,
140        model: impl Into<String>,
141    ) -> ExtractorBuilder<super::responses_api::ResponsesCompletionModel<H>, U>
142    where
143        U: JsonSchema + for<'a> Deserialize<'a> + Serialize + WasmCompatSend + WasmCompatSync,
144    {
145        ExtractorBuilder::new(self.completion_model(model))
146    }
147
148    /// Create a Completions API client from this Responses API client.
149    /// Useful for switching to the traditional Chat Completions API.
150    pub fn completions_api(self) -> CompletionsClient<H> {
151        self.with_ext(OpenAICompletionsExt)
152    }
153}
154
155#[cfg(all(not(target_family = "wasm"), feature = "websocket"))]
156impl Client<reqwest::Client> {
157    /// WebSocket mode currently uses a native `tokio-tungstenite` transport and does
158    /// not reuse custom `HttpClientExt` backends, so this API is only exposed for the
159    /// default `reqwest::Client` transport.
160    pub fn responses_websocket_builder(
161        &self,
162        model: impl Into<String>,
163    ) -> super::responses_api::websocket::ResponsesWebSocketSessionBuilder {
164        super::responses_api::websocket::ResponsesWebSocketSessionBuilder::new(
165            self.completion_model(model),
166        )
167    }
168
169    /// This API is OpenAI-specific and only available on non-wasm targets in `rig-core`.
170    pub async fn responses_websocket(
171        &self,
172        model: impl Into<String>,
173    ) -> Result<
174        super::responses_api::websocket::ResponsesWebSocketSession,
175        crate::completion::CompletionError,
176    > {
177        self.responses_websocket_builder(model).connect().await
178    }
179}
180
181impl<H> CompletionsClient<H>
182where
183    H: HttpClientExt
184        + Clone
185        + std::fmt::Debug
186        + Default
187        + WasmCompatSend
188        + WasmCompatSync
189        + 'static,
190{
191    /// Create an extractor builder with the given completion model.
192    /// Uses the OpenAI Chat Completions API.
193    pub fn extractor<U>(
194        &self,
195        model: impl Into<String>,
196    ) -> ExtractorBuilder<super::completion::CompletionModel<H>, U>
197    where
198        U: JsonSchema + for<'a> Deserialize<'a> + Serialize + WasmCompatSend + WasmCompatSync,
199    {
200        ExtractorBuilder::new(self.completion_model(model))
201    }
202
203    /// Create a Responses API client from this Completions API client.
204    /// Useful for switching to the newer Responses API.
205    pub fn responses_api(self) -> Client<H> {
206        self.with_ext(OpenAIResponsesExt)
207    }
208}
209
210impl ProviderClient for Client {
211    type Input = OpenAIApiKey;
212    type Error = crate::client::ProviderClientError;
213
214    /// Create a new OpenAI Responses API client from the `OPENAI_API_KEY` environment variable.
215    fn from_env() -> Result<Self, Self::Error> {
216        let base_url = crate::client::optional_env_var("OPENAI_BASE_URL")?;
217        let api_key = crate::client::required_env_var("OPENAI_API_KEY")?;
218
219        let mut builder = Client::builder().api_key(&api_key);
220
221        if let Some(base) = base_url {
222            builder = builder.base_url(&base);
223        }
224
225        builder.build().map_err(Into::into)
226    }
227
228    fn from_val(input: Self::Input) -> Result<Self, Self::Error> {
229        Self::new(input).map_err(Into::into)
230    }
231}
232
233impl ProviderClient for CompletionsClient {
234    type Input = OpenAIApiKey;
235    type Error = crate::client::ProviderClientError;
236
237    /// Create a new OpenAI Completions API client from the `OPENAI_API_KEY` environment variable.
238    fn from_env() -> Result<Self, Self::Error> {
239        let base_url = crate::client::optional_env_var("OPENAI_BASE_URL")?;
240        let api_key = crate::client::required_env_var("OPENAI_API_KEY")?;
241
242        let mut builder = CompletionsClient::builder().api_key(&api_key);
243
244        if let Some(base) = base_url {
245            builder = builder.base_url(&base);
246        }
247
248        builder.build().map_err(Into::into)
249    }
250
251    fn from_val(input: Self::Input) -> Result<Self, Self::Error> {
252        Self::new(input).map_err(Into::into)
253    }
254}
255
256#[derive(Debug, Deserialize)]
257pub struct ApiErrorResponse {
258    pub(crate) message: String,
259}
260
261#[derive(Debug, Deserialize)]
262#[serde(untagged)]
263pub(crate) enum ApiResponse<T> {
264    Ok(T),
265    Err(ApiErrorResponse),
266}
267
268#[cfg(test)]
269mod tests {
270    use crate::client::{CompletionClient, EmbeddingsClient};
271    use crate::message::ImageDetail;
272    use crate::providers::openai::{
273        AssistantContent, Function, ImageUrl, Message, ToolCall, ToolType, UserContent,
274    };
275    use crate::{OneOrMany, message};
276    use serde_path_to_error::deserialize;
277
278    #[test]
279    fn test_deserialize_message() {
280        let assistant_message_json = r#"
281        {
282            "role": "assistant",
283            "content": "\n\nHello there, how may I assist you today?"
284        }
285        "#;
286
287        let assistant_message_json2 = r#"
288        {
289            "role": "assistant",
290            "content": [
291                {
292                    "type": "text",
293                    "text": "\n\nHello there, how may I assist you today?"
294                }
295            ],
296            "tool_calls": null
297        }
298        "#;
299
300        let assistant_message_json3 = r#"
301        {
302            "role": "assistant",
303            "tool_calls": [
304                {
305                    "id": "call_h89ipqYUjEpCPI6SxspMnoUU",
306                    "type": "function",
307                    "function": {
308                        "name": "subtract",
309                        "arguments": "{\"x\": 2, \"y\": 5}"
310                    }
311                }
312            ],
313            "content": null,
314            "refusal": null
315        }
316        "#;
317
318        let user_message_json = r#"
319        {
320            "role": "user",
321            "content": [
322                {
323                    "type": "text",
324                    "text": "What's in this image?"
325                },
326                {
327                    "type": "image_url",
328                    "image_url": {
329                        "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
330                    }
331                },
332                {
333                    "type": "audio",
334                    "input_audio": {
335                        "data": "...",
336                        "format": "mp3"
337                    }
338                }
339            ]
340        }
341        "#;
342
343        let assistant_message: Message = {
344            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
345            deserialize(jd).unwrap_or_else(|err| {
346                panic!(
347                    "Deserialization error at {} ({}:{}): {}",
348                    err.path(),
349                    err.inner().line(),
350                    err.inner().column(),
351                    err
352                );
353            })
354        };
355
356        let assistant_message2: Message = {
357            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
358            deserialize(jd).unwrap_or_else(|err| {
359                panic!(
360                    "Deserialization error at {} ({}:{}): {}",
361                    err.path(),
362                    err.inner().line(),
363                    err.inner().column(),
364                    err
365                );
366            })
367        };
368
369        let assistant_message3: Message = {
370            let jd: &mut serde_json::Deserializer<serde_json::de::StrRead<'_>> =
371                &mut serde_json::Deserializer::from_str(assistant_message_json3);
372            deserialize(jd).unwrap_or_else(|err| {
373                panic!(
374                    "Deserialization error at {} ({}:{}): {}",
375                    err.path(),
376                    err.inner().line(),
377                    err.inner().column(),
378                    err
379                );
380            })
381        };
382
383        let user_message: Message = {
384            let jd = &mut serde_json::Deserializer::from_str(user_message_json);
385            deserialize(jd).unwrap_or_else(|err| {
386                panic!(
387                    "Deserialization error at {} ({}:{}): {}",
388                    err.path(),
389                    err.inner().line(),
390                    err.inner().column(),
391                    err
392                );
393            })
394        };
395
396        match assistant_message {
397            Message::Assistant { content, .. } => {
398                assert_eq!(
399                    content[0],
400                    AssistantContent::Text {
401                        text: "\n\nHello there, how may I assist you today?".to_string()
402                    }
403                );
404            }
405            _ => panic!("Expected assistant message"),
406        }
407
408        match assistant_message2 {
409            Message::Assistant {
410                content,
411                tool_calls,
412                ..
413            } => {
414                assert_eq!(
415                    content[0],
416                    AssistantContent::Text {
417                        text: "\n\nHello there, how may I assist you today?".to_string()
418                    }
419                );
420
421                assert_eq!(tool_calls, vec![]);
422            }
423            _ => panic!("Expected assistant message"),
424        }
425
426        match assistant_message3 {
427            Message::Assistant {
428                content,
429                tool_calls,
430                refusal,
431                ..
432            } => {
433                assert!(content.is_empty());
434                assert!(refusal.is_none());
435                assert_eq!(
436                    tool_calls[0],
437                    ToolCall {
438                        id: "call_h89ipqYUjEpCPI6SxspMnoUU".to_string(),
439                        r#type: ToolType::Function,
440                        function: Function {
441                            name: "subtract".to_string(),
442                            arguments: serde_json::json!({"x": 2, "y": 5}),
443                        },
444                    }
445                );
446            }
447            _ => panic!("Expected assistant message"),
448        }
449
450        match user_message {
451            Message::User { content, .. } => {
452                let (first, second) = {
453                    let mut iter = content.into_iter();
454                    (iter.next().unwrap(), iter.next().unwrap())
455                };
456                assert_eq!(
457                    first,
458                    UserContent::Text {
459                        text: "What's in this image?".to_string()
460                    }
461                );
462                assert_eq!(second, UserContent::Image { image_url: ImageUrl { url: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg".to_string(), detail: ImageDetail::default() } });
463            }
464            _ => panic!("Expected user message"),
465        }
466    }
467
468    #[test]
469    fn test_message_to_message_conversion() {
470        let user_message = message::Message::User {
471            content: OneOrMany::one(message::UserContent::text("Hello")),
472        };
473
474        let assistant_message = message::Message::Assistant {
475            id: None,
476            content: OneOrMany::one(message::AssistantContent::text("Hi there!")),
477        };
478
479        let converted_user_message: Vec<Message> = user_message.clone().try_into().unwrap();
480        let converted_assistant_message: Vec<Message> =
481            assistant_message.clone().try_into().unwrap();
482
483        match converted_user_message[0].clone() {
484            Message::User { content, .. } => {
485                assert_eq!(
486                    content.first(),
487                    UserContent::Text {
488                        text: "Hello".to_string()
489                    }
490                );
491            }
492            _ => panic!("Expected user message"),
493        }
494
495        match converted_assistant_message[0].clone() {
496            Message::Assistant { content, .. } => {
497                assert_eq!(
498                    content[0].clone(),
499                    AssistantContent::Text {
500                        text: "Hi there!".to_string()
501                    }
502                );
503            }
504            _ => panic!("Expected assistant message"),
505        }
506
507        let original_user_message: message::Message =
508            converted_user_message[0].clone().try_into().unwrap();
509        let original_assistant_message: message::Message =
510            converted_assistant_message[0].clone().try_into().unwrap();
511
512        assert_eq!(original_user_message, user_message);
513        assert_eq!(original_assistant_message, assistant_message);
514    }
515
516    #[test]
517    fn test_message_from_message_conversion() {
518        let user_message = Message::User {
519            content: OneOrMany::one(UserContent::Text {
520                text: "Hello".to_string(),
521            }),
522            name: None,
523        };
524
525        let assistant_message = Message::Assistant {
526            content: vec![AssistantContent::Text {
527                text: "Hi there!".to_string(),
528            }],
529            reasoning: None,
530            refusal: None,
531            audio: None,
532            name: None,
533            tool_calls: vec![],
534        };
535
536        let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
537        let converted_assistant_message: message::Message =
538            assistant_message.clone().try_into().unwrap();
539
540        match converted_user_message.clone() {
541            message::Message::User { content } => {
542                assert_eq!(content.first(), message::UserContent::text("Hello"));
543            }
544            _ => panic!("Expected user message"),
545        }
546
547        match converted_assistant_message.clone() {
548            message::Message::Assistant { content, .. } => {
549                assert_eq!(
550                    content.first(),
551                    message::AssistantContent::text("Hi there!")
552                );
553            }
554            _ => panic!("Expected assistant message"),
555        }
556
557        let original_user_message: Vec<Message> = converted_user_message.try_into().unwrap();
558        let original_assistant_message: Vec<Message> =
559            converted_assistant_message.try_into().unwrap();
560
561        assert_eq!(original_user_message[0], user_message);
562        assert_eq!(original_assistant_message[0], assistant_message);
563    }
564
565    #[test]
566    fn test_user_message_single_text_serializes_as_string() {
567        let user_message = Message::User {
568            content: OneOrMany::one(UserContent::Text {
569                text: "Hello world".to_string(),
570            }),
571            name: None,
572        };
573
574        let serialized = serde_json::to_value(&user_message).unwrap();
575
576        assert_eq!(serialized["role"], "user");
577        assert_eq!(serialized["content"], "Hello world");
578    }
579
580    #[test]
581    fn test_user_message_multiple_parts_serializes_as_array() {
582        let user_message = Message::User {
583            content: OneOrMany::many(vec![
584                UserContent::Text {
585                    text: "What's in this image?".to_string(),
586                },
587                UserContent::Image {
588                    image_url: ImageUrl {
589                        url: "https://example.com/image.jpg".to_string(),
590                        detail: ImageDetail::default(),
591                    },
592                },
593            ])
594            .unwrap(),
595            name: None,
596        };
597
598        let serialized = serde_json::to_value(&user_message).unwrap();
599
600        assert_eq!(serialized["role"], "user");
601        assert!(serialized["content"].is_array());
602        assert_eq!(serialized["content"].as_array().unwrap().len(), 2);
603    }
604
605    #[test]
606    fn test_user_message_single_image_serializes_as_array() {
607        let user_message = Message::User {
608            content: OneOrMany::one(UserContent::Image {
609                image_url: ImageUrl {
610                    url: "https://example.com/image.jpg".to_string(),
611                    detail: ImageDetail::default(),
612                },
613            }),
614            name: None,
615        };
616
617        let serialized = serde_json::to_value(&user_message).unwrap();
618
619        assert_eq!(serialized["role"], "user");
620        // Single non-text content should still serialize as array
621        assert!(serialized["content"].is_array());
622    }
623    #[test]
624    fn test_client_initialization() {
625        let _client =
626            crate::providers::openai::Client::new("dummy-key").expect("Client::new() failed");
627        let _client_from_builder = crate::providers::openai::Client::builder()
628            .api_key("dummy-key")
629            .build()
630            .expect("Client::builder() failed");
631    }
632
633    #[test]
634    fn test_legacy_chat_completion_model_type_annotation_still_compiles() {
635        let client = crate::providers::openai::Client::new("dummy-key")
636            .expect("Client::new() failed")
637            .completions_api();
638
639        let _model: crate::providers::openai::completion::CompletionModel<reqwest::Client> =
640            client.completion_model("gpt-4o");
641    }
642
643    #[test]
644    fn test_legacy_embedding_model_type_annotation_still_compiles() {
645        let client =
646            crate::providers::openai::Client::new("dummy-key").expect("Client::new() failed");
647
648        let _model: crate::providers::openai::EmbeddingModel<reqwest::Client> =
649            client.embedding_model(crate::providers::openai::TEXT_EMBEDDING_3_SMALL);
650    }
651}