1use std::pin::Pin;
2
3use bytes::Bytes;
4use futures::{stream::StreamExt, Stream};
5use reqwest::{header::HeaderMap, multipart::Form, Response};
6use reqwest_eventsource::{Error as EventSourceError, Event, EventSource, RequestBuilderExt};
7use serde::{de::DeserializeOwned, Serialize};
8
9use crate::{
10 admin::Admin,
11 chatkit::Chatkit,
12 config::{Config, OpenAIConfig},
13 error::{map_deserialization_error, ApiError, OpenAIError, StreamError, WrappedError},
14 file::Files,
15 image::Images,
16 moderation::Moderations,
17 traits::AsyncTryFrom,
18 Assistants, Audio, Batches, Chat, Completions, Containers, Conversations, Embeddings, Evals,
19 FineTuning, Models, RequestOptions, Responses, Threads, Uploads, Usage, VectorStores, Videos,
20};
21
22#[cfg(feature = "realtime")]
23use crate::Realtime;
24
25#[derive(Debug, Clone, Default)]
26pub struct Client<C: Config> {
29 http_client: reqwest::Client,
30 config: C,
31 backoff: backoff::ExponentialBackoff,
32}
33
34impl Client<OpenAIConfig> {
35 pub fn new() -> Self {
37 Self::default()
38 }
39}
40
41impl<C: Config> Client<C> {
42 pub fn build(
44 http_client: reqwest::Client,
45 config: C,
46 backoff: backoff::ExponentialBackoff,
47 ) -> Self {
48 Self {
49 http_client,
50 config,
51 backoff,
52 }
53 }
54
55 pub fn with_config(config: C) -> Self {
57 Self {
58 http_client: reqwest::Client::new(),
59 config,
60 backoff: Default::default(),
61 }
62 }
63
64 pub fn with_http_client(mut self, http_client: reqwest::Client) -> Self {
68 self.http_client = http_client;
69 self
70 }
71
72 pub fn with_backoff(mut self, backoff: backoff::ExponentialBackoff) -> Self {
74 self.backoff = backoff;
75 self
76 }
77
78 pub fn models(&self) -> Models<'_, C> {
82 Models::new(self)
83 }
84
85 pub fn completions(&self) -> Completions<'_, C> {
87 Completions::new(self)
88 }
89
90 pub fn chat(&self) -> Chat<'_, C> {
92 Chat::new(self)
93 }
94
95 pub fn images(&self) -> Images<'_, C> {
97 Images::new(self)
98 }
99
100 pub fn moderations(&self) -> Moderations<'_, C> {
102 Moderations::new(self)
103 }
104
105 pub fn files(&self) -> Files<'_, C> {
107 Files::new(self)
108 }
109
110 pub fn uploads(&self) -> Uploads<'_, C> {
112 Uploads::new(self)
113 }
114
115 pub fn fine_tuning(&self) -> FineTuning<'_, C> {
117 FineTuning::new(self)
118 }
119
120 pub fn embeddings(&self) -> Embeddings<'_, C> {
122 Embeddings::new(self)
123 }
124
125 pub fn audio(&self) -> Audio<'_, C> {
127 Audio::new(self)
128 }
129
130 pub fn videos(&self) -> Videos<'_, C> {
132 Videos::new(self)
133 }
134
135 pub fn assistants(&self) -> Assistants<'_, C> {
137 Assistants::new(self)
138 }
139
140 pub fn threads(&self) -> Threads<'_, C> {
142 Threads::new(self)
143 }
144
145 pub fn vector_stores(&self) -> VectorStores<'_, C> {
147 VectorStores::new(self)
148 }
149
150 pub fn batches(&self) -> Batches<'_, C> {
152 Batches::new(self)
153 }
154
155 pub fn admin(&self) -> Admin<'_, C> {
158 Admin::new(self)
159 }
160
161 pub fn usage(&self) -> Usage<'_, C> {
163 Usage::new(self)
164 }
165
166 pub fn responses(&self) -> Responses<'_, C> {
168 Responses::new(self)
169 }
170
171 pub fn conversations(&self) -> Conversations<'_, C> {
173 Conversations::new(self)
174 }
175
176 pub fn containers(&self) -> Containers<'_, C> {
178 Containers::new(self)
179 }
180
181 pub fn evals(&self) -> Evals<'_, C> {
183 Evals::new(self)
184 }
185
186 pub fn chatkit(&self) -> Chatkit<'_, C> {
187 Chatkit::new(self)
188 }
189
190 #[cfg(feature = "realtime")]
191 pub fn realtime(&self) -> Realtime<'_, C> {
193 Realtime::new(self)
194 }
195
196 pub fn config(&self) -> &C {
197 &self.config
198 }
199
200 fn build_request_builder(
202 &self,
203 method: reqwest::Method,
204 path: &str,
205 request_options: &RequestOptions,
206 ) -> reqwest::RequestBuilder {
207 let mut request_builder = if let Some(path) = request_options.path() {
208 self.http_client
209 .request(method, self.config.url(path.as_str()))
210 } else {
211 self.http_client.request(method, self.config.url(path))
212 };
213
214 request_builder = request_builder
215 .query(&self.config.query())
216 .headers(self.config.headers());
217
218 if let Some(headers) = request_options.headers() {
219 request_builder = request_builder.headers(headers.clone());
220 }
221
222 if !request_options.query().is_empty() {
223 request_builder = request_builder.query(request_options.query());
224 }
225
226 request_builder
227 }
228
229 pub(crate) async fn get<O>(
231 &self,
232 path: &str,
233 request_options: &RequestOptions,
234 ) -> Result<O, OpenAIError>
235 where
236 O: DeserializeOwned,
237 {
238 let request_maker = || async {
239 Ok(self
240 .build_request_builder(reqwest::Method::GET, path, request_options)
241 .build()?)
242 };
243
244 self.execute(request_maker).await
245 }
246
247 pub(crate) async fn delete<O>(
249 &self,
250 path: &str,
251 request_options: &RequestOptions,
252 ) -> Result<O, OpenAIError>
253 where
254 O: DeserializeOwned,
255 {
256 let request_maker = || async {
257 Ok(self
258 .build_request_builder(reqwest::Method::DELETE, path, request_options)
259 .build()?)
260 };
261
262 self.execute(request_maker).await
263 }
264
265 pub(crate) async fn get_raw(
267 &self,
268 path: &str,
269 request_options: &RequestOptions,
270 ) -> Result<(Bytes, HeaderMap), OpenAIError> {
271 let request_maker = || async {
272 Ok(self
273 .build_request_builder(reqwest::Method::GET, path, request_options)
274 .build()?)
275 };
276
277 self.execute_raw(request_maker).await
278 }
279
280 pub(crate) async fn post_raw<I>(
282 &self,
283 path: &str,
284 request: I,
285 request_options: &RequestOptions,
286 ) -> Result<(Bytes, HeaderMap), OpenAIError>
287 where
288 I: Serialize,
289 {
290 let request_maker = || async {
291 Ok(self
292 .build_request_builder(reqwest::Method::POST, path, request_options)
293 .json(&request)
294 .build()?)
295 };
296
297 self.execute_raw(request_maker).await
298 }
299
300 pub(crate) async fn post<I, O>(
302 &self,
303 path: &str,
304 request: I,
305 request_options: &RequestOptions,
306 ) -> Result<O, OpenAIError>
307 where
308 I: Serialize,
309 O: DeserializeOwned,
310 {
311 let request_maker = || async {
312 Ok(self
313 .build_request_builder(reqwest::Method::POST, path, request_options)
314 .json(&request)
315 .build()?)
316 };
317
318 self.execute(request_maker).await
319 }
320
321 pub(crate) async fn post_form_raw<F>(
323 &self,
324 path: &str,
325 form: F,
326 request_options: &RequestOptions,
327 ) -> Result<(Bytes, HeaderMap), OpenAIError>
328 where
329 Form: AsyncTryFrom<F, Error = OpenAIError>,
330 F: Clone,
331 {
332 let request_maker = || async {
333 Ok(self
334 .build_request_builder(reqwest::Method::POST, path, request_options)
335 .multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
336 .build()?)
337 };
338
339 self.execute_raw(request_maker).await
340 }
341
342 pub(crate) async fn post_form<O, F>(
344 &self,
345 path: &str,
346 form: F,
347 request_options: &RequestOptions,
348 ) -> Result<O, OpenAIError>
349 where
350 O: DeserializeOwned,
351 Form: AsyncTryFrom<F, Error = OpenAIError>,
352 F: Clone,
353 {
354 let request_maker = || async {
355 Ok(self
356 .build_request_builder(reqwest::Method::POST, path, request_options)
357 .multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
358 .build()?)
359 };
360
361 self.execute(request_maker).await
362 }
363
364 pub(crate) async fn post_form_stream<O, F>(
365 &self,
366 path: &str,
367 form: F,
368 request_options: &RequestOptions,
369 ) -> Result<Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>, OpenAIError>
370 where
371 F: Clone,
372 Form: AsyncTryFrom<F, Error = OpenAIError>,
373 O: DeserializeOwned + std::marker::Send + 'static,
374 {
375 let request_builder = self
378 .build_request_builder(reqwest::Method::POST, path, request_options)
379 .multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?);
380
381 let response = request_builder.send().await.map_err(OpenAIError::Reqwest)?;
382
383 if !response.status().is_success() {
385 return Err(read_response(response).await.unwrap_err());
386 }
387
388 let stream = response
390 .bytes_stream()
391 .map(|result| result.map_err(std::io::Error::other));
392 let event_stream = eventsource_stream::EventStream::new(stream);
393
394 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
396
397 tokio::spawn(async move {
398 use futures::StreamExt;
399 let mut event_stream = std::pin::pin!(event_stream);
400
401 while let Some(event_result) = event_stream.next().await {
402 match event_result {
403 Err(e) => {
404 if let Err(_e) = tx.send(Err(OpenAIError::StreamError(Box::new(
405 StreamError::EventStream(e.to_string()),
406 )))) {
407 break;
408 }
409 }
410 Ok(event) => {
411 if event.data == "[DONE]" {
413 break;
414 }
415
416 let response = match serde_json::from_str::<O>(&event.data) {
417 Err(e) => Err(map_deserialization_error(e, event.data.as_bytes())),
418 Ok(output) => Ok(output),
419 };
420
421 if let Err(_e) = tx.send(response) {
422 break;
423 }
424 }
425 }
426 }
427 });
428
429 Ok(Box::pin(
430 tokio_stream::wrappers::UnboundedReceiverStream::new(rx),
431 ))
432 }
433
434 async fn execute_raw<M, Fut>(&self, request_maker: M) -> Result<(Bytes, HeaderMap), OpenAIError>
440 where
441 M: Fn() -> Fut,
442 Fut: core::future::Future<Output = Result<reqwest::Request, OpenAIError>>,
443 {
444 let client = self.http_client.clone();
445
446 backoff::future::retry(self.backoff.clone(), || async {
447 let request = request_maker().await.map_err(backoff::Error::Permanent)?;
448 let response = client
449 .execute(request)
450 .await
451 .map_err(OpenAIError::Reqwest)
452 .map_err(backoff::Error::Permanent)?;
453
454 let status = response.status();
455
456 match read_response(response).await {
457 Ok((bytes, headers)) => Ok((bytes, headers)),
458 Err(e) => {
459 match e {
460 OpenAIError::ApiError(api_error) => {
461 if status.is_server_error() {
462 Err(backoff::Error::Transient {
463 err: OpenAIError::ApiError(api_error),
464 retry_after: None,
465 })
466 } else if status.as_u16() == 429
467 && api_error.r#type != Some("insufficient_quota".to_string())
468 {
469 tracing::warn!("Rate limited: {}", api_error.message);
471 Err(backoff::Error::Transient {
472 err: OpenAIError::ApiError(api_error),
473 retry_after: None,
474 })
475 } else {
476 Err(backoff::Error::Permanent(OpenAIError::ApiError(api_error)))
477 }
478 }
479 _ => Err(backoff::Error::Permanent(e)),
480 }
481 }
482 }
483 })
484 .await
485 }
486
487 async fn execute<O, M, Fut>(&self, request_maker: M) -> Result<O, OpenAIError>
493 where
494 O: DeserializeOwned,
495 M: Fn() -> Fut,
496 Fut: core::future::Future<Output = Result<reqwest::Request, OpenAIError>>,
497 {
498 let (bytes, _headers) = self.execute_raw(request_maker).await?;
499
500 let response: O = serde_json::from_slice(bytes.as_ref())
501 .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
502
503 Ok(response)
504 }
505
506 pub(crate) async fn post_stream<I, O>(
508 &self,
509 path: &str,
510 request: I,
511 request_options: &RequestOptions,
512 ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
513 where
514 I: Serialize,
515 O: DeserializeOwned + std::marker::Send + 'static,
516 {
517 let request_builder = self
518 .build_request_builder(reqwest::Method::POST, path, request_options)
519 .json(&request);
520
521 let event_source = request_builder.eventsource().unwrap();
522
523 stream(event_source).await
524 }
525
526 pub(crate) async fn post_stream_mapped_raw_events<I, O>(
527 &self,
528 path: &str,
529 request: I,
530 request_options: &RequestOptions,
531 event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
532 ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
533 where
534 I: Serialize,
535 O: DeserializeOwned + std::marker::Send + 'static,
536 {
537 let request_builder = self
538 .build_request_builder(reqwest::Method::POST, path, request_options)
539 .json(&request);
540
541 let event_source = request_builder.eventsource().unwrap();
542
543 stream_mapped_raw_events(event_source, event_mapper).await
544 }
545
546 pub(crate) async fn _get_stream<Q, O>(
548 &self,
549 path: &str,
550 query: &Q,
551 ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
552 where
553 Q: Serialize + ?Sized,
554 O: DeserializeOwned + std::marker::Send + 'static,
555 {
556 let event_source = self
557 .http_client
558 .get(self.config.url(path))
559 .query(query)
560 .query(&self.config.query())
561 .headers(self.config.headers())
562 .eventsource()
563 .unwrap();
564
565 stream(event_source).await
566 }
567}
568
569async fn read_response(response: Response) -> Result<(Bytes, HeaderMap), OpenAIError> {
570 let status = response.status();
571 let headers = response.headers().clone();
572 let bytes = response.bytes().await.map_err(OpenAIError::Reqwest)?;
573
574 if status.is_server_error() {
575 let message: String = String::from_utf8_lossy(&bytes).into_owned();
577 tracing::warn!("Server error: {status} - {message}");
578 return Err(OpenAIError::ApiError(ApiError {
579 message,
580 r#type: None,
581 param: None,
582 code: None,
583 }));
584 }
585
586 if !status.is_success() {
588 let wrapped_error: WrappedError = serde_json::from_slice(bytes.as_ref())
589 .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
590
591 return Err(OpenAIError::ApiError(wrapped_error.error));
592 }
593
594 Ok((bytes, headers))
595}
596
597async fn map_stream_error(value: EventSourceError) -> OpenAIError {
598 match value {
599 EventSourceError::InvalidStatusCode(status_code, response) => {
600 read_response(response).await.expect_err(&format!(
601 "Unreachable because read_response returns err when status_code {status_code} is invalid"
602 ))
603 }
604 _ => OpenAIError::StreamError(Box::new(StreamError::ReqwestEventSource(value))),
605 }
606}
607
608pub(crate) async fn stream<O>(
611 mut event_source: EventSource,
612) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
613where
614 O: DeserializeOwned + std::marker::Send + 'static,
615{
616 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
617
618 tokio::spawn(async move {
619 while let Some(ev) = event_source.next().await {
620 match ev {
621 Err(e) => {
622 if let Err(_e) = tx.send(Err(map_stream_error(e).await)) {
623 break;
625 }
626 }
627 Ok(event) => match event {
628 Event::Message(message) => {
629 if message.data == "[DONE]" {
630 break;
631 }
632
633 let response = match serde_json::from_str::<O>(&message.data) {
634 Err(e) => Err(map_deserialization_error(e, message.data.as_bytes())),
635 Ok(output) => Ok(output),
636 };
637
638 if let Err(_e) = tx.send(response) {
639 break;
641 }
642 }
643 Event::Open => continue,
644 },
645 }
646 }
647
648 event_source.close();
649 });
650
651 Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
652}
653
654pub(crate) async fn stream_mapped_raw_events<O>(
655 mut event_source: EventSource,
656 event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
657) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
658where
659 O: DeserializeOwned + std::marker::Send + 'static,
660{
661 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
662
663 tokio::spawn(async move {
664 while let Some(ev) = event_source.next().await {
665 match ev {
666 Err(e) => {
667 if let Err(_e) = tx.send(Err(map_stream_error(e).await)) {
668 break;
670 }
671 }
672 Ok(event) => match event {
673 Event::Message(message) => {
674 let mut done = false;
675
676 if message.data == "[DONE]" {
677 done = true;
678 }
679
680 let response = event_mapper(message);
681
682 if let Err(_e) = tx.send(response) {
683 break;
685 }
686
687 if done {
688 break;
689 }
690 }
691 Event::Open => continue,
692 },
693 }
694 }
695
696 event_source.close();
697 });
698
699 Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
700}