1use std::{cell::Cell, fmt, io, task::Context, task::Poll};
2
3use ntex_bytes::{Buf, BytesVec};
4use ntex_util::time::{Sleep, sleep};
5
6use crate::{FilterCtx, Flags, IoRef, IoTaskStatus, Readiness};
7
8pub struct IoContext(IoRef, Cell<Option<Sleep>>);
10
11impl fmt::Debug for IoContext {
12 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
13 f.debug_struct("IoContext").field("io", &self.0).finish()
14 }
15}
16
17impl IoContext {
18 pub(crate) fn new(io: &IoRef) -> Self {
19 Self(io.clone(), Cell::new(None))
20 }
21
22 #[doc(hidden)]
23 #[inline]
24 pub fn id(&self) -> usize {
25 self.0.0.as_ref() as *const _ as usize
26 }
27
28 #[inline]
29 pub fn tag(&self) -> &'static str {
31 self.0.tag()
32 }
33
34 #[inline]
35 #[doc(hidden)]
36 pub fn flags(&self) -> crate::flags::Flags {
38 self.0.flags()
39 }
40
41 #[inline]
42 pub fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<Readiness> {
44 self.shutdown_filters(cx);
45 self.0.filter().poll_read_ready(cx)
46 }
47
48 #[inline]
49 pub fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<Readiness> {
51 self.0.filter().poll_write_ready(cx)
52 }
53
54 #[inline]
55 pub fn stop(&self, e: Option<io::Error>) {
57 self.0.0.io_stopped(e);
58 }
59
60 #[inline]
61 pub fn is_stopped(&self) -> bool {
63 self.0.flags().is_stopped()
64 }
65
66 pub fn shutdown(&self, flush: bool, cx: &mut Context<'_>) -> Poll<()> {
68 let st = &self.0.0;
69 let flags = self.0.flags();
70
71 if flush && !flags.contains(Flags::IO_STOPPED) {
72 if flags.intersects(Flags::WR_PAUSED | Flags::IO_STOPPED) {
73 return Poll::Ready(());
74 }
75 st.insert_flags(Flags::WR_TASK_WAIT);
76 st.write_task.register(cx.waker());
77 Poll::Pending
78 } else if !flags.intersects(Flags::IO_STOPPING | Flags::IO_STOPPED) {
79 st.write_task.register(cx.waker());
80 Poll::Pending
81 } else {
82 Poll::Ready(())
83 }
84 }
85
86 pub fn get_read_buf(&self) -> BytesVec {
88 let inner = &self.0.0;
89
90 if inner.flags.get().is_read_buf_ready() {
91 inner.read_buf().get()
94 } else {
95 inner
96 .buffer
97 .get_read_source()
98 .unwrap_or_else(|| inner.read_buf().get())
99 }
100 }
101
102 pub fn resize_read_buf(&self, buf: &mut BytesVec) {
104 self.0.0.read_buf().resize(buf);
105 }
106
107 pub fn release_read_buf(
109 &self,
110 nbytes: usize,
111 buf: BytesVec,
112 result: Poll<Result<(), Option<io::Error>>>,
113 ) -> IoTaskStatus {
114 let inner = &self.0.0;
115 let orig_size = inner.buffer.read_destination_size();
116 let hw = self.0.cfg().read_buf().high;
117
118 if let Some(mut first_buf) = inner.buffer.get_read_source() {
119 first_buf.extend_from_slice(&buf);
120 inner.buffer.set_read_source(&self.0, first_buf);
121 } else {
122 inner.buffer.set_read_source(&self.0, buf);
123 }
124
125 let mut full = false;
126
127 let st_res = if nbytes > 0 {
129 match self
130 .0
131 .filter()
132 .process_read_buf(FilterCtx::new(&self.0, &inner.buffer), nbytes)
133 {
134 Ok(status) => {
135 let buffer_size = inner.buffer.read_destination_size();
136 if buffer_size.saturating_sub(orig_size) > 0 {
137 if buffer_size >= hw {
139 log::trace!(
140 "{}: Io read buffer is too large {}, enable read back-pressure",
141 self.tag(),
142 buffer_size
143 );
144 full = true;
145 inner.insert_flags(Flags::BUF_R_READY | Flags::BUF_R_FULL);
146 } else {
147 inner.insert_flags(Flags::BUF_R_READY);
148 }
149 log::trace!(
150 "{}: New {} bytes available, wakeup dispatcher",
151 self.tag(),
152 buffer_size
153 );
154 inner.dispatch_task.wake();
155 } else {
156 if buffer_size >= hw {
157 full = true;
162 inner.read_task.wake();
163 }
164 if inner.flags.get().is_waiting_for_read() {
165 inner.dispatch_task.wake();
168 }
169 }
170
171 if status.need_write {
175 self.0
176 .filter()
177 .process_write_buf(FilterCtx::new(&self.0, &inner.buffer))
178 } else {
179 Ok(())
180 }
181 }
182 Err(err) => Err(err),
183 }
184 } else {
185 Ok(())
186 };
187
188 match result {
189 Poll::Ready(Ok(_)) => {
190 if let Err(e) = st_res {
191 inner.io_stopped(Some(e));
192 IoTaskStatus::Stop
193 } else if nbytes == 0 {
194 inner.io_stopped(None);
195 IoTaskStatus::Stop
196 } else if full {
197 IoTaskStatus::Pause
198 } else {
199 IoTaskStatus::Io
200 }
201 }
202 Poll::Ready(Err(e)) => {
203 inner.io_stopped(e);
204 IoTaskStatus::Stop
205 }
206 Poll::Pending => {
207 if let Err(e) = st_res {
208 inner.io_stopped(Some(e));
209 IoTaskStatus::Stop
210 } else if full {
211 IoTaskStatus::Pause
212 } else {
213 IoTaskStatus::Io
214 }
215 }
216 }
217 }
218
219 #[inline]
220 pub fn get_write_buf(&self) -> Option<BytesVec> {
222 self.0
223 .0
224 .buffer
225 .get_write_destination()
226 .and_then(|buf| if buf.is_empty() { None } else { Some(buf) })
227 }
228
229 pub fn release_write_buf(
231 &self,
232 mut buf: BytesVec,
233 result: Poll<io::Result<usize>>,
234 ) -> IoTaskStatus {
235 let result = match result {
236 Poll::Ready(Ok(0)) => {
237 log::trace!("{}: Disconnected during flush", self.tag());
238 Err(io::Error::new(
239 io::ErrorKind::WriteZero,
240 "failed to write frame to transport",
241 ))
242 }
243 Poll::Ready(Ok(n)) => {
244 if n == buf.len() {
245 buf.clear();
246 Ok(0)
247 } else {
248 buf.advance(n);
249 Ok(buf.len())
250 }
251 }
252 Poll::Ready(Err(e)) => Err(e),
253 Poll::Pending => Ok(buf.len()),
254 };
255
256 let inner = &self.0.0;
257
258 let result = match result {
260 Ok(0) => {
261 self.0.cfg().write_buf().release(buf);
262 Ok(inner.buffer.write_destination_size())
263 }
264 Ok(_) => {
265 if let Some(b) = inner.buffer.get_write_destination() {
266 buf.extend_from_slice(&b);
267 self.0.cfg().write_buf().release(b);
268 }
269 let l = buf.len();
270 inner.buffer.set_write_destination(buf);
271 Ok(l)
272 }
273 Err(e) => Err(e),
274 };
275
276 match result {
277 Ok(0) => {
278 let mut flags = inner.flags.get();
279
280 flags.insert(Flags::WR_PAUSED);
282
283 if flags.is_task_waiting_for_write() {
284 flags.task_waiting_for_write_is_done();
285 inner.write_task.wake();
286 }
287
288 if flags.is_waiting_for_write() {
289 flags.waiting_for_write_is_done();
290 inner.dispatch_task.wake();
291 }
292 inner.flags.set(flags);
293 if self.is_stopped() {
294 IoTaskStatus::Stop
295 } else {
296 IoTaskStatus::Pause
297 }
298 }
299 Ok(len) => {
300 if inner.flags.get().contains(Flags::BUF_W_BACKPRESSURE)
302 && len < inner.write_buf().half
303 {
304 inner.remove_flags(Flags::BUF_W_BACKPRESSURE);
305 inner.dispatch_task.wake();
306 }
307 IoTaskStatus::Io
308 }
309 Err(e) => {
310 inner.io_stopped(Some(e));
311 IoTaskStatus::Stop
312 }
313 }
314 }
315
316 fn shutdown_filters(&self, cx: &mut Context<'_>) {
317 let io = &self.0;
318 let st = &self.0.0;
319 let flags = st.flags.get();
320 if flags.contains(Flags::IO_STOPPING_FILTERS)
321 && !flags.intersects(Flags::IO_STOPPED | Flags::IO_STOPPING)
322 {
323 match io.filter().shutdown(FilterCtx::new(io, &st.buffer)) {
324 Ok(Poll::Ready(())) => {
325 st.dispatch_task.wake();
326 st.insert_flags(Flags::IO_STOPPING);
327 }
328 Ok(Poll::Pending) => {
329 let flags = st.flags.get();
332 if flags.contains(Flags::RD_PAUSED)
333 || flags.contains(Flags::BUF_R_FULL | Flags::BUF_R_READY)
334 {
335 st.dispatch_task.wake();
336 st.insert_flags(Flags::IO_STOPPING);
337 } else {
338 let timeout = self
340 .1
341 .take()
342 .unwrap_or_else(|| sleep(io.cfg().disconnect_timeout()));
343 if timeout.poll_elapsed(cx).is_ready() {
344 st.dispatch_task.wake();
345 st.insert_flags(Flags::IO_STOPPING);
346 } else {
347 self.1.set(Some(timeout));
348 }
349 }
350 }
351 Err(err) => {
352 st.io_stopped(Some(err));
353 }
354 }
355 if let Err(err) = io
356 .filter()
357 .process_write_buf(FilterCtx::new(io, &st.buffer))
358 {
359 st.io_stopped(Some(err));
360 }
361 }
362 }
363}
364
365impl Clone for IoContext {
366 fn clone(&self) -> Self {
367 Self(self.0.clone(), Cell::new(None))
368 }
369}