1use std::{fmt, io, task::Context, task::Poll};
2
3use ntex_bytes::{BytePages, BytesMut};
4use ntex_util::time::sleep;
5
6use crate::{Flags, Id, IoRef, IoTaskStatus, Readiness, io::IoState};
7
8pub struct IoContext(IoRef);
10
11impl fmt::Debug for IoContext {
12 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
13 f.debug_struct("IoContext").field("io", &self.0).finish()
14 }
15}
16
17impl IoContext {
18 pub(crate) fn new(io: IoRef) -> Self {
19 Self(io)
20 }
21
22 fn st(&self) -> &IoState {
23 &self.0.0
24 }
25
26 #[doc(hidden)]
27 #[inline]
28 pub fn id(&self) -> Id {
29 self.0.id()
30 }
31
32 #[inline]
33 pub fn tag(&self) -> &'static str {
35 self.0.tag()
36 }
37
38 #[doc(hidden)]
39 pub fn flags(&self) -> Flags {
41 self.0.flags()
42 }
43
44 #[inline]
45 pub fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<Readiness> {
47 self.shutdown_filters(cx);
48 self.0.filter().poll_read_ready(cx)
49 }
50
51 #[inline]
52 pub fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<Readiness> {
54 self.0.filter().poll_write_ready(cx)
55 }
56
57 pub fn stop(&self, e: Option<io::Error>) {
59 self.st().terminate_connection(e);
60 }
61
62 pub fn is_stopped(&self) -> bool {
64 self.st().flags.is_closed()
65 }
66
67 pub fn get_read_buf(&self) -> BytesMut {
69 let st = self.st();
70
71 if st.flags.is_read_ready() {
72 st.get_read_buf()
75 } else if let Some(mut buf) = st.buffer.get_read_buf() {
76 self.0.resize_read_buf(&mut buf);
77 buf
78 } else {
79 st.get_read_buf()
80 }
81 }
82
83 pub fn resize_read_buf(&self, buf: &mut BytesMut) {
85 self.0.resize_read_buf(buf);
86 }
87
88 pub fn update_read_status(
93 &self,
94 buf: BytesMut,
95 status: io::Result<usize>,
96 ) -> IoTaskStatus {
97 let st = self.st();
98 let orig = st.buffer.read_dst_size();
99
100 #[cfg(feature = "trace")]
101 log::trace!(
102 "{}: read-status == {status:?} orig:{orig:?} flags:{:?}",
103 st.tag(),
104 st.flags
105 );
106
107 st.buffer.set_read_buf(buf, self.0.cfg());
109
110 let result = status.and_then(|nbytes| {
112 if nbytes == 0 {
113 return Ok(());
114 }
115 st.buffer.process_read_buf(&self.0, nbytes).map(|status| {
116 let size = st.buffer.read_dst_size();
117
118 if size > orig {
120 if st.is_rd_backpressure_needed(size) {
121 log::trace!("{}: Read buf({size}), enable back-pressure", st.tag());
122 st.flags.set_read_ready_and_backpressure();
123 } else {
124 st.flags.set_read_ready();
125 }
126 #[cfg(feature = "trace")]
127 log::trace!("{}: New {size} bytes available", st.tag());
128 st.wake_dispatch_task();
129 }
130
131 if st.flags.is_read_notify() {
132 st.wake_dispatch_task();
135 st.flags.set_read_notifed();
136 }
137
138 if status.wants_write {
140 if let Err(err) = st.buffer.process_write_buf_force(&self.0) {
141 st.terminate_connection(Some(err));
142 } else {
143 self.0.consolidate_write_state(false);
144 }
145 }
146
147 if status.notify {
149 self.0.call_notify();
150 }
151 })
152 });
153
154 if let Err(err) = result {
155 st.terminate_connection(Some(err));
156 IoTaskStatus::Stop
157 } else if st.flags.is_closed() {
158 IoTaskStatus::Stop
159 } else if st.flags.is_read_paused_or_backpressure() {
160 IoTaskStatus::Pause
161 } else {
162 IoTaskStatus::Io
163 }
164 }
165
166 pub fn with_write_buf<F, R>(&self, f: F) -> R
168 where
169 F: FnOnce(&mut BytePages) -> R,
170 {
171 if let Err(e) = self.st().buffer.process_write_buf(&self.0) {
173 self.st().terminate_connection(Some(e));
174 }
175
176 self.st().buffer.with_write_dst(|buffer| f(buffer))
177 }
178
179 pub fn update_write_status(&self, status: io::Result<bool>) -> IoTaskStatus {
184 let st = &self.st();
185
186 #[cfg(feature = "trace")]
187 log::trace!(
188 "{}: write-status == {status:?} buf:{} flags:{:?}",
189 st.tag(),
190 st.buffer.write_buf_size(),
191 st.flags
192 );
193
194 match status {
195 Ok(written) => {
196 let len = st.buffer.write_buf_size();
197 if st.flags.is_write_flush() {
199 if len == 0 {
201 st.wake_dispatch_task();
202 }
203 } else if st.flags.is_wr_backpressure()
204 && st.should_disable_wr_backpressure(len)
205 {
206 st.wake_dispatch_task();
208 }
209
210 if written && st.flags.is_write_notify() {
212 st.flags.unset_write_notify();
213 st.wake_read_task();
214 st.wake_write_task();
215 }
216
217 if st.flags.is_closed() {
218 IoTaskStatus::Stop
219 } else if len == 0 {
220 st.flags.set_write_paused();
222 if st.flags.is_stopping_filters() {
223 st.wake_read_task();
224 }
225 IoTaskStatus::Pause
226 } else {
227 st.flags.unset_write_paused();
228 IoTaskStatus::Io
229 }
230 }
231 Err(err) => {
232 st.terminate_connection(Some(err));
233 IoTaskStatus::Stop
234 }
235 }
236 }
237
238 pub fn shutdown(&self, flush: bool, cx: &mut Context<'_>) -> Poll<()> {
240 let st = self.st();
241 if flush && !st.flags.is_stopping() {
242 if st.flags.is_write_paused() {
243 return Poll::Ready(());
244 }
245 st.flags.set_write_notify();
246 st.read_task.register(cx.waker());
247 st.write_task.register(cx.waker());
248 Poll::Pending
249 } else if !st.flags.is_closed() {
250 st.read_task.register(cx.waker());
251 st.write_task.register(cx.waker());
252 Poll::Pending
253 } else {
254 Poll::Ready(())
255 }
256 }
257
258 fn shutdown_filters(&self, cx: &mut Context<'_>) {
259 let st = &self.st();
260 if !st.flags.is_shutting_down_filters() {
261 return;
262 }
263
264 let ready = match st.buffer.process_shutdown(&self.0) {
266 Ok(Poll::Ready(())) => true,
267 Ok(Poll::Pending) => false,
268 Err(err) => {
269 st.terminate_connection(Some(err));
270 return;
271 }
272 };
273 self.0.consolidate_write_state(true);
274
275 #[cfg(feature = "trace")]
276 log::trace!(
277 "{}: shutdown filters, done:{ready:?} wr-buf:{:?}, flags:{:?}",
278 st.tag(),
279 st.buffer.write_buf_size(),
280 st.flags,
281 );
282
283 if ready && st.flags.is_write_paused() && !st.flags.is_wr_send_scheduled() {
285 st.filters_stopped();
286 } else if st.flags.is_read_paused() || st.flags.is_read_ready_and_backpressure() {
287 st.filters_stopped();
290 } else {
291 let timeout = st
293 .shutdown_timeout
294 .take()
295 .unwrap_or_else(|| sleep(st.cfg.disconnect_timeout()));
296 if timeout.poll_elapsed(cx).is_ready() {
297 st.filters_stopped();
298 } else {
299 st.shutdown_timeout.set(Some(timeout));
300 }
301 }
302 }
303
304 pub fn notify(&self) {
306 self.0.0.wake_read_task();
307 }
308}
309
310impl Clone for IoContext {
311 fn clone(&self) -> Self {
312 Self(self.0.clone())
313 }
314}
315
316#[cfg(test)]
317mod tests {
318 use super::*;
319 use crate::{Io, testing::IoTest};
320
321 #[ntex::test]
322 async fn ctx_basics() {
323 let (_, server) = IoTest::create();
324
325 let state = Io::from(server);
326 let ctx = IoContext::new(state.get_ref());
327 let _ = ctx.flags();
328 assert!(ctx.id() != Id::default());
329 assert!(format!("{ctx:?}").contains("IoContext"));
330 }
331}