1use std::{any, fmt, hash, io};
2
3use ntex_bytes::BytesVec;
4use ntex_codec::{Decoder, Encoder};
5use ntex_service::cfg::SharedCfg;
6use ntex_util::time::Seconds;
7
8use crate::{
9 Decoded, Filter, FilterCtx, Flags, IoConfig, IoRef, OnDisconnect, WriteBuf, timer,
10 types,
11};
12
13impl IoRef {
14 #[inline]
15 pub fn tag(&self) -> &'static str {
17 self.0.cfg.get().tag()
18 }
19
20 #[inline]
21 #[doc(hidden)]
22 pub fn flags(&self) -> Flags {
24 self.0.flags.get()
25 }
26
27 #[inline]
28 pub(crate) fn filter(&self) -> &dyn Filter {
30 self.0.filter()
31 }
32
33 #[inline]
34 pub fn cfg(&self) -> &IoConfig {
36 self.0.cfg.get()
37 }
38
39 #[inline]
40 pub fn shared(&self) -> SharedCfg {
42 self.0.cfg.get().config.shared()
43 }
44
45 #[inline]
46 pub fn is_closed(&self) -> bool {
48 self.0
49 .flags
50 .get()
51 .intersects(Flags::IO_STOPPING | Flags::IO_STOPPED)
52 }
53
54 #[inline]
55 pub fn is_wr_backpressure(&self) -> bool {
57 self.0.flags.get().contains(Flags::BUF_W_BACKPRESSURE)
58 }
59
60 #[inline]
61 pub fn wake(&self) {
63 self.0.dispatch_task.wake();
64 }
65
66 #[inline]
67 pub fn close(&self) {
71 self.0.init_shutdown();
72 }
73
74 #[inline]
75 pub fn force_close(&self) {
80 log::trace!("{}: Force close io stream object", self.tag());
81 self.0.insert_flags(
82 Flags::IO_STOPPED | Flags::IO_STOPPING | Flags::IO_STOPPING_FILTERS,
83 );
84 self.0.read_task.wake();
85 self.0.write_task.wake();
86 self.0.dispatch_task.wake();
87 }
88
89 #[inline]
90 pub fn want_shutdown(&self) {
92 if !self
93 .0
94 .flags
95 .get()
96 .intersects(Flags::IO_STOPPED | Flags::IO_STOPPING | Flags::IO_STOPPING_FILTERS)
97 {
98 log::trace!(
99 "{}: Initiate io shutdown {:?}",
100 self.tag(),
101 self.0.flags.get()
102 );
103 self.0.insert_flags(Flags::IO_STOPPING_FILTERS);
104 self.0.read_task.wake();
105 }
106 }
107
108 #[inline]
109 pub fn query<T: 'static>(&self) -> types::QueryItem<T> {
111 if let Some(item) = self.filter().query(any::TypeId::of::<T>()) {
112 types::QueryItem::new(item)
113 } else {
114 types::QueryItem::empty()
115 }
116 }
117
118 #[inline]
119 pub fn encode<U>(&self, item: U::Item, codec: &U) -> Result<(), <U as Encoder>::Error>
121 where
122 U: Encoder,
123 {
124 if !self.is_closed() {
125 self.with_write_buf(|buf| {
126 self.cfg().write_buf().resize(buf);
128
129 codec.encode_vec(item, buf)
131 })
132 .unwrap_or_else(|err| {
135 log::trace!(
136 "{}: Got io error while encoding, error: {:?}",
137 self.tag(),
138 err
139 );
140 self.0.io_stopped(Some(err));
141 Ok(())
142 })
143 } else {
144 log::trace!("{}: Io is closed/closing, skip frame encoding", self.tag());
145 Ok(())
146 }
147 }
148
149 #[inline]
150 pub fn decode<U>(
152 &self,
153 codec: &U,
154 ) -> Result<Option<<U as Decoder>::Item>, <U as Decoder>::Error>
155 where
156 U: Decoder,
157 {
158 self.0
159 .buffer
160 .with_read_destination(self, |buf| codec.decode_vec(buf))
161 }
162
163 #[inline]
164 pub fn decode_item<U>(
166 &self,
167 codec: &U,
168 ) -> Result<Decoded<<U as Decoder>::Item>, <U as Decoder>::Error>
169 where
170 U: Decoder,
171 {
172 self.0.buffer.with_read_destination(self, |buf| {
173 let len = buf.len();
174 codec.decode_vec(buf).map(|item| Decoded {
175 item,
176 remains: buf.len(),
177 consumed: len - buf.len(),
178 })
179 })
180 }
181
182 #[inline]
183 pub fn write(&self, src: &[u8]) -> io::Result<()> {
185 self.with_write_buf(|buf| buf.extend_from_slice(src))
186 }
187
188 #[inline]
189 pub fn with_buf<F, R>(&self, f: F) -> io::Result<R>
191 where
192 F: FnOnce(&WriteBuf<'_>) -> R,
193 {
194 let ctx = FilterCtx::new(self, &self.0.buffer);
195 let result = ctx.write_buf(f);
196 self.0.filter().process_write_buf(ctx)?;
197 Ok(result)
198 }
199
200 #[inline]
201 pub fn with_write_buf<F, R>(&self, f: F) -> io::Result<R>
203 where
204 F: FnOnce(&mut BytesVec) -> R,
205 {
206 if self.0.flags.get().contains(Flags::IO_STOPPED) {
207 Err(self.0.error_or_disconnected())
208 } else {
209 let result = self.0.buffer.with_write_source(self, f);
210 self.0
211 .filter()
212 .process_write_buf(FilterCtx::new(self, &self.0.buffer))?;
213 Ok(result)
214 }
215 }
216
217 #[doc(hidden)]
218 #[inline]
219 pub fn with_write_dest_buf<F, R>(&self, f: F) -> R
221 where
222 F: FnOnce(Option<&mut BytesVec>) -> R,
223 {
224 self.0.buffer.with_write_destination(self, f)
225 }
226
227 #[inline]
228 pub fn with_read_buf<F, R>(&self, f: F) -> R
230 where
231 F: FnOnce(&mut BytesVec) -> R,
232 {
233 self.0.buffer.with_read_destination(self, f)
234 }
235
236 #[inline]
237 pub fn notify_dispatcher(&self) {
239 self.0.dispatch_task.wake();
240 log::trace!("{}: Timer, notify dispatcher", self.tag());
241 }
242
243 #[inline]
244 pub fn notify_timeout(&self) {
246 self.0.notify_timeout()
247 }
248
249 #[inline]
250 pub fn timer_handle(&self) -> timer::TimerHandle {
252 self.0.timeout.get()
253 }
254
255 #[inline]
256 pub fn start_timer(&self, timeout: Seconds) -> timer::TimerHandle {
258 let cur_hnd = self.0.timeout.get();
259
260 if !timeout.is_zero() {
261 if cur_hnd.is_set() {
262 let hnd = timer::update(cur_hnd, timeout, self);
263 if hnd != cur_hnd {
264 log::trace!("{}: Update timer {:?}", self.tag(), timeout);
265 self.0.timeout.set(hnd);
266 }
267 hnd
268 } else {
269 log::trace!("{}: Start timer {:?}", self.tag(), timeout);
270 let hnd = timer::register(timeout, self);
271 self.0.timeout.set(hnd);
272 hnd
273 }
274 } else {
275 if cur_hnd.is_set() {
276 self.0.timeout.set(timer::TimerHandle::ZERO);
277 timer::unregister(cur_hnd, self);
278 }
279 timer::TimerHandle::ZERO
280 }
281 }
282
283 #[inline]
284 pub fn stop_timer(&self) {
286 let hnd = self.0.timeout.get();
287 if hnd.is_set() {
288 log::trace!("{}: Stop timer", self.tag());
289 self.0.timeout.set(timer::TimerHandle::ZERO);
290 timer::unregister(hnd, self)
291 }
292 }
293
294 #[inline]
295 pub fn on_disconnect(&self) -> OnDisconnect {
297 OnDisconnect::new(self.0.clone())
298 }
299}
300
301impl Eq for IoRef {}
302
303impl PartialEq for IoRef {
304 #[inline]
305 fn eq(&self, other: &Self) -> bool {
306 self.0.eq(&other.0)
307 }
308}
309
310impl hash::Hash for IoRef {
311 #[inline]
312 fn hash<H: hash::Hasher>(&self, state: &mut H) {
313 self.0.hash(state);
314 }
315}
316
317impl fmt::Debug for IoRef {
318 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
319 f.debug_struct("IoRef")
320 .field("state", self.0.as_ref())
321 .finish()
322 }
323}
324
325#[cfg(test)]
326mod tests {
327 use std::cell::{Cell, RefCell};
328 use std::{future::Future, future::poll_fn, pin::Pin, rc::Rc, task::Poll};
329
330 use ntex_bytes::Bytes;
331 use ntex_codec::BytesCodec;
332 use ntex_util::future::lazy;
333 use ntex_util::time::{Millis, sleep};
334
335 use super::*;
336 use crate::{FilterCtx, FilterReadStatus, Io, testing::IoTest};
337
338 const BIN: &[u8] = b"GET /test HTTP/1\r\n\r\n";
339 const TEXT: &str = "GET /test HTTP/1\r\n\r\n";
340
341 #[ntex::test]
342 async fn utils() {
343 let (client, server) = IoTest::create();
344 client.remote_buffer_cap(1024);
345 client.write(TEXT);
346
347 let state = Io::from(server);
348 assert_eq!(state.get_ref(), state.get_ref());
349
350 let msg = state.recv(&BytesCodec).await.unwrap().unwrap();
351 assert_eq!(msg, Bytes::from_static(BIN));
352 assert_eq!(state.get_ref(), state.as_ref().clone());
353 assert!(format!("{state:?}").find("Io {").is_some());
354 assert!(format!("{:?}", state.get_ref()).find("IoRef {").is_some());
355
356 let res = poll_fn(|cx| Poll::Ready(state.poll_recv(&BytesCodec, cx))).await;
357 assert!(res.is_pending());
358 client.write(TEXT);
359 sleep(Millis(50)).await;
360 let res = poll_fn(|cx| Poll::Ready(state.poll_recv(&BytesCodec, cx))).await;
361 if let Poll::Ready(msg) = res {
362 assert_eq!(msg.unwrap(), Bytes::from_static(BIN));
363 }
364
365 client.read_error(io::Error::other("err"));
366 let msg = state.recv(&BytesCodec).await;
367 assert!(msg.is_err());
368 assert!(state.flags().contains(Flags::IO_STOPPED));
369
370 let (client, server) = IoTest::create();
371 client.remote_buffer_cap(1024);
372 let state = Io::from(server);
373
374 client.read_error(io::Error::other("err"));
375 let res = poll_fn(|cx| Poll::Ready(state.poll_recv(&BytesCodec, cx))).await;
376 if let Poll::Ready(msg) = res {
377 assert!(msg.is_err());
378 assert!(state.flags().contains(Flags::IO_STOPPED));
379 }
380
381 let (client, server) = IoTest::create();
382 client.remote_buffer_cap(1024);
383 let state = Io::from(server);
384 state.write(b"test").unwrap();
385 let buf = client.read().await.unwrap();
386 assert_eq!(buf, Bytes::from_static(b"test"));
387
388 client.write(b"test");
389 state.read_ready().await.unwrap();
390 let buf = state.decode(&BytesCodec).unwrap().unwrap();
391 assert_eq!(buf, Bytes::from_static(b"test"));
392
393 client.write_error(io::Error::other("err"));
394 let res = state.send(Bytes::from_static(b"test"), &BytesCodec).await;
395 assert!(res.is_err());
396 assert!(state.flags().contains(Flags::IO_STOPPED));
397
398 let (client, server) = IoTest::create();
399 client.remote_buffer_cap(1024);
400 let state = Io::from(server);
401 state.force_close();
402 assert!(state.flags().contains(Flags::IO_STOPPED));
403 assert!(state.flags().contains(Flags::IO_STOPPING));
404 }
405
406 #[ntex::test]
407 async fn read_readiness() {
408 let (client, server) = IoTest::create();
409 client.remote_buffer_cap(1024);
410
411 let io = Io::from(server);
412 assert!(lazy(|cx| io.poll_read_ready(cx)).await.is_pending());
413
414 client.write(TEXT);
415 assert_eq!(io.read_ready().await.unwrap(), Some(()));
416 assert!(lazy(|cx| io.poll_read_ready(cx)).await.is_pending());
417
418 let item = io.with_read_buf(|buffer| buffer.split());
419 assert_eq!(item, Bytes::from_static(BIN));
420
421 client.write(TEXT);
422 sleep(Millis(50)).await;
423 assert!(lazy(|cx| io.poll_read_ready(cx)).await.is_ready());
424 assert!(lazy(|cx| io.poll_read_ready(cx)).await.is_pending());
425 }
426
427 #[ntex::test]
428 #[allow(clippy::unit_cmp)]
429 async fn on_disconnect() {
430 let (client, server) = IoTest::create();
431 let state = Io::from(server);
432 let mut waiter = state.on_disconnect();
433 assert_eq!(
434 lazy(|cx| Pin::new(&mut waiter).poll(cx)).await,
435 Poll::Pending
436 );
437 let mut waiter2 = waiter.clone();
438 assert_eq!(
439 lazy(|cx| Pin::new(&mut waiter2).poll(cx)).await,
440 Poll::Pending
441 );
442 client.close().await;
443 assert_eq!(waiter.await, ());
444 assert_eq!(waiter2.await, ());
445
446 let mut waiter = state.on_disconnect();
447 assert_eq!(
448 lazy(|cx| Pin::new(&mut waiter).poll(cx)).await,
449 Poll::Ready(())
450 );
451
452 let (client, server) = IoTest::create();
453 let state = Io::from(server);
454 let mut waiter = state.on_disconnect();
455 assert_eq!(
456 lazy(|cx| Pin::new(&mut waiter).poll(cx)).await,
457 Poll::Pending
458 );
459 client.read_error(io::Error::other("err"));
460 assert_eq!(waiter.await, ());
461 }
462
463 #[ntex::test]
464 async fn write_to_closed_io() {
465 let (client, server) = IoTest::create();
466 let state = Io::from(server);
467 client.close().await;
468
469 assert!(state.is_closed());
470 assert!(state.write(TEXT.as_bytes()).is_err());
471 assert!(
472 state
473 .with_write_buf(|buf| buf.extend_from_slice(BIN))
474 .is_err()
475 );
476 }
477
478 #[derive(Debug)]
479 struct Counter<F> {
480 layer: F,
481 idx: usize,
482 in_bytes: Rc<Cell<usize>>,
483 out_bytes: Rc<Cell<usize>>,
484 read_order: Rc<RefCell<Vec<usize>>>,
485 write_order: Rc<RefCell<Vec<usize>>>,
486 }
487
488 impl<F: Filter> Filter for Counter<F> {
489 fn process_read_buf(
490 &self,
491 ctx: FilterCtx<'_>,
492 nbytes: usize,
493 ) -> io::Result<FilterReadStatus> {
494 self.read_order.borrow_mut().push(self.idx);
495 self.in_bytes.set(self.in_bytes.get() + nbytes);
496 self.layer.process_read_buf(ctx, nbytes)
497 }
498
499 fn process_write_buf(&self, ctx: FilterCtx<'_>) -> io::Result<()> {
500 self.write_order.borrow_mut().push(self.idx);
501 self.out_bytes.set(
502 self.out_bytes.get()
503 + ctx.write_buf(|buf| {
504 buf.with_src(|b| b.as_ref().map(|b| b.len()).unwrap_or_default())
505 }),
506 );
507 self.layer.process_write_buf(ctx)
508 }
509
510 crate::forward_ready!(layer);
511 crate::forward_query!(layer);
512 crate::forward_shutdown!(layer);
513 }
514
515 #[ntex::test]
516 async fn filter() {
517 let in_bytes = Rc::new(Cell::new(0));
518 let out_bytes = Rc::new(Cell::new(0));
519 let read_order = Rc::new(RefCell::new(Vec::new()));
520 let write_order = Rc::new(RefCell::new(Vec::new()));
521
522 let (client, server) = IoTest::create();
523 let io = Io::from(server).map_filter(|layer| Counter {
524 layer,
525 idx: 1,
526 in_bytes: in_bytes.clone(),
527 out_bytes: out_bytes.clone(),
528 read_order: read_order.clone(),
529 write_order: write_order.clone(),
530 });
531
532 client.remote_buffer_cap(1024);
533 client.write(TEXT);
534 let msg = io.recv(&BytesCodec).await.unwrap().unwrap();
535 assert_eq!(msg, Bytes::from_static(BIN));
536
537 io.send(Bytes::from_static(b"test"), &BytesCodec)
538 .await
539 .unwrap();
540 let buf = client.read().await.unwrap();
541 assert_eq!(buf, Bytes::from_static(b"test"));
542
543 assert_eq!(in_bytes.get(), BIN.len());
544 assert_eq!(out_bytes.get(), 4);
545 }
546
547 #[ntex::test]
548 async fn boxed_filter() {
549 let in_bytes = Rc::new(Cell::new(0));
550 let out_bytes = Rc::new(Cell::new(0));
551 let read_order = Rc::new(RefCell::new(Vec::new()));
552 let write_order = Rc::new(RefCell::new(Vec::new()));
553
554 let (client, server) = IoTest::create();
555 let state = Io::from(server)
556 .map_filter(|layer| Counter {
557 layer,
558 idx: 2,
559 in_bytes: in_bytes.clone(),
560 out_bytes: out_bytes.clone(),
561 read_order: read_order.clone(),
562 write_order: write_order.clone(),
563 })
564 .map_filter(|layer| Counter {
565 layer,
566 idx: 1,
567 in_bytes: in_bytes.clone(),
568 out_bytes: out_bytes.clone(),
569 read_order: read_order.clone(),
570 write_order: write_order.clone(),
571 });
572 let state = state.seal();
573
574 client.remote_buffer_cap(1024);
575 client.write(TEXT);
576 let msg = state.recv(&BytesCodec).await.unwrap().unwrap();
577 assert_eq!(msg, Bytes::from_static(BIN));
578
579 state
580 .send(Bytes::from_static(b"test"), &BytesCodec)
581 .await
582 .unwrap();
583 let buf = client.read().await.unwrap();
584 assert_eq!(buf, Bytes::from_static(b"test"));
585
586 assert_eq!(in_bytes.get(), BIN.len() * 2);
587 assert_eq!(out_bytes.get(), 8);
588 assert_eq!(
589 state.with_write_dest_buf(|b| b.map(|b| b.len()).unwrap_or(0)),
590 0
591 );
592
593 assert_eq!(Rc::strong_count(&in_bytes), 3);
595 drop(state);
596 assert_eq!(Rc::strong_count(&in_bytes), 1);
597 assert_eq!(*read_order.borrow(), &[1, 2][..]);
598 assert_eq!(*write_order.borrow(), &[1, 2][..]);
599 }
600}