Skip to main content

async_openai/
client.rs

1use std::sync::Arc;
2#[cfg(not(target_family = "wasm"))]
3use std::sync::Mutex;
4
5use bytes::Bytes;
6use futures::stream::StreamExt;
7use reqwest::{header::HeaderMap, multipart::Form, Response};
8use serde::{de::DeserializeOwned, Serialize};
9
10use crate::error::StreamError;
11#[cfg(feature = "middleware")]
12use crate::executor::TowerExecutor;
13use crate::{
14    config::{Config, OpenAIConfig},
15    error::{map_deserialization_error, ApiError, ApiErrorResponse, OpenAIError, WrappedError},
16    executor::{HttpRequestFactory, ReqwestExecutor, SharedExecutor},
17    traits::AsyncTryFrom,
18    RequestOptions,
19};
20
21struct RequestParts {
22    request_client: reqwest::Client,
23    method: reqwest::Method,
24    url: String,
25    headers: HeaderMap,
26    query: Vec<(String, String)>,
27}
28
29impl RequestParts {
30    fn build_request_builder(&self) -> reqwest::RequestBuilder {
31        self.request_client
32            .request(self.method.clone(), self.url.clone())
33            .query(&self.query)
34            .headers(self.headers.clone())
35    }
36}
37
38#[cfg(feature = "administration")]
39use crate::admin::Admin;
40#[cfg(feature = "chatkit")]
41use crate::chatkit::Chatkit;
42#[cfg(feature = "file")]
43use crate::file::Files;
44#[cfg(feature = "image")]
45use crate::image::Images;
46#[cfg(feature = "moderation")]
47use crate::moderation::Moderations;
48#[cfg(feature = "assistant")]
49#[allow(deprecated)]
50use crate::Assistants;
51#[cfg(feature = "audio")]
52use crate::Audio;
53#[cfg(feature = "batch")]
54use crate::Batches;
55#[cfg(feature = "chat-completion")]
56use crate::Chat;
57#[cfg(feature = "completions")]
58use crate::Completions;
59#[cfg(feature = "container")]
60use crate::Containers;
61#[cfg(feature = "responses")]
62use crate::Conversations;
63#[cfg(feature = "embedding")]
64use crate::Embeddings;
65#[cfg(feature = "evals")]
66use crate::Evals;
67#[cfg(feature = "finetuning")]
68use crate::FineTuning;
69#[cfg(feature = "model")]
70use crate::Models;
71#[cfg(feature = "realtime")]
72use crate::Realtime;
73#[cfg(feature = "responses")]
74use crate::Responses;
75#[cfg(feature = "skill")]
76use crate::Skills;
77#[cfg(feature = "assistant")]
78#[allow(deprecated)]
79use crate::Threads;
80#[cfg(feature = "upload")]
81use crate::Uploads;
82#[cfg(feature = "vectorstore")]
83use crate::VectorStores;
84#[cfg(feature = "video")]
85use crate::Videos;
86
87#[derive(Clone)]
88/// Client is a container for config and HTTP execution
89/// used to make API calls.
90pub struct Client<C: Config> {
91    request_client: reqwest::Client,
92    executor: SharedExecutor,
93    config: C,
94}
95
96impl<C> std::fmt::Debug for Client<C>
97where
98    C: Config + std::fmt::Debug,
99{
100    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
101        f.debug_struct("Client")
102            .field("request_client", &self.request_client)
103            .field("config", &self.config)
104            .finish()
105    }
106}
107
108impl<C: Config> Default for Client<C>
109where
110    C: Default,
111{
112    fn default() -> Self {
113        let request_client = reqwest::Client::new();
114        Self {
115            executor: Arc::new(ReqwestExecutor::new(request_client.clone())),
116            request_client,
117            config: C::default(),
118        }
119    }
120}
121
122impl Client<OpenAIConfig> {
123    /// Client with default [OpenAIConfig]
124    pub fn new() -> Self {
125        Self::default()
126    }
127}
128
129impl<C: Config> Client<C> {
130    /// Create client with a custom HTTP client and config.
131    pub fn build(http_client: reqwest::Client, config: C) -> Self {
132        Self {
133            executor: Arc::new(ReqwestExecutor::new(http_client.clone())),
134            request_client: http_client,
135            config,
136        }
137    }
138
139    /// Create client with [OpenAIConfig] or [crate::config::AzureConfig]
140    pub fn with_config(config: C) -> Self {
141        let request_client = reqwest::Client::new();
142        Self {
143            executor: Arc::new(ReqwestExecutor::new(request_client.clone())),
144            request_client,
145            config,
146        }
147    }
148
149    /// Provide your own [client] to make HTTP requests with.
150    ///
151    /// [client]: reqwest::Client
152    pub fn with_http_client(mut self, http_client: reqwest::Client) -> Self {
153        self.executor = Arc::new(ReqwestExecutor::new(http_client.clone()));
154        self.request_client = http_client;
155        self
156    }
157
158    /// Provide your own tower-compatible service to execute HTTP requests.
159    #[cfg(all(feature = "middleware", not(target_family = "wasm")))]
160    pub fn with_http_service<S>(mut self, service: S) -> Self
161    where
162        S: tower::Service<HttpRequestFactory, Response = Response> + Clone + Send + Sync + 'static,
163        S::Future: Send + 'static,
164        S::Error: Into<OpenAIError> + Send + Sync + 'static,
165    {
166        // This is the public middleware escape hatch. We erase the concrete
167        // tower stack here so the rest of the client does not become generic
168        // over the service type, which would otherwise leak through every API
169        // group and make the crate much harder to use.
170        self.executor = Arc::new(TowerExecutor::new(service));
171        self
172    }
173
174    /// Provide your own tower-compatible service to execute HTTP requests.
175    #[cfg(all(feature = "middleware", target_family = "wasm"))]
176    pub fn with_http_service<S>(mut self, service: S) -> Self
177    where
178        S: tower::Service<HttpRequestFactory, Response = Response> + Clone + 'static,
179        S::Future: 'static,
180        S::Error: Into<OpenAIError> + 'static,
181    {
182        // wasm futures produced by reqwest are not `Send`, so the wasm version
183        // intentionally avoids native thread-safety bounds. Users are still
184        // responsible for choosing tower layers that work in their wasm
185        // runtime.
186        self.executor = Arc::new(TowerExecutor::new(service));
187        self
188    }
189
190    // API groups
191
192    /// To call [Models] group related APIs using this client.
193    #[cfg(feature = "model")]
194    pub fn models(&self) -> Models<'_, C> {
195        Models::new(self)
196    }
197
198    /// To call [Completions] group related APIs using this client.
199    #[cfg(feature = "completions")]
200    pub fn completions(&self) -> Completions<'_, C> {
201        Completions::new(self)
202    }
203
204    /// To call [Chat] group related APIs using this client.
205    #[cfg(feature = "chat-completion")]
206    pub fn chat(&self) -> Chat<'_, C> {
207        Chat::new(self)
208    }
209
210    /// To call [Images] group related APIs using this client.
211    #[cfg(feature = "image")]
212    pub fn images(&self) -> Images<'_, C> {
213        Images::new(self)
214    }
215
216    /// To call [Moderations] group related APIs using this client.
217    #[cfg(feature = "moderation")]
218    pub fn moderations(&self) -> Moderations<'_, C> {
219        Moderations::new(self)
220    }
221
222    /// To call [Files] group related APIs using this client.
223    #[cfg(feature = "file")]
224    pub fn files(&self) -> Files<'_, C> {
225        Files::new(self)
226    }
227
228    /// To call [Uploads] group related APIs using this client.
229    #[cfg(feature = "upload")]
230    pub fn uploads(&self) -> Uploads<'_, C> {
231        Uploads::new(self)
232    }
233
234    /// To call [FineTuning] group related APIs using this client.
235    #[cfg(feature = "finetuning")]
236    pub fn fine_tuning(&self) -> FineTuning<'_, C> {
237        FineTuning::new(self)
238    }
239
240    /// To call [Embeddings] group related APIs using this client.
241    #[cfg(feature = "embedding")]
242    pub fn embeddings(&self) -> Embeddings<'_, C> {
243        Embeddings::new(self)
244    }
245
246    /// To call [Audio] group related APIs using this client.
247    #[cfg(feature = "audio")]
248    pub fn audio(&self) -> Audio<'_, C> {
249        Audio::new(self)
250    }
251
252    /// To call [Videos] group related APIs using this client.
253    #[cfg(feature = "video")]
254    pub fn videos(&self) -> Videos<'_, C> {
255        Videos::new(self)
256    }
257
258    /// To call [Assistants] group related APIs using this client.
259    #[cfg(feature = "assistant")]
260    #[deprecated(
261        note = "Assistants API is deprecated and will be removed in August 2026. Use the Responses API."
262    )]
263    #[allow(deprecated)]
264    pub fn assistants(&self) -> Assistants<'_, C> {
265        Assistants::new(self)
266    }
267
268    /// To call [Threads] group related APIs using this client.
269    #[cfg(feature = "assistant")]
270    #[deprecated(
271        note = "Assistants API is deprecated and will be removed in August 2026. Use the Responses API."
272    )]
273    #[allow(deprecated)]
274    pub fn threads(&self) -> Threads<'_, C> {
275        Threads::new(self)
276    }
277
278    /// To call [VectorStores] group related APIs using this client.
279    #[cfg(feature = "vectorstore")]
280    pub fn vector_stores(&self) -> VectorStores<'_, C> {
281        VectorStores::new(self)
282    }
283
284    /// To call [Batches] group related APIs using this client.
285    #[cfg(feature = "batch")]
286    pub fn batches(&self) -> Batches<'_, C> {
287        Batches::new(self)
288    }
289
290    /// To call [Admin] group related APIs using this client.
291    /// This groups together admin API keys, invites, users, projects, audit logs, and certificates.
292    #[cfg(feature = "administration")]
293    pub fn admin(&self) -> Admin<'_, C> {
294        Admin::new(self)
295    }
296
297    /// To call [Responses] group related APIs using this client.
298    #[cfg(feature = "responses")]
299    pub fn responses(&self) -> Responses<'_, C> {
300        Responses::new(self)
301    }
302
303    /// To call [Conversations] group related APIs using this client.
304    #[cfg(feature = "responses")]
305    pub fn conversations(&self) -> Conversations<'_, C> {
306        Conversations::new(self)
307    }
308
309    /// To call [Containers] group related APIs using this client.
310    #[cfg(feature = "container")]
311    pub fn containers(&self) -> Containers<'_, C> {
312        Containers::new(self)
313    }
314
315    /// To call [Skills] group related APIs using this client.
316    #[cfg(feature = "skill")]
317    pub fn skills(&self) -> Skills<'_, C> {
318        Skills::new(self)
319    }
320
321    /// To call [Evals] group related APIs using this client.
322    #[cfg(feature = "evals")]
323    pub fn evals(&self) -> Evals<'_, C> {
324        Evals::new(self)
325    }
326
327    #[cfg(feature = "chatkit")]
328    pub fn chatkit(&self) -> Chatkit<'_, C> {
329        Chatkit::new(self)
330    }
331
332    /// To call [Realtime] group related APIs using this client.
333    #[cfg(feature = "realtime")]
334    pub fn realtime(&self) -> Realtime<'_, C> {
335        Realtime::new(self)
336    }
337
338    pub fn config(&self) -> &C {
339        &self.config
340    }
341
342    fn build_request_parts(
343        &self,
344        method: reqwest::Method,
345        path: &str,
346        request_options: &RequestOptions,
347    ) -> Arc<RequestParts> {
348        let url = if let Some(path) = request_options.path() {
349            self.config.url(path.as_str())
350        } else {
351            self.config.url(path)
352        };
353        let mut headers = self.config.headers();
354        if let Some(request_headers) = request_options.headers() {
355            headers.extend(request_headers.clone());
356        }
357
358        let mut query = self
359            .config
360            .query()
361            .into_iter()
362            .map(|(key, value)| (key.to_string(), value.to_string()))
363            .collect::<Vec<_>>();
364        query.extend_from_slice(request_options.query());
365
366        Arc::new(RequestParts {
367            request_client: self.request_client.clone(),
368            method,
369            url,
370            headers,
371            query,
372        })
373    }
374
375    fn build_request_factory(
376        &self,
377        method: reqwest::Method,
378        path: &str,
379        request_options: &RequestOptions,
380    ) -> HttpRequestFactory {
381        let request_parts = self.build_request_parts(method, path, request_options);
382
383        HttpRequestFactory::new(move || {
384            let request_parts = request_parts.clone();
385
386            async move {
387                let request = request_parts.build_request_builder().build()?;
388                Ok(request)
389            }
390        })
391    }
392
393    fn build_request_factory_with_json<I>(
394        &self,
395        method: reqwest::Method,
396        path: &str,
397        request: I,
398        request_options: &RequestOptions,
399    ) -> Result<HttpRequestFactory, OpenAIError>
400    where
401        I: Serialize,
402    {
403        // JSON bodies are materialized once so the base BYOT path can keep
404        // accepting borrowed inputs.
405        let request = Bytes::from(serde_json::to_vec(&request).map_err(|error| {
406            OpenAIError::InvalidArgument(format!("failed to serialize request: {error}"))
407        })?);
408        let request_parts = self.build_request_parts(method, path, request_options);
409
410        Ok(HttpRequestFactory::new(move || {
411            let request_parts = request_parts.clone();
412            let request = request.clone();
413
414            async move {
415                let request_builder = request_parts
416                    .build_request_builder()
417                    .header(reqwest::header::CONTENT_TYPE, "application/json")
418                    .body(request.clone());
419
420                Ok(request_builder.build()?)
421            }
422        }))
423    }
424
425    fn build_request_factory_with_form<F>(
426        &self,
427        method: reqwest::Method,
428        path: &str,
429        form: F,
430        request_options: &RequestOptions,
431    ) -> Result<HttpRequestFactory, OpenAIError>
432    where
433        F: Clone + crate::traits::MaybeSend + 'static,
434        Form: AsyncTryFrom<F, Error = OpenAIError>,
435    {
436        // Multipart is the reason the factory exists.
437        //
438        // `Mutex` is only here to make the captured state `Sync` on native targets.
439        #[cfg(not(target_family = "wasm"))]
440        let form = Arc::new(Mutex::new(form));
441        let request_parts = self.build_request_parts(method, path, request_options);
442
443        Ok(HttpRequestFactory::new(move || {
444            let request_parts = request_parts.clone();
445            let form = form.clone();
446
447            async move {
448                #[cfg(not(target_family = "wasm"))]
449                let form = form
450                    .lock()
451                    .expect("multipart request factory mutex poisoned")
452                    .clone();
453                #[cfg(target_family = "wasm")]
454                let form = form.clone();
455                let form = <Form as AsyncTryFrom<F>>::try_from(form).await?;
456                let request_builder = request_parts.build_request_builder().multipart(form);
457
458                Ok(request_builder.build()?)
459            }
460        }))
461    }
462
463    /// Make a GET request to {path} and deserialize the response body
464    #[allow(unused)]
465    pub(crate) async fn get<O>(
466        &self,
467        path: &str,
468        request_options: &RequestOptions,
469    ) -> Result<O, OpenAIError>
470    where
471        O: DeserializeOwned,
472    {
473        let request_factory =
474            self.build_request_factory(reqwest::Method::GET, path, request_options);
475        self.execute(request_factory).await
476    }
477
478    /// Make a DELETE request to {path} and deserialize the response body
479    #[allow(unused)]
480    pub(crate) async fn delete<O>(
481        &self,
482        path: &str,
483        request_options: &RequestOptions,
484    ) -> Result<O, OpenAIError>
485    where
486        O: DeserializeOwned,
487    {
488        let request_factory =
489            self.build_request_factory(reqwest::Method::DELETE, path, request_options);
490        self.execute(request_factory).await
491    }
492
493    /// Make a GET request to {path} and return the response body
494    #[allow(unused)]
495    pub(crate) async fn get_raw(
496        &self,
497        path: &str,
498        request_options: &RequestOptions,
499    ) -> Result<(Bytes, HeaderMap), OpenAIError> {
500        let request_factory =
501            self.build_request_factory(reqwest::Method::GET, path, request_options);
502        self.execute_raw(request_factory).await
503    }
504
505    /// Make a POST request to {path} and return the response body
506    #[allow(unused)]
507    pub(crate) async fn post_raw<I>(
508        &self,
509        path: &str,
510        request: I,
511        request_options: &RequestOptions,
512    ) -> Result<(Bytes, HeaderMap), OpenAIError>
513    where
514        I: Serialize,
515    {
516        let request_factory = self.build_request_factory_with_json(
517            reqwest::Method::POST,
518            path,
519            request,
520            request_options,
521        )?;
522        self.execute_raw(request_factory).await
523    }
524
525    /// Make a POST request to {path} and deserialize the response body
526    #[allow(unused)]
527    pub(crate) async fn post<I, O>(
528        &self,
529        path: &str,
530        request: I,
531        request_options: &RequestOptions,
532    ) -> Result<O, OpenAIError>
533    where
534        I: Serialize,
535        O: DeserializeOwned,
536    {
537        let request_factory = self.build_request_factory_with_json(
538            reqwest::Method::POST,
539            path,
540            request,
541            request_options,
542        )?;
543        self.execute(request_factory).await
544    }
545
546    /// POST a form at {path} and return the response body
547    #[allow(unused)]
548    pub(crate) async fn post_form_raw<F>(
549        &self,
550        path: &str,
551        form: F,
552        request_options: &RequestOptions,
553    ) -> Result<(Bytes, HeaderMap), OpenAIError>
554    where
555        F: Clone + crate::traits::MaybeSend + 'static,
556        Form: AsyncTryFrom<F, Error = OpenAIError>,
557    {
558        let request_factory = self.build_request_factory_with_form(
559            reqwest::Method::POST,
560            path,
561            form,
562            request_options,
563        )?;
564        self.execute_raw(request_factory).await
565    }
566
567    /// POST a form at {path} and deserialize the response body
568    #[allow(unused)]
569    pub(crate) async fn post_form<O, F>(
570        &self,
571        path: &str,
572        form: F,
573        request_options: &RequestOptions,
574    ) -> Result<O, OpenAIError>
575    where
576        O: DeserializeOwned,
577        F: Clone + crate::traits::MaybeSend + 'static,
578        Form: AsyncTryFrom<F, Error = OpenAIError>,
579    {
580        let request_factory = self.build_request_factory_with_form(
581            reqwest::Method::POST,
582            path,
583            form,
584            request_options,
585        )?;
586        self.execute(request_factory).await
587    }
588
589    #[allow(unused)]
590    pub(crate) async fn post_form_stream<O, F>(
591        &self,
592        path: &str,
593        form: F,
594        request_options: &RequestOptions,
595    ) -> Result<crate::types::stream::StreamResponse<O>, OpenAIError>
596    where
597        F: Clone + crate::traits::MaybeSend + 'static,
598        Form: AsyncTryFrom<F, Error = OpenAIError>,
599        O: DeserializeOwned + crate::traits::MaybeSend + 'static,
600    {
601        let request_factory = self.build_request_factory_with_form(
602            reqwest::Method::POST,
603            path,
604            form,
605            request_options,
606        )?;
607
608        self.execute_stream(request_factory).await
609    }
610
611    async fn execute_raw(
612        &self,
613        request_factory: HttpRequestFactory,
614    ) -> Result<(Bytes, HeaderMap), OpenAIError> {
615        let response = self.execute_response(request_factory).await?;
616        read_response(response).await
617    }
618
619    async fn execute<O>(&self, request_factory: HttpRequestFactory) -> Result<O, OpenAIError>
620    where
621        O: DeserializeOwned,
622    {
623        let (bytes, _headers) = self.execute_raw(request_factory).await?;
624
625        let response: O = serde_json::from_slice(bytes.as_ref())
626            .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
627
628        Ok(response)
629    }
630
631    async fn execute_response(
632        &self,
633        request_factory: HttpRequestFactory,
634    ) -> Result<Response, OpenAIError> {
635        self.executor.execute(request_factory).await
636    }
637
638    async fn execute_stream<O>(
639        &self,
640        request_factory: HttpRequestFactory,
641    ) -> Result<crate::types::stream::StreamResponse<O>, OpenAIError>
642    where
643        O: DeserializeOwned + crate::traits::MaybeSend + 'static,
644    {
645        let response = self.execute_response(request_factory).await?;
646        Ok(stream(response).await)
647    }
648
649    async fn execute_stream_mapped_raw_events<O>(
650        &self,
651        request_factory: HttpRequestFactory,
652        event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError>
653            + crate::traits::MaybeSend
654            + 'static,
655    ) -> Result<crate::types::stream::StreamResponse<O>, OpenAIError>
656    where
657        O: DeserializeOwned + crate::traits::MaybeSend + 'static,
658    {
659        let response = self.execute_response(request_factory).await?;
660        Ok(stream_mapped_raw_events(response, event_mapper).await)
661    }
662
663    /// Make HTTP POST request to receive SSE
664    #[allow(unused)]
665    pub(crate) async fn post_stream<I, O>(
666        &self,
667        path: &str,
668        request: I,
669        request_options: &RequestOptions,
670    ) -> Result<crate::types::stream::StreamResponse<O>, OpenAIError>
671    where
672        I: Serialize,
673        O: DeserializeOwned + crate::traits::MaybeSend + 'static,
674    {
675        let request_factory = self.build_request_factory_with_json(
676            reqwest::Method::POST,
677            path,
678            request,
679            request_options,
680        )?;
681        // Stream setup is still request/response first. We only create the SSE
682        // stream after the HTTP layer has returned a response object.
683        self.execute_stream(request_factory).await
684    }
685
686    #[allow(unused)]
687    pub(crate) async fn post_stream_mapped_raw_events<I, O>(
688        &self,
689        path: &str,
690        request: I,
691        request_options: &RequestOptions,
692        event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError>
693            + crate::traits::MaybeSend
694            + 'static,
695    ) -> Result<crate::types::stream::StreamResponse<O>, OpenAIError>
696    where
697        I: Serialize,
698        O: DeserializeOwned + crate::traits::MaybeSend + 'static,
699    {
700        let request_factory = self.build_request_factory_with_json(
701            reqwest::Method::POST,
702            path,
703            request,
704            request_options,
705        )?;
706        self.execute_stream_mapped_raw_events(request_factory, event_mapper)
707            .await
708    }
709
710    /// Make HTTP GET request to receive SSE
711    #[allow(unused)]
712    pub(crate) async fn get_stream<O>(
713        &self,
714        path: &str,
715        request_options: &RequestOptions,
716    ) -> Result<crate::types::stream::StreamResponse<O>, OpenAIError>
717    where
718        O: DeserializeOwned + crate::traits::MaybeSend + 'static,
719    {
720        let request_factory =
721            self.build_request_factory(reqwest::Method::GET, path, request_options);
722        self.execute_stream(request_factory).await
723    }
724}
725
726async fn read_response(response: Response) -> Result<(Bytes, HeaderMap), OpenAIError> {
727    let status = response.status();
728    let headers = response.headers().clone();
729    let bytes = response.bytes().await.map_err(OpenAIError::Reqwest)?;
730
731    if status.is_server_error() {
732        // OpenAI does not guarantee server errors are returned as JSON so we cannot deserialize them.
733        let message: String = String::from_utf8_lossy(&bytes).into_owned();
734        tracing::warn!("Server error: {status} - {message}");
735        return Err(OpenAIError::ApiError(ApiErrorResponse {
736            status_code: status,
737            api_error: ApiError {
738                message,
739                r#type: None,
740                param: None,
741                code: None,
742            },
743        }));
744    }
745
746    // Deserialize response body from either error object or actual response object
747    if !status.is_success() {
748        let wrapped_error: WrappedError = serde_json::from_slice(bytes.as_ref())
749            .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
750
751        return Err(OpenAIError::ApiError(ApiErrorResponse {
752            status_code: status,
753            api_error: wrapped_error.error,
754        }));
755    }
756
757    Ok((bytes, headers))
758}
759
760/// Request which responds with SSE.
761/// [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format)
762pub(crate) async fn stream<O>(response: Response) -> crate::types::stream::StreamResponse<O>
763where
764    O: DeserializeOwned + crate::traits::MaybeSend + 'static,
765{
766    stream_mapped_raw_events(response, |event| {
767        serde_json::from_str::<O>(&event.data)
768            .map_err(|error| map_deserialization_error(error, event.data.as_bytes()))
769    })
770    .await
771}
772
773#[cfg(target_family = "wasm")]
774pub(crate) async fn stream_mapped_raw_events<O>(
775    response: Response,
776    event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + 'static,
777) -> crate::types::stream::StreamResponse<O>
778where
779    O: DeserializeOwned + 'static,
780{
781    if !response.status().is_success() {
782        return Box::pin(futures::stream::once(async move {
783            match read_response(response).await {
784                Ok(_) => Err(OpenAIError::InvalidArgument(
785                    "stream request failed without an error body".into(),
786                )),
787                Err(error) => Err(error),
788            }
789        }));
790    }
791
792    let byte_stream = response
793        .bytes_stream()
794        .map(|result| result.map_err(std::io::Error::other));
795    let event_stream = Box::pin(eventsource_stream::EventStream::new(byte_stream));
796
797    Box::pin(futures::stream::unfold(
798        (event_stream, event_mapper),
799        |(mut event_stream, event_mapper)| async move {
800            loop {
801                let event = match event_stream.next().await {
802                    Some(Ok(event)) => event,
803                    Some(Err(error)) => {
804                        return Some((
805                            Err(OpenAIError::StreamError(Box::new(
806                                StreamError::EventStream(error.to_string()),
807                            ))),
808                            (event_stream, event_mapper),
809                        ));
810                    }
811                    None => return None,
812                };
813
814                if event.data == "[DONE]" {
815                    return None;
816                }
817
818                if event.event == "keepalive" {
819                    continue;
820                }
821
822                let response = event_mapper(event);
823                return Some((response, (event_stream, event_mapper)));
824            }
825        },
826    ))
827}
828
829#[cfg(not(target_family = "wasm"))]
830pub(crate) async fn stream_mapped_raw_events<O>(
831    response: Response,
832    event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
833) -> crate::types::stream::StreamResponse<O>
834where
835    O: DeserializeOwned + std::marker::Send + 'static,
836{
837    let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
838
839    tokio::spawn(async move {
840        if !response.status().is_success() {
841            if let Err(e) = read_response(response).await {
842                let _ = tx.send(Err(e));
843            }
844            return;
845        }
846        let byte_stream = response
847            .bytes_stream()
848            .map(|r| r.map_err(std::io::Error::other));
849        let mut event_stream = std::pin::pin!(eventsource_stream::EventStream::new(byte_stream));
850
851        while let Some(ev) = event_stream.next().await {
852            let event = match ev {
853                Ok(e) => e,
854                Err(e) => {
855                    let _ = tx.send(Err(OpenAIError::StreamError(Box::new(
856                        StreamError::EventStream(e.to_string()),
857                    ))));
858                    break;
859                }
860            };
861            if event.data == "[DONE]" {
862                break;
863            }
864
865            if event.event == "keepalive" {
866                continue;
867            }
868
869            let response = event_mapper(event);
870
871            if tx.send(response).is_err() {
872                break;
873            }
874        }
875    });
876
877    Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
878}
879
880#[cfg(all(test, feature = "middleware", not(target_family = "wasm")))]
881mod tests {
882    use std::sync::{
883        atomic::{AtomicUsize, Ordering},
884        Arc,
885    };
886
887    use futures::StreamExt;
888    use http::Response as HttpResponse;
889    use serde_json::json;
890    use tower::{service_fn, ServiceBuilder};
891
892    use super::Client;
893    use crate::{
894        config::OpenAIConfig, error::OpenAIError, executor::HttpRequestFactory,
895        retry::SimpleRetryPolicy, traits::AsyncTryFrom, RequestOptions,
896    };
897
898    #[tokio::test]
899    async fn unary_requests_dispatch_through_middleware_service() {
900        let request_count = Arc::new(AtomicUsize::new(0));
901        let service = {
902            let request_count = request_count.clone();
903            ServiceBuilder::new()
904                .concurrency_limit(1)
905                .service(service_fn(move |factory: HttpRequestFactory| {
906                    let request_count = request_count.clone();
907                    async move {
908                        let request = factory.build().await?;
909                        assert_eq!(request.url().path(), "/models");
910                        request_count.fetch_add(1, Ordering::SeqCst);
911                        Ok::<reqwest::Response, OpenAIError>(
912                            HttpResponse::builder()
913                                .status(200)
914                                .header("content-type", "application/json")
915                                .body(reqwest::Body::from(
916                                    "{\"object\":\"list\",\"data\":[{\"id\":\"model\"}]}",
917                                ))
918                                .unwrap()
919                                .into(),
920                        )
921                    }
922                }))
923        };
924
925        let client = Client::with_config(
926            OpenAIConfig::new()
927                .with_api_base("http://example.test")
928                .with_api_key("test-key"),
929        )
930        .with_http_service(service);
931
932        let value: serde_json::Value = client.get("/models", &RequestOptions::new()).await.unwrap();
933
934        assert_eq!(value["object"], "list");
935        assert_eq!(request_count.load(Ordering::SeqCst), 1);
936    }
937
938    #[tokio::test]
939    async fn stream_requests_open_through_middleware_service() {
940        let request_count = Arc::new(AtomicUsize::new(0));
941        let service = {
942            let request_count = request_count.clone();
943            ServiceBuilder::new()
944                .concurrency_limit(1)
945                .service(service_fn(move |factory: HttpRequestFactory| {
946                    let request_count = request_count.clone();
947                    async move {
948                        let request = factory.build().await?;
949                        assert_eq!(request.url().path(), "/responses");
950                        request_count.fetch_add(1, Ordering::SeqCst);
951                        Ok::<reqwest::Response, OpenAIError>(
952                            HttpResponse::builder()
953                                .status(200)
954                                .header("content-type", "text/event-stream")
955                                .body(reqwest::Body::from(
956                                    "data: {\"ok\":true}\n\ndata: [DONE]\n\n",
957                                ))
958                                .unwrap()
959                                .into(),
960                        )
961                    }
962                }))
963        };
964
965        let client = Client::with_config(
966            OpenAIConfig::new()
967                .with_api_base("http://example.test")
968                .with_api_key("test-key"),
969        )
970        .with_http_service(service);
971
972        let mut stream = client
973            .post_stream::<_, serde_json::Value>(
974                "/responses",
975                json!({ "stream": true }),
976                &RequestOptions::new(),
977            )
978            .await
979            .unwrap();
980
981        let first = stream.next().await.unwrap().unwrap();
982
983        assert_eq!(first, json!({ "ok": true }));
984        assert_eq!(request_count.load(Ordering::SeqCst), 1);
985    }
986
987    #[tokio::test]
988    async fn middleware_retry_policy_retries_429_responses() {
989        let request_count = Arc::new(AtomicUsize::new(0));
990        let service = {
991            let request_count = request_count.clone();
992            ServiceBuilder::new()
993                .retry(SimpleRetryPolicy::default())
994                .service(service_fn(move |factory: HttpRequestFactory| {
995                    let request_count = request_count.clone();
996                    async move {
997                        let request = factory.build().await?;
998                        assert_eq!(request.url().path(), "/models");
999                        let attempt = request_count.fetch_add(1, Ordering::SeqCst);
1000
1001                        let response = if attempt == 0 {
1002                            HttpResponse::builder()
1003                                .status(429)
1004                                .header("content-type", "application/json")
1005                                .body(reqwest::Body::from(
1006                                    r#"{"error":{"message":"retry me","type":"rate_limit_error","param":null,"code":null}}"#,
1007                                ))
1008                                .unwrap()
1009                        } else {
1010                            HttpResponse::builder()
1011                                .status(200)
1012                                .header("content-type", "application/json")
1013                                .body(reqwest::Body::from(
1014                                    r#"{"object":"list","data":[{"id":"retry-model"}]}"#,
1015                                ))
1016                                .unwrap()
1017                        };
1018
1019                        Ok::<reqwest::Response, OpenAIError>(response.into())
1020                    }
1021                }))
1022        };
1023
1024        let client = Client::with_config(
1025            OpenAIConfig::new()
1026                .with_api_base("http://example.test")
1027                .with_api_key("test-key"),
1028        )
1029        .with_http_service(service);
1030
1031        let value: serde_json::Value = client.get("/models", &RequestOptions::new()).await.unwrap();
1032
1033        assert_eq!(value["data"][0]["id"], "retry-model");
1034        assert_eq!(request_count.load(Ordering::SeqCst), 2);
1035    }
1036
1037    #[derive(Clone)]
1038    struct RetryableMultipartInput {
1039        conversions: Arc<AtomicUsize>,
1040    }
1041
1042    impl AsyncTryFrom<RetryableMultipartInput> for reqwest::multipart::Form {
1043        type Error = OpenAIError;
1044
1045        async fn try_from(value: RetryableMultipartInput) -> Result<Self, Self::Error> {
1046            value.conversions.fetch_add(1, Ordering::SeqCst);
1047            Ok(reqwest::multipart::Form::new().text("field", "value"))
1048        }
1049    }
1050
1051    #[tokio::test]
1052    async fn middleware_retry_policy_rebuilds_multipart_form_per_attempt() {
1053        let request_count = Arc::new(AtomicUsize::new(0));
1054        let conversion_count = Arc::new(AtomicUsize::new(0));
1055
1056        let service = {
1057            let request_count = request_count.clone();
1058            ServiceBuilder::new()
1059                .retry(SimpleRetryPolicy::default())
1060                .service(service_fn(move |factory: HttpRequestFactory| {
1061                    let request_count = request_count.clone();
1062                    async move {
1063                        let request = factory.build().await?;
1064                        assert_eq!(request.method(), reqwest::Method::POST);
1065                        assert_eq!(request.url().path(), "/files");
1066                        let attempt = request_count.fetch_add(1, Ordering::SeqCst);
1067
1068                        let response = if attempt == 0 {
1069                            HttpResponse::builder()
1070                                .status(429)
1071                                .header("content-type", "application/json")
1072                                .body(reqwest::Body::from(
1073                                    r#"{"error":{"message":"retry me","type":"rate_limit_error","param":null,"code":null}}"#,
1074                                ))
1075                                .unwrap()
1076                        } else {
1077                            HttpResponse::builder()
1078                                .status(200)
1079                                .header("content-type", "application/json")
1080                                .body(reqwest::Body::from(r#"{"ok":true}"#))
1081                                .unwrap()
1082                        };
1083
1084                        Ok::<reqwest::Response, OpenAIError>(response.into())
1085                    }
1086                }))
1087        };
1088
1089        let client = Client::with_config(
1090            OpenAIConfig::new()
1091                .with_api_base("http://example.test")
1092                .with_api_key("test-key"),
1093        )
1094        .with_http_service(service);
1095
1096        let value: serde_json::Value = client
1097            .post_form(
1098                "/files",
1099                RetryableMultipartInput {
1100                    conversions: conversion_count.clone(),
1101                },
1102                &RequestOptions::new(),
1103            )
1104            .await
1105            .unwrap();
1106
1107        assert_eq!(value, json!({ "ok": true }));
1108        assert_eq!(request_count.load(Ordering::SeqCst), 2);
1109        assert_eq!(conversion_count.load(Ordering::SeqCst), 2);
1110    }
1111}