1#![allow(clippy::missing_panics_doc)]
3use std::sync::{Arc, Mutex};
4use std::task::{Context, Poll, Waker};
5use std::{any, cell::RefCell, cmp, fmt, future::poll_fn, io, mem, net, rc::Rc};
6
7use ntex_bytes::{BufMut, Bytes, BytesMut};
8use ntex_util::time::{Millis, sleep};
9
10use crate::{Handle, IoContext, IoStream, IoTaskStatus, Readiness, types};
11
12#[derive(Default)]
13struct AtomicWaker(Arc<Mutex<RefCell<Option<Waker>>>>);
14
15impl AtomicWaker {
16 fn wake(&self) {
17 if let Some(waker) = self.0.lock().unwrap().borrow_mut().take() {
18 waker.wake();
19 }
20 }
21}
22
23impl fmt::Debug for AtomicWaker {
24 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
25 write!(f, "AtomicWaker")
26 }
27}
28
29#[derive(Debug)]
31pub struct IoTest {
32 tp: Type,
33 peer_addr: Option<net::SocketAddr>,
34 state: Arc<Mutex<RefCell<State>>>,
35 local: Arc<Mutex<RefCell<Channel>>>,
36 remote: Arc<Mutex<RefCell<Channel>>>,
37}
38
39bitflags::bitflags! {
40 #[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
41 struct IoTestFlags: u8 {
42 const FLUSHED = 0b0000_0001;
43 const CLOSED = 0b0000_0010;
44 }
45}
46
47#[derive(Copy, Clone, PartialEq, Eq, Debug)]
48enum Type {
49 Client,
50 Server,
51 ClientClone,
52 ServerClone,
53}
54
55#[derive(Copy, Clone, Default, Debug)]
56struct State {
57 client_dropped: bool,
58 server_dropped: bool,
59}
60
61#[derive(Default, Debug)]
62struct Channel {
63 buf: BytesMut,
64 buf_cap: usize,
65 flags: IoTestFlags,
66 waker: AtomicWaker,
67 read: IoTestState,
68 write: IoTestState,
69}
70
71unsafe impl Sync for Channel {}
72unsafe impl Send for Channel {}
73
74impl Channel {
75 fn is_closed(&self) -> bool {
76 self.flags.contains(IoTestFlags::CLOSED)
77 }
78}
79
80impl Default for IoTestFlags {
81 fn default() -> Self {
82 IoTestFlags::empty()
83 }
84}
85
86#[derive(Debug, Default)]
87enum IoTestState {
88 #[default]
89 Ok,
90 Pending,
91 Close,
92 Err(io::Error),
93}
94
95impl IoTest {
96 pub fn create() -> (IoTest, IoTest) {
98 let local = Arc::new(Mutex::new(RefCell::new(Channel::default())));
99 let remote = Arc::new(Mutex::new(RefCell::new(Channel::default())));
100 let state = Arc::new(Mutex::new(RefCell::new(State::default())));
101
102 (
103 IoTest {
104 tp: Type::Client,
105 peer_addr: None,
106 local: local.clone(),
107 remote: remote.clone(),
108 state: state.clone(),
109 },
110 IoTest {
111 state,
112 peer_addr: None,
113 tp: Type::Server,
114 local: remote,
115 remote: local,
116 },
117 )
118 }
119
120 pub fn is_client_dropped(&self) -> bool {
122 self.state.lock().unwrap().borrow().client_dropped
123 }
124
125 pub fn is_server_dropped(&self) -> bool {
127 self.state.lock().unwrap().borrow().server_dropped
128 }
129
130 pub fn is_closed(&self) -> bool {
132 self.remote.lock().unwrap().borrow().is_closed()
133 }
134
135 #[must_use]
137 pub fn set_peer_addr(mut self, addr: net::SocketAddr) -> Self {
138 self.peer_addr = Some(addr);
139 self
140 }
141
142 pub fn read_pending(&self) {
144 self.remote.lock().unwrap().borrow_mut().read = IoTestState::Pending;
145 }
146
147 pub fn read_error(&self, err: io::Error) {
149 let channel = self.remote.lock().unwrap();
150 channel.borrow_mut().read = IoTestState::Err(err);
151 channel.borrow().waker.wake();
152 }
153
154 pub fn write_error(&self, err: io::Error) {
156 self.local.lock().unwrap().borrow_mut().write = IoTestState::Err(err);
157 self.remote.lock().unwrap().borrow().waker.wake();
158 }
159
160 pub fn local_buffer<F, R>(&self, f: F) -> R
162 where
163 F: FnOnce(&mut BytesMut) -> R,
164 {
165 let guard = self.local.lock().unwrap();
166 let mut ch = guard.borrow_mut();
167 f(&mut ch.buf)
168 }
169
170 pub fn remote_buffer<F, R>(&self, f: F) -> R
172 where
173 F: FnOnce(&mut BytesMut) -> R,
174 {
175 let guard = self.remote.lock().unwrap();
176 let mut ch = guard.borrow_mut();
177 f(&mut ch.buf)
178 }
179
180 pub async fn close(&self) {
182 {
183 let guard = self.remote.lock().unwrap();
184 let mut remote = guard.borrow_mut();
185 remote.read = IoTestState::Close;
186 remote.waker.wake();
187 log::debug!("close remote socket");
188 }
189 sleep(Millis(35)).await;
190 }
191
192 pub fn write<T: AsRef<[u8]>>(&self, data: T) {
194 let guard = self.remote.lock().unwrap();
195 let mut write = guard.borrow_mut();
196 write.buf.extend_from_slice(data.as_ref());
197 write.waker.wake();
198 }
199
200 pub fn remote_buffer_cap(&self, cap: usize) {
202 self.local.lock().unwrap().borrow_mut().buf_cap = cap;
204 self.remote.lock().unwrap().borrow().waker.wake();
206 }
207
208 pub fn read_any(&self) -> Bytes {
210 self.local.lock().unwrap().borrow_mut().buf.take()
211 }
212
213 pub async fn read(&self) -> Result<Bytes, io::Error> {
215 if self.local.lock().unwrap().borrow().buf.is_empty() {
216 poll_fn(|cx| {
217 let guard = self.local.lock().unwrap();
218 let read = guard.borrow_mut();
219 if read.buf.is_empty() {
220 let closed = match self.tp {
221 Type::Client | Type::ClientClone => {
222 self.is_server_dropped() || read.is_closed()
223 }
224 Type::Server | Type::ServerClone => self.is_client_dropped(),
225 };
226 if closed {
227 Poll::Ready(())
228 } else {
229 *read.waker.0.lock().unwrap().borrow_mut() =
230 Some(cx.waker().clone());
231 drop(read);
232 drop(guard);
233 Poll::Pending
234 }
235 } else {
236 Poll::Ready(())
237 }
238 })
239 .await;
240 }
241 Ok(self.local.lock().unwrap().borrow_mut().buf.take())
242 }
243
244 pub fn poll_read_buf(
245 &self,
246 cx: &mut Context<'_>,
247 buf: &mut BytesMut,
248 ) -> Poll<io::Result<usize>> {
249 let guard = self.local.lock().unwrap();
250 let mut ch = guard.borrow_mut();
251 *ch.waker.0.lock().unwrap().borrow_mut() = Some(cx.waker().clone());
252
253 if !ch.buf.is_empty() {
254 let size = std::cmp::min(ch.buf.len(), buf.remaining_mut());
255 let b = ch.buf.split_to(size);
256 buf.put_slice(&b);
257 return Poll::Ready(Ok(size));
258 }
259
260 match mem::take(&mut ch.read) {
261 IoTestState::Ok | IoTestState::Pending => Poll::Pending,
262 IoTestState::Close => {
263 ch.read = IoTestState::Close;
264 Poll::Ready(Ok(0))
265 }
266 IoTestState::Err(e) => Poll::Ready(Err(e)),
267 }
268 }
269
270 pub fn poll_write_buf(
271 &self,
272 cx: &mut Context<'_>,
273 buf: &[u8],
274 ) -> Poll<io::Result<usize>> {
275 let guard = self.remote.lock().unwrap();
276 let mut ch = guard.borrow_mut();
277
278 match mem::take(&mut ch.write) {
279 IoTestState::Ok => {
280 let cap = cmp::min(buf.len(), ch.buf_cap);
281 if cap > 0 {
282 ch.buf.extend(&buf[..cap]);
283 ch.buf_cap -= cap;
284 ch.flags.remove(IoTestFlags::FLUSHED);
285 ch.waker.wake();
286 Poll::Ready(Ok(cap))
287 } else {
288 *self
289 .local
290 .lock()
291 .unwrap()
292 .borrow_mut()
293 .waker
294 .0
295 .lock()
296 .unwrap()
297 .borrow_mut() = Some(cx.waker().clone());
298 Poll::Pending
299 }
300 }
301 IoTestState::Close => Poll::Ready(Ok(0)),
302 IoTestState::Pending => {
303 *self
304 .local
305 .lock()
306 .unwrap()
307 .borrow_mut()
308 .waker
309 .0
310 .lock()
311 .unwrap()
312 .borrow_mut() = Some(cx.waker().clone());
313 Poll::Pending
314 }
315 IoTestState::Err(e) => Poll::Ready(Err(e)),
316 }
317 }
318}
319
320impl Clone for IoTest {
321 fn clone(&self) -> Self {
322 let tp = match self.tp {
323 Type::Server => Type::ServerClone,
324 Type::Client => Type::ClientClone,
325 val => val,
326 };
327
328 IoTest {
329 tp,
330 local: self.local.clone(),
331 remote: self.remote.clone(),
332 state: self.state.clone(),
333 peer_addr: self.peer_addr,
334 }
335 }
336}
337
338impl Drop for IoTest {
339 fn drop(&mut self) {
340 let mut state = *self.state.lock().unwrap().borrow();
341 match self.tp {
342 Type::Server => state.server_dropped = true,
343 Type::Client => state.client_dropped = true,
344 _ => (),
345 }
346 *self.state.lock().unwrap().borrow_mut() = state;
347
348 let guard = self.remote.lock().unwrap();
349 let mut remote = guard.borrow_mut();
350 remote.read = IoTestState::Close;
351 remote.waker.wake();
352 log::debug!("drop remote socket");
353 }
354}
355
356impl IoStream for IoTest {
357 fn start(self, ctx: IoContext) -> Option<Box<dyn Handle>> {
358 let io = Rc::new(self);
359 ntex_util::spawn(run(io.clone(), ctx));
360 Some(Box::new(io))
361 }
362}
363
364impl Handle for Rc<IoTest> {
365 fn query(&self, id: any::TypeId) -> Option<Box<dyn any::Any>> {
366 if id == any::TypeId::of::<types::PeerAddr>()
367 && let Some(addr) = self.peer_addr
368 {
369 return Some(Box::new(types::PeerAddr(addr)));
370 }
371 None
372 }
373}
374
375#[derive(Copy, Clone, Debug, PartialEq, Eq)]
376enum Status {
377 Shutdown,
378 Terminate,
379}
380
381async fn run(io: Rc<IoTest>, ctx: IoContext) {
382 let st = poll_fn(|cx| turn(&io, &ctx, cx)).await;
383
384 log::debug!("{}: Shuting down io", ctx.tag());
385 if !ctx.is_stopped() {
386 let flush = st == Status::Shutdown;
387 poll_fn(|cx| {
388 if write(&io, &ctx, cx) == Poll::Ready(Status::Terminate) {
389 Poll::Ready(())
390 } else {
391 ctx.shutdown(flush, cx)
392 }
393 })
394 .await;
395 }
396
397 io.local
399 .lock()
400 .unwrap()
401 .borrow_mut()
402 .flags
403 .insert(IoTestFlags::CLOSED);
404
405 log::debug!("{}: Shutdown complete", ctx.tag());
406 if !ctx.is_stopped() {
407 ctx.stop(None);
408 }
409}
410
411fn turn(io: &IoTest, ctx: &IoContext, cx: &mut Context<'_>) -> Poll<Status> {
412 let read = match ctx.poll_read_ready(cx) {
413 Poll::Ready(Readiness::Ready) => read(io, ctx, cx),
414 Poll::Ready(Readiness::Shutdown | Readiness::Terminate) => Poll::Ready(()),
415 Poll::Pending => Poll::Pending,
416 };
417
418 let write = match ctx.poll_write_ready(cx) {
419 Poll::Ready(Readiness::Ready) => write(io, ctx, cx),
420 Poll::Ready(Readiness::Shutdown) => Poll::Ready(Status::Shutdown),
421 Poll::Ready(Readiness::Terminate) => Poll::Ready(Status::Terminate),
422 Poll::Pending => Poll::Pending,
423 };
424
425 if read.is_pending() && write.is_pending() {
426 Poll::Pending
427 } else if write.is_ready() {
428 write
429 } else {
430 Poll::Ready(Status::Terminate)
431 }
432}
433
434fn write(io: &IoTest, ctx: &IoContext, cx: &mut Context<'_>) -> Poll<Status> {
435 if let Some(mut buf) = ctx.get_write_buf() {
436 let result = write_io(io, &mut buf, cx, ctx.tag());
437 if ctx.release_write_buf(buf, result) == IoTaskStatus::Stop {
438 Poll::Ready(Status::Terminate)
439 } else {
440 Poll::Pending
441 }
442 } else {
443 Poll::Pending
444 }
445}
446
447fn read(io: &IoTest, ctx: &IoContext, cx: &mut Context<'_>) -> Poll<()> {
448 let mut buf = ctx.get_read_buf();
449
450 let mut n = 0;
452 loop {
453 ctx.resize_read_buf(&mut buf);
454
455 let result = match io.poll_read_buf(cx, &mut buf) {
456 Poll::Pending => {
457 if n > 0 {
458 Poll::Ready(Ok(()))
459 } else {
460 Poll::Pending
461 }
462 }
463 Poll::Ready(Ok(size)) => {
464 n += size;
465 if size > 0 && buf.remaining_mut() > 0 {
466 continue;
467 }
468 if size == 0 {
469 Poll::Ready(Err(None))
470 } else {
471 Poll::Ready(Ok(()))
472 }
473 }
474 Poll::Ready(Err(err)) => Poll::Ready(Err(Some(err))),
475 };
476
477 return if matches!(ctx.release_read_buf(n, buf, result), IoTaskStatus::Stop) {
478 Poll::Ready(())
479 } else {
480 Poll::Pending
481 };
482 }
483}
484
485pub(super) fn write_io(
487 io: &IoTest,
488 buf: &mut BytesMut,
489 cx: &mut Context<'_>,
490 tag: &'static str,
491) -> Poll<io::Result<usize>> {
492 let len = buf.len();
493
494 if len != 0 {
495 log::debug!("{tag}: flushing framed transport: {len}");
496
497 let mut written = 0;
498 while let Poll::Ready(n) = io.poll_write_buf(cx, &buf[written..])? {
499 if n == 0 {
500 log::trace!("{tag}: disconnected during flush, written {written}");
501 return Poll::Ready(Err(io::Error::new(
502 io::ErrorKind::WriteZero,
503 "failed to write frame to transport",
504 )));
505 }
506 written += n;
507 if written == len {
508 break;
509 }
510 }
511 log::debug!("{tag}: flushed {written} bytes");
512 if written > 0 {
513 Poll::Ready(Ok(written))
514 } else {
515 Poll::Pending
516 }
517 } else {
518 Poll::Pending
519 }
520}
521
522#[cfg(test)]
523#[allow(clippy::redundant_clone)]
524mod tests {
525 use super::*;
526 use ntex_util::future::lazy;
527
528 #[ntex::test]
529 async fn basic() {
530 let (client, server) = IoTest::create();
531 assert_eq!(client.tp, Type::Client);
532 assert_eq!(client.clone().tp, Type::ClientClone);
533 assert_eq!(server.tp, Type::Server);
534 assert_eq!(server.clone().tp, Type::ServerClone);
535 assert!(format!("{server:?}").contains("IoTest"));
536 assert!(format!("{:?}", AtomicWaker::default()).contains("AtomicWaker"));
537
538 server.read_pending();
539 let mut buf = BytesMut::new();
540 let res = lazy(|cx| client.poll_read_buf(cx, &mut buf)).await;
541 assert!(res.is_pending());
542
543 server.read_pending();
544 let res = lazy(|cx| server.poll_write_buf(cx, b"123")).await;
545 assert!(res.is_pending());
546
547 assert!(!server.is_client_dropped());
548 drop(client);
549 assert!(server.is_client_dropped());
550
551 let server2 = server.clone();
552 assert!(!server2.is_server_dropped());
553 drop(server);
554 assert!(server2.is_server_dropped());
555
556 let res = lazy(|cx| server2.poll_write_buf(cx, b"123")).await;
557 assert!(res.is_pending());
558
559 let (client, _) = IoTest::create();
560 let addr: net::SocketAddr = "127.0.0.1:8080".parse().unwrap();
561 let client = crate::Io::from(client.set_peer_addr(addr));
562 let item = client.query::<crate::types::PeerAddr>();
563 assert!(format!("{item:?}").contains("QueryItem(127.0.0.1:8080)"));
564 }
565}