Skip to main content

atomr_core/io/
manager.rs

1//! `TcpManager` / `UdpManager` actor-style wrappers.
2//!
3//! akka.net's `IO.Tcp.Manager` is an actor that mediates `Bind`/`Connect`
4//! commands and dispatches per-connection child actors. Our equivalent is
5//! a small state machine driven by mpsc channels — callers get an
6//! [`IoEvent`] stream of inbound connections / read bytes / disconnects.
7
8use std::collections::HashMap;
9use std::io;
10use std::net::SocketAddr;
11use std::sync::atomic::{AtomicU64, Ordering};
12use std::sync::Arc;
13
14use tokio::io::{AsyncReadExt, AsyncWriteExt};
15use tokio::net::{TcpListener, TcpStream, UdpSocket};
16use tokio::sync::{mpsc, Mutex};
17
18/// Stable identifier for an inbound TCP connection.
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
20pub struct ConnId(pub u64);
21
22/// Events emitted by [`TcpManager`] / [`UdpManager`].
23#[derive(Debug)]
24pub enum IoEvent {
25    Connected { id: ConnId, peer: SocketAddr },
26    Received { id: ConnId, bytes: Vec<u8> },
27    Closed { id: ConnId },
28    Bound { addr: SocketAddr },
29    Datagram { from: SocketAddr, bytes: Vec<u8> },
30    Error { reason: String },
31}
32
33/// Commands sent into the [`TcpManager`].
34#[derive(Debug)]
35pub enum TcpCommand {
36    Bind { addr: SocketAddr },
37    Send { id: ConnId, bytes: Vec<u8> },
38    Close { id: ConnId },
39    Shutdown,
40}
41
42type Conns = Arc<Mutex<HashMap<ConnId, mpsc::UnboundedSender<Vec<u8>>>>>;
43
44/// Actor-style TCP manager. Drop the handle (or call [`Self::shutdown`])
45/// to stop it.
46pub struct TcpManager {
47    cmd: mpsc::UnboundedSender<TcpCommand>,
48}
49
50impl TcpManager {
51    /// Spawn the manager and return the command handle + event stream.
52    pub fn spawn() -> (Self, mpsc::UnboundedReceiver<IoEvent>) {
53        let (cmd_tx, cmd_rx) = mpsc::unbounded_channel();
54        let (evt_tx, evt_rx) = mpsc::unbounded_channel();
55        let conns: Conns = Arc::new(Mutex::new(HashMap::new()));
56        tokio::spawn(run_tcp(cmd_rx, evt_tx, conns));
57        (Self { cmd: cmd_tx }, evt_rx)
58    }
59
60    pub fn bind(&self, addr: SocketAddr) -> io::Result<()> {
61        self.cmd
62            .send(TcpCommand::Bind { addr })
63            .map_err(|_| io::Error::new(io::ErrorKind::BrokenPipe, "manager stopped"))
64    }
65    pub fn send_bytes(&self, id: ConnId, bytes: Vec<u8>) -> io::Result<()> {
66        self.cmd
67            .send(TcpCommand::Send { id, bytes })
68            .map_err(|_| io::Error::new(io::ErrorKind::BrokenPipe, "manager stopped"))
69    }
70    pub fn close(&self, id: ConnId) -> io::Result<()> {
71        self.cmd
72            .send(TcpCommand::Close { id })
73            .map_err(|_| io::Error::new(io::ErrorKind::BrokenPipe, "manager stopped"))
74    }
75    pub fn shutdown(&self) {
76        let _ = self.cmd.send(TcpCommand::Shutdown);
77    }
78}
79
80static SEQ: AtomicU64 = AtomicU64::new(1);
81
82async fn run_tcp(
83    mut cmd: mpsc::UnboundedReceiver<TcpCommand>,
84    evt: mpsc::UnboundedSender<IoEvent>,
85    conns: Conns,
86) {
87    while let Some(c) = cmd.recv().await {
88        match c {
89            TcpCommand::Bind { addr } => {
90                let evt_tx = evt.clone();
91                let conns = conns.clone();
92                tokio::spawn(async move {
93                    let listener = match TcpListener::bind(addr).await {
94                        Ok(l) => l,
95                        Err(e) => {
96                            let _ = evt_tx.send(IoEvent::Error { reason: e.to_string() });
97                            return;
98                        }
99                    };
100                    let bound = listener.local_addr().unwrap_or(addr);
101                    let _ = evt_tx.send(IoEvent::Bound { addr: bound });
102                    loop {
103                        let stream = match listener.accept().await {
104                            Ok((s, _)) => s,
105                            Err(e) => {
106                                let _ = evt_tx.send(IoEvent::Error { reason: e.to_string() });
107                                break;
108                            }
109                        };
110                        let peer = stream.peer_addr().unwrap_or(bound);
111                        let id = ConnId(SEQ.fetch_add(1, Ordering::Relaxed));
112                        let _ = evt_tx.send(IoEvent::Connected { id, peer });
113                        spawn_conn(id, stream, evt_tx.clone(), conns.clone()).await;
114                    }
115                });
116            }
117            TcpCommand::Send { id, bytes } => {
118                let g = conns.lock().await;
119                if let Some(tx) = g.get(&id) {
120                    let _ = tx.send(bytes);
121                }
122            }
123            TcpCommand::Close { id } => {
124                conns.lock().await.remove(&id);
125            }
126            TcpCommand::Shutdown => break,
127        }
128    }
129}
130
131async fn spawn_conn(id: ConnId, stream: TcpStream, evt: mpsc::UnboundedSender<IoEvent>, conns: Conns) {
132    let (write_tx, mut write_rx) = mpsc::unbounded_channel::<Vec<u8>>();
133    conns.lock().await.insert(id, write_tx);
134    let (mut rh, mut wh) = stream.into_split();
135    tokio::spawn(async move {
136        while let Some(bytes) = write_rx.recv().await {
137            if wh.write_all(&bytes).await.is_err() {
138                break;
139            }
140        }
141        let _ = wh.shutdown().await;
142    });
143    let evt2 = evt.clone();
144    tokio::spawn(async move {
145        let mut buf = vec![0u8; 8 * 1024];
146        loop {
147            match rh.read(&mut buf).await {
148                Ok(0) | Err(_) => {
149                    let _ = evt2.send(IoEvent::Closed { id });
150                    break;
151                }
152                Ok(n) => {
153                    let _ = evt2.send(IoEvent::Received { id, bytes: buf[..n].to_vec() });
154                }
155            }
156        }
157    });
158}
159
160#[derive(Debug)]
161pub enum UdpCommand {
162    Send { to: SocketAddr, bytes: Vec<u8> },
163    Shutdown,
164}
165
166/// Actor-style UDP manager bound to a single socket.
167pub struct UdpManager {
168    cmd: mpsc::UnboundedSender<UdpCommand>,
169    local: SocketAddr,
170}
171
172impl UdpManager {
173    pub async fn bind(addr: SocketAddr) -> io::Result<(Self, mpsc::UnboundedReceiver<IoEvent>)> {
174        let socket = UdpSocket::bind(addr).await?;
175        let local = socket.local_addr()?;
176        let (cmd_tx, mut cmd_rx) = mpsc::unbounded_channel();
177        let (evt_tx, evt_rx) = mpsc::unbounded_channel();
178        let socket = Arc::new(socket);
179        let s_recv = socket.clone();
180        let etx = evt_tx.clone();
181        tokio::spawn(async move {
182            let mut buf = vec![0u8; 64 * 1024];
183            loop {
184                match s_recv.recv_from(&mut buf).await {
185                    Ok((n, from)) => {
186                        let _ = etx.send(IoEvent::Datagram { from, bytes: buf[..n].to_vec() });
187                    }
188                    Err(e) => {
189                        let _ = etx.send(IoEvent::Error { reason: e.to_string() });
190                        break;
191                    }
192                }
193            }
194        });
195        let s_send = socket.clone();
196        tokio::spawn(async move {
197            while let Some(c) = cmd_rx.recv().await {
198                match c {
199                    UdpCommand::Send { to, bytes } => {
200                        if let Err(e) = s_send.send_to(&bytes, to).await {
201                            let _ = evt_tx.send(IoEvent::Error { reason: e.to_string() });
202                        }
203                    }
204                    UdpCommand::Shutdown => break,
205                }
206            }
207        });
208        Ok((Self { cmd: cmd_tx, local }, evt_rx))
209    }
210
211    pub fn local_addr(&self) -> SocketAddr {
212        self.local
213    }
214
215    pub fn send_to(&self, to: SocketAddr, bytes: Vec<u8>) -> io::Result<()> {
216        self.cmd
217            .send(UdpCommand::Send { to, bytes })
218            .map_err(|_| io::Error::new(io::ErrorKind::BrokenPipe, "manager stopped"))
219    }
220    pub fn shutdown(&self) {
221        let _ = self.cmd.send(UdpCommand::Shutdown);
222    }
223}
224
225#[cfg(test)]
226mod tests {
227    use super::*;
228
229    #[tokio::test]
230    async fn udp_manager_round_trip() {
231        let (a, mut a_rx) = UdpManager::bind("127.0.0.1:0".parse().unwrap()).await.unwrap();
232        let (b, _b_rx) = UdpManager::bind("127.0.0.1:0".parse().unwrap()).await.unwrap();
233        b.send_to(a.local_addr(), b"hi".to_vec()).unwrap();
234        let evt =
235            tokio::time::timeout(std::time::Duration::from_millis(500), a_rx.recv()).await.unwrap().unwrap();
236        match evt {
237            IoEvent::Datagram { bytes, .. } => assert_eq!(bytes, b"hi"),
238            other => panic!("unexpected event: {other:?}"),
239        }
240        a.shutdown();
241        b.shutdown();
242    }
243
244    #[tokio::test]
245    async fn tcp_manager_accept_and_echo() {
246        let (mgr, mut events) = TcpManager::spawn();
247        mgr.bind("127.0.0.1:0".parse().unwrap()).unwrap();
248        let bound = match tokio::time::timeout(std::time::Duration::from_millis(500), events.recv())
249            .await
250            .unwrap()
251            .unwrap()
252        {
253            IoEvent::Bound { addr } => addr,
254            other => panic!("expected Bound, got {other:?}"),
255        };
256        let mut client = TcpStream::connect(bound).await.unwrap();
257        let id = match tokio::time::timeout(std::time::Duration::from_millis(500), events.recv())
258            .await
259            .unwrap()
260            .unwrap()
261        {
262            IoEvent::Connected { id, .. } => id,
263            other => panic!("expected Connected, got {other:?}"),
264        };
265        client.write_all(b"ping").await.unwrap();
266        match tokio::time::timeout(std::time::Duration::from_millis(500), events.recv())
267            .await
268            .unwrap()
269            .unwrap()
270        {
271            IoEvent::Received { bytes, .. } => assert_eq!(bytes, b"ping"),
272            other => panic!("expected Received, got {other:?}"),
273        }
274        mgr.send_bytes(id, b"pong".to_vec()).unwrap();
275        let mut buf = [0u8; 4];
276        client.read_exact(&mut buf).await.unwrap();
277        assert_eq!(&buf, b"pong");
278        mgr.shutdown();
279    }
280}