1use std::{any, fmt, hash, io};
2
3use ntex_bytes::{BytesVec, PoolRef};
4use ntex_codec::{Decoder, Encoder};
5use ntex_util::time::Seconds;
6
7use crate::{
8 timer, types, Decoded, Filter, FilterCtx, Flags, IoRef, OnDisconnect, WriteBuf,
9};
10
11impl IoRef {
12 #[inline]
13 #[doc(hidden)]
14 pub fn flags(&self) -> Flags {
16 self.0.flags.get()
17 }
18
19 #[inline]
20 pub(crate) fn filter(&self) -> &dyn Filter {
22 self.0.filter()
23 }
24
25 #[inline]
26 pub fn memory_pool(&self) -> PoolRef {
28 self.0.pool.get()
29 }
30
31 #[inline]
32 pub fn is_closed(&self) -> bool {
34 self.0
35 .flags
36 .get()
37 .intersects(Flags::IO_STOPPING | Flags::IO_STOPPED)
38 }
39
40 #[inline]
41 pub fn is_wr_backpressure(&self) -> bool {
43 self.0.flags.get().contains(Flags::BUF_W_BACKPRESSURE)
44 }
45
46 #[inline]
47 pub fn wake(&self) {
49 self.0.dispatch_task.wake();
50 }
51
52 #[inline]
53 pub fn close(&self) {
57 self.0.insert_flags(Flags::DSP_STOP);
58 self.0.init_shutdown();
59 }
60
61 #[inline]
62 pub fn force_close(&self) {
67 log::trace!("{}: Force close io stream object", self.tag());
68 self.0.insert_flags(
69 Flags::DSP_STOP
70 | Flags::IO_STOPPED
71 | Flags::IO_STOPPING
72 | Flags::IO_STOPPING_FILTERS,
73 );
74 self.0.read_task.wake();
75 self.0.write_task.wake();
76 self.0.dispatch_task.wake();
77 }
78
79 #[inline]
80 pub fn want_shutdown(&self) {
82 if !self
83 .0
84 .flags
85 .get()
86 .intersects(Flags::IO_STOPPED | Flags::IO_STOPPING | Flags::IO_STOPPING_FILTERS)
87 {
88 log::trace!(
89 "{}: Initiate io shutdown {:?}",
90 self.tag(),
91 self.0.flags.get()
92 );
93 self.0.insert_flags(Flags::IO_STOPPING_FILTERS);
94 self.0.read_task.wake();
95 }
96 }
97
98 #[inline]
99 pub fn query<T: 'static>(&self) -> types::QueryItem<T> {
101 if let Some(item) = self.filter().query(any::TypeId::of::<T>()) {
102 types::QueryItem::new(item)
103 } else {
104 types::QueryItem::empty()
105 }
106 }
107
108 #[inline]
109 pub fn encode<U>(&self, item: U::Item, codec: &U) -> Result<(), <U as Encoder>::Error>
111 where
112 U: Encoder,
113 {
114 if !self.is_closed() {
115 self.with_write_buf(|buf| {
116 self.memory_pool().resize_write_buf(buf);
118
119 codec.encode_vec(item, buf)
121 })
122 .unwrap_or_else(|err| {
125 log::trace!(
126 "{}: Got io error while encoding, error: {:?}",
127 self.tag(),
128 err
129 );
130 self.0.io_stopped(Some(err));
131 Ok(())
132 })
133 } else {
134 log::trace!("{}: Io is closed/closing, skip frame encoding", self.tag());
135 Ok(())
136 }
137 }
138
139 #[inline]
140 pub fn decode<U>(
142 &self,
143 codec: &U,
144 ) -> Result<Option<<U as Decoder>::Item>, <U as Decoder>::Error>
145 where
146 U: Decoder,
147 {
148 self.0
149 .buffer
150 .with_read_destination(self, |buf| codec.decode_vec(buf))
151 }
152
153 #[inline]
154 pub fn decode_item<U>(
156 &self,
157 codec: &U,
158 ) -> Result<Decoded<<U as Decoder>::Item>, <U as Decoder>::Error>
159 where
160 U: Decoder,
161 {
162 self.0.buffer.with_read_destination(self, |buf| {
163 let len = buf.len();
164 codec.decode_vec(buf).map(|item| Decoded {
165 item,
166 remains: buf.len(),
167 consumed: len - buf.len(),
168 })
169 })
170 }
171
172 #[inline]
173 pub fn write(&self, src: &[u8]) -> io::Result<()> {
175 self.with_write_buf(|buf| buf.extend_from_slice(src))
176 }
177
178 #[inline]
179 pub fn with_buf<F, R>(&self, f: F) -> io::Result<R>
181 where
182 F: FnOnce(&WriteBuf<'_>) -> R,
183 {
184 let ctx = FilterCtx::new(self, &self.0.buffer);
185 let result = ctx.write_buf(f);
186 self.0.filter().process_write_buf(ctx)?;
187 Ok(result)
188 }
189
190 #[inline]
191 pub fn with_write_buf<F, R>(&self, f: F) -> io::Result<R>
193 where
194 F: FnOnce(&mut BytesVec) -> R,
195 {
196 if self.0.flags.get().contains(Flags::IO_STOPPED) {
197 Err(self.0.error_or_disconnected())
198 } else {
199 let result = self.0.buffer.with_write_source(self, f);
200 self.0
201 .filter()
202 .process_write_buf(FilterCtx::new(self, &self.0.buffer))?;
203 Ok(result)
204 }
205 }
206
207 #[doc(hidden)]
208 #[inline]
209 pub fn with_write_dest_buf<F, R>(&self, f: F) -> R
211 where
212 F: FnOnce(Option<&mut BytesVec>) -> R,
213 {
214 self.0.buffer.with_write_destination(self, f)
215 }
216
217 #[inline]
218 pub fn with_read_buf<F, R>(&self, f: F) -> R
220 where
221 F: FnOnce(&mut BytesVec) -> R,
222 {
223 self.0.buffer.with_read_destination(self, f)
224 }
225
226 #[inline]
227 pub fn notify_dispatcher(&self) {
229 self.0.dispatch_task.wake();
230 log::trace!("{}: Timer, notify dispatcher", self.tag());
231 }
232
233 #[inline]
234 pub fn notify_timeout(&self) {
236 self.0.notify_timeout()
237 }
238
239 #[doc(hidden)]
240 #[inline]
241 pub fn is_keepalive(&self) -> bool {
243 self.0.flags.get().is_keepalive()
244 }
245
246 #[inline]
247 pub fn notify_keepalive(&self) {
249 self.0.insert_flags(Flags::DSP_KEEPALIVE);
250 self.0.notify_timeout()
251 }
252
253 #[inline]
254 pub fn timer_handle(&self) -> timer::TimerHandle {
256 self.0.timeout.get()
257 }
258
259 #[inline]
260 pub fn start_timer(&self, timeout: Seconds) -> timer::TimerHandle {
262 let cur_hnd = self.0.timeout.get();
263
264 if !timeout.is_zero() {
265 if cur_hnd.is_set() {
266 let hnd = timer::update(cur_hnd, timeout, self);
267 if hnd != cur_hnd {
268 log::trace!("{}: Update timer {:?}", self.tag(), timeout);
269 self.0.timeout.set(hnd);
270 }
271 hnd
272 } else {
273 log::trace!("{}: Start timer {:?}", self.tag(), timeout);
274 let hnd = timer::register(timeout, self);
275 self.0.timeout.set(hnd);
276 hnd
277 }
278 } else {
279 if cur_hnd.is_set() {
280 self.0.timeout.set(timer::TimerHandle::ZERO);
281 timer::unregister(cur_hnd, self);
282 }
283 timer::TimerHandle::ZERO
284 }
285 }
286
287 #[inline]
288 pub fn stop_timer(&self) {
290 let hnd = self.0.timeout.get();
291 if hnd.is_set() {
292 log::trace!("{}: Stop timer", self.tag());
293 self.0.timeout.set(timer::TimerHandle::ZERO);
294 timer::unregister(hnd, self)
295 }
296 }
297
298 #[inline]
299 pub fn tag(&self) -> &'static str {
301 self.0.tag.get()
302 }
303
304 #[inline]
305 pub fn set_tag(&self, tag: &'static str) {
307 self.0.tag.set(tag)
308 }
309
310 #[inline]
311 pub fn on_disconnect(&self) -> OnDisconnect {
313 OnDisconnect::new(self.0.clone())
314 }
315}
316
317impl Eq for IoRef {}
318
319impl PartialEq for IoRef {
320 #[inline]
321 fn eq(&self, other: &Self) -> bool {
322 self.0.eq(&other.0)
323 }
324}
325
326impl hash::Hash for IoRef {
327 #[inline]
328 fn hash<H: hash::Hasher>(&self, state: &mut H) {
329 self.0.hash(state);
330 }
331}
332
333impl fmt::Debug for IoRef {
334 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
335 f.debug_struct("IoRef")
336 .field("state", self.0.as_ref())
337 .finish()
338 }
339}
340
341#[cfg(test)]
342mod tests {
343 use std::cell::{Cell, RefCell};
344 use std::{future::poll_fn, future::Future, pin::Pin, rc::Rc, task::Poll};
345
346 use ntex_bytes::Bytes;
347 use ntex_codec::BytesCodec;
348 use ntex_util::future::lazy;
349 use ntex_util::time::{sleep, Millis};
350
351 use super::*;
352 use crate::{testing::IoTest, FilterCtx, FilterReadStatus, Io};
353
354 const BIN: &[u8] = b"GET /test HTTP/1\r\n\r\n";
355 const TEXT: &str = "GET /test HTTP/1\r\n\r\n";
356
357 #[ntex::test]
358 async fn utils() {
359 let (client, server) = IoTest::create();
360 client.remote_buffer_cap(1024);
361 client.write(TEXT);
362
363 let state = Io::new(server);
364 assert_eq!(state.get_ref(), state.get_ref());
365
366 let msg = state.recv(&BytesCodec).await.unwrap().unwrap();
367 assert_eq!(msg, Bytes::from_static(BIN));
368 assert_eq!(state.get_ref(), state.as_ref().clone());
369 assert!(format!("{state:?}").find("Io {").is_some());
370 assert!(format!("{:?}", state.get_ref()).find("IoRef {").is_some());
371
372 let res = poll_fn(|cx| Poll::Ready(state.poll_recv(&BytesCodec, cx))).await;
373 assert!(res.is_pending());
374 client.write(TEXT);
375 sleep(Millis(50)).await;
376 let res = poll_fn(|cx| Poll::Ready(state.poll_recv(&BytesCodec, cx))).await;
377 if let Poll::Ready(msg) = res {
378 assert_eq!(msg.unwrap(), Bytes::from_static(BIN));
379 }
380
381 client.read_error(io::Error::other("err"));
382 let msg = state.recv(&BytesCodec).await;
383 assert!(msg.is_err());
384 assert!(state.flags().contains(Flags::IO_STOPPED));
385
386 let (client, server) = IoTest::create();
387 client.remote_buffer_cap(1024);
388 let state = Io::new(server);
389
390 client.read_error(io::Error::other("err"));
391 let res = poll_fn(|cx| Poll::Ready(state.poll_recv(&BytesCodec, cx))).await;
392 if let Poll::Ready(msg) = res {
393 assert!(msg.is_err());
394 assert!(state.flags().contains(Flags::IO_STOPPED));
395 assert!(state.flags().contains(Flags::DSP_STOP));
396 }
397
398 let (client, server) = IoTest::create();
399 client.remote_buffer_cap(1024);
400 let state = Io::new(server);
401 state.write(b"test").unwrap();
402 let buf = client.read().await.unwrap();
403 assert_eq!(buf, Bytes::from_static(b"test"));
404
405 client.write(b"test");
406 state.read_ready().await.unwrap();
407 let buf = state.decode(&BytesCodec).unwrap().unwrap();
408 assert_eq!(buf, Bytes::from_static(b"test"));
409
410 client.write_error(io::Error::other("err"));
411 let res = state.send(Bytes::from_static(b"test"), &BytesCodec).await;
412 assert!(res.is_err());
413 assert!(state.flags().contains(Flags::IO_STOPPED));
414
415 let (client, server) = IoTest::create();
416 client.remote_buffer_cap(1024);
417 let state = Io::new(server);
418 state.force_close();
419 assert!(state.flags().contains(Flags::DSP_STOP));
420 assert!(state.flags().contains(Flags::IO_STOPPING));
421 }
422
423 #[ntex::test]
424 async fn read_readiness() {
425 let (client, server) = IoTest::create();
426 client.remote_buffer_cap(1024);
427
428 let io = Io::new(server);
429 assert!(lazy(|cx| io.poll_read_ready(cx)).await.is_pending());
430
431 client.write(TEXT);
432 assert_eq!(io.read_ready().await.unwrap(), Some(()));
433 assert!(lazy(|cx| io.poll_read_ready(cx)).await.is_pending());
434
435 let item = io.with_read_buf(|buffer| buffer.split());
436 assert_eq!(item, Bytes::from_static(BIN));
437
438 client.write(TEXT);
439 sleep(Millis(50)).await;
440 assert!(lazy(|cx| io.poll_read_ready(cx)).await.is_ready());
441 assert!(lazy(|cx| io.poll_read_ready(cx)).await.is_pending());
442 }
443
444 #[ntex::test]
445 #[allow(clippy::unit_cmp)]
446 async fn on_disconnect() {
447 let (client, server) = IoTest::create();
448 let state = Io::new(server);
449 let mut waiter = state.on_disconnect();
450 assert_eq!(
451 lazy(|cx| Pin::new(&mut waiter).poll(cx)).await,
452 Poll::Pending
453 );
454 let mut waiter2 = waiter.clone();
455 assert_eq!(
456 lazy(|cx| Pin::new(&mut waiter2).poll(cx)).await,
457 Poll::Pending
458 );
459 client.close().await;
460 assert_eq!(waiter.await, ());
461 assert_eq!(waiter2.await, ());
462
463 let mut waiter = state.on_disconnect();
464 assert_eq!(
465 lazy(|cx| Pin::new(&mut waiter).poll(cx)).await,
466 Poll::Ready(())
467 );
468
469 let (client, server) = IoTest::create();
470 let state = Io::new(server);
471 let mut waiter = state.on_disconnect();
472 assert_eq!(
473 lazy(|cx| Pin::new(&mut waiter).poll(cx)).await,
474 Poll::Pending
475 );
476 client.read_error(io::Error::other("err"));
477 assert_eq!(waiter.await, ());
478 }
479
480 #[ntex::test]
481 async fn write_to_closed_io() {
482 let (client, server) = IoTest::create();
483 let state = Io::new(server);
484 client.close().await;
485
486 assert!(state.is_closed());
487 assert!(state.write(TEXT.as_bytes()).is_err());
488 assert!(state
489 .with_write_buf(|buf| buf.extend_from_slice(BIN))
490 .is_err());
491 }
492
493 #[derive(Debug)]
494 struct Counter<F> {
495 layer: F,
496 idx: usize,
497 in_bytes: Rc<Cell<usize>>,
498 out_bytes: Rc<Cell<usize>>,
499 read_order: Rc<RefCell<Vec<usize>>>,
500 write_order: Rc<RefCell<Vec<usize>>>,
501 }
502
503 impl<F: Filter> Filter for Counter<F> {
504 fn process_read_buf(
505 &self,
506 ctx: FilterCtx<'_>,
507 nbytes: usize,
508 ) -> io::Result<FilterReadStatus> {
509 self.read_order.borrow_mut().push(self.idx);
510 self.in_bytes.set(self.in_bytes.get() + nbytes);
511 self.layer.process_read_buf(ctx, nbytes)
512 }
513
514 fn process_write_buf(&self, ctx: FilterCtx<'_>) -> io::Result<()> {
515 self.write_order.borrow_mut().push(self.idx);
516 self.out_bytes.set(
517 self.out_bytes.get()
518 + ctx.write_buf(|buf| {
519 buf.with_src(|b| b.as_ref().map(|b| b.len()).unwrap_or_default())
520 }),
521 );
522 self.layer.process_write_buf(ctx)
523 }
524
525 crate::forward_ready!(layer);
526 crate::forward_query!(layer);
527 crate::forward_shutdown!(layer);
528 }
529
530 #[ntex::test]
531 async fn filter() {
532 let in_bytes = Rc::new(Cell::new(0));
533 let out_bytes = Rc::new(Cell::new(0));
534 let read_order = Rc::new(RefCell::new(Vec::new()));
535 let write_order = Rc::new(RefCell::new(Vec::new()));
536
537 let (client, server) = IoTest::create();
538 let io = Io::new(server).map_filter(|layer| Counter {
539 layer,
540 idx: 1,
541 in_bytes: in_bytes.clone(),
542 out_bytes: out_bytes.clone(),
543 read_order: read_order.clone(),
544 write_order: write_order.clone(),
545 });
546
547 client.remote_buffer_cap(1024);
548 client.write(TEXT);
549 let msg = io.recv(&BytesCodec).await.unwrap().unwrap();
550 assert_eq!(msg, Bytes::from_static(BIN));
551
552 io.send(Bytes::from_static(b"test"), &BytesCodec)
553 .await
554 .unwrap();
555 let buf = client.read().await.unwrap();
556 assert_eq!(buf, Bytes::from_static(b"test"));
557
558 assert_eq!(in_bytes.get(), BIN.len());
559 assert_eq!(out_bytes.get(), 4);
560 }
561
562 #[ntex::test]
563 async fn boxed_filter() {
564 let in_bytes = Rc::new(Cell::new(0));
565 let out_bytes = Rc::new(Cell::new(0));
566 let read_order = Rc::new(RefCell::new(Vec::new()));
567 let write_order = Rc::new(RefCell::new(Vec::new()));
568
569 let (client, server) = IoTest::create();
570 let state = Io::new(server)
571 .map_filter(|layer| Counter {
572 layer,
573 idx: 2,
574 in_bytes: in_bytes.clone(),
575 out_bytes: out_bytes.clone(),
576 read_order: read_order.clone(),
577 write_order: write_order.clone(),
578 })
579 .map_filter(|layer| Counter {
580 layer,
581 idx: 1,
582 in_bytes: in_bytes.clone(),
583 out_bytes: out_bytes.clone(),
584 read_order: read_order.clone(),
585 write_order: write_order.clone(),
586 });
587 let state = state.seal();
588
589 client.remote_buffer_cap(1024);
590 client.write(TEXT);
591 let msg = state.recv(&BytesCodec).await.unwrap().unwrap();
592 assert_eq!(msg, Bytes::from_static(BIN));
593
594 state
595 .send(Bytes::from_static(b"test"), &BytesCodec)
596 .await
597 .unwrap();
598 let buf = client.read().await.unwrap();
599 assert_eq!(buf, Bytes::from_static(b"test"));
600
601 assert_eq!(in_bytes.get(), BIN.len() * 2);
602 assert_eq!(out_bytes.get(), 8);
603 assert_eq!(
604 state.with_write_dest_buf(|b| b.map(|b| b.len()).unwrap_or(0)),
605 0
606 );
607
608 assert_eq!(Rc::strong_count(&in_bytes), 3);
610 drop(state);
611 assert_eq!(Rc::strong_count(&in_bytes), 1);
612 assert_eq!(*read_order.borrow(), &[1, 2][..]);
613 assert_eq!(*write_order.borrow(), &[1, 2][..]);
614 }
615}