1use std::future;
2use std::marker::PhantomData;
3use std::pin::Pin;
4use std::task::{Context, Poll};
5
6use bytes::Bytes;
7use futures::stream::Filter;
8use futures::{stream::StreamExt, Stream};
9use pin_project::pin_project;
10use reqwest::multipart::Form;
11use reqwest_eventsource::{Event, EventSource, RequestBuilderExt};
12use serde::{de::DeserializeOwned, Serialize};
13
14use crate::util::async_convert::AsyncTryFrom;
15use crate::{
16 config::{Config, OpenAIConfig},
17 error::{map_deserialization_error, OpenAIError, WrappedError},
18 file::Files,
19 image::Images,
20 moderation::Moderations,
21 Assistants, Audio, AuditLogs, Batches, Chat, Completions, Embeddings, FineTuning, Invites,
22 Models, Projects, Threads, Users, VectorStores,
23};
24
25#[derive(Debug, Clone)]
26pub struct Client<C: Config> {
29 http_client: reqwest::Client,
30 config: C,
31}
32
33impl Client<OpenAIConfig> {
34 pub fn new() -> Self {
36 Self {
37 http_client: reqwest::Client::new(),
38 config: OpenAIConfig::default(),
39 }
40 }
41}
42
43impl<C: Config> Client<C> {
44 pub fn build(http_client: reqwest::Client, config: C) -> Self {
46 Self {
47 http_client,
48 config,
49 }
50 }
51
52 pub fn with_config(config: C) -> Self {
54 Self {
55 http_client: reqwest::Client::new(),
56 config,
57 }
58 }
59
60 pub fn with_http_client(mut self, http_client: reqwest::Client) -> Self {
64 self.http_client = http_client;
65 self
66 }
67
68 pub fn models(&self) -> Models<C> {
72 Models::new(self)
73 }
74
75 pub fn completions(&self) -> Completions<C> {
77 Completions::new(self)
78 }
79
80 pub fn chat(&self) -> Chat<C> {
82 Chat::new(self)
83 }
84
85 pub fn images(&self) -> Images<C> {
87 Images::new(self)
88 }
89
90 pub fn moderations(&self) -> Moderations<C> {
92 Moderations::new(self)
93 }
94
95 pub fn files(&self) -> Files<C> {
97 Files::new(self)
98 }
99
100 pub fn fine_tuning(&self) -> FineTuning<C> {
102 FineTuning::new(self)
103 }
104
105 pub fn embeddings(&self) -> Embeddings<C> {
107 Embeddings::new(self)
108 }
109
110 pub fn audio(&self) -> Audio<C> {
112 Audio::new(self)
113 }
114
115 pub fn assistants(&self) -> Assistants<C> {
117 Assistants::new(self)
118 }
119
120 pub fn threads(&self) -> Threads<C> {
122 Threads::new(self)
123 }
124
125 pub fn vector_stores(&self) -> VectorStores<C> {
127 VectorStores::new(self)
128 }
129
130 pub fn batches(&self) -> Batches<C> {
132 Batches::new(self)
133 }
134
135 pub fn audit_logs(&self) -> AuditLogs<C> {
137 AuditLogs::new(self)
138 }
139
140 pub fn invites(&self) -> Invites<C> {
142 Invites::new(self)
143 }
144
145 pub fn users(&self) -> Users<C> {
147 Users::new(self)
148 }
149
150 pub fn projects(&self) -> Projects<C> {
152 Projects::new(self)
153 }
154
155 pub fn config(&self) -> &C {
156 &self.config
157 }
158
159 pub(crate) async fn get<O>(&self, path: &str) -> Result<O, OpenAIError>
161 where
162 O: DeserializeOwned,
163 {
164 let request_maker = || async {
165 Ok(self
166 .http_client
167 .get(self.config.url(path))
168 .query(&self.config.query())
169 .headers(self.config.headers())
170 .build()?)
171 };
172
173 self.execute(request_maker).await
174 }
175
176 pub(crate) async fn get_with_query<Q, O>(&self, path: &str, query: &Q) -> Result<O, OpenAIError>
178 where
179 O: DeserializeOwned,
180 Q: Serialize + ?Sized,
181 {
182 let request_maker = || async {
183 Ok(self
184 .http_client
185 .get(self.config.url(path))
186 .query(&self.config.query())
187 .query(query)
188 .headers(self.config.headers())
189 .build()?)
190 };
191
192 self.execute(request_maker).await
193 }
194
195 pub(crate) async fn delete<O>(&self, path: &str) -> Result<O, OpenAIError>
197 where
198 O: DeserializeOwned,
199 {
200 let request_maker = || async {
201 Ok(self
202 .http_client
203 .delete(self.config.url(path))
204 .query(&self.config.query())
205 .headers(self.config.headers())
206 .build()?)
207 };
208
209 self.execute(request_maker).await
210 }
211
212 pub(crate) async fn get_raw(&self, path: &str) -> Result<Bytes, OpenAIError> {
214 let request_maker = || async {
215 Ok(self
216 .http_client
217 .get(self.config.url(path))
218 .query(&self.config.query())
219 .headers(self.config.headers())
220 .build()?)
221 };
222
223 self.execute_raw(request_maker).await
224 }
225
226 pub(crate) async fn post_raw<I>(&self, path: &str, request: I) -> Result<Bytes, OpenAIError>
228 where
229 I: Serialize,
230 {
231 let request_maker = || async {
232 Ok(self
233 .http_client
234 .post(self.config.url(path))
235 .query(&self.config.query())
236 .headers(self.config.headers())
237 .json(&request)
238 .build()?)
239 };
240
241 self.execute_raw(request_maker).await
242 }
243
244 pub(crate) async fn post<I, O>(&self, path: &str, request: I) -> Result<O, OpenAIError>
246 where
247 I: Serialize,
248 O: DeserializeOwned,
249 {
250 let request_maker = || async {
251 Ok(self
252 .http_client
253 .post(self.config.url(path))
254 .query(&self.config.query())
255 .headers(self.config.headers())
256 .json(&request)
257 .build()?)
258 };
259
260 self.execute(request_maker).await
261 }
262
263 pub(crate) async fn post_form_raw<F>(&self, path: &str, form: F) -> Result<Bytes, OpenAIError>
265 where
266 Form: AsyncTryFrom<F, Error = OpenAIError>,
267 F: Clone,
268 {
269 let request_maker = || async {
270 Ok(self
271 .http_client
272 .post(self.config.url(path))
273 .query(&self.config.query())
274 .headers(self.config.headers())
275 .multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
276 .build()?)
277 };
278
279 self.execute_raw(request_maker).await
280 }
281
282 pub(crate) async fn post_form<O, F>(&self, path: &str, form: F) -> Result<O, OpenAIError>
284 where
285 O: DeserializeOwned,
286 Form: AsyncTryFrom<F, Error = OpenAIError>,
287 F: Clone,
288 {
289 let request_maker = || async {
290 Ok(self
291 .http_client
292 .post(self.config.url(path))
293 .query(&self.config.query())
294 .headers(self.config.headers())
295 .multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
296 .build()?)
297 };
298
299 self.execute(request_maker).await
300 }
301
302 async fn execute_raw<M, Fut>(&self, request_maker: M) -> Result<Bytes, OpenAIError>
308 where
309 M: Fn() -> Fut,
310 Fut: future::Future<Output = Result<reqwest::Request, OpenAIError>>,
311 {
312 let client = self.http_client.clone();
313
314 let request = request_maker().await?;
315 let response = client
316 .execute(request)
317 .await
318 .map_err(OpenAIError::Reqwest)?;
319
320 let status = response.status();
321 let bytes = response.bytes().await.map_err(OpenAIError::Reqwest)?;
322
323 if !status.is_success() {
325 let wrapped_error: WrappedError = serde_json::from_slice(bytes.as_ref())
326 .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
327
328 if status.as_u16() == 429
329 && wrapped_error.error.r#type != Some("insufficient_quota".to_string())
332 {
333 tracing::warn!("Rate limited: {}", wrapped_error.error.message);
335 return Err(OpenAIError::ApiError(wrapped_error.error));
336 } else {
337 return Err(OpenAIError::ApiError(wrapped_error.error));
338 }
339 }
340
341 Ok(bytes)
342 }
343
344 async fn execute<O, M, Fut>(&self, request_maker: M) -> Result<O, OpenAIError>
350 where
351 O: DeserializeOwned,
352 M: Fn() -> Fut,
353 Fut: core::future::Future<Output = Result<reqwest::Request, OpenAIError>>,
354 {
355 let bytes = self.execute_raw(request_maker).await?;
356
357 let response: O = serde_json::from_slice(bytes.as_ref())
358 .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
359
360 Ok(response)
361 }
362
363 pub(crate) async fn post_stream<I, O>(&self, path: &str, request: I) -> OpenAIEventStream<O>
365 where
366 I: Serialize,
367 O: DeserializeOwned + Send + 'static,
368 {
369 let event_source = self
370 .http_client
371 .post(self.config.url(path))
372 .query(&self.config.query())
373 .headers(self.config.headers())
374 .json(&request)
375 .eventsource()
376 .unwrap();
377
378 OpenAIEventStream::new(event_source)
379 }
380
381 pub(crate) async fn post_stream_mapped_raw_events<I, O>(
382 &self,
383 path: &str,
384 request: I,
385 event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
386 ) -> OpenAIEventMappedStream<O>
387 where
388 I: Serialize,
389 O: DeserializeOwned + Send + 'static,
390 {
391 let event_source = self
392 .http_client
393 .post(self.config.url(path))
394 .query(&self.config.query())
395 .headers(self.config.headers())
396 .json(&request)
397 .eventsource()
398 .unwrap();
399
400 OpenAIEventMappedStream::new(event_source, event_mapper)
401 }
402
403 pub(crate) async fn _get_stream<Q, O>(&self, path: &str, query: &Q) -> OpenAIEventStream<O>
405 where
406 Q: Serialize + ?Sized,
407 O: DeserializeOwned + Send + 'static,
408 {
409 let event_source = self
410 .http_client
411 .get(self.config.url(path))
412 .query(query)
413 .query(&self.config.query())
414 .headers(self.config.headers())
415 .eventsource()
416 .unwrap();
417
418 OpenAIEventStream::new(event_source)
419 }
420}
421
422#[pin_project]
425pub struct OpenAIEventStream<O: DeserializeOwned + Send + 'static> {
426 #[pin]
427 stream: Filter<
428 EventSource,
429 future::Ready<bool>,
430 fn(&Result<Event, reqwest_eventsource::Error>) -> future::Ready<bool>,
431 >,
432 done: bool,
433 _phantom_data: PhantomData<O>,
434}
435
436impl<O: DeserializeOwned + Send + 'static> OpenAIEventStream<O> {
437 pub(crate) fn new(event_source: EventSource) -> Self {
438 Self {
439 stream: event_source.filter(|result|
440 future::ready(!(result.is_ok() && result.as_ref().unwrap().eq(&Event::Open)))),
442 done: false,
443 _phantom_data: PhantomData,
444 }
445 }
446}
447
448impl<O: DeserializeOwned + Send + 'static> Stream for OpenAIEventStream<O> {
449 type Item = Result<O, OpenAIError>;
450
451 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
452 let this = self.project();
453 if *this.done {
454 return Poll::Ready(None);
455 }
456 let stream: Pin<&mut _> = this.stream;
457 match stream.poll_next(cx) {
458 Poll::Ready(response) => {
459 match response {
460 None => Poll::Ready(None), Some(result) => match result {
462 Ok(event) => match event {
463 Event::Open => unreachable!(), Event::Message(message) => {
465 if message.data == "[DONE]" {
466 *this.done = true;
467 Poll::Ready(None) } else {
469 match serde_json::from_str::<O>(&message.data) {
471 Err(e) => {
472 *this.done = true;
473 Poll::Ready(Some(Err(map_deserialization_error(
474 e,
475 &message.data.as_bytes(),
476 ))))
477 }
478 Ok(output) => Poll::Ready(Some(Ok(output))),
479 }
480 }
481 }
482 },
483 Err(e) => {
484 *this.done = true;
485 Poll::Ready(Some(Err(OpenAIError::StreamError(e.to_string()))))
486 }
487 },
488 }
489 }
490 Poll::Pending => Poll::Pending,
491 }
492 }
493}
494
495#[pin_project]
496pub struct OpenAIEventMappedStream<O>
497where
498 O: Send + 'static,
499{
500 #[pin]
501 stream: Filter<
502 EventSource,
503 future::Ready<bool>,
504 fn(&Result<Event, reqwest_eventsource::Error>) -> future::Ready<bool>,
505 >,
506 event_mapper: Box<dyn Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static>,
507 done: bool,
508 _phantom_data: PhantomData<O>,
509}
510
511impl<O> OpenAIEventMappedStream<O>
512where
513 O: Send + 'static,
514{
515 pub(crate) fn new<M>(event_source: EventSource, event_mapper: M) -> Self
516 where
517 M: Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
518 {
519 Self {
520 stream: event_source.filter(|result|
521 future::ready(!(result.is_ok() && result.as_ref().unwrap().eq(&Event::Open)))),
523 done: false,
524 event_mapper: Box::new(event_mapper),
525 _phantom_data: PhantomData,
526 }
527 }
528}
529
530impl<O> Stream for OpenAIEventMappedStream<O>
531where
532 O: Send + 'static,
533{
534 type Item = Result<O, OpenAIError>;
535
536 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
537 let this = self.project();
538 if *this.done {
539 return Poll::Ready(None);
540 }
541 let stream: Pin<&mut _> = this.stream;
542 match stream.poll_next(cx) {
543 Poll::Ready(response) => {
544 match response {
545 None => Poll::Ready(None), Some(result) => match result {
547 Ok(event) => match event {
548 Event::Open => unreachable!(), Event::Message(message) => {
550 if message.data == "[DONE]" {
551 *this.done = true;
552 }
553 let response = (this.event_mapper)(message);
554 match response {
555 Ok(output) => Poll::Ready(Some(Ok(output))),
556 Err(_) => Poll::Ready(None),
557 }
558 }
559 },
560 Err(e) => {
561 *this.done = true;
562 Poll::Ready(Some(Err(OpenAIError::StreamError(e.to_string()))))
563 }
564 },
565 }
566 }
567 Poll::Pending => Poll::Pending,
568 }
569 }
570}
571
572