rig/providers/openai/
client.rs

1use crate::{
2    client::{
3        self, BearerAuth, Capabilities, Capable, DebugExt, Provider, ProviderBuilder,
4        ProviderClient,
5    },
6    extractor::ExtractorBuilder,
7    http_client::{self, HttpClientExt},
8    prelude::CompletionClient,
9    wasm_compat::{WasmCompatSend, WasmCompatSync},
10};
11use schemars::JsonSchema;
12use serde::{Deserialize, Serialize};
13use std::fmt::Debug;
14
15// ================================================================
16// Main OpenAI Client
17// ================================================================
18const OPENAI_API_BASE_URL: &str = "https://api.openai.com/v1";
19
20// ================================================================
21// OpenAI Responses API Extension
22// ================================================================
23#[derive(Debug, Default, Clone, Copy)]
24pub struct OpenAIResponsesExt;
25
26#[derive(Debug, Default, Clone, Copy)]
27pub struct OpenAIResponsesExtBuilder;
28
29// ================================================================
30// OpenAI Completions API Extension
31// ================================================================
32#[derive(Debug, Default, Clone, Copy)]
33pub struct OpenAICompletionsExt;
34
35#[derive(Debug, Default, Clone, Copy)]
36pub struct OpenAICompletionsExtBuilder;
37
38type OpenAIApiKey = BearerAuth;
39
40// Responses API client (default)
41pub type Client<H = reqwest::Client> = client::Client<OpenAIResponsesExt, H>;
42pub type ClientBuilder<H = reqwest::Client> =
43    client::ClientBuilder<OpenAIResponsesExtBuilder, OpenAIApiKey, H>;
44
45// Completions API client
46pub type CompletionsClient<H = reqwest::Client> = client::Client<OpenAICompletionsExt, H>;
47pub type CompletionsClientBuilder<H = reqwest::Client> =
48    client::ClientBuilder<OpenAICompletionsExtBuilder, OpenAIApiKey, H>;
49
50impl Provider for OpenAIResponsesExt {
51    type Builder = OpenAIResponsesExtBuilder;
52
53    const VERIFY_PATH: &'static str = "/models";
54
55    fn build<H>(
56        _: &crate::client::ClientBuilder<Self::Builder, OpenAIApiKey, H>,
57    ) -> http_client::Result<Self> {
58        Ok(Self)
59    }
60}
61
62impl Provider for OpenAICompletionsExt {
63    type Builder = OpenAICompletionsExtBuilder;
64
65    const VERIFY_PATH: &'static str = "/models";
66
67    fn build<H>(
68        _: &crate::client::ClientBuilder<Self::Builder, OpenAIApiKey, H>,
69    ) -> http_client::Result<Self> {
70        Ok(Self)
71    }
72}
73
74impl<H> Capabilities<H> for OpenAIResponsesExt {
75    type Completion = Capable<super::responses_api::ResponsesCompletionModel<H>>;
76    type Embeddings = Capable<super::EmbeddingModel<H>>;
77    type Transcription = Capable<super::TranscriptionModel<H>>;
78    #[cfg(feature = "image")]
79    type ImageGeneration = Capable<super::ImageGenerationModel<H>>;
80    #[cfg(feature = "audio")]
81    type AudioGeneration = Capable<super::audio_generation::AudioGenerationModel<H>>;
82}
83
84impl<H> Capabilities<H> for OpenAICompletionsExt {
85    type Completion = Capable<super::completion::CompletionModel<H>>;
86    type Embeddings = Capable<super::EmbeddingModel<H>>;
87    type Transcription = Capable<super::TranscriptionModel<H>>;
88    #[cfg(feature = "image")]
89    type ImageGeneration = Capable<super::ImageGenerationModel<H>>;
90    #[cfg(feature = "audio")]
91    type AudioGeneration = Capable<super::audio_generation::AudioGenerationModel<H>>;
92}
93
94impl DebugExt for OpenAIResponsesExt {}
95
96impl DebugExt for OpenAICompletionsExt {}
97
98impl ProviderBuilder for OpenAIResponsesExtBuilder {
99    type Output = OpenAIResponsesExt;
100    type ApiKey = OpenAIApiKey;
101
102    const BASE_URL: &'static str = OPENAI_API_BASE_URL;
103}
104
105impl ProviderBuilder for OpenAICompletionsExtBuilder {
106    type Output = OpenAICompletionsExt;
107    type ApiKey = OpenAIApiKey;
108
109    const BASE_URL: &'static str = OPENAI_API_BASE_URL;
110}
111
112impl<H> Client<H>
113where
114    H: HttpClientExt
115        + Clone
116        + std::fmt::Debug
117        + Default
118        + WasmCompatSend
119        + WasmCompatSync
120        + 'static,
121{
122    /// Create an extractor builder with the given completion model.
123    /// Uses the OpenAI Responses API (default behavior).
124    pub fn extractor<U>(
125        &self,
126        model: impl Into<String>,
127    ) -> ExtractorBuilder<super::responses_api::ResponsesCompletionModel<H>, U>
128    where
129        U: JsonSchema + for<'a> Deserialize<'a> + Serialize + WasmCompatSend + WasmCompatSync,
130    {
131        ExtractorBuilder::new(self.completion_model(model))
132    }
133
134    /// Create a Completions API client from this Responses API client.
135    /// Useful for switching to the traditional Chat Completions API.
136    pub fn completions_api(self) -> CompletionsClient<H> {
137        client::Client::from_parts(
138            self.base_url().to_string(),
139            self.headers().clone(),
140            self.http_client().clone(),
141            OpenAICompletionsExt,
142        )
143    }
144}
145
146impl<H> CompletionsClient<H>
147where
148    H: HttpClientExt
149        + Clone
150        + std::fmt::Debug
151        + Default
152        + WasmCompatSend
153        + WasmCompatSync
154        + 'static,
155{
156    /// Create an extractor builder with the given completion model.
157    /// Uses the OpenAI Chat Completions API.
158    pub fn extractor<U>(
159        &self,
160        model: impl Into<String>,
161    ) -> ExtractorBuilder<super::completion::CompletionModel<H>, U>
162    where
163        U: JsonSchema + for<'a> Deserialize<'a> + Serialize + WasmCompatSend + WasmCompatSync,
164    {
165        ExtractorBuilder::new(self.completion_model(model))
166    }
167
168    /// Create a Responses API client from this Completions API client.
169    /// Useful for switching to the newer Responses API.
170    pub fn responses_api(self) -> Client<H> {
171        client::Client::from_parts(
172            self.base_url().to_string(),
173            self.headers().clone(),
174            self.http_client().clone(),
175            OpenAIResponsesExt,
176        )
177    }
178}
179
180impl ProviderClient for Client {
181    type Input = OpenAIApiKey;
182
183    /// Create a new OpenAI Responses API client from the `OPENAI_API_KEY` environment variable.
184    /// Panics if the environment variable is not set.
185    fn from_env() -> Self {
186        let base_url: Option<String> = std::env::var("OPENAI_BASE_URL").ok();
187        let api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set");
188
189        let mut builder = Client::builder().api_key(&api_key);
190
191        if let Some(base) = base_url {
192            builder = builder.base_url(&base);
193        }
194
195        builder.build().unwrap()
196    }
197
198    fn from_val(input: Self::Input) -> Self {
199        Self::new(input).unwrap()
200    }
201}
202
203impl ProviderClient for CompletionsClient {
204    type Input = OpenAIApiKey;
205
206    /// Create a new OpenAI Completions API client from the `OPENAI_API_KEY` environment variable.
207    /// Panics if the environment variable is not set.
208    fn from_env() -> Self {
209        let base_url: Option<String> = std::env::var("OPENAI_BASE_URL").ok();
210        let api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set");
211
212        let mut builder = CompletionsClient::builder().api_key(&api_key);
213
214        if let Some(base) = base_url {
215            builder = builder.base_url(&base);
216        }
217
218        builder.build().unwrap()
219    }
220
221    fn from_val(input: Self::Input) -> Self {
222        Self::new(input).unwrap()
223    }
224}
225
226#[derive(Debug, Deserialize)]
227pub struct ApiErrorResponse {
228    pub(crate) message: String,
229}
230
231#[derive(Debug, Deserialize)]
232#[serde(untagged)]
233pub(crate) enum ApiResponse<T> {
234    Ok(T),
235    Err(ApiErrorResponse),
236}
237
238#[cfg(test)]
239mod tests {
240    use crate::message::ImageDetail;
241    use crate::providers::openai::{
242        AssistantContent, Function, ImageUrl, Message, ToolCall, ToolType, UserContent,
243    };
244    use crate::{OneOrMany, message};
245    use serde_path_to_error::deserialize;
246
247    #[test]
248    fn test_deserialize_message() {
249        let assistant_message_json = r#"
250        {
251            "role": "assistant",
252            "content": "\n\nHello there, how may I assist you today?"
253        }
254        "#;
255
256        let assistant_message_json2 = r#"
257        {
258            "role": "assistant",
259            "content": [
260                {
261                    "type": "text",
262                    "text": "\n\nHello there, how may I assist you today?"
263                }
264            ],
265            "tool_calls": null
266        }
267        "#;
268
269        let assistant_message_json3 = r#"
270        {
271            "role": "assistant",
272            "tool_calls": [
273                {
274                    "id": "call_h89ipqYUjEpCPI6SxspMnoUU",
275                    "type": "function",
276                    "function": {
277                        "name": "subtract",
278                        "arguments": "{\"x\": 2, \"y\": 5}"
279                    }
280                }
281            ],
282            "content": null,
283            "refusal": null
284        }
285        "#;
286
287        let user_message_json = r#"
288        {
289            "role": "user",
290            "content": [
291                {
292                    "type": "text",
293                    "text": "What's in this image?"
294                },
295                {
296                    "type": "image_url",
297                    "image_url": {
298                        "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
299                    }
300                },
301                {
302                    "type": "audio",
303                    "input_audio": {
304                        "data": "...",
305                        "format": "mp3"
306                    }
307                }
308            ]
309        }
310        "#;
311
312        let assistant_message: Message = {
313            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
314            deserialize(jd).unwrap_or_else(|err| {
315                panic!(
316                    "Deserialization error at {} ({}:{}): {}",
317                    err.path(),
318                    err.inner().line(),
319                    err.inner().column(),
320                    err
321                );
322            })
323        };
324
325        let assistant_message2: Message = {
326            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
327            deserialize(jd).unwrap_or_else(|err| {
328                panic!(
329                    "Deserialization error at {} ({}:{}): {}",
330                    err.path(),
331                    err.inner().line(),
332                    err.inner().column(),
333                    err
334                );
335            })
336        };
337
338        let assistant_message3: Message = {
339            let jd: &mut serde_json::Deserializer<serde_json::de::StrRead<'_>> =
340                &mut serde_json::Deserializer::from_str(assistant_message_json3);
341            deserialize(jd).unwrap_or_else(|err| {
342                panic!(
343                    "Deserialization error at {} ({}:{}): {}",
344                    err.path(),
345                    err.inner().line(),
346                    err.inner().column(),
347                    err
348                );
349            })
350        };
351
352        let user_message: Message = {
353            let jd = &mut serde_json::Deserializer::from_str(user_message_json);
354            deserialize(jd).unwrap_or_else(|err| {
355                panic!(
356                    "Deserialization error at {} ({}:{}): {}",
357                    err.path(),
358                    err.inner().line(),
359                    err.inner().column(),
360                    err
361                );
362            })
363        };
364
365        match assistant_message {
366            Message::Assistant { content, .. } => {
367                assert_eq!(
368                    content[0],
369                    AssistantContent::Text {
370                        text: "\n\nHello there, how may I assist you today?".to_string()
371                    }
372                );
373            }
374            _ => panic!("Expected assistant message"),
375        }
376
377        match assistant_message2 {
378            Message::Assistant {
379                content,
380                tool_calls,
381                ..
382            } => {
383                assert_eq!(
384                    content[0],
385                    AssistantContent::Text {
386                        text: "\n\nHello there, how may I assist you today?".to_string()
387                    }
388                );
389
390                assert_eq!(tool_calls, vec![]);
391            }
392            _ => panic!("Expected assistant message"),
393        }
394
395        match assistant_message3 {
396            Message::Assistant {
397                content,
398                tool_calls,
399                refusal,
400                ..
401            } => {
402                assert!(content.is_empty());
403                assert!(refusal.is_none());
404                assert_eq!(
405                    tool_calls[0],
406                    ToolCall {
407                        id: "call_h89ipqYUjEpCPI6SxspMnoUU".to_string(),
408                        r#type: ToolType::Function,
409                        function: Function {
410                            name: "subtract".to_string(),
411                            arguments: serde_json::json!({"x": 2, "y": 5}),
412                        },
413                    }
414                );
415            }
416            _ => panic!("Expected assistant message"),
417        }
418
419        match user_message {
420            Message::User { content, .. } => {
421                let (first, second) = {
422                    let mut iter = content.into_iter();
423                    (iter.next().unwrap(), iter.next().unwrap())
424                };
425                assert_eq!(
426                    first,
427                    UserContent::Text {
428                        text: "What's in this image?".to_string()
429                    }
430                );
431                assert_eq!(second, UserContent::Image { image_url: ImageUrl { url: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg".to_string(), detail: ImageDetail::default() } });
432            }
433            _ => panic!("Expected user message"),
434        }
435    }
436
437    #[test]
438    fn test_message_to_message_conversion() {
439        let user_message = message::Message::User {
440            content: OneOrMany::one(message::UserContent::text("Hello")),
441        };
442
443        let assistant_message = message::Message::Assistant {
444            id: None,
445            content: OneOrMany::one(message::AssistantContent::text("Hi there!")),
446        };
447
448        let converted_user_message: Vec<Message> = user_message.clone().try_into().unwrap();
449        let converted_assistant_message: Vec<Message> =
450            assistant_message.clone().try_into().unwrap();
451
452        match converted_user_message[0].clone() {
453            Message::User { content, .. } => {
454                assert_eq!(
455                    content.first(),
456                    UserContent::Text {
457                        text: "Hello".to_string()
458                    }
459                );
460            }
461            _ => panic!("Expected user message"),
462        }
463
464        match converted_assistant_message[0].clone() {
465            Message::Assistant { content, .. } => {
466                assert_eq!(
467                    content[0].clone(),
468                    AssistantContent::Text {
469                        text: "Hi there!".to_string()
470                    }
471                );
472            }
473            _ => panic!("Expected assistant message"),
474        }
475
476        let original_user_message: message::Message =
477            converted_user_message[0].clone().try_into().unwrap();
478        let original_assistant_message: message::Message =
479            converted_assistant_message[0].clone().try_into().unwrap();
480
481        assert_eq!(original_user_message, user_message);
482        assert_eq!(original_assistant_message, assistant_message);
483    }
484
485    #[test]
486    fn test_message_from_message_conversion() {
487        let user_message = Message::User {
488            content: OneOrMany::one(UserContent::Text {
489                text: "Hello".to_string(),
490            }),
491            name: None,
492        };
493
494        let assistant_message = Message::Assistant {
495            content: vec![AssistantContent::Text {
496                text: "Hi there!".to_string(),
497            }],
498            refusal: None,
499            audio: None,
500            name: None,
501            tool_calls: vec![],
502        };
503
504        let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
505        let converted_assistant_message: message::Message =
506            assistant_message.clone().try_into().unwrap();
507
508        match converted_user_message.clone() {
509            message::Message::User { content } => {
510                assert_eq!(content.first(), message::UserContent::text("Hello"));
511            }
512            _ => panic!("Expected user message"),
513        }
514
515        match converted_assistant_message.clone() {
516            message::Message::Assistant { content, .. } => {
517                assert_eq!(
518                    content.first(),
519                    message::AssistantContent::text("Hi there!")
520                );
521            }
522            _ => panic!("Expected assistant message"),
523        }
524
525        let original_user_message: Vec<Message> = converted_user_message.try_into().unwrap();
526        let original_assistant_message: Vec<Message> =
527            converted_assistant_message.try_into().unwrap();
528
529        assert_eq!(original_user_message[0], user_message);
530        assert_eq!(original_assistant_message[0], assistant_message);
531    }
532}