1use std::{any, fmt, hash, io};
2
3use ntex_bytes::{BytesVec, PoolRef};
4use ntex_codec::{Decoder, Encoder};
5use ntex_util::time::Seconds;
6
7use crate::{
8 timer, types, Decoded, Filter, FilterCtx, Flags, IoRef, OnDisconnect, WriteBuf,
9};
10
11impl IoRef {
12 #[inline]
13 #[doc(hidden)]
14 pub fn flags(&self) -> Flags {
16 self.0.flags.get()
17 }
18
19 #[inline]
20 pub(crate) fn filter(&self) -> &dyn Filter {
22 self.0.filter()
23 }
24
25 #[inline]
26 pub fn memory_pool(&self) -> PoolRef {
28 self.0.pool.get()
29 }
30
31 #[inline]
32 pub fn is_closed(&self) -> bool {
34 self.0
35 .flags
36 .get()
37 .intersects(Flags::IO_STOPPING | Flags::IO_STOPPED)
38 }
39
40 #[inline]
41 pub fn is_wr_backpressure(&self) -> bool {
43 self.0.flags.get().contains(Flags::BUF_W_BACKPRESSURE)
44 }
45
46 #[inline]
47 pub fn wake(&self) {
49 self.0.dispatch_task.wake();
50 }
51
52 #[inline]
53 pub fn close(&self) {
57 self.0.insert_flags(Flags::DSP_STOP);
58 self.0.init_shutdown();
59 }
60
61 #[inline]
62 pub fn force_close(&self) {
67 log::trace!("{}: Force close io stream object", self.tag());
68 self.0.insert_flags(
69 Flags::DSP_STOP
70 | Flags::IO_STOPPED
71 | Flags::IO_STOPPING
72 | Flags::IO_STOPPING_FILTERS,
73 );
74 self.0.read_task.wake();
75 self.0.write_task.wake();
76 self.0.dispatch_task.wake();
77 }
78
79 #[inline]
80 pub fn want_shutdown(&self) {
82 if !self
83 .0
84 .flags
85 .get()
86 .intersects(Flags::IO_STOPPED | Flags::IO_STOPPING | Flags::IO_STOPPING_FILTERS)
87 {
88 log::trace!(
89 "{}: Initiate io shutdown {:?}",
90 self.tag(),
91 self.0.flags.get()
92 );
93 self.0.insert_flags(Flags::IO_STOPPING_FILTERS);
94 self.0.read_task.wake();
95 }
96 }
97
98 #[inline]
99 pub fn query<T: 'static>(&self) -> types::QueryItem<T> {
101 if let Some(item) = self.filter().query(any::TypeId::of::<T>()) {
102 types::QueryItem::new(item)
103 } else {
104 types::QueryItem::empty()
105 }
106 }
107
108 #[inline]
109 pub fn encode<U>(&self, item: U::Item, codec: &U) -> Result<(), <U as Encoder>::Error>
111 where
112 U: Encoder,
113 {
114 if !self.is_closed() {
115 self.with_write_buf(|buf| {
116 self.memory_pool().resize_write_buf(buf);
118
119 codec.encode_vec(item, buf)
121 })
122 .unwrap_or_else(|err| {
125 log::trace!(
126 "{}: Got io error while encoding, error: {:?}",
127 self.tag(),
128 err
129 );
130 self.0.io_stopped(Some(err));
131 Ok(())
132 })
133 } else {
134 log::trace!("{}: Io is closed/closing, skip frame encoding", self.tag());
135 Ok(())
136 }
137 }
138
139 #[inline]
140 pub fn decode<U>(
142 &self,
143 codec: &U,
144 ) -> Result<Option<<U as Decoder>::Item>, <U as Decoder>::Error>
145 where
146 U: Decoder,
147 {
148 self.0
149 .buffer
150 .with_read_destination(self, |buf| codec.decode_vec(buf))
151 }
152
153 #[inline]
154 pub fn decode_item<U>(
156 &self,
157 codec: &U,
158 ) -> Result<Decoded<<U as Decoder>::Item>, <U as Decoder>::Error>
159 where
160 U: Decoder,
161 {
162 self.0.buffer.with_read_destination(self, |buf| {
163 let len = buf.len();
164 codec.decode_vec(buf).map(|item| Decoded {
165 item,
166 remains: buf.len(),
167 consumed: len - buf.len(),
168 })
169 })
170 }
171
172 #[inline]
173 pub fn write(&self, src: &[u8]) -> io::Result<()> {
175 self.with_write_buf(|buf| buf.extend_from_slice(src))
176 }
177
178 #[inline]
179 pub fn with_buf<F, R>(&self, f: F) -> io::Result<R>
181 where
182 F: FnOnce(&WriteBuf<'_>) -> R,
183 {
184 let ctx = FilterCtx::new(self, &self.0.buffer);
185 let result = ctx.write_buf(f);
186 self.0.filter().process_write_buf(ctx)?;
187 Ok(result)
188 }
189
190 #[inline]
191 pub fn with_write_buf<F, R>(&self, f: F) -> io::Result<R>
193 where
194 F: FnOnce(&mut BytesVec) -> R,
195 {
196 if self.0.flags.get().contains(Flags::IO_STOPPED) {
197 Err(self.0.error_or_disconnected())
198 } else {
199 let result = self.0.buffer.with_write_source(self, f);
200 self.0
201 .filter()
202 .process_write_buf(FilterCtx::new(self, &self.0.buffer))?;
203 Ok(result)
204 }
205 }
206
207 #[doc(hidden)]
208 #[inline]
209 pub fn with_write_dest_buf<F, R>(&self, f: F) -> R
211 where
212 F: FnOnce(Option<&mut BytesVec>) -> R,
213 {
214 self.0.buffer.with_write_destination(self, f)
215 }
216
217 #[inline]
218 pub fn with_read_buf<F, R>(&self, f: F) -> R
220 where
221 F: FnOnce(&mut BytesVec) -> R,
222 {
223 self.0.buffer.with_read_destination(self, f)
224 }
225
226 #[inline]
227 pub fn notify_dispatcher(&self) {
229 self.0.dispatch_task.wake();
230 log::trace!("{}: Timer, notify dispatcher", self.tag());
231 }
232
233 #[inline]
234 pub fn notify_timeout(&self) {
236 self.0.notify_timeout()
237 }
238
239 #[doc(hidden)]
240 #[deprecated]
241 #[inline]
242 pub fn notify_keepalive(&self) {
244 self.0.notify_timeout()
245 }
246
247 #[inline]
248 pub fn timer_handle(&self) -> timer::TimerHandle {
250 self.0.timeout.get()
251 }
252
253 #[inline]
254 pub fn start_timer(&self, timeout: Seconds) -> timer::TimerHandle {
256 let cur_hnd = self.0.timeout.get();
257
258 if !timeout.is_zero() {
259 if cur_hnd.is_set() {
260 let hnd = timer::update(cur_hnd, timeout, self);
261 if hnd != cur_hnd {
262 log::trace!("{}: Update timer {:?}", self.tag(), timeout);
263 self.0.timeout.set(hnd);
264 }
265 hnd
266 } else {
267 log::trace!("{}: Start timer {:?}", self.tag(), timeout);
268 let hnd = timer::register(timeout, self);
269 self.0.timeout.set(hnd);
270 hnd
271 }
272 } else {
273 if cur_hnd.is_set() {
274 self.0.timeout.set(timer::TimerHandle::ZERO);
275 timer::unregister(cur_hnd, self);
276 }
277 timer::TimerHandle::ZERO
278 }
279 }
280
281 #[inline]
282 pub fn stop_timer(&self) {
284 let hnd = self.0.timeout.get();
285 if hnd.is_set() {
286 log::trace!("{}: Stop timer", self.tag());
287 self.0.timeout.set(timer::TimerHandle::ZERO);
288 timer::unregister(hnd, self)
289 }
290 }
291
292 #[inline]
293 pub fn tag(&self) -> &'static str {
295 self.0.tag.get()
296 }
297
298 #[inline]
299 pub fn set_tag(&self, tag: &'static str) {
301 self.0.tag.set(tag)
302 }
303
304 #[inline]
305 pub fn on_disconnect(&self) -> OnDisconnect {
307 OnDisconnect::new(self.0.clone())
308 }
309}
310
311impl Eq for IoRef {}
312
313impl PartialEq for IoRef {
314 #[inline]
315 fn eq(&self, other: &Self) -> bool {
316 self.0.eq(&other.0)
317 }
318}
319
320impl hash::Hash for IoRef {
321 #[inline]
322 fn hash<H: hash::Hasher>(&self, state: &mut H) {
323 self.0.hash(state);
324 }
325}
326
327impl fmt::Debug for IoRef {
328 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
329 f.debug_struct("IoRef")
330 .field("state", self.0.as_ref())
331 .finish()
332 }
333}
334
335#[cfg(test)]
336mod tests {
337 use std::cell::{Cell, RefCell};
338 use std::{future::poll_fn, future::Future, pin::Pin, rc::Rc, task::Poll};
339
340 use ntex_bytes::Bytes;
341 use ntex_codec::BytesCodec;
342 use ntex_util::future::lazy;
343 use ntex_util::time::{sleep, Millis};
344
345 use super::*;
346 use crate::{testing::IoTest, FilterCtx, FilterReadStatus, Io};
347
348 const BIN: &[u8] = b"GET /test HTTP/1\r\n\r\n";
349 const TEXT: &str = "GET /test HTTP/1\r\n\r\n";
350
351 #[ntex::test]
352 async fn utils() {
353 let (client, server) = IoTest::create();
354 client.remote_buffer_cap(1024);
355 client.write(TEXT);
356
357 let state = Io::new(server);
358 assert_eq!(state.get_ref(), state.get_ref());
359
360 let msg = state.recv(&BytesCodec).await.unwrap().unwrap();
361 assert_eq!(msg, Bytes::from_static(BIN));
362 assert_eq!(state.get_ref(), state.as_ref().clone());
363 assert!(format!("{state:?}").find("Io {").is_some());
364 assert!(format!("{:?}", state.get_ref()).find("IoRef {").is_some());
365
366 let res = poll_fn(|cx| Poll::Ready(state.poll_recv(&BytesCodec, cx))).await;
367 assert!(res.is_pending());
368 client.write(TEXT);
369 sleep(Millis(50)).await;
370 let res = poll_fn(|cx| Poll::Ready(state.poll_recv(&BytesCodec, cx))).await;
371 if let Poll::Ready(msg) = res {
372 assert_eq!(msg.unwrap(), Bytes::from_static(BIN));
373 }
374
375 client.read_error(io::Error::other("err"));
376 let msg = state.recv(&BytesCodec).await;
377 assert!(msg.is_err());
378 assert!(state.flags().contains(Flags::IO_STOPPED));
379
380 let (client, server) = IoTest::create();
381 client.remote_buffer_cap(1024);
382 let state = Io::new(server);
383
384 client.read_error(io::Error::other("err"));
385 let res = poll_fn(|cx| Poll::Ready(state.poll_recv(&BytesCodec, cx))).await;
386 if let Poll::Ready(msg) = res {
387 assert!(msg.is_err());
388 assert!(state.flags().contains(Flags::IO_STOPPED));
389 assert!(state.flags().contains(Flags::DSP_STOP));
390 }
391
392 let (client, server) = IoTest::create();
393 client.remote_buffer_cap(1024);
394 let state = Io::new(server);
395 state.write(b"test").unwrap();
396 let buf = client.read().await.unwrap();
397 assert_eq!(buf, Bytes::from_static(b"test"));
398
399 client.write(b"test");
400 state.read_ready().await.unwrap();
401 let buf = state.decode(&BytesCodec).unwrap().unwrap();
402 assert_eq!(buf, Bytes::from_static(b"test"));
403
404 client.write_error(io::Error::other("err"));
405 let res = state.send(Bytes::from_static(b"test"), &BytesCodec).await;
406 assert!(res.is_err());
407 assert!(state.flags().contains(Flags::IO_STOPPED));
408
409 let (client, server) = IoTest::create();
410 client.remote_buffer_cap(1024);
411 let state = Io::new(server);
412 state.force_close();
413 assert!(state.flags().contains(Flags::DSP_STOP));
414 assert!(state.flags().contains(Flags::IO_STOPPING));
415 }
416
417 #[ntex::test]
418 async fn read_readiness() {
419 let (client, server) = IoTest::create();
420 client.remote_buffer_cap(1024);
421
422 let io = Io::new(server);
423 assert!(lazy(|cx| io.poll_read_ready(cx)).await.is_pending());
424
425 client.write(TEXT);
426 assert_eq!(io.read_ready().await.unwrap(), Some(()));
427 assert!(lazy(|cx| io.poll_read_ready(cx)).await.is_pending());
428
429 let item = io.with_read_buf(|buffer| buffer.split());
430 assert_eq!(item, Bytes::from_static(BIN));
431
432 client.write(TEXT);
433 sleep(Millis(50)).await;
434 assert!(lazy(|cx| io.poll_read_ready(cx)).await.is_ready());
435 assert!(lazy(|cx| io.poll_read_ready(cx)).await.is_pending());
436 }
437
438 #[ntex::test]
439 #[allow(clippy::unit_cmp)]
440 async fn on_disconnect() {
441 let (client, server) = IoTest::create();
442 let state = Io::new(server);
443 let mut waiter = state.on_disconnect();
444 assert_eq!(
445 lazy(|cx| Pin::new(&mut waiter).poll(cx)).await,
446 Poll::Pending
447 );
448 let mut waiter2 = waiter.clone();
449 assert_eq!(
450 lazy(|cx| Pin::new(&mut waiter2).poll(cx)).await,
451 Poll::Pending
452 );
453 client.close().await;
454 assert_eq!(waiter.await, ());
455 assert_eq!(waiter2.await, ());
456
457 let mut waiter = state.on_disconnect();
458 assert_eq!(
459 lazy(|cx| Pin::new(&mut waiter).poll(cx)).await,
460 Poll::Ready(())
461 );
462
463 let (client, server) = IoTest::create();
464 let state = Io::new(server);
465 let mut waiter = state.on_disconnect();
466 assert_eq!(
467 lazy(|cx| Pin::new(&mut waiter).poll(cx)).await,
468 Poll::Pending
469 );
470 client.read_error(io::Error::other("err"));
471 assert_eq!(waiter.await, ());
472 }
473
474 #[ntex::test]
475 async fn write_to_closed_io() {
476 let (client, server) = IoTest::create();
477 let state = Io::new(server);
478 client.close().await;
479
480 assert!(state.is_closed());
481 assert!(state.write(TEXT.as_bytes()).is_err());
482 assert!(state
483 .with_write_buf(|buf| buf.extend_from_slice(BIN))
484 .is_err());
485 }
486
487 #[derive(Debug)]
488 struct Counter<F> {
489 layer: F,
490 idx: usize,
491 in_bytes: Rc<Cell<usize>>,
492 out_bytes: Rc<Cell<usize>>,
493 read_order: Rc<RefCell<Vec<usize>>>,
494 write_order: Rc<RefCell<Vec<usize>>>,
495 }
496
497 impl<F: Filter> Filter for Counter<F> {
498 fn process_read_buf(
499 &self,
500 ctx: FilterCtx<'_>,
501 nbytes: usize,
502 ) -> io::Result<FilterReadStatus> {
503 self.read_order.borrow_mut().push(self.idx);
504 self.in_bytes.set(self.in_bytes.get() + nbytes);
505 self.layer.process_read_buf(ctx, nbytes)
506 }
507
508 fn process_write_buf(&self, ctx: FilterCtx<'_>) -> io::Result<()> {
509 self.write_order.borrow_mut().push(self.idx);
510 self.out_bytes.set(
511 self.out_bytes.get()
512 + ctx.write_buf(|buf| {
513 buf.with_src(|b| b.as_ref().map(|b| b.len()).unwrap_or_default())
514 }),
515 );
516 self.layer.process_write_buf(ctx)
517 }
518
519 crate::forward_ready!(layer);
520 crate::forward_query!(layer);
521 crate::forward_shutdown!(layer);
522 }
523
524 #[ntex::test]
525 async fn filter() {
526 let in_bytes = Rc::new(Cell::new(0));
527 let out_bytes = Rc::new(Cell::new(0));
528 let read_order = Rc::new(RefCell::new(Vec::new()));
529 let write_order = Rc::new(RefCell::new(Vec::new()));
530
531 let (client, server) = IoTest::create();
532 let io = Io::new(server).map_filter(|layer| Counter {
533 layer,
534 idx: 1,
535 in_bytes: in_bytes.clone(),
536 out_bytes: out_bytes.clone(),
537 read_order: read_order.clone(),
538 write_order: write_order.clone(),
539 });
540
541 client.remote_buffer_cap(1024);
542 client.write(TEXT);
543 let msg = io.recv(&BytesCodec).await.unwrap().unwrap();
544 assert_eq!(msg, Bytes::from_static(BIN));
545
546 io.send(Bytes::from_static(b"test"), &BytesCodec)
547 .await
548 .unwrap();
549 let buf = client.read().await.unwrap();
550 assert_eq!(buf, Bytes::from_static(b"test"));
551
552 assert_eq!(in_bytes.get(), BIN.len());
553 assert_eq!(out_bytes.get(), 4);
554 }
555
556 #[ntex::test]
557 async fn boxed_filter() {
558 let in_bytes = Rc::new(Cell::new(0));
559 let out_bytes = Rc::new(Cell::new(0));
560 let read_order = Rc::new(RefCell::new(Vec::new()));
561 let write_order = Rc::new(RefCell::new(Vec::new()));
562
563 let (client, server) = IoTest::create();
564 let state = Io::new(server)
565 .map_filter(|layer| Counter {
566 layer,
567 idx: 2,
568 in_bytes: in_bytes.clone(),
569 out_bytes: out_bytes.clone(),
570 read_order: read_order.clone(),
571 write_order: write_order.clone(),
572 })
573 .map_filter(|layer| Counter {
574 layer,
575 idx: 1,
576 in_bytes: in_bytes.clone(),
577 out_bytes: out_bytes.clone(),
578 read_order: read_order.clone(),
579 write_order: write_order.clone(),
580 });
581 let state = state.seal();
582
583 client.remote_buffer_cap(1024);
584 client.write(TEXT);
585 let msg = state.recv(&BytesCodec).await.unwrap().unwrap();
586 assert_eq!(msg, Bytes::from_static(BIN));
587
588 state
589 .send(Bytes::from_static(b"test"), &BytesCodec)
590 .await
591 .unwrap();
592 let buf = client.read().await.unwrap();
593 assert_eq!(buf, Bytes::from_static(b"test"));
594
595 assert_eq!(in_bytes.get(), BIN.len() * 2);
596 assert_eq!(out_bytes.get(), 8);
597 assert_eq!(
598 state.with_write_dest_buf(|b| b.map(|b| b.len()).unwrap_or(0)),
599 0
600 );
601
602 assert_eq!(Rc::strong_count(&in_bytes), 3);
604 drop(state);
605 assert_eq!(Rc::strong_count(&in_bytes), 1);
606 assert_eq!(*read_order.borrow(), &[1, 2][..]);
607 assert_eq!(*write_order.borrow(), &[1, 2][..]);
608 }
609}