rama_net/test_utils/client/
mock_connector.rs1use std::{convert::Infallible, fmt, net::Ipv4Addr};
2
3use rama_core::{Context, Service};
4use tokio::io::{AsyncRead, AsyncWrite, DuplexStream, duplex};
5
6use crate::{client::EstablishedClientConnection, stream::Socket};
7
8pub struct MockConnectorService<S> {
11 create_server: S,
12 max_buffer_size: usize,
13}
14
15impl<S: fmt::Debug> fmt::Debug for MockConnectorService<S> {
16 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
17 f.debug_struct("MockConnectorService")
18 .field("create_server", &self.create_server)
19 .finish()
20 }
21}
22
23impl<S> MockConnectorService<S> {
24 pub fn new(create_server: S) -> Self {
25 Self {
26 create_server,
27 max_buffer_size: 1024,
28 }
29 }
30
31 pub fn set_max_buffer_size(&mut self, size: usize) -> &mut Self {
33 self.max_buffer_size = size;
34 self
35 }
36
37 pub fn with_max_buffer_size(self, size: usize) -> Self {
39 Self {
40 max_buffer_size: size,
41 ..self
42 }
43 }
44}
45
46impl<State, S, Request, Error, Server> Service<State, Request> for MockConnectorService<S>
47where
48 S: Fn() -> Server + Send + Sync + 'static,
49 Server: Service<State, MockSocket, Error = Error>,
50 State: Clone + Send + Sync + 'static,
51 Request: Send + 'static,
52 Error: std::fmt::Debug + 'static,
53{
54 type Error = Infallible;
55 type Response = EstablishedClientConnection<MockSocket, State, Request>;
56
57 async fn serve(
58 &self,
59 ctx: Context<State>,
60 req: Request,
61 ) -> Result<Self::Response, Self::Error> {
62 let (client, server) = duplex(self.max_buffer_size);
63 let client_socket = MockSocket { stream: client };
64 let server_socket = MockSocket { stream: server };
65
66 let server = (self.create_server)();
67 let server_ctx = ctx.clone();
68
69 tokio::spawn(async move {
70 server.serve(server_ctx, server_socket).await.unwrap();
71 });
72
73 Ok(EstablishedClientConnection {
74 ctx,
75 req,
76 conn: client_socket,
77 })
78 }
79}
80
81#[derive(Debug)]
82pub struct MockSocket {
83 stream: DuplexStream,
84}
85
86impl AsyncRead for MockSocket {
87 fn poll_read(
88 mut self: std::pin::Pin<&mut Self>,
89 cx: &mut std::task::Context<'_>,
90 buf: &mut tokio::io::ReadBuf<'_>,
91 ) -> std::task::Poll<std::io::Result<()>> {
92 std::pin::Pin::new(&mut self.stream).poll_read(cx, buf)
93 }
94}
95
96impl AsyncWrite for MockSocket {
97 fn poll_write(
98 mut self: std::pin::Pin<&mut Self>,
99 cx: &mut std::task::Context<'_>,
100 buf: &[u8],
101 ) -> std::task::Poll<std::io::Result<usize>> {
102 std::pin::Pin::new(&mut self.stream).poll_write(cx, buf)
103 }
104
105 fn poll_flush(
106 mut self: std::pin::Pin<&mut Self>,
107 cx: &mut std::task::Context<'_>,
108 ) -> std::task::Poll<std::io::Result<()>> {
109 std::pin::Pin::new(&mut self.stream).poll_flush(cx)
110 }
111
112 fn poll_shutdown(
113 mut self: std::pin::Pin<&mut Self>,
114 cx: &mut std::task::Context<'_>,
115 ) -> std::task::Poll<std::io::Result<()>> {
116 std::pin::Pin::new(&mut self.stream).poll_shutdown(cx)
117 }
118
119 fn is_write_vectored(&self) -> bool {
120 self.stream.is_write_vectored()
121 }
122
123 fn poll_write_vectored(
124 mut self: std::pin::Pin<&mut Self>,
125 cx: &mut std::task::Context<'_>,
126 bufs: &[std::io::IoSlice<'_>],
127 ) -> std::task::Poll<Result<usize, std::io::Error>> {
128 std::pin::Pin::new(&mut self.stream).poll_write_vectored(cx, bufs)
129 }
130}
131
132impl Socket for MockSocket {
133 fn local_addr(&self) -> std::io::Result<std::net::SocketAddr> {
134 Ok(std::net::SocketAddr::V4(std::net::SocketAddrV4::new(
135 Ipv4Addr::new(127, 0, 0, 1),
136 0,
137 )))
138 }
139
140 fn peer_addr(&self) -> std::io::Result<std::net::SocketAddr> {
141 Ok(std::net::SocketAddr::V4(std::net::SocketAddrV4::new(
142 Ipv4Addr::new(127, 0, 0, 1),
143 0,
144 )))
145 }
146}