1use std::pin::Pin;
2
3use bytes::Bytes;
4use futures::{stream::StreamExt, Stream};
5use reqwest::multipart::Form;
6use reqwest_eventsource::{Error, Event, EventSource, RequestBuilderExt};
7use serde::{de::DeserializeOwned, Serialize};
8
9use crate::{
10 config::{Config, OpenAIConfig},
11 error::{map_deserialization_error, OpenAIError, WrappedError},
12 file::Files,
13 image::Images,
14 moderation::Moderations,
15 util::AsyncTryFrom,
16 Assistants, Audio, AuditLogs, Batches, Chat, Completions, Embeddings, FineTuning, Invites,
17 Models, Projects, Threads, Users, VectorStores,
18};
19
20#[derive(Debug, Clone, Default)]
21pub struct Client<C: Config> {
24 http_client: reqwest::Client,
25 config: C,
26 backoff: backoff::ExponentialBackoff,
27}
28
29impl Client<OpenAIConfig> {
30 pub fn new() -> Self {
32 Self::default()
33 }
34}
35
36impl<C: Config> Client<C> {
37 pub fn build(
39 http_client: reqwest::Client,
40 config: C,
41 backoff: backoff::ExponentialBackoff,
42 ) -> Self {
43 Self {
44 http_client,
45 config,
46 backoff,
47 }
48 }
49
50 pub fn with_config(config: C) -> Self {
52 Self {
53 http_client: reqwest::Client::new(),
54 config,
55 backoff: Default::default(),
56 }
57 }
58
59 pub fn with_http_client(mut self, http_client: reqwest::Client) -> Self {
63 self.http_client = http_client;
64 self
65 }
66
67 pub fn with_backoff(mut self, backoff: backoff::ExponentialBackoff) -> Self {
69 self.backoff = backoff;
70 self
71 }
72
73 pub fn models(&self) -> Models<C> {
77 Models::new(self)
78 }
79
80 pub fn completions(&self) -> Completions<C> {
82 Completions::new(self)
83 }
84
85 pub fn chat(&self) -> Chat<C> {
87 Chat::new(self)
88 }
89
90 pub fn images(&self) -> Images<C> {
92 Images::new(self)
93 }
94
95 pub fn moderations(&self) -> Moderations<C> {
97 Moderations::new(self)
98 }
99
100 pub fn files(&self) -> Files<C> {
102 Files::new(self)
103 }
104
105 pub fn fine_tuning(&self) -> FineTuning<C> {
107 FineTuning::new(self)
108 }
109
110 pub fn embeddings(&self) -> Embeddings<C> {
112 Embeddings::new(self)
113 }
114
115 pub fn audio(&self) -> Audio<C> {
117 Audio::new(self)
118 }
119
120 pub fn assistants(&self) -> Assistants<C> {
122 Assistants::new(self)
123 }
124
125 pub fn threads(&self) -> Threads<C> {
127 Threads::new(self)
128 }
129
130 pub fn vector_stores(&self) -> VectorStores<C> {
132 VectorStores::new(self)
133 }
134
135 pub fn batches(&self) -> Batches<C> {
137 Batches::new(self)
138 }
139
140 pub fn audit_logs(&self) -> AuditLogs<C> {
142 AuditLogs::new(self)
143 }
144
145 pub fn invites(&self) -> Invites<C> {
147 Invites::new(self)
148 }
149
150 pub fn users(&self) -> Users<C> {
152 Users::new(self)
153 }
154
155 pub fn projects(&self) -> Projects<C> {
157 Projects::new(self)
158 }
159
160 pub fn config(&self) -> &C {
161 &self.config
162 }
163
164 pub(crate) async fn get<O>(&self, path: &str) -> Result<O, OpenAIError>
166 where
167 O: DeserializeOwned,
168 {
169 let request_maker = || async {
170 Ok(self
171 .http_client
172 .get(self.config.url(path))
173 .query(&self.config.query())
174 .headers(self.config.headers())
175 .build()?)
176 };
177
178 self.execute(request_maker).await
179 }
180
181 pub(crate) async fn get_with_query<Q, O>(&self, path: &str, query: &Q) -> Result<O, OpenAIError>
183 where
184 O: DeserializeOwned,
185 Q: Serialize + ?Sized,
186 {
187 let request_maker = || async {
188 Ok(self
189 .http_client
190 .get(self.config.url(path))
191 .query(&self.config.query())
192 .query(query)
193 .headers(self.config.headers())
194 .build()?)
195 };
196
197 self.execute(request_maker).await
198 }
199
200 pub(crate) async fn delete<O>(&self, path: &str) -> Result<O, OpenAIError>
202 where
203 O: DeserializeOwned,
204 {
205 let request_maker = || async {
206 Ok(self
207 .http_client
208 .delete(self.config.url(path))
209 .query(&self.config.query())
210 .headers(self.config.headers())
211 .build()?)
212 };
213
214 self.execute(request_maker).await
215 }
216
217 pub(crate) async fn get_raw(&self, path: &str) -> Result<Bytes, OpenAIError> {
219 let request_maker = || async {
220 Ok(self
221 .http_client
222 .get(self.config.url(path))
223 .query(&self.config.query())
224 .headers(self.config.headers())
225 .build()?)
226 };
227
228 self.execute_raw(request_maker).await
229 }
230
231 pub(crate) async fn post_raw<I>(&self, path: &str, request: I) -> Result<Bytes, OpenAIError>
233 where
234 I: Serialize,
235 {
236 let request_maker = || async {
237 Ok(self
238 .http_client
239 .post(self.config.url(path))
240 .query(&self.config.query())
241 .headers(self.config.headers())
242 .json(&request)
243 .build()?)
244 };
245
246 self.execute_raw(request_maker).await
247 }
248
249 pub(crate) async fn post<I, O>(&self, path: &str, request: I) -> Result<O, OpenAIError>
251 where
252 I: Serialize,
253 O: DeserializeOwned,
254 {
255 let request_maker = || async {
256 Ok(self
257 .http_client
258 .post(self.config.url(path))
259 .query(&self.config.query())
260 .headers(self.config.headers())
261 .json(&request)
262 .build()?)
263 };
264
265 self.execute(request_maker).await
266 }
267
268 pub(crate) async fn post_form_raw<F>(&self, path: &str, form: F) -> Result<Bytes, OpenAIError>
270 where
271 Form: AsyncTryFrom<F, Error = OpenAIError>,
272 F: Clone,
273 {
274 let request_maker = || async {
275 Ok(self
276 .http_client
277 .post(self.config.url(path))
278 .query(&self.config.query())
279 .headers(self.config.headers())
280 .multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
281 .build()?)
282 };
283
284 self.execute_raw(request_maker).await
285 }
286
287 pub(crate) async fn post_form<O, F>(&self, path: &str, form: F) -> Result<O, OpenAIError>
289 where
290 O: DeserializeOwned,
291 Form: AsyncTryFrom<F, Error = OpenAIError>,
292 F: Clone,
293 {
294 let request_maker = || async {
295 Ok(self
296 .http_client
297 .post(self.config.url(path))
298 .query(&self.config.query())
299 .headers(self.config.headers())
300 .multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
301 .build()?)
302 };
303
304 self.execute(request_maker).await
305 }
306
307 async fn execute_raw<M, Fut>(&self, request_maker: M) -> Result<Bytes, OpenAIError>
313 where
314 M: Fn() -> Fut,
315 Fut: core::future::Future<Output = Result<reqwest::Request, OpenAIError>>,
316 {
317 let client = self.http_client.clone();
318
319 backoff::future::retry(self.backoff.clone(), || async {
320 let request = request_maker().await.map_err(backoff::Error::Permanent)?;
321 let response = client
322 .execute(request)
323 .await
324 .map_err(OpenAIError::Reqwest)
325 .map_err(backoff::Error::Permanent)?;
326
327 let status = response.status();
328 let bytes = response
329 .bytes()
330 .await
331 .map_err(OpenAIError::Reqwest)
332 .map_err(backoff::Error::Permanent)?;
333
334 if !status.is_success() {
336 let wrapped_error: WrappedError = serde_json::from_slice(bytes.as_ref())
337 .map_err(|e| map_deserialization_error(e, bytes.as_ref()))
338 .map_err(backoff::Error::Permanent)?;
339
340 if status.as_u16() == 429
341 && wrapped_error.error.r#type != Some("insufficient_quota".to_string())
344 {
345 tracing::warn!("Rate limited: {}", wrapped_error.error.message);
347 return Err(backoff::Error::Transient {
348 err: OpenAIError::ApiError(wrapped_error.error),
349 retry_after: None,
350 });
351 } else {
352 return Err(backoff::Error::Permanent(OpenAIError::ApiError(
353 wrapped_error.error,
354 )));
355 }
356 }
357
358 Ok(bytes)
359 })
360 .await
361 }
362
363 async fn execute<O, M, Fut>(&self, request_maker: M) -> Result<O, OpenAIError>
369 where
370 O: DeserializeOwned,
371 M: Fn() -> Fut,
372 Fut: core::future::Future<Output = Result<reqwest::Request, OpenAIError>>,
373 {
374 let bytes = self.execute_raw(request_maker).await?;
375
376 let response: O = serde_json::from_slice(bytes.as_ref())
377 .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
378
379 Ok(response)
380 }
381
382 pub(crate) async fn post_stream<I, O>(
384 &self,
385 path: &str,
386 request: I,
387 ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
388 where
389 I: Serialize,
390 O: DeserializeOwned + std::marker::Send + 'static,
391 {
392 let event_source = self
393 .http_client
394 .post(self.config.url(path))
395 .query(&self.config.query())
396 .headers(self.config.headers())
397 .json(&request)
398 .eventsource()
399 .unwrap();
400
401 stream(event_source).await
402 }
403
404 pub(crate) async fn post_stream_mapped_raw_events<I, O>(
405 &self,
406 path: &str,
407 request: I,
408 event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
409 ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
410 where
411 I: Serialize,
412 O: DeserializeOwned + std::marker::Send + 'static,
413 {
414 let event_source = self
415 .http_client
416 .post(self.config.url(path))
417 .query(&self.config.query())
418 .headers(self.config.headers())
419 .json(&request)
420 .eventsource()
421 .unwrap();
422
423 stream_mapped_raw_events(event_source, event_mapper).await
424 }
425
426 pub(crate) async fn _get_stream<Q, O>(
428 &self,
429 path: &str,
430 query: &Q,
431 ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
432 where
433 Q: Serialize + ?Sized,
434 O: DeserializeOwned + std::marker::Send + 'static,
435 {
436 let event_source = self
437 .http_client
438 .get(self.config.url(path))
439 .query(query)
440 .query(&self.config.query())
441 .headers(self.config.headers())
442 .eventsource()
443 .unwrap();
444
445 stream(event_source).await
446 }
447}
448
449async fn handle_eventsource_error(e: Error) -> Result<(), OpenAIError> {
450 let error_text = e.to_string();
451 if let Error::InvalidStatusCode(_code, response) = e {
452 if let Ok(text) = response.text().await {
453 let api_error = serde_json::from_str::<WrappedError>(&text);
454 if let Ok(e) = api_error {
455 return Err(OpenAIError::ApiError(e.error));
456 }
457 }
458 }
459
460 Err(OpenAIError::StreamError(error_text))
461}
462
463pub(crate) async fn stream<O>(
466 mut event_source: EventSource,
467) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
468where
469 O: DeserializeOwned + std::marker::Send + 'static,
470{
471 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
472
473 tokio::spawn(async move {
474 while let Some(ev) = event_source.next().await {
475 match ev {
476 Err(e) => {
477 if let Err(e) = handle_eventsource_error(e).await {
478 if let Err(_e) = tx.send(Err(e)) {
479 break;
481 }
482 }
483 }
484 Ok(event) => match event {
485 Event::Message(message) => {
486 if message.data == "[DONE]" {
487 break;
488 }
489
490 let response = match serde_json::from_str::<O>(&message.data) {
491 Err(e) => Err(map_deserialization_error(e, message.data.as_bytes())),
492 Ok(output) => Ok(output),
493 };
494
495 if let Err(_e) = tx.send(response) {
496 break;
498 }
499 }
500 Event::Open => continue,
501 },
502 }
503 }
504
505 event_source.close();
506 });
507
508 Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
509}
510
511pub(crate) async fn stream_mapped_raw_events<O>(
512 mut event_source: EventSource,
513 event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
514) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
515where
516 O: DeserializeOwned + std::marker::Send + 'static,
517{
518 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
519
520 tokio::spawn(async move {
521 while let Some(ev) = event_source.next().await {
522 match ev {
523 Err(e) => {
524 if let Err(_e) = tx.send(Err(OpenAIError::StreamError(e.to_string()))) {
525 break;
527 }
528 }
529 Ok(event) => match event {
530 Event::Message(message) => {
531 let mut done = false;
532
533 if message.data == "[DONE]" {
534 done = true;
535 }
536
537 let response = event_mapper(message);
538
539 if let Err(_e) = tx.send(response) {
540 break;
542 }
543
544 if done {
545 break;
546 }
547 }
548 Event::Open => continue,
549 },
550 }
551 }
552
553 event_source.close();
554 });
555
556 Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
557}