ogre_stream_ext/
finalization_callback_ext.rs

1//! Adds a new Stream combinator able to call a single function once the Stream ends or is canceled
2//! 
3//! Please see [crate::StreamWithFinalizationCallbacks] for a version that does indeed distinguish between
4//! completion and cancellation.
5
6use std::{
7    pin::Pin,
8    task::{Context, Poll},
9};
10use std::future::Future;
11use std::sync::atomic::{AtomicBool, Ordering};
12use futures::Stream;
13
14/// A Stream wrapper that calls a closure once the Stream either ends or is cancelled:
15pub struct StreamWithFinalizationCallback<S, FinalizationFn, FnFut>
16where
17    S: Stream,
18    FinalizationFn: FnOnce() -> FnFut,
19    FnFut: Future<Output = ()> + Send + 'static,
20{
21    inner: S,
22    finalization_fn: Option<FinalizationFn>,
23    /// Needed to avoid double-firing when the stream ends gracefully in one thread and
24    /// is immediately dropped by another
25    finalized: AtomicBool,
26}
27
28impl<S, FinalizationFn, FnFut> StreamWithFinalizationCallback<S, FinalizationFn, FnFut>
29where
30    S: Stream,
31    FinalizationFn: FnOnce() -> FnFut,
32    FnFut: Future<Output = ()> + Send + 'static,
33{
34    /// Construct a new wrapper that calls `finalization_fn` once when either:
35    ///  - `inner.poll_next()` returns `Ready(None)`, or
36    ///  - the wrapper is dropped before seeing `None`.
37    pub fn new(inner: S, finalization_fn: FinalizationFn) -> Self {
38        StreamWithFinalizationCallback {
39            inner,
40            finalization_fn: Some(finalization_fn),
41            finalized: AtomicBool::new(false),
42        }
43    }
44}
45
46impl<S, FinalizationFn, FnFut, T> Stream for StreamWithFinalizationCallback<S, FinalizationFn, FnFut>
47where
48    S: Stream<Item = T> + Unpin,
49    FinalizationFn: FnOnce() -> FnFut,
50    FnFut: Future<Output = ()> + Send + 'static,
51{
52    type Item = T;
53
54    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
55        // SAFETY:
56        //   - We only call `get_unchecked_mut()` because we know:
57        //     1) `StreamWithCallbacks<…>` is structurally pinned (it won’t be moved after pinned),
58        //     2) We never move `inner` or the callback fields out of that pinned memory except by taking them (which is OK),
59        //     3) `inner: S` is `Unpin`, so it’s safe to create a `Pin<&mut S>` from `&mut inner`.
60        //
61        // In other words, after calling `get_unchecked_mut()`, we are free to mutate
62        // the fields through `this`, and then re-pin `inner` via `Pin::new(&mut this.inner)`.
63        let this: &mut Self = unsafe { self.get_unchecked_mut() };
64
65        match Pin::new(&mut this.inner).poll_next(cx) {
66            Poll::Ready(Some(item)) => Poll::Ready(Some(item)),
67
68            Poll::Ready(None) => {
69                // The inner stream is done. If we have not yet called `finalization_fn`, do so now.
70                if let Some(finalization_fn) = this.finalization_fn.take() {
71                    let finalized = this.finalized.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |_| Some(true)).unwrap_or_default();
72                    if !finalized {
73                        tokio::runtime::Handle::current().spawn(finalization_fn());
74                    }
75                }
76                Poll::Ready(None)
77            }
78
79            Poll::Pending => Poll::Pending,
80        }
81    }
82}
83
84impl<S, FinalizationFn, FnFut> Drop for StreamWithFinalizationCallback<S, FinalizationFn, FnFut>
85where
86    S: Stream,
87    FinalizationFn: FnOnce() -> FnFut,
88    FnFut: Future<Output = ()> + Send + 'static,
89{
90    fn drop(&mut self) {
91        // If we never reached the “finished” state, that means the user dropped the stream early —
92        // so we call `finalization_fn` if it’s still present.
93        if let Some(finalization_fn) = self.finalization_fn.take() {
94            let finalized = self.finalized.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |_| Some(true)).unwrap_or_default();
95            if !finalized {
96                let handle = tokio::runtime::Handle::current();
97                let _guard = handle.enter();
98                handle.spawn(finalization_fn());
99            }
100        }
101    }
102}
103
104/// The extension trait that adds `.on_complete_or_cancellation(...)` to all `Stream`s.
105///
106/// To use it, do:
107///    use ogre_stream_ext::StreamExtFinalizationCallback;
108///    use futures::StreamExt; // if you also want `.map()`, `.filter()`, etc.
109///
110/// Then:
111///    let lock = ...;
112///    mystream
113///       .map(|x| …)
114///       .on_complete_or_cancellation(move || future::ready(drop(lock)))
115pub trait StreamExtFinalizationCallback: Stream + Sized {
116    fn on_complete_or_cancellation<FinalizationFn, FnFut>(
117        self,
118        finalization_fn: FinalizationFn,
119    ) -> StreamWithFinalizationCallback<Self, FinalizationFn, FnFut>
120    where
121        FinalizationFn: FnOnce() -> FnFut,
122        FnFut: Future<Output=()> + Send + 'static,
123    {
124        StreamWithFinalizationCallback::new(self, finalization_fn)
125    }
126}
127
128impl<S: Stream> StreamExtFinalizationCallback for S {}