ntex_util/channel/
bstream.rs

1//! Bytes stream
2use std::cell::{Cell, RefCell};
3use std::task::{Context, Poll};
4use std::{collections::VecDeque, fmt, future::poll_fn, pin::Pin, rc::Rc, rc::Weak};
5
6use ntex_bytes::Bytes;
7
8use crate::{task::LocalWaker, Stream};
9
10/// max buffer size 32k
11const MAX_BUFFER_SIZE: usize = 32_768;
12
13#[derive(Copy, Clone, Debug, PartialEq, Eq)]
14pub enum Status {
15    /// Stream is ready
16    Ready,
17    /// Receiver side is dropped
18    Dropped,
19}
20
21/// Create bytes stream.
22///
23/// This method construct two objects responsible for bytes stream
24/// generation.
25pub fn channel<E>() -> (Sender<E>, Receiver<E>) {
26    let inner = Rc::new(Inner::new(false));
27
28    (
29        Sender {
30            inner: Rc::downgrade(&inner),
31        },
32        Receiver { inner },
33    )
34}
35
36/// Create closed bytes stream.
37///
38/// This method construct two objects responsible for bytes stream
39/// generation.
40pub fn eof<E>() -> (Sender<E>, Receiver<E>) {
41    let inner = Rc::new(Inner::new(true));
42
43    (
44        Sender {
45            inner: Rc::downgrade(&inner),
46        },
47        Receiver { inner },
48    )
49}
50
51/// Create empty stream
52pub fn empty<E>(data: Option<Bytes>) -> Receiver<E> {
53    let rx = Receiver {
54        inner: Rc::new(Inner::new(true)),
55    };
56    if let Some(data) = data {
57        rx.put(data);
58    }
59    rx
60}
61
62/// Buffered stream of byte chunks
63///
64/// Payload stores chunks in a vector. Chunks can be received with
65/// `.read()` method.
66#[derive(Debug)]
67pub struct Receiver<E> {
68    inner: Rc<Inner<E>>,
69}
70
71impl<E> Receiver<E> {
72    /// Set max stream size
73    ///
74    /// By default max buffer size is set to 32Kb
75    #[inline]
76    pub fn max_size(&self, size: usize) {
77        self.inner.max_size.set(size);
78    }
79
80    /// Put unused data back to stream
81    #[inline]
82    pub fn put(&self, data: Bytes) {
83        self.inner.unread_data(data);
84    }
85
86    #[inline]
87    /// Read next available bytes chunk
88    pub async fn read(&self) -> Option<Result<Bytes, E>> {
89        poll_fn(|cx| self.poll_read(cx)).await
90    }
91
92    #[inline]
93    pub fn poll_read(&self, cx: &mut Context<'_>) -> Poll<Option<Result<Bytes, E>>> {
94        if let Some(data) = self.inner.items.borrow_mut().pop_front() {
95            let len = self.inner.len.get() - data.len();
96            self.inner.len.set(len);
97            let need_read = if len < self.inner.max_size.get() {
98                self.inner.insert_flag(Flags::NEED_READ);
99                true
100            } else {
101                self.inner.remove_flag(Flags::NEED_READ);
102                false
103            };
104            if need_read {
105                self.inner.rx_task.register(cx.waker());
106                self.inner.tx_task.wake();
107            }
108            Poll::Ready(Some(Ok(data)))
109        } else if let Some(err) = self.inner.err.take() {
110            Poll::Ready(Some(Err(err)))
111        } else if self.inner.flags.get().intersects(Flags::EOF | Flags::ERROR) {
112            Poll::Ready(None)
113        } else {
114            self.inner.insert_flag(Flags::NEED_READ);
115            self.inner.rx_task.register(cx.waker());
116            self.inner.tx_task.wake();
117            Poll::Pending
118        }
119    }
120}
121
122impl<E> Stream for Receiver<E> {
123    type Item = Result<Bytes, E>;
124
125    fn poll_next(
126        self: Pin<&mut Self>,
127        cx: &mut Context<'_>,
128    ) -> Poll<Option<Result<Bytes, E>>> {
129        self.poll_read(cx)
130    }
131}
132
133/// Sender part of the payload stream
134#[derive(Debug)]
135pub struct Sender<E> {
136    inner: Weak<Inner<E>>,
137}
138
139impl<E> Drop for Sender<E> {
140    fn drop(&mut self) {
141        if let Some(shared) = self.inner.upgrade() {
142            shared.insert_flag(Flags::EOF);
143        }
144    }
145}
146
147impl<E> Sender<E> {
148    /// Set stream error
149    pub fn set_error(&self, err: E) {
150        if let Some(shared) = self.inner.upgrade() {
151            shared.set_error(err);
152        }
153    }
154
155    /// Set stream eof
156    pub fn feed_eof(&self) {
157        if let Some(shared) = self.inner.upgrade() {
158            shared.feed_eof();
159        }
160    }
161
162    /// Add chunk to the stream
163    pub fn feed_data(&self, data: Bytes) {
164        if let Some(shared) = self.inner.upgrade() {
165            shared.feed_data(data)
166        }
167    }
168
169    /// Check stream readiness
170    pub async fn ready(&self) -> Status {
171        poll_fn(|cx| self.poll_ready(cx)).await
172    }
173
174    /// Check stream readiness
175    pub fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Status> {
176        // we check only if Payload (other side) is alive,
177        // otherwise always return true (consume payload)
178        if let Some(shared) = self.inner.upgrade() {
179            if shared.flags.get().contains(Flags::NEED_READ) {
180                Poll::Ready(Status::Ready)
181            } else {
182                shared.tx_task.register(cx.waker());
183                Poll::Pending
184            }
185        } else {
186            Poll::Ready(Status::Dropped)
187        }
188    }
189}
190
191bitflags::bitflags! {
192    #[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
193    struct Flags: u8 {
194        const EOF         = 0b0000_0001;
195        const ERROR       = 0b0000_0010;
196        const NEED_READ   = 0b0000_0100;
197        const SENDER_GONE = 0b0000_1000;
198    }
199}
200
201struct Inner<E> {
202    len: Cell<usize>,
203    flags: Cell<Flags>,
204    err: Cell<Option<E>>,
205    items: RefCell<VecDeque<Bytes>>,
206    max_size: Cell<usize>,
207    rx_task: LocalWaker,
208    tx_task: LocalWaker,
209}
210
211impl<E> Inner<E> {
212    fn new(eof: bool) -> Self {
213        let flags = if eof {
214            Flags::EOF | Flags::NEED_READ
215        } else {
216            Flags::NEED_READ
217        };
218        Inner {
219            flags: Cell::new(flags),
220            len: Cell::new(0),
221            err: Cell::new(None),
222            items: RefCell::new(VecDeque::new()),
223            rx_task: LocalWaker::new(),
224            tx_task: LocalWaker::new(),
225            max_size: Cell::new(MAX_BUFFER_SIZE),
226        }
227    }
228
229    fn insert_flag(&self, f: Flags) {
230        let mut flags = self.flags.get();
231        flags.insert(f);
232        self.flags.set(flags);
233    }
234
235    fn remove_flag(&self, f: Flags) {
236        let mut flags = self.flags.get();
237        flags.remove(f);
238        self.flags.set(flags);
239    }
240
241    fn set_error(&self, err: E) {
242        self.err.set(Some(err));
243        self.insert_flag(Flags::ERROR);
244        self.rx_task.wake()
245    }
246
247    fn feed_eof(&self) {
248        self.insert_flag(Flags::EOF);
249        self.rx_task.wake()
250    }
251
252    fn feed_data(&self, data: Bytes) {
253        let len = self.len.get() + data.len();
254        self.len.set(len);
255        self.items.borrow_mut().push_back(data);
256        if len < self.max_size.get() {
257            self.insert_flag(Flags::NEED_READ);
258        } else {
259            self.remove_flag(Flags::NEED_READ);
260        }
261        self.rx_task.wake();
262    }
263
264    fn unread_data(&self, data: Bytes) {
265        if !data.is_empty() {
266            self.len.set(self.len.get() + data.len());
267            self.items.borrow_mut().push_front(data);
268        }
269    }
270}
271
272impl<E> fmt::Debug for Inner<E> {
273    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
274        f.debug_struct("Inner")
275            .field("len", &self.len)
276            .field("flags", &self.flags)
277            .field("items", &self.items.borrow())
278            .field("max_size", &self.max_size)
279            .field("rx_task", &self.rx_task)
280            .field("tx_task", &self.tx_task)
281            .finish()
282    }
283}
284
285#[cfg(test)]
286mod tests {
287    use super::*;
288
289    #[ntex_macros::rt_test2]
290    async fn test_eof() {
291        let (_, rx) = eof::<()>();
292        assert!(rx.read().await.is_none());
293    }
294
295    #[ntex_macros::rt_test2]
296    async fn test_unread_data() {
297        let (_, payload) = channel::<()>();
298
299        payload.put(Bytes::from("data"));
300        assert_eq!(payload.inner.len.get(), 4);
301        assert_eq!(
302            Bytes::from("data"),
303            poll_fn(|cx| payload.poll_read(cx)).await.unwrap().unwrap()
304        );
305    }
306}