1use std::pin::Pin;
2
3use bytes::Bytes;
4use futures::{stream::StreamExt, Stream};
5use reqwest::{multipart::Form, Response};
6use reqwest_eventsource::{Error as EventSourceError, Event, EventSource, RequestBuilderExt};
7use serde::{de::DeserializeOwned, Serialize};
8
9use crate::{
10 config::{Config, OpenAIConfig},
11 error::{map_deserialization_error, ApiError, OpenAIError, StreamError, WrappedError},
12 file::Files,
13 image::Images,
14 moderation::Moderations,
15 traits::AsyncTryFrom,
16 Assistants, Audio, AuditLogs, Batches, Chat, Completions, Embeddings, FineTuning, Invites,
17 Models, Projects, Responses, Threads, Uploads, Users, VectorStores, Videos,
18};
19
20#[derive(Debug, Clone, Default)]
21pub struct Client<C: Config> {
24 http_client: reqwest::Client,
25 config: C,
26 backoff: backoff::ExponentialBackoff,
27}
28
29impl Client<OpenAIConfig> {
30 pub fn new() -> Self {
32 Self::default()
33 }
34}
35
36impl<C: Config> Client<C> {
37 pub fn build(
39 http_client: reqwest::Client,
40 config: C,
41 backoff: backoff::ExponentialBackoff,
42 ) -> Self {
43 Self {
44 http_client,
45 config,
46 backoff,
47 }
48 }
49
50 pub fn with_config(config: C) -> Self {
52 Self {
53 http_client: reqwest::Client::new(),
54 config,
55 backoff: Default::default(),
56 }
57 }
58
59 pub fn with_http_client(mut self, http_client: reqwest::Client) -> Self {
63 self.http_client = http_client;
64 self
65 }
66
67 pub fn with_backoff(mut self, backoff: backoff::ExponentialBackoff) -> Self {
69 self.backoff = backoff;
70 self
71 }
72
73 pub fn models(&self) -> Models<C> {
77 Models::new(self)
78 }
79
80 pub fn completions(&self) -> Completions<C> {
82 Completions::new(self)
83 }
84
85 pub fn chat(&self) -> Chat<C> {
87 Chat::new(self)
88 }
89
90 pub fn images(&self) -> Images<C> {
92 Images::new(self)
93 }
94
95 pub fn moderations(&self) -> Moderations<C> {
97 Moderations::new(self)
98 }
99
100 pub fn files(&self) -> Files<C> {
102 Files::new(self)
103 }
104
105 pub fn uploads(&self) -> Uploads<C> {
107 Uploads::new(self)
108 }
109
110 pub fn fine_tuning(&self) -> FineTuning<C> {
112 FineTuning::new(self)
113 }
114
115 pub fn embeddings(&self) -> Embeddings<C> {
117 Embeddings::new(self)
118 }
119
120 pub fn audio(&self) -> Audio<C> {
122 Audio::new(self)
123 }
124
125 pub fn videos(&self) -> Videos<C> {
127 Videos::new(self)
128 }
129
130 pub fn assistants(&self) -> Assistants<C> {
132 Assistants::new(self)
133 }
134
135 pub fn threads(&self) -> Threads<C> {
137 Threads::new(self)
138 }
139
140 pub fn vector_stores(&self) -> VectorStores<C> {
142 VectorStores::new(self)
143 }
144
145 pub fn batches(&self) -> Batches<C> {
147 Batches::new(self)
148 }
149
150 pub fn audit_logs(&self) -> AuditLogs<C> {
152 AuditLogs::new(self)
153 }
154
155 pub fn invites(&self) -> Invites<C> {
157 Invites::new(self)
158 }
159
160 pub fn users(&self) -> Users<C> {
162 Users::new(self)
163 }
164
165 pub fn projects(&self) -> Projects<C> {
167 Projects::new(self)
168 }
169
170 pub fn responses(&self) -> Responses<C> {
172 Responses::new(self)
173 }
174
175 pub fn config(&self) -> &C {
176 &self.config
177 }
178
179 pub(crate) async fn get<O>(&self, path: &str) -> Result<O, OpenAIError>
181 where
182 O: DeserializeOwned,
183 {
184 let request_maker = || async {
185 Ok(self
186 .http_client
187 .get(self.config.url(path))
188 .query(&self.config.query())
189 .headers(self.config.headers())
190 .build()?)
191 };
192
193 self.execute(request_maker).await
194 }
195
196 pub(crate) async fn get_with_query<Q, O>(&self, path: &str, query: &Q) -> Result<O, OpenAIError>
198 where
199 O: DeserializeOwned,
200 Q: Serialize + ?Sized,
201 {
202 let request_maker = || async {
203 Ok(self
204 .http_client
205 .get(self.config.url(path))
206 .query(&self.config.query())
207 .query(query)
208 .headers(self.config.headers())
209 .build()?)
210 };
211
212 self.execute(request_maker).await
213 }
214
215 pub(crate) async fn delete<O>(&self, path: &str) -> Result<O, OpenAIError>
217 where
218 O: DeserializeOwned,
219 {
220 let request_maker = || async {
221 Ok(self
222 .http_client
223 .delete(self.config.url(path))
224 .query(&self.config.query())
225 .headers(self.config.headers())
226 .build()?)
227 };
228
229 self.execute(request_maker).await
230 }
231
232 pub(crate) async fn get_raw(&self, path: &str) -> Result<Bytes, OpenAIError> {
234 let request_maker = || async {
235 Ok(self
236 .http_client
237 .get(self.config.url(path))
238 .query(&self.config.query())
239 .headers(self.config.headers())
240 .build()?)
241 };
242
243 self.execute_raw(request_maker).await
244 }
245
246 pub(crate) async fn get_raw_with_query<Q>(
247 &self,
248 path: &str,
249 query: &Q,
250 ) -> Result<Bytes, OpenAIError>
251 where
252 Q: Serialize + ?Sized,
253 {
254 let request_maker = || async {
255 Ok(self
256 .http_client
257 .get(self.config.url(path))
258 .query(&self.config.query())
259 .query(query)
260 .headers(self.config.headers())
261 .build()?)
262 };
263
264 self.execute_raw(request_maker).await
265 }
266
267 pub(crate) async fn post_raw<I>(&self, path: &str, request: I) -> Result<Bytes, OpenAIError>
269 where
270 I: Serialize,
271 {
272 let request_maker = || async {
273 Ok(self
274 .http_client
275 .post(self.config.url(path))
276 .query(&self.config.query())
277 .headers(self.config.headers())
278 .json(&request)
279 .build()?)
280 };
281
282 self.execute_raw(request_maker).await
283 }
284
285 pub(crate) async fn post<I, O>(&self, path: &str, request: I) -> Result<O, OpenAIError>
287 where
288 I: Serialize,
289 O: DeserializeOwned,
290 {
291 let request_maker = || async {
292 Ok(self
293 .http_client
294 .post(self.config.url(path))
295 .query(&self.config.query())
296 .headers(self.config.headers())
297 .json(&request)
298 .build()?)
299 };
300
301 self.execute(request_maker).await
302 }
303
304 pub(crate) async fn post_form_raw<F>(&self, path: &str, form: F) -> Result<Bytes, OpenAIError>
306 where
307 Form: AsyncTryFrom<F, Error = OpenAIError>,
308 F: Clone,
309 {
310 let request_maker = || async {
311 Ok(self
312 .http_client
313 .post(self.config.url(path))
314 .query(&self.config.query())
315 .headers(self.config.headers())
316 .multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
317 .build()?)
318 };
319
320 self.execute_raw(request_maker).await
321 }
322
323 pub(crate) async fn post_form<O, F>(&self, path: &str, form: F) -> Result<O, OpenAIError>
325 where
326 O: DeserializeOwned,
327 Form: AsyncTryFrom<F, Error = OpenAIError>,
328 F: Clone,
329 {
330 let request_maker = || async {
331 Ok(self
332 .http_client
333 .post(self.config.url(path))
334 .query(&self.config.query())
335 .headers(self.config.headers())
336 .multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
337 .build()?)
338 };
339
340 self.execute(request_maker).await
341 }
342
343 async fn execute_raw<M, Fut>(&self, request_maker: M) -> Result<Bytes, OpenAIError>
349 where
350 M: Fn() -> Fut,
351 Fut: core::future::Future<Output = Result<reqwest::Request, OpenAIError>>,
352 {
353 let client = self.http_client.clone();
354
355 backoff::future::retry(self.backoff.clone(), || async {
356 let request = request_maker().await.map_err(backoff::Error::Permanent)?;
357 let response = client
358 .execute(request)
359 .await
360 .map_err(OpenAIError::Reqwest)
361 .map_err(backoff::Error::Permanent)?;
362
363 let status = response.status();
364
365 match read_response(response).await {
366 Ok(bytes) => Ok(bytes),
367 Err(e) => {
368 match e {
369 OpenAIError::ApiError(api_error) => {
370 if status.is_server_error() {
371 Err(backoff::Error::Transient {
372 err: OpenAIError::ApiError(api_error),
373 retry_after: None,
374 })
375 } else if status.as_u16() == 429
376 && api_error.r#type != Some("insufficient_quota".to_string())
377 {
378 tracing::warn!("Rate limited: {}", api_error.message);
380 Err(backoff::Error::Transient {
381 err: OpenAIError::ApiError(api_error),
382 retry_after: None,
383 })
384 } else {
385 Err(backoff::Error::Permanent(OpenAIError::ApiError(api_error)))
386 }
387 }
388 _ => Err(backoff::Error::Permanent(e)),
389 }
390 }
391 }
392 })
393 .await
394 }
395
396 async fn execute<O, M, Fut>(&self, request_maker: M) -> Result<O, OpenAIError>
402 where
403 O: DeserializeOwned,
404 M: Fn() -> Fut,
405 Fut: core::future::Future<Output = Result<reqwest::Request, OpenAIError>>,
406 {
407 let bytes = self.execute_raw(request_maker).await?;
408
409 let response: O = serde_json::from_slice(bytes.as_ref())
410 .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
411
412 Ok(response)
413 }
414
415 pub(crate) async fn post_stream<I, O>(
417 &self,
418 path: &str,
419 request: I,
420 ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
421 where
422 I: Serialize,
423 O: DeserializeOwned + std::marker::Send + 'static,
424 {
425 let event_source = self
426 .http_client
427 .post(self.config.url(path))
428 .query(&self.config.query())
429 .headers(self.config.headers())
430 .json(&request)
431 .eventsource()
432 .unwrap();
433
434 stream(event_source).await
435 }
436
437 pub(crate) async fn post_stream_mapped_raw_events<I, O>(
438 &self,
439 path: &str,
440 request: I,
441 event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
442 ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
443 where
444 I: Serialize,
445 O: DeserializeOwned + std::marker::Send + 'static,
446 {
447 let event_source = self
448 .http_client
449 .post(self.config.url(path))
450 .query(&self.config.query())
451 .headers(self.config.headers())
452 .json(&request)
453 .eventsource()
454 .unwrap();
455
456 stream_mapped_raw_events(event_source, event_mapper).await
457 }
458
459 pub(crate) async fn _get_stream<Q, O>(
461 &self,
462 path: &str,
463 query: &Q,
464 ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
465 where
466 Q: Serialize + ?Sized,
467 O: DeserializeOwned + std::marker::Send + 'static,
468 {
469 let event_source = self
470 .http_client
471 .get(self.config.url(path))
472 .query(query)
473 .query(&self.config.query())
474 .headers(self.config.headers())
475 .eventsource()
476 .unwrap();
477
478 stream(event_source).await
479 }
480}
481
482async fn read_response(response: Response) -> Result<Bytes, OpenAIError> {
483 let status = response.status();
484 let bytes = response.bytes().await.map_err(OpenAIError::Reqwest)?;
485
486 if status.is_server_error() {
487 let message: String = String::from_utf8_lossy(&bytes).into_owned();
489 tracing::warn!("Server error: {status} - {message}");
490 return Err(OpenAIError::ApiError(ApiError {
491 message,
492 r#type: None,
493 param: None,
494 code: None,
495 }));
496 }
497
498 if !status.is_success() {
500 let wrapped_error: WrappedError = serde_json::from_slice(bytes.as_ref())
501 .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
502
503 return Err(OpenAIError::ApiError(wrapped_error.error));
504 }
505
506 Ok(bytes)
507}
508
509async fn map_stream_error(value: EventSourceError) -> OpenAIError {
510 match value {
511 EventSourceError::InvalidStatusCode(status_code, response) => {
512 read_response(response).await.expect_err(&format!(
513 "Unreachable because read_response returns err when status_code {status_code} is invalid"
514 ))
515 }
516 _ => OpenAIError::StreamError(StreamError::ReqwestEventSource(value)),
517 }
518}
519
520pub(crate) async fn stream<O>(
523 mut event_source: EventSource,
524) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
525where
526 O: DeserializeOwned + std::marker::Send + 'static,
527{
528 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
529
530 tokio::spawn(async move {
531 while let Some(ev) = event_source.next().await {
532 match ev {
533 Err(e) => {
534 if let Err(_e) = tx.send(Err(map_stream_error(e).await)) {
535 break;
537 }
538 }
539 Ok(event) => match event {
540 Event::Message(message) => {
541 if message.data == "[DONE]" {
542 break;
543 }
544
545 let response = match serde_json::from_str::<O>(&message.data) {
546 Err(e) => Err(map_deserialization_error(e, message.data.as_bytes())),
547 Ok(output) => Ok(output),
548 };
549
550 if let Err(_e) = tx.send(response) {
551 break;
553 }
554 }
555 Event::Open => continue,
556 },
557 }
558 }
559
560 event_source.close();
561 });
562
563 Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
564}
565
566pub(crate) async fn stream_mapped_raw_events<O>(
567 mut event_source: EventSource,
568 event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
569) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
570where
571 O: DeserializeOwned + std::marker::Send + 'static,
572{
573 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
574
575 tokio::spawn(async move {
576 while let Some(ev) = event_source.next().await {
577 match ev {
578 Err(e) => {
579 if let Err(_e) = tx.send(Err(map_stream_error(e).await)) {
580 break;
582 }
583 }
584 Ok(event) => match event {
585 Event::Message(message) => {
586 let mut done = false;
587
588 if message.data == "[DONE]" {
589 done = true;
590 }
591
592 let response = event_mapper(message);
593
594 if let Err(_e) = tx.send(response) {
595 break;
597 }
598
599 if done {
600 break;
601 }
602 }
603 Event::Open => continue,
604 },
605 }
606 }
607
608 event_source.close();
609 });
610
611 Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
612}