Skip to main content

walrus_socket/
server.rs

1//! Unix domain socket server — accept loop and per-connection message handler.
2
3use crate::codec;
4use tokio::{
5    net::UnixListener,
6    sync::{mpsc, oneshot},
7};
8use wcore::protocol::message::{client::ClientMessage, server::ServerMessage};
9
10/// Accept connections on the given `UnixListener` until shutdown is signalled.
11///
12/// Each connection is handled in a separate task. For each incoming
13/// `ClientMessage`, calls `on_message(msg, reply_tx)` where `reply_tx` is
14/// the per-connection sender for streaming `ServerMessage`s back.
15pub async fn accept_loop<F>(
16    listener: UnixListener,
17    on_message: F,
18    mut shutdown: oneshot::Receiver<()>,
19) where
20    F: Fn(ClientMessage, mpsc::UnboundedSender<ServerMessage>) + Clone + Send + 'static,
21{
22    loop {
23        tokio::select! {
24            result = listener.accept() => {
25                match result {
26                    Ok((stream, _addr)) => {
27                        let cb = on_message.clone();
28                        tokio::spawn(async move {
29                            let (mut reader, mut writer) = stream.into_split();
30                            let (tx, mut rx) = mpsc::unbounded_channel::<ServerMessage>();
31                            let send_task = tokio::spawn(async move {
32                                while let Some(msg) = rx.recv().await {
33                                    if let Err(e) = codec::write_message(&mut writer, &msg).await {
34                                        tracing::error!("failed to write message: {e}");
35                                        break;
36                                    }
37                                }
38                            });
39
40                            loop {
41                                let client_msg: ClientMessage = match codec::read_message(&mut reader).await {
42                                    Ok(msg) => msg,
43                                    Err(codec::FrameError::ConnectionClosed) => break,
44                                    Err(e) => { tracing::debug!("read error: {e}"); break; }
45                                };
46                                cb(client_msg, tx.clone());
47                            }
48
49                            drop(tx);
50                            let _ = send_task.await;
51                        });
52                    }
53                    Err(e) => tracing::error!("failed to accept connection: {e}"),
54                }
55            }
56            _ = &mut shutdown => {
57                tracing::info!("accept loop shutting down");
58                break;
59            }
60        }
61    }
62}