ntex_util/channel/
bstream.rs

1//! Bytes stream
2use std::cell::{Cell, RefCell};
3use std::task::{Context, Poll};
4use std::{collections::VecDeque, fmt, future::poll_fn, pin::Pin, rc::Rc, rc::Weak};
5
6use ntex_bytes::Bytes;
7
8use crate::{Stream, task::LocalWaker};
9
10/// max buffer size 32k
11const MAX_BUFFER_SIZE: usize = 32_768;
12
13#[derive(Copy, Clone, Debug, PartialEq, Eq)]
14pub enum Status {
15    /// Stream is eof
16    Eof,
17    /// Stream is ready
18    Ready,
19    /// Receiver side is dropped
20    Dropped,
21}
22
23/// Create bytes stream.
24///
25/// This method construct two objects responsible for bytes stream
26/// generation.
27pub fn channel<E>() -> (Sender<E>, Receiver<E>) {
28    let inner = Rc::new(Inner::new(false));
29
30    (
31        Sender {
32            inner: Rc::downgrade(&inner),
33        },
34        Receiver { inner },
35    )
36}
37
38/// Create closed bytes stream.
39///
40/// This method construct two objects responsible for bytes stream
41/// generation.
42pub fn eof<E>() -> (Sender<E>, Receiver<E>) {
43    let inner = Rc::new(Inner::new(true));
44
45    (
46        Sender {
47            inner: Rc::downgrade(&inner),
48        },
49        Receiver { inner },
50    )
51}
52
53/// Create empty stream
54pub fn empty<E>(data: Option<Bytes>) -> Receiver<E> {
55    let rx = Receiver {
56        inner: Rc::new(Inner::new(true)),
57    };
58    if let Some(data) = data {
59        rx.put(data);
60    }
61    rx
62}
63
64/// Buffered stream of byte chunks
65///
66/// Payload stores chunks in a vector. Chunks can be received with
67/// `.read()` method.
68#[derive(Debug)]
69pub struct Receiver<E> {
70    inner: Rc<Inner<E>>,
71}
72
73impl<E> Receiver<E> {
74    /// Set size of stream buffer
75    ///
76    /// By default buffer size is set to 32Kb
77    #[inline]
78    pub fn max_buffer_size(&self, size: usize) {
79        self.inner.max_buffer_size.set(size);
80    }
81
82    /// Put unused data back to stream
83    #[inline]
84    pub fn put(&self, data: Bytes) {
85        self.inner.unread_data(data);
86    }
87
88    #[inline]
89    /// Check if stream is eof
90    pub fn is_eof(&self) -> bool {
91        self.inner.flags.get().contains(Flags::EOF)
92    }
93
94    #[inline]
95    /// Read next available bytes chunk
96    pub async fn read(&self) -> Option<Result<Bytes, E>> {
97        poll_fn(|cx| self.poll_read(cx)).await
98    }
99
100    #[inline]
101    pub fn poll_read(&self, cx: &mut Context<'_>) -> Poll<Option<Result<Bytes, E>>> {
102        if let Some(data) = self.inner.get_data() {
103            Poll::Ready(Some(Ok(data)))
104        } else if let Some(err) = self.inner.err.take() {
105            self.inner.insert_flag(Flags::EOF);
106            Poll::Ready(Some(Err(err)))
107        } else if self.inner.flags.get().intersects(Flags::EOF | Flags::ERROR) {
108            Poll::Ready(None)
109        } else {
110            self.inner.recv_task.register(cx.waker());
111            Poll::Pending
112        }
113    }
114
115    #[doc(hidden)]
116    #[deprecated]
117    #[inline]
118    pub fn max_size(&self, size: usize) {
119        self.max_buffer_size(size);
120    }
121}
122
123impl<E> Stream for Receiver<E> {
124    type Item = Result<Bytes, E>;
125
126    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
127        self.poll_read(cx)
128    }
129}
130
131impl<E> Drop for Receiver<E> {
132    fn drop(&mut self) {
133        self.inner.send_task.wake();
134    }
135}
136
137/// Sender part of the payload stream
138///
139/// It is possible to feed data from a cloned sender, but the readiness
140/// check applies only to the most recently called one.
141#[derive(Debug)]
142pub struct Sender<E> {
143    inner: Weak<Inner<E>>,
144}
145
146impl<E> Clone for Sender<E> {
147    fn clone(&self) -> Self {
148        Self {
149            inner: self.inner.clone(),
150        }
151    }
152}
153
154impl<E> Drop for Sender<E> {
155    fn drop(&mut self) {
156        if self.inner.weak_count() == 1 {
157            if let Some(shared) = self.inner.upgrade() {
158                shared.insert_flag(Flags::EOF | Flags::SENDER_GONE);
159            }
160        }
161    }
162}
163
164impl<E> Sender<E> {
165    /// Set stream error
166    pub fn set_error(&self, err: E) {
167        if let Some(shared) = self.inner.upgrade() {
168            shared.set_error(err);
169        }
170    }
171
172    /// Set stream eof
173    pub fn feed_eof(&self) {
174        if let Some(shared) = self.inner.upgrade() {
175            shared.feed_eof();
176        }
177    }
178
179    /// Add chunk to the stream
180    pub fn feed_data(&self, data: Bytes) {
181        if let Some(shared) = self.inner.upgrade() {
182            shared.feed_data(data)
183        }
184    }
185
186    /// Check stream readiness
187    pub async fn ready(&self) -> Status {
188        poll_fn(|cx| self.poll_ready(cx)).await
189    }
190
191    /// Check stream readiness
192    pub fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Status> {
193        if let Some(shared) = self.inner.upgrade() {
194            let flags = shared.flags.get();
195            if flags.contains(Flags::NEED_READ) {
196                Poll::Ready(Status::Ready)
197            } else if flags.contains(Flags::SENDER_GONE | Flags::ERROR) {
198                Poll::Ready(Status::Dropped)
199            } else if flags.intersects(Flags::EOF) {
200                Poll::Ready(Status::Eof)
201            } else {
202                shared.send_task.register(cx.waker());
203                Poll::Pending
204            }
205        } else {
206            // receiver is gone
207            Poll::Ready(Status::Dropped)
208        }
209    }
210}
211
212bitflags::bitflags! {
213    #[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
214    struct Flags: u8 {
215        const EOF         = 0b0000_0001;
216        const ERROR       = 0b0000_0010;
217        const NEED_READ   = 0b0000_0100;
218        const SENDER_GONE = 0b0000_1000;
219    }
220}
221
222struct Inner<E> {
223    len: Cell<usize>,
224    flags: Cell<Flags>,
225    err: Cell<Option<E>>,
226    items: RefCell<VecDeque<Bytes>>,
227    max_buffer_size: Cell<usize>,
228    recv_task: LocalWaker,
229    send_task: LocalWaker,
230}
231
232impl<E> Inner<E> {
233    fn new(eof: bool) -> Self {
234        let flags = if eof { Flags::EOF } else { Flags::NEED_READ };
235        Inner {
236            flags: Cell::new(flags),
237            len: Cell::new(0),
238            err: Cell::new(None),
239            items: RefCell::new(VecDeque::new()),
240            recv_task: LocalWaker::new(),
241            send_task: LocalWaker::new(),
242            max_buffer_size: Cell::new(MAX_BUFFER_SIZE),
243        }
244    }
245
246    fn insert_flag(&self, f: Flags) {
247        let mut flags = self.flags.get();
248        flags.insert(f);
249        self.flags.set(flags);
250    }
251
252    fn remove_flag(&self, f: Flags) {
253        let mut flags = self.flags.get();
254        flags.remove(f);
255        self.flags.set(flags);
256    }
257
258    fn set_error(&self, err: E) {
259        self.err.set(Some(err));
260        self.insert_flag(Flags::ERROR);
261        self.recv_task.wake();
262        self.send_task.wake();
263    }
264
265    fn feed_eof(&self) {
266        self.insert_flag(Flags::EOF);
267        self.recv_task.wake();
268        self.send_task.wake();
269    }
270
271    fn feed_data(&self, data: Bytes) {
272        let len = self.len.get() + data.len();
273        self.len.set(len);
274        self.items.borrow_mut().push_back(data);
275        self.recv_task.wake();
276
277        if len >= self.max_buffer_size.get() {
278            self.remove_flag(Flags::NEED_READ);
279        }
280    }
281
282    fn get_data(&self) -> Option<Bytes> {
283        self.items.borrow_mut().pop_front().inspect(|data| {
284            let len = self.len.get() - data.len();
285
286            // check size of stream buffer,
287            // if stream has more space wake up sender
288            self.len.set(len);
289            if len < self.max_buffer_size.get() {
290                self.insert_flag(Flags::NEED_READ);
291                self.send_task.wake();
292            }
293        })
294    }
295
296    fn unread_data(&self, data: Bytes) {
297        if !data.is_empty() {
298            self.len.set(self.len.get() + data.len());
299            self.items.borrow_mut().push_front(data);
300        }
301    }
302}
303
304impl<E> fmt::Debug for Inner<E> {
305    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
306        f.debug_struct("Inner")
307            .field("len", &self.len)
308            .field("flags", &self.flags)
309            .field("items", &self.items.borrow())
310            .field("max_buffer_size", &self.max_buffer_size)
311            .field("recv_task", &self.recv_task)
312            .field("send_task", &self.send_task)
313            .finish()
314    }
315}
316
317#[cfg(test)]
318mod tests {
319    use super::*;
320
321    #[ntex::test]
322    async fn test_eof() {
323        let (_, rx) = eof::<()>();
324        assert!(rx.read().await.is_none());
325    }
326
327    #[ntex::test]
328    async fn test_unread_data() {
329        let (_, payload) = channel::<()>();
330
331        payload.put(Bytes::from("data"));
332        assert_eq!(payload.inner.len.get(), 4);
333        assert_eq!(
334            Bytes::from("data"),
335            poll_fn(|cx| payload.poll_read(cx)).await.unwrap().unwrap()
336        );
337    }
338
339    #[ntex::test]
340    async fn test_sender_clone() {
341        let (sender, payload) = channel::<()>();
342        assert!(!payload.is_eof());
343        let sender2 = sender.clone();
344        assert!(!payload.is_eof());
345        drop(sender2);
346        assert!(!payload.is_eof());
347        drop(sender);
348        assert!(payload.is_eof());
349    }
350}