1use super::Sink;
2use crate::unfold_state::UnfoldState;
3use core::fmt;
4use core::future::Future;
5use core::pin::Pin;
6use core::task::{Context, Poll};
7
8#[must_use = "sinks do nothing unless polled"]
10pub struct Unfold<T, F, Fut> {
11 function: F,
12 state: UnfoldState<T, Fut>,
13}
14
15impl<T, F, Fut> Unfold<T, F, Fut> {
16 unsafe fn project(self: Pin<&mut Self>) -> (&mut F, Pin<&mut UnfoldState<T, Fut>>) {
25 let this = self.get_unchecked_mut();
26 (&mut this.function, Pin::new_unchecked(&mut this.state))
27 }
28}
29
30impl<T, F, Fut> fmt::Debug for Unfold<T, F, Fut>
31where
32 T: fmt::Debug,
33 Fut: fmt::Debug,
34{
35 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
36 f.debug_struct("Unfold")
37 .field("state", &self.state)
38 .finish()
39 }
40}
41
42pub fn unfold<T, F, Fut, Item, E>(init: T, function: F) -> Unfold<T, F, Fut>
73where
74 F: FnMut(T, Item) -> Fut,
75 Fut: Future<Output = Result<T, E>>,
76{
77 Unfold {
78 function,
79 state: UnfoldState::Value { value: init },
80 }
81}
82
83impl<T, F, Fut, Item, E> Sink<Item> for Unfold<T, F, Fut>
84where
85 E: core::error::Error,
86 F: FnMut(T, Item) -> Fut,
87 Fut: Future<Output = Result<T, E>>,
88{
89 type Error = E;
90
91 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
92 self.poll_flush(cx)
93 }
94
95 fn start_send(self: Pin<&mut Self>, item: Item) -> Result<(), Self::Error> {
96 let (function, state_pin) = unsafe { self.project() };
97 let state_mut = unsafe { state_pin.get_unchecked_mut() };
98
99 let value = match state_mut {
100 UnfoldState::Value { .. } => {
101 if let UnfoldState::Value { value } = unsafe { core::ptr::read(state_mut) } {
102 value
103 } else {
104 unreachable!()
105 }
106 }
107 _ => panic!("start_send called without poll_ready being called first"),
108 };
109
110 let future = function(value, item);
111 unsafe { core::ptr::write(state_mut, UnfoldState::Future { future }) };
112 Ok(())
113 }
114
115 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
116 let (_, state_pin) = unsafe { self.project() };
117 let state_mut = unsafe { state_pin.get_unchecked_mut() };
118
119 if let UnfoldState::Future { future } = state_mut {
120 let result = match unsafe { Pin::new_unchecked(future) }.poll(cx) {
121 Poll::Ready(result) => result,
122 Poll::Pending => return Poll::Pending,
123 };
124
125 let _old_state = unsafe { core::ptr::read(state_mut) };
128
129 match result {
130 Ok(state) => {
131 unsafe { core::ptr::write(state_mut, UnfoldState::Value { value: state }) };
132 Poll::Ready(Ok(()))
133 }
134 Err(err) => {
135 unsafe { core::ptr::write(state_mut, UnfoldState::Empty) };
136 Poll::Ready(Err(err))
137 }
138 }
139 } else {
140 Poll::Ready(Ok(()))
141 }
142 }
143
144 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
145 self.poll_flush(cx)
146 }
147}