1use crate::{
2 client::{
3 self, BearerAuth, Capabilities, Capable, DebugExt, Provider, ProviderBuilder,
4 ProviderClient,
5 },
6 extractor::ExtractorBuilder,
7 http_client::{self, HttpClientExt},
8 prelude::CompletionClient,
9 wasm_compat::{WasmCompatSend, WasmCompatSync},
10};
11use schemars::JsonSchema;
12use serde::{Deserialize, Serialize};
13use std::fmt::Debug;
14
15const OPENAI_API_BASE_URL: &str = "https://api.openai.com/v1";
19
20#[derive(Debug, Default, Clone, Copy)]
24pub struct OpenAIResponsesExt;
25
26#[derive(Debug, Default, Clone, Copy)]
27pub struct OpenAIResponsesExtBuilder;
28
29#[derive(Debug, Default, Clone, Copy)]
33pub struct OpenAICompletionsExt;
34
35#[derive(Debug, Default, Clone, Copy)]
36pub struct OpenAICompletionsExtBuilder;
37
38type OpenAIApiKey = BearerAuth;
39
40pub type Client<H = reqwest::Client> = client::Client<OpenAIResponsesExt, H>;
42pub type ClientBuilder<H = reqwest::Client> =
43 client::ClientBuilder<OpenAIResponsesExtBuilder, OpenAIApiKey, H>;
44
45pub type CompletionsClient<H = reqwest::Client> = client::Client<OpenAICompletionsExt, H>;
47pub type CompletionsClientBuilder<H = reqwest::Client> =
48 client::ClientBuilder<OpenAICompletionsExtBuilder, OpenAIApiKey, H>;
49
50impl Provider for OpenAIResponsesExt {
51 type Builder = OpenAIResponsesExtBuilder;
52
53 const VERIFY_PATH: &'static str = "/models";
54
55 fn build<H>(
56 _: &crate::client::ClientBuilder<Self::Builder, OpenAIApiKey, H>,
57 ) -> http_client::Result<Self> {
58 Ok(Self)
59 }
60}
61
62impl Provider for OpenAICompletionsExt {
63 type Builder = OpenAICompletionsExtBuilder;
64
65 const VERIFY_PATH: &'static str = "/models";
66
67 fn build<H>(
68 _: &crate::client::ClientBuilder<Self::Builder, OpenAIApiKey, H>,
69 ) -> http_client::Result<Self> {
70 Ok(Self)
71 }
72}
73
74impl<H> Capabilities<H> for OpenAIResponsesExt {
75 type Completion = Capable<super::responses_api::ResponsesCompletionModel<H>>;
76 type Embeddings = Capable<super::EmbeddingModel<H>>;
77 type Transcription = Capable<super::TranscriptionModel<H>>;
78 #[cfg(feature = "image")]
79 type ImageGeneration = Capable<super::ImageGenerationModel<H>>;
80 #[cfg(feature = "audio")]
81 type AudioGeneration = Capable<super::audio_generation::AudioGenerationModel<H>>;
82}
83
84impl<H> Capabilities<H> for OpenAICompletionsExt {
85 type Completion = Capable<super::completion::CompletionModel<H>>;
86 type Embeddings = Capable<super::EmbeddingModel<H>>;
87 type Transcription = Capable<super::TranscriptionModel<H>>;
88 #[cfg(feature = "image")]
89 type ImageGeneration = Capable<super::ImageGenerationModel<H>>;
90 #[cfg(feature = "audio")]
91 type AudioGeneration = Capable<super::audio_generation::AudioGenerationModel<H>>;
92}
93
94impl DebugExt for OpenAIResponsesExt {}
95
96impl DebugExt for OpenAICompletionsExt {}
97
98impl ProviderBuilder for OpenAIResponsesExtBuilder {
99 type Output = OpenAIResponsesExt;
100 type ApiKey = OpenAIApiKey;
101
102 const BASE_URL: &'static str = OPENAI_API_BASE_URL;
103}
104
105impl ProviderBuilder for OpenAICompletionsExtBuilder {
106 type Output = OpenAICompletionsExt;
107 type ApiKey = OpenAIApiKey;
108
109 const BASE_URL: &'static str = OPENAI_API_BASE_URL;
110}
111
112impl<H> Client<H>
113where
114 H: HttpClientExt
115 + Clone
116 + std::fmt::Debug
117 + Default
118 + WasmCompatSend
119 + WasmCompatSync
120 + 'static,
121{
122 pub fn extractor<U>(
125 &self,
126 model: impl Into<String>,
127 ) -> ExtractorBuilder<super::responses_api::ResponsesCompletionModel<H>, U>
128 where
129 U: JsonSchema + for<'a> Deserialize<'a> + Serialize + WasmCompatSend + WasmCompatSync,
130 {
131 ExtractorBuilder::new(self.completion_model(model))
132 }
133
134 pub fn completions_api(self) -> CompletionsClient<H> {
137 self.with_ext(OpenAICompletionsExt)
138 }
139}
140
141impl<H> CompletionsClient<H>
142where
143 H: HttpClientExt
144 + Clone
145 + std::fmt::Debug
146 + Default
147 + WasmCompatSend
148 + WasmCompatSync
149 + 'static,
150{
151 pub fn extractor<U>(
154 &self,
155 model: impl Into<String>,
156 ) -> ExtractorBuilder<super::completion::CompletionModel<H>, U>
157 where
158 U: JsonSchema + for<'a> Deserialize<'a> + Serialize + WasmCompatSend + WasmCompatSync,
159 {
160 ExtractorBuilder::new(self.completion_model(model))
161 }
162
163 pub fn responses_api(self) -> Client<H> {
166 self.with_ext(OpenAIResponsesExt)
167 }
168}
169
170impl ProviderClient for Client {
171 type Input = OpenAIApiKey;
172
173 fn from_env() -> Self {
176 let base_url: Option<String> = std::env::var("OPENAI_BASE_URL").ok();
177 let api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set");
178
179 let mut builder = Client::builder().api_key(&api_key);
180
181 if let Some(base) = base_url {
182 builder = builder.base_url(&base);
183 }
184
185 builder.build().unwrap()
186 }
187
188 fn from_val(input: Self::Input) -> Self {
189 Self::new(input).unwrap()
190 }
191}
192
193impl ProviderClient for CompletionsClient {
194 type Input = OpenAIApiKey;
195
196 fn from_env() -> Self {
199 let base_url: Option<String> = std::env::var("OPENAI_BASE_URL").ok();
200 let api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set");
201
202 let mut builder = CompletionsClient::builder().api_key(&api_key);
203
204 if let Some(base) = base_url {
205 builder = builder.base_url(&base);
206 }
207
208 builder.build().unwrap()
209 }
210
211 fn from_val(input: Self::Input) -> Self {
212 Self::new(input).unwrap()
213 }
214}
215
216#[derive(Debug, Deserialize)]
217pub struct ApiErrorResponse {
218 pub(crate) message: String,
219}
220
221#[derive(Debug, Deserialize)]
222#[serde(untagged)]
223pub(crate) enum ApiResponse<T> {
224 Ok(T),
225 Err(ApiErrorResponse),
226}
227
228#[cfg(test)]
229mod tests {
230 use crate::message::ImageDetail;
231 use crate::providers::openai::{
232 AssistantContent, Function, ImageUrl, Message, ToolCall, ToolType, UserContent,
233 };
234 use crate::{OneOrMany, message};
235 use serde_path_to_error::deserialize;
236
237 #[test]
238 fn test_deserialize_message() {
239 let assistant_message_json = r#"
240 {
241 "role": "assistant",
242 "content": "\n\nHello there, how may I assist you today?"
243 }
244 "#;
245
246 let assistant_message_json2 = r#"
247 {
248 "role": "assistant",
249 "content": [
250 {
251 "type": "text",
252 "text": "\n\nHello there, how may I assist you today?"
253 }
254 ],
255 "tool_calls": null
256 }
257 "#;
258
259 let assistant_message_json3 = r#"
260 {
261 "role": "assistant",
262 "tool_calls": [
263 {
264 "id": "call_h89ipqYUjEpCPI6SxspMnoUU",
265 "type": "function",
266 "function": {
267 "name": "subtract",
268 "arguments": "{\"x\": 2, \"y\": 5}"
269 }
270 }
271 ],
272 "content": null,
273 "refusal": null
274 }
275 "#;
276
277 let user_message_json = r#"
278 {
279 "role": "user",
280 "content": [
281 {
282 "type": "text",
283 "text": "What's in this image?"
284 },
285 {
286 "type": "image_url",
287 "image_url": {
288 "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
289 }
290 },
291 {
292 "type": "audio",
293 "input_audio": {
294 "data": "...",
295 "format": "mp3"
296 }
297 }
298 ]
299 }
300 "#;
301
302 let assistant_message: Message = {
303 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
304 deserialize(jd).unwrap_or_else(|err| {
305 panic!(
306 "Deserialization error at {} ({}:{}): {}",
307 err.path(),
308 err.inner().line(),
309 err.inner().column(),
310 err
311 );
312 })
313 };
314
315 let assistant_message2: Message = {
316 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
317 deserialize(jd).unwrap_or_else(|err| {
318 panic!(
319 "Deserialization error at {} ({}:{}): {}",
320 err.path(),
321 err.inner().line(),
322 err.inner().column(),
323 err
324 );
325 })
326 };
327
328 let assistant_message3: Message = {
329 let jd: &mut serde_json::Deserializer<serde_json::de::StrRead<'_>> =
330 &mut serde_json::Deserializer::from_str(assistant_message_json3);
331 deserialize(jd).unwrap_or_else(|err| {
332 panic!(
333 "Deserialization error at {} ({}:{}): {}",
334 err.path(),
335 err.inner().line(),
336 err.inner().column(),
337 err
338 );
339 })
340 };
341
342 let user_message: Message = {
343 let jd = &mut serde_json::Deserializer::from_str(user_message_json);
344 deserialize(jd).unwrap_or_else(|err| {
345 panic!(
346 "Deserialization error at {} ({}:{}): {}",
347 err.path(),
348 err.inner().line(),
349 err.inner().column(),
350 err
351 );
352 })
353 };
354
355 match assistant_message {
356 Message::Assistant { content, .. } => {
357 assert_eq!(
358 content[0],
359 AssistantContent::Text {
360 text: "\n\nHello there, how may I assist you today?".to_string()
361 }
362 );
363 }
364 _ => panic!("Expected assistant message"),
365 }
366
367 match assistant_message2 {
368 Message::Assistant {
369 content,
370 tool_calls,
371 ..
372 } => {
373 assert_eq!(
374 content[0],
375 AssistantContent::Text {
376 text: "\n\nHello there, how may I assist you today?".to_string()
377 }
378 );
379
380 assert_eq!(tool_calls, vec![]);
381 }
382 _ => panic!("Expected assistant message"),
383 }
384
385 match assistant_message3 {
386 Message::Assistant {
387 content,
388 tool_calls,
389 refusal,
390 ..
391 } => {
392 assert!(content.is_empty());
393 assert!(refusal.is_none());
394 assert_eq!(
395 tool_calls[0],
396 ToolCall {
397 id: "call_h89ipqYUjEpCPI6SxspMnoUU".to_string(),
398 r#type: ToolType::Function,
399 function: Function {
400 name: "subtract".to_string(),
401 arguments: serde_json::json!({"x": 2, "y": 5}),
402 },
403 }
404 );
405 }
406 _ => panic!("Expected assistant message"),
407 }
408
409 match user_message {
410 Message::User { content, .. } => {
411 let (first, second) = {
412 let mut iter = content.into_iter();
413 (iter.next().unwrap(), iter.next().unwrap())
414 };
415 assert_eq!(
416 first,
417 UserContent::Text {
418 text: "What's in this image?".to_string()
419 }
420 );
421 assert_eq!(second, UserContent::Image { image_url: ImageUrl { url: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg".to_string(), detail: ImageDetail::default() } });
422 }
423 _ => panic!("Expected user message"),
424 }
425 }
426
427 #[test]
428 fn test_message_to_message_conversion() {
429 let user_message = message::Message::User {
430 content: OneOrMany::one(message::UserContent::text("Hello")),
431 };
432
433 let assistant_message = message::Message::Assistant {
434 id: None,
435 content: OneOrMany::one(message::AssistantContent::text("Hi there!")),
436 };
437
438 let converted_user_message: Vec<Message> = user_message.clone().try_into().unwrap();
439 let converted_assistant_message: Vec<Message> =
440 assistant_message.clone().try_into().unwrap();
441
442 match converted_user_message[0].clone() {
443 Message::User { content, .. } => {
444 assert_eq!(
445 content.first(),
446 UserContent::Text {
447 text: "Hello".to_string()
448 }
449 );
450 }
451 _ => panic!("Expected user message"),
452 }
453
454 match converted_assistant_message[0].clone() {
455 Message::Assistant { content, .. } => {
456 assert_eq!(
457 content[0].clone(),
458 AssistantContent::Text {
459 text: "Hi there!".to_string()
460 }
461 );
462 }
463 _ => panic!("Expected assistant message"),
464 }
465
466 let original_user_message: message::Message =
467 converted_user_message[0].clone().try_into().unwrap();
468 let original_assistant_message: message::Message =
469 converted_assistant_message[0].clone().try_into().unwrap();
470
471 assert_eq!(original_user_message, user_message);
472 assert_eq!(original_assistant_message, assistant_message);
473 }
474
475 #[test]
476 fn test_message_from_message_conversion() {
477 let user_message = Message::User {
478 content: OneOrMany::one(UserContent::Text {
479 text: "Hello".to_string(),
480 }),
481 name: None,
482 };
483
484 let assistant_message = Message::Assistant {
485 content: vec![AssistantContent::Text {
486 text: "Hi there!".to_string(),
487 }],
488 refusal: None,
489 audio: None,
490 name: None,
491 tool_calls: vec![],
492 };
493
494 let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
495 let converted_assistant_message: message::Message =
496 assistant_message.clone().try_into().unwrap();
497
498 match converted_user_message.clone() {
499 message::Message::User { content } => {
500 assert_eq!(content.first(), message::UserContent::text("Hello"));
501 }
502 _ => panic!("Expected user message"),
503 }
504
505 match converted_assistant_message.clone() {
506 message::Message::Assistant { content, .. } => {
507 assert_eq!(
508 content.first(),
509 message::AssistantContent::text("Hi there!")
510 );
511 }
512 _ => panic!("Expected assistant message"),
513 }
514
515 let original_user_message: Vec<Message> = converted_user_message.try_into().unwrap();
516 let original_assistant_message: Vec<Message> =
517 converted_assistant_message.try_into().unwrap();
518
519 assert_eq!(original_user_message[0], user_message);
520 assert_eq!(original_assistant_message[0], assistant_message);
521 }
522}