1use std::os::unix::fs::PermissionsExt;
7use std::path::{Path, PathBuf};
8use std::sync::Arc;
9
10use async_trait::async_trait;
11use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader};
12use tokio::net::{UnixListener, UnixStream};
13use tokio::task::JoinHandle;
14use tokio_util::sync::CancellationToken;
15
16use crate::protocol::{
17 EndMarker, Request, Response, ResponseOutcome, WireError, WireErrorKind, encode_line,
18};
19
20#[async_trait]
24pub trait Handler: Send + Sync + 'static {
25 async fn dispatch(&self, req: Request) -> DispatchOutcome;
29}
30
31pub enum DispatchOutcome {
36 OneShot(Result<serde_json::Value, WireError>),
38 Stream(Box<dyn EventStream + Send>),
41}
42
43#[async_trait]
46pub trait EventStream: Send {
47 async fn next_event(&mut self) -> Option<serde_json::Value>;
50}
51
52pub async fn spawn_unix_server<H: Handler>(
66 socket_path: &Path,
67 handler: Arc<H>,
68 cancel: CancellationToken,
69) -> std::io::Result<JoinHandle<()>> {
70 let _ = std::fs::remove_file(socket_path);
73 let listener = UnixListener::bind(socket_path)?;
74
75 let perms = std::fs::Permissions::from_mode(0o600);
76 std::fs::set_permissions(socket_path, perms)?;
77
78 let socket_path: PathBuf = socket_path.to_path_buf();
79 let handle = tokio::spawn(async move {
80 loop {
81 tokio::select! {
82 biased;
83 () = cancel.cancelled() => {
84 let _ = std::fs::remove_file(&socket_path);
85 return;
86 }
87 accepted = listener.accept() => {
88 let stream: UnixStream = match accepted {
89 Ok((s, _)) => s,
90 Err(e) => {
91 tracing::warn!(?e, "mgmt accept failed");
92 continue;
93 }
94 };
95 let h = Arc::clone(&handler);
96 tokio::spawn(async move {
97 let (read, write) = stream.into_split();
98 handle_conn(read, write, h).await;
99 });
100 }
101 }
102 }
103 });
104 Ok(handle)
105}
106
107pub(crate) async fn handle_conn<R, W, H>(read: R, mut write: W, handler: Arc<H>)
112where
113 R: AsyncRead + Unpin,
114 W: AsyncWrite + Unpin,
115 H: Handler,
116{
117 let mut lines = BufReader::new(read).lines();
118 loop {
119 let line = match lines.next_line().await {
120 Ok(Some(l)) => l,
121 Ok(None) => return,
122 Err(e) => {
123 tracing::debug!(?e, "mgmt read failed");
124 return;
125 }
126 };
127 if line.is_empty() {
128 continue;
129 }
130 match serde_json::from_str::<Request>(&line) {
131 Ok(req) => {
132 let id = req.id;
133 match handler.dispatch(req).await {
134 DispatchOutcome::OneShot(Ok(value)) => {
135 let frame = Response { id, outcome: ResponseOutcome::Result { result: value } };
136 if write_frame(&mut write, &frame).await.is_err() {
137 return;
138 }
139 }
140 DispatchOutcome::OneShot(Err(error)) => {
141 let frame = Response { id, outcome: ResponseOutcome::Error { error } };
142 if write_frame(&mut write, &frame).await.is_err() {
143 return;
144 }
145 }
146 DispatchOutcome::Stream(mut stream) => {
147 loop {
153 let Some(event) = stream.next_event().await else {
154 let end =
155 Response { id, outcome: ResponseOutcome::End { end: EndMarker::default() } };
156 let _ = write_frame(&mut write, &end).await;
157 return;
158 };
159 let frame = Response { id, outcome: ResponseOutcome::Event { event } };
160 if write_frame(&mut write, &frame).await.is_err() {
161 return;
162 }
163 }
164 }
165 }
166 }
167 Err(e) => {
168 let frame = Response {
169 id: 0,
172 outcome: ResponseOutcome::Error {
173 error: WireError { kind: WireErrorKind::BadArgs, message: format!("parse: {e}") },
174 },
175 };
176 if write_frame(&mut write, &frame).await.is_err() {
177 return;
178 }
179 }
180 }
181 }
182}
183
184async fn write_frame<W: AsyncWrite + Unpin>(
188 write: &mut W,
189 frame: &Response,
190) -> Result<(), std::io::Error> {
191 let bytes = match encode_line(frame) {
192 Ok(b) => b,
193 Err(e) => {
194 tracing::error!(?e, "mgmt response encode failed");
195 return Err(std::io::Error::other(e));
196 }
197 };
198 write.write_all(&bytes).await
199}
200
201#[cfg(test)]
202mod tests {
203 use super::*;
204 use std::sync::Mutex;
205
206 struct StubHandler {
207 last_verb: Mutex<Option<String>>,
209 }
210
211 #[async_trait]
212 impl Handler for StubHandler {
213 async fn dispatch(&self, req: Request) -> DispatchOutcome {
214 *self.last_verb.lock().unwrap() = Some(req.verb.clone());
215 let result: Result<serde_json::Value, WireError> = match req.verb.as_str() {
216 "ping" => Ok(serde_json::json!({ "pong": true })),
217 "echo" => Ok(req.args),
218 "stream2" => {
219 return DispatchOutcome::Stream(Box::new(MockStream::with_two_events()));
220 }
221 _ => Err(WireError {
222 kind: WireErrorKind::UnknownVerb,
223 message: format!("unknown {}", req.verb),
224 }),
225 };
226 DispatchOutcome::OneShot(result)
227 }
228 }
229
230 struct MockStream {
233 remaining: Vec<serde_json::Value>,
234 }
235
236 impl MockStream {
237 fn with_two_events() -> Self {
238 Self { remaining: vec![serde_json::json!({ "n": 1 }), serde_json::json!({ "n": 2 })] }
242 }
243 }
244
245 #[async_trait]
246 impl EventStream for MockStream {
247 async fn next_event(&mut self) -> Option<serde_json::Value> {
248 self.remaining.pop()
249 }
250 }
251
252 async fn drive(handler: Arc<StubHandler>, requests: &str) -> Vec<u8> {
255 let (c2s_r, mut c2s_w) = tokio::io::duplex(8192);
258 let (s2c_w, mut s2c_r) = tokio::io::duplex(8192);
259 let req = requests.to_string();
260 let server_task = tokio::spawn(handle_conn(c2s_r, s2c_w, handler));
261 c2s_w.write_all(req.as_bytes()).await.expect("write requests");
262 drop(c2s_w);
265 server_task.await.expect("server task");
266 let mut buf = Vec::new();
268 tokio::io::AsyncReadExt::read_to_end(&mut s2c_r, &mut buf).await.expect("read responses");
269 buf
270 }
271
272 fn parse_responses(bytes: &[u8]) -> Vec<Response> {
273 std::str::from_utf8(bytes)
274 .expect("utf8")
275 .lines()
276 .filter(|l| !l.is_empty())
277 .map(|l| serde_json::from_str(l).expect("parse response"))
278 .collect()
279 }
280
281 #[tokio::test]
282 async fn server_stub_dispatches_known_verb_and_writes_result_line() {
283 let handler = Arc::new(StubHandler { last_verb: Mutex::new(None) });
284 let req = Request { id: 11, verb: "ping".to_string(), args: serde_json::Value::Null };
285 let raw = serde_json::to_string(&req).unwrap() + "\n";
286 let bytes = drive(Arc::clone(&handler), &raw).await;
287 let responses = parse_responses(&bytes);
288 assert_eq!(responses.len(), 1);
289 assert_eq!(responses[0].id, 11);
290 match &responses[0].outcome {
291 ResponseOutcome::Result { result } => assert_eq!(result["pong"], true),
292 other => panic!("unexpected outcome: {other:?}"),
293 }
294 assert_eq!(handler.last_verb.lock().unwrap().as_deref(), Some("ping"));
295 }
296
297 #[tokio::test]
298 async fn server_stub_writes_error_for_unknown_verb() {
299 let handler = Arc::new(StubHandler { last_verb: Mutex::new(None) });
300 let req = Request { id: 5, verb: "wat".to_string(), args: serde_json::Value::Null };
301 let raw = serde_json::to_string(&req).unwrap() + "\n";
302 let bytes = drive(handler, &raw).await;
303 let responses = parse_responses(&bytes);
304 assert_eq!(responses.len(), 1);
305 assert_eq!(responses[0].id, 5);
306 match &responses[0].outcome {
307 ResponseOutcome::Error { error } => {
308 assert_eq!(error.kind, WireErrorKind::UnknownVerb);
309 assert!(error.message.contains("wat"));
310 }
311 other => panic!("expected error, got {other:?}"),
312 }
313 }
314
315 #[tokio::test]
316 async fn server_stub_writes_bad_args_error_for_unparseable_request() {
317 let handler = Arc::new(StubHandler { last_verb: Mutex::new(None) });
318 let raw = "this is not json\n";
319 let bytes = drive(handler, raw).await;
320 let responses = parse_responses(&bytes);
321 assert_eq!(responses.len(), 1);
322 assert_eq!(responses[0].id, 0);
324 match &responses[0].outcome {
325 ResponseOutcome::Error { error } => assert_eq!(error.kind, WireErrorKind::BadArgs),
326 other => panic!("expected error, got {other:?}"),
327 }
328 }
329
330 #[tokio::test]
331 async fn server_dispatches_streaming_verb_writes_event_then_end() {
332 let handler = Arc::new(StubHandler { last_verb: Mutex::new(None) });
333 let req = Request { id: 99, verb: "stream2".to_string(), args: serde_json::Value::Null };
334 let raw = serde_json::to_string(&req).unwrap() + "\n";
335 let bytes = drive(handler, &raw).await;
336 let responses = parse_responses(&bytes);
337 assert_eq!(responses.len(), 3, "two events plus a terminating End frame");
339 for r in &responses {
340 assert_eq!(r.id, 99, "streaming frames echo the request id");
341 }
342 assert!(matches!(responses[0].outcome, ResponseOutcome::Event { .. }));
343 assert!(matches!(responses[1].outcome, ResponseOutcome::Event { .. }));
344 assert!(matches!(responses[2].outcome, ResponseOutcome::End { .. }));
345 if let ResponseOutcome::Event { event } = &responses[0].outcome {
347 assert_eq!(event["n"], 2);
348 }
349 if let ResponseOutcome::Event { event } = &responses[1].outcome {
350 assert_eq!(event["n"], 1);
351 }
352 }
353
354 #[tokio::test]
355 async fn server_stub_handles_multiple_requests_serial_per_connection() {
356 let handler = Arc::new(StubHandler { last_verb: Mutex::new(None) });
357 let r1 =
358 serde_json::to_string(&Request { id: 1, verb: "ping".into(), args: serde_json::Value::Null })
359 .unwrap();
360 let r2 = serde_json::to_string(&Request {
361 id: 2,
362 verb: "echo".into(),
363 args: serde_json::json!({"x": 1}),
364 })
365 .unwrap();
366 let r3 =
367 serde_json::to_string(&Request { id: 3, verb: "nope".into(), args: serde_json::Value::Null })
368 .unwrap();
369 let raw = format!("{r1}\n{r2}\n\n{r3}\n");
370 let bytes = drive(handler, &raw).await;
371 let responses = parse_responses(&bytes);
372 assert_eq!(responses.len(), 3, "blank line is skipped, not echoed back");
373 assert_eq!(responses[0].id, 1);
374 assert_eq!(responses[1].id, 2);
375 assert_eq!(responses[2].id, 3);
376 assert!(matches!(responses[0].outcome, ResponseOutcome::Result { .. }));
377 assert!(matches!(responses[1].outcome, ResponseOutcome::Result { .. }));
378 assert!(matches!(responses[2].outcome, ResponseOutcome::Error { .. }));
379 }
380}