1use std::{any, fmt, hash, io};
2
3use ntex_bytes::BytesMut;
4use ntex_codec::{Decoder, Encoder};
5use ntex_service::cfg::SharedCfg;
6use ntex_util::time::Seconds;
7
8use crate::{
9 Decoded, Filter, FilterCtx, Flags, IoConfig, IoRef, OnDisconnect, WriteBuf, timer,
10 types,
11};
12
13impl IoRef {
14 #[inline]
15 pub fn tag(&self) -> &'static str {
17 self.0.cfg.get().tag()
18 }
19
20 #[inline]
21 #[doc(hidden)]
22 pub fn flags(&self) -> Flags {
24 self.0.flags.get()
25 }
26
27 #[inline]
28 pub(crate) fn filter(&self) -> &dyn Filter {
30 self.0.filter()
31 }
32
33 #[inline]
34 pub fn cfg(&self) -> &'static IoConfig {
36 self.0.cfg.get()
37 }
38
39 #[inline]
40 pub fn shared(&self) -> SharedCfg {
42 self.0.cfg.get().config.shared()
43 }
44
45 #[inline]
46 pub fn is_closed(&self) -> bool {
48 self.0
49 .flags
50 .get()
51 .intersects(Flags::IO_STOPPING | Flags::IO_STOPPED)
52 }
53
54 #[inline]
55 pub fn is_wr_backpressure(&self) -> bool {
57 self.0.flags.get().contains(Flags::BUF_W_BACKPRESSURE)
58 }
59
60 #[inline]
61 pub fn wake(&self) {
63 self.0.dispatch_task.wake();
64 }
65
66 #[inline]
67 pub fn close(&self) {
71 self.0.init_shutdown();
72 }
73
74 #[inline]
75 pub fn force_close(&self) {
80 log::trace!("{}: Force close io stream object", self.tag());
81 self.0.insert_flags(
82 Flags::IO_STOPPED | Flags::IO_STOPPING | Flags::IO_STOPPING_FILTERS,
83 );
84 self.0.read_task.wake();
85 self.0.write_task.wake();
86 self.0.dispatch_task.wake();
87 }
88
89 #[inline]
90 pub fn want_shutdown(&self) {
92 if !self
93 .0
94 .flags
95 .get()
96 .intersects(Flags::IO_STOPPED | Flags::IO_STOPPING | Flags::IO_STOPPING_FILTERS)
97 {
98 log::trace!(
99 "{}: Initiate io shutdown {:?}",
100 self.tag(),
101 self.0.flags.get()
102 );
103 self.0.insert_flags(Flags::IO_STOPPING_FILTERS);
104 self.0.read_task.wake();
105 }
106 }
107
108 #[inline]
109 pub fn query<T: 'static>(&self) -> types::QueryItem<T> {
111 if let Some(item) = self.filter().query(any::TypeId::of::<T>()) {
112 types::QueryItem::new(item)
113 } else {
114 types::QueryItem::empty()
115 }
116 }
117
118 #[inline]
119 pub fn encode<U>(&self, item: U::Item, codec: &U) -> Result<(), <U as Encoder>::Error>
121 where
122 U: Encoder,
123 {
124 if self.is_closed() {
125 log::trace!("{}: Io is closed/closing, skip frame encoding", self.tag());
126 Ok(())
127 } else {
128 self.with_write_buf(|buf| {
129 self.cfg().write_buf().resize(buf);
131
132 codec.encode(item, buf)
134 })
135 .unwrap_or_else(|err| {
138 log::trace!(
139 "{}: Got io error while encoding, error: {:?}",
140 self.tag(),
141 err
142 );
143 self.0.io_stopped(Some(err));
144 Ok(())
145 })
146 }
147 }
148
149 #[inline]
150 pub fn decode<U>(
152 &self,
153 codec: &U,
154 ) -> Result<Option<<U as Decoder>::Item>, <U as Decoder>::Error>
155 where
156 U: Decoder,
157 {
158 self.0
159 .buffer
160 .with_read_destination(self, |buf| codec.decode(buf))
161 }
162
163 #[inline]
164 pub fn decode_item<U>(
166 &self,
167 codec: &U,
168 ) -> Result<Decoded<<U as Decoder>::Item>, <U as Decoder>::Error>
169 where
170 U: Decoder,
171 {
172 self.0.buffer.with_read_destination(self, |buf| {
173 let len = buf.len();
174 codec.decode(buf).map(|item| Decoded {
175 item,
176 remains: buf.len(),
177 consumed: len - buf.len(),
178 })
179 })
180 }
181
182 #[inline]
183 pub fn write(&self, src: &[u8]) -> io::Result<()> {
185 self.with_write_buf(|buf| buf.extend_from_slice(src))
186 }
187
188 #[inline]
189 pub fn with_buf<F, R>(&self, f: F) -> io::Result<R>
191 where
192 F: FnOnce(&WriteBuf<'_>) -> R,
193 {
194 let ctx = FilterCtx::new(self, &self.0.buffer);
195 let result = ctx.write_buf(f);
196 self.0.filter().process_write_buf(ctx)?;
197 Ok(result)
198 }
199
200 #[inline]
201 pub fn with_write_buf<F, R>(&self, f: F) -> io::Result<R>
203 where
204 F: FnOnce(&mut BytesMut) -> R,
205 {
206 if self.0.flags.get().contains(Flags::IO_STOPPED) {
207 Err(self.0.error_or_disconnected())
208 } else {
209 let result = self.0.buffer.with_write_source(self, f);
210 self.0
211 .filter()
212 .process_write_buf(FilterCtx::new(self, &self.0.buffer))?;
213 Ok(result)
214 }
215 }
216
217 #[doc(hidden)]
218 #[inline]
219 pub fn with_write_dest_buf<F, R>(&self, f: F) -> R
221 where
222 F: FnOnce(Option<&mut BytesMut>) -> R,
223 {
224 self.0.buffer.with_write_destination(self, f)
225 }
226
227 #[inline]
228 pub fn with_read_buf<F, R>(&self, f: F) -> R
230 where
231 F: FnOnce(&mut BytesMut) -> R,
232 {
233 self.0.buffer.with_read_destination(self, f)
234 }
235
236 #[inline]
237 pub fn notify_dispatcher(&self) {
239 self.0.dispatch_task.wake();
240 log::trace!("{}: Timer, notify dispatcher", self.tag());
241 }
242
243 #[inline]
244 pub fn notify_timeout(&self) {
246 self.0.notify_timeout();
247 }
248
249 #[inline]
250 pub fn timer_handle(&self) -> timer::TimerHandle {
252 self.0.timeout.get()
253 }
254
255 #[inline]
256 pub fn start_timer(&self, timeout: Seconds) -> timer::TimerHandle {
258 let cur_hnd = self.0.timeout.get();
259
260 if timeout.is_zero() {
261 if cur_hnd.is_set() {
262 self.0.timeout.set(timer::TimerHandle::ZERO);
263 timer::unregister(cur_hnd, self);
264 }
265 timer::TimerHandle::ZERO
266 } else if cur_hnd.is_set() {
267 let hnd = timer::update(cur_hnd, timeout, self);
268 if hnd != cur_hnd {
269 log::trace!("{}: Update timer {:?}", self.tag(), timeout);
270 self.0.timeout.set(hnd);
271 }
272 hnd
273 } else {
274 log::trace!("{}: Start timer {:?}", self.tag(), timeout);
275 let hnd = timer::register(timeout, self);
276 self.0.timeout.set(hnd);
277 hnd
278 }
279 }
280
281 #[inline]
282 pub fn stop_timer(&self) {
284 let hnd = self.0.timeout.get();
285 if hnd.is_set() {
286 log::trace!("{}: Stop timer", self.tag());
287 self.0.timeout.set(timer::TimerHandle::ZERO);
288 timer::unregister(hnd, self);
289 }
290 }
291
292 #[inline]
293 pub fn on_disconnect(&self) -> OnDisconnect {
295 OnDisconnect::new(self.0.clone())
296 }
297}
298
299impl Eq for IoRef {}
300
301impl PartialEq for IoRef {
302 #[inline]
303 fn eq(&self, other: &Self) -> bool {
304 self.0.eq(&other.0)
305 }
306}
307
308impl hash::Hash for IoRef {
309 #[inline]
310 fn hash<H: hash::Hasher>(&self, state: &mut H) {
311 self.0.hash(state);
312 }
313}
314
315impl fmt::Debug for IoRef {
316 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
317 f.debug_struct("IoRef")
318 .field("state", self.0.as_ref())
319 .finish()
320 }
321}
322
323#[cfg(test)]
324mod tests {
325 use std::cell::{Cell, RefCell};
326 use std::{future::Future, future::poll_fn, pin::Pin, rc::Rc, task::Poll};
327
328 use ntex_bytes::Bytes;
329 use ntex_codec::BytesCodec;
330 use ntex_util::future::lazy;
331 use ntex_util::time::{Millis, sleep};
332
333 use super::*;
334 use crate::{FilterCtx, FilterReadStatus, Io, testing::IoTest};
335
336 const BIN: &[u8] = b"GET /test HTTP/1\r\n\r\n";
337 const TEXT: &str = "GET /test HTTP/1\r\n\r\n";
338
339 #[ntex::test]
340 async fn utils() {
341 let (client, server) = IoTest::create();
342 client.remote_buffer_cap(1024);
343 client.write(TEXT);
344
345 let state = Io::from(server);
346 assert_eq!(state.get_ref(), state.get_ref());
347
348 let msg = state.recv(&BytesCodec).await.unwrap().unwrap();
349 assert_eq!(msg, Bytes::from_static(BIN));
350 assert_eq!(state.get_ref(), state.as_ref().clone());
351 assert!(format!("{state:?}").find("Io {").is_some());
352 assert!(format!("{:?}", state.get_ref()).find("IoRef {").is_some());
353
354 let res = poll_fn(|cx| Poll::Ready(state.poll_recv(&BytesCodec, cx))).await;
355 assert!(res.is_pending());
356 client.write(TEXT);
357 sleep(Millis(50)).await;
358 let res = poll_fn(|cx| Poll::Ready(state.poll_recv(&BytesCodec, cx))).await;
359 if let Poll::Ready(msg) = res {
360 assert_eq!(msg.unwrap(), Bytes::from_static(BIN));
361 }
362
363 client.read_error(io::Error::other("err"));
364 let msg = state.recv(&BytesCodec).await;
365 assert!(msg.is_err());
366 assert!(state.flags().contains(Flags::IO_STOPPED));
367
368 let (client, server) = IoTest::create();
369 client.remote_buffer_cap(1024);
370 let state = Io::from(server);
371
372 client.read_error(io::Error::other("err"));
373 let res = poll_fn(|cx| Poll::Ready(state.poll_recv(&BytesCodec, cx))).await;
374 if let Poll::Ready(msg) = res {
375 assert!(msg.is_err());
376 assert!(state.flags().contains(Flags::IO_STOPPED));
377 }
378
379 let (client, server) = IoTest::create();
380 client.remote_buffer_cap(1024);
381 let state = Io::from(server);
382 state.write(b"test").unwrap();
383 let buf = client.read().await.unwrap();
384 assert_eq!(buf, Bytes::from_static(b"test"));
385
386 client.write(b"test");
387 state.read_ready().await.unwrap();
388 let buf = state.decode(&BytesCodec).unwrap().unwrap();
389 assert_eq!(buf, Bytes::from_static(b"test"));
390
391 client.write_error(io::Error::other("err"));
392 let res = state.send(Bytes::from_static(b"test"), &BytesCodec).await;
393 assert!(res.is_err());
394 assert!(state.flags().contains(Flags::IO_STOPPED));
395
396 let (client, server) = IoTest::create();
397 client.remote_buffer_cap(1024);
398 let state = Io::from(server);
399 state.force_close();
400 assert!(state.flags().contains(Flags::IO_STOPPED));
401 assert!(state.flags().contains(Flags::IO_STOPPING));
402 }
403
404 #[ntex::test]
405 async fn read_readiness() {
406 let (client, server) = IoTest::create();
407 client.remote_buffer_cap(1024);
408
409 let io = Io::from(server);
410 assert!(lazy(|cx| io.poll_read_ready(cx)).await.is_pending());
411
412 client.write(TEXT);
413 assert_eq!(io.read_ready().await.unwrap(), Some(()));
414 assert!(lazy(|cx| io.poll_read_ready(cx)).await.is_pending());
415
416 let item = io.with_read_buf(BytesMut::take);
417 assert_eq!(item, Bytes::from_static(BIN));
418
419 client.write(TEXT);
420 sleep(Millis(50)).await;
421 assert!(lazy(|cx| io.poll_read_ready(cx)).await.is_ready());
422 assert!(lazy(|cx| io.poll_read_ready(cx)).await.is_pending());
423 }
424
425 #[ntex::test]
426 #[allow(clippy::unit_cmp)]
427 async fn on_disconnect() {
428 let (client, server) = IoTest::create();
429 let state = Io::from(server);
430 let mut waiter = state.on_disconnect();
431 assert_eq!(
432 lazy(|cx| Pin::new(&mut waiter).poll(cx)).await,
433 Poll::Pending
434 );
435 let mut waiter2 = waiter.clone();
436 assert_eq!(
437 lazy(|cx| Pin::new(&mut waiter2).poll(cx)).await,
438 Poll::Pending
439 );
440 client.close().await;
441 assert_eq!(waiter.await, ());
442 assert_eq!(waiter2.await, ());
443
444 let mut waiter = state.on_disconnect();
445 assert_eq!(
446 lazy(|cx| Pin::new(&mut waiter).poll(cx)).await,
447 Poll::Ready(())
448 );
449
450 let (client, server) = IoTest::create();
451 let state = Io::from(server);
452 let mut waiter = state.on_disconnect();
453 assert_eq!(
454 lazy(|cx| Pin::new(&mut waiter).poll(cx)).await,
455 Poll::Pending
456 );
457 client.read_error(io::Error::other("err"));
458 assert_eq!(waiter.await, ());
459 }
460
461 #[ntex::test]
462 async fn write_to_closed_io() {
463 let (client, server) = IoTest::create();
464 let state = Io::from(server);
465 client.close().await;
466
467 assert!(state.is_closed());
468 assert!(state.write(TEXT.as_bytes()).is_err());
469 assert!(
470 state
471 .with_write_buf(|buf| buf.extend_from_slice(BIN))
472 .is_err()
473 );
474 }
475
476 #[derive(Debug)]
477 struct Counter<F> {
478 layer: F,
479 idx: usize,
480 in_bytes: Rc<Cell<usize>>,
481 out_bytes: Rc<Cell<usize>>,
482 read_order: Rc<RefCell<Vec<usize>>>,
483 write_order: Rc<RefCell<Vec<usize>>>,
484 }
485
486 impl<F: Filter> Filter for Counter<F> {
487 fn process_read_buf(
488 &self,
489 ctx: FilterCtx<'_>,
490 nbytes: usize,
491 ) -> io::Result<FilterReadStatus> {
492 self.read_order.borrow_mut().push(self.idx);
493 self.in_bytes.set(self.in_bytes.get() + nbytes);
494 self.layer.process_read_buf(ctx, nbytes)
495 }
496
497 fn process_write_buf(&self, ctx: FilterCtx<'_>) -> io::Result<()> {
498 self.write_order.borrow_mut().push(self.idx);
499 self.out_bytes.set(
500 self.out_bytes.get()
501 + ctx.write_buf(|buf| {
502 buf.with_src(|b| b.as_ref().map(BytesMut::len).unwrap_or_default())
503 }),
504 );
505 self.layer.process_write_buf(ctx)
506 }
507
508 crate::forward_ready!(layer);
509 crate::forward_query!(layer);
510 crate::forward_shutdown!(layer);
511 }
512
513 #[ntex::test]
514 async fn filter() {
515 let in_bytes = Rc::new(Cell::new(0));
516 let out_bytes = Rc::new(Cell::new(0));
517 let read_order = Rc::new(RefCell::new(Vec::new()));
518 let write_order = Rc::new(RefCell::new(Vec::new()));
519
520 let (client, server) = IoTest::create();
521 let io = Io::from(server).map_filter(|layer| Counter {
522 layer,
523 idx: 1,
524 in_bytes: in_bytes.clone(),
525 out_bytes: out_bytes.clone(),
526 read_order: read_order.clone(),
527 write_order: write_order.clone(),
528 });
529
530 client.remote_buffer_cap(1024);
531 client.write(TEXT);
532 let msg = io.recv(&BytesCodec).await.unwrap().unwrap();
533 assert_eq!(msg, Bytes::from_static(BIN));
534
535 io.send(Bytes::from_static(b"test"), &BytesCodec)
536 .await
537 .unwrap();
538 let buf = client.read().await.unwrap();
539 assert_eq!(buf, Bytes::from_static(b"test"));
540
541 assert_eq!(in_bytes.get(), BIN.len());
542 assert_eq!(out_bytes.get(), 4);
543 }
544
545 #[ntex::test]
546 async fn boxed_filter() {
547 let in_bytes = Rc::new(Cell::new(0));
548 let out_bytes = Rc::new(Cell::new(0));
549 let read_order = Rc::new(RefCell::new(Vec::new()));
550 let write_order = Rc::new(RefCell::new(Vec::new()));
551
552 let (client, server) = IoTest::create();
553 let state = Io::from(server)
554 .map_filter(|layer| Counter {
555 layer,
556 idx: 2,
557 in_bytes: in_bytes.clone(),
558 out_bytes: out_bytes.clone(),
559 read_order: read_order.clone(),
560 write_order: write_order.clone(),
561 })
562 .map_filter(|layer| Counter {
563 layer,
564 idx: 1,
565 in_bytes: in_bytes.clone(),
566 out_bytes: out_bytes.clone(),
567 read_order: read_order.clone(),
568 write_order: write_order.clone(),
569 });
570 let state = state.seal();
571
572 client.remote_buffer_cap(1024);
573 client.write(TEXT);
574 let msg = state.recv(&BytesCodec).await.unwrap().unwrap();
575 assert_eq!(msg, Bytes::from_static(BIN));
576
577 state
578 .send(Bytes::from_static(b"test"), &BytesCodec)
579 .await
580 .unwrap();
581 let buf = client.read().await.unwrap();
582 assert_eq!(buf, Bytes::from_static(b"test"));
583
584 assert_eq!(in_bytes.get(), BIN.len() * 2);
585 assert_eq!(out_bytes.get(), 8);
586 assert_eq!(state.with_write_dest_buf(|b| b.map_or(0, |b| b.len())), 0);
587
588 assert_eq!(Rc::strong_count(&in_bytes), 3);
590 drop(state);
591 assert_eq!(Rc::strong_count(&in_bytes), 1);
592 assert_eq!(*read_order.borrow(), &[1, 2][..]);
593 assert_eq!(*write_order.borrow(), &[1, 2][..]);
594 }
595}