Skip to main content

rig/providers/openai/
client.rs

1use crate::{
2    client::{
3        self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
4        ProviderClient,
5    },
6    extractor::ExtractorBuilder,
7    http_client::{self, HttpClientExt},
8    prelude::CompletionClient,
9    wasm_compat::{WasmCompatSend, WasmCompatSync},
10};
11use schemars::JsonSchema;
12use serde::{Deserialize, Serialize};
13use std::fmt::Debug;
14
15// ================================================================
16// Main OpenAI Client
17// ================================================================
18const OPENAI_API_BASE_URL: &str = "https://api.openai.com/v1";
19
20// ================================================================
21// OpenAI Responses API Extension
22// ================================================================
23#[derive(Debug, Default, Clone, Copy)]
24pub struct OpenAIResponsesExt;
25
26#[derive(Debug, Default, Clone, Copy)]
27pub struct OpenAIResponsesExtBuilder;
28
29// ================================================================
30// OpenAI Completions API Extension
31// ================================================================
32#[derive(Debug, Default, Clone, Copy)]
33pub struct OpenAICompletionsExt;
34
35#[derive(Debug, Default, Clone, Copy)]
36pub struct OpenAICompletionsExtBuilder;
37
38type OpenAIApiKey = BearerAuth;
39
40// Responses API client (default)
41pub type Client<H = reqwest::Client> = client::Client<OpenAIResponsesExt, H>;
42pub type ClientBuilder<H = reqwest::Client> =
43    client::ClientBuilder<OpenAIResponsesExtBuilder, OpenAIApiKey, H>;
44
45// Completions API client
46pub type CompletionsClient<H = reqwest::Client> = client::Client<OpenAICompletionsExt, H>;
47pub type CompletionsClientBuilder<H = reqwest::Client> =
48    client::ClientBuilder<OpenAICompletionsExtBuilder, OpenAIApiKey, H>;
49
50impl Provider for OpenAIResponsesExt {
51    type Builder = OpenAIResponsesExtBuilder;
52    const VERIFY_PATH: &'static str = "/models";
53}
54
55impl Provider for OpenAICompletionsExt {
56    type Builder = OpenAICompletionsExtBuilder;
57    const VERIFY_PATH: &'static str = "/models";
58}
59
60impl<H> Capabilities<H> for OpenAIResponsesExt {
61    type Completion = Capable<super::responses_api::ResponsesCompletionModel<H>>;
62    type Embeddings = Capable<super::EmbeddingModel<H>>;
63    type Transcription = Capable<super::TranscriptionModel<H>>;
64    type ModelListing = Nothing;
65    #[cfg(feature = "image")]
66    type ImageGeneration = Capable<super::ImageGenerationModel<H>>;
67    #[cfg(feature = "audio")]
68    type AudioGeneration = Capable<super::audio_generation::AudioGenerationModel<H>>;
69}
70
71impl<H> Capabilities<H> for OpenAICompletionsExt {
72    type Completion = Capable<super::completion::CompletionModel<H>>;
73    type Embeddings = Capable<super::EmbeddingModel<H>>;
74    type Transcription = Capable<super::TranscriptionModel<H>>;
75    type ModelListing = Nothing;
76    #[cfg(feature = "image")]
77    type ImageGeneration = Capable<super::ImageGenerationModel<H>>;
78    #[cfg(feature = "audio")]
79    type AudioGeneration = Capable<super::audio_generation::AudioGenerationModel<H>>;
80}
81
82impl DebugExt for OpenAIResponsesExt {}
83
84impl DebugExt for OpenAICompletionsExt {}
85
86impl ProviderBuilder for OpenAIResponsesExtBuilder {
87    type Extension<H>
88        = OpenAIResponsesExt
89    where
90        H: HttpClientExt;
91    type ApiKey = OpenAIApiKey;
92
93    const BASE_URL: &'static str = OPENAI_API_BASE_URL;
94
95    fn build<H>(
96        _builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
97    ) -> http_client::Result<Self::Extension<H>>
98    where
99        H: HttpClientExt,
100    {
101        Ok(OpenAIResponsesExt)
102    }
103}
104
105impl ProviderBuilder for OpenAICompletionsExtBuilder {
106    type Extension<H>
107        = OpenAICompletionsExt
108    where
109        H: HttpClientExt;
110    type ApiKey = OpenAIApiKey;
111
112    const BASE_URL: &'static str = OPENAI_API_BASE_URL;
113
114    fn build<H>(
115        _builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
116    ) -> http_client::Result<Self::Extension<H>>
117    where
118        H: HttpClientExt,
119    {
120        Ok(OpenAICompletionsExt)
121    }
122}
123
124impl<H> Client<H>
125where
126    H: HttpClientExt
127        + Clone
128        + std::fmt::Debug
129        + Default
130        + WasmCompatSend
131        + WasmCompatSync
132        + 'static,
133{
134    /// Create an extractor builder with the given completion model.
135    /// Uses the OpenAI Responses API (default behavior).
136    pub fn extractor<U>(
137        &self,
138        model: impl Into<String>,
139    ) -> ExtractorBuilder<super::responses_api::ResponsesCompletionModel<H>, U>
140    where
141        U: JsonSchema + for<'a> Deserialize<'a> + Serialize + WasmCompatSend + WasmCompatSync,
142    {
143        ExtractorBuilder::new(self.completion_model(model))
144    }
145
146    /// Create a Completions API client from this Responses API client.
147    /// Useful for switching to the traditional Chat Completions API.
148    pub fn completions_api(self) -> CompletionsClient<H> {
149        self.with_ext(OpenAICompletionsExt)
150    }
151}
152
153impl<H> CompletionsClient<H>
154where
155    H: HttpClientExt
156        + Clone
157        + std::fmt::Debug
158        + Default
159        + WasmCompatSend
160        + WasmCompatSync
161        + 'static,
162{
163    /// Create an extractor builder with the given completion model.
164    /// Uses the OpenAI Chat Completions API.
165    pub fn extractor<U>(
166        &self,
167        model: impl Into<String>,
168    ) -> ExtractorBuilder<super::completion::CompletionModel<H>, U>
169    where
170        U: JsonSchema + for<'a> Deserialize<'a> + Serialize + WasmCompatSend + WasmCompatSync,
171    {
172        ExtractorBuilder::new(self.completion_model(model))
173    }
174
175    /// Create a Responses API client from this Completions API client.
176    /// Useful for switching to the newer Responses API.
177    pub fn responses_api(self) -> Client<H> {
178        self.with_ext(OpenAIResponsesExt)
179    }
180}
181
182impl ProviderClient for Client {
183    type Input = OpenAIApiKey;
184
185    /// Create a new OpenAI Responses API client from the `OPENAI_API_KEY` environment variable.
186    /// Panics if the environment variable is not set.
187    fn from_env() -> Self {
188        let base_url: Option<String> = std::env::var("OPENAI_BASE_URL").ok();
189        let api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set");
190
191        let mut builder = Client::builder().api_key(&api_key);
192
193        if let Some(base) = base_url {
194            builder = builder.base_url(&base);
195        }
196
197        builder.build().unwrap()
198    }
199
200    fn from_val(input: Self::Input) -> Self {
201        Self::new(input).unwrap()
202    }
203}
204
205impl ProviderClient for CompletionsClient {
206    type Input = OpenAIApiKey;
207
208    /// Create a new OpenAI Completions API client from the `OPENAI_API_KEY` environment variable.
209    /// Panics if the environment variable is not set.
210    fn from_env() -> Self {
211        let base_url: Option<String> = std::env::var("OPENAI_BASE_URL").ok();
212        let api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set");
213
214        let mut builder = CompletionsClient::builder().api_key(&api_key);
215
216        if let Some(base) = base_url {
217            builder = builder.base_url(&base);
218        }
219
220        builder.build().unwrap()
221    }
222
223    fn from_val(input: Self::Input) -> Self {
224        Self::new(input).unwrap()
225    }
226}
227
228#[derive(Debug, Deserialize)]
229pub struct ApiErrorResponse {
230    pub(crate) message: String,
231}
232
233#[derive(Debug, Deserialize)]
234#[serde(untagged)]
235pub(crate) enum ApiResponse<T> {
236    Ok(T),
237    Err(ApiErrorResponse),
238}
239
240#[cfg(test)]
241mod tests {
242    use crate::message::ImageDetail;
243    use crate::providers::openai::{
244        AssistantContent, Function, ImageUrl, Message, ToolCall, ToolType, UserContent,
245    };
246    use crate::{OneOrMany, message};
247    use serde_path_to_error::deserialize;
248
249    #[test]
250    fn test_deserialize_message() {
251        let assistant_message_json = r#"
252        {
253            "role": "assistant",
254            "content": "\n\nHello there, how may I assist you today?"
255        }
256        "#;
257
258        let assistant_message_json2 = r#"
259        {
260            "role": "assistant",
261            "content": [
262                {
263                    "type": "text",
264                    "text": "\n\nHello there, how may I assist you today?"
265                }
266            ],
267            "tool_calls": null
268        }
269        "#;
270
271        let assistant_message_json3 = r#"
272        {
273            "role": "assistant",
274            "tool_calls": [
275                {
276                    "id": "call_h89ipqYUjEpCPI6SxspMnoUU",
277                    "type": "function",
278                    "function": {
279                        "name": "subtract",
280                        "arguments": "{\"x\": 2, \"y\": 5}"
281                    }
282                }
283            ],
284            "content": null,
285            "refusal": null
286        }
287        "#;
288
289        let user_message_json = r#"
290        {
291            "role": "user",
292            "content": [
293                {
294                    "type": "text",
295                    "text": "What's in this image?"
296                },
297                {
298                    "type": "image_url",
299                    "image_url": {
300                        "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
301                    }
302                },
303                {
304                    "type": "audio",
305                    "input_audio": {
306                        "data": "...",
307                        "format": "mp3"
308                    }
309                }
310            ]
311        }
312        "#;
313
314        let assistant_message: Message = {
315            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
316            deserialize(jd).unwrap_or_else(|err| {
317                panic!(
318                    "Deserialization error at {} ({}:{}): {}",
319                    err.path(),
320                    err.inner().line(),
321                    err.inner().column(),
322                    err
323                );
324            })
325        };
326
327        let assistant_message2: Message = {
328            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
329            deserialize(jd).unwrap_or_else(|err| {
330                panic!(
331                    "Deserialization error at {} ({}:{}): {}",
332                    err.path(),
333                    err.inner().line(),
334                    err.inner().column(),
335                    err
336                );
337            })
338        };
339
340        let assistant_message3: Message = {
341            let jd: &mut serde_json::Deserializer<serde_json::de::StrRead<'_>> =
342                &mut serde_json::Deserializer::from_str(assistant_message_json3);
343            deserialize(jd).unwrap_or_else(|err| {
344                panic!(
345                    "Deserialization error at {} ({}:{}): {}",
346                    err.path(),
347                    err.inner().line(),
348                    err.inner().column(),
349                    err
350                );
351            })
352        };
353
354        let user_message: Message = {
355            let jd = &mut serde_json::Deserializer::from_str(user_message_json);
356            deserialize(jd).unwrap_or_else(|err| {
357                panic!(
358                    "Deserialization error at {} ({}:{}): {}",
359                    err.path(),
360                    err.inner().line(),
361                    err.inner().column(),
362                    err
363                );
364            })
365        };
366
367        match assistant_message {
368            Message::Assistant { content, .. } => {
369                assert_eq!(
370                    content[0],
371                    AssistantContent::Text {
372                        text: "\n\nHello there, how may I assist you today?".to_string()
373                    }
374                );
375            }
376            _ => panic!("Expected assistant message"),
377        }
378
379        match assistant_message2 {
380            Message::Assistant {
381                content,
382                tool_calls,
383                ..
384            } => {
385                assert_eq!(
386                    content[0],
387                    AssistantContent::Text {
388                        text: "\n\nHello there, how may I assist you today?".to_string()
389                    }
390                );
391
392                assert_eq!(tool_calls, vec![]);
393            }
394            _ => panic!("Expected assistant message"),
395        }
396
397        match assistant_message3 {
398            Message::Assistant {
399                content,
400                tool_calls,
401                refusal,
402                ..
403            } => {
404                assert!(content.is_empty());
405                assert!(refusal.is_none());
406                assert_eq!(
407                    tool_calls[0],
408                    ToolCall {
409                        id: "call_h89ipqYUjEpCPI6SxspMnoUU".to_string(),
410                        r#type: ToolType::Function,
411                        function: Function {
412                            name: "subtract".to_string(),
413                            arguments: serde_json::json!({"x": 2, "y": 5}),
414                        },
415                    }
416                );
417            }
418            _ => panic!("Expected assistant message"),
419        }
420
421        match user_message {
422            Message::User { content, .. } => {
423                let (first, second) = {
424                    let mut iter = content.into_iter();
425                    (iter.next().unwrap(), iter.next().unwrap())
426                };
427                assert_eq!(
428                    first,
429                    UserContent::Text {
430                        text: "What's in this image?".to_string()
431                    }
432                );
433                assert_eq!(second, UserContent::Image { image_url: ImageUrl { url: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg".to_string(), detail: ImageDetail::default() } });
434            }
435            _ => panic!("Expected user message"),
436        }
437    }
438
439    #[test]
440    fn test_message_to_message_conversion() {
441        let user_message = message::Message::User {
442            content: OneOrMany::one(message::UserContent::text("Hello")),
443        };
444
445        let assistant_message = message::Message::Assistant {
446            id: None,
447            content: OneOrMany::one(message::AssistantContent::text("Hi there!")),
448        };
449
450        let converted_user_message: Vec<Message> = user_message.clone().try_into().unwrap();
451        let converted_assistant_message: Vec<Message> =
452            assistant_message.clone().try_into().unwrap();
453
454        match converted_user_message[0].clone() {
455            Message::User { content, .. } => {
456                assert_eq!(
457                    content.first(),
458                    UserContent::Text {
459                        text: "Hello".to_string()
460                    }
461                );
462            }
463            _ => panic!("Expected user message"),
464        }
465
466        match converted_assistant_message[0].clone() {
467            Message::Assistant { content, .. } => {
468                assert_eq!(
469                    content[0].clone(),
470                    AssistantContent::Text {
471                        text: "Hi there!".to_string()
472                    }
473                );
474            }
475            _ => panic!("Expected assistant message"),
476        }
477
478        let original_user_message: message::Message =
479            converted_user_message[0].clone().try_into().unwrap();
480        let original_assistant_message: message::Message =
481            converted_assistant_message[0].clone().try_into().unwrap();
482
483        assert_eq!(original_user_message, user_message);
484        assert_eq!(original_assistant_message, assistant_message);
485    }
486
487    #[test]
488    fn test_message_from_message_conversion() {
489        let user_message = Message::User {
490            content: OneOrMany::one(UserContent::Text {
491                text: "Hello".to_string(),
492            }),
493            name: None,
494        };
495
496        let assistant_message = Message::Assistant {
497            content: vec![AssistantContent::Text {
498                text: "Hi there!".to_string(),
499            }],
500            refusal: None,
501            audio: None,
502            name: None,
503            tool_calls: vec![],
504        };
505
506        let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
507        let converted_assistant_message: message::Message =
508            assistant_message.clone().try_into().unwrap();
509
510        match converted_user_message.clone() {
511            message::Message::User { content } => {
512                assert_eq!(content.first(), message::UserContent::text("Hello"));
513            }
514            _ => panic!("Expected user message"),
515        }
516
517        match converted_assistant_message.clone() {
518            message::Message::Assistant { content, .. } => {
519                assert_eq!(
520                    content.first(),
521                    message::AssistantContent::text("Hi there!")
522                );
523            }
524            _ => panic!("Expected assistant message"),
525        }
526
527        let original_user_message: Vec<Message> = converted_user_message.try_into().unwrap();
528        let original_assistant_message: Vec<Message> =
529            converted_assistant_message.try_into().unwrap();
530
531        assert_eq!(original_user_message[0], user_message);
532        assert_eq!(original_assistant_message[0], assistant_message);
533    }
534
535    #[test]
536    fn test_user_message_single_text_serializes_as_string() {
537        let user_message = Message::User {
538            content: OneOrMany::one(UserContent::Text {
539                text: "Hello world".to_string(),
540            }),
541            name: None,
542        };
543
544        let serialized = serde_json::to_value(&user_message).unwrap();
545
546        assert_eq!(serialized["role"], "user");
547        assert_eq!(serialized["content"], "Hello world");
548    }
549
550    #[test]
551    fn test_user_message_multiple_parts_serializes_as_array() {
552        let user_message = Message::User {
553            content: OneOrMany::many(vec![
554                UserContent::Text {
555                    text: "What's in this image?".to_string(),
556                },
557                UserContent::Image {
558                    image_url: ImageUrl {
559                        url: "https://example.com/image.jpg".to_string(),
560                        detail: ImageDetail::default(),
561                    },
562                },
563            ])
564            .unwrap(),
565            name: None,
566        };
567
568        let serialized = serde_json::to_value(&user_message).unwrap();
569
570        assert_eq!(serialized["role"], "user");
571        assert!(serialized["content"].is_array());
572        assert_eq!(serialized["content"].as_array().unwrap().len(), 2);
573    }
574
575    #[test]
576    fn test_user_message_single_image_serializes_as_array() {
577        let user_message = Message::User {
578            content: OneOrMany::one(UserContent::Image {
579                image_url: ImageUrl {
580                    url: "https://example.com/image.jpg".to_string(),
581                    detail: ImageDetail::default(),
582                },
583            }),
584            name: None,
585        };
586
587        let serialized = serde_json::to_value(&user_message).unwrap();
588
589        assert_eq!(serialized["role"], "user");
590        // Single non-text content should still serialize as array
591        assert!(serialized["content"].is_array());
592    }
593    #[test]
594    fn test_client_initialization() {
595        let _client =
596            crate::providers::openai::Client::new("dummy-key").expect("Client::new() failed");
597        let _client_from_builder = crate::providers::openai::Client::builder()
598            .api_key("dummy-key")
599            .build()
600            .expect("Client::builder() failed");
601    }
602}