1use crate::{
2 client::{
3 self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
4 ProviderClient,
5 },
6 extractor::ExtractorBuilder,
7 http_client::{self, HttpClientExt},
8 prelude::CompletionClient,
9 wasm_compat::{WasmCompatSend, WasmCompatSync},
10};
11use schemars::JsonSchema;
12use serde::{Deserialize, Serialize};
13use std::fmt::Debug;
14
15const OPENAI_API_BASE_URL: &str = "https://api.openai.com/v1";
19
20#[derive(Debug, Default, Clone, Copy)]
24pub struct OpenAIResponsesExt;
25
26#[derive(Debug, Default, Clone, Copy)]
27pub struct OpenAIResponsesExtBuilder;
28
29#[derive(Debug, Default, Clone, Copy)]
33pub struct OpenAICompletionsExt;
34
35#[derive(Debug, Default, Clone, Copy)]
36pub struct OpenAICompletionsExtBuilder;
37
38type OpenAIApiKey = BearerAuth;
39
40pub type Client<H = reqwest::Client> = client::Client<OpenAIResponsesExt, H>;
42pub type ClientBuilder<H = crate::markers::Missing> =
43 client::ClientBuilder<OpenAIResponsesExtBuilder, OpenAIApiKey, H>;
44
45pub type CompletionsClient<H = reqwest::Client> = client::Client<OpenAICompletionsExt, H>;
47pub type CompletionsClientBuilder<H = crate::markers::Missing> =
48 client::ClientBuilder<OpenAICompletionsExtBuilder, OpenAIApiKey, H>;
49
50impl Provider for OpenAIResponsesExt {
51 type Builder = OpenAIResponsesExtBuilder;
52 const VERIFY_PATH: &'static str = "/models";
53}
54
55impl Provider for OpenAICompletionsExt {
56 type Builder = OpenAICompletionsExtBuilder;
57 const VERIFY_PATH: &'static str = "/models";
58}
59
60impl<H> Capabilities<H> for OpenAIResponsesExt {
61 type Completion = Capable<super::responses_api::ResponsesCompletionModel<H>>;
62 type Embeddings = Capable<super::EmbeddingModel<H>>;
63 type Transcription = Capable<super::TranscriptionModel<H>>;
64 type ModelListing = Capable<super::OpenAIModelLister<H>>;
65 #[cfg(feature = "image")]
66 type ImageGeneration = Capable<super::ImageGenerationModel<H>>;
67 #[cfg(feature = "audio")]
68 type AudioGeneration = Capable<super::audio_generation::AudioGenerationModel<H>>;
69 type Rerank = Nothing;
70}
71
72impl<H> Capabilities<H> for OpenAICompletionsExt {
73 type Completion = Capable<super::completion::CompletionModel<H>>;
74 type Embeddings = Capable<super::GenericEmbeddingModel<OpenAICompletionsExt, H>>;
75 type Transcription = Capable<super::TranscriptionModel<H>>;
76 type ModelListing = Capable<super::OpenAIModelLister<H>>;
77 #[cfg(feature = "image")]
78 type ImageGeneration = Capable<super::ImageGenerationModel<H>>;
79 #[cfg(feature = "audio")]
80 type AudioGeneration = Capable<super::audio_generation::AudioGenerationModel<H>>;
81 type Rerank = Nothing;
82}
83
84impl DebugExt for OpenAIResponsesExt {}
85
86impl DebugExt for OpenAICompletionsExt {}
87
88impl ProviderBuilder for OpenAIResponsesExtBuilder {
89 type Extension<H>
90 = OpenAIResponsesExt
91 where
92 H: HttpClientExt;
93 type ApiKey = OpenAIApiKey;
94
95 const BASE_URL: &'static str = OPENAI_API_BASE_URL;
96
97 fn build<H>(
98 _builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
99 ) -> http_client::Result<Self::Extension<H>>
100 where
101 H: HttpClientExt,
102 {
103 Ok(OpenAIResponsesExt)
104 }
105}
106
107impl ProviderBuilder for OpenAICompletionsExtBuilder {
108 type Extension<H>
109 = OpenAICompletionsExt
110 where
111 H: HttpClientExt;
112 type ApiKey = OpenAIApiKey;
113
114 const BASE_URL: &'static str = OPENAI_API_BASE_URL;
115
116 fn build<H>(
117 _builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
118 ) -> http_client::Result<Self::Extension<H>>
119 where
120 H: HttpClientExt,
121 {
122 Ok(OpenAICompletionsExt)
123 }
124}
125
126impl<H> Client<H>
127where
128 H: HttpClientExt
129 + Clone
130 + std::fmt::Debug
131 + Default
132 + WasmCompatSend
133 + WasmCompatSync
134 + 'static,
135{
136 pub fn extractor<U>(
139 &self,
140 model: impl Into<String>,
141 ) -> ExtractorBuilder<super::responses_api::ResponsesCompletionModel<H>, U>
142 where
143 U: JsonSchema + for<'a> Deserialize<'a> + Serialize + WasmCompatSend + WasmCompatSync,
144 {
145 ExtractorBuilder::new(self.completion_model(model))
146 }
147
148 pub fn completions_api(self) -> CompletionsClient<H> {
151 self.with_ext(OpenAICompletionsExt)
152 }
153}
154
155#[cfg(all(not(target_family = "wasm"), feature = "websocket"))]
156impl Client<reqwest::Client> {
157 pub fn responses_websocket_builder(
161 &self,
162 model: impl Into<String>,
163 ) -> super::responses_api::websocket::ResponsesWebSocketSessionBuilder {
164 super::responses_api::websocket::ResponsesWebSocketSessionBuilder::new(
165 self.completion_model(model),
166 )
167 }
168
169 pub async fn responses_websocket(
171 &self,
172 model: impl Into<String>,
173 ) -> Result<
174 super::responses_api::websocket::ResponsesWebSocketSession,
175 crate::completion::CompletionError,
176 > {
177 self.responses_websocket_builder(model).connect().await
178 }
179}
180
181impl<H> CompletionsClient<H>
182where
183 H: HttpClientExt
184 + Clone
185 + std::fmt::Debug
186 + Default
187 + WasmCompatSend
188 + WasmCompatSync
189 + 'static,
190{
191 pub fn extractor<U>(
194 &self,
195 model: impl Into<String>,
196 ) -> ExtractorBuilder<super::completion::CompletionModel<H>, U>
197 where
198 U: JsonSchema + for<'a> Deserialize<'a> + Serialize + WasmCompatSend + WasmCompatSync,
199 {
200 ExtractorBuilder::new(self.completion_model(model))
201 }
202
203 pub fn responses_api(self) -> Client<H> {
206 self.with_ext(OpenAIResponsesExt)
207 }
208}
209
210impl ProviderClient for Client {
211 type Input = OpenAIApiKey;
212 type Error = crate::client::ProviderClientError;
213
214 fn from_env() -> Result<Self, Self::Error> {
216 let base_url = crate::client::optional_env_var("OPENAI_BASE_URL")?;
217 let api_key = crate::client::required_env_var("OPENAI_API_KEY")?;
218
219 let mut builder = Client::builder().api_key(&api_key);
220
221 if let Some(base) = base_url {
222 builder = builder.base_url(&base);
223 }
224
225 builder.build().map_err(Into::into)
226 }
227
228 fn from_val(input: Self::Input) -> Result<Self, Self::Error> {
229 Self::new(input).map_err(Into::into)
230 }
231}
232
233impl ProviderClient for CompletionsClient {
234 type Input = OpenAIApiKey;
235 type Error = crate::client::ProviderClientError;
236
237 fn from_env() -> Result<Self, Self::Error> {
239 let base_url = crate::client::optional_env_var("OPENAI_BASE_URL")?;
240 let api_key = crate::client::required_env_var("OPENAI_API_KEY")?;
241
242 let mut builder = CompletionsClient::builder().api_key(&api_key);
243
244 if let Some(base) = base_url {
245 builder = builder.base_url(&base);
246 }
247
248 builder.build().map_err(Into::into)
249 }
250
251 fn from_val(input: Self::Input) -> Result<Self, Self::Error> {
252 Self::new(input).map_err(Into::into)
253 }
254}
255
256#[derive(Debug, Deserialize)]
257pub struct ApiErrorResponse {
258 pub(crate) message: String,
259}
260
261#[derive(Debug, Deserialize)]
262#[serde(untagged)]
263pub(crate) enum ApiResponse<T> {
264 Ok(T),
265 Err(ApiErrorResponse),
266}
267
268#[cfg(test)]
269mod tests {
270 use crate::client::{CompletionClient, EmbeddingsClient};
271 use crate::message::ImageDetail;
272 use crate::providers::openai::{
273 AssistantContent, Function, ImageUrl, Message, ToolCall, ToolType, UserContent,
274 };
275 use crate::{OneOrMany, message};
276 use serde_path_to_error::deserialize;
277
278 #[test]
279 fn test_deserialize_message() {
280 let assistant_message_json = r#"
281 {
282 "role": "assistant",
283 "content": "\n\nHello there, how may I assist you today?"
284 }
285 "#;
286
287 let assistant_message_json2 = r#"
288 {
289 "role": "assistant",
290 "content": [
291 {
292 "type": "text",
293 "text": "\n\nHello there, how may I assist you today?"
294 }
295 ],
296 "tool_calls": null
297 }
298 "#;
299
300 let assistant_message_json3 = r#"
301 {
302 "role": "assistant",
303 "tool_calls": [
304 {
305 "id": "call_h89ipqYUjEpCPI6SxspMnoUU",
306 "type": "function",
307 "function": {
308 "name": "subtract",
309 "arguments": "{\"x\": 2, \"y\": 5}"
310 }
311 }
312 ],
313 "content": null,
314 "refusal": null
315 }
316 "#;
317
318 let user_message_json = r#"
319 {
320 "role": "user",
321 "content": [
322 {
323 "type": "text",
324 "text": "What's in this image?"
325 },
326 {
327 "type": "image_url",
328 "image_url": {
329 "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
330 }
331 },
332 {
333 "type": "audio",
334 "input_audio": {
335 "data": "...",
336 "format": "mp3"
337 }
338 }
339 ]
340 }
341 "#;
342
343 let assistant_message: Message = {
344 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
345 deserialize(jd).unwrap_or_else(|err| {
346 panic!(
347 "Deserialization error at {} ({}:{}): {}",
348 err.path(),
349 err.inner().line(),
350 err.inner().column(),
351 err
352 );
353 })
354 };
355
356 let assistant_message2: Message = {
357 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
358 deserialize(jd).unwrap_or_else(|err| {
359 panic!(
360 "Deserialization error at {} ({}:{}): {}",
361 err.path(),
362 err.inner().line(),
363 err.inner().column(),
364 err
365 );
366 })
367 };
368
369 let assistant_message3: Message = {
370 let jd: &mut serde_json::Deserializer<serde_json::de::StrRead<'_>> =
371 &mut serde_json::Deserializer::from_str(assistant_message_json3);
372 deserialize(jd).unwrap_or_else(|err| {
373 panic!(
374 "Deserialization error at {} ({}:{}): {}",
375 err.path(),
376 err.inner().line(),
377 err.inner().column(),
378 err
379 );
380 })
381 };
382
383 let user_message: Message = {
384 let jd = &mut serde_json::Deserializer::from_str(user_message_json);
385 deserialize(jd).unwrap_or_else(|err| {
386 panic!(
387 "Deserialization error at {} ({}:{}): {}",
388 err.path(),
389 err.inner().line(),
390 err.inner().column(),
391 err
392 );
393 })
394 };
395
396 match assistant_message {
397 Message::Assistant { content, .. } => {
398 assert_eq!(
399 content[0],
400 AssistantContent::Text {
401 text: "\n\nHello there, how may I assist you today?".to_string()
402 }
403 );
404 }
405 _ => panic!("Expected assistant message"),
406 }
407
408 match assistant_message2 {
409 Message::Assistant {
410 content,
411 tool_calls,
412 ..
413 } => {
414 assert_eq!(
415 content[0],
416 AssistantContent::Text {
417 text: "\n\nHello there, how may I assist you today?".to_string()
418 }
419 );
420
421 assert_eq!(tool_calls, vec![]);
422 }
423 _ => panic!("Expected assistant message"),
424 }
425
426 match assistant_message3 {
427 Message::Assistant {
428 content,
429 tool_calls,
430 refusal,
431 ..
432 } => {
433 assert!(content.is_empty());
434 assert!(refusal.is_none());
435 assert_eq!(
436 tool_calls[0],
437 ToolCall {
438 id: "call_h89ipqYUjEpCPI6SxspMnoUU".to_string(),
439 r#type: ToolType::Function,
440 function: Function {
441 name: "subtract".to_string(),
442 arguments: serde_json::json!({"x": 2, "y": 5}),
443 },
444 }
445 );
446 }
447 _ => panic!("Expected assistant message"),
448 }
449
450 match user_message {
451 Message::User { content, .. } => {
452 let (first, second) = {
453 let mut iter = content.into_iter();
454 (iter.next().unwrap(), iter.next().unwrap())
455 };
456 assert_eq!(
457 first,
458 UserContent::Text {
459 text: "What's in this image?".to_string()
460 }
461 );
462 assert_eq!(second, UserContent::Image { image_url: ImageUrl { url: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg".to_string(), detail: ImageDetail::default() } });
463 }
464 _ => panic!("Expected user message"),
465 }
466 }
467
468 #[test]
469 fn test_message_to_message_conversion() {
470 let user_message = message::Message::User {
471 content: OneOrMany::one(message::UserContent::text("Hello")),
472 };
473
474 let assistant_message = message::Message::Assistant {
475 id: None,
476 content: OneOrMany::one(message::AssistantContent::text("Hi there!")),
477 };
478
479 let converted_user_message: Vec<Message> = user_message.clone().try_into().unwrap();
480 let converted_assistant_message: Vec<Message> =
481 assistant_message.clone().try_into().unwrap();
482
483 match converted_user_message[0].clone() {
484 Message::User { content, .. } => {
485 assert_eq!(
486 content.first(),
487 UserContent::Text {
488 text: "Hello".to_string()
489 }
490 );
491 }
492 _ => panic!("Expected user message"),
493 }
494
495 match converted_assistant_message[0].clone() {
496 Message::Assistant { content, .. } => {
497 assert_eq!(
498 content[0].clone(),
499 AssistantContent::Text {
500 text: "Hi there!".to_string()
501 }
502 );
503 }
504 _ => panic!("Expected assistant message"),
505 }
506
507 let original_user_message: message::Message =
508 converted_user_message[0].clone().try_into().unwrap();
509 let original_assistant_message: message::Message =
510 converted_assistant_message[0].clone().try_into().unwrap();
511
512 assert_eq!(original_user_message, user_message);
513 assert_eq!(original_assistant_message, assistant_message);
514 }
515
516 #[test]
517 fn test_message_from_message_conversion() {
518 let user_message = Message::User {
519 content: OneOrMany::one(UserContent::Text {
520 text: "Hello".to_string(),
521 }),
522 name: None,
523 };
524
525 let assistant_message = Message::Assistant {
526 content: vec![AssistantContent::Text {
527 text: "Hi there!".to_string(),
528 }],
529 reasoning: None,
530 refusal: None,
531 audio: None,
532 name: None,
533 tool_calls: vec![],
534 };
535
536 let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
537 let converted_assistant_message: message::Message =
538 assistant_message.clone().try_into().unwrap();
539
540 match converted_user_message.clone() {
541 message::Message::User { content } => {
542 assert_eq!(content.first(), message::UserContent::text("Hello"));
543 }
544 _ => panic!("Expected user message"),
545 }
546
547 match converted_assistant_message.clone() {
548 message::Message::Assistant { content, .. } => {
549 assert_eq!(
550 content.first(),
551 message::AssistantContent::text("Hi there!")
552 );
553 }
554 _ => panic!("Expected assistant message"),
555 }
556
557 let original_user_message: Vec<Message> = converted_user_message.try_into().unwrap();
558 let original_assistant_message: Vec<Message> =
559 converted_assistant_message.try_into().unwrap();
560
561 assert_eq!(original_user_message[0], user_message);
562 assert_eq!(original_assistant_message[0], assistant_message);
563 }
564
565 #[test]
566 fn test_user_message_single_text_serializes_as_string() {
567 let user_message = Message::User {
568 content: OneOrMany::one(UserContent::Text {
569 text: "Hello world".to_string(),
570 }),
571 name: None,
572 };
573
574 let serialized = serde_json::to_value(&user_message).unwrap();
575
576 assert_eq!(serialized["role"], "user");
577 assert_eq!(serialized["content"], "Hello world");
578 }
579
580 #[test]
581 fn test_user_message_multiple_parts_serializes_as_array() {
582 let user_message = Message::User {
583 content: OneOrMany::many(vec![
584 UserContent::Text {
585 text: "What's in this image?".to_string(),
586 },
587 UserContent::Image {
588 image_url: ImageUrl {
589 url: "https://example.com/image.jpg".to_string(),
590 detail: ImageDetail::default(),
591 },
592 },
593 ])
594 .unwrap(),
595 name: None,
596 };
597
598 let serialized = serde_json::to_value(&user_message).unwrap();
599
600 assert_eq!(serialized["role"], "user");
601 assert!(serialized["content"].is_array());
602 assert_eq!(serialized["content"].as_array().unwrap().len(), 2);
603 }
604
605 #[test]
606 fn test_user_message_single_image_serializes_as_array() {
607 let user_message = Message::User {
608 content: OneOrMany::one(UserContent::Image {
609 image_url: ImageUrl {
610 url: "https://example.com/image.jpg".to_string(),
611 detail: ImageDetail::default(),
612 },
613 }),
614 name: None,
615 };
616
617 let serialized = serde_json::to_value(&user_message).unwrap();
618
619 assert_eq!(serialized["role"], "user");
620 assert!(serialized["content"].is_array());
622 }
623 #[test]
624 fn test_client_initialization() {
625 let _client =
626 crate::providers::openai::Client::new("dummy-key").expect("Client::new() failed");
627 let _client_from_builder = crate::providers::openai::Client::builder()
628 .api_key("dummy-key")
629 .build()
630 .expect("Client::builder() failed");
631 }
632
633 #[test]
634 fn test_legacy_chat_completion_model_type_annotation_still_compiles() {
635 let client = crate::providers::openai::Client::new("dummy-key")
636 .expect("Client::new() failed")
637 .completions_api();
638
639 let _model: crate::providers::openai::completion::CompletionModel<reqwest::Client> =
640 client.completion_model("gpt-4o");
641 }
642
643 #[test]
644 fn test_legacy_embedding_model_type_annotation_still_compiles() {
645 let client =
646 crate::providers::openai::Client::new("dummy-key").expect("Client::new() failed");
647
648 let _model: crate::providers::openai::EmbeddingModel<reqwest::Client> =
649 client.embedding_model(crate::providers::openai::TEXT_EMBEDDING_3_SMALL);
650 }
651}