1use std::{any, fmt, hash, io};
2
3use ntex_bytes::{BytesVec, PoolRef};
4use ntex_codec::{Decoder, Encoder};
5use ntex_util::time::Seconds;
6
7use crate::{timer, types, Decoded, Filter, Flags, IoRef, OnDisconnect, WriteBuf};
8
9impl IoRef {
10 #[inline]
11 #[doc(hidden)]
12 pub fn flags(&self) -> Flags {
14 self.0.flags.get()
15 }
16
17 #[inline]
18 pub(crate) fn filter(&self) -> &dyn Filter {
20 self.0.filter()
21 }
22
23 #[inline]
24 pub fn memory_pool(&self) -> PoolRef {
26 self.0.pool.get()
27 }
28
29 #[inline]
30 pub fn is_closed(&self) -> bool {
32 self.0
33 .flags
34 .get()
35 .intersects(Flags::IO_STOPPING | Flags::IO_STOPPED)
36 }
37
38 #[inline]
39 pub fn is_wr_backpressure(&self) -> bool {
41 self.0.flags.get().contains(Flags::BUF_W_BACKPRESSURE)
42 }
43
44 #[inline]
45 pub fn wake(&self) {
47 self.0.dispatch_task.wake();
48 }
49
50 #[inline]
51 pub fn close(&self) {
55 self.0.insert_flags(Flags::DSP_STOP);
56 self.0.init_shutdown();
57 }
58
59 #[inline]
60 pub fn force_close(&self) {
65 log::trace!("{}: Force close io stream object", self.tag());
66 self.0.insert_flags(
67 Flags::DSP_STOP
68 | Flags::IO_STOPPED
69 | Flags::IO_STOPPING
70 | Flags::IO_STOPPING_FILTERS,
71 );
72 self.0.read_task.wake();
73 self.0.write_task.wake();
74 self.0.dispatch_task.wake();
75 }
76
77 #[inline]
78 pub fn want_shutdown(&self) {
80 if !self
81 .0
82 .flags
83 .get()
84 .intersects(Flags::IO_STOPPED | Flags::IO_STOPPING | Flags::IO_STOPPING_FILTERS)
85 {
86 log::trace!(
87 "{}: Initiate io shutdown {:?}",
88 self.tag(),
89 self.0.flags.get()
90 );
91 self.0.insert_flags(Flags::IO_STOPPING_FILTERS);
92 self.0.read_task.wake();
93 }
94 }
95
96 #[inline]
97 pub fn query<T: 'static>(&self) -> types::QueryItem<T> {
99 if let Some(item) = self.filter().query(any::TypeId::of::<T>()) {
100 types::QueryItem::new(item)
101 } else {
102 types::QueryItem::empty()
103 }
104 }
105
106 #[inline]
107 pub fn encode<U>(&self, item: U::Item, codec: &U) -> Result<(), <U as Encoder>::Error>
109 where
110 U: Encoder,
111 {
112 if !self.is_closed() {
113 self.with_write_buf(|buf| {
114 self.memory_pool().resize_write_buf(buf);
116
117 codec.encode_vec(item, buf)
119 })
120 .unwrap_or_else(|err| {
123 log::trace!(
124 "{}: Got io error while encoding, error: {:?}",
125 self.tag(),
126 err
127 );
128 self.0.io_stopped(Some(err));
129 Ok(())
130 })
131 } else {
132 log::trace!("{}: Io is closed/closing, skip frame encoding", self.tag());
133 Ok(())
134 }
135 }
136
137 #[inline]
138 pub fn decode<U>(
140 &self,
141 codec: &U,
142 ) -> Result<Option<<U as Decoder>::Item>, <U as Decoder>::Error>
143 where
144 U: Decoder,
145 {
146 self.0
147 .buffer
148 .with_read_destination(self, |buf| codec.decode_vec(buf))
149 }
150
151 #[inline]
152 pub fn decode_item<U>(
154 &self,
155 codec: &U,
156 ) -> Result<Decoded<<U as Decoder>::Item>, <U as Decoder>::Error>
157 where
158 U: Decoder,
159 {
160 self.0.buffer.with_read_destination(self, |buf| {
161 let len = buf.len();
162 codec.decode_vec(buf).map(|item| Decoded {
163 item,
164 remains: buf.len(),
165 consumed: len - buf.len(),
166 })
167 })
168 }
169
170 #[inline]
171 pub fn write(&self, src: &[u8]) -> io::Result<()> {
173 self.with_write_buf(|buf| buf.extend_from_slice(src))
174 }
175
176 #[inline]
177 pub fn with_buf<F, R>(&self, f: F) -> io::Result<R>
179 where
180 F: FnOnce(&WriteBuf<'_>) -> R,
181 {
182 let result = self.0.buffer.write_buf(self, 0, f);
183 self.0.filter().process_write_buf(self, &self.0.buffer, 0)?;
184 Ok(result)
185 }
186
187 #[inline]
188 pub fn with_write_buf<F, R>(&self, f: F) -> io::Result<R>
190 where
191 F: FnOnce(&mut BytesVec) -> R,
192 {
193 if self.0.flags.get().contains(Flags::IO_STOPPED) {
194 Err(self.0.error_or_disconnected())
195 } else {
196 let result = self.0.buffer.with_write_source(self, f);
197 self.0.filter().process_write_buf(self, &self.0.buffer, 0)?;
198 Ok(result)
199 }
200 }
201
202 #[doc(hidden)]
203 #[inline]
204 pub fn with_write_dest_buf<F, R>(&self, f: F) -> R
206 where
207 F: FnOnce(Option<&mut BytesVec>) -> R,
208 {
209 self.0.buffer.with_write_destination(self, f)
210 }
211
212 #[inline]
213 pub fn with_read_buf<F, R>(&self, f: F) -> R
215 where
216 F: FnOnce(&mut BytesVec) -> R,
217 {
218 self.0.buffer.with_read_destination(self, f)
219 }
220
221 #[inline]
222 pub fn notify_dispatcher(&self) {
224 self.0.dispatch_task.wake();
225 log::trace!("{}: Timer, notify dispatcher", self.tag());
226 }
227
228 #[inline]
229 pub fn notify_timeout(&self) {
231 self.0.notify_timeout()
232 }
233
234 #[inline]
235 pub fn timer_handle(&self) -> timer::TimerHandle {
237 self.0.timeout.get()
238 }
239
240 #[inline]
241 pub fn start_timer(&self, timeout: Seconds) -> timer::TimerHandle {
243 let cur_hnd = self.0.timeout.get();
244
245 if !timeout.is_zero() {
246 if cur_hnd.is_set() {
247 let hnd = timer::update(cur_hnd, timeout, self);
248 if hnd != cur_hnd {
249 log::debug!("{}: Update timer {:?}", self.tag(), timeout);
250 self.0.timeout.set(hnd);
251 }
252 hnd
253 } else {
254 log::debug!("{}: Start timer {:?}", self.tag(), timeout);
255 let hnd = timer::register(timeout, self);
256 self.0.timeout.set(hnd);
257 hnd
258 }
259 } else {
260 if cur_hnd.is_set() {
261 self.0.timeout.set(timer::TimerHandle::ZERO);
262 timer::unregister(cur_hnd, self);
263 }
264 timer::TimerHandle::ZERO
265 }
266 }
267
268 #[inline]
269 pub fn stop_timer(&self) {
271 let hnd = self.0.timeout.get();
272 if hnd.is_set() {
273 log::debug!("{}: Stop timer", self.tag());
274 self.0.timeout.set(timer::TimerHandle::ZERO);
275 timer::unregister(hnd, self)
276 }
277 }
278
279 #[inline]
280 pub fn tag(&self) -> &'static str {
282 self.0.tag.get()
283 }
284
285 #[inline]
286 pub fn set_tag(&self, tag: &'static str) {
288 self.0.tag.set(tag)
289 }
290
291 #[inline]
292 pub fn on_disconnect(&self) -> OnDisconnect {
294 OnDisconnect::new(self.0.clone())
295 }
296}
297
298impl Eq for IoRef {}
299
300impl PartialEq for IoRef {
301 #[inline]
302 fn eq(&self, other: &Self) -> bool {
303 self.0.eq(&other.0)
304 }
305}
306
307impl hash::Hash for IoRef {
308 #[inline]
309 fn hash<H: hash::Hasher>(&self, state: &mut H) {
310 self.0.hash(state);
311 }
312}
313
314impl fmt::Debug for IoRef {
315 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
316 f.debug_struct("IoRef")
317 .field("state", self.0.as_ref())
318 .finish()
319 }
320}
321
322#[cfg(test)]
323mod tests {
324 use std::cell::{Cell, RefCell};
325 use std::{future::poll_fn, future::Future, pin::Pin, rc::Rc, task::Poll};
326
327 use ntex_bytes::Bytes;
328 use ntex_codec::BytesCodec;
329 use ntex_util::future::lazy;
330 use ntex_util::time::{sleep, Millis};
331
332 use super::*;
333 use crate::{testing::IoTest, FilterLayer, Io, ReadBuf};
334
335 const BIN: &[u8] = b"GET /test HTTP/1\r\n\r\n";
336 const TEXT: &str = "GET /test HTTP/1\r\n\r\n";
337
338 #[ntex::test]
339 async fn utils() {
340 let (client, server) = IoTest::create();
341 client.remote_buffer_cap(1024);
342 client.write(TEXT);
343
344 let state = Io::new(server);
345 assert_eq!(state.get_ref(), state.get_ref());
346
347 let msg = state.recv(&BytesCodec).await.unwrap().unwrap();
348 assert_eq!(msg, Bytes::from_static(BIN));
349 assert_eq!(state.get_ref(), state.as_ref().clone());
350 assert!(format!("{:?}", state).find("Io {").is_some());
351 assert!(format!("{:?}", state.get_ref()).find("IoRef {").is_some());
352
353 let res = poll_fn(|cx| Poll::Ready(state.poll_recv(&BytesCodec, cx))).await;
354 assert!(res.is_pending());
355 client.write(TEXT);
356 sleep(Millis(50)).await;
357 let res = poll_fn(|cx| Poll::Ready(state.poll_recv(&BytesCodec, cx))).await;
358 if let Poll::Ready(msg) = res {
359 assert_eq!(msg.unwrap(), Bytes::from_static(BIN));
360 }
361
362 client.read_error(io::Error::new(io::ErrorKind::Other, "err"));
363 let msg = state.recv(&BytesCodec).await;
364 assert!(msg.is_err());
365 assert!(state.flags().contains(Flags::IO_STOPPED));
366
367 let (client, server) = IoTest::create();
368 client.remote_buffer_cap(1024);
369 let state = Io::new(server);
370
371 client.read_error(io::Error::new(io::ErrorKind::Other, "err"));
372 let res = poll_fn(|cx| Poll::Ready(state.poll_recv(&BytesCodec, cx))).await;
373 if let Poll::Ready(msg) = res {
374 assert!(msg.is_err());
375 assert!(state.flags().contains(Flags::IO_STOPPED));
376 assert!(state.flags().contains(Flags::DSP_STOP));
377 }
378
379 let (client, server) = IoTest::create();
380 client.remote_buffer_cap(1024);
381 let state = Io::new(server);
382 state.write(b"test").unwrap();
383 let buf = client.read().await.unwrap();
384 assert_eq!(buf, Bytes::from_static(b"test"));
385
386 client.write(b"test");
387 state.read_ready().await.unwrap();
388 let buf = state.decode(&BytesCodec).unwrap().unwrap();
389 assert_eq!(buf, Bytes::from_static(b"test"));
390
391 client.write_error(io::Error::new(io::ErrorKind::Other, "err"));
392 let res = state.send(Bytes::from_static(b"test"), &BytesCodec).await;
393 assert!(res.is_err());
394 assert!(state.flags().contains(Flags::IO_STOPPED));
395
396 let (client, server) = IoTest::create();
397 client.remote_buffer_cap(1024);
398 let state = Io::new(server);
399 state.force_close();
400 assert!(state.flags().contains(Flags::DSP_STOP));
401 assert!(state.flags().contains(Flags::IO_STOPPING));
402 }
403
404 #[ntex::test]
405 async fn read_readiness() {
406 let (client, server) = IoTest::create();
407 client.remote_buffer_cap(1024);
408
409 let io = Io::new(server);
410 assert!(lazy(|cx| io.poll_read_ready(cx)).await.is_pending());
411
412 client.write(TEXT);
413 assert_eq!(io.read_ready().await.unwrap(), Some(()));
414 assert!(lazy(|cx| io.poll_read_ready(cx)).await.is_pending());
415
416 let item = io.with_read_buf(|buffer| buffer.split());
417 assert_eq!(item, Bytes::from_static(BIN));
418
419 client.write(TEXT);
420 sleep(Millis(50)).await;
421 assert!(lazy(|cx| io.poll_read_ready(cx)).await.is_ready());
422 assert!(lazy(|cx| io.poll_read_ready(cx)).await.is_pending());
423 }
424
425 #[ntex::test]
426 #[allow(clippy::unit_cmp)]
427 async fn on_disconnect() {
428 let (client, server) = IoTest::create();
429 let state = Io::new(server);
430 let mut waiter = state.on_disconnect();
431 assert_eq!(
432 lazy(|cx| Pin::new(&mut waiter).poll(cx)).await,
433 Poll::Pending
434 );
435 let mut waiter2 = waiter.clone();
436 assert_eq!(
437 lazy(|cx| Pin::new(&mut waiter2).poll(cx)).await,
438 Poll::Pending
439 );
440 client.close().await;
441 assert_eq!(waiter.await, ());
442 assert_eq!(waiter2.await, ());
443
444 let mut waiter = state.on_disconnect();
445 assert_eq!(
446 lazy(|cx| Pin::new(&mut waiter).poll(cx)).await,
447 Poll::Ready(())
448 );
449
450 let (client, server) = IoTest::create();
451 let state = Io::new(server);
452 let mut waiter = state.on_disconnect();
453 assert_eq!(
454 lazy(|cx| Pin::new(&mut waiter).poll(cx)).await,
455 Poll::Pending
456 );
457 client.read_error(io::Error::new(io::ErrorKind::Other, "err"));
458 assert_eq!(waiter.await, ());
459 }
460
461 #[ntex::test]
462 async fn write_to_closed_io() {
463 let (client, server) = IoTest::create();
464 let state = Io::new(server);
465 client.close().await;
466
467 assert!(state.is_closed());
468 assert!(state.write(TEXT.as_bytes()).is_err());
469 assert!(state
470 .with_write_buf(|buf| buf.extend_from_slice(BIN))
471 .is_err());
472 }
473
474 #[derive(Debug)]
475 struct Counter {
476 idx: usize,
477 in_bytes: Rc<Cell<usize>>,
478 out_bytes: Rc<Cell<usize>>,
479 read_order: Rc<RefCell<Vec<usize>>>,
480 write_order: Rc<RefCell<Vec<usize>>>,
481 }
482
483 impl FilterLayer for Counter {
484 const BUFFERS: bool = false;
485
486 fn process_read_buf(&self, buf: &ReadBuf<'_>) -> io::Result<usize> {
487 self.read_order.borrow_mut().push(self.idx);
488 self.in_bytes.set(self.in_bytes.get() + buf.nbytes());
489 Ok(buf.nbytes())
490 }
491
492 fn process_write_buf(&self, buf: &WriteBuf<'_>) -> io::Result<()> {
493 self.write_order.borrow_mut().push(self.idx);
494 self.out_bytes
495 .set(self.out_bytes.get() + buf.with_dst(|b| b.len()));
496 Ok(())
497 }
498 }
499
500 #[ntex::test]
501 async fn filter() {
502 let in_bytes = Rc::new(Cell::new(0));
503 let out_bytes = Rc::new(Cell::new(0));
504 let read_order = Rc::new(RefCell::new(Vec::new()));
505 let write_order = Rc::new(RefCell::new(Vec::new()));
506
507 let (client, server) = IoTest::create();
508 let counter = Counter {
509 idx: 1,
510 in_bytes: in_bytes.clone(),
511 out_bytes: out_bytes.clone(),
512 read_order: read_order.clone(),
513 write_order: write_order.clone(),
514 };
515 let _ = format!("{:?}", counter);
516 let io = Io::new(server).add_filter(counter);
517
518 client.remote_buffer_cap(1024);
519 client.write(TEXT);
520 let msg = io.recv(&BytesCodec).await.unwrap().unwrap();
521 assert_eq!(msg, Bytes::from_static(BIN));
522
523 io.send(Bytes::from_static(b"test"), &BytesCodec)
524 .await
525 .unwrap();
526 let buf = client.read().await.unwrap();
527 assert_eq!(buf, Bytes::from_static(b"test"));
528
529 assert_eq!(in_bytes.get(), BIN.len());
530 assert_eq!(out_bytes.get(), 4);
531 }
532
533 #[ntex::test]
534 async fn boxed_filter() {
535 let in_bytes = Rc::new(Cell::new(0));
536 let out_bytes = Rc::new(Cell::new(0));
537 let read_order = Rc::new(RefCell::new(Vec::new()));
538 let write_order = Rc::new(RefCell::new(Vec::new()));
539
540 let (client, server) = IoTest::create();
541 let state = Io::new(server)
542 .add_filter(Counter {
543 idx: 1,
544 in_bytes: in_bytes.clone(),
545 out_bytes: out_bytes.clone(),
546 read_order: read_order.clone(),
547 write_order: write_order.clone(),
548 })
549 .add_filter(Counter {
550 idx: 2,
551 in_bytes: in_bytes.clone(),
552 out_bytes: out_bytes.clone(),
553 read_order: read_order.clone(),
554 write_order: write_order.clone(),
555 });
556 let state = state.seal();
557
558 client.remote_buffer_cap(1024);
559 client.write(TEXT);
560 let msg = state.recv(&BytesCodec).await.unwrap().unwrap();
561 assert_eq!(msg, Bytes::from_static(BIN));
562
563 state
564 .send(Bytes::from_static(b"test"), &BytesCodec)
565 .await
566 .unwrap();
567 let buf = client.read().await.unwrap();
568 assert_eq!(buf, Bytes::from_static(b"test"));
569
570 assert_eq!(in_bytes.get(), BIN.len() * 2);
571 assert_eq!(out_bytes.get(), 8);
572 assert_eq!(
573 state.with_write_dest_buf(|b| b.map(|b| b.len()).unwrap_or(0)),
574 0
575 );
576
577 assert_eq!(Rc::strong_count(&in_bytes), 3);
579 drop(state);
580 assert_eq!(Rc::strong_count(&in_bytes), 1);
581 assert_eq!(*read_order.borrow(), &[1, 2][..]);
582 assert_eq!(*write_order.borrow(), &[2, 1][..]);
583 }
584}