ntex_util/channel/
bstream.rs

1//! Bytes stream
2use std::cell::{Cell, RefCell};
3use std::task::{Context, Poll};
4use std::{collections::VecDeque, fmt, future::poll_fn, pin::Pin, rc::Rc, rc::Weak};
5
6use ntex_bytes::Bytes;
7
8use crate::{Stream, task::LocalWaker};
9
10/// max buffer size 32k
11const MAX_BUFFER_SIZE: usize = 32_768;
12
13#[derive(Copy, Clone, Debug, PartialEq, Eq)]
14pub enum Status {
15    /// Stream is ready
16    Ready,
17    /// Receiver side is dropped
18    Dropped,
19}
20
21/// Create bytes stream.
22///
23/// This method construct two objects responsible for bytes stream
24/// generation.
25pub fn channel<E>() -> (Sender<E>, Receiver<E>) {
26    let inner = Rc::new(Inner::new(false));
27
28    (
29        Sender {
30            inner: Rc::downgrade(&inner),
31        },
32        Receiver { inner },
33    )
34}
35
36/// Create closed bytes stream.
37///
38/// This method construct two objects responsible for bytes stream
39/// generation.
40pub fn eof<E>() -> (Sender<E>, Receiver<E>) {
41    let inner = Rc::new(Inner::new(true));
42
43    (
44        Sender {
45            inner: Rc::downgrade(&inner),
46        },
47        Receiver { inner },
48    )
49}
50
51/// Create empty stream
52pub fn empty<E>(data: Option<Bytes>) -> Receiver<E> {
53    let rx = Receiver {
54        inner: Rc::new(Inner::new(true)),
55    };
56    if let Some(data) = data {
57        rx.put(data);
58    }
59    rx
60}
61
62/// Buffered stream of byte chunks
63///
64/// Payload stores chunks in a vector. Chunks can be received with
65/// `.read()` method.
66#[derive(Debug)]
67pub struct Receiver<E> {
68    inner: Rc<Inner<E>>,
69}
70
71impl<E> Receiver<E> {
72    /// Set size of stream buffer
73    ///
74    /// By default buffer size is set to 32Kb
75    #[inline]
76    pub fn max_buffer_size(&self, size: usize) {
77        self.inner.max_buffer_size.set(size);
78    }
79
80    /// Put unused data back to stream
81    #[inline]
82    pub fn put(&self, data: Bytes) {
83        self.inner.unread_data(data);
84    }
85
86    #[inline]
87    /// Check if stream is eof
88    pub fn is_eof(&self) -> bool {
89        self.inner.flags.get().contains(Flags::EOF)
90    }
91
92    #[inline]
93    /// Read next available bytes chunk
94    pub async fn read(&self) -> Option<Result<Bytes, E>> {
95        poll_fn(|cx| self.poll_read(cx)).await
96    }
97
98    #[inline]
99    pub fn poll_read(&self, cx: &mut Context<'_>) -> Poll<Option<Result<Bytes, E>>> {
100        if let Some(data) = self.inner.get_data() {
101            Poll::Ready(Some(Ok(data)))
102        } else if let Some(err) = self.inner.err.take() {
103            self.inner.insert_flag(Flags::EOF);
104            Poll::Ready(Some(Err(err)))
105        } else if self.inner.flags.get().intersects(Flags::EOF | Flags::ERROR) {
106            Poll::Ready(None)
107        } else {
108            self.inner.recv_task.register(cx.waker());
109            Poll::Pending
110        }
111    }
112
113    #[doc(hidden)]
114    #[deprecated]
115    #[inline]
116    pub fn max_size(&self, size: usize) {
117        self.max_buffer_size(size);
118    }
119}
120
121impl<E> Stream for Receiver<E> {
122    type Item = Result<Bytes, E>;
123
124    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
125        self.poll_read(cx)
126    }
127}
128
129/// Sender part of the payload stream
130///
131/// It is possible to feed data from a cloned sender, but the readiness
132/// check applies only to the most recently called one.
133#[derive(Debug)]
134pub struct Sender<E> {
135    inner: Weak<Inner<E>>,
136}
137
138impl<E> Drop for Sender<E> {
139    fn drop(&mut self) {
140        if let Some(shared) = self.inner.upgrade() {
141            if self.inner.weak_count() == 1 {
142                shared.insert_flag(Flags::EOF);
143            }
144        }
145    }
146}
147
148impl<E> Clone for Sender<E> {
149    fn clone(&self) -> Self {
150        Self {
151            inner: self.inner.clone(),
152        }
153    }
154}
155
156impl<E> Sender<E> {
157    /// Set stream error
158    pub fn set_error(&self, err: E) {
159        if let Some(shared) = self.inner.upgrade() {
160            shared.set_error(err);
161        }
162    }
163
164    /// Set stream eof
165    pub fn feed_eof(&self) {
166        if let Some(shared) = self.inner.upgrade() {
167            shared.feed_eof();
168        }
169    }
170
171    /// Add chunk to the stream
172    pub fn feed_data(&self, data: Bytes) {
173        if let Some(shared) = self.inner.upgrade() {
174            shared.feed_data(data)
175        }
176    }
177
178    /// Check stream readiness
179    pub async fn ready(&self) -> Status {
180        poll_fn(|cx| self.poll_ready(cx)).await
181    }
182
183    /// Check stream readiness
184    pub fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Status> {
185        // we check if Payload (other side) is alive,
186        // otherwise always return true (consume payload)
187        if let Some(shared) = self.inner.upgrade() {
188            if shared.flags.get().contains(Flags::NEED_READ) {
189                Poll::Ready(Status::Ready)
190            } else {
191                shared.send_task.register(cx.waker());
192                Poll::Pending
193            }
194        } else {
195            Poll::Ready(Status::Dropped)
196        }
197    }
198}
199
200bitflags::bitflags! {
201    #[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
202    struct Flags: u8 {
203        const EOF         = 0b0000_0001;
204        const ERROR       = 0b0000_0010;
205        const NEED_READ   = 0b0000_0100;
206        const SENDER_GONE = 0b0000_1000;
207    }
208}
209
210struct Inner<E> {
211    len: Cell<usize>,
212    flags: Cell<Flags>,
213    err: Cell<Option<E>>,
214    items: RefCell<VecDeque<Bytes>>,
215    max_buffer_size: Cell<usize>,
216    recv_task: LocalWaker,
217    send_task: LocalWaker,
218}
219
220impl<E> Inner<E> {
221    fn new(eof: bool) -> Self {
222        let flags = if eof { Flags::EOF } else { Flags::NEED_READ };
223        Inner {
224            flags: Cell::new(flags),
225            len: Cell::new(0),
226            err: Cell::new(None),
227            items: RefCell::new(VecDeque::new()),
228            recv_task: LocalWaker::new(),
229            send_task: LocalWaker::new(),
230            max_buffer_size: Cell::new(MAX_BUFFER_SIZE),
231        }
232    }
233
234    fn insert_flag(&self, f: Flags) {
235        let mut flags = self.flags.get();
236        flags.insert(f);
237        self.flags.set(flags);
238    }
239
240    fn remove_flag(&self, f: Flags) {
241        let mut flags = self.flags.get();
242        flags.remove(f);
243        self.flags.set(flags);
244    }
245
246    fn set_error(&self, err: E) {
247        self.err.set(Some(err));
248        self.insert_flag(Flags::ERROR);
249        self.recv_task.wake()
250    }
251
252    fn feed_eof(&self) {
253        self.insert_flag(Flags::EOF);
254        self.recv_task.wake()
255    }
256
257    fn feed_data(&self, data: Bytes) {
258        let len = self.len.get() + data.len();
259        self.len.set(len);
260        self.items.borrow_mut().push_back(data);
261        self.recv_task.wake();
262
263        if len >= self.max_buffer_size.get() {
264            self.remove_flag(Flags::NEED_READ);
265        }
266    }
267
268    fn get_data(&self) -> Option<Bytes> {
269        if let Some(data) = self.items.borrow_mut().pop_front() {
270            let len = self.len.get() - data.len();
271
272            // check size of stream buffer,
273            // if stream has more space wake up sender
274            self.len.set(len);
275            if len < self.max_buffer_size.get() {
276                self.insert_flag(Flags::NEED_READ);
277                self.send_task.wake();
278            }
279            Some(data)
280        } else {
281            self.insert_flag(Flags::NEED_READ);
282            self.send_task.wake();
283            None
284        }
285    }
286
287    fn unread_data(&self, data: Bytes) {
288        if !data.is_empty() {
289            self.len.set(self.len.get() + data.len());
290            self.items.borrow_mut().push_front(data);
291        }
292    }
293}
294
295impl<E> fmt::Debug for Inner<E> {
296    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
297        f.debug_struct("Inner")
298            .field("len", &self.len)
299            .field("flags", &self.flags)
300            .field("items", &self.items.borrow())
301            .field("max_buffer_size", &self.max_buffer_size)
302            .field("recv_task", &self.recv_task)
303            .field("send_task", &self.send_task)
304            .finish()
305    }
306}
307
308#[cfg(test)]
309mod tests {
310    use super::*;
311
312    #[ntex::test]
313    async fn test_eof() {
314        let (_, rx) = eof::<()>();
315        assert!(rx.read().await.is_none());
316    }
317
318    #[ntex::test]
319    async fn test_unread_data() {
320        let (_, payload) = channel::<()>();
321
322        payload.put(Bytes::from("data"));
323        assert_eq!(payload.inner.len.get(), 4);
324        assert_eq!(
325            Bytes::from("data"),
326            poll_fn(|cx| payload.poll_read(cx)).await.unwrap().unwrap()
327        );
328    }
329
330    #[ntex::test]
331    async fn test_sender_clone() {
332        let (sender, payload) = channel::<()>();
333        assert!(!payload.is_eof());
334        let sender2 = sender.clone();
335        assert!(!payload.is_eof());
336        drop(sender2);
337        assert!(!payload.is_eof());
338        drop(sender);
339        assert!(payload.is_eof());
340    }
341}