Skip to main content

async_openai/
client.rs

1#[cfg(not(target_family = "wasm"))]
2use std::pin::Pin;
3
4use bytes::Bytes;
5#[cfg(not(target_family = "wasm"))]
6use futures::{stream::StreamExt, Stream};
7use reqwest::{header::HeaderMap, multipart::Form, Response};
8#[cfg(not(target_family = "wasm"))]
9use reqwest_eventsource::{Error as EventSourceError, Event, EventSource, RequestBuilderExt};
10use serde::{de::DeserializeOwned, Serialize};
11
12#[cfg(not(target_family = "wasm"))]
13use crate::error::StreamError;
14use crate::{
15    config::{Config, OpenAIConfig},
16    error::{map_deserialization_error, ApiError, OpenAIError, WrappedError},
17    traits::AsyncTryFrom,
18    RequestOptions,
19};
20
21#[cfg(feature = "administration")]
22use crate::admin::Admin;
23#[cfg(feature = "chatkit")]
24use crate::chatkit::Chatkit;
25#[cfg(feature = "file")]
26use crate::file::Files;
27#[cfg(feature = "image")]
28use crate::image::Images;
29#[cfg(feature = "moderation")]
30use crate::moderation::Moderations;
31#[cfg(feature = "assistant")]
32#[allow(deprecated)]
33use crate::Assistants;
34#[cfg(feature = "audio")]
35use crate::Audio;
36#[cfg(feature = "batch")]
37use crate::Batches;
38#[cfg(feature = "chat-completion")]
39use crate::Chat;
40#[cfg(feature = "completions")]
41use crate::Completions;
42#[cfg(feature = "container")]
43use crate::Containers;
44#[cfg(feature = "responses")]
45use crate::Conversations;
46#[cfg(feature = "embedding")]
47use crate::Embeddings;
48#[cfg(feature = "evals")]
49use crate::Evals;
50#[cfg(feature = "finetuning")]
51use crate::FineTuning;
52#[cfg(feature = "model")]
53use crate::Models;
54#[cfg(feature = "realtime")]
55use crate::Realtime;
56#[cfg(feature = "responses")]
57use crate::Responses;
58#[cfg(feature = "assistant")]
59#[allow(deprecated)]
60use crate::Threads;
61#[cfg(feature = "upload")]
62use crate::Uploads;
63#[cfg(feature = "vectorstore")]
64use crate::VectorStores;
65#[cfg(feature = "video")]
66use crate::Videos;
67
68#[derive(Debug, Clone)]
69/// Client is a container for config, backoff and http_client
70/// used to make API calls.
71pub struct Client<C: Config> {
72    http_client: reqwest::Client,
73    config: C,
74    #[cfg(not(target_family = "wasm"))]
75    backoff: backoff::ExponentialBackoff,
76}
77
78impl<C: Config> Default for Client<C>
79where
80    C: Default,
81{
82    fn default() -> Self {
83        Self {
84            http_client: reqwest::Client::new(),
85            config: C::default(),
86            #[cfg(not(target_family = "wasm"))]
87            backoff: Default::default(),
88        }
89    }
90}
91
92impl Client<OpenAIConfig> {
93    /// Client with default [OpenAIConfig]
94    pub fn new() -> Self {
95        Self::default()
96    }
97}
98
99impl<C: Config> Client<C> {
100    /// Create client with a custom HTTP client, OpenAI config, and backoff.
101    #[cfg(not(target_family = "wasm"))]
102    pub fn build(
103        http_client: reqwest::Client,
104        config: C,
105        backoff: backoff::ExponentialBackoff,
106    ) -> Self {
107        Self {
108            http_client,
109            config,
110            backoff,
111        }
112    }
113
114    /// Create client with a custom HTTP client and config (WASM version without backoff).
115    #[cfg(target_family = "wasm")]
116    pub fn build(http_client: reqwest::Client, config: C) -> Self {
117        Self {
118            http_client,
119            config,
120        }
121    }
122
123    /// Create client with [OpenAIConfig] or [crate::config::AzureConfig]
124    pub fn with_config(config: C) -> Self {
125        Self {
126            http_client: reqwest::Client::new(),
127            config,
128            #[cfg(not(target_family = "wasm"))]
129            backoff: Default::default(),
130        }
131    }
132
133    /// Provide your own [client] to make HTTP requests with.
134    ///
135    /// [client]: reqwest::Client
136    pub fn with_http_client(mut self, http_client: reqwest::Client) -> Self {
137        self.http_client = http_client;
138        self
139    }
140
141    /// Exponential backoff for retrying [rate limited](https://platform.openai.com/docs/guides/rate-limits) requests.
142    #[cfg(not(target_family = "wasm"))]
143    pub fn with_backoff(mut self, backoff: backoff::ExponentialBackoff) -> Self {
144        self.backoff = backoff;
145        self
146    }
147
148    // API groups
149
150    /// To call [Models] group related APIs using this client.
151    #[cfg(feature = "model")]
152    pub fn models(&self) -> Models<'_, C> {
153        Models::new(self)
154    }
155
156    /// To call [Completions] group related APIs using this client.
157    #[cfg(feature = "completions")]
158    pub fn completions(&self) -> Completions<'_, C> {
159        Completions::new(self)
160    }
161
162    /// To call [Chat] group related APIs using this client.
163    #[cfg(feature = "chat-completion")]
164    pub fn chat(&self) -> Chat<'_, C> {
165        Chat::new(self)
166    }
167
168    /// To call [Images] group related APIs using this client.
169    #[cfg(feature = "image")]
170    pub fn images(&self) -> Images<'_, C> {
171        Images::new(self)
172    }
173
174    /// To call [Moderations] group related APIs using this client.
175    #[cfg(feature = "moderation")]
176    pub fn moderations(&self) -> Moderations<'_, C> {
177        Moderations::new(self)
178    }
179
180    /// To call [Files] group related APIs using this client.
181    #[cfg(feature = "file")]
182    pub fn files(&self) -> Files<'_, C> {
183        Files::new(self)
184    }
185
186    /// To call [Uploads] group related APIs using this client.
187    #[cfg(feature = "upload")]
188    pub fn uploads(&self) -> Uploads<'_, C> {
189        Uploads::new(self)
190    }
191
192    /// To call [FineTuning] group related APIs using this client.
193    #[cfg(feature = "finetuning")]
194    pub fn fine_tuning(&self) -> FineTuning<'_, C> {
195        FineTuning::new(self)
196    }
197
198    /// To call [Embeddings] group related APIs using this client.
199    #[cfg(feature = "embedding")]
200    pub fn embeddings(&self) -> Embeddings<'_, C> {
201        Embeddings::new(self)
202    }
203
204    /// To call [Audio] group related APIs using this client.
205    #[cfg(feature = "audio")]
206    pub fn audio(&self) -> Audio<'_, C> {
207        Audio::new(self)
208    }
209
210    /// To call [Videos] group related APIs using this client.
211    #[cfg(feature = "video")]
212    pub fn videos(&self) -> Videos<'_, C> {
213        Videos::new(self)
214    }
215
216    /// To call [Assistants] group related APIs using this client.
217    #[cfg(feature = "assistant")]
218    #[deprecated(
219        note = "Assistants API is deprecated and will be removed in August 2026. Use the Responses API."
220    )]
221    #[allow(deprecated)]
222    pub fn assistants(&self) -> Assistants<'_, C> {
223        Assistants::new(self)
224    }
225
226    /// To call [Threads] group related APIs using this client.
227    #[cfg(feature = "assistant")]
228    #[deprecated(
229        note = "Assistants API is deprecated and will be removed in August 2026. Use the Responses API."
230    )]
231    #[allow(deprecated)]
232    pub fn threads(&self) -> Threads<'_, C> {
233        Threads::new(self)
234    }
235
236    /// To call [VectorStores] group related APIs using this client.
237    #[cfg(feature = "vectorstore")]
238    pub fn vector_stores(&self) -> VectorStores<'_, C> {
239        VectorStores::new(self)
240    }
241
242    /// To call [Batches] group related APIs using this client.
243    #[cfg(feature = "batch")]
244    pub fn batches(&self) -> Batches<'_, C> {
245        Batches::new(self)
246    }
247
248    /// To call [Admin] group related APIs using this client.
249    /// This groups together admin API keys, invites, users, projects, audit logs, and certificates.
250    #[cfg(feature = "administration")]
251    pub fn admin(&self) -> Admin<'_, C> {
252        Admin::new(self)
253    }
254
255    /// To call [Responses] group related APIs using this client.
256    #[cfg(feature = "responses")]
257    pub fn responses(&self) -> Responses<'_, C> {
258        Responses::new(self)
259    }
260
261    /// To call [Conversations] group related APIs using this client.
262    #[cfg(feature = "responses")]
263    pub fn conversations(&self) -> Conversations<'_, C> {
264        Conversations::new(self)
265    }
266
267    /// To call [Containers] group related APIs using this client.
268    #[cfg(feature = "container")]
269    pub fn containers(&self) -> Containers<'_, C> {
270        Containers::new(self)
271    }
272
273    /// To call [Evals] group related APIs using this client.
274    #[cfg(feature = "evals")]
275    pub fn evals(&self) -> Evals<'_, C> {
276        Evals::new(self)
277    }
278
279    #[cfg(feature = "chatkit")]
280    pub fn chatkit(&self) -> Chatkit<'_, C> {
281        Chatkit::new(self)
282    }
283
284    /// To call [Realtime] group related APIs using this client.
285    #[cfg(feature = "realtime")]
286    pub fn realtime(&self) -> Realtime<'_, C> {
287        Realtime::new(self)
288    }
289
290    pub fn config(&self) -> &C {
291        &self.config
292    }
293
294    /// Helper function to build a request builder with common configuration
295    fn build_request_builder(
296        &self,
297        method: reqwest::Method,
298        path: &str,
299        request_options: &RequestOptions,
300    ) -> reqwest::RequestBuilder {
301        let mut request_builder = if let Some(path) = request_options.path() {
302            self.http_client
303                .request(method, self.config.url(path.as_str()))
304        } else {
305            self.http_client.request(method, self.config.url(path))
306        };
307
308        request_builder = request_builder
309            .query(&self.config.query())
310            .headers(self.config.headers());
311
312        if let Some(headers) = request_options.headers() {
313            request_builder = request_builder.headers(headers.clone());
314        }
315
316        if !request_options.query().is_empty() {
317            request_builder = request_builder.query(request_options.query());
318        }
319
320        request_builder
321    }
322
323    /// Make a GET request to {path} and deserialize the response body
324    #[allow(unused)]
325    pub(crate) async fn get<O>(
326        &self,
327        path: &str,
328        request_options: &RequestOptions,
329    ) -> Result<O, OpenAIError>
330    where
331        O: DeserializeOwned,
332    {
333        let request_maker = || async {
334            Ok(self
335                .build_request_builder(reqwest::Method::GET, path, request_options)
336                .build()?)
337        };
338
339        self.execute(request_maker).await
340    }
341
342    /// Make a DELETE request to {path} and deserialize the response body
343    #[allow(unused)]
344    pub(crate) async fn delete<O>(
345        &self,
346        path: &str,
347        request_options: &RequestOptions,
348    ) -> Result<O, OpenAIError>
349    where
350        O: DeserializeOwned,
351    {
352        let request_maker = || async {
353            Ok(self
354                .build_request_builder(reqwest::Method::DELETE, path, request_options)
355                .build()?)
356        };
357
358        self.execute(request_maker).await
359    }
360
361    /// Make a GET request to {path} and return the response body
362    #[allow(unused)]
363    pub(crate) async fn get_raw(
364        &self,
365        path: &str,
366        request_options: &RequestOptions,
367    ) -> Result<(Bytes, HeaderMap), OpenAIError> {
368        let request_maker = || async {
369            Ok(self
370                .build_request_builder(reqwest::Method::GET, path, request_options)
371                .build()?)
372        };
373
374        self.execute_raw(request_maker).await
375    }
376
377    /// Make a POST request to {path} and return the response body
378    #[allow(unused)]
379    pub(crate) async fn post_raw<I>(
380        &self,
381        path: &str,
382        request: I,
383        request_options: &RequestOptions,
384    ) -> Result<(Bytes, HeaderMap), OpenAIError>
385    where
386        I: Serialize,
387    {
388        let request_maker = || async {
389            Ok(self
390                .build_request_builder(reqwest::Method::POST, path, request_options)
391                .json(&request)
392                .build()?)
393        };
394
395        self.execute_raw(request_maker).await
396    }
397
398    /// Make a POST request to {path} and deserialize the response body
399    #[allow(unused)]
400    pub(crate) async fn post<I, O>(
401        &self,
402        path: &str,
403        request: I,
404        request_options: &RequestOptions,
405    ) -> Result<O, OpenAIError>
406    where
407        I: Serialize,
408        O: DeserializeOwned,
409    {
410        let request_maker = || async {
411            Ok(self
412                .build_request_builder(reqwest::Method::POST, path, request_options)
413                .json(&request)
414                .build()?)
415        };
416
417        self.execute(request_maker).await
418    }
419
420    /// POST a form at {path} and return the response body
421    #[allow(unused)]
422    pub(crate) async fn post_form_raw<F>(
423        &self,
424        path: &str,
425        form: F,
426        request_options: &RequestOptions,
427    ) -> Result<(Bytes, HeaderMap), OpenAIError>
428    where
429        Form: AsyncTryFrom<F, Error = OpenAIError>,
430        F: Clone,
431    {
432        let request_maker = || async {
433            Ok(self
434                .build_request_builder(reqwest::Method::POST, path, request_options)
435                .multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
436                .build()?)
437        };
438
439        self.execute_raw(request_maker).await
440    }
441
442    /// POST a form at {path} and deserialize the response body
443    #[allow(unused)]
444    pub(crate) async fn post_form<O, F>(
445        &self,
446        path: &str,
447        form: F,
448        request_options: &RequestOptions,
449    ) -> Result<O, OpenAIError>
450    where
451        O: DeserializeOwned,
452        Form: AsyncTryFrom<F, Error = OpenAIError>,
453        F: Clone,
454    {
455        let request_maker = || async {
456            Ok(self
457                .build_request_builder(reqwest::Method::POST, path, request_options)
458                .multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
459                .build()?)
460        };
461
462        self.execute(request_maker).await
463    }
464
465    #[allow(unused)]
466    #[cfg(not(target_family = "wasm"))]
467    pub(crate) async fn post_form_stream<O, F>(
468        &self,
469        path: &str,
470        form: F,
471        request_options: &RequestOptions,
472    ) -> Result<Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>, OpenAIError>
473    where
474        F: Clone,
475        Form: AsyncTryFrom<F, Error = OpenAIError>,
476        O: DeserializeOwned + std::marker::Send + 'static,
477    {
478        // Build and execute request manually since multipart::Form is not Clone
479        // and .eventsource() requires cloneability
480        let request_builder = self
481            .build_request_builder(reqwest::Method::POST, path, request_options)
482            .multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?);
483
484        let response = request_builder.send().await.map_err(OpenAIError::Reqwest)?;
485
486        // Check for error status
487        if !response.status().is_success() {
488            return Err(read_response(response).await.unwrap_err());
489        }
490
491        // Convert response body to EventSource stream
492        let stream = response
493            .bytes_stream()
494            .map(|result| result.map_err(std::io::Error::other));
495        let event_stream = eventsource_stream::EventStream::new(stream);
496
497        // Convert EventSource stream to our expected format
498        let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
499
500        tokio::spawn(async move {
501            use futures::StreamExt;
502            let mut event_stream = std::pin::pin!(event_stream);
503
504            while let Some(event_result) = event_stream.next().await {
505                match event_result {
506                    Err(e) => {
507                        if let Err(_e) = tx.send(Err(OpenAIError::StreamError(Box::new(
508                            StreamError::EventStream(e.to_string()),
509                        )))) {
510                            break;
511                        }
512                    }
513                    Ok(event) => {
514                        // eventsource_stream::Event is a struct with data field
515                        if event.data == "[DONE]" {
516                            break;
517                        }
518
519                        let response = match serde_json::from_str::<O>(&event.data) {
520                            Err(e) => Err(map_deserialization_error(e, event.data.as_bytes())),
521                            Ok(output) => Ok(output),
522                        };
523
524                        if let Err(_e) = tx.send(response) {
525                            break;
526                        }
527                    }
528                }
529            }
530        });
531
532        Ok(Box::pin(
533            tokio_stream::wrappers::UnboundedReceiverStream::new(rx),
534        ))
535    }
536
537    /// Execute a HTTP request and retry on rate limit (non-WASM version with backoff)
538    ///
539    /// request_maker serves one purpose: to be able to create request again
540    /// to retry API call after getting rate limited. request_maker is async because
541    /// reqwest::multipart::Form is created by async calls to read files for uploads.
542    #[cfg(not(target_family = "wasm"))]
543    async fn execute_raw<M, Fut>(&self, request_maker: M) -> Result<(Bytes, HeaderMap), OpenAIError>
544    where
545        M: Fn() -> Fut,
546        Fut: core::future::Future<Output = Result<reqwest::Request, OpenAIError>>,
547    {
548        let client = self.http_client.clone();
549
550        backoff::future::retry(self.backoff.clone(), || async {
551            let request = request_maker().await.map_err(backoff::Error::Permanent)?;
552            let response = client
553                .execute(request)
554                .await
555                .map_err(OpenAIError::Reqwest)
556                .map_err(backoff::Error::Permanent)?;
557
558            let status = response.status();
559
560            match read_response(response).await {
561                Ok((bytes, headers)) => Ok((bytes, headers)),
562                Err(e) => {
563                    match e {
564                        OpenAIError::ApiError(api_error) => {
565                            if status.is_server_error() {
566                                Err(backoff::Error::Transient {
567                                    err: OpenAIError::ApiError(api_error),
568                                    retry_after: None,
569                                })
570                            } else if status.as_u16() == 429
571                                && api_error.r#type != Some("insufficient_quota".to_string())
572                            {
573                                // Rate limited retry...
574                                tracing::warn!("Rate limited: {}", api_error.message);
575                                Err(backoff::Error::Transient {
576                                    err: OpenAIError::ApiError(api_error),
577                                    retry_after: None,
578                                })
579                            } else {
580                                Err(backoff::Error::Permanent(OpenAIError::ApiError(api_error)))
581                            }
582                        }
583                        _ => Err(backoff::Error::Permanent(e)),
584                    }
585                }
586            }
587        })
588        .await
589    }
590
591    /// Execute a HTTP request (WASM version - single attempt, no retry)
592    #[cfg(target_family = "wasm")]
593    async fn execute_raw<M, Fut>(&self, request_maker: M) -> Result<(Bytes, HeaderMap), OpenAIError>
594    where
595        M: Fn() -> Fut,
596        Fut: core::future::Future<Output = Result<reqwest::Request, OpenAIError>>,
597    {
598        let request = request_maker().await?;
599        let response = self
600            .http_client
601            .execute(request)
602            .await
603            .map_err(OpenAIError::Reqwest)?;
604
605        read_response(response).await
606    }
607
608    /// Execute a HTTP request and retry on rate limit
609    ///
610    /// request_maker serves one purpose: to be able to create request again
611    /// to retry API call after getting rate limited. request_maker is async because
612    /// reqwest::multipart::Form is created by async calls to read files for uploads.
613    async fn execute<O, M, Fut>(&self, request_maker: M) -> Result<O, OpenAIError>
614    where
615        O: DeserializeOwned,
616        M: Fn() -> Fut,
617        Fut: core::future::Future<Output = Result<reqwest::Request, OpenAIError>>,
618    {
619        let (bytes, _headers) = self.execute_raw(request_maker).await?;
620
621        let response: O = serde_json::from_slice(bytes.as_ref())
622            .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
623
624        Ok(response)
625    }
626
627    /// Make HTTP POST request to receive SSE
628    #[allow(unused)]
629    #[cfg(not(target_family = "wasm"))]
630    pub(crate) async fn post_stream<I, O>(
631        &self,
632        path: &str,
633        request: I,
634        request_options: &RequestOptions,
635    ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
636    where
637        I: Serialize,
638        O: DeserializeOwned + std::marker::Send + 'static,
639    {
640        let request_builder = self
641            .build_request_builder(reqwest::Method::POST, path, request_options)
642            .json(&request);
643
644        let event_source = request_builder.eventsource().unwrap();
645
646        stream(event_source).await
647    }
648
649    #[allow(unused)]
650    #[cfg(not(target_family = "wasm"))]
651    pub(crate) async fn post_stream_mapped_raw_events<I, O>(
652        &self,
653        path: &str,
654        request: I,
655        request_options: &RequestOptions,
656        event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
657    ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
658    where
659        I: Serialize,
660        O: DeserializeOwned + std::marker::Send + 'static,
661    {
662        let request_builder = self
663            .build_request_builder(reqwest::Method::POST, path, request_options)
664            .json(&request);
665
666        let event_source = request_builder.eventsource().unwrap();
667
668        stream_mapped_raw_events(event_source, event_mapper).await
669    }
670
671    /// Make HTTP GET request to receive SSE
672    #[allow(unused)]
673    #[cfg(not(target_family = "wasm"))]
674    pub(crate) async fn get_stream<O>(
675        &self,
676        path: &str,
677        request_options: &RequestOptions,
678    ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
679    where
680        O: DeserializeOwned + std::marker::Send + 'static,
681    {
682        let request_builder =
683            self.build_request_builder(reqwest::Method::GET, path, request_options);
684
685        let event_source = request_builder.eventsource().unwrap();
686
687        stream(event_source).await
688    }
689}
690
691async fn read_response(response: Response) -> Result<(Bytes, HeaderMap), OpenAIError> {
692    let status = response.status();
693    let headers = response.headers().clone();
694    let bytes = response.bytes().await.map_err(OpenAIError::Reqwest)?;
695
696    if status.is_server_error() {
697        // OpenAI does not guarantee server errors are returned as JSON so we cannot deserialize them.
698        let message: String = String::from_utf8_lossy(&bytes).into_owned();
699        tracing::warn!("Server error: {status} - {message}");
700        return Err(OpenAIError::ApiError(ApiError {
701            message,
702            r#type: None,
703            param: None,
704            code: None,
705        }));
706    }
707
708    // Deserialize response body from either error object or actual response object
709    if !status.is_success() {
710        let wrapped_error: WrappedError = serde_json::from_slice(bytes.as_ref())
711            .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
712
713        return Err(OpenAIError::ApiError(wrapped_error.error));
714    }
715
716    Ok((bytes, headers))
717}
718
719#[cfg(not(target_family = "wasm"))]
720async fn map_stream_error(value: EventSourceError) -> OpenAIError {
721    match value {
722        EventSourceError::InvalidStatusCode(status_code, response) => {
723            read_response(response).await.expect_err(&format!(
724                "Unreachable because read_response returns err when status_code {status_code} is invalid"
725            ))
726        }
727        _ => OpenAIError::StreamError(Box::new(StreamError::ReqwestEventSource(value))),
728    }
729}
730
731/// Request which responds with SSE.
732/// [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format)
733#[cfg(not(target_family = "wasm"))]
734pub(crate) async fn stream<O>(
735    mut event_source: EventSource,
736) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
737where
738    O: DeserializeOwned + std::marker::Send + 'static,
739{
740    let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
741
742    tokio::spawn(async move {
743        while let Some(ev) = event_source.next().await {
744            match ev {
745                Err(e) => {
746                    // Handle StreamEnded gracefully - it's a normal end of stream, not an error
747                    // https://github.com/64bit/async-openai/issues/456
748                    match &e {
749                        EventSourceError::StreamEnded => {
750                            break;
751                        }
752                        _ => {
753                            if let Err(_e) = tx.send(Err(map_stream_error(e).await)) {
754                                // rx dropped
755                                break;
756                            }
757                        }
758                    }
759                }
760                Ok(event) => match event {
761                    Event::Message(message) => {
762                        if message.data == "[DONE]" {
763                            break;
764                        }
765
766                        let response = match serde_json::from_str::<O>(&message.data) {
767                            Err(e) => Err(map_deserialization_error(e, message.data.as_bytes())),
768                            Ok(output) => Ok(output),
769                        };
770
771                        if let Err(_e) = tx.send(response) {
772                            // rx dropped
773                            break;
774                        }
775                    }
776                    Event::Open => continue,
777                },
778            }
779        }
780
781        event_source.close();
782    });
783
784    Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
785}
786
787#[cfg(not(target_family = "wasm"))]
788pub(crate) async fn stream_mapped_raw_events<O>(
789    mut event_source: EventSource,
790    event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
791) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
792where
793    O: DeserializeOwned + std::marker::Send + 'static,
794{
795    let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
796
797    tokio::spawn(async move {
798        while let Some(ev) = event_source.next().await {
799            match ev {
800                Err(e) => {
801                    // Handle StreamEnded gracefully - it's a normal end of stream, not an error
802                    // https://github.com/64bit/async-openai/issues/456
803                    match &e {
804                        EventSourceError::StreamEnded => {
805                            break;
806                        }
807                        _ => {
808                            if let Err(_e) = tx.send(Err(map_stream_error(e).await)) {
809                                // rx dropped
810                                break;
811                            }
812                        }
813                    }
814                }
815                Ok(event) => match event {
816                    Event::Message(message) => {
817                        let mut done = false;
818
819                        if message.data == "[DONE]" {
820                            done = true;
821                        }
822
823                        let response = event_mapper(message);
824
825                        if let Err(_e) = tx.send(response) {
826                            // rx dropped
827                            break;
828                        }
829
830                        if done {
831                            break;
832                        }
833                    }
834                    Event::Open => continue,
835                },
836            }
837        }
838
839        event_source.close();
840    });
841
842    Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
843}