1use std::pin::Pin;
2
3use bytes::Bytes;
4use futures::{stream::StreamExt, Stream};
5use reqwest::{multipart::Form, Response};
6use reqwest_eventsource::{Error as EventSourceError, Event, EventSource, RequestBuilderExt};
7use serde::{de::DeserializeOwned, Serialize};
8
9use crate::{
10 config::{Config, OpenAIConfig},
11 error::{map_deserialization_error, ApiError, OpenAIError, StreamError, WrappedError},
12 file::Files,
13 image::Images,
14 moderation::Moderations,
15 traits::AsyncTryFrom,
16 Assistants, Audio, AuditLogs, Batches, Chat, Completions, Containers, Conversations,
17 Embeddings, FineTuning, Invites, Models, Projects, Responses, Threads, Uploads, Users,
18 VectorStores, Videos,
19};
20
21#[derive(Debug, Clone, Default)]
22pub struct Client<C: Config> {
25 http_client: reqwest::Client,
26 config: C,
27 backoff: backoff::ExponentialBackoff,
28}
29
30impl Client<OpenAIConfig> {
31 pub fn new() -> Self {
33 Self::default()
34 }
35}
36
37impl<C: Config> Client<C> {
38 pub fn build(
40 http_client: reqwest::Client,
41 config: C,
42 backoff: backoff::ExponentialBackoff,
43 ) -> Self {
44 Self {
45 http_client,
46 config,
47 backoff,
48 }
49 }
50
51 pub fn with_config(config: C) -> Self {
53 Self {
54 http_client: reqwest::Client::new(),
55 config,
56 backoff: Default::default(),
57 }
58 }
59
60 pub fn with_http_client(mut self, http_client: reqwest::Client) -> Self {
64 self.http_client = http_client;
65 self
66 }
67
68 pub fn with_backoff(mut self, backoff: backoff::ExponentialBackoff) -> Self {
70 self.backoff = backoff;
71 self
72 }
73
74 pub fn models(&self) -> Models<'_, C> {
78 Models::new(self)
79 }
80
81 pub fn completions(&self) -> Completions<'_, C> {
83 Completions::new(self)
84 }
85
86 pub fn chat(&self) -> Chat<'_, C> {
88 Chat::new(self)
89 }
90
91 pub fn images(&self) -> Images<'_, C> {
93 Images::new(self)
94 }
95
96 pub fn moderations(&self) -> Moderations<'_, C> {
98 Moderations::new(self)
99 }
100
101 pub fn files(&self) -> Files<'_, C> {
103 Files::new(self)
104 }
105
106 pub fn uploads(&self) -> Uploads<'_, C> {
108 Uploads::new(self)
109 }
110
111 pub fn fine_tuning(&self) -> FineTuning<'_, C> {
113 FineTuning::new(self)
114 }
115
116 pub fn embeddings(&self) -> Embeddings<'_, C> {
118 Embeddings::new(self)
119 }
120
121 pub fn audio(&self) -> Audio<'_, C> {
123 Audio::new(self)
124 }
125
126 pub fn videos(&self) -> Videos<'_, C> {
128 Videos::new(self)
129 }
130
131 pub fn assistants(&self) -> Assistants<'_, C> {
133 Assistants::new(self)
134 }
135
136 pub fn threads(&self) -> Threads<'_, C> {
138 Threads::new(self)
139 }
140
141 pub fn vector_stores(&self) -> VectorStores<'_, C> {
143 VectorStores::new(self)
144 }
145
146 pub fn batches(&self) -> Batches<'_, C> {
148 Batches::new(self)
149 }
150
151 pub fn audit_logs(&self) -> AuditLogs<'_, C> {
153 AuditLogs::new(self)
154 }
155
156 pub fn invites(&self) -> Invites<'_, C> {
158 Invites::new(self)
159 }
160
161 pub fn users(&self) -> Users<'_, C> {
163 Users::new(self)
164 }
165
166 pub fn projects(&self) -> Projects<'_, C> {
168 Projects::new(self)
169 }
170
171 pub fn responses(&self) -> Responses<'_, C> {
173 Responses::new(self)
174 }
175
176 pub fn conversations(&self) -> Conversations<'_, C> {
178 Conversations::new(self)
179 }
180
181 pub fn containers(&self) -> Containers<'_, C> {
183 Containers::new(self)
184 }
185
186 pub fn config(&self) -> &C {
187 &self.config
188 }
189
190 pub(crate) async fn get<O>(&self, path: &str) -> Result<O, OpenAIError>
192 where
193 O: DeserializeOwned,
194 {
195 let request_maker = || async {
196 Ok(self
197 .http_client
198 .get(self.config.url(path))
199 .query(&self.config.query())
200 .headers(self.config.headers())
201 .build()?)
202 };
203
204 self.execute(request_maker).await
205 }
206
207 pub(crate) async fn get_with_query<Q, O>(&self, path: &str, query: &Q) -> Result<O, OpenAIError>
209 where
210 O: DeserializeOwned,
211 Q: Serialize + ?Sized,
212 {
213 let request_maker = || async {
214 Ok(self
215 .http_client
216 .get(self.config.url(path))
217 .query(&self.config.query())
218 .query(query)
219 .headers(self.config.headers())
220 .build()?)
221 };
222
223 self.execute(request_maker).await
224 }
225
226 pub(crate) async fn delete<O>(&self, path: &str) -> Result<O, OpenAIError>
228 where
229 O: DeserializeOwned,
230 {
231 let request_maker = || async {
232 Ok(self
233 .http_client
234 .delete(self.config.url(path))
235 .query(&self.config.query())
236 .headers(self.config.headers())
237 .build()?)
238 };
239
240 self.execute(request_maker).await
241 }
242
243 pub(crate) async fn get_raw(&self, path: &str) -> Result<Bytes, OpenAIError> {
245 let request_maker = || async {
246 Ok(self
247 .http_client
248 .get(self.config.url(path))
249 .query(&self.config.query())
250 .headers(self.config.headers())
251 .build()?)
252 };
253
254 self.execute_raw(request_maker).await
255 }
256
257 pub(crate) async fn get_raw_with_query<Q>(
258 &self,
259 path: &str,
260 query: &Q,
261 ) -> Result<Bytes, OpenAIError>
262 where
263 Q: Serialize + ?Sized,
264 {
265 let request_maker = || async {
266 Ok(self
267 .http_client
268 .get(self.config.url(path))
269 .query(&self.config.query())
270 .query(query)
271 .headers(self.config.headers())
272 .build()?)
273 };
274
275 self.execute_raw(request_maker).await
276 }
277
278 pub(crate) async fn post_raw<I>(&self, path: &str, request: I) -> Result<Bytes, OpenAIError>
280 where
281 I: Serialize,
282 {
283 let request_maker = || async {
284 Ok(self
285 .http_client
286 .post(self.config.url(path))
287 .query(&self.config.query())
288 .headers(self.config.headers())
289 .json(&request)
290 .build()?)
291 };
292
293 self.execute_raw(request_maker).await
294 }
295
296 pub(crate) async fn post<I, O>(&self, path: &str, request: I) -> Result<O, OpenAIError>
298 where
299 I: Serialize,
300 O: DeserializeOwned,
301 {
302 let request_maker = || async {
303 Ok(self
304 .http_client
305 .post(self.config.url(path))
306 .query(&self.config.query())
307 .headers(self.config.headers())
308 .json(&request)
309 .build()?)
310 };
311
312 self.execute(request_maker).await
313 }
314
315 pub(crate) async fn post_form_raw<F>(&self, path: &str, form: F) -> Result<Bytes, OpenAIError>
317 where
318 Form: AsyncTryFrom<F, Error = OpenAIError>,
319 F: Clone,
320 {
321 let request_maker = || async {
322 Ok(self
323 .http_client
324 .post(self.config.url(path))
325 .query(&self.config.query())
326 .headers(self.config.headers())
327 .multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
328 .build()?)
329 };
330
331 self.execute_raw(request_maker).await
332 }
333
334 pub(crate) async fn post_form<O, F>(&self, path: &str, form: F) -> Result<O, OpenAIError>
336 where
337 O: DeserializeOwned,
338 Form: AsyncTryFrom<F, Error = OpenAIError>,
339 F: Clone,
340 {
341 let request_maker = || async {
342 Ok(self
343 .http_client
344 .post(self.config.url(path))
345 .query(&self.config.query())
346 .headers(self.config.headers())
347 .multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
348 .build()?)
349 };
350
351 self.execute(request_maker).await
352 }
353
354 pub(crate) async fn post_form_stream<O, F>(
355 &self,
356 path: &str,
357 form: F,
358 ) -> Result<Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>, OpenAIError>
359 where
360 F: Clone,
361 Form: AsyncTryFrom<F, Error = OpenAIError>,
362 O: DeserializeOwned + std::marker::Send + 'static,
363 {
364 let response = self
367 .http_client
368 .post(self.config.url(path))
369 .query(&self.config.query())
370 .multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
371 .headers(self.config.headers())
372 .send()
373 .await
374 .map_err(OpenAIError::Reqwest)?;
375
376 if !response.status().is_success() {
378 return Err(read_response(response).await.unwrap_err());
379 }
380
381 let stream = response
383 .bytes_stream()
384 .map(|result| result.map_err(std::io::Error::other));
385 let event_stream = eventsource_stream::EventStream::new(stream);
386
387 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
389
390 tokio::spawn(async move {
391 use futures::StreamExt;
392 let mut event_stream = std::pin::pin!(event_stream);
393
394 while let Some(event_result) = event_stream.next().await {
395 match event_result {
396 Err(e) => {
397 if let Err(_e) = tx.send(Err(OpenAIError::StreamError(Box::new(
398 StreamError::EventStream(e.to_string()),
399 )))) {
400 break;
401 }
402 }
403 Ok(event) => {
404 if event.data == "[DONE]" {
406 break;
407 }
408
409 let response = match serde_json::from_str::<O>(&event.data) {
410 Err(e) => Err(map_deserialization_error(e, event.data.as_bytes())),
411 Ok(output) => Ok(output),
412 };
413
414 if let Err(_e) = tx.send(response) {
415 break;
416 }
417 }
418 }
419 }
420 });
421
422 Ok(Box::pin(
423 tokio_stream::wrappers::UnboundedReceiverStream::new(rx),
424 ))
425 }
426
427 async fn execute_raw<M, Fut>(&self, request_maker: M) -> Result<Bytes, OpenAIError>
433 where
434 M: Fn() -> Fut,
435 Fut: core::future::Future<Output = Result<reqwest::Request, OpenAIError>>,
436 {
437 let client = self.http_client.clone();
438
439 backoff::future::retry(self.backoff.clone(), || async {
440 let request = request_maker().await.map_err(backoff::Error::Permanent)?;
441 let response = client
442 .execute(request)
443 .await
444 .map_err(OpenAIError::Reqwest)
445 .map_err(backoff::Error::Permanent)?;
446
447 let status = response.status();
448
449 match read_response(response).await {
450 Ok(bytes) => Ok(bytes),
451 Err(e) => {
452 match e {
453 OpenAIError::ApiError(api_error) => {
454 if status.is_server_error() {
455 Err(backoff::Error::Transient {
456 err: OpenAIError::ApiError(api_error),
457 retry_after: None,
458 })
459 } else if status.as_u16() == 429
460 && api_error.r#type != Some("insufficient_quota".to_string())
461 {
462 tracing::warn!("Rate limited: {}", api_error.message);
464 Err(backoff::Error::Transient {
465 err: OpenAIError::ApiError(api_error),
466 retry_after: None,
467 })
468 } else {
469 Err(backoff::Error::Permanent(OpenAIError::ApiError(api_error)))
470 }
471 }
472 _ => Err(backoff::Error::Permanent(e)),
473 }
474 }
475 }
476 })
477 .await
478 }
479
480 async fn execute<O, M, Fut>(&self, request_maker: M) -> Result<O, OpenAIError>
486 where
487 O: DeserializeOwned,
488 M: Fn() -> Fut,
489 Fut: core::future::Future<Output = Result<reqwest::Request, OpenAIError>>,
490 {
491 let bytes = self.execute_raw(request_maker).await?;
492
493 let response: O = serde_json::from_slice(bytes.as_ref())
494 .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
495
496 Ok(response)
497 }
498
499 pub(crate) async fn post_stream<I, O>(
501 &self,
502 path: &str,
503 request: I,
504 ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
505 where
506 I: Serialize,
507 O: DeserializeOwned + std::marker::Send + 'static,
508 {
509 let event_source = self
510 .http_client
511 .post(self.config.url(path))
512 .query(&self.config.query())
513 .headers(self.config.headers())
514 .json(&request)
515 .eventsource()
516 .unwrap();
517
518 stream(event_source).await
519 }
520
521 pub(crate) async fn post_stream_mapped_raw_events<I, O>(
522 &self,
523 path: &str,
524 request: I,
525 event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
526 ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
527 where
528 I: Serialize,
529 O: DeserializeOwned + std::marker::Send + 'static,
530 {
531 let event_source = self
532 .http_client
533 .post(self.config.url(path))
534 .query(&self.config.query())
535 .headers(self.config.headers())
536 .json(&request)
537 .eventsource()
538 .unwrap();
539
540 stream_mapped_raw_events(event_source, event_mapper).await
541 }
542
543 pub(crate) async fn _get_stream<Q, O>(
545 &self,
546 path: &str,
547 query: &Q,
548 ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
549 where
550 Q: Serialize + ?Sized,
551 O: DeserializeOwned + std::marker::Send + 'static,
552 {
553 let event_source = self
554 .http_client
555 .get(self.config.url(path))
556 .query(query)
557 .query(&self.config.query())
558 .headers(self.config.headers())
559 .eventsource()
560 .unwrap();
561
562 stream(event_source).await
563 }
564}
565
566async fn read_response(response: Response) -> Result<Bytes, OpenAIError> {
567 let status = response.status();
568 let bytes = response.bytes().await.map_err(OpenAIError::Reqwest)?;
569
570 if status.is_server_error() {
571 let message: String = String::from_utf8_lossy(&bytes).into_owned();
573 tracing::warn!("Server error: {status} - {message}");
574 return Err(OpenAIError::ApiError(ApiError {
575 message,
576 r#type: None,
577 param: None,
578 code: None,
579 }));
580 }
581
582 if !status.is_success() {
584 let wrapped_error: WrappedError = serde_json::from_slice(bytes.as_ref())
585 .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
586
587 return Err(OpenAIError::ApiError(wrapped_error.error));
588 }
589
590 Ok(bytes)
591}
592
593async fn map_stream_error(value: EventSourceError) -> OpenAIError {
594 match value {
595 EventSourceError::InvalidStatusCode(status_code, response) => {
596 read_response(response).await.expect_err(&format!(
597 "Unreachable because read_response returns err when status_code {status_code} is invalid"
598 ))
599 }
600 _ => OpenAIError::StreamError(Box::new(StreamError::ReqwestEventSource(value))),
601 }
602}
603
604pub(crate) async fn stream<O>(
607 mut event_source: EventSource,
608) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
609where
610 O: DeserializeOwned + std::marker::Send + 'static,
611{
612 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
613
614 tokio::spawn(async move {
615 while let Some(ev) = event_source.next().await {
616 match ev {
617 Err(e) => {
618 if let Err(_e) = tx.send(Err(map_stream_error(e).await)) {
619 break;
621 }
622 }
623 Ok(event) => match event {
624 Event::Message(message) => {
625 if message.data == "[DONE]" {
626 break;
627 }
628
629 let response = match serde_json::from_str::<O>(&message.data) {
630 Err(e) => Err(map_deserialization_error(e, message.data.as_bytes())),
631 Ok(output) => Ok(output),
632 };
633
634 if let Err(_e) = tx.send(response) {
635 break;
637 }
638 }
639 Event::Open => continue,
640 },
641 }
642 }
643
644 event_source.close();
645 });
646
647 Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
648}
649
650pub(crate) async fn stream_mapped_raw_events<O>(
651 mut event_source: EventSource,
652 event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
653) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
654where
655 O: DeserializeOwned + std::marker::Send + 'static,
656{
657 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
658
659 tokio::spawn(async move {
660 while let Some(ev) = event_source.next().await {
661 match ev {
662 Err(e) => {
663 if let Err(_e) = tx.send(Err(map_stream_error(e).await)) {
664 break;
666 }
667 }
668 Ok(event) => match event {
669 Event::Message(message) => {
670 let mut done = false;
671
672 if message.data == "[DONE]" {
673 done = true;
674 }
675
676 let response = event_mapper(message);
677
678 if let Err(_e) = tx.send(response) {
679 break;
681 }
682
683 if done {
684 break;
685 }
686 }
687 Event::Open => continue,
688 },
689 }
690 }
691
692 event_source.close();
693 });
694
695 Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
696}