rig/providers/openai/
client.rs

1use crate::{
2    client::{
3        self, BearerAuth, Capabilities, Capable, DebugExt, Provider, ProviderBuilder,
4        ProviderClient,
5    },
6    extractor::ExtractorBuilder,
7    http_client::{self, HttpClientExt},
8    prelude::CompletionClient,
9    wasm_compat::{WasmCompatSend, WasmCompatSync},
10};
11use schemars::JsonSchema;
12use serde::{Deserialize, Serialize};
13use std::fmt::Debug;
14
15// ================================================================
16// Main OpenAI Client
17// ================================================================
18const OPENAI_API_BASE_URL: &str = "https://api.openai.com/v1";
19
20// ================================================================
21// OpenAI Responses API Extension
22// ================================================================
23#[derive(Debug, Default, Clone, Copy)]
24pub struct OpenAIResponsesExt;
25
26#[derive(Debug, Default, Clone, Copy)]
27pub struct OpenAIResponsesExtBuilder;
28
29// ================================================================
30// OpenAI Completions API Extension
31// ================================================================
32#[derive(Debug, Default, Clone, Copy)]
33pub struct OpenAICompletionsExt;
34
35#[derive(Debug, Default, Clone, Copy)]
36pub struct OpenAICompletionsExtBuilder;
37
38type OpenAIApiKey = BearerAuth;
39
40// Responses API client (default)
41pub type Client<H = reqwest::Client> = client::Client<OpenAIResponsesExt, H>;
42pub type ClientBuilder<H = reqwest::Client> =
43    client::ClientBuilder<OpenAIResponsesExtBuilder, OpenAIApiKey, H>;
44
45// Completions API client
46pub type CompletionsClient<H = reqwest::Client> = client::Client<OpenAICompletionsExt, H>;
47pub type CompletionsClientBuilder<H = reqwest::Client> =
48    client::ClientBuilder<OpenAICompletionsExtBuilder, OpenAIApiKey, H>;
49
50impl Provider for OpenAIResponsesExt {
51    type Builder = OpenAIResponsesExtBuilder;
52
53    const VERIFY_PATH: &'static str = "/models";
54
55    fn build<H>(
56        _: &crate::client::ClientBuilder<Self::Builder, OpenAIApiKey, H>,
57    ) -> http_client::Result<Self> {
58        Ok(Self)
59    }
60}
61
62impl Provider for OpenAICompletionsExt {
63    type Builder = OpenAICompletionsExtBuilder;
64
65    const VERIFY_PATH: &'static str = "/models";
66
67    fn build<H>(
68        _: &crate::client::ClientBuilder<Self::Builder, OpenAIApiKey, H>,
69    ) -> http_client::Result<Self> {
70        Ok(Self)
71    }
72}
73
74impl<H> Capabilities<H> for OpenAIResponsesExt {
75    type Completion = Capable<super::responses_api::ResponsesCompletionModel<H>>;
76    type Embeddings = Capable<super::EmbeddingModel<H>>;
77    type Transcription = Capable<super::TranscriptionModel<H>>;
78    #[cfg(feature = "image")]
79    type ImageGeneration = Capable<super::ImageGenerationModel<H>>;
80    #[cfg(feature = "audio")]
81    type AudioGeneration = Capable<super::audio_generation::AudioGenerationModel<H>>;
82}
83
84impl<H> Capabilities<H> for OpenAICompletionsExt {
85    type Completion = Capable<super::completion::CompletionModel<H>>;
86    type Embeddings = Capable<super::EmbeddingModel<H>>;
87    type Transcription = Capable<super::TranscriptionModel<H>>;
88    #[cfg(feature = "image")]
89    type ImageGeneration = Capable<super::ImageGenerationModel<H>>;
90    #[cfg(feature = "audio")]
91    type AudioGeneration = Capable<super::audio_generation::AudioGenerationModel<H>>;
92}
93
94impl DebugExt for OpenAIResponsesExt {}
95
96impl DebugExt for OpenAICompletionsExt {}
97
98impl ProviderBuilder for OpenAIResponsesExtBuilder {
99    type Output = OpenAIResponsesExt;
100    type ApiKey = OpenAIApiKey;
101
102    const BASE_URL: &'static str = OPENAI_API_BASE_URL;
103}
104
105impl ProviderBuilder for OpenAICompletionsExtBuilder {
106    type Output = OpenAICompletionsExt;
107    type ApiKey = OpenAIApiKey;
108
109    const BASE_URL: &'static str = OPENAI_API_BASE_URL;
110}
111
112impl<H> Client<H>
113where
114    H: HttpClientExt
115        + Clone
116        + std::fmt::Debug
117        + Default
118        + WasmCompatSend
119        + WasmCompatSync
120        + 'static,
121{
122    /// Create an extractor builder with the given completion model.
123    /// Uses the OpenAI Responses API (default behavior).
124    pub fn extractor<U>(
125        &self,
126        model: impl Into<String>,
127    ) -> ExtractorBuilder<super::responses_api::ResponsesCompletionModel<H>, U>
128    where
129        U: JsonSchema + for<'a> Deserialize<'a> + Serialize + WasmCompatSend + WasmCompatSync,
130    {
131        ExtractorBuilder::new(self.completion_model(model))
132    }
133
134    /// Create a Completions API client from this Responses API client.
135    /// Useful for switching to the traditional Chat Completions API.
136    pub fn completions_api(self) -> CompletionsClient<H> {
137        self.with_ext(OpenAICompletionsExt)
138    }
139}
140
141impl<H> CompletionsClient<H>
142where
143    H: HttpClientExt
144        + Clone
145        + std::fmt::Debug
146        + Default
147        + WasmCompatSend
148        + WasmCompatSync
149        + 'static,
150{
151    /// Create an extractor builder with the given completion model.
152    /// Uses the OpenAI Chat Completions API.
153    pub fn extractor<U>(
154        &self,
155        model: impl Into<String>,
156    ) -> ExtractorBuilder<super::completion::CompletionModel<H>, U>
157    where
158        U: JsonSchema + for<'a> Deserialize<'a> + Serialize + WasmCompatSend + WasmCompatSync,
159    {
160        ExtractorBuilder::new(self.completion_model(model))
161    }
162
163    /// Create a Responses API client from this Completions API client.
164    /// Useful for switching to the newer Responses API.
165    pub fn responses_api(self) -> Client<H> {
166        self.with_ext(OpenAIResponsesExt)
167    }
168}
169
170impl ProviderClient for Client {
171    type Input = OpenAIApiKey;
172
173    /// Create a new OpenAI Responses API client from the `OPENAI_API_KEY` environment variable.
174    /// Panics if the environment variable is not set.
175    fn from_env() -> Self {
176        let base_url: Option<String> = std::env::var("OPENAI_BASE_URL").ok();
177        let api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set");
178
179        let mut builder = Client::builder().api_key(&api_key);
180
181        if let Some(base) = base_url {
182            builder = builder.base_url(&base);
183        }
184
185        builder.build().unwrap()
186    }
187
188    fn from_val(input: Self::Input) -> Self {
189        Self::new(input).unwrap()
190    }
191}
192
193impl ProviderClient for CompletionsClient {
194    type Input = OpenAIApiKey;
195
196    /// Create a new OpenAI Completions API client from the `OPENAI_API_KEY` environment variable.
197    /// Panics if the environment variable is not set.
198    fn from_env() -> Self {
199        let base_url: Option<String> = std::env::var("OPENAI_BASE_URL").ok();
200        let api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set");
201
202        let mut builder = CompletionsClient::builder().api_key(&api_key);
203
204        if let Some(base) = base_url {
205            builder = builder.base_url(&base);
206        }
207
208        builder.build().unwrap()
209    }
210
211    fn from_val(input: Self::Input) -> Self {
212        Self::new(input).unwrap()
213    }
214}
215
216#[derive(Debug, Deserialize)]
217pub struct ApiErrorResponse {
218    pub(crate) message: String,
219}
220
221#[derive(Debug, Deserialize)]
222#[serde(untagged)]
223pub(crate) enum ApiResponse<T> {
224    Ok(T),
225    Err(ApiErrorResponse),
226}
227
228#[cfg(test)]
229mod tests {
230    use crate::message::ImageDetail;
231    use crate::providers::openai::{
232        AssistantContent, Function, ImageUrl, Message, ToolCall, ToolType, UserContent,
233    };
234    use crate::{OneOrMany, message};
235    use serde_path_to_error::deserialize;
236
237    #[test]
238    fn test_deserialize_message() {
239        let assistant_message_json = r#"
240        {
241            "role": "assistant",
242            "content": "\n\nHello there, how may I assist you today?"
243        }
244        "#;
245
246        let assistant_message_json2 = r#"
247        {
248            "role": "assistant",
249            "content": [
250                {
251                    "type": "text",
252                    "text": "\n\nHello there, how may I assist you today?"
253                }
254            ],
255            "tool_calls": null
256        }
257        "#;
258
259        let assistant_message_json3 = r#"
260        {
261            "role": "assistant",
262            "tool_calls": [
263                {
264                    "id": "call_h89ipqYUjEpCPI6SxspMnoUU",
265                    "type": "function",
266                    "function": {
267                        "name": "subtract",
268                        "arguments": "{\"x\": 2, \"y\": 5}"
269                    }
270                }
271            ],
272            "content": null,
273            "refusal": null
274        }
275        "#;
276
277        let user_message_json = r#"
278        {
279            "role": "user",
280            "content": [
281                {
282                    "type": "text",
283                    "text": "What's in this image?"
284                },
285                {
286                    "type": "image_url",
287                    "image_url": {
288                        "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
289                    }
290                },
291                {
292                    "type": "audio",
293                    "input_audio": {
294                        "data": "...",
295                        "format": "mp3"
296                    }
297                }
298            ]
299        }
300        "#;
301
302        let assistant_message: Message = {
303            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
304            deserialize(jd).unwrap_or_else(|err| {
305                panic!(
306                    "Deserialization error at {} ({}:{}): {}",
307                    err.path(),
308                    err.inner().line(),
309                    err.inner().column(),
310                    err
311                );
312            })
313        };
314
315        let assistant_message2: Message = {
316            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
317            deserialize(jd).unwrap_or_else(|err| {
318                panic!(
319                    "Deserialization error at {} ({}:{}): {}",
320                    err.path(),
321                    err.inner().line(),
322                    err.inner().column(),
323                    err
324                );
325            })
326        };
327
328        let assistant_message3: Message = {
329            let jd: &mut serde_json::Deserializer<serde_json::de::StrRead<'_>> =
330                &mut serde_json::Deserializer::from_str(assistant_message_json3);
331            deserialize(jd).unwrap_or_else(|err| {
332                panic!(
333                    "Deserialization error at {} ({}:{}): {}",
334                    err.path(),
335                    err.inner().line(),
336                    err.inner().column(),
337                    err
338                );
339            })
340        };
341
342        let user_message: Message = {
343            let jd = &mut serde_json::Deserializer::from_str(user_message_json);
344            deserialize(jd).unwrap_or_else(|err| {
345                panic!(
346                    "Deserialization error at {} ({}:{}): {}",
347                    err.path(),
348                    err.inner().line(),
349                    err.inner().column(),
350                    err
351                );
352            })
353        };
354
355        match assistant_message {
356            Message::Assistant { content, .. } => {
357                assert_eq!(
358                    content[0],
359                    AssistantContent::Text {
360                        text: "\n\nHello there, how may I assist you today?".to_string()
361                    }
362                );
363            }
364            _ => panic!("Expected assistant message"),
365        }
366
367        match assistant_message2 {
368            Message::Assistant {
369                content,
370                tool_calls,
371                ..
372            } => {
373                assert_eq!(
374                    content[0],
375                    AssistantContent::Text {
376                        text: "\n\nHello there, how may I assist you today?".to_string()
377                    }
378                );
379
380                assert_eq!(tool_calls, vec![]);
381            }
382            _ => panic!("Expected assistant message"),
383        }
384
385        match assistant_message3 {
386            Message::Assistant {
387                content,
388                tool_calls,
389                refusal,
390                ..
391            } => {
392                assert!(content.is_empty());
393                assert!(refusal.is_none());
394                assert_eq!(
395                    tool_calls[0],
396                    ToolCall {
397                        id: "call_h89ipqYUjEpCPI6SxspMnoUU".to_string(),
398                        r#type: ToolType::Function,
399                        function: Function {
400                            name: "subtract".to_string(),
401                            arguments: serde_json::json!({"x": 2, "y": 5}),
402                        },
403                    }
404                );
405            }
406            _ => panic!("Expected assistant message"),
407        }
408
409        match user_message {
410            Message::User { content, .. } => {
411                let (first, second) = {
412                    let mut iter = content.into_iter();
413                    (iter.next().unwrap(), iter.next().unwrap())
414                };
415                assert_eq!(
416                    first,
417                    UserContent::Text {
418                        text: "What's in this image?".to_string()
419                    }
420                );
421                assert_eq!(second, UserContent::Image { image_url: ImageUrl { url: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg".to_string(), detail: ImageDetail::default() } });
422            }
423            _ => panic!("Expected user message"),
424        }
425    }
426
427    #[test]
428    fn test_message_to_message_conversion() {
429        let user_message = message::Message::User {
430            content: OneOrMany::one(message::UserContent::text("Hello")),
431        };
432
433        let assistant_message = message::Message::Assistant {
434            id: None,
435            content: OneOrMany::one(message::AssistantContent::text("Hi there!")),
436        };
437
438        let converted_user_message: Vec<Message> = user_message.clone().try_into().unwrap();
439        let converted_assistant_message: Vec<Message> =
440            assistant_message.clone().try_into().unwrap();
441
442        match converted_user_message[0].clone() {
443            Message::User { content, .. } => {
444                assert_eq!(
445                    content.first(),
446                    UserContent::Text {
447                        text: "Hello".to_string()
448                    }
449                );
450            }
451            _ => panic!("Expected user message"),
452        }
453
454        match converted_assistant_message[0].clone() {
455            Message::Assistant { content, .. } => {
456                assert_eq!(
457                    content[0].clone(),
458                    AssistantContent::Text {
459                        text: "Hi there!".to_string()
460                    }
461                );
462            }
463            _ => panic!("Expected assistant message"),
464        }
465
466        let original_user_message: message::Message =
467            converted_user_message[0].clone().try_into().unwrap();
468        let original_assistant_message: message::Message =
469            converted_assistant_message[0].clone().try_into().unwrap();
470
471        assert_eq!(original_user_message, user_message);
472        assert_eq!(original_assistant_message, assistant_message);
473    }
474
475    #[test]
476    fn test_message_from_message_conversion() {
477        let user_message = Message::User {
478            content: OneOrMany::one(UserContent::Text {
479                text: "Hello".to_string(),
480            }),
481            name: None,
482        };
483
484        let assistant_message = Message::Assistant {
485            content: vec![AssistantContent::Text {
486                text: "Hi there!".to_string(),
487            }],
488            refusal: None,
489            audio: None,
490            name: None,
491            tool_calls: vec![],
492        };
493
494        let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
495        let converted_assistant_message: message::Message =
496            assistant_message.clone().try_into().unwrap();
497
498        match converted_user_message.clone() {
499            message::Message::User { content } => {
500                assert_eq!(content.first(), message::UserContent::text("Hello"));
501            }
502            _ => panic!("Expected user message"),
503        }
504
505        match converted_assistant_message.clone() {
506            message::Message::Assistant { content, .. } => {
507                assert_eq!(
508                    content.first(),
509                    message::AssistantContent::text("Hi there!")
510                );
511            }
512            _ => panic!("Expected assistant message"),
513        }
514
515        let original_user_message: Vec<Message> = converted_user_message.try_into().unwrap();
516        let original_assistant_message: Vec<Message> =
517            converted_assistant_message.try_into().unwrap();
518
519        assert_eq!(original_user_message[0], user_message);
520        assert_eq!(original_assistant_message[0], assistant_message);
521    }
522}