1#[cfg(feature = "audio")]
2use super::audio_generation::AudioGenerationModel;
3use super::embedding::{
4 EmbeddingModel, TEXT_EMBEDDING_3_LARGE, TEXT_EMBEDDING_3_SMALL, TEXT_EMBEDDING_ADA_002,
5};
6use std::fmt::Debug;
7
8#[cfg(feature = "image")]
9use super::image_generation::ImageGenerationModel;
10use super::transcription::TranscriptionModel;
11
12use crate::{
13 client::{
14 CompletionClient, EmbeddingsClient, ProviderClient, TranscriptionClient, VerifyClient,
15 VerifyError,
16 },
17 extractor::ExtractorBuilder,
18 http_client::{self, HttpClientExt},
19 providers::openai::CompletionModel,
20};
21
22#[cfg(feature = "audio")]
23use crate::client::AudioGenerationClient;
24#[cfg(feature = "image")]
25use crate::client::ImageGenerationClient;
26
27use bytes::Bytes;
28use schemars::JsonSchema;
29use serde::{Deserialize, Serialize};
30
31const OPENAI_API_BASE_URL: &str = "https://api.openai.com/v1";
35
36pub struct ClientBuilder<'a, T = reqwest::Client> {
37 api_key: &'a str,
38 base_url: &'a str,
39 http_client: T,
40}
41
42impl<'a, T> ClientBuilder<'a, T>
43where
44 T: Default,
45{
46 pub fn new(api_key: &'a str) -> Self {
47 Self {
48 api_key,
49 base_url: OPENAI_API_BASE_URL,
50 http_client: Default::default(),
51 }
52 }
53}
54
55impl<'a, T> ClientBuilder<'a, T> {
56 pub fn new_with_client(api_key: &'a str, http_client: T) -> Self {
57 ClientBuilder {
58 api_key,
59 base_url: OPENAI_API_BASE_URL,
60 http_client,
61 }
62 }
63
64 pub fn base_url(mut self, base_url: &'a str) -> Self {
65 self.base_url = base_url;
66 self
67 }
68
69 pub fn with_client<U>(self, http_client: U) -> ClientBuilder<'a, U> {
70 ClientBuilder {
71 api_key: self.api_key,
72 base_url: self.base_url,
73 http_client,
74 }
75 }
76 pub fn build(self) -> Client<T> {
77 Client {
78 base_url: self.base_url.to_string(),
79 api_key: self.api_key.to_string(),
80 http_client: self.http_client,
81 }
82 }
83}
84
85#[derive(Clone)]
86pub struct Client<T = reqwest::Client> {
87 base_url: String,
88 api_key: String,
89 pub(crate) http_client: T,
90}
91
92impl<T> Debug for Client<T>
93where
94 T: Debug,
95{
96 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
97 f.debug_struct("Client")
98 .field("base_url", &self.base_url)
99 .field("http_client", &self.http_client)
100 .field("api_key", &"<REDACTED>")
101 .finish()
102 }
103}
104
105impl Client<reqwest::Client> {
106 pub fn builder(api_key: &str) -> ClientBuilder<'_, reqwest::Client> {
117 ClientBuilder::new(api_key)
118 }
119
120 pub fn new(api_key: &str) -> Self {
123 Self::builder(api_key).build()
124 }
125
126 pub fn from_env() -> Self {
127 <Self as ProviderClient>::from_env()
128 }
129}
130
131impl<T> Client<T>
132where
133 T: HttpClientExt + Clone + std::fmt::Debug + Default + Send + 'static,
134{
135 pub(crate) fn post(&self, path: &str) -> http_client::Result<http_client::Builder> {
136 let url = format!("{}/{}", self.base_url, path.trim_start_matches('/'));
137
138 http_client::with_bearer_auth(http_client::Request::post(url), &self.api_key)
139 }
140
141 pub(crate) fn get(&self, path: &str) -> http_client::Result<http_client::Builder> {
142 let url = format!("{}/{}", self.base_url, path.trim_start_matches('/'));
143
144 http_client::with_bearer_auth(http_client::Request::get(url), &self.api_key)
145 }
146
147 pub(crate) async fn send<U, R>(
148 &self,
149 req: http_client::Request<U>,
150 ) -> http_client::Result<http_client::Response<http_client::LazyBody<R>>>
151 where
152 U: Into<Bytes> + Send,
153 R: From<Bytes> + Send + 'static,
154 {
155 self.http_client.send(req).await
156 }
157
158 pub fn extractor_completions_api<U>(
162 &self,
163 model: &str,
164 ) -> ExtractorBuilder<CompletionModel<T>, U>
165 where
166 U: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync,
167 CompletionModel<T>: crate::completion::CompletionModel,
168 {
169 ExtractorBuilder::new(self.completion_model(model).completions_api())
170 }
171}
172
173impl<T> ProviderClient for Client<T>
174where
175 T: HttpClientExt + Clone + std::fmt::Debug + Default + Send + 'static,
176{
177 fn from_env() -> Self {
180 let base_url: Option<String> = std::env::var("OPENAI_BASE_URL").ok();
181 let api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set");
182
183 match base_url {
184 Some(url) => ClientBuilder::<T>::new(&api_key).base_url(&url).build(),
185 None => ClientBuilder::<T>::new(&api_key).build(),
186 }
187 }
188
189 fn from_val(input: crate::client::ProviderValue) -> Self {
190 let crate::client::ProviderValue::Simple(api_key) = input else {
191 panic!("Incorrect provider value type")
192 };
193
194 ClientBuilder::<T>::new(&api_key).build()
195 }
196}
197
198impl<T> CompletionClient for Client<T>
199where
200 T: HttpClientExt + std::fmt::Debug + Clone + Default + Send + 'static,
201{
202 type CompletionModel = super::responses_api::ResponsesCompletionModel<T>;
203 fn completion_model(&self, model: &str) -> Self::CompletionModel {
215 super::responses_api::ResponsesCompletionModel::new(self.clone(), model)
216 }
217}
218
219impl<T> EmbeddingsClient for Client<T>
220where
221 T: HttpClientExt + std::fmt::Debug + Clone + Default + Send + 'static,
222{
223 type EmbeddingModel = EmbeddingModel<T>;
224 fn embedding_model(&self, model: &str) -> Self::EmbeddingModel {
225 let ndims = match model {
226 TEXT_EMBEDDING_3_LARGE => 3072,
227 TEXT_EMBEDDING_3_SMALL | TEXT_EMBEDDING_ADA_002 => 1536,
228 _ => 0,
229 };
230 EmbeddingModel::new(self.clone(), model, ndims)
231 }
232
233 fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> Self::EmbeddingModel {
234 EmbeddingModel::new(self.clone(), model, ndims)
235 }
236}
237
238impl<T> TranscriptionClient for Client<T>
239where
240 T: HttpClientExt + Clone + std::fmt::Debug + Default + Send + 'static,
241{
242 type TranscriptionModel = TranscriptionModel<T>;
243 fn transcription_model(&self, model: &str) -> Self::TranscriptionModel {
255 TranscriptionModel::new(self.clone(), model)
256 }
257}
258
259#[cfg(feature = "image")]
260impl<T> ImageGenerationClient for Client<T>
261where
262 T: HttpClientExt + Clone + std::fmt::Debug + Default + Send + 'static,
263{
264 type ImageGenerationModel = ImageGenerationModel<T>;
265 fn image_generation_model(&self, model: &str) -> Self::ImageGenerationModel {
277 ImageGenerationModel::new(self.clone(), model)
278 }
279}
280
281#[cfg(feature = "audio")]
282impl<T> AudioGenerationClient for Client<T>
283where
284 T: HttpClientExt + Clone + std::fmt::Debug + Send + Default + 'static,
285{
286 type AudioGenerationModel = AudioGenerationModel<T>;
287 fn audio_generation_model(&self, model: &str) -> Self::AudioGenerationModel {
299 AudioGenerationModel::new(self.clone(), model)
300 }
301}
302
303impl<T> VerifyClient for Client<T>
304where
305 T: HttpClientExt + Clone + std::fmt::Debug + Send + Default + 'static,
306{
307 #[cfg_attr(feature = "worker", worker::send)]
308 async fn verify(&self) -> Result<(), VerifyError> {
309 let req = self
310 .get("/models")?
311 .body(http_client::NoBody)
312 .map_err(|e| VerifyError::HttpError(e.into()))?;
313
314 let response = self.send(req).await?;
315
316 match response.status() {
317 reqwest::StatusCode::OK => Ok(()),
318 reqwest::StatusCode::UNAUTHORIZED => Err(VerifyError::InvalidAuthentication),
319 reqwest::StatusCode::INTERNAL_SERVER_ERROR => {
320 let text = http_client::text(response).await?;
321 Err(VerifyError::ProviderError(text))
322 }
323 _ => {
324 Ok(())
326 }
327 }
328 }
329}
330
331#[derive(Debug, Deserialize)]
332pub struct ApiErrorResponse {
333 pub(crate) message: String,
334}
335
336#[derive(Debug, Deserialize)]
337#[serde(untagged)]
338pub(crate) enum ApiResponse<T> {
339 Ok(T),
340 Err(ApiErrorResponse),
341}
342
343#[cfg(test)]
344mod tests {
345 use crate::message::ImageDetail;
346 use crate::providers::openai::{
347 AssistantContent, Function, ImageUrl, Message, ToolCall, ToolType, UserContent,
348 };
349 use crate::{OneOrMany, message};
350 use serde_path_to_error::deserialize;
351
352 #[test]
353 fn test_deserialize_message() {
354 let assistant_message_json = r#"
355 {
356 "role": "assistant",
357 "content": "\n\nHello there, how may I assist you today?"
358 }
359 "#;
360
361 let assistant_message_json2 = r#"
362 {
363 "role": "assistant",
364 "content": [
365 {
366 "type": "text",
367 "text": "\n\nHello there, how may I assist you today?"
368 }
369 ],
370 "tool_calls": null
371 }
372 "#;
373
374 let assistant_message_json3 = r#"
375 {
376 "role": "assistant",
377 "tool_calls": [
378 {
379 "id": "call_h89ipqYUjEpCPI6SxspMnoUU",
380 "type": "function",
381 "function": {
382 "name": "subtract",
383 "arguments": "{\"x\": 2, \"y\": 5}"
384 }
385 }
386 ],
387 "content": null,
388 "refusal": null
389 }
390 "#;
391
392 let user_message_json = r#"
393 {
394 "role": "user",
395 "content": [
396 {
397 "type": "text",
398 "text": "What's in this image?"
399 },
400 {
401 "type": "image_url",
402 "image_url": {
403 "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
404 }
405 },
406 {
407 "type": "audio",
408 "input_audio": {
409 "data": "...",
410 "format": "mp3"
411 }
412 }
413 ]
414 }
415 "#;
416
417 let assistant_message: Message = {
418 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
419 deserialize(jd).unwrap_or_else(|err| {
420 panic!(
421 "Deserialization error at {} ({}:{}): {}",
422 err.path(),
423 err.inner().line(),
424 err.inner().column(),
425 err
426 );
427 })
428 };
429
430 let assistant_message2: Message = {
431 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
432 deserialize(jd).unwrap_or_else(|err| {
433 panic!(
434 "Deserialization error at {} ({}:{}): {}",
435 err.path(),
436 err.inner().line(),
437 err.inner().column(),
438 err
439 );
440 })
441 };
442
443 let assistant_message3: Message = {
444 let jd: &mut serde_json::Deserializer<serde_json::de::StrRead<'_>> =
445 &mut serde_json::Deserializer::from_str(assistant_message_json3);
446 deserialize(jd).unwrap_or_else(|err| {
447 panic!(
448 "Deserialization error at {} ({}:{}): {}",
449 err.path(),
450 err.inner().line(),
451 err.inner().column(),
452 err
453 );
454 })
455 };
456
457 let user_message: Message = {
458 let jd = &mut serde_json::Deserializer::from_str(user_message_json);
459 deserialize(jd).unwrap_or_else(|err| {
460 panic!(
461 "Deserialization error at {} ({}:{}): {}",
462 err.path(),
463 err.inner().line(),
464 err.inner().column(),
465 err
466 );
467 })
468 };
469
470 match assistant_message {
471 Message::Assistant { content, .. } => {
472 assert_eq!(
473 content[0],
474 AssistantContent::Text {
475 text: "\n\nHello there, how may I assist you today?".to_string()
476 }
477 );
478 }
479 _ => panic!("Expected assistant message"),
480 }
481
482 match assistant_message2 {
483 Message::Assistant {
484 content,
485 tool_calls,
486 ..
487 } => {
488 assert_eq!(
489 content[0],
490 AssistantContent::Text {
491 text: "\n\nHello there, how may I assist you today?".to_string()
492 }
493 );
494
495 assert_eq!(tool_calls, vec![]);
496 }
497 _ => panic!("Expected assistant message"),
498 }
499
500 match assistant_message3 {
501 Message::Assistant {
502 content,
503 tool_calls,
504 refusal,
505 ..
506 } => {
507 assert!(content.is_empty());
508 assert!(refusal.is_none());
509 assert_eq!(
510 tool_calls[0],
511 ToolCall {
512 id: "call_h89ipqYUjEpCPI6SxspMnoUU".to_string(),
513 r#type: ToolType::Function,
514 function: Function {
515 name: "subtract".to_string(),
516 arguments: serde_json::json!({"x": 2, "y": 5}),
517 },
518 }
519 );
520 }
521 _ => panic!("Expected assistant message"),
522 }
523
524 match user_message {
525 Message::User { content, .. } => {
526 let (first, second) = {
527 let mut iter = content.into_iter();
528 (iter.next().unwrap(), iter.next().unwrap())
529 };
530 assert_eq!(
531 first,
532 UserContent::Text {
533 text: "What's in this image?".to_string()
534 }
535 );
536 assert_eq!(second, UserContent::Image { image_url: ImageUrl { url: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg".to_string(), detail: ImageDetail::default() } });
537 }
538 _ => panic!("Expected user message"),
539 }
540 }
541
542 #[test]
543 fn test_message_to_message_conversion() {
544 let user_message = message::Message::User {
545 content: OneOrMany::one(message::UserContent::text("Hello")),
546 };
547
548 let assistant_message = message::Message::Assistant {
549 id: None,
550 content: OneOrMany::one(message::AssistantContent::text("Hi there!")),
551 };
552
553 let converted_user_message: Vec<Message> = user_message.clone().try_into().unwrap();
554 let converted_assistant_message: Vec<Message> =
555 assistant_message.clone().try_into().unwrap();
556
557 match converted_user_message[0].clone() {
558 Message::User { content, .. } => {
559 assert_eq!(
560 content.first(),
561 UserContent::Text {
562 text: "Hello".to_string()
563 }
564 );
565 }
566 _ => panic!("Expected user message"),
567 }
568
569 match converted_assistant_message[0].clone() {
570 Message::Assistant { content, .. } => {
571 assert_eq!(
572 content[0].clone(),
573 AssistantContent::Text {
574 text: "Hi there!".to_string()
575 }
576 );
577 }
578 _ => panic!("Expected assistant message"),
579 }
580
581 let original_user_message: message::Message =
582 converted_user_message[0].clone().try_into().unwrap();
583 let original_assistant_message: message::Message =
584 converted_assistant_message[0].clone().try_into().unwrap();
585
586 assert_eq!(original_user_message, user_message);
587 assert_eq!(original_assistant_message, assistant_message);
588 }
589
590 #[test]
591 fn test_message_from_message_conversion() {
592 let user_message = Message::User {
593 content: OneOrMany::one(UserContent::Text {
594 text: "Hello".to_string(),
595 }),
596 name: None,
597 };
598
599 let assistant_message = Message::Assistant {
600 content: vec![AssistantContent::Text {
601 text: "Hi there!".to_string(),
602 }],
603 refusal: None,
604 audio: None,
605 name: None,
606 tool_calls: vec![],
607 };
608
609 let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
610 let converted_assistant_message: message::Message =
611 assistant_message.clone().try_into().unwrap();
612
613 match converted_user_message.clone() {
614 message::Message::User { content } => {
615 assert_eq!(content.first(), message::UserContent::text("Hello"));
616 }
617 _ => panic!("Expected user message"),
618 }
619
620 match converted_assistant_message.clone() {
621 message::Message::Assistant { content, .. } => {
622 assert_eq!(
623 content.first(),
624 message::AssistantContent::text("Hi there!")
625 );
626 }
627 _ => panic!("Expected assistant message"),
628 }
629
630 let original_user_message: Vec<Message> = converted_user_message.try_into().unwrap();
631 let original_assistant_message: Vec<Message> =
632 converted_assistant_message.try_into().unwrap();
633
634 assert_eq!(original_user_message[0], user_message);
635 assert_eq!(original_assistant_message[0], assistant_message);
636 }
637}