1use std::collections::HashMap;
9use std::io;
10use std::net::SocketAddr;
11use std::sync::atomic::{AtomicU64, Ordering};
12use std::sync::Arc;
13
14use tokio::io::{AsyncReadExt, AsyncWriteExt};
15use tokio::net::{TcpListener, TcpStream, UdpSocket};
16use tokio::sync::{mpsc, Mutex};
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
20pub struct ConnId(pub u64);
21
22#[derive(Debug)]
24pub enum IoEvent {
25 Connected { id: ConnId, peer: SocketAddr },
26 Received { id: ConnId, bytes: Vec<u8> },
27 Closed { id: ConnId },
28 Bound { addr: SocketAddr },
29 Datagram { from: SocketAddr, bytes: Vec<u8> },
30 Error { reason: String },
31}
32
33#[derive(Debug)]
35pub enum TcpCommand {
36 Bind { addr: SocketAddr },
37 Send { id: ConnId, bytes: Vec<u8> },
38 Close { id: ConnId },
39 Shutdown,
40}
41
42type Conns = Arc<Mutex<HashMap<ConnId, mpsc::UnboundedSender<Vec<u8>>>>>;
43
44pub struct TcpManager {
47 cmd: mpsc::UnboundedSender<TcpCommand>,
48}
49
50impl TcpManager {
51 pub fn spawn() -> (Self, mpsc::UnboundedReceiver<IoEvent>) {
53 let (cmd_tx, cmd_rx) = mpsc::unbounded_channel();
54 let (evt_tx, evt_rx) = mpsc::unbounded_channel();
55 let conns: Conns = Arc::new(Mutex::new(HashMap::new()));
56 tokio::spawn(run_tcp(cmd_rx, evt_tx, conns));
57 (Self { cmd: cmd_tx }, evt_rx)
58 }
59
60 pub fn bind(&self, addr: SocketAddr) -> io::Result<()> {
61 self.cmd
62 .send(TcpCommand::Bind { addr })
63 .map_err(|_| io::Error::new(io::ErrorKind::BrokenPipe, "manager stopped"))
64 }
65 pub fn send_bytes(&self, id: ConnId, bytes: Vec<u8>) -> io::Result<()> {
66 self.cmd
67 .send(TcpCommand::Send { id, bytes })
68 .map_err(|_| io::Error::new(io::ErrorKind::BrokenPipe, "manager stopped"))
69 }
70 pub fn close(&self, id: ConnId) -> io::Result<()> {
71 self.cmd
72 .send(TcpCommand::Close { id })
73 .map_err(|_| io::Error::new(io::ErrorKind::BrokenPipe, "manager stopped"))
74 }
75 pub fn shutdown(&self) {
76 let _ = self.cmd.send(TcpCommand::Shutdown);
77 }
78}
79
80static SEQ: AtomicU64 = AtomicU64::new(1);
81
82async fn run_tcp(
83 mut cmd: mpsc::UnboundedReceiver<TcpCommand>,
84 evt: mpsc::UnboundedSender<IoEvent>,
85 conns: Conns,
86) {
87 while let Some(c) = cmd.recv().await {
88 match c {
89 TcpCommand::Bind { addr } => {
90 let evt_tx = evt.clone();
91 let conns = conns.clone();
92 tokio::spawn(async move {
93 let listener = match TcpListener::bind(addr).await {
94 Ok(l) => l,
95 Err(e) => {
96 let _ = evt_tx.send(IoEvent::Error { reason: e.to_string() });
97 return;
98 }
99 };
100 let bound = listener.local_addr().unwrap_or(addr);
101 let _ = evt_tx.send(IoEvent::Bound { addr: bound });
102 loop {
103 let stream = match listener.accept().await {
104 Ok((s, _)) => s,
105 Err(e) => {
106 let _ = evt_tx.send(IoEvent::Error { reason: e.to_string() });
107 break;
108 }
109 };
110 let peer = stream.peer_addr().unwrap_or(bound);
111 let id = ConnId(SEQ.fetch_add(1, Ordering::Relaxed));
112 let _ = evt_tx.send(IoEvent::Connected { id, peer });
113 spawn_conn(id, stream, evt_tx.clone(), conns.clone()).await;
114 }
115 });
116 }
117 TcpCommand::Send { id, bytes } => {
118 let g = conns.lock().await;
119 if let Some(tx) = g.get(&id) {
120 let _ = tx.send(bytes);
121 }
122 }
123 TcpCommand::Close { id } => {
124 conns.lock().await.remove(&id);
125 }
126 TcpCommand::Shutdown => break,
127 }
128 }
129}
130
131async fn spawn_conn(id: ConnId, stream: TcpStream, evt: mpsc::UnboundedSender<IoEvent>, conns: Conns) {
132 let (write_tx, mut write_rx) = mpsc::unbounded_channel::<Vec<u8>>();
133 conns.lock().await.insert(id, write_tx);
134 let (mut rh, mut wh) = stream.into_split();
135 tokio::spawn(async move {
136 while let Some(bytes) = write_rx.recv().await {
137 if wh.write_all(&bytes).await.is_err() {
138 break;
139 }
140 }
141 let _ = wh.shutdown().await;
142 });
143 let evt2 = evt.clone();
144 tokio::spawn(async move {
145 let mut buf = vec![0u8; 8 * 1024];
146 loop {
147 match rh.read(&mut buf).await {
148 Ok(0) | Err(_) => {
149 let _ = evt2.send(IoEvent::Closed { id });
150 break;
151 }
152 Ok(n) => {
153 let _ = evt2.send(IoEvent::Received { id, bytes: buf[..n].to_vec() });
154 }
155 }
156 }
157 });
158}
159
160#[derive(Debug)]
161pub enum UdpCommand {
162 Send { to: SocketAddr, bytes: Vec<u8> },
163 Shutdown,
164}
165
166pub struct UdpManager {
168 cmd: mpsc::UnboundedSender<UdpCommand>,
169 local: SocketAddr,
170}
171
172impl UdpManager {
173 pub async fn bind(addr: SocketAddr) -> io::Result<(Self, mpsc::UnboundedReceiver<IoEvent>)> {
174 let socket = UdpSocket::bind(addr).await?;
175 let local = socket.local_addr()?;
176 let (cmd_tx, mut cmd_rx) = mpsc::unbounded_channel();
177 let (evt_tx, evt_rx) = mpsc::unbounded_channel();
178 let socket = Arc::new(socket);
179 let s_recv = socket.clone();
180 let etx = evt_tx.clone();
181 tokio::spawn(async move {
182 let mut buf = vec![0u8; 64 * 1024];
183 loop {
184 match s_recv.recv_from(&mut buf).await {
185 Ok((n, from)) => {
186 let _ = etx.send(IoEvent::Datagram { from, bytes: buf[..n].to_vec() });
187 }
188 Err(e) => {
189 let _ = etx.send(IoEvent::Error { reason: e.to_string() });
190 break;
191 }
192 }
193 }
194 });
195 let s_send = socket.clone();
196 tokio::spawn(async move {
197 while let Some(c) = cmd_rx.recv().await {
198 match c {
199 UdpCommand::Send { to, bytes } => {
200 if let Err(e) = s_send.send_to(&bytes, to).await {
201 let _ = evt_tx.send(IoEvent::Error { reason: e.to_string() });
202 }
203 }
204 UdpCommand::Shutdown => break,
205 }
206 }
207 });
208 Ok((Self { cmd: cmd_tx, local }, evt_rx))
209 }
210
211 pub fn local_addr(&self) -> SocketAddr {
212 self.local
213 }
214
215 pub fn send_to(&self, to: SocketAddr, bytes: Vec<u8>) -> io::Result<()> {
216 self.cmd
217 .send(UdpCommand::Send { to, bytes })
218 .map_err(|_| io::Error::new(io::ErrorKind::BrokenPipe, "manager stopped"))
219 }
220 pub fn shutdown(&self) {
221 let _ = self.cmd.send(UdpCommand::Shutdown);
222 }
223}
224
225#[cfg(test)]
226mod tests {
227 use super::*;
228
229 #[tokio::test]
230 async fn udp_manager_round_trip() {
231 let (a, mut a_rx) = UdpManager::bind("127.0.0.1:0".parse().unwrap()).await.unwrap();
232 let (b, _b_rx) = UdpManager::bind("127.0.0.1:0".parse().unwrap()).await.unwrap();
233 b.send_to(a.local_addr(), b"hi".to_vec()).unwrap();
234 let evt =
235 tokio::time::timeout(std::time::Duration::from_millis(500), a_rx.recv()).await.unwrap().unwrap();
236 match evt {
237 IoEvent::Datagram { bytes, .. } => assert_eq!(bytes, b"hi"),
238 other => panic!("unexpected event: {other:?}"),
239 }
240 a.shutdown();
241 b.shutdown();
242 }
243
244 #[tokio::test]
245 async fn tcp_manager_accept_and_echo() {
246 let (mgr, mut events) = TcpManager::spawn();
247 mgr.bind("127.0.0.1:0".parse().unwrap()).unwrap();
248 let bound = match tokio::time::timeout(std::time::Duration::from_millis(500), events.recv())
249 .await
250 .unwrap()
251 .unwrap()
252 {
253 IoEvent::Bound { addr } => addr,
254 other => panic!("expected Bound, got {other:?}"),
255 };
256 let mut client = TcpStream::connect(bound).await.unwrap();
257 let id = match tokio::time::timeout(std::time::Duration::from_millis(500), events.recv())
258 .await
259 .unwrap()
260 .unwrap()
261 {
262 IoEvent::Connected { id, .. } => id,
263 other => panic!("expected Connected, got {other:?}"),
264 };
265 client.write_all(b"ping").await.unwrap();
266 match tokio::time::timeout(std::time::Duration::from_millis(500), events.recv())
267 .await
268 .unwrap()
269 .unwrap()
270 {
271 IoEvent::Received { bytes, .. } => assert_eq!(bytes, b"ping"),
272 other => panic!("expected Received, got {other:?}"),
273 }
274 mgr.send_bytes(id, b"pong".to_vec()).unwrap();
275 let mut buf = [0u8; 4];
276 client.read_exact(&mut buf).await.unwrap();
277 assert_eq!(&buf, b"pong");
278 mgr.shutdown();
279 }
280}