1use std::os::unix::fs::PermissionsExt;
7use std::path::{Path, PathBuf};
8use std::sync::Arc;
9
10use async_trait::async_trait;
11use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader};
12use tokio::net::{UnixListener, UnixStream};
13use tokio::task::JoinHandle;
14use tokio_util::sync::CancellationToken;
15
16use crate::protocol::{
17 EndMarker, Request, Response, ResponseOutcome, WireError, WireErrorKind, encode_line,
18};
19
20pub const MAX_NDJSON_LINE_BYTES: usize = 1024 * 1024;
26
27#[async_trait]
31pub trait Handler: Send + Sync + 'static {
32 async fn dispatch(&self, req: Request) -> DispatchOutcome;
36}
37
38pub enum DispatchOutcome {
43 OneShot(Result<serde_json::Value, WireError>),
45 Stream(Box<dyn EventStream + Send>),
48}
49
50#[async_trait]
53pub trait EventStream: Send {
54 async fn next_event(&mut self) -> Option<serde_json::Value>;
57}
58
59struct UmaskRestore {
69 prev: libc::mode_t,
70}
71
72impl UmaskRestore {
73 #[allow(unsafe_code)] fn tighten(mask: libc::mode_t) -> Self {
75 let prev = unsafe { libc::umask(mask) };
78 Self { prev }
79 }
80}
81
82impl Drop for UmaskRestore {
83 #[allow(unsafe_code)] fn drop(&mut self) {
85 unsafe {
89 libc::umask(self.prev);
90 }
91 }
92}
93
94pub async fn spawn_unix_server<H: Handler>(
108 socket_path: &Path,
109 handler: Arc<H>,
110 cancel: CancellationToken,
111) -> std::io::Result<JoinHandle<()>> {
112 let _ = std::fs::remove_file(socket_path);
115
116 let _umask_restore = UmaskRestore::tighten(0o117);
122
123 let listener = UnixListener::bind(socket_path)?;
124
125 let perms = std::fs::Permissions::from_mode(0o600);
130 std::fs::set_permissions(socket_path, perms)?;
131
132 if let Some(parent) = socket_path.parent()
137 && let Ok(meta) = std::fs::metadata(parent)
138 {
139 let mode = meta.permissions().mode() & 0o777;
140 if mode != 0o700 && mode != 0o770 {
141 tracing::warn!(
142 dir = %parent.display(),
143 mode = format!("{:#o}", mode),
144 "mgmt socket parent dir is broader than 0700/0770; restrict perms or move the socket",
145 );
146 }
147 }
148
149 let socket_path: PathBuf = socket_path.to_path_buf();
150 let handle = tokio::spawn(async move {
151 loop {
152 tokio::select! {
153 biased;
154 () = cancel.cancelled() => {
155 let _ = std::fs::remove_file(&socket_path);
156 return;
157 }
158 accepted = listener.accept() => {
159 let stream: UnixStream = match accepted {
160 Ok((s, _)) => s,
161 Err(e) => {
162 tracing::warn!(?e, "mgmt accept failed");
163 continue;
164 }
165 };
166 let h = Arc::clone(&handler);
167 let conn_cancel = cancel.child_token();
172 tokio::spawn(async move {
173 let (read, write) = stream.into_split();
174 handle_conn(read, write, h, conn_cancel).await;
175 });
176 }
177 }
178 }
179 });
180 Ok(handle)
181}
182
183async fn read_line_bounded<R>(
188 reader: &mut BufReader<R>,
189 buf: &mut String,
190 cap: usize,
191) -> std::io::Result<Option<()>>
192where
193 R: AsyncRead + Unpin,
194{
195 buf.clear();
196 let start_len = buf.len();
197 loop {
198 let prev_len = buf.len();
199 let n = reader.read_line(buf).await?;
200 if n == 0 {
201 return if buf.len() == start_len { Ok(None) } else { Ok(Some(())) };
204 }
205 if buf.ends_with('\n') {
207 buf.pop();
208 if buf.ends_with('\r') {
209 buf.pop();
210 }
211 if buf.len() > cap {
214 return Err(std::io::Error::new(
215 std::io::ErrorKind::InvalidData,
216 format!("ndjson line exceeded {cap}-byte cap"),
217 ));
218 }
219 return Ok(Some(()));
220 }
221 if buf.len() > cap {
224 return Err(std::io::Error::new(
225 std::io::ErrorKind::InvalidData,
226 format!("ndjson line exceeded {cap}-byte cap"),
227 ));
228 }
229 if buf.len() == prev_len + n && n == 0 {
232 return Ok(Some(()));
233 }
234 }
235}
236
237pub(crate) async fn handle_conn<R, W, H>(
242 read: R,
243 mut write: W,
244 handler: Arc<H>,
245 cancel: CancellationToken,
246) where
247 R: AsyncRead + Unpin,
248 W: AsyncWrite + Unpin,
249 H: Handler,
250{
251 let mut reader = BufReader::new(read);
252 let mut line = String::new();
253 loop {
254 let read_outcome = tokio::select! {
258 biased;
259 () = cancel.cancelled() => return,
260 res = read_line_bounded(&mut reader, &mut line, MAX_NDJSON_LINE_BYTES) => res,
261 };
262 match read_outcome {
263 Ok(None) => return,
264 Ok(Some(())) => {}
265 Err(e) if e.kind() == std::io::ErrorKind::InvalidData => {
266 let frame = Response {
270 id: 0,
271 outcome: ResponseOutcome::Error {
272 error: WireError::new(WireErrorKind::BadArgs, format!("line too long: {e}")),
273 },
274 };
275 let _ = write_frame(&mut write, &frame).await;
276 return;
277 }
278 Err(e) => {
279 tracing::debug!(?e, "mgmt read failed");
280 return;
281 }
282 }
283 if line.is_empty() {
284 continue;
285 }
286 match serde_json::from_str::<Request>(&line) {
287 Ok(req) => {
288 let id = req.id;
289 match handler.dispatch(req).await {
290 DispatchOutcome::OneShot(Ok(value)) => {
291 let frame = Response { id, outcome: ResponseOutcome::Result { result: value } };
292 if write_frame(&mut write, &frame).await.is_err() {
293 return;
294 }
295 }
296 DispatchOutcome::OneShot(Err(error)) => {
297 let frame = Response { id, outcome: ResponseOutcome::Error { error } };
298 if write_frame(&mut write, &frame).await.is_err() {
299 return;
300 }
301 }
302 DispatchOutcome::Stream(mut stream) => {
303 loop {
309 tokio::select! {
310 biased;
311 () = cancel.cancelled() => {
312 let end = Response {
313 id,
314 outcome: ResponseOutcome::End { end: EndMarker::default() },
315 };
316 let _ = write_frame(&mut write, &end).await;
317 return;
318 }
319 maybe = stream.next_event() => {
320 let Some(event) = maybe else {
321 let end = Response {
322 id,
323 outcome: ResponseOutcome::End { end: EndMarker::default() },
324 };
325 let _ = write_frame(&mut write, &end).await;
326 return;
327 };
328 let frame = Response { id, outcome: ResponseOutcome::Event { event } };
329 if write_frame(&mut write, &frame).await.is_err() {
330 return;
331 }
332 }
333 }
334 }
335 }
336 }
337 }
338 Err(e) => {
339 let frame = Response {
340 id: 0,
343 outcome: ResponseOutcome::Error {
344 error: WireError::new(WireErrorKind::BadArgs, format!("parse: {e}")),
345 },
346 };
347 if write_frame(&mut write, &frame).await.is_err() {
348 return;
349 }
350 }
351 }
352 }
353}
354
355async fn write_frame<W: AsyncWrite + Unpin>(
359 write: &mut W,
360 frame: &Response,
361) -> Result<(), std::io::Error> {
362 let bytes = match encode_line(frame) {
363 Ok(b) => b,
364 Err(e) => {
365 tracing::error!(?e, "mgmt response encode failed");
366 return Err(std::io::Error::other(e));
367 }
368 };
369 write.write_all(&bytes).await
370}
371
372#[cfg(test)]
373mod tests {
374 use super::*;
375 use std::sync::Mutex;
376
377 struct StubHandler {
378 last_verb: Mutex<Option<String>>,
380 }
381
382 #[async_trait]
383 impl Handler for StubHandler {
384 async fn dispatch(&self, req: Request) -> DispatchOutcome {
385 *self.last_verb.lock().unwrap() = Some(req.verb.clone());
386 let result: Result<serde_json::Value, WireError> = match req.verb.as_str() {
387 "ping" => Ok(serde_json::json!({ "pong": true })),
388 "echo" => Ok(req.args),
389 "stream2" => {
390 return DispatchOutcome::Stream(Box::new(MockStream::with_two_events()));
391 }
392 _ => Err(WireError::new(WireErrorKind::UnknownVerb, format!("unknown {}", req.verb))),
393 };
394 DispatchOutcome::OneShot(result)
395 }
396 }
397
398 struct MockStream {
401 remaining: Vec<serde_json::Value>,
402 }
403
404 impl MockStream {
405 fn with_two_events() -> Self {
406 Self { remaining: vec![serde_json::json!({ "n": 1 }), serde_json::json!({ "n": 2 })] }
410 }
411 }
412
413 #[async_trait]
414 impl EventStream for MockStream {
415 async fn next_event(&mut self) -> Option<serde_json::Value> {
416 self.remaining.pop()
417 }
418 }
419
420 async fn drive(handler: Arc<StubHandler>, requests: &str) -> Vec<u8> {
423 let (c2s_r, mut c2s_w) = tokio::io::duplex(8192);
426 let (s2c_w, mut s2c_r) = tokio::io::duplex(8192);
427 let req = requests.to_string();
428 let server_task = tokio::spawn(handle_conn(c2s_r, s2c_w, handler, CancellationToken::new()));
429 c2s_w.write_all(req.as_bytes()).await.expect("write requests");
430 drop(c2s_w);
433 server_task.await.expect("server task");
434 let mut buf = Vec::new();
436 tokio::io::AsyncReadExt::read_to_end(&mut s2c_r, &mut buf).await.expect("read responses");
437 buf
438 }
439
440 fn parse_responses(bytes: &[u8]) -> Vec<Response> {
441 std::str::from_utf8(bytes)
442 .expect("utf8")
443 .lines()
444 .filter(|l| !l.is_empty())
445 .map(|l| serde_json::from_str(l).expect("parse response"))
446 .collect()
447 }
448
449 #[tokio::test]
450 async fn server_stub_dispatches_known_verb_and_writes_result_line() {
451 let handler = Arc::new(StubHandler { last_verb: Mutex::new(None) });
452 let req = Request { id: 11, verb: "ping".to_string(), args: serde_json::Value::Null };
453 let raw = serde_json::to_string(&req).unwrap() + "\n";
454 let bytes = drive(Arc::clone(&handler), &raw).await;
455 let responses = parse_responses(&bytes);
456 assert_eq!(responses.len(), 1);
457 assert_eq!(responses[0].id, 11);
458 match &responses[0].outcome {
459 ResponseOutcome::Result { result } => assert_eq!(result["pong"], true),
460 other => panic!("unexpected outcome: {other:?}"),
461 }
462 assert_eq!(handler.last_verb.lock().unwrap().as_deref(), Some("ping"));
463 }
464
465 #[tokio::test]
466 async fn server_stub_writes_error_for_unknown_verb() {
467 let handler = Arc::new(StubHandler { last_verb: Mutex::new(None) });
468 let req = Request { id: 5, verb: "wat".to_string(), args: serde_json::Value::Null };
469 let raw = serde_json::to_string(&req).unwrap() + "\n";
470 let bytes = drive(handler, &raw).await;
471 let responses = parse_responses(&bytes);
472 assert_eq!(responses.len(), 1);
473 assert_eq!(responses[0].id, 5);
474 match &responses[0].outcome {
475 ResponseOutcome::Error { error } => {
476 assert_eq!(error.kind, WireErrorKind::UnknownVerb);
477 assert!(error.message.contains("wat"));
478 }
479 other => panic!("expected error, got {other:?}"),
480 }
481 }
482
483 #[tokio::test]
484 async fn server_stub_writes_bad_args_error_for_unparseable_request() {
485 let handler = Arc::new(StubHandler { last_verb: Mutex::new(None) });
486 let raw = "this is not json\n";
487 let bytes = drive(handler, raw).await;
488 let responses = parse_responses(&bytes);
489 assert_eq!(responses.len(), 1);
490 assert_eq!(responses[0].id, 0);
492 match &responses[0].outcome {
493 ResponseOutcome::Error { error } => assert_eq!(error.kind, WireErrorKind::BadArgs),
494 other => panic!("expected error, got {other:?}"),
495 }
496 }
497
498 #[tokio::test]
499 async fn server_dispatches_streaming_verb_writes_event_then_end() {
500 let handler = Arc::new(StubHandler { last_verb: Mutex::new(None) });
501 let req = Request { id: 99, verb: "stream2".to_string(), args: serde_json::Value::Null };
502 let raw = serde_json::to_string(&req).unwrap() + "\n";
503 let bytes = drive(handler, &raw).await;
504 let responses = parse_responses(&bytes);
505 assert_eq!(responses.len(), 3, "two events plus a terminating End frame");
507 for r in &responses {
508 assert_eq!(r.id, 99, "streaming frames echo the request id");
509 }
510 assert!(matches!(responses[0].outcome, ResponseOutcome::Event { .. }));
511 assert!(matches!(responses[1].outcome, ResponseOutcome::Event { .. }));
512 assert!(matches!(responses[2].outcome, ResponseOutcome::End { .. }));
513 if let ResponseOutcome::Event { event } = &responses[0].outcome {
515 assert_eq!(event["n"], 2);
516 }
517 if let ResponseOutcome::Event { event } = &responses[1].outcome {
518 assert_eq!(event["n"], 1);
519 }
520 }
521
522 #[tokio::test]
523 async fn server_rejects_line_exceeding_cap_with_bad_args() {
524 let handler = Arc::new(StubHandler { last_verb: Mutex::new(None) });
525 let huge_line = format!(
530 "{{\"id\":1,\"verb\":\"x\",\"args\":\"{}\"}}\n",
531 "A".repeat(MAX_NDJSON_LINE_BYTES + 1)
532 );
533 let bytes = drive(handler.clone(), &huge_line).await;
534 let responses = parse_responses(&bytes);
535 assert_eq!(responses.len(), 1);
536 match &responses[0].outcome {
537 ResponseOutcome::Error { error } => {
538 assert_eq!(error.kind, WireErrorKind::BadArgs);
539 assert!(error.message.contains("line too long"), "{}", error.message);
540 }
541 other => panic!("expected BadArgs error, got {other:?}"),
542 }
543 assert!(handler.last_verb.lock().unwrap().is_none());
545 }
546
547 #[tokio::test]
548 async fn server_stub_handles_multiple_requests_serial_per_connection() {
549 let handler = Arc::new(StubHandler { last_verb: Mutex::new(None) });
550 let r1 =
551 serde_json::to_string(&Request { id: 1, verb: "ping".into(), args: serde_json::Value::Null })
552 .unwrap();
553 let r2 = serde_json::to_string(&Request {
554 id: 2,
555 verb: "echo".into(),
556 args: serde_json::json!({"x": 1}),
557 })
558 .unwrap();
559 let r3 =
560 serde_json::to_string(&Request { id: 3, verb: "nope".into(), args: serde_json::Value::Null })
561 .unwrap();
562 let raw = format!("{r1}\n{r2}\n\n{r3}\n");
563 let bytes = drive(handler, &raw).await;
564 let responses = parse_responses(&bytes);
565 assert_eq!(responses.len(), 3, "blank line is skipped, not echoed back");
566 assert_eq!(responses[0].id, 1);
567 assert_eq!(responses[1].id, 2);
568 assert_eq!(responses[2].id, 3);
569 assert!(matches!(responses[0].outcome, ResponseOutcome::Result { .. }));
570 assert!(matches!(responses[1].outcome, ResponseOutcome::Result { .. }));
571 assert!(matches!(responses[2].outcome, ResponseOutcome::Error { .. }));
572 }
573}