Skip to main content

rig/providers/openai/
client.rs

1use crate::{
2    client::{
3        self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
4        ProviderClient,
5    },
6    extractor::ExtractorBuilder,
7    http_client::{self, HttpClientExt},
8    prelude::CompletionClient,
9    wasm_compat::{WasmCompatSend, WasmCompatSync},
10};
11use schemars::JsonSchema;
12use serde::{Deserialize, Serialize};
13use std::fmt::Debug;
14
15// ================================================================
16// Main OpenAI Client
17// ================================================================
18const OPENAI_API_BASE_URL: &str = "https://api.openai.com/v1";
19
20// ================================================================
21// OpenAI Responses API Extension
22// ================================================================
23#[derive(Debug, Default, Clone, Copy)]
24pub struct OpenAIResponsesExt;
25
26#[derive(Debug, Default, Clone, Copy)]
27pub struct OpenAIResponsesExtBuilder;
28
29// ================================================================
30// OpenAI Completions API Extension
31// ================================================================
32#[derive(Debug, Default, Clone, Copy)]
33pub struct OpenAICompletionsExt;
34
35#[derive(Debug, Default, Clone, Copy)]
36pub struct OpenAICompletionsExtBuilder;
37
38type OpenAIApiKey = BearerAuth;
39
40// Responses API client (default)
41pub type Client<H = reqwest::Client> = client::Client<OpenAIResponsesExt, H>;
42pub type ClientBuilder<H = reqwest::Client> =
43    client::ClientBuilder<OpenAIResponsesExtBuilder, OpenAIApiKey, H>;
44
45// Completions API client
46pub type CompletionsClient<H = reqwest::Client> = client::Client<OpenAICompletionsExt, H>;
47pub type CompletionsClientBuilder<H = reqwest::Client> =
48    client::ClientBuilder<OpenAICompletionsExtBuilder, OpenAIApiKey, H>;
49
50impl Provider for OpenAIResponsesExt {
51    type Builder = OpenAIResponsesExtBuilder;
52
53    const VERIFY_PATH: &'static str = "/models";
54
55    fn build<H>(
56        _: &crate::client::ClientBuilder<Self::Builder, OpenAIApiKey, H>,
57    ) -> http_client::Result<Self> {
58        Ok(Self)
59    }
60}
61
62impl Provider for OpenAICompletionsExt {
63    type Builder = OpenAICompletionsExtBuilder;
64
65    const VERIFY_PATH: &'static str = "/models";
66
67    fn build<H>(
68        _: &crate::client::ClientBuilder<Self::Builder, OpenAIApiKey, H>,
69    ) -> http_client::Result<Self> {
70        Ok(Self)
71    }
72}
73
74impl<H> Capabilities<H> for OpenAIResponsesExt {
75    type Completion = Capable<super::responses_api::ResponsesCompletionModel<H>>;
76    type Embeddings = Capable<super::EmbeddingModel<H>>;
77    type Transcription = Capable<super::TranscriptionModel<H>>;
78    type ModelListing = Nothing;
79    #[cfg(feature = "image")]
80    type ImageGeneration = Capable<super::ImageGenerationModel<H>>;
81    #[cfg(feature = "audio")]
82    type AudioGeneration = Capable<super::audio_generation::AudioGenerationModel<H>>;
83}
84
85impl<H> Capabilities<H> for OpenAICompletionsExt {
86    type Completion = Capable<super::completion::CompletionModel<H>>;
87    type Embeddings = Capable<super::EmbeddingModel<H>>;
88    type Transcription = Capable<super::TranscriptionModel<H>>;
89    type ModelListing = Nothing;
90    #[cfg(feature = "image")]
91    type ImageGeneration = Capable<super::ImageGenerationModel<H>>;
92    #[cfg(feature = "audio")]
93    type AudioGeneration = Capable<super::audio_generation::AudioGenerationModel<H>>;
94}
95
96impl DebugExt for OpenAIResponsesExt {}
97
98impl DebugExt for OpenAICompletionsExt {}
99
100impl ProviderBuilder for OpenAIResponsesExtBuilder {
101    type Output = OpenAIResponsesExt;
102    type ApiKey = OpenAIApiKey;
103
104    const BASE_URL: &'static str = OPENAI_API_BASE_URL;
105}
106
107impl ProviderBuilder for OpenAICompletionsExtBuilder {
108    type Output = OpenAICompletionsExt;
109    type ApiKey = OpenAIApiKey;
110
111    const BASE_URL: &'static str = OPENAI_API_BASE_URL;
112}
113
114impl<H> Client<H>
115where
116    H: HttpClientExt
117        + Clone
118        + std::fmt::Debug
119        + Default
120        + WasmCompatSend
121        + WasmCompatSync
122        + 'static,
123{
124    /// Create an extractor builder with the given completion model.
125    /// Uses the OpenAI Responses API (default behavior).
126    pub fn extractor<U>(
127        &self,
128        model: impl Into<String>,
129    ) -> ExtractorBuilder<super::responses_api::ResponsesCompletionModel<H>, U>
130    where
131        U: JsonSchema + for<'a> Deserialize<'a> + Serialize + WasmCompatSend + WasmCompatSync,
132    {
133        ExtractorBuilder::new(self.completion_model(model))
134    }
135
136    /// Create a Completions API client from this Responses API client.
137    /// Useful for switching to the traditional Chat Completions API.
138    pub fn completions_api(self) -> CompletionsClient<H> {
139        self.with_ext(OpenAICompletionsExt)
140    }
141}
142
143impl<H> CompletionsClient<H>
144where
145    H: HttpClientExt
146        + Clone
147        + std::fmt::Debug
148        + Default
149        + WasmCompatSend
150        + WasmCompatSync
151        + 'static,
152{
153    /// Create an extractor builder with the given completion model.
154    /// Uses the OpenAI Chat Completions API.
155    pub fn extractor<U>(
156        &self,
157        model: impl Into<String>,
158    ) -> ExtractorBuilder<super::completion::CompletionModel<H>, U>
159    where
160        U: JsonSchema + for<'a> Deserialize<'a> + Serialize + WasmCompatSend + WasmCompatSync,
161    {
162        ExtractorBuilder::new(self.completion_model(model))
163    }
164
165    /// Create a Responses API client from this Completions API client.
166    /// Useful for switching to the newer Responses API.
167    pub fn responses_api(self) -> Client<H> {
168        self.with_ext(OpenAIResponsesExt)
169    }
170}
171
172impl ProviderClient for Client {
173    type Input = OpenAIApiKey;
174
175    /// Create a new OpenAI Responses API client from the `OPENAI_API_KEY` environment variable.
176    /// Panics if the environment variable is not set.
177    fn from_env() -> Self {
178        let base_url: Option<String> = std::env::var("OPENAI_BASE_URL").ok();
179        let api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set");
180
181        let mut builder = Client::builder().api_key(&api_key);
182
183        if let Some(base) = base_url {
184            builder = builder.base_url(&base);
185        }
186
187        builder.build().unwrap()
188    }
189
190    fn from_val(input: Self::Input) -> Self {
191        Self::new(input).unwrap()
192    }
193}
194
195impl ProviderClient for CompletionsClient {
196    type Input = OpenAIApiKey;
197
198    /// Create a new OpenAI Completions API client from the `OPENAI_API_KEY` environment variable.
199    /// Panics if the environment variable is not set.
200    fn from_env() -> Self {
201        let base_url: Option<String> = std::env::var("OPENAI_BASE_URL").ok();
202        let api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set");
203
204        let mut builder = CompletionsClient::builder().api_key(&api_key);
205
206        if let Some(base) = base_url {
207            builder = builder.base_url(&base);
208        }
209
210        builder.build().unwrap()
211    }
212
213    fn from_val(input: Self::Input) -> Self {
214        Self::new(input).unwrap()
215    }
216}
217
218#[derive(Debug, Deserialize)]
219pub struct ApiErrorResponse {
220    pub(crate) message: String,
221}
222
223#[derive(Debug, Deserialize)]
224#[serde(untagged)]
225pub(crate) enum ApiResponse<T> {
226    Ok(T),
227    Err(ApiErrorResponse),
228}
229
230#[cfg(test)]
231mod tests {
232    use crate::message::ImageDetail;
233    use crate::providers::openai::{
234        AssistantContent, Function, ImageUrl, Message, ToolCall, ToolType, UserContent,
235    };
236    use crate::{OneOrMany, message};
237    use serde_path_to_error::deserialize;
238
239    #[test]
240    fn test_deserialize_message() {
241        let assistant_message_json = r#"
242        {
243            "role": "assistant",
244            "content": "\n\nHello there, how may I assist you today?"
245        }
246        "#;
247
248        let assistant_message_json2 = r#"
249        {
250            "role": "assistant",
251            "content": [
252                {
253                    "type": "text",
254                    "text": "\n\nHello there, how may I assist you today?"
255                }
256            ],
257            "tool_calls": null
258        }
259        "#;
260
261        let assistant_message_json3 = r#"
262        {
263            "role": "assistant",
264            "tool_calls": [
265                {
266                    "id": "call_h89ipqYUjEpCPI6SxspMnoUU",
267                    "type": "function",
268                    "function": {
269                        "name": "subtract",
270                        "arguments": "{\"x\": 2, \"y\": 5}"
271                    }
272                }
273            ],
274            "content": null,
275            "refusal": null
276        }
277        "#;
278
279        let user_message_json = r#"
280        {
281            "role": "user",
282            "content": [
283                {
284                    "type": "text",
285                    "text": "What's in this image?"
286                },
287                {
288                    "type": "image_url",
289                    "image_url": {
290                        "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
291                    }
292                },
293                {
294                    "type": "audio",
295                    "input_audio": {
296                        "data": "...",
297                        "format": "mp3"
298                    }
299                }
300            ]
301        }
302        "#;
303
304        let assistant_message: Message = {
305            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
306            deserialize(jd).unwrap_or_else(|err| {
307                panic!(
308                    "Deserialization error at {} ({}:{}): {}",
309                    err.path(),
310                    err.inner().line(),
311                    err.inner().column(),
312                    err
313                );
314            })
315        };
316
317        let assistant_message2: Message = {
318            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
319            deserialize(jd).unwrap_or_else(|err| {
320                panic!(
321                    "Deserialization error at {} ({}:{}): {}",
322                    err.path(),
323                    err.inner().line(),
324                    err.inner().column(),
325                    err
326                );
327            })
328        };
329
330        let assistant_message3: Message = {
331            let jd: &mut serde_json::Deserializer<serde_json::de::StrRead<'_>> =
332                &mut serde_json::Deserializer::from_str(assistant_message_json3);
333            deserialize(jd).unwrap_or_else(|err| {
334                panic!(
335                    "Deserialization error at {} ({}:{}): {}",
336                    err.path(),
337                    err.inner().line(),
338                    err.inner().column(),
339                    err
340                );
341            })
342        };
343
344        let user_message: Message = {
345            let jd = &mut serde_json::Deserializer::from_str(user_message_json);
346            deserialize(jd).unwrap_or_else(|err| {
347                panic!(
348                    "Deserialization error at {} ({}:{}): {}",
349                    err.path(),
350                    err.inner().line(),
351                    err.inner().column(),
352                    err
353                );
354            })
355        };
356
357        match assistant_message {
358            Message::Assistant { content, .. } => {
359                assert_eq!(
360                    content[0],
361                    AssistantContent::Text {
362                        text: "\n\nHello there, how may I assist you today?".to_string()
363                    }
364                );
365            }
366            _ => panic!("Expected assistant message"),
367        }
368
369        match assistant_message2 {
370            Message::Assistant {
371                content,
372                tool_calls,
373                ..
374            } => {
375                assert_eq!(
376                    content[0],
377                    AssistantContent::Text {
378                        text: "\n\nHello there, how may I assist you today?".to_string()
379                    }
380                );
381
382                assert_eq!(tool_calls, vec![]);
383            }
384            _ => panic!("Expected assistant message"),
385        }
386
387        match assistant_message3 {
388            Message::Assistant {
389                content,
390                tool_calls,
391                refusal,
392                ..
393            } => {
394                assert!(content.is_empty());
395                assert!(refusal.is_none());
396                assert_eq!(
397                    tool_calls[0],
398                    ToolCall {
399                        id: "call_h89ipqYUjEpCPI6SxspMnoUU".to_string(),
400                        r#type: ToolType::Function,
401                        function: Function {
402                            name: "subtract".to_string(),
403                            arguments: serde_json::json!({"x": 2, "y": 5}),
404                        },
405                    }
406                );
407            }
408            _ => panic!("Expected assistant message"),
409        }
410
411        match user_message {
412            Message::User { content, .. } => {
413                let (first, second) = {
414                    let mut iter = content.into_iter();
415                    (iter.next().unwrap(), iter.next().unwrap())
416                };
417                assert_eq!(
418                    first,
419                    UserContent::Text {
420                        text: "What's in this image?".to_string()
421                    }
422                );
423                assert_eq!(second, UserContent::Image { image_url: ImageUrl { url: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg".to_string(), detail: ImageDetail::default() } });
424            }
425            _ => panic!("Expected user message"),
426        }
427    }
428
429    #[test]
430    fn test_message_to_message_conversion() {
431        let user_message = message::Message::User {
432            content: OneOrMany::one(message::UserContent::text("Hello")),
433        };
434
435        let assistant_message = message::Message::Assistant {
436            id: None,
437            content: OneOrMany::one(message::AssistantContent::text("Hi there!")),
438        };
439
440        let converted_user_message: Vec<Message> = user_message.clone().try_into().unwrap();
441        let converted_assistant_message: Vec<Message> =
442            assistant_message.clone().try_into().unwrap();
443
444        match converted_user_message[0].clone() {
445            Message::User { content, .. } => {
446                assert_eq!(
447                    content.first(),
448                    UserContent::Text {
449                        text: "Hello".to_string()
450                    }
451                );
452            }
453            _ => panic!("Expected user message"),
454        }
455
456        match converted_assistant_message[0].clone() {
457            Message::Assistant { content, .. } => {
458                assert_eq!(
459                    content[0].clone(),
460                    AssistantContent::Text {
461                        text: "Hi there!".to_string()
462                    }
463                );
464            }
465            _ => panic!("Expected assistant message"),
466        }
467
468        let original_user_message: message::Message =
469            converted_user_message[0].clone().try_into().unwrap();
470        let original_assistant_message: message::Message =
471            converted_assistant_message[0].clone().try_into().unwrap();
472
473        assert_eq!(original_user_message, user_message);
474        assert_eq!(original_assistant_message, assistant_message);
475    }
476
477    #[test]
478    fn test_message_from_message_conversion() {
479        let user_message = Message::User {
480            content: OneOrMany::one(UserContent::Text {
481                text: "Hello".to_string(),
482            }),
483            name: None,
484        };
485
486        let assistant_message = Message::Assistant {
487            content: vec![AssistantContent::Text {
488                text: "Hi there!".to_string(),
489            }],
490            refusal: None,
491            audio: None,
492            name: None,
493            tool_calls: vec![],
494        };
495
496        let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
497        let converted_assistant_message: message::Message =
498            assistant_message.clone().try_into().unwrap();
499
500        match converted_user_message.clone() {
501            message::Message::User { content } => {
502                assert_eq!(content.first(), message::UserContent::text("Hello"));
503            }
504            _ => panic!("Expected user message"),
505        }
506
507        match converted_assistant_message.clone() {
508            message::Message::Assistant { content, .. } => {
509                assert_eq!(
510                    content.first(),
511                    message::AssistantContent::text("Hi there!")
512                );
513            }
514            _ => panic!("Expected assistant message"),
515        }
516
517        let original_user_message: Vec<Message> = converted_user_message.try_into().unwrap();
518        let original_assistant_message: Vec<Message> =
519            converted_assistant_message.try_into().unwrap();
520
521        assert_eq!(original_user_message[0], user_message);
522        assert_eq!(original_assistant_message[0], assistant_message);
523    }
524
525    #[test]
526    fn test_user_message_single_text_serializes_as_string() {
527        let user_message = Message::User {
528            content: OneOrMany::one(UserContent::Text {
529                text: "Hello world".to_string(),
530            }),
531            name: None,
532        };
533
534        let serialized = serde_json::to_value(&user_message).unwrap();
535
536        assert_eq!(serialized["role"], "user");
537        assert_eq!(serialized["content"], "Hello world");
538    }
539
540    #[test]
541    fn test_user_message_multiple_parts_serializes_as_array() {
542        let user_message = Message::User {
543            content: OneOrMany::many(vec![
544                UserContent::Text {
545                    text: "What's in this image?".to_string(),
546                },
547                UserContent::Image {
548                    image_url: ImageUrl {
549                        url: "https://example.com/image.jpg".to_string(),
550                        detail: ImageDetail::default(),
551                    },
552                },
553            ])
554            .unwrap(),
555            name: None,
556        };
557
558        let serialized = serde_json::to_value(&user_message).unwrap();
559
560        assert_eq!(serialized["role"], "user");
561        assert!(serialized["content"].is_array());
562        assert_eq!(serialized["content"].as_array().unwrap().len(), 2);
563    }
564
565    #[test]
566    fn test_user_message_single_image_serializes_as_array() {
567        let user_message = Message::User {
568            content: OneOrMany::one(UserContent::Image {
569                image_url: ImageUrl {
570                    url: "https://example.com/image.jpg".to_string(),
571                    detail: ImageDetail::default(),
572                },
573            }),
574            name: None,
575        };
576
577        let serialized = serde_json::to_value(&user_message).unwrap();
578
579        assert_eq!(serialized["role"], "user");
580        // Single non-text content should still serialize as array
581        assert!(serialized["content"].is_array());
582    }
583    #[test]
584    fn test_client_initialization() {
585        let _client: crate::providers::openai::Client =
586            crate::providers::openai::Client::new("dummy-key").expect("Client::new() failed");
587        let _client_from_builder: crate::providers::openai::Client =
588            crate::providers::openai::Client::builder()
589                .api_key("dummy-key")
590                .build()
591                .expect("Client::builder() failed");
592    }
593}