ogre_stream_ext/
finalization_callback_ext.rs1use std::{
7 pin::Pin,
8 task::{Context, Poll},
9};
10use std::future::Future;
11use std::sync::atomic::{AtomicBool, Ordering};
12use futures::Stream;
13
14pub struct StreamWithFinalizationCallback<S, FinalizationFn, FnFut>
16where
17 S: Stream,
18 FinalizationFn: FnOnce() -> FnFut,
19 FnFut: Future<Output = ()> + Send + 'static,
20{
21 inner: S,
22 finalization_fn: Option<FinalizationFn>,
23 finalized: AtomicBool,
26}
27
28impl<S, FinalizationFn, FnFut> StreamWithFinalizationCallback<S, FinalizationFn, FnFut>
29where
30 S: Stream,
31 FinalizationFn: FnOnce() -> FnFut,
32 FnFut: Future<Output = ()> + Send + 'static,
33{
34 pub fn new(inner: S, finalization_fn: FinalizationFn) -> Self {
38 StreamWithFinalizationCallback {
39 inner,
40 finalization_fn: Some(finalization_fn),
41 finalized: AtomicBool::new(false),
42 }
43 }
44}
45
46impl<S, FinalizationFn, FnFut, T> Stream for StreamWithFinalizationCallback<S, FinalizationFn, FnFut>
47where
48 S: Stream<Item = T> + Unpin,
49 FinalizationFn: FnOnce() -> FnFut,
50 FnFut: Future<Output = ()> + Send + 'static,
51{
52 type Item = T;
53
54 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
55 let this: &mut Self = unsafe { self.get_unchecked_mut() };
64
65 match Pin::new(&mut this.inner).poll_next(cx) {
66 Poll::Ready(Some(item)) => Poll::Ready(Some(item)),
67
68 Poll::Ready(None) => {
69 if let Some(finalization_fn) = this.finalization_fn.take() {
71 let finalized = this.finalized.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |_| Some(true)).unwrap_or_default();
72 if !finalized {
73 tokio::runtime::Handle::current().spawn(finalization_fn());
74 }
75 }
76 Poll::Ready(None)
77 }
78
79 Poll::Pending => Poll::Pending,
80 }
81 }
82}
83
84impl<S, FinalizationFn, FnFut> Drop for StreamWithFinalizationCallback<S, FinalizationFn, FnFut>
85where
86 S: Stream,
87 FinalizationFn: FnOnce() -> FnFut,
88 FnFut: Future<Output = ()> + Send + 'static,
89{
90 fn drop(&mut self) {
91 if let Some(finalization_fn) = self.finalization_fn.take() {
94 let finalized = self.finalized.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |_| Some(true)).unwrap_or_default();
95 if !finalized {
96 let handle = tokio::runtime::Handle::current();
97 let _guard = handle.enter();
98 handle.spawn(finalization_fn());
99 }
100 }
101 }
102}
103
104pub trait StreamExtFinalizationCallback: Stream + Sized {
116 fn on_complete_or_cancellation<FinalizationFn, FnFut>(
117 self,
118 finalization_fn: FinalizationFn,
119 ) -> StreamWithFinalizationCallback<Self, FinalizationFn, FnFut>
120 where
121 FinalizationFn: FnOnce() -> FnFut,
122 FnFut: Future<Output=()> + Send + 'static,
123 {
124 StreamWithFinalizationCallback::new(self, finalization_fn)
125 }
126}
127
128impl<S: Stream> StreamExtFinalizationCallback for S {}