1use std::{cell::Cell, fmt, future::poll_fn, io, task::Context, task::Poll};
2
3use ntex_bytes::{Buf, BufMut, BytesVec};
4use ntex_util::{future::lazy, future::select, future::Either, time::sleep, time::Sleep};
5
6use crate::{AsyncRead, AsyncWrite, FilterCtx, Flags, IoRef, IoTaskStatus, Readiness};
7
8pub struct ReadContext(IoRef, Cell<Option<Sleep>>);
10
11impl fmt::Debug for ReadContext {
12 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
13 f.debug_struct("ReadContext").field("io", &self.0).finish()
14 }
15}
16
17impl ReadContext {
18 pub(crate) fn new(io: &IoRef) -> Self {
19 Self(io.clone(), Cell::new(None))
20 }
21
22 #[doc(hidden)]
23 #[inline]
24 pub fn context(&self) -> IoContext {
26 IoContext::new(&self.0)
27 }
28
29 #[inline]
30 pub fn tag(&self) -> &'static str {
32 self.0.tag()
33 }
34
35 async fn wait_for_close(&self) {
37 poll_fn(|cx| {
38 let flags = self.0.flags();
39
40 if flags.intersects(Flags::IO_STOPPING | Flags::IO_STOPPED) {
41 Poll::Ready(())
42 } else {
43 self.0 .0.read_task.register(cx.waker());
44 if flags.contains(Flags::IO_STOPPING_FILTERS) {
45 self.shutdown_filters(cx);
46 }
47 Poll::Pending
48 }
49 })
50 .await
51 }
52
53 pub async fn handle<T>(&self, io: &mut T)
55 where
56 T: AsyncRead,
57 {
58 let inner = &self.0 .0;
59
60 loop {
61 let result = poll_fn(|cx| self.0.filter().poll_read_ready(cx)).await;
62 if result == Readiness::Terminate {
63 log::trace!("{}: Read task is instructed to shutdown", self.tag());
64 break;
65 }
66
67 let mut buf = if inner.flags.get().is_read_buf_ready() {
68 inner.pool.get().get_read_buf()
71 } else {
72 inner
73 .buffer
74 .get_read_source()
75 .unwrap_or_else(|| inner.pool.get().get_read_buf())
76 };
77
78 let (hw, lw) = self.0.memory_pool().read_params().unpack();
80 let remaining = buf.remaining_mut();
81 if remaining <= lw {
82 buf.reserve(hw - remaining);
83 }
84 let total = buf.len();
85
86 let (buf, result) = match select(io.read(buf), self.wait_for_close()).await {
88 Either::Left(res) => res,
89 Either::Right(_) => {
90 log::trace!("{}: Read io is closed, stop read task", self.tag());
91 break;
92 }
93 };
94
95 let total2 = buf.len();
97 let nbytes = total2.saturating_sub(total);
98 let total = total2;
99
100 if let Some(mut first_buf) = inner.buffer.get_read_source() {
101 first_buf.extend_from_slice(&buf);
102 inner.buffer.set_read_source(&self.0, first_buf);
103 } else {
104 inner.buffer.set_read_source(&self.0, buf);
105 }
106
107 if nbytes > 0 {
109 let filter = self.0.filter();
110 let res = match filter
111 .process_read_buf(FilterCtx::new(&self.0, &inner.buffer), nbytes)
112 {
113 Ok(status) => {
114 if status.nbytes > 0 {
115 if hw < inner.buffer.read_destination_size() {
117 log::trace!(
118 "{}: Io read buffer is too large {}, enable read back-pressure",
119 self.0.tag(),
120 total
121 );
122 inner.insert_flags(Flags::BUF_R_READY | Flags::BUF_R_FULL);
123 } else {
124 inner.insert_flags(Flags::BUF_R_READY);
125 }
126 log::trace!(
127 "{}: New {} bytes available, wakeup dispatcher",
128 self.0.tag(),
129 nbytes
130 );
131 inner.dispatch_task.wake();
133 } else if inner.flags.get().is_waiting_for_read() {
134 inner.dispatch_task.wake();
137 }
138
139 if status.need_write {
143 filter.process_write_buf(FilterCtx::new(&self.0, &inner.buffer))
144 } else {
145 Ok(())
146 }
147 }
148 Err(err) => Err(err),
149 };
150
151 if let Err(err) = res {
152 inner.dispatch_task.wake();
153 inner.io_stopped(Some(err));
154 inner.insert_flags(Flags::BUF_R_READY);
155 }
156 }
157
158 match result {
159 Ok(0) => {
160 log::trace!("{}: Tcp stream is disconnected", self.tag());
161 inner.io_stopped(None);
162 break;
163 }
164 Ok(_) => {
165 if inner.flags.get().contains(Flags::IO_STOPPING_FILTERS) {
166 lazy(|cx| self.shutdown_filters(cx)).await;
167 }
168 }
169 Err(err) => {
170 log::trace!("{}: Read task failed on io {:?}", self.tag(), err);
171 inner.io_stopped(Some(err));
172 break;
173 }
174 }
175 }
176 }
177
178 fn shutdown_filters(&self, cx: &mut Context<'_>) {
179 let st = &self.0 .0;
180 let filter = self.0.filter();
181
182 match filter.shutdown(FilterCtx::new(&self.0, &st.buffer)) {
183 Ok(Poll::Ready(())) => {
184 st.dispatch_task.wake();
185 st.insert_flags(Flags::IO_STOPPING);
186 }
187 Ok(Poll::Pending) => {
188 let flags = st.flags.get();
189
190 if flags.contains(Flags::RD_PAUSED)
193 || flags.contains(Flags::BUF_R_FULL | Flags::BUF_R_READY)
194 {
195 st.dispatch_task.wake();
196 st.insert_flags(Flags::IO_STOPPING);
197 } else {
198 let timeout = self
200 .1
201 .take()
202 .unwrap_or_else(|| sleep(st.disconnect_timeout.get()));
203 if timeout.poll_elapsed(cx).is_ready() {
204 st.dispatch_task.wake();
205 st.insert_flags(Flags::IO_STOPPING);
206 } else {
207 self.1.set(Some(timeout));
208 }
209 }
210 }
211 Err(err) => {
212 st.io_stopped(Some(err));
213 }
214 }
215 if let Err(err) = filter.process_write_buf(FilterCtx::new(&self.0, &st.buffer)) {
216 st.io_stopped(Some(err));
217 }
218 }
219}
220
221#[derive(Debug)]
222pub struct WriteContext(IoRef);
224
225#[derive(Debug)]
226pub struct WriteContextBuf {
228 io: IoRef,
229 buf: Option<BytesVec>,
230}
231
232impl WriteContext {
233 pub(crate) fn new(io: &IoRef) -> Self {
234 Self(io.clone())
235 }
236
237 #[inline]
238 pub fn tag(&self) -> &'static str {
240 self.0.tag()
241 }
242
243 async fn ready(&self) -> Readiness {
245 poll_fn(|cx| self.0.filter().poll_write_ready(cx)).await
246 }
247
248 fn close(&self, err: Option<io::Error>) {
250 self.0 .0.io_stopped(err);
251 }
252
253 async fn when_stopped(&self) {
255 poll_fn(|cx| {
256 if self.0.flags().is_stopped() {
257 Poll::Ready(())
258 } else {
259 self.0 .0.write_task.register(cx.waker());
260 Poll::Pending
261 }
262 })
263 .await
264 }
265
266 pub async fn handle<T>(&self, io: &mut T)
268 where
269 T: AsyncWrite,
270 {
271 let mut buf = WriteContextBuf {
272 io: self.0.clone(),
273 buf: None,
274 };
275
276 loop {
277 match self.ready().await {
278 Readiness::Ready => {
279 match select(io.write(&mut buf), self.when_stopped()).await {
281 Either::Left(Ok(_)) => continue,
282 Either::Left(Err(e)) => self.close(Some(e)),
283 Either::Right(_) => return,
284 }
285 }
286 Readiness::Shutdown => {
287 log::trace!("{}: Write task is instructed to shutdown", self.tag());
288
289 let fut = async {
290 io.write(&mut buf).await?;
292 io.flush().await?;
293 io.shutdown().await?;
294 Ok(())
295 };
296 match select(sleep(self.0 .0.disconnect_timeout.get()), fut).await {
297 Either::Left(_) => self.close(None),
298 Either::Right(res) => self.close(res.err()),
299 }
300 }
301 Readiness::Terminate => {
302 log::trace!("{}: Write task is instructed to terminate", self.tag());
303 self.close(io.shutdown().await.err());
304 }
305 }
306 return;
307 }
308 }
309}
310
311impl WriteContextBuf {
312 pub fn set(&mut self, mut buf: BytesVec) {
313 if buf.is_empty() {
314 self.io.memory_pool().release_write_buf(buf);
315 } else if let Some(b) = self.buf.take() {
316 buf.extend_from_slice(&b);
317 self.io.memory_pool().release_write_buf(b);
318 self.buf = Some(buf);
319 } else if let Some(b) = self.io.0.buffer.set_write_destination(buf) {
320 self.buf = Some(b);
322 }
323
324 let inner = &self.io.0;
326 let len = self.buf.as_ref().map(|b| b.len()).unwrap_or_default()
327 + inner.buffer.write_destination_size();
328 let mut flags = inner.flags.get();
329
330 if len == 0 {
331 if flags.is_waiting_for_write() {
332 flags.waiting_for_write_is_done();
333 inner.dispatch_task.wake();
334 }
335 flags.insert(Flags::WR_PAUSED);
336 inner.flags.set(flags);
337 } else if flags.contains(Flags::BUF_W_BACKPRESSURE)
338 && len < inner.pool.get().write_params_high() << 1
339 {
340 flags.remove(Flags::BUF_W_BACKPRESSURE);
341 inner.flags.set(flags);
342 inner.dispatch_task.wake();
343 }
344 }
345
346 pub fn take(&mut self) -> Option<BytesVec> {
347 if let Some(buf) = self.buf.take() {
348 Some(buf)
349 } else {
350 self.io.0.buffer.get_write_destination()
351 }
352 }
353}
354
355pub struct IoContext(IoRef);
357
358impl fmt::Debug for IoContext {
359 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
360 f.debug_struct("IoContext").field("io", &self.0).finish()
361 }
362}
363
364impl IoContext {
365 pub(crate) fn new(io: &IoRef) -> Self {
366 Self(io.clone())
367 }
368
369 #[doc(hidden)]
370 #[inline]
371 pub fn id(&self) -> usize {
372 self.0 .0.as_ref() as *const _ as usize
373 }
374
375 #[inline]
376 pub fn tag(&self) -> &'static str {
378 self.0.tag()
379 }
380
381 #[doc(hidden)]
382 pub fn flags(&self) -> crate::flags::Flags {
384 self.0.flags()
385 }
386
387 #[inline]
388 pub fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<Readiness> {
390 self.shutdown_filters();
391 self.0.filter().poll_read_ready(cx)
392 }
393
394 #[inline]
395 pub fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<Readiness> {
397 self.0.filter().poll_write_ready(cx)
398 }
399
400 #[inline]
401 pub fn stop(&self, e: Option<io::Error>) {
403 self.0 .0.io_stopped(e);
404 }
405
406 #[deprecated(since = "2.14.1")]
407 #[doc(hidden)]
408 #[inline]
409 pub fn init_shutdown(&self) {
411 self.0 .0.init_shutdown();
412 }
413
414 #[inline]
415 pub fn is_stopped(&self) -> bool {
417 self.0.flags().is_stopped()
418 }
419
420 pub fn shutdown(&self, flush: bool, cx: &mut Context<'_>) -> Poll<()> {
422 let st = &self.0 .0;
423
424 let flags = self.0.flags();
425 if !flags.intersects(Flags::IO_STOPPING | Flags::IO_STOPPED) {
426 st.write_task.register(cx.waker());
427 return Poll::Pending;
428 }
429
430 if flush && !flags.contains(Flags::IO_STOPPED) {
431 if flags.intersects(Flags::WR_PAUSED | Flags::IO_STOPPED) {
432 return Poll::Ready(());
433 }
434 st.insert_flags(Flags::WR_TASK_WAIT);
435 st.write_task.register(cx.waker());
436 Poll::Pending
437 } else {
438 Poll::Ready(())
439 }
440 }
441
442 pub fn get_read_buf(&self) -> BytesVec {
444 let inner = &self.0 .0;
445
446 if inner.flags.get().is_read_buf_ready() {
447 inner.pool.get().get_read_buf()
450 } else {
451 inner
452 .buffer
453 .get_read_source()
454 .unwrap_or_else(|| inner.pool.get().get_read_buf())
455 }
456 }
457
458 pub fn release_read_buf(
460 &self,
461 nbytes: usize,
462 buf: BytesVec,
463 result: Poll<io::Result<()>>,
464 ) -> IoTaskStatus {
465 let inner = &self.0 .0;
466 let orig_size = inner.buffer.read_destination_size();
467 let hw = self.0.memory_pool().read_params().unpack().0;
468
469 if let Some(mut first_buf) = inner.buffer.get_read_source() {
470 first_buf.extend_from_slice(&buf);
471 inner.buffer.set_read_source(&self.0, first_buf);
472 } else {
473 inner.buffer.set_read_source(&self.0, buf);
474 }
475
476 let mut full = false;
477
478 let st_res = if nbytes > 0 {
480 match self
481 .0
482 .filter()
483 .process_read_buf(FilterCtx::new(&self.0, &inner.buffer), nbytes)
484 {
485 Ok(status) => {
486 let buffer_size = inner.buffer.read_destination_size();
487 if buffer_size.saturating_sub(orig_size) > 0 {
488 if buffer_size >= hw {
490 log::trace!(
491 "{}: Io read buffer is too large {}, enable read back-pressure",
492 self.tag(),
493 buffer_size
494 );
495 full = true;
496 inner.insert_flags(Flags::BUF_R_READY | Flags::BUF_R_FULL);
497 } else {
498 inner.insert_flags(Flags::BUF_R_READY);
499 }
500 log::trace!(
501 "{}: New {} bytes available, wakeup dispatcher",
502 self.tag(),
503 buffer_size
504 );
505 inner.dispatch_task.wake();
506 } else {
507 if buffer_size >= hw {
508 full = true;
513 inner.read_task.wake();
514 }
515 if inner.flags.get().is_waiting_for_read() {
516 inner.dispatch_task.wake();
519 }
520 }
521
522 if status.need_write {
526 self.0
527 .filter()
528 .process_write_buf(FilterCtx::new(&self.0, &inner.buffer))
529 } else {
530 Ok(())
531 }
532 }
533 Err(err) => Err(err),
534 }
535 } else {
536 Ok(())
537 };
538
539 match result {
540 Poll::Ready(Ok(_)) => {
541 if let Err(e) = st_res {
542 inner.io_stopped(Some(e));
543 IoTaskStatus::Pause
544 } else if nbytes == 0 {
545 inner.io_stopped(None);
546 IoTaskStatus::Pause
547 } else {
548 self.shutdown_filters();
549 if full {
550 IoTaskStatus::Pause
551 } else {
552 IoTaskStatus::Io
553 }
554 }
555 }
556 Poll::Ready(Err(e)) => {
557 inner.io_stopped(Some(e));
558 IoTaskStatus::Pause
559 }
560 Poll::Pending => {
561 if let Err(e) = st_res {
562 inner.io_stopped(Some(e));
563 IoTaskStatus::Pause
564 } else {
565 self.shutdown_filters();
566 if full {
567 IoTaskStatus::Pause
568 } else {
569 IoTaskStatus::Io
570 }
571 }
572 }
573 }
574 }
575
576 pub fn get_write_buf(&self) -> Option<BytesVec> {
578 self.0 .0.buffer.get_write_destination().and_then(|buf| {
579 if buf.is_empty() {
580 None
581 } else {
582 Some(buf)
583 }
584 })
585 }
586
587 pub fn release_write_buf(
589 &self,
590 mut buf: BytesVec,
591 result: Poll<io::Result<usize>>,
592 ) -> IoTaskStatus {
593 let result = match result {
594 Poll::Ready(Ok(0)) => {
595 log::trace!("{}: Disconnected during flush", self.tag());
596 Err(io::Error::new(
597 io::ErrorKind::WriteZero,
598 "failed to write frame to transport",
599 ))
600 }
601 Poll::Ready(Ok(n)) => {
602 if n == buf.len() {
603 buf.clear();
604 Ok(0)
605 } else {
606 buf.advance(n);
607 Ok(buf.len())
608 }
609 }
610 Poll::Ready(Err(e)) => Err(e),
611 Poll::Pending => Ok(buf.len()),
612 };
613
614 let inner = &self.0 .0;
615
616 let result = match result {
618 Ok(0) => {
619 self.0.memory_pool().release_write_buf(buf);
620 Ok(inner.buffer.write_destination_size())
621 }
622 Ok(_) => {
623 if let Some(b) = inner.buffer.get_write_destination() {
624 buf.extend_from_slice(&b);
625 self.0.memory_pool().release_write_buf(b);
626 }
627 let l = buf.len();
628 inner.buffer.set_write_destination(buf);
629 Ok(l)
630 }
631 Err(e) => Err(e),
632 };
633
634 match result {
635 Ok(0) => {
636 let mut flags = inner.flags.get();
637
638 flags.insert(Flags::WR_PAUSED);
640
641 if flags.is_task_waiting_for_write() {
642 flags.task_waiting_for_write_is_done();
643 inner.write_task.wake();
644 }
645
646 if flags.is_waiting_for_write() {
647 flags.waiting_for_write_is_done();
648 inner.dispatch_task.wake();
649 }
650 inner.flags.set(flags);
651 IoTaskStatus::Pause
652 }
653 Ok(len) => {
654 if inner.flags.get().contains(Flags::BUF_W_BACKPRESSURE)
656 && len < inner.pool.get().write_params_high() << 1
657 {
658 inner.remove_flags(Flags::BUF_W_BACKPRESSURE);
659 inner.dispatch_task.wake();
660 }
661 if self.is_stopped() {
662 IoTaskStatus::Pause
663 } else {
664 IoTaskStatus::Io
665 }
666 }
667 Err(e) => {
668 inner.io_stopped(Some(e));
669 IoTaskStatus::Pause
670 }
671 }
672 }
673
674 fn shutdown_filters(&self) {
675 let io = &self.0;
676 let st = &self.0 .0;
677 let flags = st.flags.get();
678 if flags.contains(Flags::IO_STOPPING_FILTERS)
679 && !flags.intersects(Flags::IO_STOPPED | Flags::IO_STOPPING)
680 {
681 match io.filter().shutdown(FilterCtx::new(io, &st.buffer)) {
682 Ok(Poll::Ready(())) => {
683 st.write_task.wake();
684 st.dispatch_task.wake();
685 st.insert_flags(Flags::IO_STOPPING);
686 }
687 Ok(Poll::Pending) => {
688 let flags = st.flags.get();
691 if flags.contains(Flags::RD_PAUSED)
692 || flags.contains(Flags::BUF_R_FULL | Flags::BUF_R_READY)
693 {
694 st.write_task.wake();
695 st.dispatch_task.wake();
696 st.insert_flags(Flags::IO_STOPPING);
697 }
698 }
699 Err(err) => {
700 st.io_stopped(Some(err));
701 }
702 }
703 if let Err(err) = io
704 .filter()
705 .process_write_buf(FilterCtx::new(io, &st.buffer))
706 {
707 st.io_stopped(Some(err));
708 }
709 }
710 }
711}
712
713impl Clone for IoContext {
714 fn clone(&self) -> Self {
715 Self(self.0.clone())
716 }
717}