async_sink/ext/
unfold.rs

1use super::Sink;
2use crate::unfold_state::UnfoldState;
3use core::fmt;
4use core::future::Future;
5use core::pin::Pin;
6use core::task::{Context, Poll};
7
8/// Sink for the [`unfold`] function.
9#[must_use = "sinks do nothing unless polled"]
10pub struct Unfold<T, F, Fut> {
11    function: F,
12    state: UnfoldState<T, Fut>,
13}
14
15impl<T, F, Fut> Unfold<T, F, Fut> {
16    // Helper to get a mutable reference to the `function` field and a
17    // pinned mutable reference to the `state` field.
18    //
19    // # Safety
20    //
21    // This is `unsafe` because it returns a `Pin` to one of the fields of the
22    // struct. The caller must ensure that they don't move the struct while this
23    // `Pin` is in use.
24    unsafe fn project(self: Pin<&mut Self>) -> (&mut F, Pin<&mut UnfoldState<T, Fut>>) {
25        let this = self.get_unchecked_mut();
26        (&mut this.function, Pin::new_unchecked(&mut this.state))
27    }
28}
29
30impl<T, F, Fut> fmt::Debug for Unfold<T, F, Fut>
31where
32    T: fmt::Debug,
33    Fut: fmt::Debug,
34{
35    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
36        f.debug_struct("Unfold")
37            .field("state", &self.state)
38            .finish()
39    }
40}
41
42/// Create a sink from a function which processes one item at a time.
43///
44/// # Examples
45///
46/// ```
47/// use core::pin::pin;
48/// use async_sink::SinkExt;
49/// use tokio::sync::Mutex;
50/// use std::sync::Arc;
51///
52/// #[tokio::main]
53/// async fn main() {
54/// let output: Arc<Mutex<Vec<usize>>> = Arc::new(tokio::sync::Mutex::new(Vec::new()));
55///
56/// let unfold = async_sink::unfold(0, |mut sum, i: usize| {
57///    let cb_output = output.clone();
58///    async move {
59///         sum += i;
60///         cb_output.clone().lock().await.push(sum);
61///         Ok::<_, core::convert::Infallible>(sum)
62///    }
63/// });
64/// let mut unfold = pin!(unfold);
65/// let input: [usize; 3] = [5, 15, 35];
66/// assert!(unfold.send_all(&mut tokio_stream::iter(input.iter().copied().map(|i| Ok(i)))).await.is_ok());
67/// assert_eq!(output.lock().await.as_slice(),input.iter().scan(0, |state, &x|
68///   { *state += x; Some(*state) }).collect::<Vec<usize>>().as_slice()
69/// );
70/// }
71/// ```
72pub fn unfold<T, F, Fut, Item, E>(init: T, function: F) -> Unfold<T, F, Fut>
73where
74    F: FnMut(T, Item) -> Fut,
75    Fut: Future<Output = Result<T, E>>,
76{
77    Unfold {
78        function,
79        state: UnfoldState::Value { value: init },
80    }
81}
82
83impl<T, F, Fut, Item, E> Sink<Item> for Unfold<T, F, Fut>
84where
85    E: core::error::Error,
86    F: FnMut(T, Item) -> Fut,
87    Fut: Future<Output = Result<T, E>>,
88{
89    type Error = E;
90
91    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
92        self.poll_flush(cx)
93    }
94
95    fn start_send(self: Pin<&mut Self>, item: Item) -> Result<(), Self::Error> {
96        let (function, state_pin) = unsafe { self.project() };
97        let state_mut = unsafe { state_pin.get_unchecked_mut() };
98
99        let value = match state_mut {
100            UnfoldState::Value { .. } => {
101                if let UnfoldState::Value { value } = unsafe { core::ptr::read(state_mut) } {
102                    value
103                } else {
104                    unreachable!()
105                }
106            }
107            _ => panic!("start_send called without poll_ready being called first"),
108        };
109
110        let future = function(value, item);
111        unsafe { core::ptr::write(state_mut, UnfoldState::Future { future }) };
112        Ok(())
113    }
114
115    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
116        let (_, state_pin) = unsafe { self.project() };
117        let state_mut = unsafe { state_pin.get_unchecked_mut() };
118
119        if let UnfoldState::Future { future } = state_mut {
120            let result = match unsafe { Pin::new_unchecked(future) }.poll(cx) {
121                Poll::Ready(result) => result,
122                Poll::Pending => return Poll::Pending,
123            };
124
125            // The future is finished, so we can replace the state.
126            // First, destruct the old state.
127            let _old_state = unsafe { core::ptr::read(state_mut) };
128
129            match result {
130                Ok(state) => {
131                    unsafe { core::ptr::write(state_mut, UnfoldState::Value { value: state }) };
132                    Poll::Ready(Ok(()))
133                }
134                Err(err) => {
135                    unsafe { core::ptr::write(state_mut, UnfoldState::Empty) };
136                    Poll::Ready(Err(err))
137                }
138            }
139        } else {
140            Poll::Ready(Ok(()))
141        }
142    }
143
144    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
145        self.poll_flush(cx)
146    }
147}