Skip to main content

atomr_core/io/
manager.rs

1//! `TcpManager` / `UdpManager` actor-style wrappers.
2//!
3//! is an actor that mediates `Bind`/`Connect`
4//! commands and dispatches per-connection child actors. Our equivalent is
5//! a small state machine driven by mpsc channels — callers get an
6//! [`IoEvent`] stream of inbound connections / read bytes / disconnects.
7
8use std::collections::HashMap;
9use std::io;
10use std::net::SocketAddr;
11use std::sync::atomic::{AtomicU64, Ordering};
12use std::sync::Arc;
13
14use tokio::io::{AsyncReadExt, AsyncWriteExt};
15use tokio::net::{TcpListener, TcpStream, UdpSocket};
16use tokio::sync::{mpsc, Mutex};
17
18/// Stable identifier for an inbound TCP connection.
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
20pub struct ConnId(pub u64);
21
22/// Events emitted by [`TcpManager`] / [`UdpManager`].
23#[derive(Debug)]
24pub enum IoEvent {
25    Connected { id: ConnId, peer: SocketAddr },
26    Received { id: ConnId, bytes: Vec<u8> },
27    Closed { id: ConnId },
28    Bound { addr: SocketAddr },
29    Datagram { from: SocketAddr, bytes: Vec<u8> },
30    Error { reason: String },
31}
32
33/// Commands sent into the [`TcpManager`].
34#[derive(Debug)]
35pub enum TcpCommand {
36    /// Listen on `addr`. The kernel-assigned port flows back as
37    /// `IoEvent::Bound { addr }`.
38    Bind {
39        addr: SocketAddr,
40    },
41    /// Initiate an outbound connection. On success a
42    /// `IoEvent::Connected { id, peer }` is published; subsequent
43    /// reads / writes use the same `ConnId` API as inbound.
44    Connect {
45        addr: SocketAddr,
46    },
47    Send {
48        id: ConnId,
49        bytes: Vec<u8>,
50    },
51    Close {
52        id: ConnId,
53    },
54    Shutdown,
55}
56
57type Conns = Arc<Mutex<HashMap<ConnId, mpsc::UnboundedSender<Vec<u8>>>>>;
58
59/// Actor-style TCP manager. Drop the handle (or call [`Self::shutdown`])
60/// to stop it.
61pub struct TcpManager {
62    cmd: mpsc::UnboundedSender<TcpCommand>,
63}
64
65impl TcpManager {
66    /// Spawn the manager and return the command handle + event stream.
67    pub fn spawn() -> (Self, mpsc::UnboundedReceiver<IoEvent>) {
68        let (cmd_tx, cmd_rx) = mpsc::unbounded_channel();
69        let (evt_tx, evt_rx) = mpsc::unbounded_channel();
70        let conns: Conns = Arc::new(Mutex::new(HashMap::new()));
71        tokio::spawn(run_tcp(cmd_rx, evt_tx, conns));
72        (Self { cmd: cmd_tx }, evt_rx)
73    }
74
75    pub fn bind(&self, addr: SocketAddr) -> io::Result<()> {
76        self.cmd
77            .send(TcpCommand::Bind { addr })
78            .map_err(|_| io::Error::new(io::ErrorKind::BrokenPipe, "manager stopped"))
79    }
80    /// Initiate an outbound connection. On success the manager
81    /// publishes `IoEvent::Connected { id, peer }`; on failure it
82    /// publishes `IoEvent::Error { reason }`.
83    pub fn connect(&self, addr: SocketAddr) -> io::Result<()> {
84        self.cmd
85            .send(TcpCommand::Connect { addr })
86            .map_err(|_| io::Error::new(io::ErrorKind::BrokenPipe, "manager stopped"))
87    }
88    pub fn send_bytes(&self, id: ConnId, bytes: Vec<u8>) -> io::Result<()> {
89        self.cmd
90            .send(TcpCommand::Send { id, bytes })
91            .map_err(|_| io::Error::new(io::ErrorKind::BrokenPipe, "manager stopped"))
92    }
93    pub fn close(&self, id: ConnId) -> io::Result<()> {
94        self.cmd
95            .send(TcpCommand::Close { id })
96            .map_err(|_| io::Error::new(io::ErrorKind::BrokenPipe, "manager stopped"))
97    }
98    pub fn shutdown(&self) {
99        let _ = self.cmd.send(TcpCommand::Shutdown);
100    }
101}
102
103static SEQ: AtomicU64 = AtomicU64::new(1);
104
105async fn run_tcp(
106    mut cmd: mpsc::UnboundedReceiver<TcpCommand>,
107    evt: mpsc::UnboundedSender<IoEvent>,
108    conns: Conns,
109) {
110    while let Some(c) = cmd.recv().await {
111        match c {
112            TcpCommand::Bind { addr } => {
113                let evt_tx = evt.clone();
114                let conns = conns.clone();
115                tokio::spawn(async move {
116                    let listener = match TcpListener::bind(addr).await {
117                        Ok(l) => l,
118                        Err(e) => {
119                            let _ = evt_tx.send(IoEvent::Error { reason: e.to_string() });
120                            return;
121                        }
122                    };
123                    let bound = listener.local_addr().unwrap_or(addr);
124                    let _ = evt_tx.send(IoEvent::Bound { addr: bound });
125                    loop {
126                        let stream = match listener.accept().await {
127                            Ok((s, _)) => s,
128                            Err(e) => {
129                                let _ = evt_tx.send(IoEvent::Error { reason: e.to_string() });
130                                break;
131                            }
132                        };
133                        let peer = stream.peer_addr().unwrap_or(bound);
134                        let id = ConnId(SEQ.fetch_add(1, Ordering::Relaxed));
135                        let _ = evt_tx.send(IoEvent::Connected { id, peer });
136                        spawn_conn(id, stream, evt_tx.clone(), conns.clone()).await;
137                    }
138                });
139            }
140            TcpCommand::Connect { addr } => {
141                let evt_tx = evt.clone();
142                let conns = conns.clone();
143                tokio::spawn(async move {
144                    let stream = match TcpStream::connect(addr).await {
145                        Ok(s) => s,
146                        Err(e) => {
147                            let _ = evt_tx.send(IoEvent::Error { reason: e.to_string() });
148                            return;
149                        }
150                    };
151                    let peer = stream.peer_addr().unwrap_or(addr);
152                    let id = ConnId(SEQ.fetch_add(1, Ordering::Relaxed));
153                    let _ = evt_tx.send(IoEvent::Connected { id, peer });
154                    spawn_conn(id, stream, evt_tx, conns).await;
155                });
156            }
157            TcpCommand::Send { id, bytes } => {
158                let g = conns.lock().await;
159                if let Some(tx) = g.get(&id) {
160                    let _ = tx.send(bytes);
161                }
162            }
163            TcpCommand::Close { id } => {
164                conns.lock().await.remove(&id);
165            }
166            TcpCommand::Shutdown => break,
167        }
168    }
169}
170
171async fn spawn_conn(id: ConnId, stream: TcpStream, evt: mpsc::UnboundedSender<IoEvent>, conns: Conns) {
172    let (write_tx, mut write_rx) = mpsc::unbounded_channel::<Vec<u8>>();
173    conns.lock().await.insert(id, write_tx);
174    let (mut rh, mut wh) = stream.into_split();
175    tokio::spawn(async move {
176        while let Some(bytes) = write_rx.recv().await {
177            if wh.write_all(&bytes).await.is_err() {
178                break;
179            }
180        }
181        let _ = wh.shutdown().await;
182    });
183    let evt2 = evt.clone();
184    tokio::spawn(async move {
185        let mut buf = vec![0u8; 8 * 1024];
186        loop {
187            match rh.read(&mut buf).await {
188                Ok(0) | Err(_) => {
189                    let _ = evt2.send(IoEvent::Closed { id });
190                    break;
191                }
192                Ok(n) => {
193                    let _ = evt2.send(IoEvent::Received { id, bytes: buf[..n].to_vec() });
194                }
195            }
196        }
197    });
198}
199
200#[derive(Debug)]
201pub enum UdpCommand {
202    Send { to: SocketAddr, bytes: Vec<u8> },
203    Shutdown,
204}
205
206/// Actor-style UDP manager bound to a single socket.
207pub struct UdpManager {
208    cmd: mpsc::UnboundedSender<UdpCommand>,
209    local: SocketAddr,
210}
211
212impl UdpManager {
213    pub async fn bind(addr: SocketAddr) -> io::Result<(Self, mpsc::UnboundedReceiver<IoEvent>)> {
214        let socket = UdpSocket::bind(addr).await?;
215        let local = socket.local_addr()?;
216        let (cmd_tx, mut cmd_rx) = mpsc::unbounded_channel();
217        let (evt_tx, evt_rx) = mpsc::unbounded_channel();
218        let socket = Arc::new(socket);
219        let s_recv = socket.clone();
220        let etx = evt_tx.clone();
221        tokio::spawn(async move {
222            let mut buf = vec![0u8; 64 * 1024];
223            loop {
224                match s_recv.recv_from(&mut buf).await {
225                    Ok((n, from)) => {
226                        let _ = etx.send(IoEvent::Datagram { from, bytes: buf[..n].to_vec() });
227                    }
228                    Err(e) => {
229                        let _ = etx.send(IoEvent::Error { reason: e.to_string() });
230                        break;
231                    }
232                }
233            }
234        });
235        let s_send = socket.clone();
236        tokio::spawn(async move {
237            while let Some(c) = cmd_rx.recv().await {
238                match c {
239                    UdpCommand::Send { to, bytes } => {
240                        if let Err(e) = s_send.send_to(&bytes, to).await {
241                            let _ = evt_tx.send(IoEvent::Error { reason: e.to_string() });
242                        }
243                    }
244                    UdpCommand::Shutdown => break,
245                }
246            }
247        });
248        Ok((Self { cmd: cmd_tx, local }, evt_rx))
249    }
250
251    pub fn local_addr(&self) -> SocketAddr {
252        self.local
253    }
254
255    pub fn send_to(&self, to: SocketAddr, bytes: Vec<u8>) -> io::Result<()> {
256        self.cmd
257            .send(UdpCommand::Send { to, bytes })
258            .map_err(|_| io::Error::new(io::ErrorKind::BrokenPipe, "manager stopped"))
259    }
260    pub fn shutdown(&self) {
261        let _ = self.cmd.send(UdpCommand::Shutdown);
262    }
263}
264
265#[cfg(test)]
266mod tests {
267    use super::*;
268
269    #[tokio::test]
270    async fn udp_manager_round_trip() {
271        let (a, mut a_rx) = UdpManager::bind("127.0.0.1:0".parse().unwrap()).await.unwrap();
272        let (b, _b_rx) = UdpManager::bind("127.0.0.1:0".parse().unwrap()).await.unwrap();
273        b.send_to(a.local_addr(), b"hi".to_vec()).unwrap();
274        let evt =
275            tokio::time::timeout(std::time::Duration::from_millis(500), a_rx.recv()).await.unwrap().unwrap();
276        match evt {
277            IoEvent::Datagram { bytes, .. } => assert_eq!(bytes, b"hi"),
278            other => panic!("unexpected event: {other:?}"),
279        }
280        a.shutdown();
281        b.shutdown();
282    }
283
284    #[tokio::test]
285    async fn tcp_manager_accept_and_echo() {
286        let (mgr, mut events) = TcpManager::spawn();
287        mgr.bind("127.0.0.1:0".parse().unwrap()).unwrap();
288        let bound = match tokio::time::timeout(std::time::Duration::from_millis(500), events.recv())
289            .await
290            .unwrap()
291            .unwrap()
292        {
293            IoEvent::Bound { addr } => addr,
294            other => panic!("expected Bound, got {other:?}"),
295        };
296        let mut client = TcpStream::connect(bound).await.unwrap();
297        let id = match tokio::time::timeout(std::time::Duration::from_millis(500), events.recv())
298            .await
299            .unwrap()
300            .unwrap()
301        {
302            IoEvent::Connected { id, .. } => id,
303            other => panic!("expected Connected, got {other:?}"),
304        };
305        client.write_all(b"ping").await.unwrap();
306        match tokio::time::timeout(std::time::Duration::from_millis(500), events.recv())
307            .await
308            .unwrap()
309            .unwrap()
310        {
311            IoEvent::Received { bytes, .. } => assert_eq!(bytes, b"ping"),
312            other => panic!("expected Received, got {other:?}"),
313        }
314        mgr.send_bytes(id, b"pong".to_vec()).unwrap();
315        let mut buf = [0u8; 4];
316        client.read_exact(&mut buf).await.unwrap();
317        assert_eq!(&buf, b"pong");
318        mgr.shutdown();
319    }
320}