1use std::future;
2use std::marker::PhantomData;
3use std::pin::Pin;
4use std::task::{Context, Poll};
5
6use bytes::Bytes;
7use futures::stream::Filter;
8use futures::{Stream, stream::StreamExt};
9use pin_project::pin_project;
10use reqwest::multipart::Form;
11use reqwest_eventsource::{Event, EventSource, RequestBuilderExt};
12use serde::{Serialize, de::DeserializeOwned};
13
14use crate::{
15 Assistants, Audio, AuditLogs, Batches, Chat, Completions, Embeddings, FineTuning, Invites,
16 Models, Projects, Responses, Threads, Uploads, Users, VectorStores,
17 config::{Config, OpenAIConfig},
18 error::{OpenAIError, WrappedError, map_deserialization_error},
19 file::Files,
20 image::Images,
21 moderation::Moderations,
22 traits::AsyncTryFrom,
23};
24
25#[derive(Debug, Clone)]
26pub struct Client<C: Config> {
29 http_client: reqwest::Client,
30 config: C,
31}
32
33impl Client<OpenAIConfig> {
34 pub fn new() -> Self {
36 Self {
37 http_client: reqwest::Client::new(),
38 config: OpenAIConfig::default(),
39 }
40 }
41}
42
43impl<C: Config> Client<C> {
44 pub fn build(http_client: reqwest::Client, config: C) -> Self {
46 Self {
47 http_client,
48 config,
49 }
50 }
51
52 pub fn with_config(config: C) -> Self {
54 Self {
55 http_client: reqwest::Client::new(),
56 config,
57 }
58 }
59
60 pub fn with_http_client(mut self, http_client: reqwest::Client) -> Self {
64 self.http_client = http_client;
65 self
66 }
67
68 pub fn models(&self) -> Models<C> {
72 Models::new(self)
73 }
74
75 pub fn completions(&self) -> Completions<C> {
77 Completions::new(self)
78 }
79
80 pub fn chat(&self) -> Chat<C> {
82 Chat::new(self)
83 }
84
85 pub fn images(&self) -> Images<C> {
87 Images::new(self)
88 }
89
90 pub fn moderations(&self) -> Moderations<C> {
92 Moderations::new(self)
93 }
94
95 pub fn files(&self) -> Files<C> {
97 Files::new(self)
98 }
99
100 pub fn uploads(&self) -> Uploads<C> {
102 Uploads::new(self)
103 }
104
105 pub fn fine_tuning(&self) -> FineTuning<C> {
107 FineTuning::new(self)
108 }
109
110 pub fn embeddings(&self) -> Embeddings<C> {
112 Embeddings::new(self)
113 }
114
115 pub fn audio(&self) -> Audio<C> {
117 Audio::new(self)
118 }
119
120 pub fn assistants(&self) -> Assistants<C> {
122 Assistants::new(self)
123 }
124
125 pub fn threads(&self) -> Threads<C> {
127 Threads::new(self)
128 }
129
130 pub fn vector_stores(&self) -> VectorStores<C> {
132 VectorStores::new(self)
133 }
134
135 pub fn batches(&self) -> Batches<C> {
137 Batches::new(self)
138 }
139
140 pub fn audit_logs(&self) -> AuditLogs<C> {
142 AuditLogs::new(self)
143 }
144
145 pub fn invites(&self) -> Invites<C> {
147 Invites::new(self)
148 }
149
150 pub fn users(&self) -> Users<C> {
152 Users::new(self)
153 }
154
155 pub fn projects(&self) -> Projects<C> {
157 Projects::new(self)
158 }
159
160 pub fn responses(&self) -> Responses<C> {
162 Responses::new(self)
163 }
164
165 pub fn config(&self) -> &C {
166 &self.config
167 }
168
169 pub(crate) async fn get<O>(&self, path: &str) -> Result<O, OpenAIError>
171 where
172 O: DeserializeOwned,
173 {
174 let request_maker = async || {
175 Ok(self
176 .http_client
177 .get(self.config.url(path))
178 .query(&self.config.query())
179 .headers(self.config.headers())
180 .build()?)
181 };
182
183 self.execute(request_maker).await
184 }
185
186 pub(crate) async fn get_with_query<Q, O>(&self, path: &str, query: &Q) -> Result<O, OpenAIError>
188 where
189 O: DeserializeOwned,
190 Q: Serialize + ?Sized,
191 {
192 let request_maker = async || {
193 Ok(self
194 .http_client
195 .get(self.config.url(path))
196 .query(&self.config.query())
197 .query(query)
198 .headers(self.config.headers())
199 .build()?)
200 };
201
202 self.execute(request_maker).await
203 }
204
205 pub(crate) async fn delete<O>(&self, path: &str) -> Result<O, OpenAIError>
207 where
208 O: DeserializeOwned,
209 {
210 let request_maker = async || {
211 Ok(self
212 .http_client
213 .delete(self.config.url(path))
214 .query(&self.config.query())
215 .headers(self.config.headers())
216 .build()?)
217 };
218
219 self.execute(request_maker).await
220 }
221
222 pub(crate) async fn get_raw(&self, path: &str) -> Result<Bytes, OpenAIError> {
224 let request_maker = async || {
225 Ok(self
226 .http_client
227 .get(self.config.url(path))
228 .query(&self.config.query())
229 .headers(self.config.headers())
230 .build()?)
231 };
232
233 self.execute_raw(request_maker).await
234 }
235
236 pub(crate) async fn post_raw<I>(&self, path: &str, request: I) -> Result<Bytes, OpenAIError>
238 where
239 I: Serialize,
240 {
241 let request_maker = async || {
242 Ok(self
243 .http_client
244 .post(self.config.url(path))
245 .query(&self.config.query())
246 .headers(self.config.headers())
247 .json(&request)
248 .build()?)
249 };
250
251 self.execute_raw(request_maker).await
252 }
253
254 pub(crate) async fn post<I, O>(&self, path: &str, request: I) -> Result<O, OpenAIError>
256 where
257 I: Serialize,
258 O: DeserializeOwned,
259 {
260 let request_maker = async || {
261 Ok(self
262 .http_client
263 .post(self.config.url(path))
264 .query(&self.config.query())
265 .headers(self.config.headers())
266 .json(&request)
267 .build()?)
268 };
269
270 self.execute(request_maker).await
271 }
272
273 pub(crate) async fn post_form_raw<F>(&self, path: &str, form: F) -> Result<Bytes, OpenAIError>
275 where
276 Form: AsyncTryFrom<F, Error = OpenAIError>,
277 F: Clone,
278 {
279 let request_maker = async || {
280 Ok(self
281 .http_client
282 .post(self.config.url(path))
283 .query(&self.config.query())
284 .headers(self.config.headers())
285 .multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
286 .build()?)
287 };
288
289 self.execute_raw(request_maker).await
290 }
291
292 pub(crate) async fn post_form<O, F>(&self, path: &str, form: F) -> Result<O, OpenAIError>
294 where
295 O: DeserializeOwned,
296 Form: AsyncTryFrom<F, Error = OpenAIError>,
297 F: Clone,
298 {
299 let request_maker = async || {
300 Ok(self
301 .http_client
302 .post(self.config.url(path))
303 .query(&self.config.query())
304 .headers(self.config.headers())
305 .multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
306 .build()?)
307 };
308
309 self.execute(request_maker).await
310 }
311
312 async fn execute_raw(
318 &self,
319 request_maker: impl AsyncFn() -> Result<reqwest::Request, OpenAIError>,
320 ) -> Result<Bytes, OpenAIError> {
321 let client = self.http_client.clone();
322
323 let request = request_maker().await?;
324 let response = client
325 .execute(request)
326 .await
327 .map_err(OpenAIError::Reqwest)?;
328
329 let status = response.status();
330 let bytes = response.bytes().await.map_err(OpenAIError::Reqwest)?;
331
332 if !status.is_success() {
334 let wrapped_error: WrappedError = serde_json::from_slice(bytes.as_ref())
335 .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
336
337 if status.as_u16() == 429
338 && wrapped_error.error.r#type != Some("insufficient_quota".to_string())
341 {
342 tracing::warn!("Rate limited: {}", wrapped_error.error.message);
344 return Err(OpenAIError::ApiError(wrapped_error.error));
345 } else {
346 return Err(OpenAIError::ApiError(wrapped_error.error));
347 }
348 }
349
350 Ok(bytes)
351 }
352
353 async fn execute<O>(
359 &self,
360 request_maker: impl AsyncFn() -> Result<reqwest::Request, OpenAIError>,
361 ) -> Result<O, OpenAIError>
362 where
363 O: DeserializeOwned,
364 {
365 let bytes = self.execute_raw(request_maker).await?;
366
367 let response: O = serde_json::from_slice(bytes.as_ref())
368 .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
369
370 Ok(response)
371 }
372
373 pub(crate) async fn post_stream<I, O>(&self, path: &str, request: I) -> OpenAIEventStream<O>
375 where
376 I: Serialize,
377 O: DeserializeOwned + Send + 'static,
378 {
379 let event_source = self
380 .http_client
381 .post(self.config.url(path))
382 .query(&self.config.query())
383 .headers(self.config.headers())
384 .json(&request)
385 .eventsource()
386 .unwrap();
387
388 OpenAIEventStream::new(event_source)
389 }
390
391 pub(crate) async fn post_stream_mapped_raw_events<I, O>(
392 &self,
393 path: &str,
394 request: I,
395 event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
396 ) -> OpenAIEventStream<O>
397 where
398 I: Serialize,
399 O: DeserializeOwned + Send + 'static,
400 {
401 let event_source = self
402 .http_client
403 .post(self.config.url(path))
404 .query(&self.config.query())
405 .headers(self.config.headers())
406 .json(&request)
407 .eventsource()
408 .unwrap();
409
410 OpenAIEventStream::with_event_mapping(event_source, event_mapper)
411 }
412
413 pub(crate) async fn _get_stream<Q, O>(&self, path: &str, query: &Q) -> OpenAIEventStream<O>
415 where
416 Q: Serialize + ?Sized,
417 O: DeserializeOwned + Send + 'static,
418 {
419 let event_source = self
420 .http_client
421 .get(self.config.url(path))
422 .query(query)
423 .query(&self.config.query())
424 .headers(self.config.headers())
425 .eventsource()
426 .unwrap();
427
428 OpenAIEventStream::new(event_source)
429 }
430}
431
432#[pin_project]
436pub struct OpenAIEventStream<O>
437where
438 O: DeserializeOwned + Send + 'static,
439{
440 #[pin]
441 stream: Filter<
442 EventSource,
443 future::Ready<bool>,
444 fn(&Result<Event, reqwest_eventsource::Error>) -> future::Ready<bool>,
445 >,
446 event_mapper:
447 Option<Box<dyn Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static>>,
448 done: bool,
449 _phantom_data: PhantomData<O>,
450}
451
452impl<O> OpenAIEventStream<O>
453where
454 O: DeserializeOwned + Send + 'static,
455{
456 pub(crate) fn with_event_mapping<M>(event_source: EventSource, event_mapper: M) -> Self
457 where
458 M: Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
459 {
460 Self {
461 stream: event_source.filter(|result|
462 future::ready(!(result.is_ok() && result.as_ref().unwrap().eq(&Event::Open)))),
464 done: false,
465 event_mapper: Some(Box::new(event_mapper)),
466 _phantom_data: PhantomData,
467 }
468 }
469
470 pub(crate) fn new(event_source: EventSource) -> Self {
471 Self {
472 stream: event_source.filter(|result|
473 future::ready(!(result.is_ok() && result.as_ref().unwrap().eq(&Event::Open)))),
475 done: false,
476 event_mapper: None,
477 _phantom_data: PhantomData,
478 }
479 }
480}
481
482impl<O> Stream for OpenAIEventStream<O>
483where
484 O: DeserializeOwned + Send + 'static,
485{
486 type Item = Result<O, OpenAIError>;
487
488 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
489 let this = self.project();
490 if *this.done {
491 return Poll::Ready(None);
492 }
493 let stream: Pin<&mut _> = this.stream;
494 match stream.poll_next(cx) {
495 Poll::Ready(response) => {
496 match response {
497 None => Poll::Ready(None), Some(result) => match result {
499 Ok(event) => match event {
500 Event::Open => unreachable!(), Event::Message(message) => {
502 if let Some(event_mapper) = this.event_mapper.as_ref() {
503 if message.data == "[DONE]" {
504 *this.done = true;
505 }
506 let response = event_mapper(message);
507 match response {
508 Ok(output) => Poll::Ready(Some(Ok(output))),
509 Err(_) => Poll::Ready(None),
510 }
511 } else {
512 if message.data == "[DONE]" {
513 *this.done = true;
514 Poll::Ready(None) } else {
516 match serde_json::from_str::<O>(&message.data) {
518 Err(e) => {
519 *this.done = true;
520 Poll::Ready(Some(Err(map_deserialization_error(
521 e,
522 &message.data.as_bytes(),
523 ))))
524 }
525 Ok(output) => Poll::Ready(Some(Ok(output))),
526 }
527 }
528 }
529 }
530 },
531 Err(e) => {
532 *this.done = true;
533 Poll::Ready(Some(Err(OpenAIError::StreamError(e.to_string()))))
534 }
535 },
536 }
537 }
538 Poll::Pending => Poll::Pending,
539 }
540 }
541}
542
543