Skip to main content

rig/providers/openai/
client.rs

1use crate::{
2    client::{
3        self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
4        ProviderClient,
5    },
6    extractor::ExtractorBuilder,
7    http_client::{self, HttpClientExt},
8    prelude::CompletionClient,
9    wasm_compat::{WasmCompatSend, WasmCompatSync},
10};
11use schemars::JsonSchema;
12use serde::{Deserialize, Serialize};
13use std::fmt::Debug;
14
15// ================================================================
16// Main OpenAI Client
17// ================================================================
18const OPENAI_API_BASE_URL: &str = "https://api.openai.com/v1";
19
20// ================================================================
21// OpenAI Responses API Extension
22// ================================================================
23#[derive(Debug, Default, Clone, Copy)]
24pub struct OpenAIResponsesExt;
25
26#[derive(Debug, Default, Clone, Copy)]
27pub struct OpenAIResponsesExtBuilder;
28
29// ================================================================
30// OpenAI Completions API Extension
31// ================================================================
32#[derive(Debug, Default, Clone, Copy)]
33pub struct OpenAICompletionsExt;
34
35#[derive(Debug, Default, Clone, Copy)]
36pub struct OpenAICompletionsExtBuilder;
37
38type OpenAIApiKey = BearerAuth;
39
40// Responses API client (default)
41pub type Client<H = reqwest::Client> = client::Client<OpenAIResponsesExt, H>;
42pub type ClientBuilder<H = reqwest::Client> =
43    client::ClientBuilder<OpenAIResponsesExtBuilder, OpenAIApiKey, H>;
44
45// Completions API client
46pub type CompletionsClient<H = reqwest::Client> = client::Client<OpenAICompletionsExt, H>;
47pub type CompletionsClientBuilder<H = reqwest::Client> =
48    client::ClientBuilder<OpenAICompletionsExtBuilder, OpenAIApiKey, H>;
49
50impl Provider for OpenAIResponsesExt {
51    type Builder = OpenAIResponsesExtBuilder;
52    const VERIFY_PATH: &'static str = "/models";
53}
54
55impl Provider for OpenAICompletionsExt {
56    type Builder = OpenAICompletionsExtBuilder;
57    const VERIFY_PATH: &'static str = "/models";
58}
59
60impl<H> Capabilities<H> for OpenAIResponsesExt {
61    type Completion = Capable<super::responses_api::ResponsesCompletionModel<H>>;
62    type Embeddings = Capable<super::EmbeddingModel<H>>;
63    type Transcription = Capable<super::TranscriptionModel<H>>;
64    type ModelListing = Nothing;
65    #[cfg(feature = "image")]
66    type ImageGeneration = Capable<super::ImageGenerationModel<H>>;
67    #[cfg(feature = "audio")]
68    type AudioGeneration = Capable<super::audio_generation::AudioGenerationModel<H>>;
69}
70
71impl<H> Capabilities<H> for OpenAICompletionsExt {
72    type Completion = Capable<super::completion::CompletionModel<H>>;
73    type Embeddings = Capable<super::EmbeddingModel<H>>;
74    type Transcription = Capable<super::TranscriptionModel<H>>;
75    type ModelListing = Nothing;
76    #[cfg(feature = "image")]
77    type ImageGeneration = Capable<super::ImageGenerationModel<H>>;
78    #[cfg(feature = "audio")]
79    type AudioGeneration = Capable<super::audio_generation::AudioGenerationModel<H>>;
80}
81
82impl DebugExt for OpenAIResponsesExt {}
83
84impl DebugExt for OpenAICompletionsExt {}
85
86impl ProviderBuilder for OpenAIResponsesExtBuilder {
87    type Extension<H>
88        = OpenAIResponsesExt
89    where
90        H: HttpClientExt;
91    type ApiKey = OpenAIApiKey;
92
93    const BASE_URL: &'static str = OPENAI_API_BASE_URL;
94
95    fn build<H>(
96        _builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
97    ) -> http_client::Result<Self::Extension<H>>
98    where
99        H: HttpClientExt,
100    {
101        Ok(OpenAIResponsesExt)
102    }
103}
104
105impl ProviderBuilder for OpenAICompletionsExtBuilder {
106    type Extension<H>
107        = OpenAICompletionsExt
108    where
109        H: HttpClientExt;
110    type ApiKey = OpenAIApiKey;
111
112    const BASE_URL: &'static str = OPENAI_API_BASE_URL;
113
114    fn build<H>(
115        _builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
116    ) -> http_client::Result<Self::Extension<H>>
117    where
118        H: HttpClientExt,
119    {
120        Ok(OpenAICompletionsExt)
121    }
122}
123
124impl<H> Client<H>
125where
126    H: HttpClientExt
127        + Clone
128        + std::fmt::Debug
129        + Default
130        + WasmCompatSend
131        + WasmCompatSync
132        + 'static,
133{
134    /// Create an extractor builder with the given completion model.
135    /// Uses the OpenAI Responses API (default behavior).
136    pub fn extractor<U>(
137        &self,
138        model: impl Into<String>,
139    ) -> ExtractorBuilder<super::responses_api::ResponsesCompletionModel<H>, U>
140    where
141        U: JsonSchema + for<'a> Deserialize<'a> + Serialize + WasmCompatSend + WasmCompatSync,
142    {
143        ExtractorBuilder::new(self.completion_model(model))
144    }
145
146    /// Create a Completions API client from this Responses API client.
147    /// Useful for switching to the traditional Chat Completions API.
148    pub fn completions_api(self) -> CompletionsClient<H> {
149        self.with_ext(OpenAICompletionsExt)
150    }
151}
152
153#[cfg(not(target_family = "wasm"))]
154impl Client<reqwest::Client> {
155    /// WebSocket mode currently uses a native `tokio-tungstenite` transport and does
156    /// not reuse custom `HttpClientExt` backends, so this API is only exposed for the
157    /// default `reqwest::Client` transport.
158    pub fn responses_websocket_builder(
159        &self,
160        model: impl Into<String>,
161    ) -> super::responses_api::websocket::ResponsesWebSocketSessionBuilder {
162        super::responses_api::websocket::ResponsesWebSocketSessionBuilder::new(
163            self.completion_model(model),
164        )
165    }
166
167    /// This API is OpenAI-specific and only available on non-wasm targets in `rig-core`.
168    pub async fn responses_websocket(
169        &self,
170        model: impl Into<String>,
171    ) -> Result<
172        super::responses_api::websocket::ResponsesWebSocketSession,
173        crate::completion::CompletionError,
174    > {
175        self.responses_websocket_builder(model).connect().await
176    }
177}
178
179impl<H> CompletionsClient<H>
180where
181    H: HttpClientExt
182        + Clone
183        + std::fmt::Debug
184        + Default
185        + WasmCompatSend
186        + WasmCompatSync
187        + 'static,
188{
189    /// Create an extractor builder with the given completion model.
190    /// Uses the OpenAI Chat Completions API.
191    pub fn extractor<U>(
192        &self,
193        model: impl Into<String>,
194    ) -> ExtractorBuilder<super::completion::CompletionModel<H>, U>
195    where
196        U: JsonSchema + for<'a> Deserialize<'a> + Serialize + WasmCompatSend + WasmCompatSync,
197    {
198        ExtractorBuilder::new(self.completion_model(model))
199    }
200
201    /// Create a Responses API client from this Completions API client.
202    /// Useful for switching to the newer Responses API.
203    pub fn responses_api(self) -> Client<H> {
204        self.with_ext(OpenAIResponsesExt)
205    }
206}
207
208impl ProviderClient for Client {
209    type Input = OpenAIApiKey;
210
211    /// Create a new OpenAI Responses API client from the `OPENAI_API_KEY` environment variable.
212    /// Panics if the environment variable is not set.
213    fn from_env() -> Self {
214        let base_url: Option<String> = std::env::var("OPENAI_BASE_URL").ok();
215        let api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set");
216
217        let mut builder = Client::builder().api_key(&api_key);
218
219        if let Some(base) = base_url {
220            builder = builder.base_url(&base);
221        }
222
223        builder.build().unwrap()
224    }
225
226    fn from_val(input: Self::Input) -> Self {
227        Self::new(input).unwrap()
228    }
229}
230
231impl ProviderClient for CompletionsClient {
232    type Input = OpenAIApiKey;
233
234    /// Create a new OpenAI Completions API client from the `OPENAI_API_KEY` environment variable.
235    /// Panics if the environment variable is not set.
236    fn from_env() -> Self {
237        let base_url: Option<String> = std::env::var("OPENAI_BASE_URL").ok();
238        let api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set");
239
240        let mut builder = CompletionsClient::builder().api_key(&api_key);
241
242        if let Some(base) = base_url {
243            builder = builder.base_url(&base);
244        }
245
246        builder.build().unwrap()
247    }
248
249    fn from_val(input: Self::Input) -> Self {
250        Self::new(input).unwrap()
251    }
252}
253
254#[derive(Debug, Deserialize)]
255pub struct ApiErrorResponse {
256    pub(crate) message: String,
257}
258
259#[derive(Debug, Deserialize)]
260#[serde(untagged)]
261pub(crate) enum ApiResponse<T> {
262    Ok(T),
263    Err(ApiErrorResponse),
264}
265
266#[cfg(test)]
267mod tests {
268    use crate::message::ImageDetail;
269    use crate::providers::openai::{
270        AssistantContent, Function, ImageUrl, Message, ToolCall, ToolType, UserContent,
271    };
272    use crate::{OneOrMany, message};
273    use serde_path_to_error::deserialize;
274
275    #[test]
276    fn test_deserialize_message() {
277        let assistant_message_json = r#"
278        {
279            "role": "assistant",
280            "content": "\n\nHello there, how may I assist you today?"
281        }
282        "#;
283
284        let assistant_message_json2 = r#"
285        {
286            "role": "assistant",
287            "content": [
288                {
289                    "type": "text",
290                    "text": "\n\nHello there, how may I assist you today?"
291                }
292            ],
293            "tool_calls": null
294        }
295        "#;
296
297        let assistant_message_json3 = r#"
298        {
299            "role": "assistant",
300            "tool_calls": [
301                {
302                    "id": "call_h89ipqYUjEpCPI6SxspMnoUU",
303                    "type": "function",
304                    "function": {
305                        "name": "subtract",
306                        "arguments": "{\"x\": 2, \"y\": 5}"
307                    }
308                }
309            ],
310            "content": null,
311            "refusal": null
312        }
313        "#;
314
315        let user_message_json = r#"
316        {
317            "role": "user",
318            "content": [
319                {
320                    "type": "text",
321                    "text": "What's in this image?"
322                },
323                {
324                    "type": "image_url",
325                    "image_url": {
326                        "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
327                    }
328                },
329                {
330                    "type": "audio",
331                    "input_audio": {
332                        "data": "...",
333                        "format": "mp3"
334                    }
335                }
336            ]
337        }
338        "#;
339
340        let assistant_message: Message = {
341            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
342            deserialize(jd).unwrap_or_else(|err| {
343                panic!(
344                    "Deserialization error at {} ({}:{}): {}",
345                    err.path(),
346                    err.inner().line(),
347                    err.inner().column(),
348                    err
349                );
350            })
351        };
352
353        let assistant_message2: Message = {
354            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
355            deserialize(jd).unwrap_or_else(|err| {
356                panic!(
357                    "Deserialization error at {} ({}:{}): {}",
358                    err.path(),
359                    err.inner().line(),
360                    err.inner().column(),
361                    err
362                );
363            })
364        };
365
366        let assistant_message3: Message = {
367            let jd: &mut serde_json::Deserializer<serde_json::de::StrRead<'_>> =
368                &mut serde_json::Deserializer::from_str(assistant_message_json3);
369            deserialize(jd).unwrap_or_else(|err| {
370                panic!(
371                    "Deserialization error at {} ({}:{}): {}",
372                    err.path(),
373                    err.inner().line(),
374                    err.inner().column(),
375                    err
376                );
377            })
378        };
379
380        let user_message: Message = {
381            let jd = &mut serde_json::Deserializer::from_str(user_message_json);
382            deserialize(jd).unwrap_or_else(|err| {
383                panic!(
384                    "Deserialization error at {} ({}:{}): {}",
385                    err.path(),
386                    err.inner().line(),
387                    err.inner().column(),
388                    err
389                );
390            })
391        };
392
393        match assistant_message {
394            Message::Assistant { content, .. } => {
395                assert_eq!(
396                    content[0],
397                    AssistantContent::Text {
398                        text: "\n\nHello there, how may I assist you today?".to_string()
399                    }
400                );
401            }
402            _ => panic!("Expected assistant message"),
403        }
404
405        match assistant_message2 {
406            Message::Assistant {
407                content,
408                tool_calls,
409                ..
410            } => {
411                assert_eq!(
412                    content[0],
413                    AssistantContent::Text {
414                        text: "\n\nHello there, how may I assist you today?".to_string()
415                    }
416                );
417
418                assert_eq!(tool_calls, vec![]);
419            }
420            _ => panic!("Expected assistant message"),
421        }
422
423        match assistant_message3 {
424            Message::Assistant {
425                content,
426                tool_calls,
427                refusal,
428                ..
429            } => {
430                assert!(content.is_empty());
431                assert!(refusal.is_none());
432                assert_eq!(
433                    tool_calls[0],
434                    ToolCall {
435                        id: "call_h89ipqYUjEpCPI6SxspMnoUU".to_string(),
436                        r#type: ToolType::Function,
437                        function: Function {
438                            name: "subtract".to_string(),
439                            arguments: serde_json::json!({"x": 2, "y": 5}),
440                        },
441                    }
442                );
443            }
444            _ => panic!("Expected assistant message"),
445        }
446
447        match user_message {
448            Message::User { content, .. } => {
449                let (first, second) = {
450                    let mut iter = content.into_iter();
451                    (iter.next().unwrap(), iter.next().unwrap())
452                };
453                assert_eq!(
454                    first,
455                    UserContent::Text {
456                        text: "What's in this image?".to_string()
457                    }
458                );
459                assert_eq!(second, UserContent::Image { image_url: ImageUrl { url: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg".to_string(), detail: ImageDetail::default() } });
460            }
461            _ => panic!("Expected user message"),
462        }
463    }
464
465    #[test]
466    fn test_message_to_message_conversion() {
467        let user_message = message::Message::User {
468            content: OneOrMany::one(message::UserContent::text("Hello")),
469        };
470
471        let assistant_message = message::Message::Assistant {
472            id: None,
473            content: OneOrMany::one(message::AssistantContent::text("Hi there!")),
474        };
475
476        let converted_user_message: Vec<Message> = user_message.clone().try_into().unwrap();
477        let converted_assistant_message: Vec<Message> =
478            assistant_message.clone().try_into().unwrap();
479
480        match converted_user_message[0].clone() {
481            Message::User { content, .. } => {
482                assert_eq!(
483                    content.first(),
484                    UserContent::Text {
485                        text: "Hello".to_string()
486                    }
487                );
488            }
489            _ => panic!("Expected user message"),
490        }
491
492        match converted_assistant_message[0].clone() {
493            Message::Assistant { content, .. } => {
494                assert_eq!(
495                    content[0].clone(),
496                    AssistantContent::Text {
497                        text: "Hi there!".to_string()
498                    }
499                );
500            }
501            _ => panic!("Expected assistant message"),
502        }
503
504        let original_user_message: message::Message =
505            converted_user_message[0].clone().try_into().unwrap();
506        let original_assistant_message: message::Message =
507            converted_assistant_message[0].clone().try_into().unwrap();
508
509        assert_eq!(original_user_message, user_message);
510        assert_eq!(original_assistant_message, assistant_message);
511    }
512
513    #[test]
514    fn test_message_from_message_conversion() {
515        let user_message = Message::User {
516            content: OneOrMany::one(UserContent::Text {
517                text: "Hello".to_string(),
518            }),
519            name: None,
520        };
521
522        let assistant_message = Message::Assistant {
523            content: vec![AssistantContent::Text {
524                text: "Hi there!".to_string(),
525            }],
526            refusal: None,
527            audio: None,
528            name: None,
529            tool_calls: vec![],
530        };
531
532        let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
533        let converted_assistant_message: message::Message =
534            assistant_message.clone().try_into().unwrap();
535
536        match converted_user_message.clone() {
537            message::Message::User { content } => {
538                assert_eq!(content.first(), message::UserContent::text("Hello"));
539            }
540            _ => panic!("Expected user message"),
541        }
542
543        match converted_assistant_message.clone() {
544            message::Message::Assistant { content, .. } => {
545                assert_eq!(
546                    content.first(),
547                    message::AssistantContent::text("Hi there!")
548                );
549            }
550            _ => panic!("Expected assistant message"),
551        }
552
553        let original_user_message: Vec<Message> = converted_user_message.try_into().unwrap();
554        let original_assistant_message: Vec<Message> =
555            converted_assistant_message.try_into().unwrap();
556
557        assert_eq!(original_user_message[0], user_message);
558        assert_eq!(original_assistant_message[0], assistant_message);
559    }
560
561    #[test]
562    fn test_user_message_single_text_serializes_as_string() {
563        let user_message = Message::User {
564            content: OneOrMany::one(UserContent::Text {
565                text: "Hello world".to_string(),
566            }),
567            name: None,
568        };
569
570        let serialized = serde_json::to_value(&user_message).unwrap();
571
572        assert_eq!(serialized["role"], "user");
573        assert_eq!(serialized["content"], "Hello world");
574    }
575
576    #[test]
577    fn test_user_message_multiple_parts_serializes_as_array() {
578        let user_message = Message::User {
579            content: OneOrMany::many(vec![
580                UserContent::Text {
581                    text: "What's in this image?".to_string(),
582                },
583                UserContent::Image {
584                    image_url: ImageUrl {
585                        url: "https://example.com/image.jpg".to_string(),
586                        detail: ImageDetail::default(),
587                    },
588                },
589            ])
590            .unwrap(),
591            name: None,
592        };
593
594        let serialized = serde_json::to_value(&user_message).unwrap();
595
596        assert_eq!(serialized["role"], "user");
597        assert!(serialized["content"].is_array());
598        assert_eq!(serialized["content"].as_array().unwrap().len(), 2);
599    }
600
601    #[test]
602    fn test_user_message_single_image_serializes_as_array() {
603        let user_message = Message::User {
604            content: OneOrMany::one(UserContent::Image {
605                image_url: ImageUrl {
606                    url: "https://example.com/image.jpg".to_string(),
607                    detail: ImageDetail::default(),
608                },
609            }),
610            name: None,
611        };
612
613        let serialized = serde_json::to_value(&user_message).unwrap();
614
615        assert_eq!(serialized["role"], "user");
616        // Single non-text content should still serialize as array
617        assert!(serialized["content"].is_array());
618    }
619    #[test]
620    fn test_client_initialization() {
621        let _client =
622            crate::providers::openai::Client::new("dummy-key").expect("Client::new() failed");
623        let _client_from_builder = crate::providers::openai::Client::builder()
624            .api_key("dummy-key")
625            .build()
626            .expect("Client::builder() failed");
627    }
628}