1use crate::{
2 client::{
3 self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
4 ProviderClient,
5 },
6 extractor::ExtractorBuilder,
7 http_client::{self, HttpClientExt},
8 prelude::CompletionClient,
9 wasm_compat::{WasmCompatSend, WasmCompatSync},
10};
11use schemars::JsonSchema;
12use serde::{Deserialize, Serialize};
13use std::fmt::Debug;
14
15const OPENAI_API_BASE_URL: &str = "https://api.openai.com/v1";
19
20#[derive(Debug, Default, Clone, Copy)]
24pub struct OpenAIResponsesExt;
25
26#[derive(Debug, Default, Clone, Copy)]
27pub struct OpenAIResponsesExtBuilder;
28
29#[derive(Debug, Default, Clone, Copy)]
33pub struct OpenAICompletionsExt;
34
35#[derive(Debug, Default, Clone, Copy)]
36pub struct OpenAICompletionsExtBuilder;
37
38type OpenAIApiKey = BearerAuth;
39
40pub type Client<H = reqwest::Client> = client::Client<OpenAIResponsesExt, H>;
42pub type ClientBuilder<H = reqwest::Client> =
43 client::ClientBuilder<OpenAIResponsesExtBuilder, OpenAIApiKey, H>;
44
45pub type CompletionsClient<H = reqwest::Client> = client::Client<OpenAICompletionsExt, H>;
47pub type CompletionsClientBuilder<H = reqwest::Client> =
48 client::ClientBuilder<OpenAICompletionsExtBuilder, OpenAIApiKey, H>;
49
50impl Provider for OpenAIResponsesExt {
51 type Builder = OpenAIResponsesExtBuilder;
52 const VERIFY_PATH: &'static str = "/models";
53}
54
55impl Provider for OpenAICompletionsExt {
56 type Builder = OpenAICompletionsExtBuilder;
57 const VERIFY_PATH: &'static str = "/models";
58}
59
60impl<H> Capabilities<H> for OpenAIResponsesExt {
61 type Completion = Capable<super::responses_api::ResponsesCompletionModel<H>>;
62 type Embeddings = Capable<super::EmbeddingModel<H>>;
63 type Transcription = Capable<super::TranscriptionModel<H>>;
64 type ModelListing = Nothing;
65 #[cfg(feature = "image")]
66 type ImageGeneration = Capable<super::ImageGenerationModel<H>>;
67 #[cfg(feature = "audio")]
68 type AudioGeneration = Capable<super::audio_generation::AudioGenerationModel<H>>;
69}
70
71impl<H> Capabilities<H> for OpenAICompletionsExt {
72 type Completion = Capable<super::completion::CompletionModel<H>>;
73 type Embeddings = Capable<super::EmbeddingModel<H>>;
74 type Transcription = Capable<super::TranscriptionModel<H>>;
75 type ModelListing = Nothing;
76 #[cfg(feature = "image")]
77 type ImageGeneration = Capable<super::ImageGenerationModel<H>>;
78 #[cfg(feature = "audio")]
79 type AudioGeneration = Capable<super::audio_generation::AudioGenerationModel<H>>;
80}
81
82impl DebugExt for OpenAIResponsesExt {}
83
84impl DebugExt for OpenAICompletionsExt {}
85
86impl ProviderBuilder for OpenAIResponsesExtBuilder {
87 type Extension<H>
88 = OpenAIResponsesExt
89 where
90 H: HttpClientExt;
91 type ApiKey = OpenAIApiKey;
92
93 const BASE_URL: &'static str = OPENAI_API_BASE_URL;
94
95 fn build<H>(
96 _builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
97 ) -> http_client::Result<Self::Extension<H>>
98 where
99 H: HttpClientExt,
100 {
101 Ok(OpenAIResponsesExt)
102 }
103}
104
105impl ProviderBuilder for OpenAICompletionsExtBuilder {
106 type Extension<H>
107 = OpenAICompletionsExt
108 where
109 H: HttpClientExt;
110 type ApiKey = OpenAIApiKey;
111
112 const BASE_URL: &'static str = OPENAI_API_BASE_URL;
113
114 fn build<H>(
115 _builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
116 ) -> http_client::Result<Self::Extension<H>>
117 where
118 H: HttpClientExt,
119 {
120 Ok(OpenAICompletionsExt)
121 }
122}
123
124impl<H> Client<H>
125where
126 H: HttpClientExt
127 + Clone
128 + std::fmt::Debug
129 + Default
130 + WasmCompatSend
131 + WasmCompatSync
132 + 'static,
133{
134 pub fn extractor<U>(
137 &self,
138 model: impl Into<String>,
139 ) -> ExtractorBuilder<super::responses_api::ResponsesCompletionModel<H>, U>
140 where
141 U: JsonSchema + for<'a> Deserialize<'a> + Serialize + WasmCompatSend + WasmCompatSync,
142 {
143 ExtractorBuilder::new(self.completion_model(model))
144 }
145
146 pub fn completions_api(self) -> CompletionsClient<H> {
149 self.with_ext(OpenAICompletionsExt)
150 }
151}
152
153impl<H> CompletionsClient<H>
154where
155 H: HttpClientExt
156 + Clone
157 + std::fmt::Debug
158 + Default
159 + WasmCompatSend
160 + WasmCompatSync
161 + 'static,
162{
163 pub fn extractor<U>(
166 &self,
167 model: impl Into<String>,
168 ) -> ExtractorBuilder<super::completion::CompletionModel<H>, U>
169 where
170 U: JsonSchema + for<'a> Deserialize<'a> + Serialize + WasmCompatSend + WasmCompatSync,
171 {
172 ExtractorBuilder::new(self.completion_model(model))
173 }
174
175 pub fn responses_api(self) -> Client<H> {
178 self.with_ext(OpenAIResponsesExt)
179 }
180}
181
182impl ProviderClient for Client {
183 type Input = OpenAIApiKey;
184
185 fn from_env() -> Self {
188 let base_url: Option<String> = std::env::var("OPENAI_BASE_URL").ok();
189 let api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set");
190
191 let mut builder = Client::builder().api_key(&api_key);
192
193 if let Some(base) = base_url {
194 builder = builder.base_url(&base);
195 }
196
197 builder.build().unwrap()
198 }
199
200 fn from_val(input: Self::Input) -> Self {
201 Self::new(input).unwrap()
202 }
203}
204
205impl ProviderClient for CompletionsClient {
206 type Input = OpenAIApiKey;
207
208 fn from_env() -> Self {
211 let base_url: Option<String> = std::env::var("OPENAI_BASE_URL").ok();
212 let api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set");
213
214 let mut builder = CompletionsClient::builder().api_key(&api_key);
215
216 if let Some(base) = base_url {
217 builder = builder.base_url(&base);
218 }
219
220 builder.build().unwrap()
221 }
222
223 fn from_val(input: Self::Input) -> Self {
224 Self::new(input).unwrap()
225 }
226}
227
228#[derive(Debug, Deserialize)]
229pub struct ApiErrorResponse {
230 pub(crate) message: String,
231}
232
233#[derive(Debug, Deserialize)]
234#[serde(untagged)]
235pub(crate) enum ApiResponse<T> {
236 Ok(T),
237 Err(ApiErrorResponse),
238}
239
240#[cfg(test)]
241mod tests {
242 use crate::message::ImageDetail;
243 use crate::providers::openai::{
244 AssistantContent, Function, ImageUrl, Message, ToolCall, ToolType, UserContent,
245 };
246 use crate::{OneOrMany, message};
247 use serde_path_to_error::deserialize;
248
249 #[test]
250 fn test_deserialize_message() {
251 let assistant_message_json = r#"
252 {
253 "role": "assistant",
254 "content": "\n\nHello there, how may I assist you today?"
255 }
256 "#;
257
258 let assistant_message_json2 = r#"
259 {
260 "role": "assistant",
261 "content": [
262 {
263 "type": "text",
264 "text": "\n\nHello there, how may I assist you today?"
265 }
266 ],
267 "tool_calls": null
268 }
269 "#;
270
271 let assistant_message_json3 = r#"
272 {
273 "role": "assistant",
274 "tool_calls": [
275 {
276 "id": "call_h89ipqYUjEpCPI6SxspMnoUU",
277 "type": "function",
278 "function": {
279 "name": "subtract",
280 "arguments": "{\"x\": 2, \"y\": 5}"
281 }
282 }
283 ],
284 "content": null,
285 "refusal": null
286 }
287 "#;
288
289 let user_message_json = r#"
290 {
291 "role": "user",
292 "content": [
293 {
294 "type": "text",
295 "text": "What's in this image?"
296 },
297 {
298 "type": "image_url",
299 "image_url": {
300 "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
301 }
302 },
303 {
304 "type": "audio",
305 "input_audio": {
306 "data": "...",
307 "format": "mp3"
308 }
309 }
310 ]
311 }
312 "#;
313
314 let assistant_message: Message = {
315 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
316 deserialize(jd).unwrap_or_else(|err| {
317 panic!(
318 "Deserialization error at {} ({}:{}): {}",
319 err.path(),
320 err.inner().line(),
321 err.inner().column(),
322 err
323 );
324 })
325 };
326
327 let assistant_message2: Message = {
328 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
329 deserialize(jd).unwrap_or_else(|err| {
330 panic!(
331 "Deserialization error at {} ({}:{}): {}",
332 err.path(),
333 err.inner().line(),
334 err.inner().column(),
335 err
336 );
337 })
338 };
339
340 let assistant_message3: Message = {
341 let jd: &mut serde_json::Deserializer<serde_json::de::StrRead<'_>> =
342 &mut serde_json::Deserializer::from_str(assistant_message_json3);
343 deserialize(jd).unwrap_or_else(|err| {
344 panic!(
345 "Deserialization error at {} ({}:{}): {}",
346 err.path(),
347 err.inner().line(),
348 err.inner().column(),
349 err
350 );
351 })
352 };
353
354 let user_message: Message = {
355 let jd = &mut serde_json::Deserializer::from_str(user_message_json);
356 deserialize(jd).unwrap_or_else(|err| {
357 panic!(
358 "Deserialization error at {} ({}:{}): {}",
359 err.path(),
360 err.inner().line(),
361 err.inner().column(),
362 err
363 );
364 })
365 };
366
367 match assistant_message {
368 Message::Assistant { content, .. } => {
369 assert_eq!(
370 content[0],
371 AssistantContent::Text {
372 text: "\n\nHello there, how may I assist you today?".to_string()
373 }
374 );
375 }
376 _ => panic!("Expected assistant message"),
377 }
378
379 match assistant_message2 {
380 Message::Assistant {
381 content,
382 tool_calls,
383 ..
384 } => {
385 assert_eq!(
386 content[0],
387 AssistantContent::Text {
388 text: "\n\nHello there, how may I assist you today?".to_string()
389 }
390 );
391
392 assert_eq!(tool_calls, vec![]);
393 }
394 _ => panic!("Expected assistant message"),
395 }
396
397 match assistant_message3 {
398 Message::Assistant {
399 content,
400 tool_calls,
401 refusal,
402 ..
403 } => {
404 assert!(content.is_empty());
405 assert!(refusal.is_none());
406 assert_eq!(
407 tool_calls[0],
408 ToolCall {
409 id: "call_h89ipqYUjEpCPI6SxspMnoUU".to_string(),
410 r#type: ToolType::Function,
411 function: Function {
412 name: "subtract".to_string(),
413 arguments: serde_json::json!({"x": 2, "y": 5}),
414 },
415 }
416 );
417 }
418 _ => panic!("Expected assistant message"),
419 }
420
421 match user_message {
422 Message::User { content, .. } => {
423 let (first, second) = {
424 let mut iter = content.into_iter();
425 (iter.next().unwrap(), iter.next().unwrap())
426 };
427 assert_eq!(
428 first,
429 UserContent::Text {
430 text: "What's in this image?".to_string()
431 }
432 );
433 assert_eq!(second, UserContent::Image { image_url: ImageUrl { url: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg".to_string(), detail: ImageDetail::default() } });
434 }
435 _ => panic!("Expected user message"),
436 }
437 }
438
439 #[test]
440 fn test_message_to_message_conversion() {
441 let user_message = message::Message::User {
442 content: OneOrMany::one(message::UserContent::text("Hello")),
443 };
444
445 let assistant_message = message::Message::Assistant {
446 id: None,
447 content: OneOrMany::one(message::AssistantContent::text("Hi there!")),
448 };
449
450 let converted_user_message: Vec<Message> = user_message.clone().try_into().unwrap();
451 let converted_assistant_message: Vec<Message> =
452 assistant_message.clone().try_into().unwrap();
453
454 match converted_user_message[0].clone() {
455 Message::User { content, .. } => {
456 assert_eq!(
457 content.first(),
458 UserContent::Text {
459 text: "Hello".to_string()
460 }
461 );
462 }
463 _ => panic!("Expected user message"),
464 }
465
466 match converted_assistant_message[0].clone() {
467 Message::Assistant { content, .. } => {
468 assert_eq!(
469 content[0].clone(),
470 AssistantContent::Text {
471 text: "Hi there!".to_string()
472 }
473 );
474 }
475 _ => panic!("Expected assistant message"),
476 }
477
478 let original_user_message: message::Message =
479 converted_user_message[0].clone().try_into().unwrap();
480 let original_assistant_message: message::Message =
481 converted_assistant_message[0].clone().try_into().unwrap();
482
483 assert_eq!(original_user_message, user_message);
484 assert_eq!(original_assistant_message, assistant_message);
485 }
486
487 #[test]
488 fn test_message_from_message_conversion() {
489 let user_message = Message::User {
490 content: OneOrMany::one(UserContent::Text {
491 text: "Hello".to_string(),
492 }),
493 name: None,
494 };
495
496 let assistant_message = Message::Assistant {
497 content: vec![AssistantContent::Text {
498 text: "Hi there!".to_string(),
499 }],
500 refusal: None,
501 audio: None,
502 name: None,
503 tool_calls: vec![],
504 };
505
506 let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
507 let converted_assistant_message: message::Message =
508 assistant_message.clone().try_into().unwrap();
509
510 match converted_user_message.clone() {
511 message::Message::User { content } => {
512 assert_eq!(content.first(), message::UserContent::text("Hello"));
513 }
514 _ => panic!("Expected user message"),
515 }
516
517 match converted_assistant_message.clone() {
518 message::Message::Assistant { content, .. } => {
519 assert_eq!(
520 content.first(),
521 message::AssistantContent::text("Hi there!")
522 );
523 }
524 _ => panic!("Expected assistant message"),
525 }
526
527 let original_user_message: Vec<Message> = converted_user_message.try_into().unwrap();
528 let original_assistant_message: Vec<Message> =
529 converted_assistant_message.try_into().unwrap();
530
531 assert_eq!(original_user_message[0], user_message);
532 assert_eq!(original_assistant_message[0], assistant_message);
533 }
534
535 #[test]
536 fn test_user_message_single_text_serializes_as_string() {
537 let user_message = Message::User {
538 content: OneOrMany::one(UserContent::Text {
539 text: "Hello world".to_string(),
540 }),
541 name: None,
542 };
543
544 let serialized = serde_json::to_value(&user_message).unwrap();
545
546 assert_eq!(serialized["role"], "user");
547 assert_eq!(serialized["content"], "Hello world");
548 }
549
550 #[test]
551 fn test_user_message_multiple_parts_serializes_as_array() {
552 let user_message = Message::User {
553 content: OneOrMany::many(vec![
554 UserContent::Text {
555 text: "What's in this image?".to_string(),
556 },
557 UserContent::Image {
558 image_url: ImageUrl {
559 url: "https://example.com/image.jpg".to_string(),
560 detail: ImageDetail::default(),
561 },
562 },
563 ])
564 .unwrap(),
565 name: None,
566 };
567
568 let serialized = serde_json::to_value(&user_message).unwrap();
569
570 assert_eq!(serialized["role"], "user");
571 assert!(serialized["content"].is_array());
572 assert_eq!(serialized["content"].as_array().unwrap().len(), 2);
573 }
574
575 #[test]
576 fn test_user_message_single_image_serializes_as_array() {
577 let user_message = Message::User {
578 content: OneOrMany::one(UserContent::Image {
579 image_url: ImageUrl {
580 url: "https://example.com/image.jpg".to_string(),
581 detail: ImageDetail::default(),
582 },
583 }),
584 name: None,
585 };
586
587 let serialized = serde_json::to_value(&user_message).unwrap();
588
589 assert_eq!(serialized["role"], "user");
590 assert!(serialized["content"].is_array());
592 }
593 #[test]
594 fn test_client_initialization() {
595 let _client =
596 crate::providers::openai::Client::new("dummy-key").expect("Client::new() failed");
597 let _client_from_builder = crate::providers::openai::Client::builder()
598 .api_key("dummy-key")
599 .build()
600 .expect("Client::builder() failed");
601 }
602}