1use std::{any, fmt, hash, io};
2
3use ntex_bytes::BytesVec;
4use ntex_codec::{Decoder, Encoder};
5use ntex_service::cfg::SharedCfg;
6use ntex_util::time::Seconds;
7
8use crate::{
9 Decoded, Filter, FilterCtx, Flags, IoConfig, IoRef, OnDisconnect, WriteBuf, timer,
10 types,
11};
12
13impl IoRef {
14 #[inline]
15 pub fn tag(&self) -> &'static str {
17 self.0.cfg.get().tag()
18 }
19
20 #[inline]
21 #[doc(hidden)]
22 pub fn flags(&self) -> Flags {
24 self.0.flags.get()
25 }
26
27 #[inline]
28 pub(crate) fn filter(&self) -> &dyn Filter {
30 self.0.filter()
31 }
32
33 #[inline]
34 pub fn cfg(&self) -> &IoConfig {
36 self.0.cfg.get()
37 }
38
39 #[inline]
40 pub fn shared(&self) -> SharedCfg {
42 self.0.cfg.get().config.shared()
43 }
44
45 #[inline]
46 pub fn is_closed(&self) -> bool {
48 self.0
49 .flags
50 .get()
51 .intersects(Flags::IO_STOPPING | Flags::IO_STOPPED)
52 }
53
54 #[inline]
55 pub fn is_wr_backpressure(&self) -> bool {
57 self.0.flags.get().contains(Flags::BUF_W_BACKPRESSURE)
58 }
59
60 #[inline]
61 pub fn wake(&self) {
63 self.0.dispatch_task.wake();
64 }
65
66 #[inline]
67 pub fn close(&self) {
71 self.0.insert_flags(Flags::DSP_STOP);
72 self.0.init_shutdown();
73 }
74
75 #[inline]
76 pub fn force_close(&self) {
81 log::trace!("{}: Force close io stream object", self.tag());
82 self.0.insert_flags(
83 Flags::DSP_STOP
84 | Flags::IO_STOPPED
85 | Flags::IO_STOPPING
86 | Flags::IO_STOPPING_FILTERS,
87 );
88 self.0.read_task.wake();
89 self.0.write_task.wake();
90 self.0.dispatch_task.wake();
91 }
92
93 #[inline]
94 pub fn want_shutdown(&self) {
96 if !self
97 .0
98 .flags
99 .get()
100 .intersects(Flags::IO_STOPPED | Flags::IO_STOPPING | Flags::IO_STOPPING_FILTERS)
101 {
102 log::trace!(
103 "{}: Initiate io shutdown {:?}",
104 self.tag(),
105 self.0.flags.get()
106 );
107 self.0.insert_flags(Flags::IO_STOPPING_FILTERS);
108 self.0.read_task.wake();
109 }
110 }
111
112 #[inline]
113 pub fn query<T: 'static>(&self) -> types::QueryItem<T> {
115 if let Some(item) = self.filter().query(any::TypeId::of::<T>()) {
116 types::QueryItem::new(item)
117 } else {
118 types::QueryItem::empty()
119 }
120 }
121
122 #[inline]
123 pub fn encode<U>(&self, item: U::Item, codec: &U) -> Result<(), <U as Encoder>::Error>
125 where
126 U: Encoder,
127 {
128 if !self.is_closed() {
129 self.with_write_buf(|buf| {
130 self.cfg().write_buf().resize(buf);
132
133 codec.encode_vec(item, buf)
135 })
136 .unwrap_or_else(|err| {
139 log::trace!(
140 "{}: Got io error while encoding, error: {:?}",
141 self.tag(),
142 err
143 );
144 self.0.io_stopped(Some(err));
145 Ok(())
146 })
147 } else {
148 log::trace!("{}: Io is closed/closing, skip frame encoding", self.tag());
149 Ok(())
150 }
151 }
152
153 #[inline]
154 pub fn decode<U>(
156 &self,
157 codec: &U,
158 ) -> Result<Option<<U as Decoder>::Item>, <U as Decoder>::Error>
159 where
160 U: Decoder,
161 {
162 self.0
163 .buffer
164 .with_read_destination(self, |buf| codec.decode_vec(buf))
165 }
166
167 #[inline]
168 pub fn decode_item<U>(
170 &self,
171 codec: &U,
172 ) -> Result<Decoded<<U as Decoder>::Item>, <U as Decoder>::Error>
173 where
174 U: Decoder,
175 {
176 self.0.buffer.with_read_destination(self, |buf| {
177 let len = buf.len();
178 codec.decode_vec(buf).map(|item| Decoded {
179 item,
180 remains: buf.len(),
181 consumed: len - buf.len(),
182 })
183 })
184 }
185
186 #[inline]
187 pub fn write(&self, src: &[u8]) -> io::Result<()> {
189 self.with_write_buf(|buf| buf.extend_from_slice(src))
190 }
191
192 #[inline]
193 pub fn with_buf<F, R>(&self, f: F) -> io::Result<R>
195 where
196 F: FnOnce(&WriteBuf<'_>) -> R,
197 {
198 let ctx = FilterCtx::new(self, &self.0.buffer);
199 let result = ctx.write_buf(f);
200 self.0.filter().process_write_buf(ctx)?;
201 Ok(result)
202 }
203
204 #[inline]
205 pub fn with_write_buf<F, R>(&self, f: F) -> io::Result<R>
207 where
208 F: FnOnce(&mut BytesVec) -> R,
209 {
210 if self.0.flags.get().contains(Flags::IO_STOPPED) {
211 Err(self.0.error_or_disconnected())
212 } else {
213 let result = self.0.buffer.with_write_source(self, f);
214 self.0
215 .filter()
216 .process_write_buf(FilterCtx::new(self, &self.0.buffer))?;
217 Ok(result)
218 }
219 }
220
221 #[doc(hidden)]
222 #[inline]
223 pub fn with_write_dest_buf<F, R>(&self, f: F) -> R
225 where
226 F: FnOnce(Option<&mut BytesVec>) -> R,
227 {
228 self.0.buffer.with_write_destination(self, f)
229 }
230
231 #[inline]
232 pub fn with_read_buf<F, R>(&self, f: F) -> R
234 where
235 F: FnOnce(&mut BytesVec) -> R,
236 {
237 self.0.buffer.with_read_destination(self, f)
238 }
239
240 #[inline]
241 pub fn notify_dispatcher(&self) {
243 self.0.dispatch_task.wake();
244 log::trace!("{}: Timer, notify dispatcher", self.tag());
245 }
246
247 #[inline]
248 pub fn notify_timeout(&self) {
250 self.0.notify_timeout()
251 }
252
253 #[inline]
254 pub fn timer_handle(&self) -> timer::TimerHandle {
256 self.0.timeout.get()
257 }
258
259 #[inline]
260 pub fn start_timer(&self, timeout: Seconds) -> timer::TimerHandle {
262 let cur_hnd = self.0.timeout.get();
263
264 if !timeout.is_zero() {
265 if cur_hnd.is_set() {
266 let hnd = timer::update(cur_hnd, timeout, self);
267 if hnd != cur_hnd {
268 log::trace!("{}: Update timer {:?}", self.tag(), timeout);
269 self.0.timeout.set(hnd);
270 }
271 hnd
272 } else {
273 log::trace!("{}: Start timer {:?}", self.tag(), timeout);
274 let hnd = timer::register(timeout, self);
275 self.0.timeout.set(hnd);
276 hnd
277 }
278 } else {
279 if cur_hnd.is_set() {
280 self.0.timeout.set(timer::TimerHandle::ZERO);
281 timer::unregister(cur_hnd, self);
282 }
283 timer::TimerHandle::ZERO
284 }
285 }
286
287 #[inline]
288 pub fn stop_timer(&self) {
290 let hnd = self.0.timeout.get();
291 if hnd.is_set() {
292 log::trace!("{}: Stop timer", self.tag());
293 self.0.timeout.set(timer::TimerHandle::ZERO);
294 timer::unregister(hnd, self)
295 }
296 }
297
298 #[inline]
299 pub fn on_disconnect(&self) -> OnDisconnect {
301 OnDisconnect::new(self.0.clone())
302 }
303}
304
305impl Eq for IoRef {}
306
307impl PartialEq for IoRef {
308 #[inline]
309 fn eq(&self, other: &Self) -> bool {
310 self.0.eq(&other.0)
311 }
312}
313
314impl hash::Hash for IoRef {
315 #[inline]
316 fn hash<H: hash::Hasher>(&self, state: &mut H) {
317 self.0.hash(state);
318 }
319}
320
321impl fmt::Debug for IoRef {
322 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
323 f.debug_struct("IoRef")
324 .field("state", self.0.as_ref())
325 .finish()
326 }
327}
328
329#[cfg(test)]
330mod tests {
331 use std::cell::{Cell, RefCell};
332 use std::{future::Future, future::poll_fn, pin::Pin, rc::Rc, task::Poll};
333
334 use ntex_bytes::Bytes;
335 use ntex_codec::BytesCodec;
336 use ntex_util::future::lazy;
337 use ntex_util::time::{Millis, sleep};
338
339 use super::*;
340 use crate::{FilterCtx, FilterReadStatus, Io, testing::IoTest};
341
342 const BIN: &[u8] = b"GET /test HTTP/1\r\n\r\n";
343 const TEXT: &str = "GET /test HTTP/1\r\n\r\n";
344
345 #[ntex::test]
346 async fn utils() {
347 let (client, server) = IoTest::create();
348 client.remote_buffer_cap(1024);
349 client.write(TEXT);
350
351 let state = Io::from(server);
352 assert_eq!(state.get_ref(), state.get_ref());
353
354 let msg = state.recv(&BytesCodec).await.unwrap().unwrap();
355 assert_eq!(msg, Bytes::from_static(BIN));
356 assert_eq!(state.get_ref(), state.as_ref().clone());
357 assert!(format!("{state:?}").find("Io {").is_some());
358 assert!(format!("{:?}", state.get_ref()).find("IoRef {").is_some());
359
360 let res = poll_fn(|cx| Poll::Ready(state.poll_recv(&BytesCodec, cx))).await;
361 assert!(res.is_pending());
362 client.write(TEXT);
363 sleep(Millis(50)).await;
364 let res = poll_fn(|cx| Poll::Ready(state.poll_recv(&BytesCodec, cx))).await;
365 if let Poll::Ready(msg) = res {
366 assert_eq!(msg.unwrap(), Bytes::from_static(BIN));
367 }
368
369 client.read_error(io::Error::other("err"));
370 let msg = state.recv(&BytesCodec).await;
371 assert!(msg.is_err());
372 assert!(state.flags().contains(Flags::IO_STOPPED));
373
374 let (client, server) = IoTest::create();
375 client.remote_buffer_cap(1024);
376 let state = Io::from(server);
377
378 client.read_error(io::Error::other("err"));
379 let res = poll_fn(|cx| Poll::Ready(state.poll_recv(&BytesCodec, cx))).await;
380 if let Poll::Ready(msg) = res {
381 assert!(msg.is_err());
382 assert!(state.flags().contains(Flags::IO_STOPPED));
383 assert!(state.flags().contains(Flags::DSP_STOP));
384 }
385
386 let (client, server) = IoTest::create();
387 client.remote_buffer_cap(1024);
388 let state = Io::from(server);
389 state.write(b"test").unwrap();
390 let buf = client.read().await.unwrap();
391 assert_eq!(buf, Bytes::from_static(b"test"));
392
393 client.write(b"test");
394 state.read_ready().await.unwrap();
395 let buf = state.decode(&BytesCodec).unwrap().unwrap();
396 assert_eq!(buf, Bytes::from_static(b"test"));
397
398 client.write_error(io::Error::other("err"));
399 let res = state.send(Bytes::from_static(b"test"), &BytesCodec).await;
400 assert!(res.is_err());
401 assert!(state.flags().contains(Flags::IO_STOPPED));
402
403 let (client, server) = IoTest::create();
404 client.remote_buffer_cap(1024);
405 let state = Io::from(server);
406 state.force_close();
407 assert!(state.flags().contains(Flags::DSP_STOP));
408 assert!(state.flags().contains(Flags::IO_STOPPING));
409 }
410
411 #[ntex::test]
412 async fn read_readiness() {
413 let (client, server) = IoTest::create();
414 client.remote_buffer_cap(1024);
415
416 let io = Io::from(server);
417 assert!(lazy(|cx| io.poll_read_ready(cx)).await.is_pending());
418
419 client.write(TEXT);
420 assert_eq!(io.read_ready().await.unwrap(), Some(()));
421 assert!(lazy(|cx| io.poll_read_ready(cx)).await.is_pending());
422
423 let item = io.with_read_buf(|buffer| buffer.split());
424 assert_eq!(item, Bytes::from_static(BIN));
425
426 client.write(TEXT);
427 sleep(Millis(50)).await;
428 assert!(lazy(|cx| io.poll_read_ready(cx)).await.is_ready());
429 assert!(lazy(|cx| io.poll_read_ready(cx)).await.is_pending());
430 }
431
432 #[ntex::test]
433 #[allow(clippy::unit_cmp)]
434 async fn on_disconnect() {
435 let (client, server) = IoTest::create();
436 let state = Io::from(server);
437 let mut waiter = state.on_disconnect();
438 assert_eq!(
439 lazy(|cx| Pin::new(&mut waiter).poll(cx)).await,
440 Poll::Pending
441 );
442 let mut waiter2 = waiter.clone();
443 assert_eq!(
444 lazy(|cx| Pin::new(&mut waiter2).poll(cx)).await,
445 Poll::Pending
446 );
447 client.close().await;
448 assert_eq!(waiter.await, ());
449 assert_eq!(waiter2.await, ());
450
451 let mut waiter = state.on_disconnect();
452 assert_eq!(
453 lazy(|cx| Pin::new(&mut waiter).poll(cx)).await,
454 Poll::Ready(())
455 );
456
457 let (client, server) = IoTest::create();
458 let state = Io::from(server);
459 let mut waiter = state.on_disconnect();
460 assert_eq!(
461 lazy(|cx| Pin::new(&mut waiter).poll(cx)).await,
462 Poll::Pending
463 );
464 client.read_error(io::Error::other("err"));
465 assert_eq!(waiter.await, ());
466 }
467
468 #[ntex::test]
469 async fn write_to_closed_io() {
470 let (client, server) = IoTest::create();
471 let state = Io::from(server);
472 client.close().await;
473
474 assert!(state.is_closed());
475 assert!(state.write(TEXT.as_bytes()).is_err());
476 assert!(
477 state
478 .with_write_buf(|buf| buf.extend_from_slice(BIN))
479 .is_err()
480 );
481 }
482
483 #[derive(Debug)]
484 struct Counter<F> {
485 layer: F,
486 idx: usize,
487 in_bytes: Rc<Cell<usize>>,
488 out_bytes: Rc<Cell<usize>>,
489 read_order: Rc<RefCell<Vec<usize>>>,
490 write_order: Rc<RefCell<Vec<usize>>>,
491 }
492
493 impl<F: Filter> Filter for Counter<F> {
494 fn process_read_buf(
495 &self,
496 ctx: FilterCtx<'_>,
497 nbytes: usize,
498 ) -> io::Result<FilterReadStatus> {
499 self.read_order.borrow_mut().push(self.idx);
500 self.in_bytes.set(self.in_bytes.get() + nbytes);
501 self.layer.process_read_buf(ctx, nbytes)
502 }
503
504 fn process_write_buf(&self, ctx: FilterCtx<'_>) -> io::Result<()> {
505 self.write_order.borrow_mut().push(self.idx);
506 self.out_bytes.set(
507 self.out_bytes.get()
508 + ctx.write_buf(|buf| {
509 buf.with_src(|b| b.as_ref().map(|b| b.len()).unwrap_or_default())
510 }),
511 );
512 self.layer.process_write_buf(ctx)
513 }
514
515 crate::forward_ready!(layer);
516 crate::forward_query!(layer);
517 crate::forward_shutdown!(layer);
518 }
519
520 #[ntex::test]
521 async fn filter() {
522 let in_bytes = Rc::new(Cell::new(0));
523 let out_bytes = Rc::new(Cell::new(0));
524 let read_order = Rc::new(RefCell::new(Vec::new()));
525 let write_order = Rc::new(RefCell::new(Vec::new()));
526
527 let (client, server) = IoTest::create();
528 let io = Io::from(server).map_filter(|layer| Counter {
529 layer,
530 idx: 1,
531 in_bytes: in_bytes.clone(),
532 out_bytes: out_bytes.clone(),
533 read_order: read_order.clone(),
534 write_order: write_order.clone(),
535 });
536
537 client.remote_buffer_cap(1024);
538 client.write(TEXT);
539 let msg = io.recv(&BytesCodec).await.unwrap().unwrap();
540 assert_eq!(msg, Bytes::from_static(BIN));
541
542 io.send(Bytes::from_static(b"test"), &BytesCodec)
543 .await
544 .unwrap();
545 let buf = client.read().await.unwrap();
546 assert_eq!(buf, Bytes::from_static(b"test"));
547
548 assert_eq!(in_bytes.get(), BIN.len());
549 assert_eq!(out_bytes.get(), 4);
550 }
551
552 #[ntex::test]
553 async fn boxed_filter() {
554 let in_bytes = Rc::new(Cell::new(0));
555 let out_bytes = Rc::new(Cell::new(0));
556 let read_order = Rc::new(RefCell::new(Vec::new()));
557 let write_order = Rc::new(RefCell::new(Vec::new()));
558
559 let (client, server) = IoTest::create();
560 let state = Io::from(server)
561 .map_filter(|layer| Counter {
562 layer,
563 idx: 2,
564 in_bytes: in_bytes.clone(),
565 out_bytes: out_bytes.clone(),
566 read_order: read_order.clone(),
567 write_order: write_order.clone(),
568 })
569 .map_filter(|layer| Counter {
570 layer,
571 idx: 1,
572 in_bytes: in_bytes.clone(),
573 out_bytes: out_bytes.clone(),
574 read_order: read_order.clone(),
575 write_order: write_order.clone(),
576 });
577 let state = state.seal();
578
579 client.remote_buffer_cap(1024);
580 client.write(TEXT);
581 let msg = state.recv(&BytesCodec).await.unwrap().unwrap();
582 assert_eq!(msg, Bytes::from_static(BIN));
583
584 state
585 .send(Bytes::from_static(b"test"), &BytesCodec)
586 .await
587 .unwrap();
588 let buf = client.read().await.unwrap();
589 assert_eq!(buf, Bytes::from_static(b"test"));
590
591 assert_eq!(in_bytes.get(), BIN.len() * 2);
592 assert_eq!(out_bytes.get(), 8);
593 assert_eq!(
594 state.with_write_dest_buf(|b| b.map(|b| b.len()).unwrap_or(0)),
595 0
596 );
597
598 assert_eq!(Rc::strong_count(&in_bytes), 3);
600 drop(state);
601 assert_eq!(Rc::strong_count(&in_bytes), 1);
602 assert_eq!(*read_order.borrow(), &[1, 2][..]);
603 assert_eq!(*write_order.borrow(), &[1, 2][..]);
604 }
605}