ntex_util/channel/
bstream.rs

1//! Bytes stream
2use std::cell::{Cell, RefCell};
3use std::task::{Context, Poll};
4use std::{collections::VecDeque, fmt, future::poll_fn, pin::Pin, rc::Rc, rc::Weak};
5
6use ntex_bytes::Bytes;
7
8use crate::{Stream, task::LocalWaker};
9
10/// max buffer size 32k
11const MAX_BUFFER_SIZE: usize = 32_768;
12
13#[derive(Copy, Clone, Debug, PartialEq, Eq)]
14pub enum Status {
15    /// Stream is eof
16    Eof,
17    /// Stream is ready
18    Ready,
19    /// Receiver side is dropped
20    Dropped,
21}
22
23/// Create bytes stream.
24///
25/// This method construct two objects responsible for bytes stream
26/// generation.
27pub fn channel<E>() -> (Sender<E>, Receiver<E>) {
28    let inner = Rc::new(Inner::new(false));
29
30    (
31        Sender {
32            inner: Rc::downgrade(&inner),
33        },
34        Receiver { inner },
35    )
36}
37
38/// Create closed bytes stream.
39///
40/// This method construct two objects responsible for bytes stream
41/// generation.
42pub fn eof<E>() -> (Sender<E>, Receiver<E>) {
43    let inner = Rc::new(Inner::new(true));
44
45    (
46        Sender {
47            inner: Rc::downgrade(&inner),
48        },
49        Receiver { inner },
50    )
51}
52
53/// Create empty stream
54pub fn empty<E>(data: Option<Bytes>) -> Receiver<E> {
55    let rx = Receiver {
56        inner: Rc::new(Inner::new(true)),
57    };
58    if let Some(data) = data {
59        rx.put(data);
60    }
61    rx
62}
63
64/// Buffered stream of byte chunks
65///
66/// Payload stores chunks in a vector. Chunks can be received with
67/// `.read()` method.
68#[derive(Debug)]
69pub struct Receiver<E> {
70    inner: Rc<Inner<E>>,
71}
72
73impl<E> Receiver<E> {
74    /// Set size of stream buffer
75    ///
76    /// By default buffer size is set to 32Kb
77    #[inline]
78    pub fn max_buffer_size(&self, size: usize) {
79        self.inner.max_buffer_size.set(size);
80    }
81
82    /// Put unused data back to stream
83    #[inline]
84    pub fn put(&self, data: Bytes) {
85        self.inner.unread_data(data);
86    }
87
88    #[inline]
89    /// Check if stream is eof
90    pub fn is_eof(&self) -> bool {
91        self.inner.flags.get().contains(Flags::EOF)
92    }
93
94    #[inline]
95    /// Read next available bytes chunk
96    pub async fn read(&self) -> Option<Result<Bytes, E>> {
97        poll_fn(|cx| self.poll_read(cx)).await
98    }
99
100    #[inline]
101    pub fn poll_read(&self, cx: &mut Context<'_>) -> Poll<Option<Result<Bytes, E>>> {
102        if let Some(data) = self.inner.get_data() {
103            Poll::Ready(Some(Ok(data)))
104        } else if let Some(err) = self.inner.err.take() {
105            self.inner.insert_flag(Flags::EOF);
106            Poll::Ready(Some(Err(err)))
107        } else if self.inner.flags.get().intersects(Flags::EOF | Flags::ERROR) {
108            Poll::Ready(None)
109        } else {
110            self.inner.recv_task.register(cx.waker());
111            Poll::Pending
112        }
113    }
114}
115
116impl<E> Stream for Receiver<E> {
117    type Item = Result<Bytes, E>;
118
119    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
120        self.poll_read(cx)
121    }
122}
123
124impl<E> Drop for Receiver<E> {
125    fn drop(&mut self) {
126        self.inner.send_task.wake();
127    }
128}
129
130/// Sender part of the payload stream
131///
132/// It is possible to feed data from a cloned sender, but the readiness
133/// check applies only to the most recently called one.
134#[derive(Debug)]
135pub struct Sender<E> {
136    inner: Weak<Inner<E>>,
137}
138
139impl<E> Clone for Sender<E> {
140    fn clone(&self) -> Self {
141        Self {
142            inner: self.inner.clone(),
143        }
144    }
145}
146
147impl<E> Drop for Sender<E> {
148    fn drop(&mut self) {
149        if self.inner.weak_count() == 1 {
150            if let Some(shared) = self.inner.upgrade() {
151                shared.insert_flag(Flags::EOF | Flags::SENDER_GONE);
152            }
153        }
154    }
155}
156
157impl<E> Sender<E> {
158    /// Set stream error
159    pub fn set_error(&self, err: E) {
160        if let Some(shared) = self.inner.upgrade() {
161            shared.set_error(err);
162        }
163    }
164
165    /// Set stream eof
166    pub fn feed_eof(&self) {
167        if let Some(shared) = self.inner.upgrade() {
168            shared.feed_eof();
169        }
170    }
171
172    /// Add chunk to the stream
173    pub fn feed_data(&self, data: Bytes) {
174        if let Some(shared) = self.inner.upgrade() {
175            shared.feed_data(data)
176        }
177    }
178
179    /// Check stream readiness
180    pub async fn ready(&self) -> Status {
181        poll_fn(|cx| self.poll_ready(cx)).await
182    }
183
184    /// Check stream readiness
185    pub fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Status> {
186        if let Some(shared) = self.inner.upgrade() {
187            let flags = shared.flags.get();
188            if flags.contains(Flags::NEED_READ) {
189                Poll::Ready(Status::Ready)
190            } else if flags.contains(Flags::SENDER_GONE | Flags::ERROR) {
191                Poll::Ready(Status::Dropped)
192            } else if flags.intersects(Flags::EOF) {
193                Poll::Ready(Status::Eof)
194            } else {
195                shared.send_task.register(cx.waker());
196                Poll::Pending
197            }
198        } else {
199            // receiver is gone
200            Poll::Ready(Status::Dropped)
201        }
202    }
203}
204
205bitflags::bitflags! {
206    #[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
207    struct Flags: u8 {
208        const EOF         = 0b0000_0001;
209        const ERROR       = 0b0000_0010;
210        const NEED_READ   = 0b0000_0100;
211        const SENDER_GONE = 0b0000_1000;
212    }
213}
214
215struct Inner<E> {
216    len: Cell<usize>,
217    flags: Cell<Flags>,
218    err: Cell<Option<E>>,
219    items: RefCell<VecDeque<Bytes>>,
220    max_buffer_size: Cell<usize>,
221    recv_task: LocalWaker,
222    send_task: LocalWaker,
223}
224
225impl<E> Inner<E> {
226    fn new(eof: bool) -> Self {
227        let flags = if eof { Flags::EOF } else { Flags::NEED_READ };
228        Inner {
229            flags: Cell::new(flags),
230            len: Cell::new(0),
231            err: Cell::new(None),
232            items: RefCell::new(VecDeque::new()),
233            recv_task: LocalWaker::new(),
234            send_task: LocalWaker::new(),
235            max_buffer_size: Cell::new(MAX_BUFFER_SIZE),
236        }
237    }
238
239    fn insert_flag(&self, f: Flags) {
240        let mut flags = self.flags.get();
241        flags.insert(f);
242        self.flags.set(flags);
243    }
244
245    fn remove_flag(&self, f: Flags) {
246        let mut flags = self.flags.get();
247        flags.remove(f);
248        self.flags.set(flags);
249    }
250
251    fn set_error(&self, err: E) {
252        self.err.set(Some(err));
253        self.insert_flag(Flags::ERROR);
254        self.recv_task.wake();
255        self.send_task.wake();
256    }
257
258    fn feed_eof(&self) {
259        self.insert_flag(Flags::EOF);
260        self.recv_task.wake();
261        self.send_task.wake();
262    }
263
264    fn feed_data(&self, data: Bytes) {
265        let len = self.len.get() + data.len();
266        self.len.set(len);
267        self.items.borrow_mut().push_back(data);
268        self.recv_task.wake();
269
270        if len >= self.max_buffer_size.get() {
271            self.remove_flag(Flags::NEED_READ);
272        }
273    }
274
275    fn get_data(&self) -> Option<Bytes> {
276        self.items.borrow_mut().pop_front().inspect(|data| {
277            let len = self.len.get() - data.len();
278
279            // check size of stream buffer,
280            // if stream has more space wake up sender
281            self.len.set(len);
282            if len < self.max_buffer_size.get() {
283                self.insert_flag(Flags::NEED_READ);
284                self.send_task.wake();
285            }
286        })
287    }
288
289    fn unread_data(&self, data: Bytes) {
290        if !data.is_empty() {
291            self.len.set(self.len.get() + data.len());
292            self.items.borrow_mut().push_front(data);
293        }
294    }
295}
296
297impl<E> fmt::Debug for Inner<E> {
298    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
299        f.debug_struct("Inner")
300            .field("len", &self.len)
301            .field("flags", &self.flags)
302            .field("items", &self.items.borrow())
303            .field("max_buffer_size", &self.max_buffer_size)
304            .field("recv_task", &self.recv_task)
305            .field("send_task", &self.send_task)
306            .finish()
307    }
308}
309
310#[cfg(test)]
311mod tests {
312    use super::*;
313
314    #[ntex::test]
315    async fn test_eof() {
316        let (_, rx) = eof::<()>();
317        assert!(rx.read().await.is_none());
318    }
319
320    #[ntex::test]
321    async fn test_unread_data() {
322        let (_, payload) = channel::<()>();
323
324        payload.put(Bytes::from("data"));
325        assert_eq!(payload.inner.len.get(), 4);
326        assert_eq!(
327            Bytes::from("data"),
328            poll_fn(|cx| payload.poll_read(cx)).await.unwrap().unwrap()
329        );
330    }
331
332    #[ntex::test]
333    async fn test_sender_clone() {
334        let (sender, payload) = channel::<()>();
335        assert!(!payload.is_eof());
336        let sender2 = sender.clone();
337        assert!(!payload.is_eof());
338        drop(sender2);
339        assert!(!payload.is_eof());
340        drop(sender);
341        assert!(payload.is_eof());
342    }
343}