external_buffered_stream/
lib.rs

1mod buffer;
2mod error;
3mod serde;
4mod runtime;
5
6pub use buffer::*;
7pub use error::*;
8pub use serde::*;
9
10use std::{
11    marker::PhantomData,
12    pin::Pin,
13    sync::{
14        Arc,
15        atomic::{AtomicBool, Ordering},
16    },
17    task::{Context, Poll},
18};
19
20use futures::{FutureExt, SinkExt, Stream, StreamExt, channel::mpsc};
21
22pub struct ExternalBufferedStream<T, B, S>
23where
24    T: Send,
25    B: ExternalBuffer<T>,
26    S: Stream<Item = T>,
27{
28    buffer: Arc<B>,
29    _source: PhantomData<S>,
30    notify: mpsc::UnboundedReceiver<()>,
31    stop_flag: Arc<AtomicBool>,
32}
33
34impl<T, B, S> ExternalBufferedStream<T, B, S>
35where
36    T: Send,
37    B: ExternalBuffer<T> + 'static,
38    S: Stream<Item = T> + Send + 'static,
39{
40    pub fn new(source: S, buffer: B) -> Self {
41        let source = Box::pin(source);
42
43        let buffer = Arc::new(buffer);
44        let buffer_clone = buffer.clone();
45
46        let (notify_tx, notify_rx) = mpsc::unbounded::<()>();
47
48        let stop_flag = Arc::new(AtomicBool::new(false));
49        let stop_flag_clone = stop_flag.clone();
50
51        let handle_source = async move {
52            let mut source = source;
53            let mut notify_tx = notify_tx;
54            while let Some(item) = source.next().await {
55                match buffer_clone.push(item).await {
56                    Ok(()) => match notify_tx.send(()).await {
57                        Ok(_) => {}
58                        Err(e) => {
59                            log::error!("Failed to notify: {:?}", e);
60                            break;
61                        }
62                    },
63                    Err(e) => {
64                        log::error!("Failed to push item to buffer: {:?}", e);
65                        break;
66                    }
67                }
68            }
69            log::info!("Source stream is ended");
70            stop_flag_clone.store(true, Ordering::SeqCst);
71            _ = notify_tx.send(())
72        };
73        runtime::spawn(handle_source);
74
75        ExternalBufferedStream {
76            buffer,
77            _source: PhantomData,
78            notify: notify_rx,
79            stop_flag,
80        }
81    }
82}
83
84impl<T, B, S> Stream for ExternalBufferedStream<T, B, S>
85where
86    T: Send,
87    B: ExternalBuffer<T> + 'static,
88    S: Stream<Item = T> + Send + 'static,
89{
90    type Item = T;
91
92    fn poll_next(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
93        // S is PhantomData, so here is safe to get mut
94        let this = unsafe { self.get_unchecked_mut() };
95
96        loop {
97            match this.buffer.shift().poll_unpin(ctx) {
98                Poll::Ready(next) => match next {
99                    Ok(Some(item)) => return Poll::Ready(Some(item)),
100                    Ok(None) => {
101                        let mut wait = (&mut this.notify).next();
102                        match wait.poll_unpin(ctx) {
103                            Poll::Ready(_) => {
104                                if this.stop_flag.load(Ordering::SeqCst) {
105                                    break Poll::Ready(None);
106                                } else {
107                                    continue;
108                                }
109                            }
110                            Poll::Pending => return Poll::Pending,
111                        }
112                    }
113                    Err(err) => {
114                        log::error!("poll external buffer error: {}", err);
115                        return Poll::Ready(None);
116                    }
117                },
118                Poll::Pending => return Poll::Pending,
119            }
120        }
121    }
122}
123
124#[cfg(feature = "default")]
125pub fn create_external_buffered_stream<T, S, P>(
126    stream: S,
127    path: P,
128) -> Result<ExternalBufferedStream<T, ExternalBufferSled, S>, Error>
129where
130    T: ExternalBufferSerde + Send + 'static,
131    S: Stream<Item = T> + Send + Sync + 'static,
132    P: AsRef<std::path::Path>,
133{
134    Ok(ExternalBufferedStream::new(
135        stream,
136        ExternalBufferSled::new(path)?,
137    ))
138}
139
140#[cfg(feature = "queue")]
141pub fn create_queued_stream<T, S>(
142    stream: S,
143) -> Result<ExternalBufferedStream<T, ExternalBufferQueue<T>, S>, Error>
144where
145    T: Ord + Send + 'static,
146    S: Stream<Item = T> + Send + Sync + 'static,
147{
148    Ok(ExternalBufferedStream::new(
149        stream,
150        ExternalBufferQueue::new(),
151    ))
152}