1use super::event::Event;
2use crate::error::Error;
3use futures_util::{Stream, StreamExt};
4
5pub trait SseStreamExt<T, E>: Stream<Item = Result<T, E>> + Sized
26where
27 E: Into<Error>,
28{
29 fn cast_events<F>(self, mut f: F) -> impl Stream<Item = Result<Event, Error>> + Send
34 where
35 F: FnMut(T) -> Result<Event, Error> + Send,
36 T: Send,
37 E: Send,
38 Self: Send,
39 {
40 self.map(move |result| result.map_err(Into::into).and_then(&mut f))
41 }
42}
43
44impl<S, T, E> SseStreamExt<T, E> for S
46where
47 S: Stream<Item = Result<T, E>> + Sized,
48 E: Into<Error>,
49{
50}
51
52#[cfg(test)]
53mod tests {
54 use super::*;
55 use crate::error::Error;
56 use futures_util::StreamExt;
57
58 #[tokio::test]
59 async fn cast_events_maps_items() {
60 let stream = futures_util::stream::iter(vec![
61 Ok::<_, Error>("hello".to_string()),
62 Ok("world".to_string()),
63 ]);
64
65 let events: Vec<super::super::Event> = stream
66 .cast_events(|s| super::super::Event::new("id", "msg").map(|e| e.data(s)))
67 .filter_map(|r| async { r.ok() })
68 .collect()
69 .await;
70
71 assert_eq!(events.len(), 2);
72 assert_eq!(events[0].data.as_deref(), Some("hello"));
73 assert_eq!(events[1].data.as_deref(), Some("world"));
74 }
75
76 #[tokio::test]
77 async fn cast_events_propagates_source_errors() {
78 let stream = futures_util::stream::iter(vec![
79 Ok::<_, Error>("ok".to_string()),
80 Err(Error::internal("boom")),
81 ]);
82
83 let results: Vec<Result<super::super::Event, Error>> = stream
84 .cast_events(|s| super::super::Event::new("id", "msg").map(|e| e.data(s)))
85 .collect()
86 .await;
87
88 assert!(results[0].is_ok());
89 assert!(results[1].is_err());
90 assert_eq!(results[1].as_ref().unwrap_err().message(), "boom");
91 }
92
93 #[tokio::test]
94 async fn cast_events_propagates_closure_errors() {
95 let stream = futures_util::stream::iter(vec![Ok::<_, Error>("ok".to_string())]);
96
97 let results: Vec<Result<super::super::Event, Error>> = stream
98 .cast_events(|_| Err(Error::bad_request("bad")))
99 .collect()
100 .await;
101
102 assert!(results[0].is_err());
103 assert_eq!(results[0].as_ref().unwrap_err().message(), "bad");
104 }
105}