async_openai/
client.rs

1#[cfg(not(target_family = "wasm"))]
2use std::pin::Pin;
3
4use bytes::Bytes;
5#[cfg(not(target_family = "wasm"))]
6use futures::{stream::StreamExt, Stream};
7use reqwest::{header::HeaderMap, multipart::Form, Response};
8#[cfg(not(target_family = "wasm"))]
9use reqwest_eventsource::{Error as EventSourceError, Event, EventSource, RequestBuilderExt};
10use serde::{de::DeserializeOwned, Serialize};
11
12#[cfg(not(target_family = "wasm"))]
13use crate::error::StreamError;
14use crate::{
15    config::{Config, OpenAIConfig},
16    error::{map_deserialization_error, ApiError, OpenAIError, WrappedError},
17    traits::AsyncTryFrom,
18    RequestOptions,
19};
20
21#[cfg(feature = "administration")]
22use crate::admin::Admin;
23#[cfg(feature = "chatkit")]
24use crate::chatkit::Chatkit;
25#[cfg(feature = "file")]
26use crate::file::Files;
27#[cfg(feature = "image")]
28use crate::image::Images;
29#[cfg(feature = "moderation")]
30use crate::moderation::Moderations;
31#[cfg(feature = "assistant")]
32use crate::Assistants;
33#[cfg(feature = "audio")]
34use crate::Audio;
35#[cfg(feature = "batch")]
36use crate::Batches;
37#[cfg(feature = "chat-completion")]
38use crate::Chat;
39#[cfg(feature = "completions")]
40use crate::Completions;
41#[cfg(feature = "container")]
42use crate::Containers;
43#[cfg(feature = "responses")]
44use crate::Conversations;
45#[cfg(feature = "embedding")]
46use crate::Embeddings;
47#[cfg(feature = "evals")]
48use crate::Evals;
49#[cfg(feature = "finetuning")]
50use crate::FineTuning;
51#[cfg(feature = "model")]
52use crate::Models;
53#[cfg(feature = "realtime")]
54use crate::Realtime;
55#[cfg(feature = "responses")]
56use crate::Responses;
57#[cfg(feature = "assistant")]
58use crate::Threads;
59#[cfg(feature = "upload")]
60use crate::Uploads;
61#[cfg(feature = "vectorstore")]
62use crate::VectorStores;
63#[cfg(feature = "video")]
64use crate::Videos;
65
66#[derive(Debug, Clone)]
67/// Client is a container for config, backoff and http_client
68/// used to make API calls.
69pub struct Client<C: Config> {
70    http_client: reqwest::Client,
71    config: C,
72    #[cfg(not(target_family = "wasm"))]
73    backoff: backoff::ExponentialBackoff,
74}
75
76impl<C: Config> Default for Client<C>
77where
78    C: Default,
79{
80    fn default() -> Self {
81        Self {
82            http_client: reqwest::Client::new(),
83            config: C::default(),
84            #[cfg(not(target_family = "wasm"))]
85            backoff: Default::default(),
86        }
87    }
88}
89
90impl Client<OpenAIConfig> {
91    /// Client with default [OpenAIConfig]
92    pub fn new() -> Self {
93        Self::default()
94    }
95}
96
97impl<C: Config> Client<C> {
98    /// Create client with a custom HTTP client, OpenAI config, and backoff.
99    #[cfg(not(target_family = "wasm"))]
100    pub fn build(
101        http_client: reqwest::Client,
102        config: C,
103        backoff: backoff::ExponentialBackoff,
104    ) -> Self {
105        Self {
106            http_client,
107            config,
108            backoff,
109        }
110    }
111
112    /// Create client with a custom HTTP client and config (WASM version without backoff).
113    #[cfg(target_family = "wasm")]
114    pub fn build(http_client: reqwest::Client, config: C) -> Self {
115        Self {
116            http_client,
117            config,
118        }
119    }
120
121    /// Create client with [OpenAIConfig] or [crate::config::AzureConfig]
122    pub fn with_config(config: C) -> Self {
123        Self {
124            http_client: reqwest::Client::new(),
125            config,
126            #[cfg(not(target_family = "wasm"))]
127            backoff: Default::default(),
128        }
129    }
130
131    /// Provide your own [client] to make HTTP requests with.
132    ///
133    /// [client]: reqwest::Client
134    pub fn with_http_client(mut self, http_client: reqwest::Client) -> Self {
135        self.http_client = http_client;
136        self
137    }
138
139    /// Exponential backoff for retrying [rate limited](https://platform.openai.com/docs/guides/rate-limits) requests.
140    #[cfg(not(target_family = "wasm"))]
141    pub fn with_backoff(mut self, backoff: backoff::ExponentialBackoff) -> Self {
142        self.backoff = backoff;
143        self
144    }
145
146    // API groups
147
148    /// To call [Models] group related APIs using this client.
149    #[cfg(feature = "model")]
150    pub fn models(&self) -> Models<'_, C> {
151        Models::new(self)
152    }
153
154    /// To call [Completions] group related APIs using this client.
155    #[cfg(feature = "completions")]
156    pub fn completions(&self) -> Completions<'_, C> {
157        Completions::new(self)
158    }
159
160    /// To call [Chat] group related APIs using this client.
161    #[cfg(feature = "chat-completion")]
162    pub fn chat(&self) -> Chat<'_, C> {
163        Chat::new(self)
164    }
165
166    /// To call [Images] group related APIs using this client.
167    #[cfg(feature = "image")]
168    pub fn images(&self) -> Images<'_, C> {
169        Images::new(self)
170    }
171
172    /// To call [Moderations] group related APIs using this client.
173    #[cfg(feature = "moderation")]
174    pub fn moderations(&self) -> Moderations<'_, C> {
175        Moderations::new(self)
176    }
177
178    /// To call [Files] group related APIs using this client.
179    #[cfg(feature = "file")]
180    pub fn files(&self) -> Files<'_, C> {
181        Files::new(self)
182    }
183
184    /// To call [Uploads] group related APIs using this client.
185    #[cfg(feature = "upload")]
186    pub fn uploads(&self) -> Uploads<'_, C> {
187        Uploads::new(self)
188    }
189
190    /// To call [FineTuning] group related APIs using this client.
191    #[cfg(feature = "finetuning")]
192    pub fn fine_tuning(&self) -> FineTuning<'_, C> {
193        FineTuning::new(self)
194    }
195
196    /// To call [Embeddings] group related APIs using this client.
197    #[cfg(feature = "embedding")]
198    pub fn embeddings(&self) -> Embeddings<'_, C> {
199        Embeddings::new(self)
200    }
201
202    /// To call [Audio] group related APIs using this client.
203    #[cfg(feature = "audio")]
204    pub fn audio(&self) -> Audio<'_, C> {
205        Audio::new(self)
206    }
207
208    /// To call [Videos] group related APIs using this client.
209    #[cfg(feature = "video")]
210    pub fn videos(&self) -> Videos<'_, C> {
211        Videos::new(self)
212    }
213
214    /// To call [Assistants] group related APIs using this client.
215    #[cfg(feature = "assistant")]
216    pub fn assistants(&self) -> Assistants<'_, C> {
217        Assistants::new(self)
218    }
219
220    /// To call [Threads] group related APIs using this client.
221    #[cfg(feature = "assistant")]
222    pub fn threads(&self) -> Threads<'_, C> {
223        Threads::new(self)
224    }
225
226    /// To call [VectorStores] group related APIs using this client.
227    #[cfg(feature = "vectorstore")]
228    pub fn vector_stores(&self) -> VectorStores<'_, C> {
229        VectorStores::new(self)
230    }
231
232    /// To call [Batches] group related APIs using this client.
233    #[cfg(feature = "batch")]
234    pub fn batches(&self) -> Batches<'_, C> {
235        Batches::new(self)
236    }
237
238    /// To call [Admin] group related APIs using this client.
239    /// This groups together admin API keys, invites, users, projects, audit logs, and certificates.
240    #[cfg(feature = "administration")]
241    pub fn admin(&self) -> Admin<'_, C> {
242        Admin::new(self)
243    }
244
245    /// To call [Responses] group related APIs using this client.
246    #[cfg(feature = "responses")]
247    pub fn responses(&self) -> Responses<'_, C> {
248        Responses::new(self)
249    }
250
251    /// To call [Conversations] group related APIs using this client.
252    #[cfg(feature = "responses")]
253    pub fn conversations(&self) -> Conversations<'_, C> {
254        Conversations::new(self)
255    }
256
257    /// To call [Containers] group related APIs using this client.
258    #[cfg(feature = "container")]
259    pub fn containers(&self) -> Containers<'_, C> {
260        Containers::new(self)
261    }
262
263    /// To call [Evals] group related APIs using this client.
264    #[cfg(feature = "evals")]
265    pub fn evals(&self) -> Evals<'_, C> {
266        Evals::new(self)
267    }
268
269    #[cfg(feature = "chatkit")]
270    pub fn chatkit(&self) -> Chatkit<'_, C> {
271        Chatkit::new(self)
272    }
273
274    /// To call [Realtime] group related APIs using this client.
275    #[cfg(feature = "realtime")]
276    pub fn realtime(&self) -> Realtime<'_, C> {
277        Realtime::new(self)
278    }
279
280    pub fn config(&self) -> &C {
281        &self.config
282    }
283
284    /// Helper function to build a request builder with common configuration
285    fn build_request_builder(
286        &self,
287        method: reqwest::Method,
288        path: &str,
289        request_options: &RequestOptions,
290    ) -> reqwest::RequestBuilder {
291        let mut request_builder = if let Some(path) = request_options.path() {
292            self.http_client
293                .request(method, self.config.url(path.as_str()))
294        } else {
295            self.http_client.request(method, self.config.url(path))
296        };
297
298        request_builder = request_builder
299            .query(&self.config.query())
300            .headers(self.config.headers());
301
302        if let Some(headers) = request_options.headers() {
303            request_builder = request_builder.headers(headers.clone());
304        }
305
306        if !request_options.query().is_empty() {
307            request_builder = request_builder.query(request_options.query());
308        }
309
310        request_builder
311    }
312
313    /// Make a GET request to {path} and deserialize the response body
314    #[allow(unused)]
315    pub(crate) async fn get<O>(
316        &self,
317        path: &str,
318        request_options: &RequestOptions,
319    ) -> Result<O, OpenAIError>
320    where
321        O: DeserializeOwned,
322    {
323        let request_maker = || async {
324            Ok(self
325                .build_request_builder(reqwest::Method::GET, path, request_options)
326                .build()?)
327        };
328
329        self.execute(request_maker).await
330    }
331
332    /// Make a DELETE request to {path} and deserialize the response body
333    #[allow(unused)]
334    pub(crate) async fn delete<O>(
335        &self,
336        path: &str,
337        request_options: &RequestOptions,
338    ) -> Result<O, OpenAIError>
339    where
340        O: DeserializeOwned,
341    {
342        let request_maker = || async {
343            Ok(self
344                .build_request_builder(reqwest::Method::DELETE, path, request_options)
345                .build()?)
346        };
347
348        self.execute(request_maker).await
349    }
350
351    /// Make a GET request to {path} and return the response body
352    #[allow(unused)]
353    pub(crate) async fn get_raw(
354        &self,
355        path: &str,
356        request_options: &RequestOptions,
357    ) -> Result<(Bytes, HeaderMap), OpenAIError> {
358        let request_maker = || async {
359            Ok(self
360                .build_request_builder(reqwest::Method::GET, path, request_options)
361                .build()?)
362        };
363
364        self.execute_raw(request_maker).await
365    }
366
367    /// Make a POST request to {path} and return the response body
368    #[allow(unused)]
369    pub(crate) async fn post_raw<I>(
370        &self,
371        path: &str,
372        request: I,
373        request_options: &RequestOptions,
374    ) -> Result<(Bytes, HeaderMap), OpenAIError>
375    where
376        I: Serialize,
377    {
378        let request_maker = || async {
379            Ok(self
380                .build_request_builder(reqwest::Method::POST, path, request_options)
381                .json(&request)
382                .build()?)
383        };
384
385        self.execute_raw(request_maker).await
386    }
387
388    /// Make a POST request to {path} and deserialize the response body
389    #[allow(unused)]
390    pub(crate) async fn post<I, O>(
391        &self,
392        path: &str,
393        request: I,
394        request_options: &RequestOptions,
395    ) -> Result<O, OpenAIError>
396    where
397        I: Serialize,
398        O: DeserializeOwned,
399    {
400        let request_maker = || async {
401            Ok(self
402                .build_request_builder(reqwest::Method::POST, path, request_options)
403                .json(&request)
404                .build()?)
405        };
406
407        self.execute(request_maker).await
408    }
409
410    /// POST a form at {path} and return the response body
411    #[allow(unused)]
412    pub(crate) async fn post_form_raw<F>(
413        &self,
414        path: &str,
415        form: F,
416        request_options: &RequestOptions,
417    ) -> Result<(Bytes, HeaderMap), OpenAIError>
418    where
419        Form: AsyncTryFrom<F, Error = OpenAIError>,
420        F: Clone,
421    {
422        let request_maker = || async {
423            Ok(self
424                .build_request_builder(reqwest::Method::POST, path, request_options)
425                .multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
426                .build()?)
427        };
428
429        self.execute_raw(request_maker).await
430    }
431
432    /// POST a form at {path} and deserialize the response body
433    #[allow(unused)]
434    pub(crate) async fn post_form<O, F>(
435        &self,
436        path: &str,
437        form: F,
438        request_options: &RequestOptions,
439    ) -> Result<O, OpenAIError>
440    where
441        O: DeserializeOwned,
442        Form: AsyncTryFrom<F, Error = OpenAIError>,
443        F: Clone,
444    {
445        let request_maker = || async {
446            Ok(self
447                .build_request_builder(reqwest::Method::POST, path, request_options)
448                .multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
449                .build()?)
450        };
451
452        self.execute(request_maker).await
453    }
454
455    #[allow(unused)]
456    #[cfg(not(target_family = "wasm"))]
457    pub(crate) async fn post_form_stream<O, F>(
458        &self,
459        path: &str,
460        form: F,
461        request_options: &RequestOptions,
462    ) -> Result<Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>, OpenAIError>
463    where
464        F: Clone,
465        Form: AsyncTryFrom<F, Error = OpenAIError>,
466        O: DeserializeOwned + std::marker::Send + 'static,
467    {
468        // Build and execute request manually since multipart::Form is not Clone
469        // and .eventsource() requires cloneability
470        let request_builder = self
471            .build_request_builder(reqwest::Method::POST, path, request_options)
472            .multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?);
473
474        let response = request_builder.send().await.map_err(OpenAIError::Reqwest)?;
475
476        // Check for error status
477        if !response.status().is_success() {
478            return Err(read_response(response).await.unwrap_err());
479        }
480
481        // Convert response body to EventSource stream
482        let stream = response
483            .bytes_stream()
484            .map(|result| result.map_err(std::io::Error::other));
485        let event_stream = eventsource_stream::EventStream::new(stream);
486
487        // Convert EventSource stream to our expected format
488        let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
489
490        tokio::spawn(async move {
491            use futures::StreamExt;
492            let mut event_stream = std::pin::pin!(event_stream);
493
494            while let Some(event_result) = event_stream.next().await {
495                match event_result {
496                    Err(e) => {
497                        if let Err(_e) = tx.send(Err(OpenAIError::StreamError(Box::new(
498                            StreamError::EventStream(e.to_string()),
499                        )))) {
500                            break;
501                        }
502                    }
503                    Ok(event) => {
504                        // eventsource_stream::Event is a struct with data field
505                        if event.data == "[DONE]" {
506                            break;
507                        }
508
509                        let response = match serde_json::from_str::<O>(&event.data) {
510                            Err(e) => Err(map_deserialization_error(e, event.data.as_bytes())),
511                            Ok(output) => Ok(output),
512                        };
513
514                        if let Err(_e) = tx.send(response) {
515                            break;
516                        }
517                    }
518                }
519            }
520        });
521
522        Ok(Box::pin(
523            tokio_stream::wrappers::UnboundedReceiverStream::new(rx),
524        ))
525    }
526
527    /// Execute a HTTP request and retry on rate limit (non-WASM version with backoff)
528    ///
529    /// request_maker serves one purpose: to be able to create request again
530    /// to retry API call after getting rate limited. request_maker is async because
531    /// reqwest::multipart::Form is created by async calls to read files for uploads.
532    #[cfg(not(target_family = "wasm"))]
533    async fn execute_raw<M, Fut>(&self, request_maker: M) -> Result<(Bytes, HeaderMap), OpenAIError>
534    where
535        M: Fn() -> Fut,
536        Fut: core::future::Future<Output = Result<reqwest::Request, OpenAIError>>,
537    {
538        let client = self.http_client.clone();
539
540        backoff::future::retry(self.backoff.clone(), || async {
541            let request = request_maker().await.map_err(backoff::Error::Permanent)?;
542            let response = client
543                .execute(request)
544                .await
545                .map_err(OpenAIError::Reqwest)
546                .map_err(backoff::Error::Permanent)?;
547
548            let status = response.status();
549
550            match read_response(response).await {
551                Ok((bytes, headers)) => Ok((bytes, headers)),
552                Err(e) => {
553                    match e {
554                        OpenAIError::ApiError(api_error) => {
555                            if status.is_server_error() {
556                                Err(backoff::Error::Transient {
557                                    err: OpenAIError::ApiError(api_error),
558                                    retry_after: None,
559                                })
560                            } else if status.as_u16() == 429
561                                && api_error.r#type != Some("insufficient_quota".to_string())
562                            {
563                                // Rate limited retry...
564                                tracing::warn!("Rate limited: {}", api_error.message);
565                                Err(backoff::Error::Transient {
566                                    err: OpenAIError::ApiError(api_error),
567                                    retry_after: None,
568                                })
569                            } else {
570                                Err(backoff::Error::Permanent(OpenAIError::ApiError(api_error)))
571                            }
572                        }
573                        _ => Err(backoff::Error::Permanent(e)),
574                    }
575                }
576            }
577        })
578        .await
579    }
580
581    /// Execute a HTTP request (WASM version - single attempt, no retry)
582    #[cfg(target_family = "wasm")]
583    async fn execute_raw<M, Fut>(&self, request_maker: M) -> Result<(Bytes, HeaderMap), OpenAIError>
584    where
585        M: Fn() -> Fut,
586        Fut: core::future::Future<Output = Result<reqwest::Request, OpenAIError>>,
587    {
588        let request = request_maker().await?;
589        let response = self
590            .http_client
591            .execute(request)
592            .await
593            .map_err(OpenAIError::Reqwest)?;
594
595        read_response(response).await
596    }
597
598    /// Execute a HTTP request and retry on rate limit
599    ///
600    /// request_maker serves one purpose: to be able to create request again
601    /// to retry API call after getting rate limited. request_maker is async because
602    /// reqwest::multipart::Form is created by async calls to read files for uploads.
603    async fn execute<O, M, Fut>(&self, request_maker: M) -> Result<O, OpenAIError>
604    where
605        O: DeserializeOwned,
606        M: Fn() -> Fut,
607        Fut: core::future::Future<Output = Result<reqwest::Request, OpenAIError>>,
608    {
609        let (bytes, _headers) = self.execute_raw(request_maker).await?;
610
611        let response: O = serde_json::from_slice(bytes.as_ref())
612            .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
613
614        Ok(response)
615    }
616
617    /// Make HTTP POST request to receive SSE
618    #[allow(unused)]
619    #[cfg(not(target_family = "wasm"))]
620    pub(crate) async fn post_stream<I, O>(
621        &self,
622        path: &str,
623        request: I,
624        request_options: &RequestOptions,
625    ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
626    where
627        I: Serialize,
628        O: DeserializeOwned + std::marker::Send + 'static,
629    {
630        let request_builder = self
631            .build_request_builder(reqwest::Method::POST, path, request_options)
632            .json(&request);
633
634        let event_source = request_builder.eventsource().unwrap();
635
636        stream(event_source).await
637    }
638
639    #[allow(unused)]
640    #[cfg(not(target_family = "wasm"))]
641    pub(crate) async fn post_stream_mapped_raw_events<I, O>(
642        &self,
643        path: &str,
644        request: I,
645        request_options: &RequestOptions,
646        event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
647    ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
648    where
649        I: Serialize,
650        O: DeserializeOwned + std::marker::Send + 'static,
651    {
652        let request_builder = self
653            .build_request_builder(reqwest::Method::POST, path, request_options)
654            .json(&request);
655
656        let event_source = request_builder.eventsource().unwrap();
657
658        stream_mapped_raw_events(event_source, event_mapper).await
659    }
660
661    /// Make HTTP GET request to receive SSE
662    #[allow(unused)]
663    #[cfg(not(target_family = "wasm"))]
664    pub(crate) async fn get_stream<O>(
665        &self,
666        path: &str,
667        request_options: &RequestOptions,
668    ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
669    where
670        O: DeserializeOwned + std::marker::Send + 'static,
671    {
672        let request_builder =
673            self.build_request_builder(reqwest::Method::GET, path, request_options);
674
675        let event_source = request_builder.eventsource().unwrap();
676
677        stream(event_source).await
678    }
679}
680
681async fn read_response(response: Response) -> Result<(Bytes, HeaderMap), OpenAIError> {
682    let status = response.status();
683    let headers = response.headers().clone();
684    let bytes = response.bytes().await.map_err(OpenAIError::Reqwest)?;
685
686    if status.is_server_error() {
687        // OpenAI does not guarantee server errors are returned as JSON so we cannot deserialize them.
688        let message: String = String::from_utf8_lossy(&bytes).into_owned();
689        tracing::warn!("Server error: {status} - {message}");
690        return Err(OpenAIError::ApiError(ApiError {
691            message,
692            r#type: None,
693            param: None,
694            code: None,
695        }));
696    }
697
698    // Deserialize response body from either error object or actual response object
699    if !status.is_success() {
700        let wrapped_error: WrappedError = serde_json::from_slice(bytes.as_ref())
701            .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
702
703        return Err(OpenAIError::ApiError(wrapped_error.error));
704    }
705
706    Ok((bytes, headers))
707}
708
709#[cfg(not(target_family = "wasm"))]
710async fn map_stream_error(value: EventSourceError) -> OpenAIError {
711    match value {
712        EventSourceError::InvalidStatusCode(status_code, response) => {
713            read_response(response).await.expect_err(&format!(
714                "Unreachable because read_response returns err when status_code {status_code} is invalid"
715            ))
716        }
717        _ => OpenAIError::StreamError(Box::new(StreamError::ReqwestEventSource(value))),
718    }
719}
720
721/// Request which responds with SSE.
722/// [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format)
723#[cfg(not(target_family = "wasm"))]
724pub(crate) async fn stream<O>(
725    mut event_source: EventSource,
726) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
727where
728    O: DeserializeOwned + std::marker::Send + 'static,
729{
730    let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
731
732    tokio::spawn(async move {
733        while let Some(ev) = event_source.next().await {
734            match ev {
735                Err(e) => {
736                    // Handle StreamEnded gracefully - it's a normal end of stream, not an error
737                    // https://github.com/64bit/async-openai/issues/456
738                    match &e {
739                        EventSourceError::StreamEnded => {
740                            break;
741                        }
742                        _ => {
743                            if let Err(_e) = tx.send(Err(map_stream_error(e).await)) {
744                                // rx dropped
745                                break;
746                            }
747                        }
748                    }
749                }
750                Ok(event) => match event {
751                    Event::Message(message) => {
752                        if message.data == "[DONE]" {
753                            break;
754                        }
755
756                        let response = match serde_json::from_str::<O>(&message.data) {
757                            Err(e) => Err(map_deserialization_error(e, message.data.as_bytes())),
758                            Ok(output) => Ok(output),
759                        };
760
761                        if let Err(_e) = tx.send(response) {
762                            // rx dropped
763                            break;
764                        }
765                    }
766                    Event::Open => continue,
767                },
768            }
769        }
770
771        event_source.close();
772    });
773
774    Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
775}
776
777#[cfg(not(target_family = "wasm"))]
778pub(crate) async fn stream_mapped_raw_events<O>(
779    mut event_source: EventSource,
780    event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
781) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
782where
783    O: DeserializeOwned + std::marker::Send + 'static,
784{
785    let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
786
787    tokio::spawn(async move {
788        while let Some(ev) = event_source.next().await {
789            match ev {
790                Err(e) => {
791                    // Handle StreamEnded gracefully - it's a normal end of stream, not an error
792                    // https://github.com/64bit/async-openai/issues/456
793                    match &e {
794                        EventSourceError::StreamEnded => {
795                            break;
796                        }
797                        _ => {
798                            if let Err(_e) = tx.send(Err(map_stream_error(e).await)) {
799                                // rx dropped
800                                break;
801                            }
802                        }
803                    }
804                }
805                Ok(event) => match event {
806                    Event::Message(message) => {
807                        let mut done = false;
808
809                        if message.data == "[DONE]" {
810                            done = true;
811                        }
812
813                        let response = event_mapper(message);
814
815                        if let Err(_e) = tx.send(response) {
816                            // rx dropped
817                            break;
818                        }
819
820                        if done {
821                            break;
822                        }
823                    }
824                    Event::Open => continue,
825                },
826            }
827        }
828
829        event_source.close();
830    });
831
832    Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
833}