Skip to main content

ntex_io/
testing.rs

1//! utilities and helpers for testing
2#![allow(clippy::missing_panics_doc)]
3use std::sync::{Arc, Mutex};
4use std::task::{Context, Poll, Waker};
5use std::{any, cell::RefCell, cmp, fmt, future::poll_fn, io, mem, net, rc::Rc};
6
7use ntex_bytes::{BufMut, Bytes, BytesMut};
8use ntex_util::time::{Millis, sleep};
9
10use crate::{Handle, IoContext, IoStream, IoTaskStatus, Readiness, types};
11
12#[derive(Default)]
13struct AtomicWaker(Arc<Mutex<RefCell<Option<Waker>>>>);
14
15impl AtomicWaker {
16    fn wake(&self) {
17        if let Some(waker) = self.0.lock().unwrap().borrow_mut().take() {
18            waker.wake();
19        }
20    }
21}
22
23impl fmt::Debug for AtomicWaker {
24    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
25        write!(f, "AtomicWaker")
26    }
27}
28
29/// Async io stream
30#[derive(Debug)]
31pub struct IoTest {
32    tp: Type,
33    peer_addr: Option<net::SocketAddr>,
34    state: Arc<Mutex<RefCell<State>>>,
35    local: Arc<Mutex<RefCell<Channel>>>,
36    remote: Arc<Mutex<RefCell<Channel>>>,
37}
38
39bitflags::bitflags! {
40    #[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
41    struct IoTestFlags: u8 {
42        const FLUSHED = 0b0000_0001;
43        const CLOSED  = 0b0000_0010;
44    }
45}
46
47#[derive(Copy, Clone, PartialEq, Eq, Debug)]
48enum Type {
49    Client,
50    Server,
51    ClientClone,
52    ServerClone,
53}
54
55#[derive(Copy, Clone, Default, Debug)]
56struct State {
57    client_dropped: bool,
58    server_dropped: bool,
59}
60
61#[derive(Default, Debug)]
62struct Channel {
63    buf: BytesMut,
64    buf_cap: usize,
65    flags: IoTestFlags,
66    waker: AtomicWaker,
67    read: IoTestState,
68    write: IoTestState,
69}
70
71unsafe impl Sync for Channel {}
72unsafe impl Send for Channel {}
73
74impl Channel {
75    fn is_closed(&self) -> bool {
76        self.flags.contains(IoTestFlags::CLOSED)
77    }
78}
79
80impl Default for IoTestFlags {
81    fn default() -> Self {
82        IoTestFlags::empty()
83    }
84}
85
86#[derive(Debug, Default)]
87enum IoTestState {
88    #[default]
89    Ok,
90    Pending,
91    Close,
92    Err(io::Error),
93}
94
95impl IoTest {
96    /// Create a two interconnected streams
97    pub fn create() -> (IoTest, IoTest) {
98        let local = Arc::new(Mutex::new(RefCell::new(Channel::default())));
99        let remote = Arc::new(Mutex::new(RefCell::new(Channel::default())));
100        let state = Arc::new(Mutex::new(RefCell::new(State::default())));
101
102        (
103            IoTest {
104                tp: Type::Client,
105                peer_addr: None,
106                local: local.clone(),
107                remote: remote.clone(),
108                state: state.clone(),
109            },
110            IoTest {
111                state,
112                peer_addr: None,
113                tp: Type::Server,
114                local: remote,
115                remote: local,
116            },
117        )
118    }
119
120    /// Check if client is dropped
121    pub fn is_client_dropped(&self) -> bool {
122        self.state.lock().unwrap().borrow().client_dropped
123    }
124
125    /// Check if server is dropped
126    pub fn is_server_dropped(&self) -> bool {
127        self.state.lock().unwrap().borrow().server_dropped
128    }
129
130    /// Check if channel is closed from remoote side
131    pub fn is_closed(&self) -> bool {
132        self.remote.lock().unwrap().borrow().is_closed()
133    }
134
135    /// Set peer addr
136    #[must_use]
137    pub fn set_peer_addr(mut self, addr: net::SocketAddr) -> Self {
138        self.peer_addr = Some(addr);
139        self
140    }
141
142    /// Set read to Pending state
143    pub fn read_pending(&self) {
144        self.remote.lock().unwrap().borrow_mut().read = IoTestState::Pending;
145    }
146
147    /// Set read to error
148    pub fn read_error(&self, err: io::Error) {
149        let channel = self.remote.lock().unwrap();
150        channel.borrow_mut().read = IoTestState::Err(err);
151        channel.borrow().waker.wake();
152    }
153
154    /// Set write error on remote side
155    pub fn write_error(&self, err: io::Error) {
156        self.local.lock().unwrap().borrow_mut().write = IoTestState::Err(err);
157        self.remote.lock().unwrap().borrow().waker.wake();
158    }
159
160    /// Access read buffer.
161    pub fn local_buffer<F, R>(&self, f: F) -> R
162    where
163        F: FnOnce(&mut BytesMut) -> R,
164    {
165        let guard = self.local.lock().unwrap();
166        let mut ch = guard.borrow_mut();
167        f(&mut ch.buf)
168    }
169
170    /// Access remote buffer.
171    pub fn remote_buffer<F, R>(&self, f: F) -> R
172    where
173        F: FnOnce(&mut BytesMut) -> R,
174    {
175        let guard = self.remote.lock().unwrap();
176        let mut ch = guard.borrow_mut();
177        f(&mut ch.buf)
178    }
179
180    /// Closed remote side.
181    pub async fn close(&self) {
182        {
183            let guard = self.remote.lock().unwrap();
184            let mut remote = guard.borrow_mut();
185            remote.read = IoTestState::Close;
186            remote.waker.wake();
187            log::debug!("close remote socket");
188        }
189        sleep(Millis(35)).await;
190    }
191
192    /// Add extra data to the remote buffer and notify reader
193    pub fn write<T: AsRef<[u8]>>(&self, data: T) {
194        let guard = self.remote.lock().unwrap();
195        let mut write = guard.borrow_mut();
196        write.buf.extend_from_slice(data.as_ref());
197        write.waker.wake();
198    }
199
200    /// Read any available data
201    pub fn remote_buffer_cap(&self, cap: usize) {
202        // change cap
203        self.local.lock().unwrap().borrow_mut().buf_cap = cap;
204        // wake remote
205        self.remote.lock().unwrap().borrow().waker.wake();
206    }
207
208    /// Read any available data
209    pub fn read_any(&self) -> Bytes {
210        self.local.lock().unwrap().borrow_mut().buf.take()
211    }
212
213    /// Read data, if data is not available wait for it
214    pub async fn read(&self) -> Result<Bytes, io::Error> {
215        if self.local.lock().unwrap().borrow().buf.is_empty() {
216            poll_fn(|cx| {
217                let guard = self.local.lock().unwrap();
218                let read = guard.borrow_mut();
219                if read.buf.is_empty() {
220                    let closed = match self.tp {
221                        Type::Client | Type::ClientClone => {
222                            self.is_server_dropped() || read.is_closed()
223                        }
224                        Type::Server | Type::ServerClone => self.is_client_dropped(),
225                    };
226                    if closed {
227                        Poll::Ready(())
228                    } else {
229                        *read.waker.0.lock().unwrap().borrow_mut() =
230                            Some(cx.waker().clone());
231                        drop(read);
232                        drop(guard);
233                        Poll::Pending
234                    }
235                } else {
236                    Poll::Ready(())
237                }
238            })
239            .await;
240        }
241        Ok(self.local.lock().unwrap().borrow_mut().buf.take())
242    }
243
244    pub fn poll_read_buf(
245        &self,
246        cx: &mut Context<'_>,
247        buf: &mut BytesMut,
248    ) -> Poll<io::Result<usize>> {
249        let guard = self.local.lock().unwrap();
250        let mut ch = guard.borrow_mut();
251        *ch.waker.0.lock().unwrap().borrow_mut() = Some(cx.waker().clone());
252
253        if !ch.buf.is_empty() {
254            let size = std::cmp::min(ch.buf.len(), buf.remaining_mut());
255            let b = ch.buf.split_to(size);
256            buf.put_slice(&b);
257            return Poll::Ready(Ok(size));
258        }
259
260        match mem::take(&mut ch.read) {
261            IoTestState::Ok | IoTestState::Pending => Poll::Pending,
262            IoTestState::Close => {
263                ch.read = IoTestState::Close;
264                Poll::Ready(Ok(0))
265            }
266            IoTestState::Err(e) => Poll::Ready(Err(e)),
267        }
268    }
269
270    pub fn poll_write_buf(
271        &self,
272        cx: &mut Context<'_>,
273        buf: &[u8],
274    ) -> Poll<io::Result<usize>> {
275        let guard = self.remote.lock().unwrap();
276        let mut ch = guard.borrow_mut();
277
278        match mem::take(&mut ch.write) {
279            IoTestState::Ok => {
280                let cap = cmp::min(buf.len(), ch.buf_cap);
281                if cap > 0 {
282                    ch.buf.extend(&buf[..cap]);
283                    ch.buf_cap -= cap;
284                    ch.flags.remove(IoTestFlags::FLUSHED);
285                    ch.waker.wake();
286                    Poll::Ready(Ok(cap))
287                } else {
288                    *self
289                        .local
290                        .lock()
291                        .unwrap()
292                        .borrow_mut()
293                        .waker
294                        .0
295                        .lock()
296                        .unwrap()
297                        .borrow_mut() = Some(cx.waker().clone());
298                    Poll::Pending
299                }
300            }
301            IoTestState::Close => Poll::Ready(Ok(0)),
302            IoTestState::Pending => {
303                *self
304                    .local
305                    .lock()
306                    .unwrap()
307                    .borrow_mut()
308                    .waker
309                    .0
310                    .lock()
311                    .unwrap()
312                    .borrow_mut() = Some(cx.waker().clone());
313                Poll::Pending
314            }
315            IoTestState::Err(e) => Poll::Ready(Err(e)),
316        }
317    }
318}
319
320impl Clone for IoTest {
321    fn clone(&self) -> Self {
322        let tp = match self.tp {
323            Type::Server => Type::ServerClone,
324            Type::Client => Type::ClientClone,
325            val => val,
326        };
327
328        IoTest {
329            tp,
330            local: self.local.clone(),
331            remote: self.remote.clone(),
332            state: self.state.clone(),
333            peer_addr: self.peer_addr,
334        }
335    }
336}
337
338impl Drop for IoTest {
339    fn drop(&mut self) {
340        let mut state = *self.state.lock().unwrap().borrow();
341        match self.tp {
342            Type::Server => state.server_dropped = true,
343            Type::Client => state.client_dropped = true,
344            _ => (),
345        }
346        *self.state.lock().unwrap().borrow_mut() = state;
347
348        let guard = self.remote.lock().unwrap();
349        let mut remote = guard.borrow_mut();
350        remote.read = IoTestState::Close;
351        remote.waker.wake();
352        log::debug!("drop remote socket");
353    }
354}
355
356impl IoStream for IoTest {
357    fn start(self, ctx: IoContext) -> Option<Box<dyn Handle>> {
358        let io = Rc::new(self);
359        ntex_util::spawn(run(io.clone(), ctx));
360        Some(Box::new(io))
361    }
362}
363
364impl Handle for Rc<IoTest> {
365    fn query(&self, id: any::TypeId) -> Option<Box<dyn any::Any>> {
366        if id == any::TypeId::of::<types::PeerAddr>()
367            && let Some(addr) = self.peer_addr
368        {
369            return Some(Box::new(types::PeerAddr(addr)));
370        }
371        None
372    }
373}
374
375#[derive(Copy, Clone, Debug, PartialEq, Eq)]
376enum Status {
377    Shutdown,
378    Terminate,
379}
380
381async fn run(io: Rc<IoTest>, ctx: IoContext) {
382    let st = poll_fn(|cx| turn(&io, &ctx, cx)).await;
383
384    log::debug!("{}: Shuting down io", ctx.tag());
385    if !ctx.is_stopped() {
386        let flush = st == Status::Shutdown;
387        poll_fn(|cx| {
388            if write(&io, &ctx, cx) == Poll::Ready(Status::Terminate) {
389                Poll::Ready(())
390            } else {
391                ctx.shutdown(flush, cx)
392            }
393        })
394        .await;
395    }
396
397    // shutdown WRITE side
398    io.local
399        .lock()
400        .unwrap()
401        .borrow_mut()
402        .flags
403        .insert(IoTestFlags::CLOSED);
404
405    log::debug!("{}: Shutdown complete", ctx.tag());
406    if !ctx.is_stopped() {
407        ctx.stop(None);
408    }
409}
410
411fn turn(io: &IoTest, ctx: &IoContext, cx: &mut Context<'_>) -> Poll<Status> {
412    let read = match ctx.poll_read_ready(cx) {
413        Poll::Ready(Readiness::Ready) => read(io, ctx, cx),
414        Poll::Ready(Readiness::Shutdown | Readiness::Terminate) => Poll::Ready(()),
415        Poll::Pending => Poll::Pending,
416    };
417
418    let write = match ctx.poll_write_ready(cx) {
419        Poll::Ready(Readiness::Ready) => write(io, ctx, cx),
420        Poll::Ready(Readiness::Shutdown) => Poll::Ready(Status::Shutdown),
421        Poll::Ready(Readiness::Terminate) => Poll::Ready(Status::Terminate),
422        Poll::Pending => Poll::Pending,
423    };
424
425    if read.is_pending() && write.is_pending() {
426        Poll::Pending
427    } else if write.is_ready() {
428        write
429    } else {
430        Poll::Ready(Status::Terminate)
431    }
432}
433
434fn write(io: &IoTest, ctx: &IoContext, cx: &mut Context<'_>) -> Poll<Status> {
435    if let Some(mut buf) = ctx.get_write_buf() {
436        let result = write_io(io, &mut buf, cx, ctx.tag());
437        if ctx.release_write_buf(buf, result) == IoTaskStatus::Stop {
438            Poll::Ready(Status::Terminate)
439        } else {
440            Poll::Pending
441        }
442    } else {
443        Poll::Pending
444    }
445}
446
447fn read(io: &IoTest, ctx: &IoContext, cx: &mut Context<'_>) -> Poll<()> {
448    let mut buf = ctx.get_read_buf();
449
450    // read data from socket
451    let mut n = 0;
452    loop {
453        ctx.resize_read_buf(&mut buf);
454
455        let result = match io.poll_read_buf(cx, &mut buf) {
456            Poll::Pending => {
457                if n > 0 {
458                    Poll::Ready(Ok(()))
459                } else {
460                    Poll::Pending
461                }
462            }
463            Poll::Ready(Ok(size)) => {
464                n += size;
465                if size > 0 && buf.remaining_mut() > 0 {
466                    continue;
467                }
468                if size == 0 {
469                    Poll::Ready(Err(None))
470                } else {
471                    Poll::Ready(Ok(()))
472                }
473            }
474            Poll::Ready(Err(err)) => Poll::Ready(Err(Some(err))),
475        };
476
477        return if matches!(ctx.release_read_buf(n, buf, result), IoTaskStatus::Stop) {
478            Poll::Ready(())
479        } else {
480            Poll::Pending
481        };
482    }
483}
484
485/// Flush write buffer to underlying I/O stream.
486pub(super) fn write_io(
487    io: &IoTest,
488    buf: &mut BytesMut,
489    cx: &mut Context<'_>,
490    tag: &'static str,
491) -> Poll<io::Result<usize>> {
492    let len = buf.len();
493
494    if len != 0 {
495        log::debug!("{tag}: flushing framed transport: {len}");
496
497        let mut written = 0;
498        while let Poll::Ready(n) = io.poll_write_buf(cx, &buf[written..])? {
499            if n == 0 {
500                log::trace!("{tag}: disconnected during flush, written {written}");
501                return Poll::Ready(Err(io::Error::new(
502                    io::ErrorKind::WriteZero,
503                    "failed to write frame to transport",
504                )));
505            }
506            written += n;
507            if written == len {
508                break;
509            }
510        }
511        log::debug!("{tag}: flushed {written} bytes");
512        if written > 0 {
513            Poll::Ready(Ok(written))
514        } else {
515            Poll::Pending
516        }
517    } else {
518        Poll::Pending
519    }
520}
521
522#[cfg(test)]
523#[allow(clippy::redundant_clone)]
524mod tests {
525    use super::*;
526    use ntex_util::future::lazy;
527
528    #[ntex::test]
529    async fn basic() {
530        let (client, server) = IoTest::create();
531        assert_eq!(client.tp, Type::Client);
532        assert_eq!(client.clone().tp, Type::ClientClone);
533        assert_eq!(server.tp, Type::Server);
534        assert_eq!(server.clone().tp, Type::ServerClone);
535        assert!(format!("{server:?}").contains("IoTest"));
536        assert!(format!("{:?}", AtomicWaker::default()).contains("AtomicWaker"));
537
538        server.read_pending();
539        let mut buf = BytesMut::new();
540        let res = lazy(|cx| client.poll_read_buf(cx, &mut buf)).await;
541        assert!(res.is_pending());
542
543        server.read_pending();
544        let res = lazy(|cx| server.poll_write_buf(cx, b"123")).await;
545        assert!(res.is_pending());
546
547        assert!(!server.is_client_dropped());
548        drop(client);
549        assert!(server.is_client_dropped());
550
551        let server2 = server.clone();
552        assert!(!server2.is_server_dropped());
553        drop(server);
554        assert!(server2.is_server_dropped());
555
556        let res = lazy(|cx| server2.poll_write_buf(cx, b"123")).await;
557        assert!(res.is_pending());
558
559        let (client, _) = IoTest::create();
560        let addr: net::SocketAddr = "127.0.0.1:8080".parse().unwrap();
561        let client = crate::Io::from(client.set_peer_addr(addr));
562        let item = client.query::<crate::types::PeerAddr>();
563        assert!(format!("{item:?}").contains("QueryItem(127.0.0.1:8080)"));
564    }
565}