Skip to main content

rig/providers/openai/
client.rs

1use crate::{
2    client::{
3        self, BearerAuth, Capabilities, Capable, DebugExt, Provider, ProviderBuilder,
4        ProviderClient,
5    },
6    extractor::ExtractorBuilder,
7    http_client::{self, HttpClientExt},
8    prelude::CompletionClient,
9    wasm_compat::{WasmCompatSend, WasmCompatSync},
10};
11use schemars::JsonSchema;
12use serde::{Deserialize, Serialize};
13use std::fmt::Debug;
14
15// ================================================================
16// Main OpenAI Client
17// ================================================================
18const OPENAI_API_BASE_URL: &str = "https://api.openai.com/v1";
19
20// ================================================================
21// OpenAI Responses API Extension
22// ================================================================
23#[derive(Debug, Default, Clone, Copy)]
24pub struct OpenAIResponsesExt;
25
26#[derive(Debug, Default, Clone, Copy)]
27pub struct OpenAIResponsesExtBuilder;
28
29// ================================================================
30// OpenAI Completions API Extension
31// ================================================================
32#[derive(Debug, Default, Clone, Copy)]
33pub struct OpenAICompletionsExt;
34
35#[derive(Debug, Default, Clone, Copy)]
36pub struct OpenAICompletionsExtBuilder;
37
38type OpenAIApiKey = BearerAuth;
39
40// Responses API client (default)
41pub type Client<H = reqwest::Client> = client::Client<OpenAIResponsesExt, H>;
42pub type ClientBuilder<H = reqwest::Client> =
43    client::ClientBuilder<OpenAIResponsesExtBuilder, OpenAIApiKey, H>;
44
45// Completions API client
46pub type CompletionsClient<H = reqwest::Client> = client::Client<OpenAICompletionsExt, H>;
47pub type CompletionsClientBuilder<H = reqwest::Client> =
48    client::ClientBuilder<OpenAICompletionsExtBuilder, OpenAIApiKey, H>;
49
50impl Provider for OpenAIResponsesExt {
51    type Builder = OpenAIResponsesExtBuilder;
52    const VERIFY_PATH: &'static str = "/models";
53}
54
55impl Provider for OpenAICompletionsExt {
56    type Builder = OpenAICompletionsExtBuilder;
57    const VERIFY_PATH: &'static str = "/models";
58}
59
60impl<H> Capabilities<H> for OpenAIResponsesExt {
61    type Completion = Capable<super::responses_api::ResponsesCompletionModel<H>>;
62    type Embeddings = Capable<super::EmbeddingModel<H>>;
63    type Transcription = Capable<super::TranscriptionModel<H>>;
64    type ModelListing = Capable<super::OpenAIModelLister<H>>;
65    #[cfg(feature = "image")]
66    type ImageGeneration = Capable<super::ImageGenerationModel<H>>;
67    #[cfg(feature = "audio")]
68    type AudioGeneration = Capable<super::audio_generation::AudioGenerationModel<H>>;
69}
70
71impl<H> Capabilities<H> for OpenAICompletionsExt {
72    type Completion = Capable<super::completion::CompletionModel<H>>;
73    type Embeddings = Capable<super::GenericEmbeddingModel<OpenAICompletionsExt, H>>;
74    type Transcription = Capable<super::TranscriptionModel<H>>;
75    type ModelListing = Capable<super::OpenAIModelLister<H>>;
76    #[cfg(feature = "image")]
77    type ImageGeneration = Capable<super::ImageGenerationModel<H>>;
78    #[cfg(feature = "audio")]
79    type AudioGeneration = Capable<super::audio_generation::AudioGenerationModel<H>>;
80}
81
82impl DebugExt for OpenAIResponsesExt {}
83
84impl DebugExt for OpenAICompletionsExt {}
85
86impl ProviderBuilder for OpenAIResponsesExtBuilder {
87    type Extension<H>
88        = OpenAIResponsesExt
89    where
90        H: HttpClientExt;
91    type ApiKey = OpenAIApiKey;
92
93    const BASE_URL: &'static str = OPENAI_API_BASE_URL;
94
95    fn build<H>(
96        _builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
97    ) -> http_client::Result<Self::Extension<H>>
98    where
99        H: HttpClientExt,
100    {
101        Ok(OpenAIResponsesExt)
102    }
103}
104
105impl ProviderBuilder for OpenAICompletionsExtBuilder {
106    type Extension<H>
107        = OpenAICompletionsExt
108    where
109        H: HttpClientExt;
110    type ApiKey = OpenAIApiKey;
111
112    const BASE_URL: &'static str = OPENAI_API_BASE_URL;
113
114    fn build<H>(
115        _builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
116    ) -> http_client::Result<Self::Extension<H>>
117    where
118        H: HttpClientExt,
119    {
120        Ok(OpenAICompletionsExt)
121    }
122}
123
124impl<H> Client<H>
125where
126    H: HttpClientExt
127        + Clone
128        + std::fmt::Debug
129        + Default
130        + WasmCompatSend
131        + WasmCompatSync
132        + 'static,
133{
134    /// Create an extractor builder with the given completion model.
135    /// Uses the OpenAI Responses API (default behavior).
136    pub fn extractor<U>(
137        &self,
138        model: impl Into<String>,
139    ) -> ExtractorBuilder<super::responses_api::ResponsesCompletionModel<H>, U>
140    where
141        U: JsonSchema + for<'a> Deserialize<'a> + Serialize + WasmCompatSend + WasmCompatSync,
142    {
143        ExtractorBuilder::new(self.completion_model(model))
144    }
145
146    /// Create a Completions API client from this Responses API client.
147    /// Useful for switching to the traditional Chat Completions API.
148    pub fn completions_api(self) -> CompletionsClient<H> {
149        self.with_ext(OpenAICompletionsExt)
150    }
151}
152
153#[cfg(all(not(target_family = "wasm"), feature = "websocket"))]
154impl Client<reqwest::Client> {
155    /// WebSocket mode currently uses a native `tokio-tungstenite` transport and does
156    /// not reuse custom `HttpClientExt` backends, so this API is only exposed for the
157    /// default `reqwest::Client` transport.
158    pub fn responses_websocket_builder(
159        &self,
160        model: impl Into<String>,
161    ) -> super::responses_api::websocket::ResponsesWebSocketSessionBuilder {
162        super::responses_api::websocket::ResponsesWebSocketSessionBuilder::new(
163            self.completion_model(model),
164        )
165    }
166
167    /// This API is OpenAI-specific and only available on non-wasm targets in `rig-core`.
168    pub async fn responses_websocket(
169        &self,
170        model: impl Into<String>,
171    ) -> Result<
172        super::responses_api::websocket::ResponsesWebSocketSession,
173        crate::completion::CompletionError,
174    > {
175        self.responses_websocket_builder(model).connect().await
176    }
177}
178
179impl<H> CompletionsClient<H>
180where
181    H: HttpClientExt
182        + Clone
183        + std::fmt::Debug
184        + Default
185        + WasmCompatSend
186        + WasmCompatSync
187        + 'static,
188{
189    /// Create an extractor builder with the given completion model.
190    /// Uses the OpenAI Chat Completions API.
191    pub fn extractor<U>(
192        &self,
193        model: impl Into<String>,
194    ) -> ExtractorBuilder<super::completion::CompletionModel<H>, U>
195    where
196        U: JsonSchema + for<'a> Deserialize<'a> + Serialize + WasmCompatSend + WasmCompatSync,
197    {
198        ExtractorBuilder::new(self.completion_model(model))
199    }
200
201    /// Create a Responses API client from this Completions API client.
202    /// Useful for switching to the newer Responses API.
203    pub fn responses_api(self) -> Client<H> {
204        self.with_ext(OpenAIResponsesExt)
205    }
206}
207
208impl ProviderClient for Client {
209    type Input = OpenAIApiKey;
210    type Error = crate::client::ProviderClientError;
211
212    /// Create a new OpenAI Responses API client from the `OPENAI_API_KEY` environment variable.
213    fn from_env() -> Result<Self, Self::Error> {
214        let base_url = crate::client::optional_env_var("OPENAI_BASE_URL")?;
215        let api_key = crate::client::required_env_var("OPENAI_API_KEY")?;
216
217        let mut builder = Client::builder().api_key(&api_key);
218
219        if let Some(base) = base_url {
220            builder = builder.base_url(&base);
221        }
222
223        builder.build().map_err(Into::into)
224    }
225
226    fn from_val(input: Self::Input) -> Result<Self, Self::Error> {
227        Self::new(input).map_err(Into::into)
228    }
229}
230
231impl ProviderClient for CompletionsClient {
232    type Input = OpenAIApiKey;
233    type Error = crate::client::ProviderClientError;
234
235    /// Create a new OpenAI Completions API client from the `OPENAI_API_KEY` environment variable.
236    fn from_env() -> Result<Self, Self::Error> {
237        let base_url = crate::client::optional_env_var("OPENAI_BASE_URL")?;
238        let api_key = crate::client::required_env_var("OPENAI_API_KEY")?;
239
240        let mut builder = CompletionsClient::builder().api_key(&api_key);
241
242        if let Some(base) = base_url {
243            builder = builder.base_url(&base);
244        }
245
246        builder.build().map_err(Into::into)
247    }
248
249    fn from_val(input: Self::Input) -> Result<Self, Self::Error> {
250        Self::new(input).map_err(Into::into)
251    }
252}
253
254#[derive(Debug, Deserialize)]
255pub struct ApiErrorResponse {
256    pub(crate) message: String,
257}
258
259#[derive(Debug, Deserialize)]
260#[serde(untagged)]
261pub(crate) enum ApiResponse<T> {
262    Ok(T),
263    Err(ApiErrorResponse),
264}
265
266#[cfg(test)]
267mod tests {
268    use crate::client::{CompletionClient, EmbeddingsClient};
269    use crate::message::ImageDetail;
270    use crate::providers::openai::{
271        AssistantContent, Function, ImageUrl, Message, ToolCall, ToolType, UserContent,
272    };
273    use crate::{OneOrMany, message};
274    use serde_path_to_error::deserialize;
275
276    #[test]
277    fn test_deserialize_message() {
278        let assistant_message_json = r#"
279        {
280            "role": "assistant",
281            "content": "\n\nHello there, how may I assist you today?"
282        }
283        "#;
284
285        let assistant_message_json2 = r#"
286        {
287            "role": "assistant",
288            "content": [
289                {
290                    "type": "text",
291                    "text": "\n\nHello there, how may I assist you today?"
292                }
293            ],
294            "tool_calls": null
295        }
296        "#;
297
298        let assistant_message_json3 = r#"
299        {
300            "role": "assistant",
301            "tool_calls": [
302                {
303                    "id": "call_h89ipqYUjEpCPI6SxspMnoUU",
304                    "type": "function",
305                    "function": {
306                        "name": "subtract",
307                        "arguments": "{\"x\": 2, \"y\": 5}"
308                    }
309                }
310            ],
311            "content": null,
312            "refusal": null
313        }
314        "#;
315
316        let user_message_json = r#"
317        {
318            "role": "user",
319            "content": [
320                {
321                    "type": "text",
322                    "text": "What's in this image?"
323                },
324                {
325                    "type": "image_url",
326                    "image_url": {
327                        "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
328                    }
329                },
330                {
331                    "type": "audio",
332                    "input_audio": {
333                        "data": "...",
334                        "format": "mp3"
335                    }
336                }
337            ]
338        }
339        "#;
340
341        let assistant_message: Message = {
342            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
343            deserialize(jd).unwrap_or_else(|err| {
344                panic!(
345                    "Deserialization error at {} ({}:{}): {}",
346                    err.path(),
347                    err.inner().line(),
348                    err.inner().column(),
349                    err
350                );
351            })
352        };
353
354        let assistant_message2: Message = {
355            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
356            deserialize(jd).unwrap_or_else(|err| {
357                panic!(
358                    "Deserialization error at {} ({}:{}): {}",
359                    err.path(),
360                    err.inner().line(),
361                    err.inner().column(),
362                    err
363                );
364            })
365        };
366
367        let assistant_message3: Message = {
368            let jd: &mut serde_json::Deserializer<serde_json::de::StrRead<'_>> =
369                &mut serde_json::Deserializer::from_str(assistant_message_json3);
370            deserialize(jd).unwrap_or_else(|err| {
371                panic!(
372                    "Deserialization error at {} ({}:{}): {}",
373                    err.path(),
374                    err.inner().line(),
375                    err.inner().column(),
376                    err
377                );
378            })
379        };
380
381        let user_message: Message = {
382            let jd = &mut serde_json::Deserializer::from_str(user_message_json);
383            deserialize(jd).unwrap_or_else(|err| {
384                panic!(
385                    "Deserialization error at {} ({}:{}): {}",
386                    err.path(),
387                    err.inner().line(),
388                    err.inner().column(),
389                    err
390                );
391            })
392        };
393
394        match assistant_message {
395            Message::Assistant { content, .. } => {
396                assert_eq!(
397                    content[0],
398                    AssistantContent::Text {
399                        text: "\n\nHello there, how may I assist you today?".to_string()
400                    }
401                );
402            }
403            _ => panic!("Expected assistant message"),
404        }
405
406        match assistant_message2 {
407            Message::Assistant {
408                content,
409                tool_calls,
410                ..
411            } => {
412                assert_eq!(
413                    content[0],
414                    AssistantContent::Text {
415                        text: "\n\nHello there, how may I assist you today?".to_string()
416                    }
417                );
418
419                assert_eq!(tool_calls, vec![]);
420            }
421            _ => panic!("Expected assistant message"),
422        }
423
424        match assistant_message3 {
425            Message::Assistant {
426                content,
427                tool_calls,
428                refusal,
429                ..
430            } => {
431                assert!(content.is_empty());
432                assert!(refusal.is_none());
433                assert_eq!(
434                    tool_calls[0],
435                    ToolCall {
436                        id: "call_h89ipqYUjEpCPI6SxspMnoUU".to_string(),
437                        r#type: ToolType::Function,
438                        function: Function {
439                            name: "subtract".to_string(),
440                            arguments: serde_json::json!({"x": 2, "y": 5}),
441                        },
442                    }
443                );
444            }
445            _ => panic!("Expected assistant message"),
446        }
447
448        match user_message {
449            Message::User { content, .. } => {
450                let (first, second) = {
451                    let mut iter = content.into_iter();
452                    (iter.next().unwrap(), iter.next().unwrap())
453                };
454                assert_eq!(
455                    first,
456                    UserContent::Text {
457                        text: "What's in this image?".to_string()
458                    }
459                );
460                assert_eq!(second, UserContent::Image { image_url: ImageUrl { url: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg".to_string(), detail: ImageDetail::default() } });
461            }
462            _ => panic!("Expected user message"),
463        }
464    }
465
466    #[test]
467    fn test_message_to_message_conversion() {
468        let user_message = message::Message::User {
469            content: OneOrMany::one(message::UserContent::text("Hello")),
470        };
471
472        let assistant_message = message::Message::Assistant {
473            id: None,
474            content: OneOrMany::one(message::AssistantContent::text("Hi there!")),
475        };
476
477        let converted_user_message: Vec<Message> = user_message.clone().try_into().unwrap();
478        let converted_assistant_message: Vec<Message> =
479            assistant_message.clone().try_into().unwrap();
480
481        match converted_user_message[0].clone() {
482            Message::User { content, .. } => {
483                assert_eq!(
484                    content.first(),
485                    UserContent::Text {
486                        text: "Hello".to_string()
487                    }
488                );
489            }
490            _ => panic!("Expected user message"),
491        }
492
493        match converted_assistant_message[0].clone() {
494            Message::Assistant { content, .. } => {
495                assert_eq!(
496                    content[0].clone(),
497                    AssistantContent::Text {
498                        text: "Hi there!".to_string()
499                    }
500                );
501            }
502            _ => panic!("Expected assistant message"),
503        }
504
505        let original_user_message: message::Message =
506            converted_user_message[0].clone().try_into().unwrap();
507        let original_assistant_message: message::Message =
508            converted_assistant_message[0].clone().try_into().unwrap();
509
510        assert_eq!(original_user_message, user_message);
511        assert_eq!(original_assistant_message, assistant_message);
512    }
513
514    #[test]
515    fn test_message_from_message_conversion() {
516        let user_message = Message::User {
517            content: OneOrMany::one(UserContent::Text {
518                text: "Hello".to_string(),
519            }),
520            name: None,
521        };
522
523        let assistant_message = Message::Assistant {
524            content: vec![AssistantContent::Text {
525                text: "Hi there!".to_string(),
526            }],
527            reasoning: None,
528            refusal: None,
529            audio: None,
530            name: None,
531            tool_calls: vec![],
532        };
533
534        let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
535        let converted_assistant_message: message::Message =
536            assistant_message.clone().try_into().unwrap();
537
538        match converted_user_message.clone() {
539            message::Message::User { content } => {
540                assert_eq!(content.first(), message::UserContent::text("Hello"));
541            }
542            _ => panic!("Expected user message"),
543        }
544
545        match converted_assistant_message.clone() {
546            message::Message::Assistant { content, .. } => {
547                assert_eq!(
548                    content.first(),
549                    message::AssistantContent::text("Hi there!")
550                );
551            }
552            _ => panic!("Expected assistant message"),
553        }
554
555        let original_user_message: Vec<Message> = converted_user_message.try_into().unwrap();
556        let original_assistant_message: Vec<Message> =
557            converted_assistant_message.try_into().unwrap();
558
559        assert_eq!(original_user_message[0], user_message);
560        assert_eq!(original_assistant_message[0], assistant_message);
561    }
562
563    #[test]
564    fn test_user_message_single_text_serializes_as_string() {
565        let user_message = Message::User {
566            content: OneOrMany::one(UserContent::Text {
567                text: "Hello world".to_string(),
568            }),
569            name: None,
570        };
571
572        let serialized = serde_json::to_value(&user_message).unwrap();
573
574        assert_eq!(serialized["role"], "user");
575        assert_eq!(serialized["content"], "Hello world");
576    }
577
578    #[test]
579    fn test_user_message_multiple_parts_serializes_as_array() {
580        let user_message = Message::User {
581            content: OneOrMany::many(vec![
582                UserContent::Text {
583                    text: "What's in this image?".to_string(),
584                },
585                UserContent::Image {
586                    image_url: ImageUrl {
587                        url: "https://example.com/image.jpg".to_string(),
588                        detail: ImageDetail::default(),
589                    },
590                },
591            ])
592            .unwrap(),
593            name: None,
594        };
595
596        let serialized = serde_json::to_value(&user_message).unwrap();
597
598        assert_eq!(serialized["role"], "user");
599        assert!(serialized["content"].is_array());
600        assert_eq!(serialized["content"].as_array().unwrap().len(), 2);
601    }
602
603    #[test]
604    fn test_user_message_single_image_serializes_as_array() {
605        let user_message = Message::User {
606            content: OneOrMany::one(UserContent::Image {
607                image_url: ImageUrl {
608                    url: "https://example.com/image.jpg".to_string(),
609                    detail: ImageDetail::default(),
610                },
611            }),
612            name: None,
613        };
614
615        let serialized = serde_json::to_value(&user_message).unwrap();
616
617        assert_eq!(serialized["role"], "user");
618        // Single non-text content should still serialize as array
619        assert!(serialized["content"].is_array());
620    }
621    #[test]
622    fn test_client_initialization() {
623        let _client =
624            crate::providers::openai::Client::new("dummy-key").expect("Client::new() failed");
625        let _client_from_builder = crate::providers::openai::Client::builder()
626            .api_key("dummy-key")
627            .build()
628            .expect("Client::builder() failed");
629    }
630
631    #[test]
632    fn test_legacy_chat_completion_model_type_annotation_still_compiles() {
633        let client = crate::providers::openai::Client::new("dummy-key")
634            .expect("Client::new() failed")
635            .completions_api();
636
637        let _model: crate::providers::openai::completion::CompletionModel<reqwest::Client> =
638            client.completion_model("gpt-4o");
639    }
640
641    #[test]
642    fn test_legacy_embedding_model_type_annotation_still_compiles() {
643        let client =
644            crate::providers::openai::Client::new("dummy-key").expect("Client::new() failed");
645
646        let _model: crate::providers::openai::EmbeddingModel<reqwest::Client> =
647            client.embedding_model(crate::providers::openai::TEXT_EMBEDDING_3_SMALL);
648    }
649}