1use std::{cell::Cell, fmt, future::poll_fn, io, task::Context, task::Poll};
2
3use ntex_bytes::{Buf, BufMut, BytesVec};
4use ntex_util::{future::lazy, future::select, future::Either, time::sleep, time::Sleep};
5
6use crate::{AsyncRead, AsyncWrite, FilterCtx, Flags, IoRef, IoTaskStatus, Readiness};
7
8pub struct ReadContext(IoRef, Cell<Option<Sleep>>);
10
11impl fmt::Debug for ReadContext {
12 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
13 f.debug_struct("ReadContext").field("io", &self.0).finish()
14 }
15}
16
17impl ReadContext {
18 pub(crate) fn new(io: &IoRef) -> Self {
19 Self(io.clone(), Cell::new(None))
20 }
21
22 #[doc(hidden)]
23 #[inline]
24 pub fn context(&self) -> IoContext {
26 IoContext::new(&self.0)
27 }
28
29 #[inline]
30 pub fn tag(&self) -> &'static str {
32 self.0.tag()
33 }
34
35 async fn wait_for_close(&self) {
37 poll_fn(|cx| {
38 let flags = self.0.flags();
39
40 if flags.intersects(Flags::IO_STOPPING | Flags::IO_STOPPED) {
41 Poll::Ready(())
42 } else {
43 self.0 .0.read_task.register(cx.waker());
44 if flags.contains(Flags::IO_STOPPING_FILTERS) {
45 self.shutdown_filters(cx);
46 }
47 Poll::Pending
48 }
49 })
50 .await
51 }
52
53 pub async fn handle<T>(&self, io: &mut T)
55 where
56 T: AsyncRead,
57 {
58 let inner = &self.0 .0;
59
60 loop {
61 let result = poll_fn(|cx| self.0.filter().poll_read_ready(cx)).await;
62 if result == Readiness::Terminate {
63 log::trace!("{}: Read task is instructed to shutdown", self.tag());
64 break;
65 }
66
67 let mut buf = if inner.flags.get().is_read_buf_ready() {
68 inner.pool.get().get_read_buf()
71 } else {
72 inner
73 .buffer
74 .get_read_source()
75 .unwrap_or_else(|| inner.pool.get().get_read_buf())
76 };
77
78 let (hw, lw) = self.0.memory_pool().read_params().unpack();
80 let remaining = buf.remaining_mut();
81 if remaining <= lw {
82 buf.reserve(hw - remaining);
83 }
84 let total = buf.len();
85
86 let (buf, result) = match select(io.read(buf), self.wait_for_close()).await {
88 Either::Left(res) => res,
89 Either::Right(_) => {
90 log::trace!("{}: Read io is closed, stop read task", self.tag());
91 break;
92 }
93 };
94
95 let total2 = buf.len();
97 let nbytes = total2.saturating_sub(total);
98 let total = total2;
99
100 if let Some(mut first_buf) = inner.buffer.get_read_source() {
101 first_buf.extend_from_slice(&buf);
102 inner.buffer.set_read_source(&self.0, first_buf);
103 } else {
104 inner.buffer.set_read_source(&self.0, buf);
105 }
106
107 if nbytes > 0 {
109 let filter = self.0.filter();
110 let res = match filter
111 .process_read_buf(FilterCtx::new(&self.0, &inner.buffer), nbytes)
112 {
113 Ok(status) => {
114 if status.nbytes > 0 {
115 if hw < inner.buffer.read_destination_size() {
117 log::trace!(
118 "{}: Io read buffer is too large {}, enable read back-pressure",
119 self.0.tag(),
120 total
121 );
122 inner.insert_flags(Flags::BUF_R_READY | Flags::BUF_R_FULL);
123 } else {
124 inner.insert_flags(Flags::BUF_R_READY);
125 }
126 log::trace!(
127 "{}: New {} bytes available, wakeup dispatcher",
128 self.0.tag(),
129 nbytes
130 );
131 inner.dispatch_task.wake();
133 } else if inner.flags.get().is_waiting_for_read() {
134 inner.dispatch_task.wake();
137 }
138
139 if status.need_write {
143 filter.process_write_buf(FilterCtx::new(&self.0, &inner.buffer))
144 } else {
145 Ok(())
146 }
147 }
148 Err(err) => Err(err),
149 };
150
151 if let Err(err) = res {
152 inner.dispatch_task.wake();
153 inner.io_stopped(Some(err));
154 inner.insert_flags(Flags::BUF_R_READY);
155 }
156 }
157
158 match result {
159 Ok(0) => {
160 log::trace!("{}: Tcp stream is disconnected", self.tag());
161 inner.io_stopped(None);
162 break;
163 }
164 Ok(_) => {
165 if inner.flags.get().contains(Flags::IO_STOPPING_FILTERS) {
166 lazy(|cx| self.shutdown_filters(cx)).await;
167 }
168 }
169 Err(err) => {
170 log::trace!("{}: Read task failed on io {:?}", self.tag(), err);
171 inner.io_stopped(Some(err));
172 break;
173 }
174 }
175 }
176 }
177
178 fn shutdown_filters(&self, cx: &mut Context<'_>) {
179 let st = &self.0 .0;
180 let filter = self.0.filter();
181
182 match filter.shutdown(FilterCtx::new(&self.0, &st.buffer)) {
183 Ok(Poll::Ready(())) => {
184 st.dispatch_task.wake();
185 st.insert_flags(Flags::IO_STOPPING);
186 }
187 Ok(Poll::Pending) => {
188 let flags = st.flags.get();
189
190 if flags.contains(Flags::RD_PAUSED)
193 || flags.contains(Flags::BUF_R_FULL | Flags::BUF_R_READY)
194 {
195 st.dispatch_task.wake();
196 st.insert_flags(Flags::IO_STOPPING);
197 } else {
198 let timeout = self
200 .1
201 .take()
202 .unwrap_or_else(|| sleep(st.disconnect_timeout.get()));
203 if timeout.poll_elapsed(cx).is_ready() {
204 st.dispatch_task.wake();
205 st.insert_flags(Flags::IO_STOPPING);
206 } else {
207 self.1.set(Some(timeout));
208 }
209 }
210 }
211 Err(err) => {
212 st.io_stopped(Some(err));
213 }
214 }
215 if let Err(err) = filter.process_write_buf(FilterCtx::new(&self.0, &st.buffer)) {
216 st.io_stopped(Some(err));
217 }
218 }
219}
220
221#[derive(Debug)]
222pub struct WriteContext(IoRef);
224
225#[derive(Debug)]
226pub struct WriteContextBuf {
228 io: IoRef,
229 buf: Option<BytesVec>,
230}
231
232impl WriteContext {
233 pub(crate) fn new(io: &IoRef) -> Self {
234 Self(io.clone())
235 }
236
237 #[inline]
238 pub fn tag(&self) -> &'static str {
240 self.0.tag()
241 }
242
243 async fn ready(&self) -> Readiness {
245 poll_fn(|cx| self.0.filter().poll_write_ready(cx)).await
246 }
247
248 fn close(&self, err: Option<io::Error>) {
250 self.0 .0.io_stopped(err);
251 }
252
253 async fn when_stopped(&self) {
255 poll_fn(|cx| {
256 if self.0.flags().is_stopped() {
257 Poll::Ready(())
258 } else {
259 self.0 .0.write_task.register(cx.waker());
260 Poll::Pending
261 }
262 })
263 .await
264 }
265
266 pub async fn handle<T>(&self, io: &mut T)
268 where
269 T: AsyncWrite,
270 {
271 let mut buf = WriteContextBuf {
272 io: self.0.clone(),
273 buf: None,
274 };
275
276 loop {
277 match self.ready().await {
278 Readiness::Ready => {
279 match select(io.write(&mut buf), self.when_stopped()).await {
281 Either::Left(Ok(_)) => continue,
282 Either::Left(Err(e)) => self.close(Some(e)),
283 Either::Right(_) => return,
284 }
285 }
286 Readiness::Shutdown => {
287 log::trace!("{}: Write task is instructed to shutdown", self.tag());
288
289 let fut = async {
290 io.write(&mut buf).await?;
292 io.flush().await?;
293 io.shutdown().await?;
294 Ok(())
295 };
296 match select(sleep(self.0 .0.disconnect_timeout.get()), fut).await {
297 Either::Left(_) => self.close(None),
298 Either::Right(res) => self.close(res.err()),
299 }
300 }
301 Readiness::Terminate => {
302 log::trace!("{}: Write task is instructed to terminate", self.tag());
303 self.close(io.shutdown().await.err());
304 }
305 }
306 return;
307 }
308 }
309}
310
311impl WriteContextBuf {
312 pub fn set(&mut self, mut buf: BytesVec) {
313 if buf.is_empty() {
314 self.io.memory_pool().release_write_buf(buf);
315 } else if let Some(b) = self.buf.take() {
316 buf.extend_from_slice(&b);
317 self.io.memory_pool().release_write_buf(b);
318 self.buf = Some(buf);
319 } else if let Some(b) = self.io.0.buffer.set_write_destination(buf) {
320 self.buf = Some(b);
322 }
323
324 let inner = &self.io.0;
326 let len = self.buf.as_ref().map(|b| b.len()).unwrap_or_default()
327 + inner.buffer.write_destination_size();
328 let mut flags = inner.flags.get();
329
330 if len == 0 {
331 if flags.is_waiting_for_write() {
332 flags.waiting_for_write_is_done();
333 inner.dispatch_task.wake();
334 }
335 flags.insert(Flags::WR_PAUSED);
336 inner.flags.set(flags);
337 } else if flags.contains(Flags::BUF_W_BACKPRESSURE)
338 && len < inner.pool.get().write_params_high() << 1
339 {
340 flags.remove(Flags::BUF_W_BACKPRESSURE);
341 inner.flags.set(flags);
342 inner.dispatch_task.wake();
343 }
344 }
345
346 pub fn take(&mut self) -> Option<BytesVec> {
347 if let Some(buf) = self.buf.take() {
348 Some(buf)
349 } else {
350 self.io.0.buffer.get_write_destination()
351 }
352 }
353}
354
355pub struct IoContext(IoRef);
357
358impl fmt::Debug for IoContext {
359 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
360 f.debug_struct("IoContext").field("io", &self.0).finish()
361 }
362}
363
364impl IoContext {
365 pub(crate) fn new(io: &IoRef) -> Self {
366 Self(io.clone())
367 }
368
369 #[doc(hidden)]
370 #[inline]
371 pub fn id(&self) -> usize {
372 self.0 .0.as_ref() as *const _ as usize
373 }
374
375 #[inline]
376 pub fn tag(&self) -> &'static str {
378 self.0.tag()
379 }
380
381 #[doc(hidden)]
382 pub fn flags(&self) -> crate::flags::Flags {
384 self.0.flags()
385 }
386
387 #[inline]
388 pub fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<Readiness> {
390 self.shutdown_filters();
391 self.0.filter().poll_read_ready(cx)
392 }
393
394 #[inline]
395 pub fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<Readiness> {
397 self.0.filter().poll_write_ready(cx)
398 }
399
400 #[inline]
401 pub fn stop(&self, e: Option<io::Error>) {
403 self.0 .0.io_stopped(e);
404 }
405
406 #[inline]
407 pub fn init_shutdown(&self) {
409 self.0 .0.init_shutdown();
410 }
411
412 #[inline]
413 pub fn is_stopped(&self) -> bool {
415 self.0.flags().is_stopped()
416 }
417
418 pub fn shutdown(&self, flush: bool, cx: &mut Context<'_>) -> Poll<()> {
420 let st = &self.0 .0;
421
422 let flags = self.0.flags();
423 if !flags.intersects(Flags::IO_STOPPING | Flags::IO_STOPPED) {
424 st.write_task.register(cx.waker());
425 return Poll::Pending;
426 }
427
428 if flush && !flags.contains(Flags::IO_STOPPED) {
429 if flags.intersects(Flags::WR_PAUSED | Flags::IO_STOPPED) {
430 return Poll::Ready(());
431 }
432 st.insert_flags(Flags::WR_TASK_WAIT);
433 st.write_task.register(cx.waker());
434 Poll::Pending
435 } else {
436 Poll::Ready(())
437 }
438 }
439
440 pub fn get_read_buf(&self) -> BytesVec {
442 let inner = &self.0 .0;
443
444 if inner.flags.get().is_read_buf_ready() {
445 inner.pool.get().get_read_buf()
448 } else {
449 inner
450 .buffer
451 .get_read_source()
452 .unwrap_or_else(|| inner.pool.get().get_read_buf())
453 }
454 }
455
456 pub fn release_read_buf(
458 &self,
459 nbytes: usize,
460 buf: BytesVec,
461 result: Poll<io::Result<()>>,
462 ) -> IoTaskStatus {
463 let inner = &self.0 .0;
464 let orig_size = inner.buffer.read_destination_size();
465 let hw = self.0.memory_pool().read_params().unpack().0;
466
467 if let Some(mut first_buf) = inner.buffer.get_read_source() {
468 first_buf.extend_from_slice(&buf);
469 inner.buffer.set_read_source(&self.0, first_buf);
470 } else {
471 inner.buffer.set_read_source(&self.0, buf);
472 }
473
474 let mut full = false;
475
476 let st_res = if nbytes > 0 {
478 match self
479 .0
480 .filter()
481 .process_read_buf(FilterCtx::new(&self.0, &inner.buffer), nbytes)
482 {
483 Ok(status) => {
484 let buffer_size = inner.buffer.read_destination_size();
485 if buffer_size.saturating_sub(orig_size) > 0 {
486 if buffer_size >= hw {
488 log::trace!(
489 "{}: Io read buffer is too large {}, enable read back-pressure",
490 self.tag(),
491 buffer_size
492 );
493 full = true;
494 inner.insert_flags(Flags::BUF_R_READY | Flags::BUF_R_FULL);
495 } else {
496 inner.insert_flags(Flags::BUF_R_READY);
497 }
498 log::trace!(
499 "{}: New {} bytes available, wakeup dispatcher",
500 self.tag(),
501 buffer_size
502 );
503 inner.dispatch_task.wake();
504 } else {
505 if buffer_size >= hw {
506 full = true;
511 inner.read_task.wake();
512 }
513 if inner.flags.get().is_waiting_for_read() {
514 inner.dispatch_task.wake();
517 }
518 }
519
520 if status.need_write {
524 self.0
525 .filter()
526 .process_write_buf(FilterCtx::new(&self.0, &inner.buffer))
527 } else {
528 Ok(())
529 }
530 }
531 Err(err) => Err(err),
532 }
533 } else {
534 Ok(())
535 };
536
537 match result {
538 Poll::Ready(Ok(_)) => {
539 if let Err(e) = st_res {
540 inner.io_stopped(Some(e));
541 IoTaskStatus::Pause
542 } else if nbytes == 0 {
543 inner.io_stopped(None);
544 IoTaskStatus::Pause
545 } else {
546 self.shutdown_filters();
547 if full {
548 IoTaskStatus::Pause
549 } else {
550 IoTaskStatus::Io
551 }
552 }
553 }
554 Poll::Ready(Err(e)) => {
555 inner.io_stopped(Some(e));
556 IoTaskStatus::Pause
557 }
558 Poll::Pending => {
559 if let Err(e) = st_res {
560 inner.io_stopped(Some(e));
561 IoTaskStatus::Pause
562 } else {
563 self.shutdown_filters();
564 if full {
565 IoTaskStatus::Pause
566 } else {
567 IoTaskStatus::Io
568 }
569 }
570 }
571 }
572 }
573
574 pub fn get_write_buf(&self) -> Option<BytesVec> {
576 self.0 .0.buffer.get_write_destination().and_then(|buf| {
577 if buf.is_empty() {
578 None
579 } else {
580 Some(buf)
581 }
582 })
583 }
584
585 pub fn release_write_buf(
587 &self,
588 mut buf: BytesVec,
589 result: Poll<io::Result<usize>>,
590 ) -> IoTaskStatus {
591 let result = match result {
592 Poll::Ready(Ok(0)) => {
593 log::trace!("{}: Disconnected during flush", self.tag());
594 Err(io::Error::new(
595 io::ErrorKind::WriteZero,
596 "failed to write frame to transport",
597 ))
598 }
599 Poll::Ready(Ok(n)) => {
600 if n == buf.len() {
601 buf.clear();
602 Ok(0)
603 } else {
604 buf.advance(n);
605 Ok(buf.len())
606 }
607 }
608 Poll::Ready(Err(e)) => Err(e),
609 Poll::Pending => Ok(buf.len()),
610 };
611
612 let inner = &self.0 .0;
613
614 let result = match result {
616 Ok(0) => {
617 self.0.memory_pool().release_write_buf(buf);
618 Ok(inner.buffer.write_destination_size())
619 }
620 Ok(_) => {
621 if let Some(b) = inner.buffer.get_write_destination() {
622 buf.extend_from_slice(&b);
623 self.0.memory_pool().release_write_buf(b);
624 }
625 let l = buf.len();
626 inner.buffer.set_write_destination(buf);
627 Ok(l)
628 }
629 Err(e) => Err(e),
630 };
631
632 match result {
633 Ok(0) => {
634 let mut flags = inner.flags.get();
635
636 flags.insert(Flags::WR_PAUSED);
638
639 if flags.is_task_waiting_for_write() {
640 flags.task_waiting_for_write_is_done();
641 inner.write_task.wake();
642 }
643
644 if flags.is_waiting_for_write() {
645 flags.waiting_for_write_is_done();
646 inner.dispatch_task.wake();
647 }
648 inner.flags.set(flags);
649 IoTaskStatus::Pause
650 }
651 Ok(len) => {
652 if inner.flags.get().contains(Flags::BUF_W_BACKPRESSURE)
654 && len < inner.pool.get().write_params_high() << 1
655 {
656 inner.remove_flags(Flags::BUF_W_BACKPRESSURE);
657 inner.dispatch_task.wake();
658 }
659 if self.is_stopped() {
660 IoTaskStatus::Pause
661 } else {
662 IoTaskStatus::Io
663 }
664 }
665 Err(e) => {
666 inner.io_stopped(Some(e));
667 IoTaskStatus::Pause
668 }
669 }
670 }
671
672 fn shutdown_filters(&self) {
673 let io = &self.0;
674 let st = &self.0 .0;
675 let flags = st.flags.get();
676 if flags.contains(Flags::IO_STOPPING_FILTERS)
677 && !flags.intersects(Flags::IO_STOPPED | Flags::IO_STOPPING)
678 {
679 match io.filter().shutdown(FilterCtx::new(io, &st.buffer)) {
680 Ok(Poll::Ready(())) => {
681 st.write_task.wake();
682 st.dispatch_task.wake();
683 st.insert_flags(Flags::IO_STOPPING);
684 }
685 Ok(Poll::Pending) => {
686 let flags = st.flags.get();
689 if flags.contains(Flags::RD_PAUSED)
690 || flags.contains(Flags::BUF_R_FULL | Flags::BUF_R_READY)
691 {
692 st.write_task.wake();
693 st.dispatch_task.wake();
694 st.insert_flags(Flags::IO_STOPPING);
695 }
696 }
697 Err(err) => {
698 st.io_stopped(Some(err));
699 }
700 }
701 if let Err(err) = io
702 .filter()
703 .process_write_buf(FilterCtx::new(io, &st.buffer))
704 {
705 st.io_stopped(Some(err));
706 }
707 }
708 }
709}
710
711impl Clone for IoContext {
712 fn clone(&self) -> Self {
713 Self(self.0.clone())
714 }
715}