1#![allow(clippy::let_underscore_future)]
3use std::sync::{Arc, Mutex};
4use std::task::{Context, Poll, Waker};
5use std::{any, cell::RefCell, cmp, fmt, future::poll_fn, io, mem, net, rc::Rc};
6
7use ntex_bytes::{Buf, BufMut, Bytes, BytesVec};
8use ntex_util::time::{sleep, Millis};
9
10use crate::{types, Handle, IoStream, ReadContext, WriteContext, WriteContextBuf};
11
12#[derive(Default)]
13struct AtomicWaker(Arc<Mutex<RefCell<Option<Waker>>>>);
14
15impl AtomicWaker {
16 fn wake(&self) {
17 if let Some(waker) = self.0.lock().unwrap().borrow_mut().take() {
18 waker.wake()
19 }
20 }
21}
22
23impl fmt::Debug for AtomicWaker {
24 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
25 write!(f, "AtomicWaker")
26 }
27}
28
29#[derive(Debug)]
31pub struct IoTest {
32 tp: Type,
33 peer_addr: Option<net::SocketAddr>,
34 state: Arc<Mutex<RefCell<State>>>,
35 local: Arc<Mutex<RefCell<Channel>>>,
36 remote: Arc<Mutex<RefCell<Channel>>>,
37}
38
39bitflags::bitflags! {
40 #[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
41 struct IoTestFlags: u8 {
42 const FLUSHED = 0b0000_0001;
43 const CLOSED = 0b0000_0010;
44 }
45}
46
47#[derive(Copy, Clone, PartialEq, Eq, Debug)]
48enum Type {
49 Client,
50 Server,
51 ClientClone,
52 ServerClone,
53}
54
55#[derive(Copy, Clone, Default, Debug)]
56struct State {
57 client_dropped: bool,
58 server_dropped: bool,
59}
60
61#[derive(Default, Debug)]
62struct Channel {
63 buf: BytesVec,
64 buf_cap: usize,
65 flags: IoTestFlags,
66 waker: AtomicWaker,
67 read: IoTestState,
68 write: IoTestState,
69}
70
71unsafe impl Sync for Channel {}
72unsafe impl Send for Channel {}
73
74impl Channel {
75 fn is_closed(&self) -> bool {
76 self.flags.contains(IoTestFlags::CLOSED)
77 }
78}
79
80impl Default for IoTestFlags {
81 fn default() -> Self {
82 IoTestFlags::empty()
83 }
84}
85
86#[derive(Debug, Default)]
87enum IoTestState {
88 #[default]
89 Ok,
90 Pending,
91 Close,
92 Err(io::Error),
93}
94
95impl IoTest {
96 pub fn create() -> (IoTest, IoTest) {
98 let local = Arc::new(Mutex::new(RefCell::new(Channel::default())));
99 let remote = Arc::new(Mutex::new(RefCell::new(Channel::default())));
100 let state = Arc::new(Mutex::new(RefCell::new(State::default())));
101
102 (
103 IoTest {
104 tp: Type::Client,
105 peer_addr: None,
106 local: local.clone(),
107 remote: remote.clone(),
108 state: state.clone(),
109 },
110 IoTest {
111 state,
112 peer_addr: None,
113 tp: Type::Server,
114 local: remote,
115 remote: local,
116 },
117 )
118 }
119
120 pub fn is_client_dropped(&self) -> bool {
121 self.state.lock().unwrap().borrow().client_dropped
122 }
123
124 pub fn is_server_dropped(&self) -> bool {
125 self.state.lock().unwrap().borrow().server_dropped
126 }
127
128 pub fn is_closed(&self) -> bool {
130 self.remote.lock().unwrap().borrow().is_closed()
131 }
132
133 pub fn set_peer_addr(mut self, addr: net::SocketAddr) -> Self {
135 self.peer_addr = Some(addr);
136 self
137 }
138
139 pub fn read_pending(&self) {
141 self.remote.lock().unwrap().borrow_mut().read = IoTestState::Pending;
142 }
143
144 pub fn read_error(&self, err: io::Error) {
146 let channel = self.remote.lock().unwrap();
147 channel.borrow_mut().read = IoTestState::Err(err);
148 channel.borrow().waker.wake();
149 }
150
151 pub fn write_error(&self, err: io::Error) {
153 self.local.lock().unwrap().borrow_mut().write = IoTestState::Err(err);
154 self.remote.lock().unwrap().borrow().waker.wake();
155 }
156
157 pub fn local_buffer<F, R>(&self, f: F) -> R
159 where
160 F: FnOnce(&mut BytesVec) -> R,
161 {
162 let guard = self.local.lock().unwrap();
163 let mut ch = guard.borrow_mut();
164 f(&mut ch.buf)
165 }
166
167 pub fn remote_buffer<F, R>(&self, f: F) -> R
169 where
170 F: FnOnce(&mut BytesVec) -> R,
171 {
172 let guard = self.remote.lock().unwrap();
173 let mut ch = guard.borrow_mut();
174 f(&mut ch.buf)
175 }
176
177 pub async fn close(&self) {
179 {
180 let guard = self.remote.lock().unwrap();
181 let mut remote = guard.borrow_mut();
182 remote.read = IoTestState::Close;
183 remote.waker.wake();
184 log::trace!("close remote socket");
185 }
186 sleep(Millis(35)).await;
187 }
188
189 pub fn write<T: AsRef<[u8]>>(&self, data: T) {
191 let guard = self.remote.lock().unwrap();
192 let mut write = guard.borrow_mut();
193 write.buf.extend_from_slice(data.as_ref());
194 write.waker.wake();
195 }
196
197 pub fn remote_buffer_cap(&self, cap: usize) {
199 self.local.lock().unwrap().borrow_mut().buf_cap = cap;
201 self.remote.lock().unwrap().borrow().waker.wake();
203 }
204
205 pub fn read_any(&self) -> Bytes {
207 self.local.lock().unwrap().borrow_mut().buf.split().freeze()
208 }
209
210 pub async fn read(&self) -> Result<Bytes, io::Error> {
212 if self.local.lock().unwrap().borrow().buf.is_empty() {
213 poll_fn(|cx| {
214 let guard = self.local.lock().unwrap();
215 let read = guard.borrow_mut();
216 if read.buf.is_empty() {
217 let closed = match self.tp {
218 Type::Client | Type::ClientClone => {
219 self.is_server_dropped() || read.is_closed()
220 }
221 Type::Server | Type::ServerClone => self.is_client_dropped(),
222 };
223 if closed {
224 Poll::Ready(())
225 } else {
226 *read.waker.0.lock().unwrap().borrow_mut() =
227 Some(cx.waker().clone());
228 drop(read);
229 drop(guard);
230 Poll::Pending
231 }
232 } else {
233 Poll::Ready(())
234 }
235 })
236 .await;
237 }
238 Ok(self.local.lock().unwrap().borrow_mut().buf.split().freeze())
239 }
240
241 pub fn poll_read_buf(
242 &self,
243 cx: &mut Context<'_>,
244 buf: &mut BytesVec,
245 ) -> Poll<io::Result<usize>> {
246 let guard = self.local.lock().unwrap();
247 let mut ch = guard.borrow_mut();
248 *ch.waker.0.lock().unwrap().borrow_mut() = Some(cx.waker().clone());
249
250 if !ch.buf.is_empty() {
251 let size = std::cmp::min(ch.buf.len(), buf.remaining_mut());
252 let b = ch.buf.split_to(size);
253 buf.put_slice(&b);
254 return Poll::Ready(Ok(size));
255 }
256
257 match mem::take(&mut ch.read) {
258 IoTestState::Ok => Poll::Pending,
259 IoTestState::Close => {
260 ch.read = IoTestState::Close;
261 Poll::Ready(Ok(0))
262 }
263 IoTestState::Pending => Poll::Pending,
264 IoTestState::Err(e) => Poll::Ready(Err(e)),
265 }
266 }
267
268 pub fn poll_write_buf(
269 &self,
270 cx: &mut Context<'_>,
271 buf: &[u8],
272 ) -> Poll<io::Result<usize>> {
273 let guard = self.remote.lock().unwrap();
274 let mut ch = guard.borrow_mut();
275
276 match mem::take(&mut ch.write) {
277 IoTestState::Ok => {
278 let cap = cmp::min(buf.len(), ch.buf_cap);
279 if cap > 0 {
280 ch.buf.extend(&buf[..cap]);
281 ch.buf_cap -= cap;
282 ch.flags.remove(IoTestFlags::FLUSHED);
283 ch.waker.wake();
284 Poll::Ready(Ok(cap))
285 } else {
286 *self
287 .local
288 .lock()
289 .unwrap()
290 .borrow_mut()
291 .waker
292 .0
293 .lock()
294 .unwrap()
295 .borrow_mut() = Some(cx.waker().clone());
296 Poll::Pending
297 }
298 }
299 IoTestState::Close => Poll::Ready(Ok(0)),
300 IoTestState::Pending => {
301 *self
302 .local
303 .lock()
304 .unwrap()
305 .borrow_mut()
306 .waker
307 .0
308 .lock()
309 .unwrap()
310 .borrow_mut() = Some(cx.waker().clone());
311 Poll::Pending
312 }
313 IoTestState::Err(e) => Poll::Ready(Err(e)),
314 }
315 }
316}
317
318impl Clone for IoTest {
319 fn clone(&self) -> Self {
320 let tp = match self.tp {
321 Type::Server => Type::ServerClone,
322 Type::Client => Type::ClientClone,
323 val => val,
324 };
325
326 IoTest {
327 tp,
328 local: self.local.clone(),
329 remote: self.remote.clone(),
330 state: self.state.clone(),
331 peer_addr: self.peer_addr,
332 }
333 }
334}
335
336impl Drop for IoTest {
337 fn drop(&mut self) {
338 let mut state = *self.state.lock().unwrap().borrow();
339 match self.tp {
340 Type::Server => state.server_dropped = true,
341 Type::Client => state.client_dropped = true,
342 _ => (),
343 }
344 *self.state.lock().unwrap().borrow_mut() = state;
345
346 let guard = self.remote.lock().unwrap();
347 let mut remote = guard.borrow_mut();
348 remote.read = IoTestState::Close;
349 remote.waker.wake();
350 log::trace!("drop remote socket");
351 }
352}
353
354impl IoStream for IoTest {
355 fn start(self, read: ReadContext, write: WriteContext) -> Option<Box<dyn Handle>> {
356 let io = Rc::new(self);
357
358 let mut rio = Read(io.clone());
359 let _ = ntex_util::spawn(async move {
360 read.handle(&mut rio).await;
361 });
362
363 let mut wio = Write(io.clone());
364 let _ = ntex_util::spawn(async move {
365 write.handle(&mut wio).await;
366 });
367
368 Some(Box::new(io))
369 }
370}
371
372impl Handle for Rc<IoTest> {
373 fn query(&self, id: any::TypeId) -> Option<Box<dyn any::Any>> {
374 if id == any::TypeId::of::<types::PeerAddr>() {
375 if let Some(addr) = self.peer_addr {
376 return Some(Box::new(types::PeerAddr(addr)));
377 }
378 }
379 None
380 }
381}
382
383struct Read(Rc<IoTest>);
385
386impl crate::AsyncRead for Read {
387 async fn read(&mut self, mut buf: BytesVec) -> (BytesVec, io::Result<usize>) {
388 let result = poll_fn(|cx| self.0.poll_read_buf(cx, &mut buf)).await;
390 (buf, result)
391 }
392}
393
394struct Write(Rc<IoTest>);
396
397impl crate::AsyncWrite for Write {
398 async fn write(&mut self, buf: &mut WriteContextBuf) -> io::Result<()> {
399 poll_fn(|cx| {
400 if let Some(mut b) = buf.take() {
401 let result = write_io(&self.0, &mut b, cx);
402 buf.set(b);
403 result
404 } else {
405 Poll::Ready(Ok(()))
406 }
407 })
408 .await
409 }
410
411 async fn flush(&mut self) -> io::Result<()> {
412 Ok(())
413 }
414
415 async fn shutdown(&mut self) -> io::Result<()> {
416 self.0
418 .local
419 .lock()
420 .unwrap()
421 .borrow_mut()
422 .flags
423 .insert(IoTestFlags::CLOSED);
424 Ok(())
425 }
426}
427
428pub(super) fn write_io(
430 io: &IoTest,
431 buf: &mut BytesVec,
432 cx: &mut Context<'_>,
433) -> Poll<io::Result<()>> {
434 let len = buf.len();
435
436 if len != 0 {
437 log::trace!("flushing framed transport: {len}");
438
439 let mut written = 0;
440 let result = loop {
441 break match io.poll_write_buf(cx, &buf[written..]) {
442 Poll::Ready(Ok(n)) => {
443 if n == 0 {
444 log::trace!("disconnected during flush, written {written}");
445 Poll::Ready(Err(io::Error::new(
446 io::ErrorKind::WriteZero,
447 "failed to write frame to transport",
448 )))
449 } else {
450 written += n;
451 if written == len {
452 buf.clear();
453 Poll::Ready(Ok(()))
454 } else {
455 continue;
456 }
457 }
458 }
459 Poll::Pending => {
460 buf.advance(written);
462 Poll::Pending
463 }
464 Poll::Ready(Err(e)) => {
465 log::trace!("error during flush: {e}");
466 Poll::Ready(Err(e))
467 }
468 };
469 };
470 log::trace!("flushed {written} bytes");
471 result
472 } else {
473 Poll::Ready(Ok(()))
474 }
475}
476
477#[cfg(test)]
478#[allow(clippy::redundant_clone)]
479mod tests {
480 use super::*;
481 use ntex_util::future::lazy;
482
483 #[ntex::test]
484 async fn basic() {
485 let (client, server) = IoTest::create();
486 assert_eq!(client.tp, Type::Client);
487 assert_eq!(client.clone().tp, Type::ClientClone);
488 assert_eq!(server.tp, Type::Server);
489 assert_eq!(server.clone().tp, Type::ServerClone);
490 assert!(format!("{server:?}").contains("IoTest"));
491 assert!(format!("{:?}", AtomicWaker::default()).contains("AtomicWaker"));
492
493 server.read_pending();
494 let mut buf = BytesVec::new();
495 let res = lazy(|cx| client.poll_read_buf(cx, &mut buf)).await;
496 assert!(res.is_pending());
497
498 server.read_pending();
499 let res = lazy(|cx| server.poll_write_buf(cx, b"123")).await;
500 assert!(res.is_pending());
501
502 assert!(!server.is_client_dropped());
503 drop(client);
504 assert!(server.is_client_dropped());
505
506 let server2 = server.clone();
507 assert!(!server2.is_server_dropped());
508 drop(server);
509 assert!(server2.is_server_dropped());
510
511 let res = lazy(|cx| server2.poll_write_buf(cx, b"123")).await;
512 assert!(res.is_pending());
513
514 let (client, _) = IoTest::create();
515 let addr: net::SocketAddr = "127.0.0.1:8080".parse().unwrap();
516 let client = crate::Io::new(client.set_peer_addr(addr));
517 let item = client.query::<crate::types::PeerAddr>();
518 assert!(format!("{item:?}").contains("QueryItem(127.0.0.1:8080)"));
519 }
520}