1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
use std::time::{Duration, Instant};
use async_io::Timer;
use async_task::Task;
use async_trait::async_trait;
use futures_lite::FutureExt as _;
use futures_util::{AsyncReadExt, AsyncWriteExt};
use rand::{Rng, RngCore};
use sillad::{Pipe, dialer::Dialer, listener::Listener};
/// Wraps an underlying dialer with a connection quality test.
pub struct ConnTestDialer<D: Dialer> {
pub inner: D,
pub ping_count: usize,
}
#[async_trait]
impl<D: Dialer> Dialer for ConnTestDialer<D> {
type P = D::P;
async fn dial(&self) -> std::io::Result<Self::P> {
let mut pipe = self.inner.dial().await?;
for index in 0..self.ping_count {
let start = Instant::now();
// Pick a random payload size (nonzero)
let size = rand::rng().random_range(1..1000u16);
// Tell the server the payload size.
pipe.write_all(&size.to_be_bytes()).await?;
// Prepare and send a random payload.
let mut buf = vec![0u8; size as usize];
rand::rng().fill_bytes(&mut buf);
pipe.write_all(&buf).await?;
// Read back the echoed payload.
let mut echo = vec![0u8; size as usize];
pipe.read_exact(&mut echo).await?;
let remote_addr = pipe.remote_addr();
tracing::debug!(
elapsed = debug(start.elapsed()),
total_count = self.ping_count,
index,
remote_addr = debug(remote_addr),
"ping completed"
);
if buf != echo {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"ping returned incorrect data",
));
}
}
// Termination message: a 0 length indicates end of testing.
pipe.write_all(&[0u8; 2]).await?;
Ok(pipe)
}
}
/// Wraps an underlying listener with a connection quality test.
pub struct ConnTestListener<L: Listener> {
recv_conn: tachyonix::Receiver<L::P>,
_task: Task<()>,
}
impl<L: Listener> ConnTestListener<L> {
pub fn new(mut listener: L) -> Self {
// Create a channel for passing successfully tested connections.
let (send_conn, recv_conn) = tachyonix::channel(1);
// Spawn a background task that loops over accepted connections.
let task = smolscale::spawn(async move {
loop {
// Accept a new connection from the underlying listener.
let mut conn = match listener.accept().await {
Ok(c) => c,
Err(e) => {
tracing::warn!("Failed to accept connection: {:?}", e);
async_io::Timer::after(Duration::from_secs(1)).await;
continue;
}
};
let send_conn = send_conn.clone();
// For each accepted connection, spawn a task to perform the ping test.
smolscale::spawn::<std::io::Result<()>>(async move {
let inner = async {
loop {
let mut size_buf = [0u8; 2];
conn.read_exact(&mut size_buf).await?;
let size = u16::from_be_bytes(size_buf);
// A zero size means the client has finished pinging.
if size == 0 {
let _ = send_conn.send(conn).await;
return Ok(());
}
let mut payload = vec![0u8; size as usize];
conn.read_exact(&mut payload).await?;
conn.write_all(&payload).await?;
}
};
inner
.or(async {
Timer::after(Duration::from_secs(30)).await;
Ok(())
})
.await
})
.detach();
}
});
Self {
recv_conn,
_task: task,
}
}
}
#[async_trait]
impl<L: Listener> Listener for ConnTestListener<L> {
type P = L::P;
async fn accept(&mut self) -> std::io::Result<Self::P> {
// Wait for a connection that passed the ping test.
self.recv_conn.recv().await.map_err(|_| {
std::io::Error::new(std::io::ErrorKind::BrokenPipe, "background task is done")
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use futures_lite::{AsyncReadExt, AsyncWriteExt};
use sillad::tcp::{TcpDialer, TcpListener};
use smolscale::spawn;
use std::io;
use std::net::SocketAddr;
// If your TCP types are defined in another module, adjust these imports accordingly.
// For example:
// use crate::{TcpListener, TcpDialer};
/// This unit test creates a TCP listener (wrapped by `ConnTestListener`) that
/// echoes incoming data. The client uses `ConnTestDialer` to perform several
/// ping rounds before using the connection. The test then verifies that a test
/// message is echoed back correctly.
#[test]
fn test_successful_ping() -> io::Result<()> {
async_io::block_on(async {
// Bind a TCP listener to an ephemeral port on localhost.
let addr: SocketAddr = "127.0.0.1:0".parse().unwrap();
let tcp_listener = TcpListener::bind(addr).await?;
let local_addr = tcp_listener.local_addr().await;
// Wrap the TCP listener with ConnTestListener.
let mut conn_test_listener = ConnTestListener::new(tcp_listener);
// Spawn a background task that, once the ping test is complete,
// performs an echo for any additional messages.
let server_handle = spawn(async move {
// Accept the connection that passed the ping test.
let mut conn = conn_test_listener.accept().await?;
let mut buf = [0u8; 1024];
loop {
let n = conn.read(&mut buf).await?;
if n == 0 {
break; // Connection closed.
}
conn.write_all(&buf[..n]).await?;
}
Ok::<(), io::Error>(())
});
// Create a TCP dialer pointed at the server’s address.
let tcp_dialer = TcpDialer {
dest_addr: local_addr,
};
// Wrap the TCP dialer with ConnTestDialer (performing, for example, 3 ping rounds).
let conn_test_dialer = ConnTestDialer {
inner: tcp_dialer,
ping_count: 3,
};
// Dial to the server. This will perform the ping test internally.
let mut client_pipe = conn_test_dialer.dial().await?;
// Send a test message and expect an echo.
let test_message = b"hello, unit test!";
client_pipe.write_all(test_message).await?;
let mut buf = vec![0u8; test_message.len()];
client_pipe.read_exact(&mut buf).await?;
assert_eq!(
&buf, test_message,
"the echoed message should match the sent message"
);
// Clean up.
drop(client_pipe);
server_handle.await?;
Ok(())
})
}
/// This unit test simulates a server that deliberately corrupts the ping echo.
/// As a result, the `ConnTestDialer` should detect the invalid data and fail.
#[test]
fn test_failed_ping() -> io::Result<()> {
async_io::block_on(async {
// Bind a TCP listener to an ephemeral port.
let addr: SocketAddr = "127.0.0.1:0".parse().unwrap();
let mut tcp_listener = TcpListener::bind(addr).await?;
let local_addr = tcp_listener.local_addr().await;
// Spawn a server task that corrupts each ping echo.
let server_handle = spawn(async move {
let mut conn = tcp_listener.accept().await?;
loop {
// Read the two-byte payload size.
let mut size_buf = [0u8; 2];
if conn.read_exact(&mut size_buf).await.is_err() {
break; // Connection closed.
}
let size = u16::from_be_bytes(size_buf);
if size == 0 {
break; // Termination message.
}
// Read the payload.
let mut payload = vec![0u8; size as usize];
conn.read_exact(&mut payload).await?;
// Corrupt the payload (flip the first byte, if any).
if !payload.is_empty() {
payload[0] = payload[0].wrapping_add(1);
}
// Send the corrupted payload back.
conn.write_all(&payload).await?;
}
Ok::<(), io::Error>(())
});
// Create a TCP dialer pointed at the server’s address.
let tcp_dialer = TcpDialer {
dest_addr: local_addr,
};
// Wrap the TCP dialer with ConnTestDialer (using 3 ping rounds).
let conn_test_dialer = ConnTestDialer {
inner: tcp_dialer,
ping_count: 3,
};
// Attempt to dial to the server.
// Since the server corrupts the echoed pings, the dial should return an error.
let result = conn_test_dialer.dial().await;
assert!(
result.is_err(),
"dialing should fail due to corrupted ping echoes"
);
let _ = server_handle.await;
Ok(())
})
}
}