entelix_graph/
finalizing_stream.rs1use std::pin::Pin;
27use std::task::{Context, Poll};
28
29use futures::Stream;
30use pin_project_lite::pin_project;
31
32pin_project! {
33 pub struct FinalizingStream<St, F>
38 where
39 F: FnOnce(),
40 {
41 #[pin]
42 inner: St,
43 done: bool,
44 finalize: Option<F>,
45 }
46
47 impl<St, F> PinnedDrop for FinalizingStream<St, F>
48 where
49 F: FnOnce(),
50 {
51 fn drop(this: Pin<&mut Self>) {
52 let proj = this.project();
53 if !*proj.done && let Some(f) = proj.finalize.take() {
54 f();
55 }
56 }
57 }
58}
59
60impl<St, F> FinalizingStream<St, F>
61where
62 F: FnOnce(),
63{
64 pub const fn new(inner: St, finalize: F) -> Self {
66 Self {
67 inner,
68 done: false,
69 finalize: Some(finalize),
70 }
71 }
72}
73
74impl<St, F> Stream for FinalizingStream<St, F>
75where
76 St: Stream,
77 F: FnOnce(),
78{
79 type Item = St::Item;
80
81 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
82 let proj = self.project();
83 if *proj.done {
91 return Poll::Ready(None);
92 }
93 match proj.inner.poll_next(cx) {
94 Poll::Ready(None) => {
95 *proj.done = true;
96 proj.finalize.take();
98 Poll::Ready(None)
99 }
100 other => other,
101 }
102 }
103
104 fn size_hint(&self) -> (usize, Option<usize>) {
105 if self.done {
106 (0, Some(0))
107 } else {
108 self.inner.size_hint()
109 }
110 }
111}
112
113#[cfg(test)]
114#[allow(clippy::unwrap_used)]
115mod tests {
116 use std::sync::Arc;
117 use std::sync::atomic::{AtomicUsize, Ordering};
118
119 use futures::StreamExt;
120 use futures::stream;
121
122 use super::*;
123
124 fn finalizer(counter: &Arc<AtomicUsize>) -> impl FnOnce() + use<> {
125 let counter = Arc::clone(counter);
126 move || {
127 counter.fetch_add(1, Ordering::SeqCst);
128 }
129 }
130
131 #[tokio::test]
132 async fn finalizer_does_not_fire_on_normal_completion() {
133 let counter = Arc::new(AtomicUsize::new(0));
134 let inner = stream::iter(vec![1, 2, 3]);
135 let mut s = FinalizingStream::new(inner, finalizer(&counter));
136 while s.next().await.is_some() {}
137 drop(s);
138 assert_eq!(counter.load(Ordering::SeqCst), 0);
139 }
140
141 #[tokio::test]
142 async fn finalizer_fires_on_early_drop() {
143 let counter = Arc::new(AtomicUsize::new(0));
144 let inner = stream::iter(0..1000);
145 let mut s = FinalizingStream::new(inner, finalizer(&counter));
146 let _ = s.next().await;
148 drop(s);
149 assert_eq!(counter.load(Ordering::SeqCst), 1);
150 }
151
152 #[tokio::test]
153 async fn finalizer_fires_on_drop_without_polling() {
154 let counter = Arc::new(AtomicUsize::new(0));
155 let inner = stream::iter(0..10);
156 let s = FinalizingStream::new(inner, finalizer(&counter));
157 drop(s);
158 assert_eq!(counter.load(Ordering::SeqCst), 1);
159 }
160
161 #[tokio::test]
162 async fn poll_after_completion_returns_none_without_polling_inner() {
163 struct PanicAfterNone {
167 yielded: bool,
168 ended: bool,
169 }
170 impl Stream for PanicAfterNone {
171 type Item = u32;
172 fn poll_next(
173 mut self: Pin<&mut Self>,
174 _cx: &mut Context<'_>,
175 ) -> Poll<Option<Self::Item>> {
176 if !self.yielded {
177 self.yielded = true;
178 Poll::Ready(Some(7))
179 } else if !self.ended {
180 self.ended = true;
181 Poll::Ready(None)
182 } else {
183 panic!("inner stream polled past completion");
184 }
185 }
186 }
187
188 let counter = Arc::new(AtomicUsize::new(0));
189 let mut s = FinalizingStream::new(
190 PanicAfterNone {
191 yielded: false,
192 ended: false,
193 },
194 finalizer(&counter),
195 );
196 assert_eq!(s.next().await, Some(7));
197 assert_eq!(s.next().await, None);
198 assert_eq!(s.next().await, None);
200 assert_eq!(s.next().await, None);
201 }
202
203 #[tokio::test]
204 async fn finalizer_runs_at_most_once() {
205 let counter = Arc::new(AtomicUsize::new(0));
206 let inner = stream::iter(vec![1]);
208 let mut s = FinalizingStream::new(inner, finalizer(&counter));
209 let _ = s.next().await;
210 let _ = s.next().await; drop(s);
212 assert_eq!(
213 counter.load(Ordering::SeqCst),
214 0,
215 "completion suppresses finalizer"
216 );
217 }
218}