azure_speech/
stream_ext.rs

1use crate::callback::Callback;
2use core::fmt;
3use core::pin::Pin;
4use core::task::{Context, Poll};
5use pin_project_lite::pin_project;
6use std::future::Future;
7use std::pin::pin;
8use tokio_stream::{Stream, StreamExt as _};
9
10pin_project! {
11/// Stream for the [`stop_after`](stop_after) method.
12#[must_use = "streams do nothing unless polled"]
13    pub struct StopAfter<St, F> {
14        #[pin]
15        stream: St,
16        predicate: F,
17        done: bool,
18    }
19}
20
21impl<St, F> StopAfter<St, F> {
22    pub(super) fn new(stream: St, predicate: F) -> Self {
23        Self {
24            stream,
25            predicate,
26            done: false,
27        }
28    }
29}
30
31impl<St, F> fmt::Debug for StopAfter<St, F>
32where
33    St: fmt::Debug,
34{
35    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
36        f.debug_struct("StopAfter")
37            .field("stream", &self.stream)
38            .field("done", &self.done)
39            .finish()
40    }
41}
42
43impl<St, F> Stream for StopAfter<St, F>
44where
45    St: Stream,
46    F: FnMut(&St::Item) -> bool,
47{
48    type Item = St::Item;
49
50    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
51        if !*self.as_mut().project().done {
52            self.as_mut().project().stream.poll_next(cx).map(|ready| {
53                let ready = ready.map(|item| {
54                    if (self.as_mut().project().predicate)(&item) {
55                        *self.as_mut().project().done = true;
56                    }
57                    item
58                });
59                ready
60            })
61        } else {
62            Poll::Ready(None)
63        }
64    }
65
66    fn size_hint(&self) -> (usize, Option<usize>) {
67        if self.done {
68            return (0, Some(0));
69        }
70
71        let (_, upper) = self.stream.size_hint();
72
73        (0, upper)
74    }
75}
76
77/// An extension trait for `Stream` that provides a variety of convenient combinator functions.
78pub trait StreamExt: Stream
79where
80    Self: 'static,
81{
82    /// Takes elements from this stream until the provided predicate resolves to `true`.
83    ///
84    /// This function operates similarly to `Iterator::take_while`, extracting elements from the
85    /// stream until the predicate `f` evaluates to `false`. Unlike `Iterator::take_while`, this function
86    /// also returns the last evaluated element for which the predicate was `true`, marking the stream as done afterwards.
87    /// Once an element causes the predicate to return false, the stream will consistently return that it is finished.
88    ///
89    /// # Examples
90    ///
91    /// Basic usage:
92    ///
93    /// ```
94    /// use tokio_stream::{self as stream, StreamExt as _};
95    /// use azure_speech::StreamExt;
96    ///
97    /// #[tokio::main]
98    /// async fn main() {
99    ///     
100    /// let mut stream = stream::iter(1..=5).stop_after(|&x| x >= 3);
101    ///
102    ///     assert_eq!(Some(1), stream.next().await);
103    ///     assert_eq!(Some(2), stream.next().await);
104    ///     assert_eq!(Some(3), stream.next().await);
105    ///     // Since 4 > 3, the stream is now considered done
106    ///     assert_eq!(None, stream.next().await);
107    /// }
108    /// ```
109    ///
110    /// This function is particularly useful when you need to process elements of a stream up to a certain point,
111    /// and then stop processing, including the element that caused the stop condition.
112    fn stop_after<F>(self, f: F) -> StopAfter<Self, F>
113    where
114        F: FnMut(&Self::Item) -> bool,
115        Self: Sized,
116    {
117        StopAfter::new(self, f)
118    }
119
120    /// Calls the provided callback for each item in the stream.
121    fn use_callbacks<C>(self, callback: C) -> impl Future<Output = ()>
122    where
123        Self: Sized + Send + Sync,
124        C: Callback<Item = Self::Item> + 'static,
125    {
126        async move {
127            let mut _self = pin!(self);
128            while let Some(event) = _self.next().await {
129                callback.on_event(event).await;
130            }
131        }
132    }
133}
134
135impl<St: ?Sized + 'static> StreamExt for St where St: Stream {}
136
137#[cfg(test)]
138mod tests {
139    use super::*;
140    use crate::callback::Callback;
141    use std::sync::{Arc, Mutex};
142
143    struct CollectCallback(Arc<Mutex<Vec<i32>>>);
144
145    impl Callback for CollectCallback {
146        type Item = i32;
147        fn on_event(&self, item: Self::Item) -> impl Future<Output = ()> {
148            let data = self.0.clone();
149            async move {
150                data.lock().unwrap().push(item);
151            }
152        }
153    }
154
155    #[tokio::test]
156    async fn test_stop_after_includes_trigger() {
157        let stream = tokio_stream::iter(1..=5).stop_after(|&x| x >= 3);
158        let collected: Vec<_> = stream.collect().await;
159        assert_eq!(collected, vec![1, 2, 3]);
160    }
161
162    #[tokio::test]
163    async fn test_use_callbacks_collects_items() {
164        let store = Arc::new(Mutex::new(Vec::new()));
165        let cb = CollectCallback(store.clone());
166        tokio_stream::iter(1..=3).use_callbacks(cb).await;
167        assert_eq!(*store.lock().unwrap(), vec![1, 2, 3]);
168    }
169}