Skip to main content

walrus_socket/
server.rs

1//! Unix domain socket server — accept loop and per-connection message handler.
2
3use crate::codec;
4use protocol::message::{client::ClientMessage, server::ServerMessage};
5use tokio::{
6    net::{
7        UnixListener,
8        unix::{OwnedReadHalf, OwnedWriteHalf},
9    },
10    sync::{mpsc, oneshot},
11};
12
13/// Accept connections on the given `UnixListener` until shutdown is signalled.
14///
15/// Each connection is handled in a separate task. For each incoming
16/// `ClientMessage`, calls `on_message(msg, reply_tx)` where `reply_tx` is
17/// the per-connection sender for streaming `ServerMessage`s back. The caller
18/// controls dispatch routing (DD#11).
19pub async fn accept_loop<F>(
20    listener: UnixListener,
21    on_message: F,
22    mut shutdown: oneshot::Receiver<()>,
23) where
24    F: Fn(ClientMessage, mpsc::UnboundedSender<ServerMessage>) + Clone + Send + 'static,
25{
26    loop {
27        tokio::select! {
28            result = listener.accept() => {
29                match result {
30                    Ok((stream, _addr)) => {
31                        let cb = on_message.clone();
32                        tokio::spawn(async move {
33                            handle_connection(stream, cb).await;
34                        });
35                    }
36                    Err(e) => {
37                        tracing::error!("failed to accept connection: {e}");
38                    }
39                }
40            }
41            _ = &mut shutdown => {
42                tracing::info!("accept loop shutting down");
43                break;
44            }
45        }
46    }
47}
48
49/// Handle an established Unix domain socket connection.
50async fn handle_connection<F>(stream: tokio::net::UnixStream, on_message: F)
51where
52    F: Fn(ClientMessage, mpsc::UnboundedSender<ServerMessage>),
53{
54    let (reader, writer) = stream.into_split();
55    let (tx, rx) = mpsc::unbounded_channel::<ServerMessage>();
56
57    // Sender task: forward ServerMessages to the socket.
58    let send_task = tokio::spawn(sender_loop(writer, rx));
59
60    // Receiver loop: process incoming ClientMessages.
61    receiver_loop(reader, tx, on_message).await;
62
63    // Clean up — dropping tx already happened in receiver_loop on exit,
64    // which causes sender_loop to end.
65    let _ = send_task.await;
66}
67
68/// Reads messages from the mpsc channel and writes them to the socket.
69async fn sender_loop(mut writer: OwnedWriteHalf, mut rx: mpsc::UnboundedReceiver<ServerMessage>) {
70    while let Some(msg) = rx.recv().await {
71        if let Err(e) = codec::write_message(&mut writer, &msg).await {
72            tracing::error!("failed to write message: {e}");
73            break;
74        }
75    }
76}
77
78/// Reads client messages from the socket and dispatches via callback.
79async fn receiver_loop<F>(
80    mut reader: OwnedReadHalf,
81    tx: mpsc::UnboundedSender<ServerMessage>,
82    on_message: F,
83) where
84    F: Fn(ClientMessage, mpsc::UnboundedSender<ServerMessage>),
85{
86    loop {
87        let client_msg: ClientMessage = match codec::read_message(&mut reader).await {
88            Ok(msg) => msg,
89            Err(codec::FrameError::ConnectionClosed) => break,
90            Err(e) => {
91                tracing::debug!("read error: {e}");
92                break;
93            }
94        };
95
96        on_message(client_msg, tx.clone());
97    }
98}