ntex_io/
testing.rs

1//! utilities and helpers for testing
2#![allow(clippy::let_underscore_future)]
3use std::sync::{Arc, Mutex};
4use std::task::{Context, Poll, Waker};
5use std::{any, cell::RefCell, cmp, fmt, future::poll_fn, io, mem, net, rc::Rc};
6
7use ntex_bytes::{Buf, BufMut, Bytes, BytesVec};
8use ntex_util::time::{sleep, Millis};
9
10use crate::{types, Handle, IoStream, ReadContext, WriteContext, WriteContextBuf};
11
12#[derive(Default)]
13struct AtomicWaker(Arc<Mutex<RefCell<Option<Waker>>>>);
14
15impl AtomicWaker {
16    fn wake(&self) {
17        if let Some(waker) = self.0.lock().unwrap().borrow_mut().take() {
18            waker.wake()
19        }
20    }
21}
22
23impl fmt::Debug for AtomicWaker {
24    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
25        write!(f, "AtomicWaker")
26    }
27}
28
29/// Async io stream
30#[derive(Debug)]
31pub struct IoTest {
32    tp: Type,
33    peer_addr: Option<net::SocketAddr>,
34    state: Arc<Mutex<RefCell<State>>>,
35    local: Arc<Mutex<RefCell<Channel>>>,
36    remote: Arc<Mutex<RefCell<Channel>>>,
37}
38
39bitflags::bitflags! {
40    #[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
41    struct IoTestFlags: u8 {
42        const FLUSHED = 0b0000_0001;
43        const CLOSED  = 0b0000_0010;
44    }
45}
46
47#[derive(Copy, Clone, PartialEq, Eq, Debug)]
48enum Type {
49    Client,
50    Server,
51    ClientClone,
52    ServerClone,
53}
54
55#[derive(Copy, Clone, Default, Debug)]
56struct State {
57    client_dropped: bool,
58    server_dropped: bool,
59}
60
61#[derive(Default, Debug)]
62struct Channel {
63    buf: BytesVec,
64    buf_cap: usize,
65    flags: IoTestFlags,
66    waker: AtomicWaker,
67    read: IoTestState,
68    write: IoTestState,
69}
70
71unsafe impl Sync for Channel {}
72unsafe impl Send for Channel {}
73
74impl Channel {
75    fn is_closed(&self) -> bool {
76        self.flags.contains(IoTestFlags::CLOSED)
77    }
78}
79
80impl Default for IoTestFlags {
81    fn default() -> Self {
82        IoTestFlags::empty()
83    }
84}
85
86#[derive(Debug, Default)]
87enum IoTestState {
88    #[default]
89    Ok,
90    Pending,
91    Close,
92    Err(io::Error),
93}
94
95impl IoTest {
96    /// Create a two interconnected streams
97    pub fn create() -> (IoTest, IoTest) {
98        let local = Arc::new(Mutex::new(RefCell::new(Channel::default())));
99        let remote = Arc::new(Mutex::new(RefCell::new(Channel::default())));
100        let state = Arc::new(Mutex::new(RefCell::new(State::default())));
101
102        (
103            IoTest {
104                tp: Type::Client,
105                peer_addr: None,
106                local: local.clone(),
107                remote: remote.clone(),
108                state: state.clone(),
109            },
110            IoTest {
111                state,
112                peer_addr: None,
113                tp: Type::Server,
114                local: remote,
115                remote: local,
116            },
117        )
118    }
119
120    pub fn is_client_dropped(&self) -> bool {
121        self.state.lock().unwrap().borrow().client_dropped
122    }
123
124    pub fn is_server_dropped(&self) -> bool {
125        self.state.lock().unwrap().borrow().server_dropped
126    }
127
128    /// Check if channel is closed from remoote side
129    pub fn is_closed(&self) -> bool {
130        self.remote.lock().unwrap().borrow().is_closed()
131    }
132
133    /// Set peer addr
134    pub fn set_peer_addr(mut self, addr: net::SocketAddr) -> Self {
135        self.peer_addr = Some(addr);
136        self
137    }
138
139    /// Set read to Pending state
140    pub fn read_pending(&self) {
141        self.remote.lock().unwrap().borrow_mut().read = IoTestState::Pending;
142    }
143
144    /// Set read to error
145    pub fn read_error(&self, err: io::Error) {
146        let channel = self.remote.lock().unwrap();
147        channel.borrow_mut().read = IoTestState::Err(err);
148        channel.borrow().waker.wake();
149    }
150
151    /// Set write error on remote side
152    pub fn write_error(&self, err: io::Error) {
153        self.local.lock().unwrap().borrow_mut().write = IoTestState::Err(err);
154        self.remote.lock().unwrap().borrow().waker.wake();
155    }
156
157    /// Access read buffer.
158    pub fn local_buffer<F, R>(&self, f: F) -> R
159    where
160        F: FnOnce(&mut BytesVec) -> R,
161    {
162        let guard = self.local.lock().unwrap();
163        let mut ch = guard.borrow_mut();
164        f(&mut ch.buf)
165    }
166
167    /// Access remote buffer.
168    pub fn remote_buffer<F, R>(&self, f: F) -> R
169    where
170        F: FnOnce(&mut BytesVec) -> R,
171    {
172        let guard = self.remote.lock().unwrap();
173        let mut ch = guard.borrow_mut();
174        f(&mut ch.buf)
175    }
176
177    /// Closed remote side.
178    pub async fn close(&self) {
179        {
180            let guard = self.remote.lock().unwrap();
181            let mut remote = guard.borrow_mut();
182            remote.read = IoTestState::Close;
183            remote.waker.wake();
184            log::trace!("close remote socket");
185        }
186        sleep(Millis(35)).await;
187    }
188
189    /// Add extra data to the remote buffer and notify reader
190    pub fn write<T: AsRef<[u8]>>(&self, data: T) {
191        let guard = self.remote.lock().unwrap();
192        let mut write = guard.borrow_mut();
193        write.buf.extend_from_slice(data.as_ref());
194        write.waker.wake();
195    }
196
197    /// Read any available data
198    pub fn remote_buffer_cap(&self, cap: usize) {
199        // change cap
200        self.local.lock().unwrap().borrow_mut().buf_cap = cap;
201        // wake remote
202        self.remote.lock().unwrap().borrow().waker.wake();
203    }
204
205    /// Read any available data
206    pub fn read_any(&self) -> Bytes {
207        self.local.lock().unwrap().borrow_mut().buf.split().freeze()
208    }
209
210    /// Read data, if data is not available wait for it
211    pub async fn read(&self) -> Result<Bytes, io::Error> {
212        if self.local.lock().unwrap().borrow().buf.is_empty() {
213            poll_fn(|cx| {
214                let guard = self.local.lock().unwrap();
215                let read = guard.borrow_mut();
216                if read.buf.is_empty() {
217                    let closed = match self.tp {
218                        Type::Client | Type::ClientClone => {
219                            self.is_server_dropped() || read.is_closed()
220                        }
221                        Type::Server | Type::ServerClone => self.is_client_dropped(),
222                    };
223                    if closed {
224                        Poll::Ready(())
225                    } else {
226                        *read.waker.0.lock().unwrap().borrow_mut() =
227                            Some(cx.waker().clone());
228                        drop(read);
229                        drop(guard);
230                        Poll::Pending
231                    }
232                } else {
233                    Poll::Ready(())
234                }
235            })
236            .await;
237        }
238        Ok(self.local.lock().unwrap().borrow_mut().buf.split().freeze())
239    }
240
241    pub fn poll_read_buf(
242        &self,
243        cx: &mut Context<'_>,
244        buf: &mut BytesVec,
245    ) -> Poll<io::Result<usize>> {
246        let guard = self.local.lock().unwrap();
247        let mut ch = guard.borrow_mut();
248        *ch.waker.0.lock().unwrap().borrow_mut() = Some(cx.waker().clone());
249
250        if !ch.buf.is_empty() {
251            let size = std::cmp::min(ch.buf.len(), buf.remaining_mut());
252            let b = ch.buf.split_to(size);
253            buf.put_slice(&b);
254            return Poll::Ready(Ok(size));
255        }
256
257        match mem::take(&mut ch.read) {
258            IoTestState::Ok => Poll::Pending,
259            IoTestState::Close => {
260                ch.read = IoTestState::Close;
261                Poll::Ready(Ok(0))
262            }
263            IoTestState::Pending => Poll::Pending,
264            IoTestState::Err(e) => Poll::Ready(Err(e)),
265        }
266    }
267
268    pub fn poll_write_buf(
269        &self,
270        cx: &mut Context<'_>,
271        buf: &[u8],
272    ) -> Poll<io::Result<usize>> {
273        let guard = self.remote.lock().unwrap();
274        let mut ch = guard.borrow_mut();
275
276        match mem::take(&mut ch.write) {
277            IoTestState::Ok => {
278                let cap = cmp::min(buf.len(), ch.buf_cap);
279                if cap > 0 {
280                    ch.buf.extend(&buf[..cap]);
281                    ch.buf_cap -= cap;
282                    ch.flags.remove(IoTestFlags::FLUSHED);
283                    ch.waker.wake();
284                    Poll::Ready(Ok(cap))
285                } else {
286                    *self
287                        .local
288                        .lock()
289                        .unwrap()
290                        .borrow_mut()
291                        .waker
292                        .0
293                        .lock()
294                        .unwrap()
295                        .borrow_mut() = Some(cx.waker().clone());
296                    Poll::Pending
297                }
298            }
299            IoTestState::Close => Poll::Ready(Ok(0)),
300            IoTestState::Pending => {
301                *self
302                    .local
303                    .lock()
304                    .unwrap()
305                    .borrow_mut()
306                    .waker
307                    .0
308                    .lock()
309                    .unwrap()
310                    .borrow_mut() = Some(cx.waker().clone());
311                Poll::Pending
312            }
313            IoTestState::Err(e) => Poll::Ready(Err(e)),
314        }
315    }
316}
317
318impl Clone for IoTest {
319    fn clone(&self) -> Self {
320        let tp = match self.tp {
321            Type::Server => Type::ServerClone,
322            Type::Client => Type::ClientClone,
323            val => val,
324        };
325
326        IoTest {
327            tp,
328            local: self.local.clone(),
329            remote: self.remote.clone(),
330            state: self.state.clone(),
331            peer_addr: self.peer_addr,
332        }
333    }
334}
335
336impl Drop for IoTest {
337    fn drop(&mut self) {
338        let mut state = *self.state.lock().unwrap().borrow();
339        match self.tp {
340            Type::Server => state.server_dropped = true,
341            Type::Client => state.client_dropped = true,
342            _ => (),
343        }
344        *self.state.lock().unwrap().borrow_mut() = state;
345
346        let guard = self.remote.lock().unwrap();
347        let mut remote = guard.borrow_mut();
348        remote.read = IoTestState::Close;
349        remote.waker.wake();
350        log::trace!("drop remote socket");
351    }
352}
353
354impl IoStream for IoTest {
355    fn start(self, read: ReadContext, write: WriteContext) -> Option<Box<dyn Handle>> {
356        let io = Rc::new(self);
357
358        let mut rio = Read(io.clone());
359        let _ = ntex_util::spawn(async move {
360            read.handle(&mut rio).await;
361        });
362
363        let mut wio = Write(io.clone());
364        let _ = ntex_util::spawn(async move {
365            write.handle(&mut wio).await;
366        });
367
368        Some(Box::new(io))
369    }
370}
371
372impl Handle for Rc<IoTest> {
373    fn query(&self, id: any::TypeId) -> Option<Box<dyn any::Any>> {
374        if id == any::TypeId::of::<types::PeerAddr>() {
375            if let Some(addr) = self.peer_addr {
376                return Some(Box::new(types::PeerAddr(addr)));
377            }
378        }
379        None
380    }
381}
382
383/// Read io task
384struct Read(Rc<IoTest>);
385
386impl crate::AsyncRead for Read {
387    async fn read(&mut self, mut buf: BytesVec) -> (BytesVec, io::Result<usize>) {
388        // read data from socket
389        let result = poll_fn(|cx| self.0.poll_read_buf(cx, &mut buf)).await;
390        (buf, result)
391    }
392}
393
394/// Write
395struct Write(Rc<IoTest>);
396
397impl crate::AsyncWrite for Write {
398    async fn write(&mut self, buf: &mut WriteContextBuf) -> io::Result<()> {
399        poll_fn(|cx| {
400            if let Some(mut b) = buf.take() {
401                let result = write_io(&self.0, &mut b, cx);
402                buf.set(b);
403                result
404            } else {
405                Poll::Ready(Ok(()))
406            }
407        })
408        .await
409    }
410
411    async fn flush(&mut self) -> io::Result<()> {
412        Ok(())
413    }
414
415    async fn shutdown(&mut self) -> io::Result<()> {
416        // shutdown WRITE side
417        self.0
418            .local
419            .lock()
420            .unwrap()
421            .borrow_mut()
422            .flags
423            .insert(IoTestFlags::CLOSED);
424        Ok(())
425    }
426}
427
428/// Flush write buffer to underlying I/O stream.
429pub(super) fn write_io(
430    io: &IoTest,
431    buf: &mut BytesVec,
432    cx: &mut Context<'_>,
433) -> Poll<io::Result<()>> {
434    let len = buf.len();
435
436    if len != 0 {
437        log::trace!("flushing framed transport: {len}");
438
439        let mut written = 0;
440        let result = loop {
441            break match io.poll_write_buf(cx, &buf[written..]) {
442                Poll::Ready(Ok(n)) => {
443                    if n == 0 {
444                        log::trace!("disconnected during flush, written {written}");
445                        Poll::Ready(Err(io::Error::new(
446                            io::ErrorKind::WriteZero,
447                            "failed to write frame to transport",
448                        )))
449                    } else {
450                        written += n;
451                        if written == len {
452                            buf.clear();
453                            Poll::Ready(Ok(()))
454                        } else {
455                            continue;
456                        }
457                    }
458                }
459                Poll::Pending => {
460                    // remove written data
461                    buf.advance(written);
462                    Poll::Pending
463                }
464                Poll::Ready(Err(e)) => {
465                    log::trace!("error during flush: {e}");
466                    Poll::Ready(Err(e))
467                }
468            };
469        };
470        log::trace!("flushed {written} bytes");
471        result
472    } else {
473        Poll::Ready(Ok(()))
474    }
475}
476
477#[cfg(test)]
478#[allow(clippy::redundant_clone)]
479mod tests {
480    use super::*;
481    use ntex_util::future::lazy;
482
483    #[ntex::test]
484    async fn basic() {
485        let (client, server) = IoTest::create();
486        assert_eq!(client.tp, Type::Client);
487        assert_eq!(client.clone().tp, Type::ClientClone);
488        assert_eq!(server.tp, Type::Server);
489        assert_eq!(server.clone().tp, Type::ServerClone);
490        assert!(format!("{server:?}").contains("IoTest"));
491        assert!(format!("{:?}", AtomicWaker::default()).contains("AtomicWaker"));
492
493        server.read_pending();
494        let mut buf = BytesVec::new();
495        let res = lazy(|cx| client.poll_read_buf(cx, &mut buf)).await;
496        assert!(res.is_pending());
497
498        server.read_pending();
499        let res = lazy(|cx| server.poll_write_buf(cx, b"123")).await;
500        assert!(res.is_pending());
501
502        assert!(!server.is_client_dropped());
503        drop(client);
504        assert!(server.is_client_dropped());
505
506        let server2 = server.clone();
507        assert!(!server2.is_server_dropped());
508        drop(server);
509        assert!(server2.is_server_dropped());
510
511        let res = lazy(|cx| server2.poll_write_buf(cx, b"123")).await;
512        assert!(res.is_pending());
513
514        let (client, _) = IoTest::create();
515        let addr: net::SocketAddr = "127.0.0.1:8080".parse().unwrap();
516        let client = crate::Io::new(client.set_peer_addr(addr));
517        let item = client.query::<crate::types::PeerAddr>();
518        assert!(format!("{item:?}").contains("QueryItem(127.0.0.1:8080)"));
519    }
520}