1use crate::{
2 client::{
3 self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
4 ProviderClient,
5 },
6 extractor::ExtractorBuilder,
7 http_client::{self, HttpClientExt},
8 prelude::CompletionClient,
9 wasm_compat::{WasmCompatSend, WasmCompatSync},
10};
11use schemars::JsonSchema;
12use serde::{Deserialize, Serialize};
13use std::fmt::Debug;
14
15const OPENAI_API_BASE_URL: &str = "https://api.openai.com/v1";
19
20#[derive(Debug, Default, Clone, Copy)]
24pub struct OpenAIResponsesExt;
25
26#[derive(Debug, Default, Clone, Copy)]
27pub struct OpenAIResponsesExtBuilder;
28
29#[derive(Debug, Default, Clone, Copy)]
33pub struct OpenAICompletionsExt;
34
35#[derive(Debug, Default, Clone, Copy)]
36pub struct OpenAICompletionsExtBuilder;
37
38type OpenAIApiKey = BearerAuth;
39
40pub type Client<H = reqwest::Client> = client::Client<OpenAIResponsesExt, H>;
42pub type ClientBuilder<H = reqwest::Client> =
43 client::ClientBuilder<OpenAIResponsesExtBuilder, OpenAIApiKey, H>;
44
45pub type CompletionsClient<H = reqwest::Client> = client::Client<OpenAICompletionsExt, H>;
47pub type CompletionsClientBuilder<H = reqwest::Client> =
48 client::ClientBuilder<OpenAICompletionsExtBuilder, OpenAIApiKey, H>;
49
50impl Provider for OpenAIResponsesExt {
51 type Builder = OpenAIResponsesExtBuilder;
52 const VERIFY_PATH: &'static str = "/models";
53}
54
55impl Provider for OpenAICompletionsExt {
56 type Builder = OpenAICompletionsExtBuilder;
57 const VERIFY_PATH: &'static str = "/models";
58}
59
60impl<H> Capabilities<H> for OpenAIResponsesExt {
61 type Completion = Capable<super::responses_api::ResponsesCompletionModel<H>>;
62 type Embeddings = Capable<super::EmbeddingModel<H>>;
63 type Transcription = Capable<super::TranscriptionModel<H>>;
64 type ModelListing = Nothing;
65 #[cfg(feature = "image")]
66 type ImageGeneration = Capable<super::ImageGenerationModel<H>>;
67 #[cfg(feature = "audio")]
68 type AudioGeneration = Capable<super::audio_generation::AudioGenerationModel<H>>;
69}
70
71impl<H> Capabilities<H> for OpenAICompletionsExt {
72 type Completion = Capable<super::completion::CompletionModel<H>>;
73 type Embeddings = Capable<super::EmbeddingModel<H>>;
74 type Transcription = Capable<super::TranscriptionModel<H>>;
75 type ModelListing = Nothing;
76 #[cfg(feature = "image")]
77 type ImageGeneration = Capable<super::ImageGenerationModel<H>>;
78 #[cfg(feature = "audio")]
79 type AudioGeneration = Capable<super::audio_generation::AudioGenerationModel<H>>;
80}
81
82impl DebugExt for OpenAIResponsesExt {}
83
84impl DebugExt for OpenAICompletionsExt {}
85
86impl ProviderBuilder for OpenAIResponsesExtBuilder {
87 type Extension<H>
88 = OpenAIResponsesExt
89 where
90 H: HttpClientExt;
91 type ApiKey = OpenAIApiKey;
92
93 const BASE_URL: &'static str = OPENAI_API_BASE_URL;
94
95 fn build<H>(
96 _builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
97 ) -> http_client::Result<Self::Extension<H>>
98 where
99 H: HttpClientExt,
100 {
101 Ok(OpenAIResponsesExt)
102 }
103}
104
105impl ProviderBuilder for OpenAICompletionsExtBuilder {
106 type Extension<H>
107 = OpenAICompletionsExt
108 where
109 H: HttpClientExt;
110 type ApiKey = OpenAIApiKey;
111
112 const BASE_URL: &'static str = OPENAI_API_BASE_URL;
113
114 fn build<H>(
115 _builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
116 ) -> http_client::Result<Self::Extension<H>>
117 where
118 H: HttpClientExt,
119 {
120 Ok(OpenAICompletionsExt)
121 }
122}
123
124impl<H> Client<H>
125where
126 H: HttpClientExt
127 + Clone
128 + std::fmt::Debug
129 + Default
130 + WasmCompatSend
131 + WasmCompatSync
132 + 'static,
133{
134 pub fn extractor<U>(
137 &self,
138 model: impl Into<String>,
139 ) -> ExtractorBuilder<super::responses_api::ResponsesCompletionModel<H>, U>
140 where
141 U: JsonSchema + for<'a> Deserialize<'a> + Serialize + WasmCompatSend + WasmCompatSync,
142 {
143 ExtractorBuilder::new(self.completion_model(model))
144 }
145
146 pub fn completions_api(self) -> CompletionsClient<H> {
149 self.with_ext(OpenAICompletionsExt)
150 }
151}
152
153#[cfg(not(target_family = "wasm"))]
154impl Client<reqwest::Client> {
155 pub fn responses_websocket_builder(
159 &self,
160 model: impl Into<String>,
161 ) -> super::responses_api::websocket::ResponsesWebSocketSessionBuilder {
162 super::responses_api::websocket::ResponsesWebSocketSessionBuilder::new(
163 self.completion_model(model),
164 )
165 }
166
167 pub async fn responses_websocket(
169 &self,
170 model: impl Into<String>,
171 ) -> Result<
172 super::responses_api::websocket::ResponsesWebSocketSession,
173 crate::completion::CompletionError,
174 > {
175 self.responses_websocket_builder(model).connect().await
176 }
177}
178
179impl<H> CompletionsClient<H>
180where
181 H: HttpClientExt
182 + Clone
183 + std::fmt::Debug
184 + Default
185 + WasmCompatSend
186 + WasmCompatSync
187 + 'static,
188{
189 pub fn extractor<U>(
192 &self,
193 model: impl Into<String>,
194 ) -> ExtractorBuilder<super::completion::CompletionModel<H>, U>
195 where
196 U: JsonSchema + for<'a> Deserialize<'a> + Serialize + WasmCompatSend + WasmCompatSync,
197 {
198 ExtractorBuilder::new(self.completion_model(model))
199 }
200
201 pub fn responses_api(self) -> Client<H> {
204 self.with_ext(OpenAIResponsesExt)
205 }
206}
207
208impl ProviderClient for Client {
209 type Input = OpenAIApiKey;
210
211 fn from_env() -> Self {
214 let base_url: Option<String> = std::env::var("OPENAI_BASE_URL").ok();
215 let api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set");
216
217 let mut builder = Client::builder().api_key(&api_key);
218
219 if let Some(base) = base_url {
220 builder = builder.base_url(&base);
221 }
222
223 builder.build().unwrap()
224 }
225
226 fn from_val(input: Self::Input) -> Self {
227 Self::new(input).unwrap()
228 }
229}
230
231impl ProviderClient for CompletionsClient {
232 type Input = OpenAIApiKey;
233
234 fn from_env() -> Self {
237 let base_url: Option<String> = std::env::var("OPENAI_BASE_URL").ok();
238 let api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set");
239
240 let mut builder = CompletionsClient::builder().api_key(&api_key);
241
242 if let Some(base) = base_url {
243 builder = builder.base_url(&base);
244 }
245
246 builder.build().unwrap()
247 }
248
249 fn from_val(input: Self::Input) -> Self {
250 Self::new(input).unwrap()
251 }
252}
253
254#[derive(Debug, Deserialize)]
255pub struct ApiErrorResponse {
256 pub(crate) message: String,
257}
258
259#[derive(Debug, Deserialize)]
260#[serde(untagged)]
261pub(crate) enum ApiResponse<T> {
262 Ok(T),
263 Err(ApiErrorResponse),
264}
265
266#[cfg(test)]
267mod tests {
268 use crate::message::ImageDetail;
269 use crate::providers::openai::{
270 AssistantContent, Function, ImageUrl, Message, ToolCall, ToolType, UserContent,
271 };
272 use crate::{OneOrMany, message};
273 use serde_path_to_error::deserialize;
274
275 #[test]
276 fn test_deserialize_message() {
277 let assistant_message_json = r#"
278 {
279 "role": "assistant",
280 "content": "\n\nHello there, how may I assist you today?"
281 }
282 "#;
283
284 let assistant_message_json2 = r#"
285 {
286 "role": "assistant",
287 "content": [
288 {
289 "type": "text",
290 "text": "\n\nHello there, how may I assist you today?"
291 }
292 ],
293 "tool_calls": null
294 }
295 "#;
296
297 let assistant_message_json3 = r#"
298 {
299 "role": "assistant",
300 "tool_calls": [
301 {
302 "id": "call_h89ipqYUjEpCPI6SxspMnoUU",
303 "type": "function",
304 "function": {
305 "name": "subtract",
306 "arguments": "{\"x\": 2, \"y\": 5}"
307 }
308 }
309 ],
310 "content": null,
311 "refusal": null
312 }
313 "#;
314
315 let user_message_json = r#"
316 {
317 "role": "user",
318 "content": [
319 {
320 "type": "text",
321 "text": "What's in this image?"
322 },
323 {
324 "type": "image_url",
325 "image_url": {
326 "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
327 }
328 },
329 {
330 "type": "audio",
331 "input_audio": {
332 "data": "...",
333 "format": "mp3"
334 }
335 }
336 ]
337 }
338 "#;
339
340 let assistant_message: Message = {
341 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
342 deserialize(jd).unwrap_or_else(|err| {
343 panic!(
344 "Deserialization error at {} ({}:{}): {}",
345 err.path(),
346 err.inner().line(),
347 err.inner().column(),
348 err
349 );
350 })
351 };
352
353 let assistant_message2: Message = {
354 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
355 deserialize(jd).unwrap_or_else(|err| {
356 panic!(
357 "Deserialization error at {} ({}:{}): {}",
358 err.path(),
359 err.inner().line(),
360 err.inner().column(),
361 err
362 );
363 })
364 };
365
366 let assistant_message3: Message = {
367 let jd: &mut serde_json::Deserializer<serde_json::de::StrRead<'_>> =
368 &mut serde_json::Deserializer::from_str(assistant_message_json3);
369 deserialize(jd).unwrap_or_else(|err| {
370 panic!(
371 "Deserialization error at {} ({}:{}): {}",
372 err.path(),
373 err.inner().line(),
374 err.inner().column(),
375 err
376 );
377 })
378 };
379
380 let user_message: Message = {
381 let jd = &mut serde_json::Deserializer::from_str(user_message_json);
382 deserialize(jd).unwrap_or_else(|err| {
383 panic!(
384 "Deserialization error at {} ({}:{}): {}",
385 err.path(),
386 err.inner().line(),
387 err.inner().column(),
388 err
389 );
390 })
391 };
392
393 match assistant_message {
394 Message::Assistant { content, .. } => {
395 assert_eq!(
396 content[0],
397 AssistantContent::Text {
398 text: "\n\nHello there, how may I assist you today?".to_string()
399 }
400 );
401 }
402 _ => panic!("Expected assistant message"),
403 }
404
405 match assistant_message2 {
406 Message::Assistant {
407 content,
408 tool_calls,
409 ..
410 } => {
411 assert_eq!(
412 content[0],
413 AssistantContent::Text {
414 text: "\n\nHello there, how may I assist you today?".to_string()
415 }
416 );
417
418 assert_eq!(tool_calls, vec![]);
419 }
420 _ => panic!("Expected assistant message"),
421 }
422
423 match assistant_message3 {
424 Message::Assistant {
425 content,
426 tool_calls,
427 refusal,
428 ..
429 } => {
430 assert!(content.is_empty());
431 assert!(refusal.is_none());
432 assert_eq!(
433 tool_calls[0],
434 ToolCall {
435 id: "call_h89ipqYUjEpCPI6SxspMnoUU".to_string(),
436 r#type: ToolType::Function,
437 function: Function {
438 name: "subtract".to_string(),
439 arguments: serde_json::json!({"x": 2, "y": 5}),
440 },
441 }
442 );
443 }
444 _ => panic!("Expected assistant message"),
445 }
446
447 match user_message {
448 Message::User { content, .. } => {
449 let (first, second) = {
450 let mut iter = content.into_iter();
451 (iter.next().unwrap(), iter.next().unwrap())
452 };
453 assert_eq!(
454 first,
455 UserContent::Text {
456 text: "What's in this image?".to_string()
457 }
458 );
459 assert_eq!(second, UserContent::Image { image_url: ImageUrl { url: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg".to_string(), detail: ImageDetail::default() } });
460 }
461 _ => panic!("Expected user message"),
462 }
463 }
464
465 #[test]
466 fn test_message_to_message_conversion() {
467 let user_message = message::Message::User {
468 content: OneOrMany::one(message::UserContent::text("Hello")),
469 };
470
471 let assistant_message = message::Message::Assistant {
472 id: None,
473 content: OneOrMany::one(message::AssistantContent::text("Hi there!")),
474 };
475
476 let converted_user_message: Vec<Message> = user_message.clone().try_into().unwrap();
477 let converted_assistant_message: Vec<Message> =
478 assistant_message.clone().try_into().unwrap();
479
480 match converted_user_message[0].clone() {
481 Message::User { content, .. } => {
482 assert_eq!(
483 content.first(),
484 UserContent::Text {
485 text: "Hello".to_string()
486 }
487 );
488 }
489 _ => panic!("Expected user message"),
490 }
491
492 match converted_assistant_message[0].clone() {
493 Message::Assistant { content, .. } => {
494 assert_eq!(
495 content[0].clone(),
496 AssistantContent::Text {
497 text: "Hi there!".to_string()
498 }
499 );
500 }
501 _ => panic!("Expected assistant message"),
502 }
503
504 let original_user_message: message::Message =
505 converted_user_message[0].clone().try_into().unwrap();
506 let original_assistant_message: message::Message =
507 converted_assistant_message[0].clone().try_into().unwrap();
508
509 assert_eq!(original_user_message, user_message);
510 assert_eq!(original_assistant_message, assistant_message);
511 }
512
513 #[test]
514 fn test_message_from_message_conversion() {
515 let user_message = Message::User {
516 content: OneOrMany::one(UserContent::Text {
517 text: "Hello".to_string(),
518 }),
519 name: None,
520 };
521
522 let assistant_message = Message::Assistant {
523 content: vec![AssistantContent::Text {
524 text: "Hi there!".to_string(),
525 }],
526 refusal: None,
527 audio: None,
528 name: None,
529 tool_calls: vec![],
530 };
531
532 let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
533 let converted_assistant_message: message::Message =
534 assistant_message.clone().try_into().unwrap();
535
536 match converted_user_message.clone() {
537 message::Message::User { content } => {
538 assert_eq!(content.first(), message::UserContent::text("Hello"));
539 }
540 _ => panic!("Expected user message"),
541 }
542
543 match converted_assistant_message.clone() {
544 message::Message::Assistant { content, .. } => {
545 assert_eq!(
546 content.first(),
547 message::AssistantContent::text("Hi there!")
548 );
549 }
550 _ => panic!("Expected assistant message"),
551 }
552
553 let original_user_message: Vec<Message> = converted_user_message.try_into().unwrap();
554 let original_assistant_message: Vec<Message> =
555 converted_assistant_message.try_into().unwrap();
556
557 assert_eq!(original_user_message[0], user_message);
558 assert_eq!(original_assistant_message[0], assistant_message);
559 }
560
561 #[test]
562 fn test_user_message_single_text_serializes_as_string() {
563 let user_message = Message::User {
564 content: OneOrMany::one(UserContent::Text {
565 text: "Hello world".to_string(),
566 }),
567 name: None,
568 };
569
570 let serialized = serde_json::to_value(&user_message).unwrap();
571
572 assert_eq!(serialized["role"], "user");
573 assert_eq!(serialized["content"], "Hello world");
574 }
575
576 #[test]
577 fn test_user_message_multiple_parts_serializes_as_array() {
578 let user_message = Message::User {
579 content: OneOrMany::many(vec![
580 UserContent::Text {
581 text: "What's in this image?".to_string(),
582 },
583 UserContent::Image {
584 image_url: ImageUrl {
585 url: "https://example.com/image.jpg".to_string(),
586 detail: ImageDetail::default(),
587 },
588 },
589 ])
590 .unwrap(),
591 name: None,
592 };
593
594 let serialized = serde_json::to_value(&user_message).unwrap();
595
596 assert_eq!(serialized["role"], "user");
597 assert!(serialized["content"].is_array());
598 assert_eq!(serialized["content"].as_array().unwrap().len(), 2);
599 }
600
601 #[test]
602 fn test_user_message_single_image_serializes_as_array() {
603 let user_message = Message::User {
604 content: OneOrMany::one(UserContent::Image {
605 image_url: ImageUrl {
606 url: "https://example.com/image.jpg".to_string(),
607 detail: ImageDetail::default(),
608 },
609 }),
610 name: None,
611 };
612
613 let serialized = serde_json::to_value(&user_message).unwrap();
614
615 assert_eq!(serialized["role"], "user");
616 assert!(serialized["content"].is_array());
618 }
619 #[test]
620 fn test_client_initialization() {
621 let _client =
622 crate::providers::openai::Client::new("dummy-key").expect("Client::new() failed");
623 let _client_from_builder = crate::providers::openai::Client::builder()
624 .api_key("dummy-key")
625 .build()
626 .expect("Client::builder() failed");
627 }
628}