1use std::pin::Pin;
2
3use bytes::Bytes;
4use futures::{stream::StreamExt, Stream};
5use reqwest::{header::HeaderMap, multipart::Form, Response};
6use reqwest_eventsource::{Error as EventSourceError, Event, EventSource, RequestBuilderExt};
7use serde::{de::DeserializeOwned, Serialize};
8
9use crate::{
10 admin::Admin,
11 chatkit::Chatkit,
12 config::{Config, OpenAIConfig},
13 error::{map_deserialization_error, ApiError, OpenAIError, StreamError, WrappedError},
14 file::Files,
15 image::Images,
16 moderation::Moderations,
17 traits::AsyncTryFrom,
18 Assistants, Audio, Batches, Chat, Completions, Containers, Conversations, Embeddings, Evals,
19 FineTuning, Models, Responses, Threads, Uploads, Usage, VectorStores, Videos,
20};
21
22#[cfg(feature = "realtime")]
23use crate::Realtime;
24
25#[derive(Debug, Clone, Default)]
26pub struct Client<C: Config> {
29 http_client: reqwest::Client,
30 config: C,
31 backoff: backoff::ExponentialBackoff,
32}
33
34impl Client<OpenAIConfig> {
35 pub fn new() -> Self {
37 Self::default()
38 }
39}
40
41impl<C: Config> Client<C> {
42 pub fn build(
44 http_client: reqwest::Client,
45 config: C,
46 backoff: backoff::ExponentialBackoff,
47 ) -> Self {
48 Self {
49 http_client,
50 config,
51 backoff,
52 }
53 }
54
55 pub fn with_config(config: C) -> Self {
57 Self {
58 http_client: reqwest::Client::new(),
59 config,
60 backoff: Default::default(),
61 }
62 }
63
64 pub fn with_http_client(mut self, http_client: reqwest::Client) -> Self {
68 self.http_client = http_client;
69 self
70 }
71
72 pub fn with_backoff(mut self, backoff: backoff::ExponentialBackoff) -> Self {
74 self.backoff = backoff;
75 self
76 }
77
78 pub fn models(&self) -> Models<'_, C> {
82 Models::new(self)
83 }
84
85 pub fn completions(&self) -> Completions<'_, C> {
87 Completions::new(self)
88 }
89
90 pub fn chat(&self) -> Chat<'_, C> {
92 Chat::new(self)
93 }
94
95 pub fn images(&self) -> Images<'_, C> {
97 Images::new(self)
98 }
99
100 pub fn moderations(&self) -> Moderations<'_, C> {
102 Moderations::new(self)
103 }
104
105 pub fn files(&self) -> Files<'_, C> {
107 Files::new(self)
108 }
109
110 pub fn uploads(&self) -> Uploads<'_, C> {
112 Uploads::new(self)
113 }
114
115 pub fn fine_tuning(&self) -> FineTuning<'_, C> {
117 FineTuning::new(self)
118 }
119
120 pub fn embeddings(&self) -> Embeddings<'_, C> {
122 Embeddings::new(self)
123 }
124
125 pub fn audio(&self) -> Audio<'_, C> {
127 Audio::new(self)
128 }
129
130 pub fn videos(&self) -> Videos<'_, C> {
132 Videos::new(self)
133 }
134
135 pub fn assistants(&self) -> Assistants<'_, C> {
137 Assistants::new(self)
138 }
139
140 pub fn threads(&self) -> Threads<'_, C> {
142 Threads::new(self)
143 }
144
145 pub fn vector_stores(&self) -> VectorStores<'_, C> {
147 VectorStores::new(self)
148 }
149
150 pub fn batches(&self) -> Batches<'_, C> {
152 Batches::new(self)
153 }
154
155 pub fn admin(&self) -> Admin<'_, C> {
158 Admin::new(self)
159 }
160
161 pub fn usage(&self) -> Usage<'_, C> {
163 Usage::new(self)
164 }
165
166 pub fn responses(&self) -> Responses<'_, C> {
168 Responses::new(self)
169 }
170
171 pub fn conversations(&self) -> Conversations<'_, C> {
173 Conversations::new(self)
174 }
175
176 pub fn containers(&self) -> Containers<'_, C> {
178 Containers::new(self)
179 }
180
181 pub fn evals(&self) -> Evals<'_, C> {
183 Evals::new(self)
184 }
185
186 pub fn chatkit(&self) -> Chatkit<'_, C> {
187 Chatkit::new(self)
188 }
189
190 #[cfg(feature = "realtime")]
191 pub fn realtime(&self) -> Realtime<'_, C> {
193 Realtime::new(self)
194 }
195
196 pub fn config(&self) -> &C {
197 &self.config
198 }
199
200 pub(crate) async fn get<O>(&self, path: &str) -> Result<O, OpenAIError>
202 where
203 O: DeserializeOwned,
204 {
205 let request_maker = || async {
206 Ok(self
207 .http_client
208 .get(self.config.url(path))
209 .query(&self.config.query())
210 .headers(self.config.headers())
211 .build()?)
212 };
213
214 self.execute(request_maker).await
215 }
216
217 pub(crate) async fn get_with_query<Q, O>(&self, path: &str, query: &Q) -> Result<O, OpenAIError>
219 where
220 O: DeserializeOwned,
221 Q: Serialize + ?Sized,
222 {
223 let request_maker = || async {
224 Ok(self
225 .http_client
226 .get(self.config.url(path))
227 .query(&self.config.query())
228 .query(query)
229 .headers(self.config.headers())
230 .build()?)
231 };
232
233 self.execute(request_maker).await
234 }
235
236 pub(crate) async fn delete<O>(&self, path: &str) -> Result<O, OpenAIError>
238 where
239 O: DeserializeOwned,
240 {
241 let request_maker = || async {
242 Ok(self
243 .http_client
244 .delete(self.config.url(path))
245 .query(&self.config.query())
246 .headers(self.config.headers())
247 .build()?)
248 };
249
250 self.execute(request_maker).await
251 }
252
253 pub(crate) async fn get_raw(&self, path: &str) -> Result<(Bytes, HeaderMap), OpenAIError> {
255 let request_maker = || async {
256 Ok(self
257 .http_client
258 .get(self.config.url(path))
259 .query(&self.config.query())
260 .headers(self.config.headers())
261 .build()?)
262 };
263
264 self.execute_raw(request_maker).await
265 }
266
267 pub(crate) async fn get_raw_with_query<Q>(
268 &self,
269 path: &str,
270 query: &Q,
271 ) -> Result<(Bytes, HeaderMap), OpenAIError>
272 where
273 Q: Serialize + ?Sized,
274 {
275 let request_maker = || async {
276 Ok(self
277 .http_client
278 .get(self.config.url(path))
279 .query(&self.config.query())
280 .query(query)
281 .headers(self.config.headers())
282 .build()?)
283 };
284
285 self.execute_raw(request_maker).await
286 }
287
288 pub(crate) async fn post_raw<I>(
290 &self,
291 path: &str,
292 request: I,
293 ) -> Result<(Bytes, HeaderMap), OpenAIError>
294 where
295 I: Serialize,
296 {
297 let request_maker = || async {
298 Ok(self
299 .http_client
300 .post(self.config.url(path))
301 .query(&self.config.query())
302 .headers(self.config.headers())
303 .json(&request)
304 .build()?)
305 };
306
307 self.execute_raw(request_maker).await
308 }
309
310 pub(crate) async fn post<I, O>(&self, path: &str, request: I) -> Result<O, OpenAIError>
312 where
313 I: Serialize,
314 O: DeserializeOwned,
315 {
316 let request_maker = || async {
317 Ok(self
318 .http_client
319 .post(self.config.url(path))
320 .query(&self.config.query())
321 .headers(self.config.headers())
322 .json(&request)
323 .build()?)
324 };
325
326 self.execute(request_maker).await
327 }
328
329 pub(crate) async fn post_form_raw<F>(
331 &self,
332 path: &str,
333 form: F,
334 ) -> Result<(Bytes, HeaderMap), OpenAIError>
335 where
336 Form: AsyncTryFrom<F, Error = OpenAIError>,
337 F: Clone,
338 {
339 let request_maker = || async {
340 Ok(self
341 .http_client
342 .post(self.config.url(path))
343 .query(&self.config.query())
344 .headers(self.config.headers())
345 .multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
346 .build()?)
347 };
348
349 self.execute_raw(request_maker).await
350 }
351
352 pub(crate) async fn post_form<O, F>(&self, path: &str, form: F) -> Result<O, OpenAIError>
354 where
355 O: DeserializeOwned,
356 Form: AsyncTryFrom<F, Error = OpenAIError>,
357 F: Clone,
358 {
359 let request_maker = || async {
360 Ok(self
361 .http_client
362 .post(self.config.url(path))
363 .query(&self.config.query())
364 .headers(self.config.headers())
365 .multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
366 .build()?)
367 };
368
369 self.execute(request_maker).await
370 }
371
372 pub(crate) async fn post_form_stream<O, F>(
373 &self,
374 path: &str,
375 form: F,
376 ) -> Result<Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>, OpenAIError>
377 where
378 F: Clone,
379 Form: AsyncTryFrom<F, Error = OpenAIError>,
380 O: DeserializeOwned + std::marker::Send + 'static,
381 {
382 let response = self
385 .http_client
386 .post(self.config.url(path))
387 .query(&self.config.query())
388 .multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
389 .headers(self.config.headers())
390 .send()
391 .await
392 .map_err(OpenAIError::Reqwest)?;
393
394 if !response.status().is_success() {
396 return Err(read_response(response).await.unwrap_err());
397 }
398
399 let stream = response
401 .bytes_stream()
402 .map(|result| result.map_err(std::io::Error::other));
403 let event_stream = eventsource_stream::EventStream::new(stream);
404
405 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
407
408 tokio::spawn(async move {
409 use futures::StreamExt;
410 let mut event_stream = std::pin::pin!(event_stream);
411
412 while let Some(event_result) = event_stream.next().await {
413 match event_result {
414 Err(e) => {
415 if let Err(_e) = tx.send(Err(OpenAIError::StreamError(Box::new(
416 StreamError::EventStream(e.to_string()),
417 )))) {
418 break;
419 }
420 }
421 Ok(event) => {
422 if event.data == "[DONE]" {
424 break;
425 }
426
427 let response = match serde_json::from_str::<O>(&event.data) {
428 Err(e) => Err(map_deserialization_error(e, event.data.as_bytes())),
429 Ok(output) => Ok(output),
430 };
431
432 if let Err(_e) = tx.send(response) {
433 break;
434 }
435 }
436 }
437 }
438 });
439
440 Ok(Box::pin(
441 tokio_stream::wrappers::UnboundedReceiverStream::new(rx),
442 ))
443 }
444
445 async fn execute_raw<M, Fut>(&self, request_maker: M) -> Result<(Bytes, HeaderMap), OpenAIError>
451 where
452 M: Fn() -> Fut,
453 Fut: core::future::Future<Output = Result<reqwest::Request, OpenAIError>>,
454 {
455 let client = self.http_client.clone();
456
457 backoff::future::retry(self.backoff.clone(), || async {
458 let request = request_maker().await.map_err(backoff::Error::Permanent)?;
459 let response = client
460 .execute(request)
461 .await
462 .map_err(OpenAIError::Reqwest)
463 .map_err(backoff::Error::Permanent)?;
464
465 let status = response.status();
466
467 match read_response(response).await {
468 Ok((bytes, headers)) => Ok((bytes, headers)),
469 Err(e) => {
470 match e {
471 OpenAIError::ApiError(api_error) => {
472 if status.is_server_error() {
473 Err(backoff::Error::Transient {
474 err: OpenAIError::ApiError(api_error),
475 retry_after: None,
476 })
477 } else if status.as_u16() == 429
478 && api_error.r#type != Some("insufficient_quota".to_string())
479 {
480 tracing::warn!("Rate limited: {}", api_error.message);
482 Err(backoff::Error::Transient {
483 err: OpenAIError::ApiError(api_error),
484 retry_after: None,
485 })
486 } else {
487 Err(backoff::Error::Permanent(OpenAIError::ApiError(api_error)))
488 }
489 }
490 _ => Err(backoff::Error::Permanent(e)),
491 }
492 }
493 }
494 })
495 .await
496 }
497
498 async fn execute<O, M, Fut>(&self, request_maker: M) -> Result<O, OpenAIError>
504 where
505 O: DeserializeOwned,
506 M: Fn() -> Fut,
507 Fut: core::future::Future<Output = Result<reqwest::Request, OpenAIError>>,
508 {
509 let (bytes, _headers) = self.execute_raw(request_maker).await?;
510
511 let response: O = serde_json::from_slice(bytes.as_ref())
512 .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
513
514 Ok(response)
515 }
516
517 pub(crate) async fn post_stream<I, O>(
519 &self,
520 path: &str,
521 request: I,
522 ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
523 where
524 I: Serialize,
525 O: DeserializeOwned + std::marker::Send + 'static,
526 {
527 let event_source = self
528 .http_client
529 .post(self.config.url(path))
530 .query(&self.config.query())
531 .headers(self.config.headers())
532 .json(&request)
533 .eventsource()
534 .unwrap();
535
536 stream(event_source).await
537 }
538
539 pub(crate) async fn post_stream_mapped_raw_events<I, O>(
540 &self,
541 path: &str,
542 request: I,
543 event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
544 ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
545 where
546 I: Serialize,
547 O: DeserializeOwned + std::marker::Send + 'static,
548 {
549 let event_source = self
550 .http_client
551 .post(self.config.url(path))
552 .query(&self.config.query())
553 .headers(self.config.headers())
554 .json(&request)
555 .eventsource()
556 .unwrap();
557
558 stream_mapped_raw_events(event_source, event_mapper).await
559 }
560
561 pub(crate) async fn _get_stream<Q, O>(
563 &self,
564 path: &str,
565 query: &Q,
566 ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
567 where
568 Q: Serialize + ?Sized,
569 O: DeserializeOwned + std::marker::Send + 'static,
570 {
571 let event_source = self
572 .http_client
573 .get(self.config.url(path))
574 .query(query)
575 .query(&self.config.query())
576 .headers(self.config.headers())
577 .eventsource()
578 .unwrap();
579
580 stream(event_source).await
581 }
582}
583
584async fn read_response(response: Response) -> Result<(Bytes, HeaderMap), OpenAIError> {
585 let status = response.status();
586 let headers = response.headers().clone();
587 let bytes = response.bytes().await.map_err(OpenAIError::Reqwest)?;
588
589 if status.is_server_error() {
590 let message: String = String::from_utf8_lossy(&bytes).into_owned();
592 tracing::warn!("Server error: {status} - {message}");
593 return Err(OpenAIError::ApiError(ApiError {
594 message,
595 r#type: None,
596 param: None,
597 code: None,
598 }));
599 }
600
601 if !status.is_success() {
603 let wrapped_error: WrappedError = serde_json::from_slice(bytes.as_ref())
604 .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
605
606 return Err(OpenAIError::ApiError(wrapped_error.error));
607 }
608
609 Ok((bytes, headers))
610}
611
612async fn map_stream_error(value: EventSourceError) -> OpenAIError {
613 match value {
614 EventSourceError::InvalidStatusCode(status_code, response) => {
615 read_response(response).await.expect_err(&format!(
616 "Unreachable because read_response returns err when status_code {status_code} is invalid"
617 ))
618 }
619 _ => OpenAIError::StreamError(Box::new(StreamError::ReqwestEventSource(value))),
620 }
621}
622
623pub(crate) async fn stream<O>(
626 mut event_source: EventSource,
627) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
628where
629 O: DeserializeOwned + std::marker::Send + 'static,
630{
631 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
632
633 tokio::spawn(async move {
634 while let Some(ev) = event_source.next().await {
635 match ev {
636 Err(e) => {
637 if let Err(_e) = tx.send(Err(map_stream_error(e).await)) {
638 break;
640 }
641 }
642 Ok(event) => match event {
643 Event::Message(message) => {
644 if message.data == "[DONE]" {
645 break;
646 }
647
648 let response = match serde_json::from_str::<O>(&message.data) {
649 Err(e) => Err(map_deserialization_error(e, message.data.as_bytes())),
650 Ok(output) => Ok(output),
651 };
652
653 if let Err(_e) = tx.send(response) {
654 break;
656 }
657 }
658 Event::Open => continue,
659 },
660 }
661 }
662
663 event_source.close();
664 });
665
666 Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
667}
668
669pub(crate) async fn stream_mapped_raw_events<O>(
670 mut event_source: EventSource,
671 event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
672) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
673where
674 O: DeserializeOwned + std::marker::Send + 'static,
675{
676 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
677
678 tokio::spawn(async move {
679 while let Some(ev) = event_source.next().await {
680 match ev {
681 Err(e) => {
682 if let Err(_e) = tx.send(Err(map_stream_error(e).await)) {
683 break;
685 }
686 }
687 Ok(event) => match event {
688 Event::Message(message) => {
689 let mut done = false;
690
691 if message.data == "[DONE]" {
692 done = true;
693 }
694
695 let response = event_mapper(message);
696
697 if let Err(_e) = tx.send(response) {
698 break;
700 }
701
702 if done {
703 break;
704 }
705 }
706 Event::Open => continue,
707 },
708 }
709 }
710
711 event_source.close();
712 });
713
714 Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
715}