ntex_util/channel/
bstream.rs1use std::cell::{Cell, RefCell};
3use std::task::{Context, Poll};
4use std::{collections::VecDeque, fmt, future::poll_fn, pin::Pin, rc::Rc, rc::Weak};
5
6use ntex_bytes::Bytes;
7
8use crate::{Stream, task::LocalWaker};
9
10const MAX_BUFFER_SIZE: usize = 32_768;
12
13#[derive(Copy, Clone, Debug, PartialEq, Eq)]
14pub enum Status {
15 Eof,
17 Ready,
19 Dropped,
21}
22
23pub fn channel<E>() -> (Sender<E>, Receiver<E>) {
28 let inner = Rc::new(Inner::new(false));
29
30 (
31 Sender {
32 inner: Rc::downgrade(&inner),
33 },
34 Receiver { inner },
35 )
36}
37
38pub fn eof<E>() -> (Sender<E>, Receiver<E>) {
43 let inner = Rc::new(Inner::new(true));
44
45 (
46 Sender {
47 inner: Rc::downgrade(&inner),
48 },
49 Receiver { inner },
50 )
51}
52
53pub fn empty<E>(data: Option<Bytes>) -> Receiver<E> {
55 let rx = Receiver {
56 inner: Rc::new(Inner::new(true)),
57 };
58 if let Some(data) = data {
59 rx.put(data);
60 }
61 rx
62}
63
64#[derive(Debug)]
69pub struct Receiver<E> {
70 inner: Rc<Inner<E>>,
71}
72
73impl<E> Receiver<E> {
74 #[inline]
78 pub fn max_buffer_size(&self, size: usize) {
79 self.inner.max_buffer_size.set(size);
80 }
81
82 #[inline]
84 pub fn put(&self, data: Bytes) {
85 self.inner.unread_data(data);
86 }
87
88 #[inline]
89 pub fn is_eof(&self) -> bool {
91 self.inner.flags.get().contains(Flags::EOF)
92 }
93
94 #[inline]
95 pub async fn read(&self) -> Option<Result<Bytes, E>> {
97 poll_fn(|cx| self.poll_read(cx)).await
98 }
99
100 #[inline]
101 pub fn poll_read(&self, cx: &mut Context<'_>) -> Poll<Option<Result<Bytes, E>>> {
102 if let Some(data) = self.inner.get_data() {
103 Poll::Ready(Some(Ok(data)))
104 } else if let Some(err) = self.inner.err.take() {
105 self.inner.insert_flag(Flags::EOF);
106 Poll::Ready(Some(Err(err)))
107 } else if self.inner.flags.get().intersects(Flags::EOF | Flags::ERROR) {
108 Poll::Ready(None)
109 } else {
110 self.inner.recv_task.register(cx.waker());
111 Poll::Pending
112 }
113 }
114}
115
116impl<E> Stream for Receiver<E> {
117 type Item = Result<Bytes, E>;
118
119 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
120 self.poll_read(cx)
121 }
122}
123
124impl<E> Drop for Receiver<E> {
125 fn drop(&mut self) {
126 self.inner.send_task.wake();
127 }
128}
129
130#[derive(Debug)]
135pub struct Sender<E> {
136 inner: Weak<Inner<E>>,
137}
138
139impl<E> Clone for Sender<E> {
140 fn clone(&self) -> Self {
141 Self {
142 inner: self.inner.clone(),
143 }
144 }
145}
146
147impl<E> Drop for Sender<E> {
148 fn drop(&mut self) {
149 if self.inner.weak_count() == 1 {
150 if let Some(shared) = self.inner.upgrade() {
151 shared.insert_flag(Flags::EOF | Flags::SENDER_GONE);
152 }
153 }
154 }
155}
156
157impl<E> Sender<E> {
158 pub fn set_error(&self, err: E) {
160 if let Some(shared) = self.inner.upgrade() {
161 shared.set_error(err);
162 }
163 }
164
165 pub fn feed_eof(&self) {
167 if let Some(shared) = self.inner.upgrade() {
168 shared.feed_eof();
169 }
170 }
171
172 pub fn feed_data(&self, data: Bytes) {
174 if let Some(shared) = self.inner.upgrade() {
175 shared.feed_data(data)
176 }
177 }
178
179 pub async fn ready(&self) -> Status {
181 poll_fn(|cx| self.poll_ready(cx)).await
182 }
183
184 pub fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Status> {
186 if let Some(shared) = self.inner.upgrade() {
187 let flags = shared.flags.get();
188 if flags.contains(Flags::NEED_READ) {
189 Poll::Ready(Status::Ready)
190 } else if flags.contains(Flags::SENDER_GONE | Flags::ERROR) {
191 Poll::Ready(Status::Dropped)
192 } else if flags.intersects(Flags::EOF) {
193 Poll::Ready(Status::Eof)
194 } else {
195 shared.send_task.register(cx.waker());
196 Poll::Pending
197 }
198 } else {
199 Poll::Ready(Status::Dropped)
201 }
202 }
203}
204
205bitflags::bitflags! {
206 #[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
207 struct Flags: u8 {
208 const EOF = 0b0000_0001;
209 const ERROR = 0b0000_0010;
210 const NEED_READ = 0b0000_0100;
211 const SENDER_GONE = 0b0000_1000;
212 }
213}
214
215struct Inner<E> {
216 len: Cell<usize>,
217 flags: Cell<Flags>,
218 err: Cell<Option<E>>,
219 items: RefCell<VecDeque<Bytes>>,
220 max_buffer_size: Cell<usize>,
221 recv_task: LocalWaker,
222 send_task: LocalWaker,
223}
224
225impl<E> Inner<E> {
226 fn new(eof: bool) -> Self {
227 let flags = if eof { Flags::EOF } else { Flags::NEED_READ };
228 Inner {
229 flags: Cell::new(flags),
230 len: Cell::new(0),
231 err: Cell::new(None),
232 items: RefCell::new(VecDeque::new()),
233 recv_task: LocalWaker::new(),
234 send_task: LocalWaker::new(),
235 max_buffer_size: Cell::new(MAX_BUFFER_SIZE),
236 }
237 }
238
239 fn insert_flag(&self, f: Flags) {
240 let mut flags = self.flags.get();
241 flags.insert(f);
242 self.flags.set(flags);
243 }
244
245 fn remove_flag(&self, f: Flags) {
246 let mut flags = self.flags.get();
247 flags.remove(f);
248 self.flags.set(flags);
249 }
250
251 fn set_error(&self, err: E) {
252 self.err.set(Some(err));
253 self.insert_flag(Flags::ERROR);
254 self.recv_task.wake();
255 self.send_task.wake();
256 }
257
258 fn feed_eof(&self) {
259 self.insert_flag(Flags::EOF);
260 self.recv_task.wake();
261 self.send_task.wake();
262 }
263
264 fn feed_data(&self, data: Bytes) {
265 let len = self.len.get() + data.len();
266 self.len.set(len);
267 self.items.borrow_mut().push_back(data);
268 self.recv_task.wake();
269
270 if len >= self.max_buffer_size.get() {
271 self.remove_flag(Flags::NEED_READ);
272 }
273 }
274
275 fn get_data(&self) -> Option<Bytes> {
276 self.items.borrow_mut().pop_front().inspect(|data| {
277 let len = self.len.get() - data.len();
278
279 self.len.set(len);
282 if len < self.max_buffer_size.get() {
283 self.insert_flag(Flags::NEED_READ);
284 self.send_task.wake();
285 }
286 })
287 }
288
289 fn unread_data(&self, data: Bytes) {
290 if !data.is_empty() {
291 self.len.set(self.len.get() + data.len());
292 self.items.borrow_mut().push_front(data);
293 }
294 }
295}
296
297impl<E> fmt::Debug for Inner<E> {
298 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
299 f.debug_struct("Inner")
300 .field("len", &self.len)
301 .field("flags", &self.flags)
302 .field("items", &self.items.borrow())
303 .field("max_buffer_size", &self.max_buffer_size)
304 .field("recv_task", &self.recv_task)
305 .field("send_task", &self.send_task)
306 .finish()
307 }
308}
309
310#[cfg(test)]
311mod tests {
312 use super::*;
313
314 #[ntex::test]
315 async fn test_eof() {
316 let (_, rx) = eof::<()>();
317 assert!(rx.read().await.is_none());
318 }
319
320 #[ntex::test]
321 async fn test_unread_data() {
322 let (_, payload) = channel::<()>();
323
324 payload.put(Bytes::from("data"));
325 assert_eq!(payload.inner.len.get(), 4);
326 assert_eq!(
327 Bytes::from("data"),
328 poll_fn(|cx| payload.poll_read(cx)).await.unwrap().unwrap()
329 );
330 }
331
332 #[ntex::test]
333 async fn test_sender_clone() {
334 let (sender, payload) = channel::<()>();
335 assert!(!payload.is_eof());
336 let sender2 = sender.clone();
337 assert!(!payload.is_eof());
338 drop(sender2);
339 assert!(!payload.is_eof());
340 drop(sender);
341 assert!(payload.is_eof());
342 }
343}