ntex_util/channel/
bstream.rs1use std::cell::{Cell, RefCell};
3use std::task::{Context, Poll};
4use std::{collections::VecDeque, fmt, future::poll_fn, pin::Pin, rc::Rc, rc::Weak};
5
6use ntex_bytes::Bytes;
7
8use crate::{Stream, task::LocalWaker};
9
10const MAX_BUFFER_SIZE: usize = 32_768;
12
13#[derive(Copy, Clone, Debug, PartialEq, Eq)]
14pub enum Status {
15 Ready,
17 Dropped,
19}
20
21pub fn channel<E>() -> (Sender<E>, Receiver<E>) {
26 let inner = Rc::new(Inner::new(false));
27
28 (
29 Sender {
30 inner: Rc::downgrade(&inner),
31 },
32 Receiver { inner },
33 )
34}
35
36pub fn eof<E>() -> (Sender<E>, Receiver<E>) {
41 let inner = Rc::new(Inner::new(true));
42
43 (
44 Sender {
45 inner: Rc::downgrade(&inner),
46 },
47 Receiver { inner },
48 )
49}
50
51pub fn empty<E>(data: Option<Bytes>) -> Receiver<E> {
53 let rx = Receiver {
54 inner: Rc::new(Inner::new(true)),
55 };
56 if let Some(data) = data {
57 rx.put(data);
58 }
59 rx
60}
61
62#[derive(Debug)]
67pub struct Receiver<E> {
68 inner: Rc<Inner<E>>,
69}
70
71impl<E> Receiver<E> {
72 #[inline]
76 pub fn max_buffer_size(&self, size: usize) {
77 self.inner.max_buffer_size.set(size);
78 }
79
80 #[inline]
82 pub fn put(&self, data: Bytes) {
83 self.inner.unread_data(data);
84 }
85
86 #[inline]
87 pub fn is_eof(&self) -> bool {
89 self.inner.flags.get().contains(Flags::EOF)
90 }
91
92 #[inline]
93 pub async fn read(&self) -> Option<Result<Bytes, E>> {
95 poll_fn(|cx| self.poll_read(cx)).await
96 }
97
98 #[inline]
99 pub fn poll_read(&self, cx: &mut Context<'_>) -> Poll<Option<Result<Bytes, E>>> {
100 if let Some(data) = self.inner.get_data() {
101 Poll::Ready(Some(Ok(data)))
102 } else if let Some(err) = self.inner.err.take() {
103 self.inner.insert_flag(Flags::EOF);
104 Poll::Ready(Some(Err(err)))
105 } else if self.inner.flags.get().intersects(Flags::EOF | Flags::ERROR) {
106 Poll::Ready(None)
107 } else {
108 self.inner.recv_task.register(cx.waker());
109 Poll::Pending
110 }
111 }
112
113 #[doc(hidden)]
114 #[deprecated]
115 #[inline]
116 pub fn max_size(&self, size: usize) {
117 self.max_buffer_size(size);
118 }
119}
120
121impl<E> Stream for Receiver<E> {
122 type Item = Result<Bytes, E>;
123
124 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
125 self.poll_read(cx)
126 }
127}
128
129#[derive(Debug)]
134pub struct Sender<E> {
135 inner: Weak<Inner<E>>,
136}
137
138impl<E> Drop for Sender<E> {
139 fn drop(&mut self) {
140 if let Some(shared) = self.inner.upgrade() {
141 if self.inner.weak_count() == 1 {
142 shared.insert_flag(Flags::EOF);
143 }
144 }
145 }
146}
147
148impl<E> Clone for Sender<E> {
149 fn clone(&self) -> Self {
150 Self {
151 inner: self.inner.clone(),
152 }
153 }
154}
155
156impl<E> Sender<E> {
157 pub fn set_error(&self, err: E) {
159 if let Some(shared) = self.inner.upgrade() {
160 shared.set_error(err);
161 }
162 }
163
164 pub fn feed_eof(&self) {
166 if let Some(shared) = self.inner.upgrade() {
167 shared.feed_eof();
168 }
169 }
170
171 pub fn feed_data(&self, data: Bytes) {
173 if let Some(shared) = self.inner.upgrade() {
174 shared.feed_data(data)
175 }
176 }
177
178 pub async fn ready(&self) -> Status {
180 poll_fn(|cx| self.poll_ready(cx)).await
181 }
182
183 pub fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Status> {
185 if let Some(shared) = self.inner.upgrade() {
188 if shared.flags.get().contains(Flags::NEED_READ) {
189 Poll::Ready(Status::Ready)
190 } else {
191 shared.send_task.register(cx.waker());
192 Poll::Pending
193 }
194 } else {
195 Poll::Ready(Status::Dropped)
196 }
197 }
198}
199
200bitflags::bitflags! {
201 #[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
202 struct Flags: u8 {
203 const EOF = 0b0000_0001;
204 const ERROR = 0b0000_0010;
205 const NEED_READ = 0b0000_0100;
206 const SENDER_GONE = 0b0000_1000;
207 }
208}
209
210struct Inner<E> {
211 len: Cell<usize>,
212 flags: Cell<Flags>,
213 err: Cell<Option<E>>,
214 items: RefCell<VecDeque<Bytes>>,
215 max_buffer_size: Cell<usize>,
216 recv_task: LocalWaker,
217 send_task: LocalWaker,
218}
219
220impl<E> Inner<E> {
221 fn new(eof: bool) -> Self {
222 let flags = if eof { Flags::EOF } else { Flags::NEED_READ };
223 Inner {
224 flags: Cell::new(flags),
225 len: Cell::new(0),
226 err: Cell::new(None),
227 items: RefCell::new(VecDeque::new()),
228 recv_task: LocalWaker::new(),
229 send_task: LocalWaker::new(),
230 max_buffer_size: Cell::new(MAX_BUFFER_SIZE),
231 }
232 }
233
234 fn insert_flag(&self, f: Flags) {
235 let mut flags = self.flags.get();
236 flags.insert(f);
237 self.flags.set(flags);
238 }
239
240 fn remove_flag(&self, f: Flags) {
241 let mut flags = self.flags.get();
242 flags.remove(f);
243 self.flags.set(flags);
244 }
245
246 fn set_error(&self, err: E) {
247 self.err.set(Some(err));
248 self.insert_flag(Flags::ERROR);
249 self.recv_task.wake()
250 }
251
252 fn feed_eof(&self) {
253 self.insert_flag(Flags::EOF);
254 self.recv_task.wake()
255 }
256
257 fn feed_data(&self, data: Bytes) {
258 let len = self.len.get() + data.len();
259 self.len.set(len);
260 self.items.borrow_mut().push_back(data);
261 self.recv_task.wake();
262
263 if len >= self.max_buffer_size.get() {
264 self.remove_flag(Flags::NEED_READ);
265 }
266 }
267
268 fn get_data(&self) -> Option<Bytes> {
269 if let Some(data) = self.items.borrow_mut().pop_front() {
270 let len = self.len.get() - data.len();
271
272 self.len.set(len);
275 if len < self.max_buffer_size.get() {
276 self.insert_flag(Flags::NEED_READ);
277 self.send_task.wake();
278 }
279 Some(data)
280 } else {
281 self.insert_flag(Flags::NEED_READ);
282 self.send_task.wake();
283 None
284 }
285 }
286
287 fn unread_data(&self, data: Bytes) {
288 if !data.is_empty() {
289 self.len.set(self.len.get() + data.len());
290 self.items.borrow_mut().push_front(data);
291 }
292 }
293}
294
295impl<E> fmt::Debug for Inner<E> {
296 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
297 f.debug_struct("Inner")
298 .field("len", &self.len)
299 .field("flags", &self.flags)
300 .field("items", &self.items.borrow())
301 .field("max_buffer_size", &self.max_buffer_size)
302 .field("recv_task", &self.recv_task)
303 .field("send_task", &self.send_task)
304 .finish()
305 }
306}
307
308#[cfg(test)]
309mod tests {
310 use super::*;
311
312 #[ntex::test]
313 async fn test_eof() {
314 let (_, rx) = eof::<()>();
315 assert!(rx.read().await.is_none());
316 }
317
318 #[ntex::test]
319 async fn test_unread_data() {
320 let (_, payload) = channel::<()>();
321
322 payload.put(Bytes::from("data"));
323 assert_eq!(payload.inner.len.get(), 4);
324 assert_eq!(
325 Bytes::from("data"),
326 poll_fn(|cx| payload.poll_read(cx)).await.unwrap().unwrap()
327 );
328 }
329
330 #[ntex::test]
331 async fn test_sender_clone() {
332 let (sender, payload) = channel::<()>();
333 assert!(!payload.is_eof());
334 let sender2 = sender.clone();
335 assert!(!payload.is_eof());
336 drop(sender2);
337 assert!(!payload.is_eof());
338 drop(sender);
339 assert!(payload.is_eof());
340 }
341}