Skip to main content

ntex_util/channel/
bstream.rs

1//! Bytes stream
2use std::cell::{Cell, RefCell};
3use std::task::{Context, Poll};
4use std::{collections::VecDeque, fmt, future::poll_fn, pin::Pin, rc::Rc, rc::Weak};
5
6use ntex_bytes::Bytes;
7
8use crate::{Stream, task::LocalWaker};
9
10/// max buffer size 32k
11const MAX_BUFFER_SIZE: usize = 32_768;
12
13/// Indicates the current status of a byte stream.
14#[derive(Copy, Clone, Debug, PartialEq, Eq)]
15pub enum Status {
16    /// End of stream reached.
17    Eof,
18    /// Stream is ready to accept more bytes.
19    Ready,
20    /// The receiver side has been dropped.
21    Dropped,
22}
23
24/// Creates a byte stream.
25///
26/// This method constructs two objects responsible for generating
27/// and consuming the byte stream.
28pub fn channel<E>() -> (Sender<E>, Receiver<E>) {
29    let inner = Rc::new(Inner::new(false));
30
31    (
32        Sender {
33            inner: Rc::downgrade(&inner),
34        },
35        Receiver { inner },
36    )
37}
38
39/// Create closed bytes stream.
40///
41/// This method construct two objects responsible for bytes stream
42/// generation.
43pub fn eof<E>() -> (Sender<E>, Receiver<E>) {
44    let inner = Rc::new(Inner::new(true));
45
46    (
47        Sender {
48            inner: Rc::downgrade(&inner),
49        },
50        Receiver { inner },
51    )
52}
53
54/// Creates an empty byte stream.
55pub fn empty<E>(data: Option<Bytes>) -> Receiver<E> {
56    let rx = Receiver {
57        inner: Rc::new(Inner::new(true)),
58    };
59    if let Some(data) = data {
60        rx.put(data);
61    }
62    rx
63}
64
65/// A buffered stream of byte chunks.
66///
67/// Incoming payload data is stored internally as a vector of chunks.
68/// Chunks can be retrieved incrementally using the `.read()` method.
69#[derive(Debug)]
70pub struct Receiver<E> {
71    inner: Rc<Inner<E>>,
72}
73
74impl<E> Receiver<E> {
75    /// Sets the size of the stream buffer.
76    ///
77    /// By default, the buffer size is 32 KB.
78    #[inline]
79    pub fn max_buffer_size(&self, size: usize) {
80        self.inner.max_buffer_size.set(size);
81    }
82
83    /// Puts unused data back into the stream.
84    #[inline]
85    pub fn put(&self, data: Bytes) {
86        self.inner.unread_data(data);
87    }
88
89    #[inline]
90    /// Returns `true` if the stream has reached EOF.
91    pub fn is_eof(&self) -> bool {
92        self.inner.flags.get().contains(Flags::EOF)
93    }
94
95    #[inline]
96    /// Reads the next available chunk of bytes from the stream.
97    pub async fn read(&self) -> Option<Result<Bytes, E>> {
98        poll_fn(|cx| self.poll_read(cx)).await
99    }
100
101    #[inline]
102    /// Attempts to read the next available chunk of bytes from the stream.
103    pub fn poll_read(&self, cx: &mut Context<'_>) -> Poll<Option<Result<Bytes, E>>> {
104        if let Some(data) = self.inner.get_data() {
105            Poll::Ready(Some(Ok(data)))
106        } else if let Some(err) = self.inner.err.take() {
107            self.inner.insert_flag(Flags::EOF);
108            Poll::Ready(Some(Err(err)))
109        } else if self.inner.flags.get().intersects(Flags::EOF | Flags::ERROR) {
110            Poll::Ready(None)
111        } else {
112            self.inner.recv_task.register(cx.waker());
113            Poll::Pending
114        }
115    }
116}
117
118impl<E> Stream for Receiver<E> {
119    type Item = Result<Bytes, E>;
120
121    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
122        self.poll_read(cx)
123    }
124}
125
126impl<E> Drop for Receiver<E> {
127    fn drop(&mut self) {
128        self.inner.send_task.wake();
129    }
130}
131
132/// Sender side of the byte stream.
133///
134/// It is possible to send data from a cloned sender, but readiness
135/// checks apply only to the most recently used instance.
136#[derive(Debug)]
137pub struct Sender<E> {
138    inner: Weak<Inner<E>>,
139}
140
141impl<E> Clone for Sender<E> {
142    fn clone(&self) -> Self {
143        Self {
144            inner: self.inner.clone(),
145        }
146    }
147}
148
149impl<E> Drop for Sender<E> {
150    fn drop(&mut self) {
151        if self.inner.weak_count() == 1
152            && let Some(shared) = self.inner.upgrade()
153        {
154            shared.insert_flag(Flags::EOF | Flags::SENDER_GONE);
155        }
156    }
157}
158
159impl<E> Sender<E> {
160    /// Returns whether this channel is closed.
161    pub fn is_closed(&self) -> bool {
162        self.inner.strong_count() == 0
163    }
164
165    /// Sets the stream error.
166    pub fn set_error(&self, err: E) {
167        if let Some(shared) = self.inner.upgrade() {
168            shared.set_error(err);
169        }
170    }
171
172    /// Marks the stream as EOF.
173    pub fn feed_eof(&self) {
174        if let Some(shared) = self.inner.upgrade() {
175            shared.feed_eof();
176        }
177    }
178
179    /// Adds a chunk to the stream.
180    pub fn feed_data(&self, data: Bytes) {
181        if let Some(shared) = self.inner.upgrade() {
182            shared.feed_data(data);
183        }
184    }
185
186    /// Checks whether the stream is ready for operation.
187    pub async fn ready(&self) -> Status {
188        poll_fn(|cx| self.poll_ready(cx)).await
189    }
190
191    /// Checks whether the stream is ready for operation.
192    pub fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Status> {
193        if let Some(shared) = self.inner.upgrade() {
194            let flags = shared.flags.get();
195            if flags.contains(Flags::NEED_READ) {
196                Poll::Ready(Status::Ready)
197            } else if flags.contains(Flags::SENDER_GONE | Flags::ERROR) {
198                Poll::Ready(Status::Dropped)
199            } else if flags.intersects(Flags::EOF) {
200                Poll::Ready(Status::Eof)
201            } else {
202                shared.send_task.register(cx.waker());
203                Poll::Pending
204            }
205        } else {
206            // receiver is gone
207            Poll::Ready(Status::Dropped)
208        }
209    }
210}
211
212bitflags::bitflags! {
213    #[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
214    struct Flags: u8 {
215        const EOF         = 0b0000_0001;
216        const ERROR       = 0b0000_0010;
217        const NEED_READ   = 0b0000_0100;
218        const SENDER_GONE = 0b0000_1000;
219    }
220}
221
222struct Inner<E> {
223    len: Cell<usize>,
224    flags: Cell<Flags>,
225    err: Cell<Option<E>>,
226    items: RefCell<VecDeque<Bytes>>,
227    max_buffer_size: Cell<usize>,
228    recv_task: LocalWaker,
229    send_task: LocalWaker,
230}
231
232impl<E> Inner<E> {
233    fn new(eof: bool) -> Self {
234        let flags = if eof { Flags::EOF } else { Flags::NEED_READ };
235        Inner {
236            flags: Cell::new(flags),
237            len: Cell::new(0),
238            err: Cell::new(None),
239            items: RefCell::new(VecDeque::new()),
240            recv_task: LocalWaker::new(),
241            send_task: LocalWaker::new(),
242            max_buffer_size: Cell::new(MAX_BUFFER_SIZE),
243        }
244    }
245
246    fn insert_flag(&self, f: Flags) {
247        let mut flags = self.flags.get();
248        flags.insert(f);
249        self.flags.set(flags);
250    }
251
252    fn remove_flag(&self, f: Flags) {
253        let mut flags = self.flags.get();
254        flags.remove(f);
255        self.flags.set(flags);
256    }
257
258    fn set_error(&self, err: E) {
259        self.err.set(Some(err));
260        self.insert_flag(Flags::ERROR);
261        self.recv_task.wake();
262        self.send_task.wake();
263    }
264
265    fn feed_eof(&self) {
266        self.insert_flag(Flags::EOF);
267        self.recv_task.wake();
268        self.send_task.wake();
269    }
270
271    fn feed_data(&self, data: Bytes) {
272        let len = self.len.get() + data.len();
273        self.len.set(len);
274        self.items.borrow_mut().push_back(data);
275        self.recv_task.wake();
276
277        if len >= self.max_buffer_size.get() {
278            self.remove_flag(Flags::NEED_READ);
279        }
280    }
281
282    fn get_data(&self) -> Option<Bytes> {
283        self.items.borrow_mut().pop_front().inspect(|data| {
284            let len = self.len.get() - data.len();
285
286            // check size of stream buffer,
287            // if stream has more space wake up sender
288            self.len.set(len);
289            if len < self.max_buffer_size.get() {
290                self.insert_flag(Flags::NEED_READ);
291                self.send_task.wake();
292            }
293        })
294    }
295
296    fn unread_data(&self, data: Bytes) {
297        if !data.is_empty() {
298            self.len.set(self.len.get() + data.len());
299            self.items.borrow_mut().push_front(data);
300        }
301    }
302}
303
304impl<E> fmt::Debug for Inner<E> {
305    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
306        f.debug_struct("Inner")
307            .field("len", &self.len)
308            .field("flags", &self.flags)
309            .field("items", &self.items.borrow())
310            .field("max_buffer_size", &self.max_buffer_size)
311            .field("recv_task", &self.recv_task)
312            .field("send_task", &self.send_task)
313            .finish()
314    }
315}
316
317#[cfg(test)]
318mod tests {
319    use super::*;
320
321    #[ntex::test]
322    async fn test_eof() {
323        let (tx, rx) = eof::<()>();
324        rx.max_buffer_size(100);
325        assert!(rx.read().await.is_none());
326        assert_eq!(tx.ready().await, Status::Eof);
327    }
328
329    #[ntex::test]
330    async fn test_closed() {
331        // drop receiver
332        let (tx, rx) = channel::<()>();
333        assert!(!tx.is_closed());
334        drop(rx);
335        assert!(tx.is_closed());
336
337        // drop sender
338        let (tx, rx) = channel::<()>();
339        drop(tx);
340        assert_eq!(rx.read().await, None);
341    }
342
343    #[ntex::test]
344    async fn test_unread_data() {
345        let (_, payload) = channel::<()>();
346
347        payload.put(Bytes::from("data"));
348        assert_eq!(payload.inner.len.get(), 4);
349        assert_eq!(
350            Bytes::from("data"),
351            poll_fn(|cx| payload.poll_read(cx)).await.unwrap().unwrap()
352        );
353    }
354
355    #[ntex::test]
356    async fn test_sender_clone() {
357        let (sender, payload) = channel::<()>();
358        assert!(!payload.is_eof());
359        let sender2 = sender.clone();
360        assert!(!payload.is_eof());
361        drop(sender2);
362        assert!(!payload.is_eof());
363        drop(sender);
364        assert!(payload.is_eof());
365    }
366}