rama_net/test_utils/client/
mock_connector.rs

1use std::{convert::Infallible, fmt, net::Ipv4Addr};
2
3use rama_core::{Context, Service};
4use tokio::io::{AsyncRead, AsyncWrite, DuplexStream, duplex};
5
6use crate::{client::EstablishedClientConnection, stream::Socket};
7
8/// Mock connector can be used in tests to simulate connectors so we can test client and servers
9/// without opening actuall connections
10pub struct MockConnectorService<S> {
11    create_server: S,
12    max_buffer_size: usize,
13}
14
15impl<S: fmt::Debug> fmt::Debug for MockConnectorService<S> {
16    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
17        f.debug_struct("MockConnectorService")
18            .field("create_server", &self.create_server)
19            .finish()
20    }
21}
22
23impl<S> MockConnectorService<S> {
24    pub fn new(create_server: S) -> Self {
25        Self {
26            create_server,
27            max_buffer_size: 1024,
28        }
29    }
30
31    /// Set `max_buffer_size` that will be used when creating DuplexStream
32    pub fn set_max_buffer_size(&mut self, size: usize) -> &mut Self {
33        self.max_buffer_size = size;
34        self
35    }
36
37    /// [`MockConnectorService`] with `max_buffer_size` that will be used when creating DuplexStream
38    pub fn with_max_buffer_size(self, size: usize) -> Self {
39        Self {
40            max_buffer_size: size,
41            ..self
42        }
43    }
44}
45
46impl<State, S, Request, Error, Server> Service<State, Request> for MockConnectorService<S>
47where
48    S: Fn() -> Server + Send + Sync + 'static,
49    Server: Service<State, MockSocket, Error = Error>,
50    State: Clone + Send + Sync + 'static,
51    Request: Send + 'static,
52    Error: std::fmt::Debug + 'static,
53{
54    type Error = Infallible;
55    type Response = EstablishedClientConnection<MockSocket, State, Request>;
56
57    async fn serve(
58        &self,
59        ctx: Context<State>,
60        req: Request,
61    ) -> Result<Self::Response, Self::Error> {
62        let (client, server) = duplex(self.max_buffer_size);
63        let client_socket = MockSocket { stream: client };
64        let server_socket = MockSocket { stream: server };
65
66        let server = (self.create_server)();
67        let server_ctx = ctx.clone();
68
69        tokio::spawn(async move {
70            server.serve(server_ctx, server_socket).await.unwrap();
71        });
72
73        Ok(EstablishedClientConnection {
74            ctx,
75            req,
76            conn: client_socket,
77        })
78    }
79}
80
81#[derive(Debug)]
82pub struct MockSocket {
83    stream: DuplexStream,
84}
85
86impl AsyncRead for MockSocket {
87    fn poll_read(
88        mut self: std::pin::Pin<&mut Self>,
89        cx: &mut std::task::Context<'_>,
90        buf: &mut tokio::io::ReadBuf<'_>,
91    ) -> std::task::Poll<std::io::Result<()>> {
92        std::pin::Pin::new(&mut self.stream).poll_read(cx, buf)
93    }
94}
95
96impl AsyncWrite for MockSocket {
97    fn poll_write(
98        mut self: std::pin::Pin<&mut Self>,
99        cx: &mut std::task::Context<'_>,
100        buf: &[u8],
101    ) -> std::task::Poll<std::io::Result<usize>> {
102        std::pin::Pin::new(&mut self.stream).poll_write(cx, buf)
103    }
104
105    fn poll_flush(
106        mut self: std::pin::Pin<&mut Self>,
107        cx: &mut std::task::Context<'_>,
108    ) -> std::task::Poll<std::io::Result<()>> {
109        std::pin::Pin::new(&mut self.stream).poll_flush(cx)
110    }
111
112    fn poll_shutdown(
113        mut self: std::pin::Pin<&mut Self>,
114        cx: &mut std::task::Context<'_>,
115    ) -> std::task::Poll<std::io::Result<()>> {
116        std::pin::Pin::new(&mut self.stream).poll_shutdown(cx)
117    }
118
119    fn is_write_vectored(&self) -> bool {
120        self.stream.is_write_vectored()
121    }
122
123    fn poll_write_vectored(
124        mut self: std::pin::Pin<&mut Self>,
125        cx: &mut std::task::Context<'_>,
126        bufs: &[std::io::IoSlice<'_>],
127    ) -> std::task::Poll<Result<usize, std::io::Error>> {
128        std::pin::Pin::new(&mut self.stream).poll_write_vectored(cx, bufs)
129    }
130}
131
132impl Socket for MockSocket {
133    fn local_addr(&self) -> std::io::Result<std::net::SocketAddr> {
134        Ok(std::net::SocketAddr::V4(std::net::SocketAddrV4::new(
135            Ipv4Addr::new(127, 0, 0, 1),
136            0,
137        )))
138    }
139
140    fn peer_addr(&self) -> std::io::Result<std::net::SocketAddr> {
141        Ok(std::net::SocketAddr::V4(std::net::SocketAddrV4::new(
142            Ipv4Addr::new(127, 0, 0, 1),
143            0,
144        )))
145    }
146}