1use std::pin::Pin;
2
3use bytes::Bytes;
4use futures::{stream::StreamExt, Stream};
5use reqwest::{header::HeaderMap, multipart::Form, Response};
6use reqwest_eventsource::{Error as EventSourceError, Event, EventSource, RequestBuilderExt};
7use serde::{de::DeserializeOwned, Serialize};
8
9use crate::{
10 admin::Admin,
11 chatkit::Chatkit,
12 config::{Config, OpenAIConfig},
13 error::{map_deserialization_error, ApiError, OpenAIError, StreamError, WrappedError},
14 file::Files,
15 image::Images,
16 moderation::Moderations,
17 traits::AsyncTryFrom,
18 Assistants, Audio, Batches, Chat, Completions, Containers, Conversations, Embeddings, Evals,
19 FineTuning, Models, RequestOptions, Responses, Threads, Uploads, VectorStores, Videos,
20};
21
22#[cfg(feature = "realtime")]
23use crate::Realtime;
24
25#[derive(Debug, Clone, Default)]
26pub struct Client<C: Config> {
29 http_client: reqwest::Client,
30 config: C,
31 backoff: backoff::ExponentialBackoff,
32}
33
34impl Client<OpenAIConfig> {
35 pub fn new() -> Self {
37 Self::default()
38 }
39}
40
41impl<C: Config> Client<C> {
42 pub fn build(
44 http_client: reqwest::Client,
45 config: C,
46 backoff: backoff::ExponentialBackoff,
47 ) -> Self {
48 Self {
49 http_client,
50 config,
51 backoff,
52 }
53 }
54
55 pub fn with_config(config: C) -> Self {
57 Self {
58 http_client: reqwest::Client::new(),
59 config,
60 backoff: Default::default(),
61 }
62 }
63
64 pub fn with_http_client(mut self, http_client: reqwest::Client) -> Self {
68 self.http_client = http_client;
69 self
70 }
71
72 pub fn with_backoff(mut self, backoff: backoff::ExponentialBackoff) -> Self {
74 self.backoff = backoff;
75 self
76 }
77
78 pub fn models(&self) -> Models<'_, C> {
82 Models::new(self)
83 }
84
85 pub fn completions(&self) -> Completions<'_, C> {
87 Completions::new(self)
88 }
89
90 pub fn chat(&self) -> Chat<'_, C> {
92 Chat::new(self)
93 }
94
95 pub fn images(&self) -> Images<'_, C> {
97 Images::new(self)
98 }
99
100 pub fn moderations(&self) -> Moderations<'_, C> {
102 Moderations::new(self)
103 }
104
105 pub fn files(&self) -> Files<'_, C> {
107 Files::new(self)
108 }
109
110 pub fn uploads(&self) -> Uploads<'_, C> {
112 Uploads::new(self)
113 }
114
115 pub fn fine_tuning(&self) -> FineTuning<'_, C> {
117 FineTuning::new(self)
118 }
119
120 pub fn embeddings(&self) -> Embeddings<'_, C> {
122 Embeddings::new(self)
123 }
124
125 pub fn audio(&self) -> Audio<'_, C> {
127 Audio::new(self)
128 }
129
130 pub fn videos(&self) -> Videos<'_, C> {
132 Videos::new(self)
133 }
134
135 pub fn assistants(&self) -> Assistants<'_, C> {
137 Assistants::new(self)
138 }
139
140 pub fn threads(&self) -> Threads<'_, C> {
142 Threads::new(self)
143 }
144
145 pub fn vector_stores(&self) -> VectorStores<'_, C> {
147 VectorStores::new(self)
148 }
149
150 pub fn batches(&self) -> Batches<'_, C> {
152 Batches::new(self)
153 }
154
155 pub fn admin(&self) -> Admin<'_, C> {
158 Admin::new(self)
159 }
160
161 pub fn responses(&self) -> Responses<'_, C> {
163 Responses::new(self)
164 }
165
166 pub fn conversations(&self) -> Conversations<'_, C> {
168 Conversations::new(self)
169 }
170
171 pub fn containers(&self) -> Containers<'_, C> {
173 Containers::new(self)
174 }
175
176 pub fn evals(&self) -> Evals<'_, C> {
178 Evals::new(self)
179 }
180
181 pub fn chatkit(&self) -> Chatkit<'_, C> {
182 Chatkit::new(self)
183 }
184
185 #[cfg(feature = "realtime")]
186 pub fn realtime(&self) -> Realtime<'_, C> {
188 Realtime::new(self)
189 }
190
191 pub fn config(&self) -> &C {
192 &self.config
193 }
194
195 fn build_request_builder(
197 &self,
198 method: reqwest::Method,
199 path: &str,
200 request_options: &RequestOptions,
201 ) -> reqwest::RequestBuilder {
202 let mut request_builder = if let Some(path) = request_options.path() {
203 self.http_client
204 .request(method, self.config.url(path.as_str()))
205 } else {
206 self.http_client.request(method, self.config.url(path))
207 };
208
209 request_builder = request_builder
210 .query(&self.config.query())
211 .headers(self.config.headers());
212
213 if let Some(headers) = request_options.headers() {
214 request_builder = request_builder.headers(headers.clone());
215 }
216
217 if !request_options.query().is_empty() {
218 request_builder = request_builder.query(request_options.query());
219 }
220
221 request_builder
222 }
223
224 pub(crate) async fn get<O>(
226 &self,
227 path: &str,
228 request_options: &RequestOptions,
229 ) -> Result<O, OpenAIError>
230 where
231 O: DeserializeOwned,
232 {
233 let request_maker = || async {
234 Ok(self
235 .build_request_builder(reqwest::Method::GET, path, request_options)
236 .build()?)
237 };
238
239 self.execute(request_maker).await
240 }
241
242 pub(crate) async fn delete<O>(
244 &self,
245 path: &str,
246 request_options: &RequestOptions,
247 ) -> Result<O, OpenAIError>
248 where
249 O: DeserializeOwned,
250 {
251 let request_maker = || async {
252 Ok(self
253 .build_request_builder(reqwest::Method::DELETE, path, request_options)
254 .build()?)
255 };
256
257 self.execute(request_maker).await
258 }
259
260 pub(crate) async fn get_raw(
262 &self,
263 path: &str,
264 request_options: &RequestOptions,
265 ) -> Result<(Bytes, HeaderMap), OpenAIError> {
266 let request_maker = || async {
267 Ok(self
268 .build_request_builder(reqwest::Method::GET, path, request_options)
269 .build()?)
270 };
271
272 self.execute_raw(request_maker).await
273 }
274
275 pub(crate) async fn post_raw<I>(
277 &self,
278 path: &str,
279 request: I,
280 request_options: &RequestOptions,
281 ) -> Result<(Bytes, HeaderMap), OpenAIError>
282 where
283 I: Serialize,
284 {
285 let request_maker = || async {
286 Ok(self
287 .build_request_builder(reqwest::Method::POST, path, request_options)
288 .json(&request)
289 .build()?)
290 };
291
292 self.execute_raw(request_maker).await
293 }
294
295 pub(crate) async fn post<I, O>(
297 &self,
298 path: &str,
299 request: I,
300 request_options: &RequestOptions,
301 ) -> Result<O, OpenAIError>
302 where
303 I: Serialize,
304 O: DeserializeOwned,
305 {
306 let request_maker = || async {
307 Ok(self
308 .build_request_builder(reqwest::Method::POST, path, request_options)
309 .json(&request)
310 .build()?)
311 };
312
313 self.execute(request_maker).await
314 }
315
316 pub(crate) async fn post_form_raw<F>(
318 &self,
319 path: &str,
320 form: F,
321 request_options: &RequestOptions,
322 ) -> Result<(Bytes, HeaderMap), OpenAIError>
323 where
324 Form: AsyncTryFrom<F, Error = OpenAIError>,
325 F: Clone,
326 {
327 let request_maker = || async {
328 Ok(self
329 .build_request_builder(reqwest::Method::POST, path, request_options)
330 .multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
331 .build()?)
332 };
333
334 self.execute_raw(request_maker).await
335 }
336
337 pub(crate) async fn post_form<O, F>(
339 &self,
340 path: &str,
341 form: F,
342 request_options: &RequestOptions,
343 ) -> Result<O, OpenAIError>
344 where
345 O: DeserializeOwned,
346 Form: AsyncTryFrom<F, Error = OpenAIError>,
347 F: Clone,
348 {
349 let request_maker = || async {
350 Ok(self
351 .build_request_builder(reqwest::Method::POST, path, request_options)
352 .multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
353 .build()?)
354 };
355
356 self.execute(request_maker).await
357 }
358
359 pub(crate) async fn post_form_stream<O, F>(
360 &self,
361 path: &str,
362 form: F,
363 request_options: &RequestOptions,
364 ) -> Result<Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>, OpenAIError>
365 where
366 F: Clone,
367 Form: AsyncTryFrom<F, Error = OpenAIError>,
368 O: DeserializeOwned + std::marker::Send + 'static,
369 {
370 let request_builder = self
373 .build_request_builder(reqwest::Method::POST, path, request_options)
374 .multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?);
375
376 let response = request_builder.send().await.map_err(OpenAIError::Reqwest)?;
377
378 if !response.status().is_success() {
380 return Err(read_response(response).await.unwrap_err());
381 }
382
383 let stream = response
385 .bytes_stream()
386 .map(|result| result.map_err(std::io::Error::other));
387 let event_stream = eventsource_stream::EventStream::new(stream);
388
389 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
391
392 tokio::spawn(async move {
393 use futures::StreamExt;
394 let mut event_stream = std::pin::pin!(event_stream);
395
396 while let Some(event_result) = event_stream.next().await {
397 match event_result {
398 Err(e) => {
399 if let Err(_e) = tx.send(Err(OpenAIError::StreamError(Box::new(
400 StreamError::EventStream(e.to_string()),
401 )))) {
402 break;
403 }
404 }
405 Ok(event) => {
406 if event.data == "[DONE]" {
408 break;
409 }
410
411 let response = match serde_json::from_str::<O>(&event.data) {
412 Err(e) => Err(map_deserialization_error(e, event.data.as_bytes())),
413 Ok(output) => Ok(output),
414 };
415
416 if let Err(_e) = tx.send(response) {
417 break;
418 }
419 }
420 }
421 }
422 });
423
424 Ok(Box::pin(
425 tokio_stream::wrappers::UnboundedReceiverStream::new(rx),
426 ))
427 }
428
429 async fn execute_raw<M, Fut>(&self, request_maker: M) -> Result<(Bytes, HeaderMap), OpenAIError>
435 where
436 M: Fn() -> Fut,
437 Fut: core::future::Future<Output = Result<reqwest::Request, OpenAIError>>,
438 {
439 let client = self.http_client.clone();
440
441 backoff::future::retry(self.backoff.clone(), || async {
442 let request = request_maker().await.map_err(backoff::Error::Permanent)?;
443 let response = client
444 .execute(request)
445 .await
446 .map_err(OpenAIError::Reqwest)
447 .map_err(backoff::Error::Permanent)?;
448
449 let status = response.status();
450
451 match read_response(response).await {
452 Ok((bytes, headers)) => Ok((bytes, headers)),
453 Err(e) => {
454 match e {
455 OpenAIError::ApiError(api_error) => {
456 if status.is_server_error() {
457 Err(backoff::Error::Transient {
458 err: OpenAIError::ApiError(api_error),
459 retry_after: None,
460 })
461 } else if status.as_u16() == 429
462 && api_error.r#type != Some("insufficient_quota".to_string())
463 {
464 tracing::warn!("Rate limited: {}", api_error.message);
466 Err(backoff::Error::Transient {
467 err: OpenAIError::ApiError(api_error),
468 retry_after: None,
469 })
470 } else {
471 Err(backoff::Error::Permanent(OpenAIError::ApiError(api_error)))
472 }
473 }
474 _ => Err(backoff::Error::Permanent(e)),
475 }
476 }
477 }
478 })
479 .await
480 }
481
482 async fn execute<O, M, Fut>(&self, request_maker: M) -> Result<O, OpenAIError>
488 where
489 O: DeserializeOwned,
490 M: Fn() -> Fut,
491 Fut: core::future::Future<Output = Result<reqwest::Request, OpenAIError>>,
492 {
493 let (bytes, _headers) = self.execute_raw(request_maker).await?;
494
495 let response: O = serde_json::from_slice(bytes.as_ref())
496 .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
497
498 Ok(response)
499 }
500
501 pub(crate) async fn post_stream<I, O>(
503 &self,
504 path: &str,
505 request: I,
506 request_options: &RequestOptions,
507 ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
508 where
509 I: Serialize,
510 O: DeserializeOwned + std::marker::Send + 'static,
511 {
512 let request_builder = self
513 .build_request_builder(reqwest::Method::POST, path, request_options)
514 .json(&request);
515
516 let event_source = request_builder.eventsource().unwrap();
517
518 stream(event_source).await
519 }
520
521 pub(crate) async fn post_stream_mapped_raw_events<I, O>(
522 &self,
523 path: &str,
524 request: I,
525 request_options: &RequestOptions,
526 event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
527 ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
528 where
529 I: Serialize,
530 O: DeserializeOwned + std::marker::Send + 'static,
531 {
532 let request_builder = self
533 .build_request_builder(reqwest::Method::POST, path, request_options)
534 .json(&request);
535
536 let event_source = request_builder.eventsource().unwrap();
537
538 stream_mapped_raw_events(event_source, event_mapper).await
539 }
540
541 pub(crate) async fn get_stream<O>(
543 &self,
544 path: &str,
545 request_options: &RequestOptions,
546 ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
547 where
548 O: DeserializeOwned + std::marker::Send + 'static,
549 {
550 let request_builder =
551 self.build_request_builder(reqwest::Method::GET, path, request_options);
552
553 let event_source = request_builder.eventsource().unwrap();
554
555 stream(event_source).await
556 }
557}
558
559async fn read_response(response: Response) -> Result<(Bytes, HeaderMap), OpenAIError> {
560 let status = response.status();
561 let headers = response.headers().clone();
562 let bytes = response.bytes().await.map_err(OpenAIError::Reqwest)?;
563
564 if status.is_server_error() {
565 let message: String = String::from_utf8_lossy(&bytes).into_owned();
567 tracing::warn!("Server error: {status} - {message}");
568 return Err(OpenAIError::ApiError(ApiError {
569 message,
570 r#type: None,
571 param: None,
572 code: None,
573 }));
574 }
575
576 if !status.is_success() {
578 let wrapped_error: WrappedError = serde_json::from_slice(bytes.as_ref())
579 .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
580
581 return Err(OpenAIError::ApiError(wrapped_error.error));
582 }
583
584 Ok((bytes, headers))
585}
586
587async fn map_stream_error(value: EventSourceError) -> OpenAIError {
588 match value {
589 EventSourceError::InvalidStatusCode(status_code, response) => {
590 read_response(response).await.expect_err(&format!(
591 "Unreachable because read_response returns err when status_code {status_code} is invalid"
592 ))
593 }
594 _ => OpenAIError::StreamError(Box::new(StreamError::ReqwestEventSource(value))),
595 }
596}
597
598pub(crate) async fn stream<O>(
601 mut event_source: EventSource,
602) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
603where
604 O: DeserializeOwned + std::marker::Send + 'static,
605{
606 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
607
608 tokio::spawn(async move {
609 while let Some(ev) = event_source.next().await {
610 match ev {
611 Err(e) => {
612 match &e {
615 EventSourceError::StreamEnded => {
616 break;
617 }
618 _ => {
619 if let Err(_e) = tx.send(Err(map_stream_error(e).await)) {
620 break;
622 }
623 }
624 }
625 }
626 Ok(event) => match event {
627 Event::Message(message) => {
628 if message.data == "[DONE]" {
629 break;
630 }
631
632 let response = match serde_json::from_str::<O>(&message.data) {
633 Err(e) => Err(map_deserialization_error(e, message.data.as_bytes())),
634 Ok(output) => Ok(output),
635 };
636
637 if let Err(_e) = tx.send(response) {
638 break;
640 }
641 }
642 Event::Open => continue,
643 },
644 }
645 }
646
647 event_source.close();
648 });
649
650 Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
651}
652
653pub(crate) async fn stream_mapped_raw_events<O>(
654 mut event_source: EventSource,
655 event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
656) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
657where
658 O: DeserializeOwned + std::marker::Send + 'static,
659{
660 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
661
662 tokio::spawn(async move {
663 while let Some(ev) = event_source.next().await {
664 match ev {
665 Err(e) => {
666 match &e {
669 EventSourceError::StreamEnded => {
670 break;
671 }
672 _ => {
673 if let Err(_e) = tx.send(Err(map_stream_error(e).await)) {
674 break;
676 }
677 }
678 }
679 }
680 Ok(event) => match event {
681 Event::Message(message) => {
682 let mut done = false;
683
684 if message.data == "[DONE]" {
685 done = true;
686 }
687
688 let response = event_mapper(message);
689
690 if let Err(_e) = tx.send(response) {
691 break;
693 }
694
695 if done {
696 break;
697 }
698 }
699 Event::Open => continue,
700 },
701 }
702 }
703
704 event_source.close();
705 });
706
707 Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
708}