taskers_control/
socket.rs1use std::{future::Future, io, os::unix::net::UnixListener as StdUnixListener, path::Path};
2
3use serde_json::{from_slice, to_vec};
4use tokio::{
5 io::{AsyncBufReadExt, AsyncWriteExt, BufReader},
6 net::{UnixListener, UnixStream},
7};
8
9use crate::{
10 RequestFrame,
11 controller::InMemoryController,
12 protocol::{ControlCommand, ControlError, ControlResponse, ResponseFrame},
13};
14
15pub fn bind_socket(path: impl AsRef<Path>) -> io::Result<UnixListener> {
16 let path = path.as_ref();
17 if path.exists() {
18 std::fs::remove_file(path)?;
19 }
20 let listener = StdUnixListener::bind(path)?;
21 listener.set_nonblocking(true)?;
22 UnixListener::from_std(listener)
23}
24
25pub async fn serve<S>(
26 listener: UnixListener,
27 controller: InMemoryController,
28 shutdown: S,
29) -> io::Result<()>
30where
31 S: Future<Output = ()> + Send,
32{
33 serve_with_handler(
34 listener,
35 move |command| {
36 let controller = controller.clone();
37 async move {
38 controller
39 .handle(command)
40 .map_err(|error| ControlError::internal(error.to_string()))
41 }
42 },
43 shutdown,
44 )
45 .await
46}
47
48pub async fn serve_with_handler<S, H, F>(
49 listener: UnixListener,
50 handler: H,
51 shutdown: S,
52) -> io::Result<()>
53where
54 S: Future<Output = ()> + Send,
55 H: Fn(ControlCommand) -> F + Clone + Send + Sync + 'static,
56 F: Future<Output = Result<ControlResponse, ControlError>> + Send + 'static,
57{
58 tokio::pin!(shutdown);
59
60 loop {
61 tokio::select! {
62 _ = &mut shutdown => break,
63 accepted = listener.accept() => {
64 let (stream, _) = accepted?;
65 let handler = handler.clone();
66 tokio::spawn(async move {
67 let _ = handle_connection_with_handler(stream, handler).await;
68 });
69 }
70 }
71 }
72
73 Ok(())
74}
75
76async fn handle_connection_with_handler<H, F>(stream: UnixStream, handler: H) -> io::Result<()>
77where
78 H: Fn(ControlCommand) -> F + Clone + Send + Sync + 'static,
79 F: Future<Output = Result<ControlResponse, ControlError>> + Send + 'static,
80{
81 let (read_half, mut write_half) = stream.into_split();
82 let mut reader = BufReader::new(read_half);
83 let mut line = String::new();
84 reader.read_line(&mut line).await?;
85
86 let request: RequestFrame = from_slice(line.trim_end().as_bytes()).map_err(invalid_data)?;
87 let result = handler(request.command).await;
88 let response = ResponseFrame {
89 request_id: request.request_id,
90 response: result,
91 };
92 let payload = to_vec(&response).map_err(invalid_data)?;
93 write_half.write_all(&payload).await?;
94 write_half.write_all(b"\n").await?;
95 write_half.flush().await?;
96
97 Ok(())
98}
99
100fn invalid_data(error: impl ToString) -> io::Error {
101 io::Error::new(io::ErrorKind::InvalidData, error.to_string())
102}
103
104#[cfg(test)]
105mod tests {
106 use std::{future::pending, path::PathBuf};
107
108 use tempfile::tempdir;
109 use tokio::sync::oneshot;
110
111 use taskers_domain::AppModel;
112
113 use crate::{
114 client::ControlClient,
115 controller::InMemoryController,
116 protocol::{ControlCommand, ControlQuery, ControlResponse},
117 };
118
119 use super::{bind_socket, serve};
120
121 #[tokio::test]
122 async fn client_and_server_roundtrip() {
123 let tempdir = tempdir().expect("tempdir");
124 let socket_path = PathBuf::from(tempdir.path()).join("taskers.sock");
125 let listener = bind_socket(&socket_path).expect("listener");
126 let controller = InMemoryController::new(AppModel::new("Main"));
127 let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>();
128
129 let server = tokio::spawn(serve(listener, controller.clone(), async move {
130 let _ = shutdown_rx.await;
131 }));
132
133 let client = ControlClient::new(&socket_path);
134 let created = client
135 .send(ControlCommand::CreateWorkspace {
136 label: "Docs".into(),
137 })
138 .await
139 .expect("create workspace request");
140 assert!(matches!(
141 created.response,
142 Ok(ControlResponse::WorkspaceCreated { .. })
143 ));
144
145 let status = client
146 .send(ControlCommand::QueryStatus {
147 query: ControlQuery::All,
148 })
149 .await
150 .expect("query request");
151 match status.response {
152 Ok(ControlResponse::Status { session }) => {
153 assert_eq!(session.model.workspaces.len(), 2);
154 }
155 other => panic!("unexpected response: {other:?}"),
156 }
157
158 shutdown_tx.send(()).expect("shutdown");
159 server.await.expect("server task").expect("serve cleanly");
160 drop(pending::<()>());
161 }
162}