1use crate::{
2 client::{
3 self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
4 ProviderClient,
5 },
6 extractor::ExtractorBuilder,
7 http_client::{self, HttpClientExt},
8 prelude::CompletionClient,
9 wasm_compat::{WasmCompatSend, WasmCompatSync},
10};
11use schemars::JsonSchema;
12use serde::{Deserialize, Serialize};
13use std::fmt::Debug;
14
15const OPENAI_API_BASE_URL: &str = "https://api.openai.com/v1";
19
20#[derive(Debug, Default, Clone, Copy)]
24pub struct OpenAIResponsesExt;
25
26#[derive(Debug, Default, Clone, Copy)]
27pub struct OpenAIResponsesExtBuilder;
28
29#[derive(Debug, Default, Clone, Copy)]
33pub struct OpenAICompletionsExt;
34
35#[derive(Debug, Default, Clone, Copy)]
36pub struct OpenAICompletionsExtBuilder;
37
38type OpenAIApiKey = BearerAuth;
39
40pub type Client<H = reqwest::Client> = client::Client<OpenAIResponsesExt, H>;
42pub type ClientBuilder<H = reqwest::Client> =
43 client::ClientBuilder<OpenAIResponsesExtBuilder, OpenAIApiKey, H>;
44
45pub type CompletionsClient<H = reqwest::Client> = client::Client<OpenAICompletionsExt, H>;
47pub type CompletionsClientBuilder<H = reqwest::Client> =
48 client::ClientBuilder<OpenAICompletionsExtBuilder, OpenAIApiKey, H>;
49
50impl Provider for OpenAIResponsesExt {
51 type Builder = OpenAIResponsesExtBuilder;
52
53 const VERIFY_PATH: &'static str = "/models";
54
55 fn build<H>(
56 _: &crate::client::ClientBuilder<Self::Builder, OpenAIApiKey, H>,
57 ) -> http_client::Result<Self> {
58 Ok(Self)
59 }
60}
61
62impl Provider for OpenAICompletionsExt {
63 type Builder = OpenAICompletionsExtBuilder;
64
65 const VERIFY_PATH: &'static str = "/models";
66
67 fn build<H>(
68 _: &crate::client::ClientBuilder<Self::Builder, OpenAIApiKey, H>,
69 ) -> http_client::Result<Self> {
70 Ok(Self)
71 }
72}
73
74impl<H> Capabilities<H> for OpenAIResponsesExt {
75 type Completion = Capable<super::responses_api::ResponsesCompletionModel<H>>;
76 type Embeddings = Capable<super::EmbeddingModel<H>>;
77 type Transcription = Capable<super::TranscriptionModel<H>>;
78 type ModelListing = Nothing;
79 #[cfg(feature = "image")]
80 type ImageGeneration = Capable<super::ImageGenerationModel<H>>;
81 #[cfg(feature = "audio")]
82 type AudioGeneration = Capable<super::audio_generation::AudioGenerationModel<H>>;
83}
84
85impl<H> Capabilities<H> for OpenAICompletionsExt {
86 type Completion = Capable<super::completion::CompletionModel<H>>;
87 type Embeddings = Capable<super::EmbeddingModel<H>>;
88 type Transcription = Capable<super::TranscriptionModel<H>>;
89 type ModelListing = Nothing;
90 #[cfg(feature = "image")]
91 type ImageGeneration = Capable<super::ImageGenerationModel<H>>;
92 #[cfg(feature = "audio")]
93 type AudioGeneration = Capable<super::audio_generation::AudioGenerationModel<H>>;
94}
95
96impl DebugExt for OpenAIResponsesExt {}
97
98impl DebugExt for OpenAICompletionsExt {}
99
100impl ProviderBuilder for OpenAIResponsesExtBuilder {
101 type Output = OpenAIResponsesExt;
102 type ApiKey = OpenAIApiKey;
103
104 const BASE_URL: &'static str = OPENAI_API_BASE_URL;
105}
106
107impl ProviderBuilder for OpenAICompletionsExtBuilder {
108 type Output = OpenAICompletionsExt;
109 type ApiKey = OpenAIApiKey;
110
111 const BASE_URL: &'static str = OPENAI_API_BASE_URL;
112}
113
114impl<H> Client<H>
115where
116 H: HttpClientExt
117 + Clone
118 + std::fmt::Debug
119 + Default
120 + WasmCompatSend
121 + WasmCompatSync
122 + 'static,
123{
124 pub fn extractor<U>(
127 &self,
128 model: impl Into<String>,
129 ) -> ExtractorBuilder<super::responses_api::ResponsesCompletionModel<H>, U>
130 where
131 U: JsonSchema + for<'a> Deserialize<'a> + Serialize + WasmCompatSend + WasmCompatSync,
132 {
133 ExtractorBuilder::new(self.completion_model(model))
134 }
135
136 pub fn completions_api(self) -> CompletionsClient<H> {
139 self.with_ext(OpenAICompletionsExt)
140 }
141}
142
143impl<H> CompletionsClient<H>
144where
145 H: HttpClientExt
146 + Clone
147 + std::fmt::Debug
148 + Default
149 + WasmCompatSend
150 + WasmCompatSync
151 + 'static,
152{
153 pub fn extractor<U>(
156 &self,
157 model: impl Into<String>,
158 ) -> ExtractorBuilder<super::completion::CompletionModel<H>, U>
159 where
160 U: JsonSchema + for<'a> Deserialize<'a> + Serialize + WasmCompatSend + WasmCompatSync,
161 {
162 ExtractorBuilder::new(self.completion_model(model))
163 }
164
165 pub fn responses_api(self) -> Client<H> {
168 self.with_ext(OpenAIResponsesExt)
169 }
170}
171
172impl ProviderClient for Client {
173 type Input = OpenAIApiKey;
174
175 fn from_env() -> Self {
178 let base_url: Option<String> = std::env::var("OPENAI_BASE_URL").ok();
179 let api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set");
180
181 let mut builder = Client::builder().api_key(&api_key);
182
183 if let Some(base) = base_url {
184 builder = builder.base_url(&base);
185 }
186
187 builder.build().unwrap()
188 }
189
190 fn from_val(input: Self::Input) -> Self {
191 Self::new(input).unwrap()
192 }
193}
194
195impl ProviderClient for CompletionsClient {
196 type Input = OpenAIApiKey;
197
198 fn from_env() -> Self {
201 let base_url: Option<String> = std::env::var("OPENAI_BASE_URL").ok();
202 let api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set");
203
204 let mut builder = CompletionsClient::builder().api_key(&api_key);
205
206 if let Some(base) = base_url {
207 builder = builder.base_url(&base);
208 }
209
210 builder.build().unwrap()
211 }
212
213 fn from_val(input: Self::Input) -> Self {
214 Self::new(input).unwrap()
215 }
216}
217
218#[derive(Debug, Deserialize)]
219pub struct ApiErrorResponse {
220 pub(crate) message: String,
221}
222
223#[derive(Debug, Deserialize)]
224#[serde(untagged)]
225pub(crate) enum ApiResponse<T> {
226 Ok(T),
227 Err(ApiErrorResponse),
228}
229
230#[cfg(test)]
231mod tests {
232 use crate::message::ImageDetail;
233 use crate::providers::openai::{
234 AssistantContent, Function, ImageUrl, Message, ToolCall, ToolType, UserContent,
235 };
236 use crate::{OneOrMany, message};
237 use serde_path_to_error::deserialize;
238
239 #[test]
240 fn test_deserialize_message() {
241 let assistant_message_json = r#"
242 {
243 "role": "assistant",
244 "content": "\n\nHello there, how may I assist you today?"
245 }
246 "#;
247
248 let assistant_message_json2 = r#"
249 {
250 "role": "assistant",
251 "content": [
252 {
253 "type": "text",
254 "text": "\n\nHello there, how may I assist you today?"
255 }
256 ],
257 "tool_calls": null
258 }
259 "#;
260
261 let assistant_message_json3 = r#"
262 {
263 "role": "assistant",
264 "tool_calls": [
265 {
266 "id": "call_h89ipqYUjEpCPI6SxspMnoUU",
267 "type": "function",
268 "function": {
269 "name": "subtract",
270 "arguments": "{\"x\": 2, \"y\": 5}"
271 }
272 }
273 ],
274 "content": null,
275 "refusal": null
276 }
277 "#;
278
279 let user_message_json = r#"
280 {
281 "role": "user",
282 "content": [
283 {
284 "type": "text",
285 "text": "What's in this image?"
286 },
287 {
288 "type": "image_url",
289 "image_url": {
290 "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
291 }
292 },
293 {
294 "type": "audio",
295 "input_audio": {
296 "data": "...",
297 "format": "mp3"
298 }
299 }
300 ]
301 }
302 "#;
303
304 let assistant_message: Message = {
305 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
306 deserialize(jd).unwrap_or_else(|err| {
307 panic!(
308 "Deserialization error at {} ({}:{}): {}",
309 err.path(),
310 err.inner().line(),
311 err.inner().column(),
312 err
313 );
314 })
315 };
316
317 let assistant_message2: Message = {
318 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
319 deserialize(jd).unwrap_or_else(|err| {
320 panic!(
321 "Deserialization error at {} ({}:{}): {}",
322 err.path(),
323 err.inner().line(),
324 err.inner().column(),
325 err
326 );
327 })
328 };
329
330 let assistant_message3: Message = {
331 let jd: &mut serde_json::Deserializer<serde_json::de::StrRead<'_>> =
332 &mut serde_json::Deserializer::from_str(assistant_message_json3);
333 deserialize(jd).unwrap_or_else(|err| {
334 panic!(
335 "Deserialization error at {} ({}:{}): {}",
336 err.path(),
337 err.inner().line(),
338 err.inner().column(),
339 err
340 );
341 })
342 };
343
344 let user_message: Message = {
345 let jd = &mut serde_json::Deserializer::from_str(user_message_json);
346 deserialize(jd).unwrap_or_else(|err| {
347 panic!(
348 "Deserialization error at {} ({}:{}): {}",
349 err.path(),
350 err.inner().line(),
351 err.inner().column(),
352 err
353 );
354 })
355 };
356
357 match assistant_message {
358 Message::Assistant { content, .. } => {
359 assert_eq!(
360 content[0],
361 AssistantContent::Text {
362 text: "\n\nHello there, how may I assist you today?".to_string()
363 }
364 );
365 }
366 _ => panic!("Expected assistant message"),
367 }
368
369 match assistant_message2 {
370 Message::Assistant {
371 content,
372 tool_calls,
373 ..
374 } => {
375 assert_eq!(
376 content[0],
377 AssistantContent::Text {
378 text: "\n\nHello there, how may I assist you today?".to_string()
379 }
380 );
381
382 assert_eq!(tool_calls, vec![]);
383 }
384 _ => panic!("Expected assistant message"),
385 }
386
387 match assistant_message3 {
388 Message::Assistant {
389 content,
390 tool_calls,
391 refusal,
392 ..
393 } => {
394 assert!(content.is_empty());
395 assert!(refusal.is_none());
396 assert_eq!(
397 tool_calls[0],
398 ToolCall {
399 id: "call_h89ipqYUjEpCPI6SxspMnoUU".to_string(),
400 r#type: ToolType::Function,
401 function: Function {
402 name: "subtract".to_string(),
403 arguments: serde_json::json!({"x": 2, "y": 5}),
404 },
405 }
406 );
407 }
408 _ => panic!("Expected assistant message"),
409 }
410
411 match user_message {
412 Message::User { content, .. } => {
413 let (first, second) = {
414 let mut iter = content.into_iter();
415 (iter.next().unwrap(), iter.next().unwrap())
416 };
417 assert_eq!(
418 first,
419 UserContent::Text {
420 text: "What's in this image?".to_string()
421 }
422 );
423 assert_eq!(second, UserContent::Image { image_url: ImageUrl { url: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg".to_string(), detail: ImageDetail::default() } });
424 }
425 _ => panic!("Expected user message"),
426 }
427 }
428
429 #[test]
430 fn test_message_to_message_conversion() {
431 let user_message = message::Message::User {
432 content: OneOrMany::one(message::UserContent::text("Hello")),
433 };
434
435 let assistant_message = message::Message::Assistant {
436 id: None,
437 content: OneOrMany::one(message::AssistantContent::text("Hi there!")),
438 };
439
440 let converted_user_message: Vec<Message> = user_message.clone().try_into().unwrap();
441 let converted_assistant_message: Vec<Message> =
442 assistant_message.clone().try_into().unwrap();
443
444 match converted_user_message[0].clone() {
445 Message::User { content, .. } => {
446 assert_eq!(
447 content.first(),
448 UserContent::Text {
449 text: "Hello".to_string()
450 }
451 );
452 }
453 _ => panic!("Expected user message"),
454 }
455
456 match converted_assistant_message[0].clone() {
457 Message::Assistant { content, .. } => {
458 assert_eq!(
459 content[0].clone(),
460 AssistantContent::Text {
461 text: "Hi there!".to_string()
462 }
463 );
464 }
465 _ => panic!("Expected assistant message"),
466 }
467
468 let original_user_message: message::Message =
469 converted_user_message[0].clone().try_into().unwrap();
470 let original_assistant_message: message::Message =
471 converted_assistant_message[0].clone().try_into().unwrap();
472
473 assert_eq!(original_user_message, user_message);
474 assert_eq!(original_assistant_message, assistant_message);
475 }
476
477 #[test]
478 fn test_message_from_message_conversion() {
479 let user_message = Message::User {
480 content: OneOrMany::one(UserContent::Text {
481 text: "Hello".to_string(),
482 }),
483 name: None,
484 };
485
486 let assistant_message = Message::Assistant {
487 content: vec![AssistantContent::Text {
488 text: "Hi there!".to_string(),
489 }],
490 refusal: None,
491 audio: None,
492 name: None,
493 tool_calls: vec![],
494 };
495
496 let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
497 let converted_assistant_message: message::Message =
498 assistant_message.clone().try_into().unwrap();
499
500 match converted_user_message.clone() {
501 message::Message::User { content } => {
502 assert_eq!(content.first(), message::UserContent::text("Hello"));
503 }
504 _ => panic!("Expected user message"),
505 }
506
507 match converted_assistant_message.clone() {
508 message::Message::Assistant { content, .. } => {
509 assert_eq!(
510 content.first(),
511 message::AssistantContent::text("Hi there!")
512 );
513 }
514 _ => panic!("Expected assistant message"),
515 }
516
517 let original_user_message: Vec<Message> = converted_user_message.try_into().unwrap();
518 let original_assistant_message: Vec<Message> =
519 converted_assistant_message.try_into().unwrap();
520
521 assert_eq!(original_user_message[0], user_message);
522 assert_eq!(original_assistant_message[0], assistant_message);
523 }
524
525 #[test]
526 fn test_user_message_single_text_serializes_as_string() {
527 let user_message = Message::User {
528 content: OneOrMany::one(UserContent::Text {
529 text: "Hello world".to_string(),
530 }),
531 name: None,
532 };
533
534 let serialized = serde_json::to_value(&user_message).unwrap();
535
536 assert_eq!(serialized["role"], "user");
537 assert_eq!(serialized["content"], "Hello world");
538 }
539
540 #[test]
541 fn test_user_message_multiple_parts_serializes_as_array() {
542 let user_message = Message::User {
543 content: OneOrMany::many(vec![
544 UserContent::Text {
545 text: "What's in this image?".to_string(),
546 },
547 UserContent::Image {
548 image_url: ImageUrl {
549 url: "https://example.com/image.jpg".to_string(),
550 detail: ImageDetail::default(),
551 },
552 },
553 ])
554 .unwrap(),
555 name: None,
556 };
557
558 let serialized = serde_json::to_value(&user_message).unwrap();
559
560 assert_eq!(serialized["role"], "user");
561 assert!(serialized["content"].is_array());
562 assert_eq!(serialized["content"].as_array().unwrap().len(), 2);
563 }
564
565 #[test]
566 fn test_user_message_single_image_serializes_as_array() {
567 let user_message = Message::User {
568 content: OneOrMany::one(UserContent::Image {
569 image_url: ImageUrl {
570 url: "https://example.com/image.jpg".to_string(),
571 detail: ImageDetail::default(),
572 },
573 }),
574 name: None,
575 };
576
577 let serialized = serde_json::to_value(&user_message).unwrap();
578
579 assert_eq!(serialized["role"], "user");
580 assert!(serialized["content"].is_array());
582 }
583 #[test]
584 fn test_client_initialization() {
585 let _client: crate::providers::openai::Client =
586 crate::providers::openai::Client::new("dummy-key").expect("Client::new() failed");
587 let _client_from_builder: crate::providers::openai::Client =
588 crate::providers::openai::Client::builder()
589 .api_key("dummy-key")
590 .build()
591 .expect("Client::builder() failed");
592 }
593}