1#[cfg(feature = "audio")]
2use super::audio_generation::AudioGenerationModel;
3use super::embedding::{
4 EmbeddingModel, TEXT_EMBEDDING_3_LARGE, TEXT_EMBEDDING_3_SMALL, TEXT_EMBEDDING_ADA_002,
5};
6use std::fmt::Debug;
7
8#[cfg(feature = "image")]
9use super::image_generation::ImageGenerationModel;
10use super::transcription::TranscriptionModel;
11
12use crate::{
13 client::{
14 CompletionClient, EmbeddingsClient, ProviderClient, TranscriptionClient, VerifyClient,
15 VerifyError,
16 },
17 extractor::ExtractorBuilder,
18 http_client::{self, HttpClientExt},
19 providers::openai::CompletionModel,
20};
21
22#[cfg(feature = "audio")]
23use crate::client::AudioGenerationClient;
24#[cfg(feature = "image")]
25use crate::client::ImageGenerationClient;
26
27use bytes::Bytes;
28use schemars::JsonSchema;
29use serde::{Deserialize, Serialize};
30
31const OPENAI_API_BASE_URL: &str = "https://api.openai.com/v1";
35
36pub struct ClientBuilder<'a, T = reqwest::Client> {
37 api_key: &'a str,
38 base_url: &'a str,
39 http_client: T,
40}
41
42impl<'a, T> ClientBuilder<'a, T>
43where
44 T: Default,
45{
46 pub fn new(api_key: &'a str) -> Self {
47 Self {
48 api_key,
49 base_url: OPENAI_API_BASE_URL,
50 http_client: Default::default(),
51 }
52 }
53}
54
55impl<'a, T> ClientBuilder<'a, T> {
56 pub fn base_url(mut self, base_url: &'a str) -> Self {
57 self.base_url = base_url;
58 self
59 }
60
61 pub fn with_client<U>(self, http_client: U) -> ClientBuilder<'a, U> {
62 ClientBuilder {
63 api_key: self.api_key,
64 base_url: self.base_url,
65 http_client,
66 }
67 }
68 pub fn build(self) -> Client<T> {
69 Client {
70 base_url: self.base_url.to_string(),
71 api_key: self.api_key.to_string(),
72 http_client: self.http_client,
73 }
74 }
75}
76
77#[derive(Clone)]
78pub struct Client<T = reqwest::Client> {
79 base_url: String,
80 api_key: String,
81 http_client: T,
82}
83
84impl<T> Debug for Client<T>
85where
86 T: Debug,
87{
88 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89 f.debug_struct("Client")
90 .field("base_url", &self.base_url)
91 .field("http_client", &self.http_client)
92 .field("api_key", &"<REDACTED>")
93 .finish()
94 }
95}
96
97impl<T> Client<T>
98where
99 T: Default,
100{
101 pub fn builder(api_key: &str) -> ClientBuilder<'_, T> {
112 ClientBuilder::new(api_key)
113 }
114
115 pub fn new(api_key: &str) -> Self {
118 Self::builder(api_key).build()
119 }
120}
121
122impl<T> Client<T>
123where
124 T: HttpClientExt,
125{
126 pub(crate) fn post(&self, path: &str) -> http_client::Result<http_client::Builder> {
127 let url = format!("{}/{}", self.base_url, path.trim_start_matches('/'));
128
129 http_client::with_bearer_auth(http_client::Request::post(url), &self.api_key)
130 }
131
132 pub(crate) fn get(&self, path: &str) -> http_client::Result<http_client::Builder> {
133 let url = format!("{}/{}", self.base_url, path.trim_start_matches('/'));
134
135 http_client::with_bearer_auth(http_client::Request::get(url), &self.api_key)
136 }
137
138 pub(crate) async fn send<U, R>(
139 &self,
140 req: http_client::Request<U>,
141 ) -> http_client::Result<http_client::Response<http_client::LazyBody<R>>>
142 where
143 U: Into<Bytes> + Send,
144 R: From<Bytes> + Send + 'static,
145 {
146 self.http_client.send(req).await
147 }
148}
149
150impl Client<reqwest::Client> {
151 pub(crate) fn post_reqwest(&self, path: &str) -> reqwest::RequestBuilder {
152 let url = format!("{}/{}", self.base_url, path.trim_start_matches('/'));
153
154 self.http_client.post(url).bearer_auth(&self.api_key)
155 }
156
157 pub fn extractor_completions_api<U>(
161 &self,
162 model: &str,
163 ) -> ExtractorBuilder<CompletionModel<reqwest::Client>, U>
164 where
165 U: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync,
166 {
167 ExtractorBuilder::new(self.completion_model(model).completions_api())
168 }
169}
170
171impl Client<reqwest::Client> {}
172
173impl ProviderClient for Client<reqwest::Client> {
174 fn from_env() -> Self {
177 let base_url: Option<String> = std::env::var("OPENAI_BASE_URL").ok();
178 let api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set");
179
180 match base_url {
181 Some(url) => Self::builder(&api_key).base_url(&url).build(),
182 None => Self::new(&api_key),
183 }
184 }
185
186 fn from_val(input: crate::client::ProviderValue) -> Self {
187 let crate::client::ProviderValue::Simple(api_key) = input else {
188 panic!("Incorrect provider value type")
189 };
190 Self::new(&api_key)
191 }
192}
193
194impl CompletionClient for Client<reqwest::Client> {
195 type CompletionModel = super::responses_api::ResponsesCompletionModel<reqwest::Client>;
196 fn completion_model(
208 &self,
209 model: &str,
210 ) -> super::responses_api::ResponsesCompletionModel<reqwest::Client> {
211 super::responses_api::ResponsesCompletionModel::new(self.clone(), model)
212 }
213}
214
215impl EmbeddingsClient for Client<reqwest::Client> {
216 type EmbeddingModel = EmbeddingModel<reqwest::Client>;
217 fn embedding_model(&self, model: &str) -> Self::EmbeddingModel {
218 let ndims = match model {
219 TEXT_EMBEDDING_3_LARGE => 3072,
220 TEXT_EMBEDDING_3_SMALL | TEXT_EMBEDDING_ADA_002 => 1536,
221 _ => 0,
222 };
223 EmbeddingModel::new(self.clone(), model, ndims)
224 }
225
226 fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> Self::EmbeddingModel {
227 EmbeddingModel::new(self.clone(), model, ndims)
228 }
229}
230
231impl TranscriptionClient for Client<reqwest::Client> {
232 type TranscriptionModel = TranscriptionModel<reqwest::Client>;
233 fn transcription_model(&self, model: &str) -> TranscriptionModel<reqwest::Client> {
245 TranscriptionModel::new(self.clone(), model)
246 }
247}
248
249#[cfg(feature = "image")]
250impl ImageGenerationClient for Client<reqwest::Client> {
251 type ImageGenerationModel = ImageGenerationModel<reqwest::Client>;
252 fn image_generation_model(&self, model: &str) -> Self::ImageGenerationModel {
264 ImageGenerationModel::new(self.clone(), model)
265 }
266}
267
268#[cfg(feature = "audio")]
269impl AudioGenerationClient for Client<reqwest::Client> {
270 type AudioGenerationModel = AudioGenerationModel<reqwest::Client>;
271 fn audio_generation_model(&self, model: &str) -> Self::AudioGenerationModel {
283 AudioGenerationModel::new(self.clone(), model)
284 }
285}
286
287impl VerifyClient for Client<reqwest::Client> {
288 #[cfg_attr(feature = "worker", worker::send)]
289 async fn verify(&self) -> Result<(), VerifyError> {
290 let req = self
291 .get("/models")?
292 .body(http_client::NoBody)
293 .map_err(|e| VerifyError::HttpError(e.into()))?;
294
295 let response = self.send(req).await?;
296
297 match response.status() {
298 reqwest::StatusCode::OK => Ok(()),
299 reqwest::StatusCode::UNAUTHORIZED => Err(VerifyError::InvalidAuthentication),
300 reqwest::StatusCode::INTERNAL_SERVER_ERROR => {
301 let text = http_client::text(response).await?;
302 Err(VerifyError::ProviderError(text))
303 }
304 _ => {
305 Ok(())
307 }
308 }
309 }
310}
311
312#[derive(Debug, Deserialize)]
313pub struct ApiErrorResponse {
314 pub(crate) message: String,
315}
316
317#[derive(Debug, Deserialize)]
318#[serde(untagged)]
319pub(crate) enum ApiResponse<T> {
320 Ok(T),
321 Err(ApiErrorResponse),
322}
323
324#[cfg(test)]
325mod tests {
326 use crate::message::ImageDetail;
327 use crate::providers::openai::{
328 AssistantContent, Function, ImageUrl, Message, ToolCall, ToolType, UserContent,
329 };
330 use crate::{OneOrMany, message};
331 use serde_path_to_error::deserialize;
332
333 #[test]
334 fn test_deserialize_message() {
335 let assistant_message_json = r#"
336 {
337 "role": "assistant",
338 "content": "\n\nHello there, how may I assist you today?"
339 }
340 "#;
341
342 let assistant_message_json2 = r#"
343 {
344 "role": "assistant",
345 "content": [
346 {
347 "type": "text",
348 "text": "\n\nHello there, how may I assist you today?"
349 }
350 ],
351 "tool_calls": null
352 }
353 "#;
354
355 let assistant_message_json3 = r#"
356 {
357 "role": "assistant",
358 "tool_calls": [
359 {
360 "id": "call_h89ipqYUjEpCPI6SxspMnoUU",
361 "type": "function",
362 "function": {
363 "name": "subtract",
364 "arguments": "{\"x\": 2, \"y\": 5}"
365 }
366 }
367 ],
368 "content": null,
369 "refusal": null
370 }
371 "#;
372
373 let user_message_json = r#"
374 {
375 "role": "user",
376 "content": [
377 {
378 "type": "text",
379 "text": "What's in this image?"
380 },
381 {
382 "type": "image_url",
383 "image_url": {
384 "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
385 }
386 },
387 {
388 "type": "audio",
389 "input_audio": {
390 "data": "...",
391 "format": "mp3"
392 }
393 }
394 ]
395 }
396 "#;
397
398 let assistant_message: Message = {
399 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
400 deserialize(jd).unwrap_or_else(|err| {
401 panic!(
402 "Deserialization error at {} ({}:{}): {}",
403 err.path(),
404 err.inner().line(),
405 err.inner().column(),
406 err
407 );
408 })
409 };
410
411 let assistant_message2: Message = {
412 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
413 deserialize(jd).unwrap_or_else(|err| {
414 panic!(
415 "Deserialization error at {} ({}:{}): {}",
416 err.path(),
417 err.inner().line(),
418 err.inner().column(),
419 err
420 );
421 })
422 };
423
424 let assistant_message3: Message = {
425 let jd: &mut serde_json::Deserializer<serde_json::de::StrRead<'_>> =
426 &mut serde_json::Deserializer::from_str(assistant_message_json3);
427 deserialize(jd).unwrap_or_else(|err| {
428 panic!(
429 "Deserialization error at {} ({}:{}): {}",
430 err.path(),
431 err.inner().line(),
432 err.inner().column(),
433 err
434 );
435 })
436 };
437
438 let user_message: Message = {
439 let jd = &mut serde_json::Deserializer::from_str(user_message_json);
440 deserialize(jd).unwrap_or_else(|err| {
441 panic!(
442 "Deserialization error at {} ({}:{}): {}",
443 err.path(),
444 err.inner().line(),
445 err.inner().column(),
446 err
447 );
448 })
449 };
450
451 match assistant_message {
452 Message::Assistant { content, .. } => {
453 assert_eq!(
454 content[0],
455 AssistantContent::Text {
456 text: "\n\nHello there, how may I assist you today?".to_string()
457 }
458 );
459 }
460 _ => panic!("Expected assistant message"),
461 }
462
463 match assistant_message2 {
464 Message::Assistant {
465 content,
466 tool_calls,
467 ..
468 } => {
469 assert_eq!(
470 content[0],
471 AssistantContent::Text {
472 text: "\n\nHello there, how may I assist you today?".to_string()
473 }
474 );
475
476 assert_eq!(tool_calls, vec![]);
477 }
478 _ => panic!("Expected assistant message"),
479 }
480
481 match assistant_message3 {
482 Message::Assistant {
483 content,
484 tool_calls,
485 refusal,
486 ..
487 } => {
488 assert!(content.is_empty());
489 assert!(refusal.is_none());
490 assert_eq!(
491 tool_calls[0],
492 ToolCall {
493 id: "call_h89ipqYUjEpCPI6SxspMnoUU".to_string(),
494 r#type: ToolType::Function,
495 function: Function {
496 name: "subtract".to_string(),
497 arguments: serde_json::json!({"x": 2, "y": 5}),
498 },
499 }
500 );
501 }
502 _ => panic!("Expected assistant message"),
503 }
504
505 match user_message {
506 Message::User { content, .. } => {
507 let (first, second) = {
508 let mut iter = content.into_iter();
509 (iter.next().unwrap(), iter.next().unwrap())
510 };
511 assert_eq!(
512 first,
513 UserContent::Text {
514 text: "What's in this image?".to_string()
515 }
516 );
517 assert_eq!(second, UserContent::Image { image_url: ImageUrl { url: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg".to_string(), detail: ImageDetail::default() } });
518 }
519 _ => panic!("Expected user message"),
520 }
521 }
522
523 #[test]
524 fn test_message_to_message_conversion() {
525 let user_message = message::Message::User {
526 content: OneOrMany::one(message::UserContent::text("Hello")),
527 };
528
529 let assistant_message = message::Message::Assistant {
530 id: None,
531 content: OneOrMany::one(message::AssistantContent::text("Hi there!")),
532 };
533
534 let converted_user_message: Vec<Message> = user_message.clone().try_into().unwrap();
535 let converted_assistant_message: Vec<Message> =
536 assistant_message.clone().try_into().unwrap();
537
538 match converted_user_message[0].clone() {
539 Message::User { content, .. } => {
540 assert_eq!(
541 content.first(),
542 UserContent::Text {
543 text: "Hello".to_string()
544 }
545 );
546 }
547 _ => panic!("Expected user message"),
548 }
549
550 match converted_assistant_message[0].clone() {
551 Message::Assistant { content, .. } => {
552 assert_eq!(
553 content[0].clone(),
554 AssistantContent::Text {
555 text: "Hi there!".to_string()
556 }
557 );
558 }
559 _ => panic!("Expected assistant message"),
560 }
561
562 let original_user_message: message::Message =
563 converted_user_message[0].clone().try_into().unwrap();
564 let original_assistant_message: message::Message =
565 converted_assistant_message[0].clone().try_into().unwrap();
566
567 assert_eq!(original_user_message, user_message);
568 assert_eq!(original_assistant_message, assistant_message);
569 }
570
571 #[test]
572 fn test_message_from_message_conversion() {
573 let user_message = Message::User {
574 content: OneOrMany::one(UserContent::Text {
575 text: "Hello".to_string(),
576 }),
577 name: None,
578 };
579
580 let assistant_message = Message::Assistant {
581 content: vec![AssistantContent::Text {
582 text: "Hi there!".to_string(),
583 }],
584 refusal: None,
585 audio: None,
586 name: None,
587 tool_calls: vec![],
588 };
589
590 let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
591 let converted_assistant_message: message::Message =
592 assistant_message.clone().try_into().unwrap();
593
594 match converted_user_message.clone() {
595 message::Message::User { content } => {
596 assert_eq!(content.first(), message::UserContent::text("Hello"));
597 }
598 _ => panic!("Expected user message"),
599 }
600
601 match converted_assistant_message.clone() {
602 message::Message::Assistant { content, .. } => {
603 assert_eq!(
604 content.first(),
605 message::AssistantContent::text("Hi there!")
606 );
607 }
608 _ => panic!("Expected assistant message"),
609 }
610
611 let original_user_message: Vec<Message> = converted_user_message.try_into().unwrap();
612 let original_assistant_message: Vec<Message> =
613 converted_assistant_message.try_into().unwrap();
614
615 assert_eq!(original_user_message[0], user_message);
616 assert_eq!(original_assistant_message[0], assistant_message);
617 }
618}