ntex_util/channel/
bstream.rs1use std::cell::{Cell, RefCell};
3use std::task::{Context, Poll};
4use std::{collections::VecDeque, fmt, future::poll_fn, pin::Pin, rc::Rc, rc::Weak};
5
6use ntex_bytes::Bytes;
7
8use crate::{Stream, task::LocalWaker};
9
10const MAX_BUFFER_SIZE: usize = 32_768;
12
13#[derive(Copy, Clone, Debug, PartialEq, Eq)]
15pub enum Status {
16 Eof,
18 Ready,
20 Dropped,
22}
23
24pub fn channel<E>() -> (Sender<E>, Receiver<E>) {
29 let inner = Rc::new(Inner::new(false));
30
31 (
32 Sender {
33 inner: Rc::downgrade(&inner),
34 },
35 Receiver { inner },
36 )
37}
38
39pub fn eof<E>() -> (Sender<E>, Receiver<E>) {
44 let inner = Rc::new(Inner::new(true));
45
46 (
47 Sender {
48 inner: Rc::downgrade(&inner),
49 },
50 Receiver { inner },
51 )
52}
53
54pub fn empty<E>(data: Option<Bytes>) -> Receiver<E> {
56 let rx = Receiver {
57 inner: Rc::new(Inner::new(true)),
58 };
59 if let Some(data) = data {
60 rx.put(data);
61 }
62 rx
63}
64
65#[derive(Debug)]
70pub struct Receiver<E> {
71 inner: Rc<Inner<E>>,
72}
73
74impl<E> Receiver<E> {
75 #[inline]
79 pub fn max_buffer_size(&self, size: usize) {
80 self.inner.max_buffer_size.set(size);
81 }
82
83 #[inline]
85 pub fn put(&self, data: Bytes) {
86 self.inner.unread_data(data);
87 }
88
89 #[inline]
90 pub fn is_eof(&self) -> bool {
92 self.inner.flags.get().contains(Flags::EOF)
93 }
94
95 #[inline]
96 pub async fn read(&self) -> Option<Result<Bytes, E>> {
98 poll_fn(|cx| self.poll_read(cx)).await
99 }
100
101 #[inline]
102 pub fn poll_read(&self, cx: &mut Context<'_>) -> Poll<Option<Result<Bytes, E>>> {
104 if let Some(data) = self.inner.get_data() {
105 Poll::Ready(Some(Ok(data)))
106 } else if let Some(err) = self.inner.err.take() {
107 self.inner.insert_flag(Flags::EOF);
108 Poll::Ready(Some(Err(err)))
109 } else if self.inner.flags.get().intersects(Flags::EOF | Flags::ERROR) {
110 Poll::Ready(None)
111 } else {
112 self.inner.recv_task.register(cx.waker());
113 Poll::Pending
114 }
115 }
116}
117
118impl<E> Stream for Receiver<E> {
119 type Item = Result<Bytes, E>;
120
121 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
122 self.poll_read(cx)
123 }
124}
125
126impl<E> Drop for Receiver<E> {
127 fn drop(&mut self) {
128 self.inner.send_task.wake();
129 }
130}
131
132#[derive(Debug)]
137pub struct Sender<E> {
138 inner: Weak<Inner<E>>,
139}
140
141impl<E> Clone for Sender<E> {
142 fn clone(&self) -> Self {
143 Self {
144 inner: self.inner.clone(),
145 }
146 }
147}
148
149impl<E> Drop for Sender<E> {
150 fn drop(&mut self) {
151 if self.inner.weak_count() == 1
152 && let Some(shared) = self.inner.upgrade()
153 {
154 shared.insert_flag(Flags::EOF | Flags::SENDER_GONE);
155 }
156 }
157}
158
159impl<E> Sender<E> {
160 pub fn is_closed(&self) -> bool {
162 self.inner.strong_count() == 0
163 }
164
165 pub fn set_error(&self, err: E) {
167 if let Some(shared) = self.inner.upgrade() {
168 shared.set_error(err);
169 }
170 }
171
172 pub fn feed_eof(&self) {
174 if let Some(shared) = self.inner.upgrade() {
175 shared.feed_eof();
176 }
177 }
178
179 pub fn feed_data(&self, data: Bytes) {
181 if let Some(shared) = self.inner.upgrade() {
182 shared.feed_data(data);
183 }
184 }
185
186 pub async fn ready(&self) -> Status {
188 poll_fn(|cx| self.poll_ready(cx)).await
189 }
190
191 pub fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Status> {
193 if let Some(shared) = self.inner.upgrade() {
194 let flags = shared.flags.get();
195 if flags.contains(Flags::NEED_READ) {
196 Poll::Ready(Status::Ready)
197 } else if flags.contains(Flags::SENDER_GONE | Flags::ERROR) {
198 Poll::Ready(Status::Dropped)
199 } else if flags.intersects(Flags::EOF) {
200 Poll::Ready(Status::Eof)
201 } else {
202 shared.send_task.register(cx.waker());
203 Poll::Pending
204 }
205 } else {
206 Poll::Ready(Status::Dropped)
208 }
209 }
210}
211
212bitflags::bitflags! {
213 #[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
214 struct Flags: u8 {
215 const EOF = 0b0000_0001;
216 const ERROR = 0b0000_0010;
217 const NEED_READ = 0b0000_0100;
218 const SENDER_GONE = 0b0000_1000;
219 }
220}
221
222struct Inner<E> {
223 len: Cell<usize>,
224 flags: Cell<Flags>,
225 err: Cell<Option<E>>,
226 items: RefCell<VecDeque<Bytes>>,
227 max_buffer_size: Cell<usize>,
228 recv_task: LocalWaker,
229 send_task: LocalWaker,
230}
231
232impl<E> Inner<E> {
233 fn new(eof: bool) -> Self {
234 let flags = if eof { Flags::EOF } else { Flags::NEED_READ };
235 Inner {
236 flags: Cell::new(flags),
237 len: Cell::new(0),
238 err: Cell::new(None),
239 items: RefCell::new(VecDeque::new()),
240 recv_task: LocalWaker::new(),
241 send_task: LocalWaker::new(),
242 max_buffer_size: Cell::new(MAX_BUFFER_SIZE),
243 }
244 }
245
246 fn insert_flag(&self, f: Flags) {
247 let mut flags = self.flags.get();
248 flags.insert(f);
249 self.flags.set(flags);
250 }
251
252 fn remove_flag(&self, f: Flags) {
253 let mut flags = self.flags.get();
254 flags.remove(f);
255 self.flags.set(flags);
256 }
257
258 fn set_error(&self, err: E) {
259 self.err.set(Some(err));
260 self.insert_flag(Flags::ERROR);
261 self.recv_task.wake();
262 self.send_task.wake();
263 }
264
265 fn feed_eof(&self) {
266 self.insert_flag(Flags::EOF);
267 self.recv_task.wake();
268 self.send_task.wake();
269 }
270
271 fn feed_data(&self, data: Bytes) {
272 let len = self.len.get() + data.len();
273 self.len.set(len);
274 self.items.borrow_mut().push_back(data);
275 self.recv_task.wake();
276
277 if len >= self.max_buffer_size.get() {
278 self.remove_flag(Flags::NEED_READ);
279 }
280 }
281
282 fn get_data(&self) -> Option<Bytes> {
283 self.items.borrow_mut().pop_front().inspect(|data| {
284 let len = self.len.get() - data.len();
285
286 self.len.set(len);
289 if len < self.max_buffer_size.get() {
290 self.insert_flag(Flags::NEED_READ);
291 self.send_task.wake();
292 }
293 })
294 }
295
296 fn unread_data(&self, data: Bytes) {
297 if !data.is_empty() {
298 self.len.set(self.len.get() + data.len());
299 self.items.borrow_mut().push_front(data);
300 }
301 }
302}
303
304impl<E> fmt::Debug for Inner<E> {
305 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
306 f.debug_struct("Inner")
307 .field("len", &self.len)
308 .field("flags", &self.flags)
309 .field("items", &self.items.borrow())
310 .field("max_buffer_size", &self.max_buffer_size)
311 .field("recv_task", &self.recv_task)
312 .field("send_task", &self.send_task)
313 .finish()
314 }
315}
316
317#[cfg(test)]
318mod tests {
319 use super::*;
320
321 #[ntex::test]
322 async fn test_eof() {
323 let (tx, rx) = eof::<()>();
324 rx.max_buffer_size(100);
325 assert!(rx.read().await.is_none());
326 assert_eq!(tx.ready().await, Status::Eof);
327 }
328
329 #[ntex::test]
330 async fn test_closed() {
331 let (tx, rx) = channel::<()>();
333 assert!(!tx.is_closed());
334 drop(rx);
335 assert!(tx.is_closed());
336
337 let (tx, rx) = channel::<()>();
339 drop(tx);
340 assert_eq!(rx.read().await, None);
341 }
342
343 #[ntex::test]
344 async fn test_unread_data() {
345 let (_, payload) = channel::<()>();
346
347 payload.put(Bytes::from("data"));
348 assert_eq!(payload.inner.len.get(), 4);
349 assert_eq!(
350 Bytes::from("data"),
351 poll_fn(|cx| payload.poll_read(cx)).await.unwrap().unwrap()
352 );
353 }
354
355 #[ntex::test]
356 async fn test_sender_clone() {
357 let (sender, payload) = channel::<()>();
358 assert!(!payload.is_eof());
359 let sender2 = sender.clone();
360 assert!(!payload.is_eof());
361 drop(sender2);
362 assert!(!payload.is_eof());
363 drop(sender);
364 assert!(payload.is_eof());
365 }
366}