1use std::collections::HashMap;
9use std::io;
10use std::net::SocketAddr;
11use std::sync::atomic::{AtomicU64, Ordering};
12use std::sync::Arc;
13
14use tokio::io::{AsyncReadExt, AsyncWriteExt};
15use tokio::net::{TcpListener, TcpStream, UdpSocket};
16use tokio::sync::{mpsc, Mutex};
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
20pub struct ConnId(pub u64);
21
22#[derive(Debug)]
24pub enum IoEvent {
25 Connected { id: ConnId, peer: SocketAddr },
26 Received { id: ConnId, bytes: Vec<u8> },
27 Closed { id: ConnId },
28 Bound { addr: SocketAddr },
29 Datagram { from: SocketAddr, bytes: Vec<u8> },
30 Error { reason: String },
31}
32
33#[derive(Debug)]
35pub enum TcpCommand {
36 Bind {
39 addr: SocketAddr,
40 },
41 Connect {
45 addr: SocketAddr,
46 },
47 Send {
48 id: ConnId,
49 bytes: Vec<u8>,
50 },
51 Close {
52 id: ConnId,
53 },
54 Shutdown,
55}
56
57type Conns = Arc<Mutex<HashMap<ConnId, mpsc::UnboundedSender<Vec<u8>>>>>;
58
59pub struct TcpManager {
62 cmd: mpsc::UnboundedSender<TcpCommand>,
63}
64
65impl TcpManager {
66 pub fn spawn() -> (Self, mpsc::UnboundedReceiver<IoEvent>) {
68 let (cmd_tx, cmd_rx) = mpsc::unbounded_channel();
69 let (evt_tx, evt_rx) = mpsc::unbounded_channel();
70 let conns: Conns = Arc::new(Mutex::new(HashMap::new()));
71 tokio::spawn(run_tcp(cmd_rx, evt_tx, conns));
72 (Self { cmd: cmd_tx }, evt_rx)
73 }
74
75 pub fn bind(&self, addr: SocketAddr) -> io::Result<()> {
76 self.cmd
77 .send(TcpCommand::Bind { addr })
78 .map_err(|_| io::Error::new(io::ErrorKind::BrokenPipe, "manager stopped"))
79 }
80 pub fn connect(&self, addr: SocketAddr) -> io::Result<()> {
84 self.cmd
85 .send(TcpCommand::Connect { addr })
86 .map_err(|_| io::Error::new(io::ErrorKind::BrokenPipe, "manager stopped"))
87 }
88 pub fn send_bytes(&self, id: ConnId, bytes: Vec<u8>) -> io::Result<()> {
89 self.cmd
90 .send(TcpCommand::Send { id, bytes })
91 .map_err(|_| io::Error::new(io::ErrorKind::BrokenPipe, "manager stopped"))
92 }
93 pub fn close(&self, id: ConnId) -> io::Result<()> {
94 self.cmd
95 .send(TcpCommand::Close { id })
96 .map_err(|_| io::Error::new(io::ErrorKind::BrokenPipe, "manager stopped"))
97 }
98 pub fn shutdown(&self) {
99 let _ = self.cmd.send(TcpCommand::Shutdown);
100 }
101}
102
103static SEQ: AtomicU64 = AtomicU64::new(1);
104
105async fn run_tcp(
106 mut cmd: mpsc::UnboundedReceiver<TcpCommand>,
107 evt: mpsc::UnboundedSender<IoEvent>,
108 conns: Conns,
109) {
110 while let Some(c) = cmd.recv().await {
111 match c {
112 TcpCommand::Bind { addr } => {
113 let evt_tx = evt.clone();
114 let conns = conns.clone();
115 tokio::spawn(async move {
116 let listener = match TcpListener::bind(addr).await {
117 Ok(l) => l,
118 Err(e) => {
119 let _ = evt_tx.send(IoEvent::Error { reason: e.to_string() });
120 return;
121 }
122 };
123 let bound = listener.local_addr().unwrap_or(addr);
124 let _ = evt_tx.send(IoEvent::Bound { addr: bound });
125 loop {
126 let stream = match listener.accept().await {
127 Ok((s, _)) => s,
128 Err(e) => {
129 let _ = evt_tx.send(IoEvent::Error { reason: e.to_string() });
130 break;
131 }
132 };
133 let peer = stream.peer_addr().unwrap_or(bound);
134 let id = ConnId(SEQ.fetch_add(1, Ordering::Relaxed));
135 let _ = evt_tx.send(IoEvent::Connected { id, peer });
136 spawn_conn(id, stream, evt_tx.clone(), conns.clone()).await;
137 }
138 });
139 }
140 TcpCommand::Connect { addr } => {
141 let evt_tx = evt.clone();
142 let conns = conns.clone();
143 tokio::spawn(async move {
144 let stream = match TcpStream::connect(addr).await {
145 Ok(s) => s,
146 Err(e) => {
147 let _ = evt_tx.send(IoEvent::Error { reason: e.to_string() });
148 return;
149 }
150 };
151 let peer = stream.peer_addr().unwrap_or(addr);
152 let id = ConnId(SEQ.fetch_add(1, Ordering::Relaxed));
153 let _ = evt_tx.send(IoEvent::Connected { id, peer });
154 spawn_conn(id, stream, evt_tx, conns).await;
155 });
156 }
157 TcpCommand::Send { id, bytes } => {
158 let g = conns.lock().await;
159 if let Some(tx) = g.get(&id) {
160 let _ = tx.send(bytes);
161 }
162 }
163 TcpCommand::Close { id } => {
164 conns.lock().await.remove(&id);
165 }
166 TcpCommand::Shutdown => break,
167 }
168 }
169}
170
171async fn spawn_conn(id: ConnId, stream: TcpStream, evt: mpsc::UnboundedSender<IoEvent>, conns: Conns) {
172 let (write_tx, mut write_rx) = mpsc::unbounded_channel::<Vec<u8>>();
173 conns.lock().await.insert(id, write_tx);
174 let (mut rh, mut wh) = stream.into_split();
175 tokio::spawn(async move {
176 while let Some(bytes) = write_rx.recv().await {
177 if wh.write_all(&bytes).await.is_err() {
178 break;
179 }
180 }
181 let _ = wh.shutdown().await;
182 });
183 let evt2 = evt.clone();
184 tokio::spawn(async move {
185 let mut buf = vec![0u8; 8 * 1024];
186 loop {
187 match rh.read(&mut buf).await {
188 Ok(0) | Err(_) => {
189 let _ = evt2.send(IoEvent::Closed { id });
190 break;
191 }
192 Ok(n) => {
193 let _ = evt2.send(IoEvent::Received { id, bytes: buf[..n].to_vec() });
194 }
195 }
196 }
197 });
198}
199
200#[derive(Debug)]
201pub enum UdpCommand {
202 Send { to: SocketAddr, bytes: Vec<u8> },
203 Shutdown,
204}
205
206pub struct UdpManager {
208 cmd: mpsc::UnboundedSender<UdpCommand>,
209 local: SocketAddr,
210}
211
212impl UdpManager {
213 pub async fn bind(addr: SocketAddr) -> io::Result<(Self, mpsc::UnboundedReceiver<IoEvent>)> {
214 let socket = UdpSocket::bind(addr).await?;
215 let local = socket.local_addr()?;
216 let (cmd_tx, mut cmd_rx) = mpsc::unbounded_channel();
217 let (evt_tx, evt_rx) = mpsc::unbounded_channel();
218 let socket = Arc::new(socket);
219 let s_recv = socket.clone();
220 let etx = evt_tx.clone();
221 tokio::spawn(async move {
222 let mut buf = vec![0u8; 64 * 1024];
223 loop {
224 match s_recv.recv_from(&mut buf).await {
225 Ok((n, from)) => {
226 let _ = etx.send(IoEvent::Datagram { from, bytes: buf[..n].to_vec() });
227 }
228 Err(e) => {
229 let _ = etx.send(IoEvent::Error { reason: e.to_string() });
230 break;
231 }
232 }
233 }
234 });
235 let s_send = socket.clone();
236 tokio::spawn(async move {
237 while let Some(c) = cmd_rx.recv().await {
238 match c {
239 UdpCommand::Send { to, bytes } => {
240 if let Err(e) = s_send.send_to(&bytes, to).await {
241 let _ = evt_tx.send(IoEvent::Error { reason: e.to_string() });
242 }
243 }
244 UdpCommand::Shutdown => break,
245 }
246 }
247 });
248 Ok((Self { cmd: cmd_tx, local }, evt_rx))
249 }
250
251 pub fn local_addr(&self) -> SocketAddr {
252 self.local
253 }
254
255 pub fn send_to(&self, to: SocketAddr, bytes: Vec<u8>) -> io::Result<()> {
256 self.cmd
257 .send(UdpCommand::Send { to, bytes })
258 .map_err(|_| io::Error::new(io::ErrorKind::BrokenPipe, "manager stopped"))
259 }
260 pub fn shutdown(&self) {
261 let _ = self.cmd.send(UdpCommand::Shutdown);
262 }
263}
264
265#[cfg(test)]
266mod tests {
267 use super::*;
268
269 #[tokio::test]
270 async fn udp_manager_round_trip() {
271 let (a, mut a_rx) = UdpManager::bind("127.0.0.1:0".parse().unwrap()).await.unwrap();
272 let (b, _b_rx) = UdpManager::bind("127.0.0.1:0".parse().unwrap()).await.unwrap();
273 b.send_to(a.local_addr(), b"hi".to_vec()).unwrap();
274 let evt =
275 tokio::time::timeout(std::time::Duration::from_millis(500), a_rx.recv()).await.unwrap().unwrap();
276 match evt {
277 IoEvent::Datagram { bytes, .. } => assert_eq!(bytes, b"hi"),
278 other => panic!("unexpected event: {other:?}"),
279 }
280 a.shutdown();
281 b.shutdown();
282 }
283
284 #[tokio::test]
285 async fn tcp_manager_accept_and_echo() {
286 let (mgr, mut events) = TcpManager::spawn();
287 mgr.bind("127.0.0.1:0".parse().unwrap()).unwrap();
288 let bound = match tokio::time::timeout(std::time::Duration::from_millis(500), events.recv())
289 .await
290 .unwrap()
291 .unwrap()
292 {
293 IoEvent::Bound { addr } => addr,
294 other => panic!("expected Bound, got {other:?}"),
295 };
296 let mut client = TcpStream::connect(bound).await.unwrap();
297 let id = match tokio::time::timeout(std::time::Duration::from_millis(500), events.recv())
298 .await
299 .unwrap()
300 .unwrap()
301 {
302 IoEvent::Connected { id, .. } => id,
303 other => panic!("expected Connected, got {other:?}"),
304 };
305 client.write_all(b"ping").await.unwrap();
306 match tokio::time::timeout(std::time::Duration::from_millis(500), events.recv())
307 .await
308 .unwrap()
309 .unwrap()
310 {
311 IoEvent::Received { bytes, .. } => assert_eq!(bytes, b"ping"),
312 other => panic!("expected Received, got {other:?}"),
313 }
314 mgr.send_bytes(id, b"pong".to_vec()).unwrap();
315 let mut buf = [0u8; 4];
316 client.read_exact(&mut buf).await.unwrap();
317 assert_eq!(&buf, b"pong");
318 mgr.shutdown();
319 }
320}