Skip to main content

taskers_control/
socket.rs

1use std::{future::Future, io, os::unix::net::UnixListener as StdUnixListener, path::Path};
2
3use serde_json::{from_slice, to_vec};
4use tokio::{
5    io::{AsyncBufReadExt, AsyncWriteExt, BufReader},
6    net::{UnixListener, UnixStream},
7};
8
9use crate::{
10    RequestFrame,
11    controller::InMemoryController,
12    protocol::{ControlCommand, ControlError, ControlResponse, ResponseFrame},
13};
14
15pub fn bind_socket(path: impl AsRef<Path>) -> io::Result<UnixListener> {
16    let path = path.as_ref();
17    if path.exists() {
18        std::fs::remove_file(path)?;
19    }
20    let listener = StdUnixListener::bind(path)?;
21    listener.set_nonblocking(true)?;
22    UnixListener::from_std(listener)
23}
24
25pub async fn serve<S>(
26    listener: UnixListener,
27    controller: InMemoryController,
28    shutdown: S,
29) -> io::Result<()>
30where
31    S: Future<Output = ()> + Send,
32{
33    serve_with_handler(
34        listener,
35        move |command| {
36            let controller = controller.clone();
37            async move {
38                controller
39                    .handle(command)
40                    .map_err(|error| ControlError::internal(error.to_string()))
41            }
42        },
43        shutdown,
44    )
45    .await
46}
47
48pub async fn serve_with_handler<S, H, F>(
49    listener: UnixListener,
50    handler: H,
51    shutdown: S,
52) -> io::Result<()>
53where
54    S: Future<Output = ()> + Send,
55    H: Fn(ControlCommand) -> F + Clone + Send + Sync + 'static,
56    F: Future<Output = Result<ControlResponse, ControlError>> + Send + 'static,
57{
58    tokio::pin!(shutdown);
59
60    loop {
61        tokio::select! {
62            _ = &mut shutdown => break,
63            accepted = listener.accept() => {
64                let (stream, _) = accepted?;
65                let handler = handler.clone();
66                tokio::spawn(async move {
67                    let _ = handle_connection_with_handler(stream, handler).await;
68                });
69            }
70        }
71    }
72
73    Ok(())
74}
75
76async fn handle_connection_with_handler<H, F>(stream: UnixStream, handler: H) -> io::Result<()>
77where
78    H: Fn(ControlCommand) -> F + Clone + Send + Sync + 'static,
79    F: Future<Output = Result<ControlResponse, ControlError>> + Send + 'static,
80{
81    let (read_half, mut write_half) = stream.into_split();
82    let mut reader = BufReader::new(read_half);
83    let mut line = String::new();
84    reader.read_line(&mut line).await?;
85
86    let request: RequestFrame = from_slice(line.trim_end().as_bytes()).map_err(invalid_data)?;
87    let result = handler(request.command).await;
88    let response = ResponseFrame {
89        request_id: request.request_id,
90        response: result,
91    };
92    let payload = to_vec(&response).map_err(invalid_data)?;
93    write_half.write_all(&payload).await?;
94    write_half.write_all(b"\n").await?;
95    write_half.flush().await?;
96
97    Ok(())
98}
99
100fn invalid_data(error: impl ToString) -> io::Error {
101    io::Error::new(io::ErrorKind::InvalidData, error.to_string())
102}
103
104#[cfg(test)]
105mod tests {
106    use std::{future::pending, path::PathBuf};
107
108    use tempfile::tempdir;
109    use tokio::sync::oneshot;
110
111    use taskers_domain::AppModel;
112
113    use crate::{
114        client::ControlClient,
115        controller::InMemoryController,
116        protocol::{ControlCommand, ControlQuery, ControlResponse},
117    };
118
119    use super::{bind_socket, serve};
120
121    #[tokio::test]
122    async fn client_and_server_roundtrip() {
123        let tempdir = tempdir().expect("tempdir");
124        let socket_path = PathBuf::from(tempdir.path()).join("taskers.sock");
125        let listener = bind_socket(&socket_path).expect("listener");
126        let controller = InMemoryController::new(AppModel::new("Main"));
127        let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>();
128
129        let server = tokio::spawn(serve(listener, controller.clone(), async move {
130            let _ = shutdown_rx.await;
131        }));
132
133        let client = ControlClient::new(&socket_path);
134        let created = client
135            .send(ControlCommand::CreateWorkspace {
136                label: "Docs".into(),
137            })
138            .await
139            .expect("create workspace request");
140        assert!(matches!(
141            created.response,
142            Ok(ControlResponse::WorkspaceCreated { .. })
143        ));
144
145        let status = client
146            .send(ControlCommand::QueryStatus {
147                query: ControlQuery::All,
148            })
149            .await
150            .expect("query request");
151        match status.response {
152            Ok(ControlResponse::Status { session }) => {
153                assert_eq!(session.model.workspaces.len(), 2);
154            }
155            other => panic!("unexpected response: {other:?}"),
156        }
157
158        shutdown_tx.send(()).expect("shutdown");
159        server.await.expect("server task").expect("serve cleanly");
160        drop(pending::<()>());
161    }
162}