Skip to main content

async_openai/
client.rs

1use std::sync::Arc;
2#[cfg(not(target_family = "wasm"))]
3use std::sync::Mutex;
4
5use bytes::Bytes;
6use futures::stream::StreamExt;
7use reqwest::{header::HeaderMap, multipart::Form, Response};
8use serde::{de::DeserializeOwned, Serialize};
9
10use crate::error::StreamError;
11#[cfg(feature = "middleware")]
12use crate::executor::TowerExecutor;
13use crate::{
14    config::{Config, OpenAIConfig},
15    error::{map_deserialization_error, ApiError, OpenAIError, WrappedError},
16    executor::{HttpRequestFactory, ReqwestExecutor, SharedExecutor},
17    traits::AsyncTryFrom,
18    RequestOptions,
19};
20
21struct RequestParts {
22    request_client: reqwest::Client,
23    method: reqwest::Method,
24    url: String,
25    headers: HeaderMap,
26    query: Vec<(String, String)>,
27}
28
29impl RequestParts {
30    fn build_request_builder(&self) -> reqwest::RequestBuilder {
31        self.request_client
32            .request(self.method.clone(), self.url.clone())
33            .query(&self.query)
34            .headers(self.headers.clone())
35    }
36}
37
38#[cfg(feature = "administration")]
39use crate::admin::Admin;
40#[cfg(feature = "chatkit")]
41use crate::chatkit::Chatkit;
42#[cfg(feature = "file")]
43use crate::file::Files;
44#[cfg(feature = "image")]
45use crate::image::Images;
46#[cfg(feature = "moderation")]
47use crate::moderation::Moderations;
48#[cfg(feature = "assistant")]
49#[allow(deprecated)]
50use crate::Assistants;
51#[cfg(feature = "audio")]
52use crate::Audio;
53#[cfg(feature = "batch")]
54use crate::Batches;
55#[cfg(feature = "chat-completion")]
56use crate::Chat;
57#[cfg(feature = "completions")]
58use crate::Completions;
59#[cfg(feature = "container")]
60use crate::Containers;
61#[cfg(feature = "responses")]
62use crate::Conversations;
63#[cfg(feature = "embedding")]
64use crate::Embeddings;
65#[cfg(feature = "evals")]
66use crate::Evals;
67#[cfg(feature = "finetuning")]
68use crate::FineTuning;
69#[cfg(feature = "model")]
70use crate::Models;
71#[cfg(feature = "realtime")]
72use crate::Realtime;
73#[cfg(feature = "responses")]
74use crate::Responses;
75#[cfg(feature = "skill")]
76use crate::Skills;
77#[cfg(feature = "assistant")]
78#[allow(deprecated)]
79use crate::Threads;
80#[cfg(feature = "upload")]
81use crate::Uploads;
82#[cfg(feature = "vectorstore")]
83use crate::VectorStores;
84#[cfg(feature = "video")]
85use crate::Videos;
86
87#[derive(Clone)]
88/// Client is a container for config and HTTP execution
89/// used to make API calls.
90pub struct Client<C: Config> {
91    request_client: reqwest::Client,
92    executor: SharedExecutor,
93    config: C,
94}
95
96impl<C> std::fmt::Debug for Client<C>
97where
98    C: Config + std::fmt::Debug,
99{
100    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
101        f.debug_struct("Client")
102            .field("request_client", &self.request_client)
103            .field("config", &self.config)
104            .finish()
105    }
106}
107
108impl<C: Config> Default for Client<C>
109where
110    C: Default,
111{
112    fn default() -> Self {
113        let request_client = reqwest::Client::new();
114        Self {
115            executor: Arc::new(ReqwestExecutor::new(request_client.clone())),
116            request_client,
117            config: C::default(),
118        }
119    }
120}
121
122impl Client<OpenAIConfig> {
123    /// Client with default [OpenAIConfig]
124    pub fn new() -> Self {
125        Self::default()
126    }
127}
128
129impl<C: Config> Client<C> {
130    /// Create client with a custom HTTP client and config.
131    pub fn build(http_client: reqwest::Client, config: C) -> Self {
132        Self {
133            executor: Arc::new(ReqwestExecutor::new(http_client.clone())),
134            request_client: http_client,
135            config,
136        }
137    }
138
139    /// Create client with [OpenAIConfig] or [crate::config::AzureConfig]
140    pub fn with_config(config: C) -> Self {
141        let request_client = reqwest::Client::new();
142        Self {
143            executor: Arc::new(ReqwestExecutor::new(request_client.clone())),
144            request_client,
145            config,
146        }
147    }
148
149    /// Provide your own [client] to make HTTP requests with.
150    ///
151    /// [client]: reqwest::Client
152    pub fn with_http_client(mut self, http_client: reqwest::Client) -> Self {
153        self.executor = Arc::new(ReqwestExecutor::new(http_client.clone()));
154        self.request_client = http_client;
155        self
156    }
157
158    /// Provide your own tower-compatible service to execute HTTP requests.
159    #[cfg(all(feature = "middleware", not(target_family = "wasm")))]
160    pub fn with_http_service<S>(mut self, service: S) -> Self
161    where
162        S: tower::Service<HttpRequestFactory, Response = Response> + Clone + Send + Sync + 'static,
163        S::Future: Send + 'static,
164        S::Error: Into<OpenAIError> + Send + Sync + 'static,
165    {
166        // This is the public middleware escape hatch. We erase the concrete
167        // tower stack here so the rest of the client does not become generic
168        // over the service type, which would otherwise leak through every API
169        // group and make the crate much harder to use.
170        self.executor = Arc::new(TowerExecutor::new(service));
171        self
172    }
173
174    /// Provide your own tower-compatible service to execute HTTP requests.
175    #[cfg(all(feature = "middleware", target_family = "wasm"))]
176    pub fn with_http_service<S>(mut self, service: S) -> Self
177    where
178        S: tower::Service<HttpRequestFactory, Response = Response> + Clone + 'static,
179        S::Future: 'static,
180        S::Error: Into<OpenAIError> + 'static,
181    {
182        // wasm futures produced by reqwest are not `Send`, so the wasm version
183        // intentionally avoids native thread-safety bounds. Users are still
184        // responsible for choosing tower layers that work in their wasm
185        // runtime.
186        self.executor = Arc::new(TowerExecutor::new(service));
187        self
188    }
189
190    // API groups
191
192    /// To call [Models] group related APIs using this client.
193    #[cfg(feature = "model")]
194    pub fn models(&self) -> Models<'_, C> {
195        Models::new(self)
196    }
197
198    /// To call [Completions] group related APIs using this client.
199    #[cfg(feature = "completions")]
200    pub fn completions(&self) -> Completions<'_, C> {
201        Completions::new(self)
202    }
203
204    /// To call [Chat] group related APIs using this client.
205    #[cfg(feature = "chat-completion")]
206    pub fn chat(&self) -> Chat<'_, C> {
207        Chat::new(self)
208    }
209
210    /// To call [Images] group related APIs using this client.
211    #[cfg(feature = "image")]
212    pub fn images(&self) -> Images<'_, C> {
213        Images::new(self)
214    }
215
216    /// To call [Moderations] group related APIs using this client.
217    #[cfg(feature = "moderation")]
218    pub fn moderations(&self) -> Moderations<'_, C> {
219        Moderations::new(self)
220    }
221
222    /// To call [Files] group related APIs using this client.
223    #[cfg(feature = "file")]
224    pub fn files(&self) -> Files<'_, C> {
225        Files::new(self)
226    }
227
228    /// To call [Uploads] group related APIs using this client.
229    #[cfg(feature = "upload")]
230    pub fn uploads(&self) -> Uploads<'_, C> {
231        Uploads::new(self)
232    }
233
234    /// To call [FineTuning] group related APIs using this client.
235    #[cfg(feature = "finetuning")]
236    pub fn fine_tuning(&self) -> FineTuning<'_, C> {
237        FineTuning::new(self)
238    }
239
240    /// To call [Embeddings] group related APIs using this client.
241    #[cfg(feature = "embedding")]
242    pub fn embeddings(&self) -> Embeddings<'_, C> {
243        Embeddings::new(self)
244    }
245
246    /// To call [Audio] group related APIs using this client.
247    #[cfg(feature = "audio")]
248    pub fn audio(&self) -> Audio<'_, C> {
249        Audio::new(self)
250    }
251
252    /// To call [Videos] group related APIs using this client.
253    #[cfg(feature = "video")]
254    pub fn videos(&self) -> Videos<'_, C> {
255        Videos::new(self)
256    }
257
258    /// To call [Assistants] group related APIs using this client.
259    #[cfg(feature = "assistant")]
260    #[deprecated(
261        note = "Assistants API is deprecated and will be removed in August 2026. Use the Responses API."
262    )]
263    #[allow(deprecated)]
264    pub fn assistants(&self) -> Assistants<'_, C> {
265        Assistants::new(self)
266    }
267
268    /// To call [Threads] group related APIs using this client.
269    #[cfg(feature = "assistant")]
270    #[deprecated(
271        note = "Assistants API is deprecated and will be removed in August 2026. Use the Responses API."
272    )]
273    #[allow(deprecated)]
274    pub fn threads(&self) -> Threads<'_, C> {
275        Threads::new(self)
276    }
277
278    /// To call [VectorStores] group related APIs using this client.
279    #[cfg(feature = "vectorstore")]
280    pub fn vector_stores(&self) -> VectorStores<'_, C> {
281        VectorStores::new(self)
282    }
283
284    /// To call [Batches] group related APIs using this client.
285    #[cfg(feature = "batch")]
286    pub fn batches(&self) -> Batches<'_, C> {
287        Batches::new(self)
288    }
289
290    /// To call [Admin] group related APIs using this client.
291    /// This groups together admin API keys, invites, users, projects, audit logs, and certificates.
292    #[cfg(feature = "administration")]
293    pub fn admin(&self) -> Admin<'_, C> {
294        Admin::new(self)
295    }
296
297    /// To call [Responses] group related APIs using this client.
298    #[cfg(feature = "responses")]
299    pub fn responses(&self) -> Responses<'_, C> {
300        Responses::new(self)
301    }
302
303    /// To call [Conversations] group related APIs using this client.
304    #[cfg(feature = "responses")]
305    pub fn conversations(&self) -> Conversations<'_, C> {
306        Conversations::new(self)
307    }
308
309    /// To call [Containers] group related APIs using this client.
310    #[cfg(feature = "container")]
311    pub fn containers(&self) -> Containers<'_, C> {
312        Containers::new(self)
313    }
314
315    /// To call [Skills] group related APIs using this client.
316    #[cfg(feature = "skill")]
317    pub fn skills(&self) -> Skills<'_, C> {
318        Skills::new(self)
319    }
320
321    /// To call [Evals] group related APIs using this client.
322    #[cfg(feature = "evals")]
323    pub fn evals(&self) -> Evals<'_, C> {
324        Evals::new(self)
325    }
326
327    #[cfg(feature = "chatkit")]
328    pub fn chatkit(&self) -> Chatkit<'_, C> {
329        Chatkit::new(self)
330    }
331
332    /// To call [Realtime] group related APIs using this client.
333    #[cfg(feature = "realtime")]
334    pub fn realtime(&self) -> Realtime<'_, C> {
335        Realtime::new(self)
336    }
337
338    pub fn config(&self) -> &C {
339        &self.config
340    }
341
342    fn build_request_parts(
343        &self,
344        method: reqwest::Method,
345        path: &str,
346        request_options: &RequestOptions,
347    ) -> Arc<RequestParts> {
348        let url = if let Some(path) = request_options.path() {
349            self.config.url(path.as_str())
350        } else {
351            self.config.url(path)
352        };
353        let mut headers = self.config.headers();
354        if let Some(request_headers) = request_options.headers() {
355            headers.extend(request_headers.clone());
356        }
357
358        let mut query = self
359            .config
360            .query()
361            .into_iter()
362            .map(|(key, value)| (key.to_string(), value.to_string()))
363            .collect::<Vec<_>>();
364        query.extend_from_slice(request_options.query());
365
366        Arc::new(RequestParts {
367            request_client: self.request_client.clone(),
368            method,
369            url,
370            headers,
371            query,
372        })
373    }
374
375    fn build_request_factory(
376        &self,
377        method: reqwest::Method,
378        path: &str,
379        request_options: &RequestOptions,
380    ) -> HttpRequestFactory {
381        let request_parts = self.build_request_parts(method, path, request_options);
382
383        HttpRequestFactory::new(move || {
384            let request_parts = request_parts.clone();
385
386            async move {
387                let request = request_parts.build_request_builder().build()?;
388                Ok(request)
389            }
390        })
391    }
392
393    fn build_request_factory_with_json<I>(
394        &self,
395        method: reqwest::Method,
396        path: &str,
397        request: I,
398        request_options: &RequestOptions,
399    ) -> Result<HttpRequestFactory, OpenAIError>
400    where
401        I: Serialize,
402    {
403        // JSON bodies are materialized once so the base BYOT path can keep
404        // accepting borrowed inputs. Middleware-enabled BYOT still adds owned
405        // replay bounds in the macro, but the core client does not force those
406        // bounds onto non-middleware users.
407        let request = Bytes::from(serde_json::to_vec(&request).map_err(|error| {
408            OpenAIError::InvalidArgument(format!("failed to serialize request: {error}"))
409        })?);
410        let request_parts = self.build_request_parts(method, path, request_options);
411
412        Ok(HttpRequestFactory::new(move || {
413            let request_parts = request_parts.clone();
414            let request = request.clone();
415
416            async move {
417                let request_builder = request_parts
418                    .build_request_builder()
419                    .header(reqwest::header::CONTENT_TYPE, "application/json")
420                    .body(request.clone());
421
422                Ok(request_builder.build()?)
423            }
424        }))
425    }
426
427    fn build_request_factory_with_form<F>(
428        &self,
429        method: reqwest::Method,
430        path: &str,
431        form: F,
432        request_options: &RequestOptions,
433    ) -> Result<HttpRequestFactory, OpenAIError>
434    where
435        F: Clone + crate::traits::MaybeSend + 'static,
436        Form: AsyncTryFrom<F, Error = OpenAIError>,
437    {
438        // Multipart is the reason the factory exists.
439        //
440        // `Mutex` is only here to make the captured state `Sync` on native targets.
441        #[cfg(not(target_family = "wasm"))]
442        let form = Arc::new(Mutex::new(form));
443        let request_parts = self.build_request_parts(method, path, request_options);
444
445        Ok(HttpRequestFactory::new(move || {
446            let request_parts = request_parts.clone();
447            let form = form.clone();
448
449            async move {
450                #[cfg(not(target_family = "wasm"))]
451                let form = form
452                    .lock()
453                    .expect("multipart request factory mutex poisoned")
454                    .clone();
455                #[cfg(target_family = "wasm")]
456                let form = form.clone();
457                let form = <Form as AsyncTryFrom<F>>::try_from(form).await?;
458                let request_builder = request_parts.build_request_builder().multipart(form);
459
460                Ok(request_builder.build()?)
461            }
462        }))
463    }
464
465    /// Make a GET request to {path} and deserialize the response body
466    #[allow(unused)]
467    pub(crate) async fn get<O>(
468        &self,
469        path: &str,
470        request_options: &RequestOptions,
471    ) -> Result<O, OpenAIError>
472    where
473        O: DeserializeOwned,
474    {
475        let request_factory =
476            self.build_request_factory(reqwest::Method::GET, path, request_options);
477        self.execute(request_factory).await
478    }
479
480    /// Make a DELETE request to {path} and deserialize the response body
481    #[allow(unused)]
482    pub(crate) async fn delete<O>(
483        &self,
484        path: &str,
485        request_options: &RequestOptions,
486    ) -> Result<O, OpenAIError>
487    where
488        O: DeserializeOwned,
489    {
490        let request_factory =
491            self.build_request_factory(reqwest::Method::DELETE, path, request_options);
492        self.execute(request_factory).await
493    }
494
495    /// Make a GET request to {path} and return the response body
496    #[allow(unused)]
497    pub(crate) async fn get_raw(
498        &self,
499        path: &str,
500        request_options: &RequestOptions,
501    ) -> Result<(Bytes, HeaderMap), OpenAIError> {
502        let request_factory =
503            self.build_request_factory(reqwest::Method::GET, path, request_options);
504        self.execute_raw(request_factory).await
505    }
506
507    /// Make a POST request to {path} and return the response body
508    #[allow(unused)]
509    pub(crate) async fn post_raw<I>(
510        &self,
511        path: &str,
512        request: I,
513        request_options: &RequestOptions,
514    ) -> Result<(Bytes, HeaderMap), OpenAIError>
515    where
516        I: Serialize,
517    {
518        let request_factory = self.build_request_factory_with_json(
519            reqwest::Method::POST,
520            path,
521            request,
522            request_options,
523        )?;
524        self.execute_raw(request_factory).await
525    }
526
527    /// Make a POST request to {path} and deserialize the response body
528    #[allow(unused)]
529    pub(crate) async fn post<I, O>(
530        &self,
531        path: &str,
532        request: I,
533        request_options: &RequestOptions,
534    ) -> Result<O, OpenAIError>
535    where
536        I: Serialize,
537        O: DeserializeOwned,
538    {
539        let request_factory = self.build_request_factory_with_json(
540            reqwest::Method::POST,
541            path,
542            request,
543            request_options,
544        )?;
545        self.execute(request_factory).await
546    }
547
548    /// POST a form at {path} and return the response body
549    #[allow(unused)]
550    pub(crate) async fn post_form_raw<F>(
551        &self,
552        path: &str,
553        form: F,
554        request_options: &RequestOptions,
555    ) -> Result<(Bytes, HeaderMap), OpenAIError>
556    where
557        F: Clone + crate::traits::MaybeSend + 'static,
558        Form: AsyncTryFrom<F, Error = OpenAIError>,
559    {
560        let request_factory = self.build_request_factory_with_form(
561            reqwest::Method::POST,
562            path,
563            form,
564            request_options,
565        )?;
566        self.execute_raw(request_factory).await
567    }
568
569    /// POST a form at {path} and deserialize the response body
570    #[allow(unused)]
571    pub(crate) async fn post_form<O, F>(
572        &self,
573        path: &str,
574        form: F,
575        request_options: &RequestOptions,
576    ) -> Result<O, OpenAIError>
577    where
578        O: DeserializeOwned,
579        F: Clone + crate::traits::MaybeSend + 'static,
580        Form: AsyncTryFrom<F, Error = OpenAIError>,
581    {
582        let request_factory = self.build_request_factory_with_form(
583            reqwest::Method::POST,
584            path,
585            form,
586            request_options,
587        )?;
588        self.execute(request_factory).await
589    }
590
591    #[allow(unused)]
592    pub(crate) async fn post_form_stream<O, F>(
593        &self,
594        path: &str,
595        form: F,
596        request_options: &RequestOptions,
597    ) -> Result<crate::types::stream::StreamResponse<O>, OpenAIError>
598    where
599        F: Clone + crate::traits::MaybeSend + 'static,
600        Form: AsyncTryFrom<F, Error = OpenAIError>,
601        O: DeserializeOwned + crate::traits::MaybeSend + 'static,
602    {
603        let request_factory = self.build_request_factory_with_form(
604            reqwest::Method::POST,
605            path,
606            form,
607            request_options,
608        )?;
609
610        self.execute_stream(request_factory).await
611    }
612
613    async fn execute_raw(
614        &self,
615        request_factory: HttpRequestFactory,
616    ) -> Result<(Bytes, HeaderMap), OpenAIError> {
617        let response = self.execute_response(request_factory).await?;
618        read_response(response).await
619    }
620
621    async fn execute<O>(&self, request_factory: HttpRequestFactory) -> Result<O, OpenAIError>
622    where
623        O: DeserializeOwned,
624    {
625        let (bytes, _headers) = self.execute_raw(request_factory).await?;
626
627        let response: O = serde_json::from_slice(bytes.as_ref())
628            .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
629
630        Ok(response)
631    }
632
633    async fn execute_response(
634        &self,
635        request_factory: HttpRequestFactory,
636    ) -> Result<Response, OpenAIError> {
637        self.executor.execute(request_factory).await
638    }
639
640    async fn execute_stream<O>(
641        &self,
642        request_factory: HttpRequestFactory,
643    ) -> Result<crate::types::stream::StreamResponse<O>, OpenAIError>
644    where
645        O: DeserializeOwned + crate::traits::MaybeSend + 'static,
646    {
647        let response = self.execute_response(request_factory).await?;
648        Ok(stream(response).await)
649    }
650
651    async fn execute_stream_mapped_raw_events<O>(
652        &self,
653        request_factory: HttpRequestFactory,
654        event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError>
655            + crate::traits::MaybeSend
656            + 'static,
657    ) -> Result<crate::types::stream::StreamResponse<O>, OpenAIError>
658    where
659        O: DeserializeOwned + crate::traits::MaybeSend + 'static,
660    {
661        let response = self.execute_response(request_factory).await?;
662        Ok(stream_mapped_raw_events(response, event_mapper).await)
663    }
664
665    /// Make HTTP POST request to receive SSE
666    #[allow(unused)]
667    pub(crate) async fn post_stream<I, O>(
668        &self,
669        path: &str,
670        request: I,
671        request_options: &RequestOptions,
672    ) -> Result<crate::types::stream::StreamResponse<O>, OpenAIError>
673    where
674        I: Serialize,
675        O: DeserializeOwned + crate::traits::MaybeSend + 'static,
676    {
677        let request_factory = self.build_request_factory_with_json(
678            reqwest::Method::POST,
679            path,
680            request,
681            request_options,
682        )?;
683        // Stream setup is still request/response first. We only create the SSE
684        // stream after the HTTP layer has returned a response object.
685        self.execute_stream(request_factory).await
686    }
687
688    #[allow(unused)]
689    pub(crate) async fn post_stream_mapped_raw_events<I, O>(
690        &self,
691        path: &str,
692        request: I,
693        request_options: &RequestOptions,
694        event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError>
695            + crate::traits::MaybeSend
696            + 'static,
697    ) -> Result<crate::types::stream::StreamResponse<O>, OpenAIError>
698    where
699        I: Serialize,
700        O: DeserializeOwned + crate::traits::MaybeSend + 'static,
701    {
702        let request_factory = self.build_request_factory_with_json(
703            reqwest::Method::POST,
704            path,
705            request,
706            request_options,
707        )?;
708        self.execute_stream_mapped_raw_events(request_factory, event_mapper)
709            .await
710    }
711
712    /// Make HTTP GET request to receive SSE
713    #[allow(unused)]
714    pub(crate) async fn get_stream<O>(
715        &self,
716        path: &str,
717        request_options: &RequestOptions,
718    ) -> Result<crate::types::stream::StreamResponse<O>, OpenAIError>
719    where
720        O: DeserializeOwned + crate::traits::MaybeSend + 'static,
721    {
722        let request_factory =
723            self.build_request_factory(reqwest::Method::GET, path, request_options);
724        self.execute_stream(request_factory).await
725    }
726}
727
728async fn read_response(response: Response) -> Result<(Bytes, HeaderMap), OpenAIError> {
729    let status = response.status();
730    let headers = response.headers().clone();
731    let bytes = response.bytes().await.map_err(OpenAIError::Reqwest)?;
732
733    if status.is_server_error() {
734        // OpenAI does not guarantee server errors are returned as JSON so we cannot deserialize them.
735        let message: String = String::from_utf8_lossy(&bytes).into_owned();
736        tracing::warn!("Server error: {status} - {message}");
737        return Err(OpenAIError::ApiError(ApiError {
738            message,
739            r#type: None,
740            param: None,
741            code: None,
742        }));
743    }
744
745    // Deserialize response body from either error object or actual response object
746    if !status.is_success() {
747        let wrapped_error: WrappedError = serde_json::from_slice(bytes.as_ref())
748            .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
749
750        return Err(OpenAIError::ApiError(wrapped_error.error));
751    }
752
753    Ok((bytes, headers))
754}
755
756/// Request which responds with SSE.
757/// [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format)
758pub(crate) async fn stream<O>(response: Response) -> crate::types::stream::StreamResponse<O>
759where
760    O: DeserializeOwned + crate::traits::MaybeSend + 'static,
761{
762    stream_mapped_raw_events(response, |event| {
763        serde_json::from_str::<O>(&event.data)
764            .map_err(|error| map_deserialization_error(error, event.data.as_bytes()))
765    })
766    .await
767}
768
769#[cfg(target_family = "wasm")]
770pub(crate) async fn stream_mapped_raw_events<O>(
771    response: Response,
772    event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + 'static,
773) -> crate::types::stream::StreamResponse<O>
774where
775    O: DeserializeOwned + 'static,
776{
777    if !response.status().is_success() {
778        return Box::pin(futures::stream::once(async move {
779            match read_response(response).await {
780                Ok(_) => Err(OpenAIError::InvalidArgument(
781                    "stream request failed without an error body".into(),
782                )),
783                Err(error) => Err(error),
784            }
785        }));
786    }
787
788    let byte_stream = response
789        .bytes_stream()
790        .map(|result| result.map_err(std::io::Error::other));
791    let event_stream = Box::pin(eventsource_stream::EventStream::new(byte_stream));
792
793    Box::pin(futures::stream::unfold(
794        (event_stream, event_mapper, false),
795        |(mut event_stream, event_mapper, finished)| async move {
796            if finished {
797                return None;
798            }
799
800            loop {
801                let event = match event_stream.next().await {
802                    Some(Ok(event)) => event,
803                    Some(Err(error)) => {
804                        return Some((
805                            Err(OpenAIError::StreamError(Box::new(
806                                StreamError::EventStream(error.to_string()),
807                            ))),
808                            (event_stream, event_mapper, true),
809                        ));
810                    }
811                    None => return None,
812                };
813
814                let done = event.data == "[DONE]";
815
816                if event.event == "keepalive" {
817                    continue;
818                }
819
820                let response = event_mapper(event);
821                return Some((response, (event_stream, event_mapper, done)));
822            }
823        },
824    ))
825}
826
827#[cfg(not(target_family = "wasm"))]
828pub(crate) async fn stream_mapped_raw_events<O>(
829    response: Response,
830    event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
831) -> crate::types::stream::StreamResponse<O>
832where
833    O: DeserializeOwned + std::marker::Send + 'static,
834{
835    let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
836
837    tokio::spawn(async move {
838        if !response.status().is_success() {
839            if let Err(e) = read_response(response).await {
840                let _ = tx.send(Err(e));
841            }
842            return;
843        }
844        let byte_stream = response
845            .bytes_stream()
846            .map(|r| r.map_err(std::io::Error::other));
847        let mut event_stream = std::pin::pin!(eventsource_stream::EventStream::new(byte_stream));
848
849        while let Some(ev) = event_stream.next().await {
850            let event = match ev {
851                Ok(e) => e,
852                Err(e) => {
853                    let _ = tx.send(Err(OpenAIError::StreamError(Box::new(
854                        StreamError::EventStream(e.to_string()),
855                    ))));
856                    break;
857                }
858            };
859            let done = event.data == "[DONE]";
860
861            if event.event == "keepalive" {
862                continue;
863            }
864
865            let response = event_mapper(event);
866
867            if tx.send(response).is_err() {
868                break;
869            }
870
871            if done {
872                break;
873            }
874        }
875    });
876
877    Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
878}
879
880#[cfg(all(test, feature = "middleware", not(target_family = "wasm")))]
881mod tests {
882    use std::sync::{
883        atomic::{AtomicUsize, Ordering},
884        Arc,
885    };
886
887    use futures::StreamExt;
888    use http::Response as HttpResponse;
889    use serde_json::json;
890    use tower::{service_fn, ServiceBuilder};
891
892    use super::Client;
893    use crate::{
894        config::OpenAIConfig, error::OpenAIError, executor::HttpRequestFactory,
895        retry::SimpleRetryPolicy, traits::AsyncTryFrom, RequestOptions,
896    };
897
898    #[tokio::test]
899    async fn unary_requests_dispatch_through_middleware_service() {
900        let request_count = Arc::new(AtomicUsize::new(0));
901        let service = {
902            let request_count = request_count.clone();
903            ServiceBuilder::new()
904                .concurrency_limit(1)
905                .service(service_fn(move |factory: HttpRequestFactory| {
906                    let request_count = request_count.clone();
907                    async move {
908                        let request = factory.build().await?;
909                        assert_eq!(request.url().path(), "/models");
910                        request_count.fetch_add(1, Ordering::SeqCst);
911                        Ok::<reqwest::Response, OpenAIError>(
912                            HttpResponse::builder()
913                                .status(200)
914                                .header("content-type", "application/json")
915                                .body(reqwest::Body::from(
916                                    "{\"object\":\"list\",\"data\":[{\"id\":\"model\"}]}",
917                                ))
918                                .unwrap()
919                                .into(),
920                        )
921                    }
922                }))
923        };
924
925        let client = Client::with_config(
926            OpenAIConfig::new()
927                .with_api_base("http://example.test")
928                .with_api_key("test-key"),
929        )
930        .with_http_service(service);
931
932        let value: serde_json::Value = client.get("/models", &RequestOptions::new()).await.unwrap();
933
934        assert_eq!(value["object"], "list");
935        assert_eq!(request_count.load(Ordering::SeqCst), 1);
936    }
937
938    #[tokio::test]
939    async fn stream_requests_open_through_middleware_service() {
940        let request_count = Arc::new(AtomicUsize::new(0));
941        let service = {
942            let request_count = request_count.clone();
943            ServiceBuilder::new()
944                .concurrency_limit(1)
945                .service(service_fn(move |factory: HttpRequestFactory| {
946                    let request_count = request_count.clone();
947                    async move {
948                        let request = factory.build().await?;
949                        assert_eq!(request.url().path(), "/responses");
950                        request_count.fetch_add(1, Ordering::SeqCst);
951                        Ok::<reqwest::Response, OpenAIError>(
952                            HttpResponse::builder()
953                                .status(200)
954                                .header("content-type", "text/event-stream")
955                                .body(reqwest::Body::from(
956                                    "data: {\"ok\":true}\n\ndata: [DONE]\n\n",
957                                ))
958                                .unwrap()
959                                .into(),
960                        )
961                    }
962                }))
963        };
964
965        let client = Client::with_config(
966            OpenAIConfig::new()
967                .with_api_base("http://example.test")
968                .with_api_key("test-key"),
969        )
970        .with_http_service(service);
971
972        let mut stream = client
973            .post_stream::<_, serde_json::Value>(
974                "/responses",
975                json!({ "stream": true }),
976                &RequestOptions::new(),
977            )
978            .await
979            .unwrap();
980
981        let first = stream.next().await.unwrap().unwrap();
982
983        assert_eq!(first, json!({ "ok": true }));
984        assert_eq!(request_count.load(Ordering::SeqCst), 1);
985    }
986
987    #[tokio::test]
988    async fn middleware_retry_policy_retries_429_responses() {
989        let request_count = Arc::new(AtomicUsize::new(0));
990        let service = {
991            let request_count = request_count.clone();
992            ServiceBuilder::new()
993                .retry(SimpleRetryPolicy::default())
994                .service(service_fn(move |factory: HttpRequestFactory| {
995                    let request_count = request_count.clone();
996                    async move {
997                        let request = factory.build().await?;
998                        assert_eq!(request.url().path(), "/models");
999                        let attempt = request_count.fetch_add(1, Ordering::SeqCst);
1000
1001                        let response = if attempt == 0 {
1002                            HttpResponse::builder()
1003                                .status(429)
1004                                .header("content-type", "application/json")
1005                                .body(reqwest::Body::from(
1006                                    r#"{"error":{"message":"retry me","type":"rate_limit_error","param":null,"code":null}}"#,
1007                                ))
1008                                .unwrap()
1009                        } else {
1010                            HttpResponse::builder()
1011                                .status(200)
1012                                .header("content-type", "application/json")
1013                                .body(reqwest::Body::from(
1014                                    r#"{"object":"list","data":[{"id":"retry-model"}]}"#,
1015                                ))
1016                                .unwrap()
1017                        };
1018
1019                        Ok::<reqwest::Response, OpenAIError>(response.into())
1020                    }
1021                }))
1022        };
1023
1024        let client = Client::with_config(
1025            OpenAIConfig::new()
1026                .with_api_base("http://example.test")
1027                .with_api_key("test-key"),
1028        )
1029        .with_http_service(service);
1030
1031        let value: serde_json::Value = client.get("/models", &RequestOptions::new()).await.unwrap();
1032
1033        assert_eq!(value["data"][0]["id"], "retry-model");
1034        assert_eq!(request_count.load(Ordering::SeqCst), 2);
1035    }
1036
1037    #[derive(Clone)]
1038    struct RetryableMultipartInput {
1039        conversions: Arc<AtomicUsize>,
1040    }
1041
1042    impl AsyncTryFrom<RetryableMultipartInput> for reqwest::multipart::Form {
1043        type Error = OpenAIError;
1044
1045        async fn try_from(value: RetryableMultipartInput) -> Result<Self, Self::Error> {
1046            value.conversions.fetch_add(1, Ordering::SeqCst);
1047            Ok(reqwest::multipart::Form::new().text("field", "value"))
1048        }
1049    }
1050
1051    #[tokio::test]
1052    async fn middleware_retry_policy_rebuilds_multipart_form_per_attempt() {
1053        let request_count = Arc::new(AtomicUsize::new(0));
1054        let conversion_count = Arc::new(AtomicUsize::new(0));
1055
1056        let service = {
1057            let request_count = request_count.clone();
1058            ServiceBuilder::new()
1059                .retry(SimpleRetryPolicy::default())
1060                .service(service_fn(move |factory: HttpRequestFactory| {
1061                    let request_count = request_count.clone();
1062                    async move {
1063                        let request = factory.build().await?;
1064                        assert_eq!(request.method(), reqwest::Method::POST);
1065                        assert_eq!(request.url().path(), "/files");
1066                        let attempt = request_count.fetch_add(1, Ordering::SeqCst);
1067
1068                        let response = if attempt == 0 {
1069                            HttpResponse::builder()
1070                                .status(429)
1071                                .header("content-type", "application/json")
1072                                .body(reqwest::Body::from(
1073                                    r#"{"error":{"message":"retry me","type":"rate_limit_error","param":null,"code":null}}"#,
1074                                ))
1075                                .unwrap()
1076                        } else {
1077                            HttpResponse::builder()
1078                                .status(200)
1079                                .header("content-type", "application/json")
1080                                .body(reqwest::Body::from(r#"{"ok":true}"#))
1081                                .unwrap()
1082                        };
1083
1084                        Ok::<reqwest::Response, OpenAIError>(response.into())
1085                    }
1086                }))
1087        };
1088
1089        let client = Client::with_config(
1090            OpenAIConfig::new()
1091                .with_api_base("http://example.test")
1092                .with_api_key("test-key"),
1093        )
1094        .with_http_service(service);
1095
1096        let value: serde_json::Value = client
1097            .post_form(
1098                "/files",
1099                RetryableMultipartInput {
1100                    conversions: conversion_count.clone(),
1101                },
1102                &RequestOptions::new(),
1103            )
1104            .await
1105            .unwrap();
1106
1107        assert_eq!(value, json!({ "ok": true }));
1108        assert_eq!(request_count.load(Ordering::SeqCst), 2);
1109        assert_eq!(conversion_count.load(Ordering::SeqCst), 2);
1110    }
1111}