ntex_util/channel/
bstream.rs1use std::cell::{Cell, RefCell};
3use std::task::{Context, Poll};
4use std::{collections::VecDeque, fmt, future::poll_fn, pin::Pin, rc::Rc, rc::Weak};
5
6use ntex_bytes::Bytes;
7
8use crate::{task::LocalWaker, Stream};
9
10const MAX_BUFFER_SIZE: usize = 32_768;
12
13#[derive(Copy, Clone, Debug, PartialEq, Eq)]
14pub enum Status {
15 Ready,
17 Dropped,
19}
20
21pub fn channel<E>() -> (Sender<E>, Receiver<E>) {
26 let inner = Rc::new(Inner::new(false));
27
28 (
29 Sender {
30 inner: Rc::downgrade(&inner),
31 },
32 Receiver { inner },
33 )
34}
35
36pub fn eof<E>() -> (Sender<E>, Receiver<E>) {
41 let inner = Rc::new(Inner::new(true));
42
43 (
44 Sender {
45 inner: Rc::downgrade(&inner),
46 },
47 Receiver { inner },
48 )
49}
50
51pub fn empty<E>(data: Option<Bytes>) -> Receiver<E> {
53 let rx = Receiver {
54 inner: Rc::new(Inner::new(true)),
55 };
56 if let Some(data) = data {
57 rx.put(data);
58 }
59 rx
60}
61
62#[derive(Debug)]
67pub struct Receiver<E> {
68 inner: Rc<Inner<E>>,
69}
70
71impl<E> Receiver<E> {
72 #[inline]
76 pub fn max_size(&self, size: usize) {
77 self.inner.max_size.set(size);
78 }
79
80 #[inline]
82 pub fn put(&self, data: Bytes) {
83 self.inner.unread_data(data);
84 }
85
86 #[inline]
87 pub async fn read(&self) -> Option<Result<Bytes, E>> {
89 poll_fn(|cx| self.poll_read(cx)).await
90 }
91
92 #[inline]
93 pub fn poll_read(&self, cx: &mut Context<'_>) -> Poll<Option<Result<Bytes, E>>> {
94 if let Some(data) = self.inner.items.borrow_mut().pop_front() {
95 let len = self.inner.len.get() - data.len();
96 self.inner.len.set(len);
97 let need_read = if len < self.inner.max_size.get() {
98 self.inner.insert_flag(Flags::NEED_READ);
99 true
100 } else {
101 self.inner.remove_flag(Flags::NEED_READ);
102 false
103 };
104 if need_read {
105 self.inner.rx_task.register(cx.waker());
106 self.inner.tx_task.wake();
107 }
108 Poll::Ready(Some(Ok(data)))
109 } else if let Some(err) = self.inner.err.take() {
110 Poll::Ready(Some(Err(err)))
111 } else if self.inner.flags.get().intersects(Flags::EOF | Flags::ERROR) {
112 Poll::Ready(None)
113 } else {
114 self.inner.insert_flag(Flags::NEED_READ);
115 self.inner.rx_task.register(cx.waker());
116 self.inner.tx_task.wake();
117 Poll::Pending
118 }
119 }
120}
121
122impl<E> Stream for Receiver<E> {
123 type Item = Result<Bytes, E>;
124
125 fn poll_next(
126 self: Pin<&mut Self>,
127 cx: &mut Context<'_>,
128 ) -> Poll<Option<Result<Bytes, E>>> {
129 self.poll_read(cx)
130 }
131}
132
133#[derive(Debug)]
135pub struct Sender<E> {
136 inner: Weak<Inner<E>>,
137}
138
139impl<E> Drop for Sender<E> {
140 fn drop(&mut self) {
141 if let Some(shared) = self.inner.upgrade() {
142 shared.insert_flag(Flags::EOF);
143 }
144 }
145}
146
147impl<E> Sender<E> {
148 pub fn set_error(&self, err: E) {
150 if let Some(shared) = self.inner.upgrade() {
151 shared.set_error(err);
152 }
153 }
154
155 pub fn feed_eof(&self) {
157 if let Some(shared) = self.inner.upgrade() {
158 shared.feed_eof();
159 }
160 }
161
162 pub fn feed_data(&self, data: Bytes) {
164 if let Some(shared) = self.inner.upgrade() {
165 shared.feed_data(data)
166 }
167 }
168
169 pub async fn ready(&self) -> Status {
171 poll_fn(|cx| self.poll_ready(cx)).await
172 }
173
174 pub fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Status> {
176 if let Some(shared) = self.inner.upgrade() {
179 if shared.flags.get().contains(Flags::NEED_READ) {
180 Poll::Ready(Status::Ready)
181 } else {
182 shared.tx_task.register(cx.waker());
183 Poll::Pending
184 }
185 } else {
186 Poll::Ready(Status::Dropped)
187 }
188 }
189}
190
191bitflags::bitflags! {
192 #[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
193 struct Flags: u8 {
194 const EOF = 0b0000_0001;
195 const ERROR = 0b0000_0010;
196 const NEED_READ = 0b0000_0100;
197 const SENDER_GONE = 0b0000_1000;
198 }
199}
200
201struct Inner<E> {
202 len: Cell<usize>,
203 flags: Cell<Flags>,
204 err: Cell<Option<E>>,
205 items: RefCell<VecDeque<Bytes>>,
206 max_size: Cell<usize>,
207 rx_task: LocalWaker,
208 tx_task: LocalWaker,
209}
210
211impl<E> Inner<E> {
212 fn new(eof: bool) -> Self {
213 let flags = if eof {
214 Flags::EOF | Flags::NEED_READ
215 } else {
216 Flags::NEED_READ
217 };
218 Inner {
219 flags: Cell::new(flags),
220 len: Cell::new(0),
221 err: Cell::new(None),
222 items: RefCell::new(VecDeque::new()),
223 rx_task: LocalWaker::new(),
224 tx_task: LocalWaker::new(),
225 max_size: Cell::new(MAX_BUFFER_SIZE),
226 }
227 }
228
229 fn insert_flag(&self, f: Flags) {
230 let mut flags = self.flags.get();
231 flags.insert(f);
232 self.flags.set(flags);
233 }
234
235 fn remove_flag(&self, f: Flags) {
236 let mut flags = self.flags.get();
237 flags.remove(f);
238 self.flags.set(flags);
239 }
240
241 fn set_error(&self, err: E) {
242 self.err.set(Some(err));
243 self.insert_flag(Flags::ERROR);
244 self.rx_task.wake()
245 }
246
247 fn feed_eof(&self) {
248 self.insert_flag(Flags::EOF);
249 self.rx_task.wake()
250 }
251
252 fn feed_data(&self, data: Bytes) {
253 let len = self.len.get() + data.len();
254 self.len.set(len);
255 self.items.borrow_mut().push_back(data);
256 if len < self.max_size.get() {
257 self.insert_flag(Flags::NEED_READ);
258 } else {
259 self.remove_flag(Flags::NEED_READ);
260 }
261 self.rx_task.wake();
262 }
263
264 fn unread_data(&self, data: Bytes) {
265 if !data.is_empty() {
266 self.len.set(self.len.get() + data.len());
267 self.items.borrow_mut().push_front(data);
268 }
269 }
270}
271
272impl<E> fmt::Debug for Inner<E> {
273 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
274 f.debug_struct("Inner")
275 .field("len", &self.len)
276 .field("flags", &self.flags)
277 .field("items", &self.items.borrow())
278 .field("max_size", &self.max_size)
279 .field("rx_task", &self.rx_task)
280 .field("tx_task", &self.tx_task)
281 .finish()
282 }
283}
284
285#[cfg(test)]
286mod tests {
287 use super::*;
288
289 #[ntex_macros::rt_test2]
290 async fn test_eof() {
291 let (_, rx) = eof::<()>();
292 assert!(rx.read().await.is_none());
293 }
294
295 #[ntex_macros::rt_test2]
296 async fn test_unread_data() {
297 let (_, payload) = channel::<()>();
298
299 payload.put(Bytes::from("data"));
300 assert_eq!(payload.inner.len.get(), 4);
301 assert_eq!(
302 Bytes::from("data"),
303 poll_fn(|cx| payload.poll_read(cx)).await.unwrap().unwrap()
304 );
305 }
306}