1use std::{cell::Cell, fmt, future::poll_fn, io, task::Context, task::Poll};
2
3use ntex_bytes::{Buf, BufMut, BytesVec};
4use ntex_util::{future::lazy, future::select, future::Either, time::sleep, time::Sleep};
5
6use crate::{AsyncRead, AsyncWrite, Flags, IoRef, IoTaskStatus, Readiness};
7
8pub struct ReadContext(IoRef, Cell<Option<Sleep>>);
10
11impl fmt::Debug for ReadContext {
12 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
13 f.debug_struct("ReadContext").field("io", &self.0).finish()
14 }
15}
16
17impl ReadContext {
18 pub(crate) fn new(io: &IoRef) -> Self {
19 Self(io.clone(), Cell::new(None))
20 }
21
22 #[doc(hidden)]
23 #[inline]
24 pub fn context(&self) -> IoContext {
26 IoContext::new(&self.0)
27 }
28
29 #[inline]
30 pub fn tag(&self) -> &'static str {
32 self.0.tag()
33 }
34
35 async fn wait_for_close(&self) {
37 poll_fn(|cx| {
38 let flags = self.0.flags();
39
40 if flags.intersects(Flags::IO_STOPPING | Flags::IO_STOPPED) {
41 Poll::Ready(())
42 } else {
43 self.0 .0.read_task.register(cx.waker());
44 if flags.contains(Flags::IO_STOPPING_FILTERS) {
45 self.shutdown_filters(cx);
46 }
47 Poll::Pending
48 }
49 })
50 .await
51 }
52
53 pub async fn handle<T>(&self, io: &mut T)
55 where
56 T: AsyncRead,
57 {
58 let inner = &self.0 .0;
59
60 loop {
61 let result = poll_fn(|cx| self.0.filter().poll_read_ready(cx)).await;
62 if result == Readiness::Terminate {
63 log::trace!("{}: Read task is instructed to shutdown", self.tag());
64 break;
65 }
66
67 let mut buf = if inner.flags.get().is_read_buf_ready() {
68 inner.pool.get().get_read_buf()
71 } else {
72 inner
73 .buffer
74 .get_read_source()
75 .unwrap_or_else(|| inner.pool.get().get_read_buf())
76 };
77
78 let (hw, lw) = self.0.memory_pool().read_params().unpack();
80 let remaining = buf.remaining_mut();
81 if remaining <= lw {
82 buf.reserve(hw - remaining);
83 }
84 let total = buf.len();
85
86 let (buf, result) = match select(io.read(buf), self.wait_for_close()).await {
88 Either::Left(res) => res,
89 Either::Right(_) => {
90 log::trace!("{}: Read io is closed, stop read task", self.tag());
91 break;
92 }
93 };
94
95 let total2 = buf.len();
97 let nbytes = total2.saturating_sub(total);
98 let total = total2;
99
100 if let Some(mut first_buf) = inner.buffer.get_read_source() {
101 first_buf.extend_from_slice(&buf);
102 inner.buffer.set_read_source(&self.0, first_buf);
103 } else {
104 inner.buffer.set_read_source(&self.0, buf);
105 }
106
107 if nbytes > 0 {
109 let filter = self.0.filter();
110 let res = match filter.process_read_buf(&self.0, &inner.buffer, 0, nbytes) {
111 Ok(status) => {
112 if status.nbytes > 0 {
113 if hw < inner.buffer.read_destination_size() {
115 log::trace!(
116 "{}: Io read buffer is too large {}, enable read back-pressure",
117 self.0.tag(),
118 total
119 );
120 inner.insert_flags(Flags::BUF_R_READY | Flags::BUF_R_FULL);
121 } else {
122 inner.insert_flags(Flags::BUF_R_READY);
123 }
124 log::trace!(
125 "{}: New {} bytes available, wakeup dispatcher",
126 self.0.tag(),
127 nbytes
128 );
129 inner.dispatch_task.wake();
131 } else if inner.flags.get().is_waiting_for_read() {
132 inner.dispatch_task.wake();
135 }
136
137 if status.need_write {
141 filter.process_write_buf(&self.0, &inner.buffer, 0)
142 } else {
143 Ok(())
144 }
145 }
146 Err(err) => Err(err),
147 };
148
149 if let Err(err) = res {
150 inner.dispatch_task.wake();
151 inner.io_stopped(Some(err));
152 inner.insert_flags(Flags::BUF_R_READY);
153 }
154 }
155
156 match result {
157 Ok(0) => {
158 log::trace!("{}: Tcp stream is disconnected", self.tag());
159 inner.io_stopped(None);
160 break;
161 }
162 Ok(_) => {
163 if inner.flags.get().contains(Flags::IO_STOPPING_FILTERS) {
164 lazy(|cx| self.shutdown_filters(cx)).await;
165 }
166 }
167 Err(err) => {
168 log::trace!("{}: Read task failed on io {:?}", self.tag(), err);
169 inner.io_stopped(Some(err));
170 break;
171 }
172 }
173 }
174 }
175
176 fn shutdown_filters(&self, cx: &mut Context<'_>) {
177 let st = &self.0 .0;
178 let filter = self.0.filter();
179
180 match filter.shutdown(&self.0, &st.buffer, 0) {
181 Ok(Poll::Ready(())) => {
182 st.dispatch_task.wake();
183 st.insert_flags(Flags::IO_STOPPING);
184 }
185 Ok(Poll::Pending) => {
186 let flags = st.flags.get();
187
188 if flags.contains(Flags::RD_PAUSED)
191 || flags.contains(Flags::BUF_R_FULL | Flags::BUF_R_READY)
192 {
193 st.dispatch_task.wake();
194 st.insert_flags(Flags::IO_STOPPING);
195 } else {
196 let timeout = self
198 .1
199 .take()
200 .unwrap_or_else(|| sleep(st.disconnect_timeout.get()));
201 if timeout.poll_elapsed(cx).is_ready() {
202 st.dispatch_task.wake();
203 st.insert_flags(Flags::IO_STOPPING);
204 } else {
205 self.1.set(Some(timeout));
206 }
207 }
208 }
209 Err(err) => {
210 st.io_stopped(Some(err));
211 }
212 }
213 if let Err(err) = filter.process_write_buf(&self.0, &st.buffer, 0) {
214 st.io_stopped(Some(err));
215 }
216 }
217}
218
219#[derive(Debug)]
220pub struct WriteContext(IoRef);
222
223#[derive(Debug)]
224pub struct WriteContextBuf {
226 io: IoRef,
227 buf: Option<BytesVec>,
228}
229
230impl WriteContext {
231 pub(crate) fn new(io: &IoRef) -> Self {
232 Self(io.clone())
233 }
234
235 #[inline]
236 pub fn tag(&self) -> &'static str {
238 self.0.tag()
239 }
240
241 async fn ready(&self) -> Readiness {
243 poll_fn(|cx| self.0.filter().poll_write_ready(cx)).await
244 }
245
246 fn close(&self, err: Option<io::Error>) {
248 self.0 .0.io_stopped(err);
249 }
250
251 async fn when_stopped(&self) {
253 poll_fn(|cx| {
254 if self.0.flags().is_stopped() {
255 Poll::Ready(())
256 } else {
257 self.0 .0.write_task.register(cx.waker());
258 Poll::Pending
259 }
260 })
261 .await
262 }
263
264 pub async fn handle<T>(&self, io: &mut T)
266 where
267 T: AsyncWrite,
268 {
269 let mut buf = WriteContextBuf {
270 io: self.0.clone(),
271 buf: None,
272 };
273
274 loop {
275 match self.ready().await {
276 Readiness::Ready => {
277 match select(io.write(&mut buf), self.when_stopped()).await {
279 Either::Left(Ok(_)) => continue,
280 Either::Left(Err(e)) => self.close(Some(e)),
281 Either::Right(_) => return,
282 }
283 }
284 Readiness::Shutdown => {
285 log::trace!("{}: Write task is instructed to shutdown", self.tag());
286
287 let fut = async {
288 io.write(&mut buf).await?;
290 io.flush().await?;
291 io.shutdown().await?;
292 Ok(())
293 };
294 match select(sleep(self.0 .0.disconnect_timeout.get()), fut).await {
295 Either::Left(_) => self.close(None),
296 Either::Right(res) => self.close(res.err()),
297 }
298 }
299 Readiness::Terminate => {
300 log::trace!("{}: Write task is instructed to terminate", self.tag());
301 self.close(io.shutdown().await.err());
302 }
303 }
304 return;
305 }
306 }
307}
308
309impl WriteContextBuf {
310 pub fn set(&mut self, mut buf: BytesVec) {
311 if buf.is_empty() {
312 self.io.memory_pool().release_write_buf(buf);
313 } else if let Some(b) = self.buf.take() {
314 buf.extend_from_slice(&b);
315 self.io.memory_pool().release_write_buf(b);
316 self.buf = Some(buf);
317 } else if let Some(b) = self.io.0.buffer.set_write_destination(buf) {
318 self.buf = Some(b);
320 }
321
322 let inner = &self.io.0;
324 let len = self.buf.as_ref().map(|b| b.len()).unwrap_or_default()
325 + inner.buffer.write_destination_size();
326 let mut flags = inner.flags.get();
327
328 if len == 0 {
329 if flags.is_waiting_for_write() {
330 flags.waiting_for_write_is_done();
331 inner.dispatch_task.wake();
332 }
333 flags.insert(Flags::WR_PAUSED);
334 inner.flags.set(flags);
335 } else if flags.contains(Flags::BUF_W_BACKPRESSURE)
336 && len < inner.pool.get().write_params_high() << 1
337 {
338 flags.remove(Flags::BUF_W_BACKPRESSURE);
339 inner.flags.set(flags);
340 inner.dispatch_task.wake();
341 }
342 }
343
344 pub fn take(&mut self) -> Option<BytesVec> {
345 if let Some(buf) = self.buf.take() {
346 Some(buf)
347 } else {
348 self.io.0.buffer.get_write_destination()
349 }
350 }
351}
352
353pub struct IoContext(IoRef);
355
356impl fmt::Debug for IoContext {
357 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
358 f.debug_struct("IoContext").field("io", &self.0).finish()
359 }
360}
361
362impl IoContext {
363 pub(crate) fn new(io: &IoRef) -> Self {
364 Self(io.clone())
365 }
366
367 #[inline]
368 pub fn tag(&self) -> &'static str {
370 self.0.tag()
371 }
372
373 #[doc(hidden)]
374 pub fn flags(&self) -> crate::flags::Flags {
376 self.0.flags()
377 }
378
379 #[inline]
380 pub fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<Readiness> {
382 self.shutdown_filters();
383 self.0.filter().poll_read_ready(cx)
384 }
385
386 #[inline]
387 pub fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<Readiness> {
389 self.0.filter().poll_write_ready(cx)
390 }
391
392 #[inline]
393 pub fn stopped(&self, e: Option<io::Error>) {
395 self.0 .0.io_stopped(e);
396 }
397
398 #[inline]
399 pub fn is_stopped(&self) -> bool {
401 self.0.flags().is_stopped()
402 }
403
404 pub async fn shutdown(&self, flush_buf: bool) {
406 let st = &self.0 .0;
407 let mut timeout = None;
408
409 poll_fn(|cx| {
410 let flags = self.0.flags();
411 if flags.intersects(Flags::IO_STOPPING | Flags::IO_STOPPED) {
412 Poll::Ready(())
413 } else {
414 st.write_task.register(cx.waker());
415 if flags.contains(Flags::IO_STOPPING_FILTERS) {
416 if timeout.is_none() {
417 timeout = Some(sleep(st.disconnect_timeout.get()));
418 }
419 if timeout.as_ref().unwrap().poll_elapsed(cx).is_ready() {
420 st.dispatch_task.wake();
421 st.insert_flags(Flags::IO_STOPPING);
422 return Poll::Ready(());
423 }
424 }
425 Poll::Pending
426 }
427 })
428 .await;
429
430 if flush_buf && !st.flags.get().contains(Flags::WR_PAUSED) {
431 st.insert_flags(Flags::WR_TASK_WAIT);
432
433 poll_fn(|cx| {
434 let flags = st.flags.get();
435 if flags.intersects(Flags::WR_PAUSED | Flags::IO_STOPPED) {
436 Poll::Ready(())
437 } else {
438 st.write_task.register(cx.waker());
439 if timeout.is_none() {
440 timeout = Some(sleep(st.disconnect_timeout.get()));
441 }
442 if timeout.as_ref().unwrap().poll_elapsed(cx).is_ready() {
443 Poll::Ready(())
444 } else {
445 Poll::Pending
446 }
447 }
448 })
449 .await;
450 }
451 }
452
453 pub fn get_read_buf(&self) -> (BytesVec, usize, usize) {
455 let inner = &self.0 .0;
456
457 let buf = if inner.flags.get().is_read_buf_ready() {
458 inner.pool.get().get_read_buf()
461 } else {
462 inner
463 .buffer
464 .get_read_source()
465 .unwrap_or_else(|| inner.pool.get().get_read_buf())
466 };
467
468 let (hw, lw) = self.0.memory_pool().read_params().unpack();
470 (buf, hw, lw)
471 }
472
473 pub fn release_read_buf(
475 &self,
476 nbytes: usize,
477 buf: BytesVec,
478 result: Poll<io::Result<()>>,
479 ) -> IoTaskStatus {
480 let inner = &self.0 .0;
481 let orig_size = inner.buffer.read_destination_size();
482 let hw = self.0.memory_pool().read_params().unpack().0;
483
484 if let Some(mut first_buf) = inner.buffer.get_read_source() {
485 first_buf.extend_from_slice(&buf);
486 inner.buffer.set_read_source(&self.0, first_buf);
487 } else {
488 inner.buffer.set_read_source(&self.0, buf);
489 }
490
491 let st_res = if nbytes > 0 {
493 match self
494 .0
495 .filter()
496 .process_read_buf(&self.0, &inner.buffer, 0, nbytes)
497 {
498 Ok(status) => {
499 let buffer_size = inner.buffer.read_destination_size();
500 if buffer_size.saturating_sub(orig_size) > 0 {
501 if buffer_size >= hw {
503 log::trace!(
504 "{}: Io read buffer is too large {}, enable read back-pressure",
505 self.tag(),
506 buffer_size
507 );
508 inner.insert_flags(Flags::BUF_R_READY | Flags::BUF_R_FULL);
509 } else {
510 inner.insert_flags(Flags::BUF_R_READY);
511 }
512 log::trace!(
513 "{}: New {} bytes available, wakeup dispatcher",
514 self.tag(),
515 buffer_size
516 );
517 inner.dispatch_task.wake();
518 } else {
519 if buffer_size >= hw {
520 inner.read_task.wake();
525 }
526 if inner.flags.get().is_waiting_for_read() {
527 inner.dispatch_task.wake();
530 }
531 }
532
533 if status.need_write {
537 self.0.filter().process_write_buf(&self.0, &inner.buffer, 0)
538 } else {
539 Ok(())
540 }
541 }
542 Err(err) => {
543 inner.insert_flags(Flags::BUF_R_READY);
544 Err(err)
545 }
546 }
547 } else {
548 Ok(())
549 };
550
551 match result {
552 Poll::Ready(Ok(_)) => {
553 if let Err(e) = st_res {
554 inner.io_stopped(Some(e));
555 IoTaskStatus::Pause
556 } else if nbytes == 0 {
557 inner.io_stopped(None);
558 IoTaskStatus::Pause
559 } else {
560 IoTaskStatus::Io
561 }
562 }
563 Poll::Ready(Err(e)) => {
564 inner.io_stopped(Some(e));
565 IoTaskStatus::Pause
566 }
567 Poll::Pending => {
568 if let Err(e) = st_res {
569 inner.io_stopped(Some(e));
570 IoTaskStatus::Pause
571 } else {
572 self.shutdown_filters();
573 IoTaskStatus::Io
574 }
575 }
576 }
577 }
578
579 pub fn get_write_buf(&self) -> Option<BytesVec> {
581 self.0 .0.buffer.get_write_destination().and_then(|buf| {
582 if buf.is_empty() {
583 None
584 } else {
585 Some(buf)
586 }
587 })
588 }
589
590 pub fn release_write_buf(
592 &self,
593 mut buf: BytesVec,
594 result: Poll<io::Result<usize>>,
595 ) -> IoTaskStatus {
596 let result = match result {
597 Poll::Ready(Ok(0)) => {
598 log::trace!("{}: Disconnected during flush", self.tag());
599 Err(io::Error::new(
600 io::ErrorKind::WriteZero,
601 "failed to write frame to transport",
602 ))
603 }
604 Poll::Ready(Ok(n)) => {
605 if n == buf.len() {
606 buf.clear();
607 Ok(0)
608 } else {
609 buf.advance(n);
610 Ok(buf.len())
611 }
612 }
613 Poll::Ready(Err(e)) => Err(e),
614 Poll::Pending => Ok(buf.len()),
615 };
616
617 let inner = &self.0 .0;
618
619 let result = match result {
621 Ok(0) => {
622 self.0.memory_pool().release_write_buf(buf);
623 Ok(inner.buffer.write_destination_size())
624 }
625 Ok(_) => {
626 if let Some(b) = inner.buffer.get_write_destination() {
627 buf.extend_from_slice(&b);
628 self.0.memory_pool().release_write_buf(b);
629 }
630 let l = buf.len();
631 inner.buffer.set_write_destination(buf);
632 Ok(l)
633 }
634 Err(e) => Err(e),
635 };
636
637 match result {
638 Ok(0) => {
639 let mut flags = inner.flags.get();
640
641 flags.insert(Flags::WR_PAUSED);
643
644 if flags.is_task_waiting_for_write() {
645 flags.task_waiting_for_write_is_done();
646 inner.write_task.wake();
647 }
648
649 if flags.is_waiting_for_write() {
650 flags.waiting_for_write_is_done();
651 inner.dispatch_task.wake();
652 }
653 inner.flags.set(flags);
654 IoTaskStatus::Pause
655 }
656 Ok(len) => {
657 if inner.flags.get().contains(Flags::BUF_W_BACKPRESSURE)
659 && len < inner.pool.get().write_params_high() << 1
660 {
661 inner.remove_flags(Flags::BUF_W_BACKPRESSURE);
662 inner.dispatch_task.wake();
663 }
664 IoTaskStatus::Io
665 }
666 Err(e) => {
667 inner.io_stopped(Some(e));
668 IoTaskStatus::Pause
669 }
670 }
671 }
672
673 fn shutdown_filters(&self) {
674 let io = &self.0;
675 let st = &self.0 .0;
676 let flags = st.flags.get();
677 if flags.contains(Flags::IO_STOPPING_FILTERS)
678 && !flags.intersects(Flags::IO_STOPPED | Flags::IO_STOPPING)
679 {
680 match io.filter().shutdown(io, &st.buffer, 0) {
681 Ok(Poll::Ready(())) => {
682 st.dispatch_task.wake();
683 st.insert_flags(Flags::IO_STOPPING);
684 }
685 Ok(Poll::Pending) => {
686 let flags = st.flags.get();
689 if flags.contains(Flags::RD_PAUSED)
690 || flags.contains(Flags::BUF_R_FULL | Flags::BUF_R_READY)
691 {
692 st.dispatch_task.wake();
693 st.insert_flags(Flags::IO_STOPPING);
694 }
695 }
696 Err(err) => {
697 st.io_stopped(Some(err));
698 }
699 }
700 if let Err(err) = io.filter().process_write_buf(io, &st.buffer, 0) {
701 st.io_stopped(Some(err));
702 }
703 }
704 }
705}
706
707impl Clone for IoContext {
708 fn clone(&self) -> Self {
709 Self(self.0.clone())
710 }
711}