1use std::pin::Pin;
12
13use bytes::Bytes;
14use futures::{Stream, stream::StreamExt};
15use reqwest::multipart::Form;
16use reqwest_eventsource::{Event, EventSource, RequestBuilderExt};
17use serde::{Serialize, de::DeserializeOwned};
18
19use crate::{
20 Assistants, Audio, AuditLogs, Batches, Chat, Completions, Embeddings, FineTuning, Invites,
21 Models, Projects, Responses, Threads, Uploads, Users, VectorStores,
22 config::{Config, OpenAIConfig},
23 error::{ApiError, OpenAIError, WrappedError, map_deserialization_error},
24 file::Files,
25 image::Images,
26 moderation::Moderations,
27 traits::AsyncTryFrom,
28};
29
30#[derive(Debug, Clone, Default)]
31pub struct Client<C: Config> {
34 http_client: reqwest::Client,
35 config: C,
36 backoff: backoff::ExponentialBackoff,
37}
38
39impl Client<OpenAIConfig> {
40 pub fn new() -> Self {
42 Self::default()
43 }
44}
45
46impl<C: Config> Client<C> {
47 pub fn build(
49 http_client: reqwest::Client,
50 config: C,
51 backoff: backoff::ExponentialBackoff,
52 ) -> Self {
53 Self {
54 http_client,
55 config,
56 backoff,
57 }
58 }
59
60 pub fn with_config(config: C) -> Self {
62 Self {
63 http_client: reqwest::Client::new(),
64 config,
65 backoff: Default::default(),
66 }
67 }
68
69 pub fn with_http_client(mut self, http_client: reqwest::Client) -> Self {
73 self.http_client = http_client;
74 self
75 }
76
77 pub fn with_backoff(mut self, backoff: backoff::ExponentialBackoff) -> Self {
79 self.backoff = backoff;
80 self
81 }
82
83 pub fn models(&self) -> Models<C> {
87 Models::new(self)
88 }
89
90 pub fn completions(&self) -> Completions<C> {
92 Completions::new(self)
93 }
94
95 pub fn chat(&self) -> Chat<C> {
97 Chat::new(self)
98 }
99
100 pub fn images(&self) -> Images<C> {
102 Images::new(self)
103 }
104
105 pub fn moderations(&self) -> Moderations<C> {
107 Moderations::new(self)
108 }
109
110 pub fn files(&self) -> Files<C> {
112 Files::new(self)
113 }
114
115 pub fn uploads(&self) -> Uploads<C> {
117 Uploads::new(self)
118 }
119
120 pub fn fine_tuning(&self) -> FineTuning<C> {
122 FineTuning::new(self)
123 }
124
125 pub fn embeddings(&self) -> Embeddings<C> {
127 Embeddings::new(self)
128 }
129
130 pub fn audio(&self) -> Audio<C> {
132 Audio::new(self)
133 }
134
135 pub fn assistants(&self) -> Assistants<C> {
137 Assistants::new(self)
138 }
139
140 pub fn threads(&self) -> Threads<C> {
142 Threads::new(self)
143 }
144
145 pub fn vector_stores(&self) -> VectorStores<C> {
147 VectorStores::new(self)
148 }
149
150 pub fn batches(&self) -> Batches<C> {
152 Batches::new(self)
153 }
154
155 pub fn audit_logs(&self) -> AuditLogs<C> {
157 AuditLogs::new(self)
158 }
159
160 pub fn invites(&self) -> Invites<C> {
162 Invites::new(self)
163 }
164
165 pub fn users(&self) -> Users<C> {
167 Users::new(self)
168 }
169
170 pub fn projects(&self) -> Projects<C> {
172 Projects::new(self)
173 }
174
175 pub fn responses(&self) -> Responses<C> {
177 Responses::new(self)
178 }
179
180 pub fn config(&self) -> &C {
181 &self.config
182 }
183
184 pub(crate) async fn get<O>(&self, path: &str) -> Result<O, OpenAIError>
186 where
187 O: DeserializeOwned,
188 {
189 let request_maker = || async {
190 Ok(self
191 .http_client
192 .get(self.config.url(path))
193 .query(&self.config.query())
194 .headers(self.config.headers())
195 .build()?)
196 };
197
198 self.execute(request_maker).await
199 }
200
201 pub(crate) async fn get_with_query<Q, O>(&self, path: &str, query: &Q) -> Result<O, OpenAIError>
203 where
204 O: DeserializeOwned,
205 Q: Serialize + ?Sized,
206 {
207 let request_maker = || async {
208 Ok(self
209 .http_client
210 .get(self.config.url(path))
211 .query(&self.config.query())
212 .query(query)
213 .headers(self.config.headers())
214 .build()?)
215 };
216
217 self.execute(request_maker).await
218 }
219
220 pub(crate) async fn delete<O>(&self, path: &str) -> Result<O, OpenAIError>
222 where
223 O: DeserializeOwned,
224 {
225 let request_maker = || async {
226 Ok(self
227 .http_client
228 .delete(self.config.url(path))
229 .query(&self.config.query())
230 .headers(self.config.headers())
231 .build()?)
232 };
233
234 self.execute(request_maker).await
235 }
236
237 pub(crate) async fn get_raw(&self, path: &str) -> Result<Bytes, OpenAIError> {
239 let request_maker = || async {
240 Ok(self
241 .http_client
242 .get(self.config.url(path))
243 .query(&self.config.query())
244 .headers(self.config.headers())
245 .build()?)
246 };
247
248 self.execute_raw(request_maker).await
249 }
250
251 pub(crate) async fn post_raw<I>(&self, path: &str, request: I) -> Result<Bytes, OpenAIError>
253 where
254 I: Serialize,
255 {
256 let request_maker = || async {
257 Ok(self
258 .http_client
259 .post(self.config.url(path))
260 .query(&self.config.query())
261 .headers(self.config.headers())
262 .json(&request)
263 .build()?)
264 };
265
266 self.execute_raw(request_maker).await
267 }
268
269 pub(crate) async fn post<I, O>(&self, path: &str, request: I) -> Result<O, OpenAIError>
271 where
272 I: Serialize,
273 O: DeserializeOwned,
274 {
275 let request_maker = || async {
276 Ok(self
277 .http_client
278 .post(self.config.url(path))
279 .query(&self.config.query())
280 .headers(self.config.headers())
281 .json(&request)
282 .build()?)
283 };
284
285 self.execute(request_maker).await
286 }
287
288 pub(crate) async fn post_form_raw<F>(&self, path: &str, form: F) -> Result<Bytes, OpenAIError>
290 where
291 Form: AsyncTryFrom<F, Error = OpenAIError>,
292 F: Clone,
293 {
294 let request_maker = || async {
295 Ok(self
296 .http_client
297 .post(self.config.url(path))
298 .query(&self.config.query())
299 .headers(self.config.headers())
300 .multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
301 .build()?)
302 };
303
304 self.execute_raw(request_maker).await
305 }
306
307 pub(crate) async fn post_form<O, F>(&self, path: &str, form: F) -> Result<O, OpenAIError>
309 where
310 O: DeserializeOwned,
311 Form: AsyncTryFrom<F, Error = OpenAIError>,
312 F: Clone,
313 {
314 let request_maker = || async {
315 Ok(self
316 .http_client
317 .post(self.config.url(path))
318 .query(&self.config.query())
319 .headers(self.config.headers())
320 .multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
321 .build()?)
322 };
323
324 self.execute(request_maker).await
325 }
326
327 async fn execute_raw<M, Fut>(&self, request_maker: M) -> Result<Bytes, OpenAIError>
333 where
334 M: Fn() -> Fut,
335 Fut: core::future::Future<Output = Result<reqwest::Request, OpenAIError>>,
336 {
337 let client = self.http_client.clone();
338
339 backoff::future::retry(self.backoff.clone(), || async {
340 let request = request_maker().await.map_err(backoff::Error::Permanent)?;
341 let response = client
342 .execute(request)
343 .await
344 .map_err(OpenAIError::Reqwest)
345 .map_err(backoff::Error::Permanent)?;
346
347 let status = response.status();
348 let bytes = response
349 .bytes()
350 .await
351 .map_err(OpenAIError::Reqwest)
352 .map_err(backoff::Error::Permanent)?;
353
354 if status.is_server_error() {
355 let message: String = String::from_utf8_lossy(&bytes).into_owned();
357 tracing::warn!("Server error: {status} - {message}");
358 return Err(backoff::Error::Transient {
359 err: OpenAIError::ApiError(ApiError {
360 message,
361 r#type: None,
362 param: None,
363 code: None,
364 }),
365 retry_after: None,
366 });
367 }
368
369 if !status.is_success() {
371 let wrapped_error: WrappedError = serde_json::from_slice(bytes.as_ref())
372 .map_err(|e| map_deserialization_error(e, bytes.as_ref()))
373 .map_err(backoff::Error::Permanent)?;
374
375 if status.as_u16() == 429
376 && wrapped_error.error.r#type != Some("insufficient_quota".to_string())
379 {
380 tracing::warn!("Rate limited: {}", wrapped_error.error.message);
382 return Err(backoff::Error::Transient {
383 err: OpenAIError::ApiError(wrapped_error.error),
384 retry_after: None,
385 });
386 } else {
387 return Err(backoff::Error::Permanent(OpenAIError::ApiError(
388 wrapped_error.error,
389 )));
390 }
391 }
392
393 Ok(bytes)
394 })
395 .await
396 }
397
398 async fn execute<O, M, Fut>(&self, request_maker: M) -> Result<O, OpenAIError>
404 where
405 O: DeserializeOwned,
406 M: Fn() -> Fut,
407 Fut: core::future::Future<Output = Result<reqwest::Request, OpenAIError>>,
408 {
409 let bytes = self.execute_raw(request_maker).await?;
410
411 let response: O = serde_json::from_slice(bytes.as_ref())
412 .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
413
414 Ok(response)
415 }
416
417 pub(crate) async fn post_stream<I, O>(
419 &self,
420 path: &str,
421 request: I,
422 ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
423 where
424 I: Serialize,
425 O: DeserializeOwned + std::marker::Send + 'static,
426 {
427 let event_source = self
428 .http_client
429 .post(self.config.url(path))
430 .query(&self.config.query())
431 .headers(self.config.headers())
432 .json(&request)
433 .eventsource()
434 .unwrap();
435
436 stream(event_source).await
437 }
438
439 pub(crate) async fn post_stream_mapped_raw_events<I, O>(
440 &self,
441 path: &str,
442 request: I,
443 event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
444 ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
445 where
446 I: Serialize,
447 O: DeserializeOwned + std::marker::Send + 'static,
448 {
449 let event_source = self
450 .http_client
451 .post(self.config.url(path))
452 .query(&self.config.query())
453 .headers(self.config.headers())
454 .json(&request)
455 .eventsource()
456 .unwrap();
457
458 stream_mapped_raw_events(event_source, event_mapper).await
459 }
460
461 pub(crate) async fn _get_stream<Q, O>(
463 &self,
464 path: &str,
465 query: &Q,
466 ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
467 where
468 Q: Serialize + ?Sized,
469 O: DeserializeOwned + std::marker::Send + 'static,
470 {
471 let event_source = self
472 .http_client
473 .get(self.config.url(path))
474 .query(query)
475 .query(&self.config.query())
476 .headers(self.config.headers())
477 .eventsource()
478 .unwrap();
479
480 stream(event_source).await
481 }
482}
483
484pub(crate) async fn stream<O>(
487 mut event_source: EventSource,
488) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
489where
490 O: DeserializeOwned + std::marker::Send + 'static,
491{
492 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
493
494 tokio::spawn(async move {
495 while let Some(ev) = event_source.next().await {
496 match ev {
497 Err(e) => {
498 if let Err(_e) = tx.send(Err(OpenAIError::StreamError(e.to_string()))) {
499 break;
501 }
502 }
503 Ok(event) => match event {
504 Event::Message(message) => {
505 if message.data == "[DONE]" {
506 break;
507 }
508
509 let response = match serde_json::from_str::<O>(&message.data) {
510 Err(e) => Err(map_deserialization_error(e, message.data.as_bytes())),
511 Ok(output) => Ok(output),
512 };
513
514 if let Err(_e) = tx.send(response) {
515 break;
517 }
518 }
519 Event::Open => continue,
520 },
521 }
522 }
523
524 event_source.close();
525 });
526
527 Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
528}
529
530pub(crate) async fn stream_mapped_raw_events<O>(
531 mut event_source: EventSource,
532 event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
533) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
534where
535 O: DeserializeOwned + std::marker::Send + 'static,
536{
537 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
538
539 tokio::spawn(async move {
540 while let Some(ev) = event_source.next().await {
541 match ev {
542 Err(e) => {
543 if let Err(_e) = tx.send(Err(OpenAIError::StreamError(e.to_string()))) {
544 break;
546 }
547 }
548 Ok(event) => match event {
549 Event::Message(message) => {
550 let mut done = false;
551
552 if message.data == "[DONE]" {
553 done = true;
554 }
555
556 let response = event_mapper(message);
557
558 if let Err(_e) = tx.send(response) {
559 break;
561 }
562
563 if done {
564 break;
565 }
566 }
567 Event::Open => continue,
568 },
569 }
570 }
571
572 event_source.close();
573 });
574
575 Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
576}