1use crate::{
2 client::{
3 self, BearerAuth, Capabilities, Capable, DebugExt, Provider, ProviderBuilder,
4 ProviderClient,
5 },
6 extractor::ExtractorBuilder,
7 http_client::{self, HttpClientExt},
8 prelude::CompletionClient,
9 wasm_compat::{WasmCompatSend, WasmCompatSync},
10};
11use schemars::JsonSchema;
12use serde::{Deserialize, Serialize};
13use std::fmt::Debug;
14
15const OPENAI_API_BASE_URL: &str = "https://api.openai.com/v1";
19
20#[derive(Debug, Default, Clone, Copy)]
24pub struct OpenAIResponsesExt;
25
26#[derive(Debug, Default, Clone, Copy)]
27pub struct OpenAIResponsesExtBuilder;
28
29#[derive(Debug, Default, Clone, Copy)]
33pub struct OpenAICompletionsExt;
34
35#[derive(Debug, Default, Clone, Copy)]
36pub struct OpenAICompletionsExtBuilder;
37
38type OpenAIApiKey = BearerAuth;
39
40pub type Client<H = reqwest::Client> = client::Client<OpenAIResponsesExt, H>;
42pub type ClientBuilder<H = reqwest::Client> =
43 client::ClientBuilder<OpenAIResponsesExtBuilder, OpenAIApiKey, H>;
44
45pub type CompletionsClient<H = reqwest::Client> = client::Client<OpenAICompletionsExt, H>;
47pub type CompletionsClientBuilder<H = reqwest::Client> =
48 client::ClientBuilder<OpenAICompletionsExtBuilder, OpenAIApiKey, H>;
49
50impl Provider for OpenAIResponsesExt {
51 type Builder = OpenAIResponsesExtBuilder;
52 const VERIFY_PATH: &'static str = "/models";
53}
54
55impl Provider for OpenAICompletionsExt {
56 type Builder = OpenAICompletionsExtBuilder;
57 const VERIFY_PATH: &'static str = "/models";
58}
59
60impl<H> Capabilities<H> for OpenAIResponsesExt {
61 type Completion = Capable<super::responses_api::ResponsesCompletionModel<H>>;
62 type Embeddings = Capable<super::EmbeddingModel<H>>;
63 type Transcription = Capable<super::TranscriptionModel<H>>;
64 type ModelListing = Capable<super::OpenAIModelLister<H>>;
65 #[cfg(feature = "image")]
66 type ImageGeneration = Capable<super::ImageGenerationModel<H>>;
67 #[cfg(feature = "audio")]
68 type AudioGeneration = Capable<super::audio_generation::AudioGenerationModel<H>>;
69}
70
71impl<H> Capabilities<H> for OpenAICompletionsExt {
72 type Completion = Capable<super::completion::CompletionModel<H>>;
73 type Embeddings = Capable<super::GenericEmbeddingModel<OpenAICompletionsExt, H>>;
74 type Transcription = Capable<super::TranscriptionModel<H>>;
75 type ModelListing = Capable<super::OpenAIModelLister<H>>;
76 #[cfg(feature = "image")]
77 type ImageGeneration = Capable<super::ImageGenerationModel<H>>;
78 #[cfg(feature = "audio")]
79 type AudioGeneration = Capable<super::audio_generation::AudioGenerationModel<H>>;
80}
81
82impl DebugExt for OpenAIResponsesExt {}
83
84impl DebugExt for OpenAICompletionsExt {}
85
86impl ProviderBuilder for OpenAIResponsesExtBuilder {
87 type Extension<H>
88 = OpenAIResponsesExt
89 where
90 H: HttpClientExt;
91 type ApiKey = OpenAIApiKey;
92
93 const BASE_URL: &'static str = OPENAI_API_BASE_URL;
94
95 fn build<H>(
96 _builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
97 ) -> http_client::Result<Self::Extension<H>>
98 where
99 H: HttpClientExt,
100 {
101 Ok(OpenAIResponsesExt)
102 }
103}
104
105impl ProviderBuilder for OpenAICompletionsExtBuilder {
106 type Extension<H>
107 = OpenAICompletionsExt
108 where
109 H: HttpClientExt;
110 type ApiKey = OpenAIApiKey;
111
112 const BASE_URL: &'static str = OPENAI_API_BASE_URL;
113
114 fn build<H>(
115 _builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
116 ) -> http_client::Result<Self::Extension<H>>
117 where
118 H: HttpClientExt,
119 {
120 Ok(OpenAICompletionsExt)
121 }
122}
123
124impl<H> Client<H>
125where
126 H: HttpClientExt
127 + Clone
128 + std::fmt::Debug
129 + Default
130 + WasmCompatSend
131 + WasmCompatSync
132 + 'static,
133{
134 pub fn extractor<U>(
137 &self,
138 model: impl Into<String>,
139 ) -> ExtractorBuilder<super::responses_api::ResponsesCompletionModel<H>, U>
140 where
141 U: JsonSchema + for<'a> Deserialize<'a> + Serialize + WasmCompatSend + WasmCompatSync,
142 {
143 ExtractorBuilder::new(self.completion_model(model))
144 }
145
146 pub fn completions_api(self) -> CompletionsClient<H> {
149 self.with_ext(OpenAICompletionsExt)
150 }
151}
152
153#[cfg(all(not(target_family = "wasm"), feature = "websocket"))]
154impl Client<reqwest::Client> {
155 pub fn responses_websocket_builder(
159 &self,
160 model: impl Into<String>,
161 ) -> super::responses_api::websocket::ResponsesWebSocketSessionBuilder {
162 super::responses_api::websocket::ResponsesWebSocketSessionBuilder::new(
163 self.completion_model(model),
164 )
165 }
166
167 pub async fn responses_websocket(
169 &self,
170 model: impl Into<String>,
171 ) -> Result<
172 super::responses_api::websocket::ResponsesWebSocketSession,
173 crate::completion::CompletionError,
174 > {
175 self.responses_websocket_builder(model).connect().await
176 }
177}
178
179impl<H> CompletionsClient<H>
180where
181 H: HttpClientExt
182 + Clone
183 + std::fmt::Debug
184 + Default
185 + WasmCompatSend
186 + WasmCompatSync
187 + 'static,
188{
189 pub fn extractor<U>(
192 &self,
193 model: impl Into<String>,
194 ) -> ExtractorBuilder<super::completion::CompletionModel<H>, U>
195 where
196 U: JsonSchema + for<'a> Deserialize<'a> + Serialize + WasmCompatSend + WasmCompatSync,
197 {
198 ExtractorBuilder::new(self.completion_model(model))
199 }
200
201 pub fn responses_api(self) -> Client<H> {
204 self.with_ext(OpenAIResponsesExt)
205 }
206}
207
208impl ProviderClient for Client {
209 type Input = OpenAIApiKey;
210 type Error = crate::client::ProviderClientError;
211
212 fn from_env() -> Result<Self, Self::Error> {
214 let base_url = crate::client::optional_env_var("OPENAI_BASE_URL")?;
215 let api_key = crate::client::required_env_var("OPENAI_API_KEY")?;
216
217 let mut builder = Client::builder().api_key(&api_key);
218
219 if let Some(base) = base_url {
220 builder = builder.base_url(&base);
221 }
222
223 builder.build().map_err(Into::into)
224 }
225
226 fn from_val(input: Self::Input) -> Result<Self, Self::Error> {
227 Self::new(input).map_err(Into::into)
228 }
229}
230
231impl ProviderClient for CompletionsClient {
232 type Input = OpenAIApiKey;
233 type Error = crate::client::ProviderClientError;
234
235 fn from_env() -> Result<Self, Self::Error> {
237 let base_url = crate::client::optional_env_var("OPENAI_BASE_URL")?;
238 let api_key = crate::client::required_env_var("OPENAI_API_KEY")?;
239
240 let mut builder = CompletionsClient::builder().api_key(&api_key);
241
242 if let Some(base) = base_url {
243 builder = builder.base_url(&base);
244 }
245
246 builder.build().map_err(Into::into)
247 }
248
249 fn from_val(input: Self::Input) -> Result<Self, Self::Error> {
250 Self::new(input).map_err(Into::into)
251 }
252}
253
254#[derive(Debug, Deserialize)]
255pub struct ApiErrorResponse {
256 pub(crate) message: String,
257}
258
259#[derive(Debug, Deserialize)]
260#[serde(untagged)]
261pub(crate) enum ApiResponse<T> {
262 Ok(T),
263 Err(ApiErrorResponse),
264}
265
266#[cfg(test)]
267mod tests {
268 use crate::client::{CompletionClient, EmbeddingsClient};
269 use crate::message::ImageDetail;
270 use crate::providers::openai::{
271 AssistantContent, Function, ImageUrl, Message, ToolCall, ToolType, UserContent,
272 };
273 use crate::{OneOrMany, message};
274 use serde_path_to_error::deserialize;
275
276 #[test]
277 fn test_deserialize_message() {
278 let assistant_message_json = r#"
279 {
280 "role": "assistant",
281 "content": "\n\nHello there, how may I assist you today?"
282 }
283 "#;
284
285 let assistant_message_json2 = r#"
286 {
287 "role": "assistant",
288 "content": [
289 {
290 "type": "text",
291 "text": "\n\nHello there, how may I assist you today?"
292 }
293 ],
294 "tool_calls": null
295 }
296 "#;
297
298 let assistant_message_json3 = r#"
299 {
300 "role": "assistant",
301 "tool_calls": [
302 {
303 "id": "call_h89ipqYUjEpCPI6SxspMnoUU",
304 "type": "function",
305 "function": {
306 "name": "subtract",
307 "arguments": "{\"x\": 2, \"y\": 5}"
308 }
309 }
310 ],
311 "content": null,
312 "refusal": null
313 }
314 "#;
315
316 let user_message_json = r#"
317 {
318 "role": "user",
319 "content": [
320 {
321 "type": "text",
322 "text": "What's in this image?"
323 },
324 {
325 "type": "image_url",
326 "image_url": {
327 "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
328 }
329 },
330 {
331 "type": "audio",
332 "input_audio": {
333 "data": "...",
334 "format": "mp3"
335 }
336 }
337 ]
338 }
339 "#;
340
341 let assistant_message: Message = {
342 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
343 deserialize(jd).unwrap_or_else(|err| {
344 panic!(
345 "Deserialization error at {} ({}:{}): {}",
346 err.path(),
347 err.inner().line(),
348 err.inner().column(),
349 err
350 );
351 })
352 };
353
354 let assistant_message2: Message = {
355 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
356 deserialize(jd).unwrap_or_else(|err| {
357 panic!(
358 "Deserialization error at {} ({}:{}): {}",
359 err.path(),
360 err.inner().line(),
361 err.inner().column(),
362 err
363 );
364 })
365 };
366
367 let assistant_message3: Message = {
368 let jd: &mut serde_json::Deserializer<serde_json::de::StrRead<'_>> =
369 &mut serde_json::Deserializer::from_str(assistant_message_json3);
370 deserialize(jd).unwrap_or_else(|err| {
371 panic!(
372 "Deserialization error at {} ({}:{}): {}",
373 err.path(),
374 err.inner().line(),
375 err.inner().column(),
376 err
377 );
378 })
379 };
380
381 let user_message: Message = {
382 let jd = &mut serde_json::Deserializer::from_str(user_message_json);
383 deserialize(jd).unwrap_or_else(|err| {
384 panic!(
385 "Deserialization error at {} ({}:{}): {}",
386 err.path(),
387 err.inner().line(),
388 err.inner().column(),
389 err
390 );
391 })
392 };
393
394 match assistant_message {
395 Message::Assistant { content, .. } => {
396 assert_eq!(
397 content[0],
398 AssistantContent::Text {
399 text: "\n\nHello there, how may I assist you today?".to_string()
400 }
401 );
402 }
403 _ => panic!("Expected assistant message"),
404 }
405
406 match assistant_message2 {
407 Message::Assistant {
408 content,
409 tool_calls,
410 ..
411 } => {
412 assert_eq!(
413 content[0],
414 AssistantContent::Text {
415 text: "\n\nHello there, how may I assist you today?".to_string()
416 }
417 );
418
419 assert_eq!(tool_calls, vec![]);
420 }
421 _ => panic!("Expected assistant message"),
422 }
423
424 match assistant_message3 {
425 Message::Assistant {
426 content,
427 tool_calls,
428 refusal,
429 ..
430 } => {
431 assert!(content.is_empty());
432 assert!(refusal.is_none());
433 assert_eq!(
434 tool_calls[0],
435 ToolCall {
436 id: "call_h89ipqYUjEpCPI6SxspMnoUU".to_string(),
437 r#type: ToolType::Function,
438 function: Function {
439 name: "subtract".to_string(),
440 arguments: serde_json::json!({"x": 2, "y": 5}),
441 },
442 }
443 );
444 }
445 _ => panic!("Expected assistant message"),
446 }
447
448 match user_message {
449 Message::User { content, .. } => {
450 let (first, second) = {
451 let mut iter = content.into_iter();
452 (iter.next().unwrap(), iter.next().unwrap())
453 };
454 assert_eq!(
455 first,
456 UserContent::Text {
457 text: "What's in this image?".to_string()
458 }
459 );
460 assert_eq!(second, UserContent::Image { image_url: ImageUrl { url: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg".to_string(), detail: ImageDetail::default() } });
461 }
462 _ => panic!("Expected user message"),
463 }
464 }
465
466 #[test]
467 fn test_message_to_message_conversion() {
468 let user_message = message::Message::User {
469 content: OneOrMany::one(message::UserContent::text("Hello")),
470 };
471
472 let assistant_message = message::Message::Assistant {
473 id: None,
474 content: OneOrMany::one(message::AssistantContent::text("Hi there!")),
475 };
476
477 let converted_user_message: Vec<Message> = user_message.clone().try_into().unwrap();
478 let converted_assistant_message: Vec<Message> =
479 assistant_message.clone().try_into().unwrap();
480
481 match converted_user_message[0].clone() {
482 Message::User { content, .. } => {
483 assert_eq!(
484 content.first(),
485 UserContent::Text {
486 text: "Hello".to_string()
487 }
488 );
489 }
490 _ => panic!("Expected user message"),
491 }
492
493 match converted_assistant_message[0].clone() {
494 Message::Assistant { content, .. } => {
495 assert_eq!(
496 content[0].clone(),
497 AssistantContent::Text {
498 text: "Hi there!".to_string()
499 }
500 );
501 }
502 _ => panic!("Expected assistant message"),
503 }
504
505 let original_user_message: message::Message =
506 converted_user_message[0].clone().try_into().unwrap();
507 let original_assistant_message: message::Message =
508 converted_assistant_message[0].clone().try_into().unwrap();
509
510 assert_eq!(original_user_message, user_message);
511 assert_eq!(original_assistant_message, assistant_message);
512 }
513
514 #[test]
515 fn test_message_from_message_conversion() {
516 let user_message = Message::User {
517 content: OneOrMany::one(UserContent::Text {
518 text: "Hello".to_string(),
519 }),
520 name: None,
521 };
522
523 let assistant_message = Message::Assistant {
524 content: vec![AssistantContent::Text {
525 text: "Hi there!".to_string(),
526 }],
527 reasoning: None,
528 refusal: None,
529 audio: None,
530 name: None,
531 tool_calls: vec![],
532 };
533
534 let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
535 let converted_assistant_message: message::Message =
536 assistant_message.clone().try_into().unwrap();
537
538 match converted_user_message.clone() {
539 message::Message::User { content } => {
540 assert_eq!(content.first(), message::UserContent::text("Hello"));
541 }
542 _ => panic!("Expected user message"),
543 }
544
545 match converted_assistant_message.clone() {
546 message::Message::Assistant { content, .. } => {
547 assert_eq!(
548 content.first(),
549 message::AssistantContent::text("Hi there!")
550 );
551 }
552 _ => panic!("Expected assistant message"),
553 }
554
555 let original_user_message: Vec<Message> = converted_user_message.try_into().unwrap();
556 let original_assistant_message: Vec<Message> =
557 converted_assistant_message.try_into().unwrap();
558
559 assert_eq!(original_user_message[0], user_message);
560 assert_eq!(original_assistant_message[0], assistant_message);
561 }
562
563 #[test]
564 fn test_user_message_single_text_serializes_as_string() {
565 let user_message = Message::User {
566 content: OneOrMany::one(UserContent::Text {
567 text: "Hello world".to_string(),
568 }),
569 name: None,
570 };
571
572 let serialized = serde_json::to_value(&user_message).unwrap();
573
574 assert_eq!(serialized["role"], "user");
575 assert_eq!(serialized["content"], "Hello world");
576 }
577
578 #[test]
579 fn test_user_message_multiple_parts_serializes_as_array() {
580 let user_message = Message::User {
581 content: OneOrMany::many(vec![
582 UserContent::Text {
583 text: "What's in this image?".to_string(),
584 },
585 UserContent::Image {
586 image_url: ImageUrl {
587 url: "https://example.com/image.jpg".to_string(),
588 detail: ImageDetail::default(),
589 },
590 },
591 ])
592 .unwrap(),
593 name: None,
594 };
595
596 let serialized = serde_json::to_value(&user_message).unwrap();
597
598 assert_eq!(serialized["role"], "user");
599 assert!(serialized["content"].is_array());
600 assert_eq!(serialized["content"].as_array().unwrap().len(), 2);
601 }
602
603 #[test]
604 fn test_user_message_single_image_serializes_as_array() {
605 let user_message = Message::User {
606 content: OneOrMany::one(UserContent::Image {
607 image_url: ImageUrl {
608 url: "https://example.com/image.jpg".to_string(),
609 detail: ImageDetail::default(),
610 },
611 }),
612 name: None,
613 };
614
615 let serialized = serde_json::to_value(&user_message).unwrap();
616
617 assert_eq!(serialized["role"], "user");
618 assert!(serialized["content"].is_array());
620 }
621 #[test]
622 fn test_client_initialization() {
623 let _client =
624 crate::providers::openai::Client::new("dummy-key").expect("Client::new() failed");
625 let _client_from_builder = crate::providers::openai::Client::builder()
626 .api_key("dummy-key")
627 .build()
628 .expect("Client::builder() failed");
629 }
630
631 #[test]
632 fn test_legacy_chat_completion_model_type_annotation_still_compiles() {
633 let client = crate::providers::openai::Client::new("dummy-key")
634 .expect("Client::new() failed")
635 .completions_api();
636
637 let _model: crate::providers::openai::completion::CompletionModel<reqwest::Client> =
638 client.completion_model("gpt-4o");
639 }
640
641 #[test]
642 fn test_legacy_embedding_model_type_annotation_still_compiles() {
643 let client =
644 crate::providers::openai::Client::new("dummy-key").expect("Client::new() failed");
645
646 let _model: crate::providers::openai::EmbeddingModel<reqwest::Client> =
647 client.embedding_model(crate::providers::openai::TEXT_EMBEDDING_3_SMALL);
648 }
649}