Skip to main content

ndjson_rpc/
server.rs

1//! Unix-socket accept loop + per-connection line-delimited JSON
2//! dispatch to verb handlers. The HTTP-over-TCP transport
3//! ([`crate::http_server`]) speaks the same frame shapes, so dispatch
4//! logic is shared.
5
6use std::os::unix::fs::PermissionsExt;
7use std::path::{Path, PathBuf};
8use std::sync::Arc;
9
10use async_trait::async_trait;
11use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader};
12use tokio::net::{UnixListener, UnixStream};
13use tokio::task::JoinHandle;
14use tokio_util::sync::CancellationToken;
15
16use crate::protocol::{
17	EndMarker, Request, Response, ResponseOutcome, WireError, WireErrorKind, encode_line,
18};
19
20/// Hard cap on the size of one NDJSON request line. Aligns with the
21/// 1 MiB body cap on the HTTP transport so both faces of the mgmt
22/// plane reject the same magnitude of oversized input. A real verb
23/// payload is well under a kilobyte; anything larger is either a
24/// malformed framing or an adversarial slowloris-by-line attack.
25pub const MAX_NDJSON_LINE_BYTES: usize = 1024 * 1024;
26
27/// Server-side dispatcher. Callers implement this against their own
28/// application state and pass an `Arc<H>` to [`spawn_unix_server`] or
29/// [`crate::spawn_http_server`].
30#[async_trait]
31pub trait Handler: Send + Sync + 'static {
32	/// Dispatch a parsed request to either a one-shot result or a
33	/// streaming event source. The server frames whichever outcome the
34	/// handler returns.
35	async fn dispatch(&self, req: Request) -> DispatchOutcome;
36}
37
38/// What `dispatch` returns. One-shot verbs (`ping`, `stats`, ...)
39/// produce a single result/error frame; streaming verbs
40/// (`tail_flow`) produce a sequence of `Event` frames terminated
41/// by an `End` frame.
42pub enum DispatchOutcome {
43	/// One-shot reply: a single JSON value or a structured error.
44	OneShot(Result<serde_json::Value, WireError>),
45	/// Streaming reply: each call to `next_event` yields the next
46	/// `Event` payload, or `None` to terminate with an `End` frame.
47	Stream(Box<dyn EventStream + Send>),
48}
49
50/// A streaming event source. The server polls `next_event` until the
51/// client disconnects or the stream returns `None`.
52#[async_trait]
53pub trait EventStream: Send {
54	/// `Some(event)` = next event payload to write as `Event { event }`.
55	/// `None` = stream terminated normally; the server writes `End`.
56	async fn next_event(&mut self) -> Option<serde_json::Value>;
57}
58
59/// RAII guard that tightens the process's file-mode-creation mask
60/// (`umask`) for the duration of a scope and restores the prior value
61/// on drop. Used around `UnixListener::bind` so a permissive operator
62/// umask cannot widen the perms of the freshly-created socket file.
63///
64/// `umask(2)` is process-global, so any concurrent file creation in
65/// other tasks while this guard is alive will also see the tightened
66/// mask. For mgmt-socket bind this window is sub-millisecond and we
67/// hold no other I/O off the critical path.
68struct UmaskRestore {
69	prev: libc::mode_t,
70}
71
72impl UmaskRestore {
73	#[allow(unsafe_code)] // libc::umask is FFI; thread-safe POSIX call with no preconditions.
74	fn tighten(mask: libc::mode_t) -> Self {
75		// SAFETY: `umask` is a thread-safe POSIX call with no
76		// preconditions. The return value is the previous mask.
77		let prev = unsafe { libc::umask(mask) };
78		Self { prev }
79	}
80}
81
82impl Drop for UmaskRestore {
83	#[allow(unsafe_code)] // libc::umask is FFI; see `tighten`.
84	fn drop(&mut self) {
85		// SAFETY: see `tighten`. Restoration is best-effort: there is
86		// nothing useful to do if the kernel rejects the value (it
87		// cannot — `umask` accepts any `mode_t`).
88		unsafe {
89			libc::umask(self.prev);
90		}
91	}
92}
93
94/// Bind a Unix socket and serve mgmt requests until `cancel` fires.
95///
96/// On bind, an existing socket file at `socket_path` is unlinked first
97/// — operators are responsible for ensuring no other `vaned` is using
98/// the path. The socket file's mode is set to `0600`: mgmt access is
99/// gated by file-system permissions only, no in-band auth.
100///
101/// On cancellation, the bound socket file is removed before the task
102/// returns so a subsequent `vaned` boot can re-bind cleanly.
103///
104/// # Errors
105/// Bind / chmod / remove-stale-file failures bubble up as
106/// [`std::io::Error`].
107pub async fn spawn_unix_server<H: Handler>(
108	socket_path: &Path,
109	handler: Arc<H>,
110	cancel: CancellationToken,
111) -> std::io::Result<JoinHandle<()>> {
112	// Unlink any stale socket file before bind. systemd-style socket
113	// activation is not supported this round.
114	let _ = std::fs::remove_file(socket_path);
115
116	// Tighten the inherited umask BEFORE bind so the kernel creates
117	// the socket file with restrictive perms (0o660 modulo umask =
118	// 0o600). Restore the previous umask via RAII regardless of bind
119	// outcome so the daemon's other I/O paths see their original
120	// settings.
121	let _umask_restore = UmaskRestore::tighten(0o117);
122
123	let listener = UnixListener::bind(socket_path)?;
124
125	// Belt-and-suspenders: fchmod the socket to 0600 explicitly. The
126	// umask path covers the bind-side race; this covers operators
127	// running with permissive umasks (`077`) where the kernel would
128	// have created the socket more permissively than we want.
129	let perms = std::fs::Permissions::from_mode(0o600);
130	std::fs::set_permissions(socket_path, perms)?;
131
132	// Best-effort: warn when the socket's parent directory is more
133	// permissive than the operator probably intends. A 0o755 parent
134	// dir means any local user can `stat` the socket; a 0o777 parent
135	// can unlink it. Both are footguns on multi-tenant hosts.
136	if let Some(parent) = socket_path.parent()
137		&& let Ok(meta) = std::fs::metadata(parent)
138	{
139		let mode = meta.permissions().mode() & 0o777;
140		if mode != 0o700 && mode != 0o770 {
141			tracing::warn!(
142				dir = %parent.display(),
143				mode = format!("{:#o}", mode),
144				"mgmt socket parent dir is broader than 0700/0770; restrict perms or move the socket",
145			);
146		}
147	}
148
149	let socket_path: PathBuf = socket_path.to_path_buf();
150	let handle = tokio::spawn(async move {
151		loop {
152			tokio::select! {
153				biased;
154				() = cancel.cancelled() => {
155					let _ = std::fs::remove_file(&socket_path);
156					return;
157				}
158				accepted = listener.accept() => {
159					let stream: UnixStream = match accepted {
160						Ok((s, _)) => s,
161						Err(e) => {
162							tracing::warn!(?e, "mgmt accept failed");
163							continue;
164						}
165					};
166					let h = Arc::clone(&handler);
167					// Each per-connection driver gets a child token so
168					// shutdown drives every in-flight verb / stream to
169					// exit cleanly instead of leaving them blocked on
170					// the read side of the socket.
171					let conn_cancel = cancel.child_token();
172					tokio::spawn(async move {
173						let (read, write) = stream.into_split();
174						handle_conn(read, write, h, conn_cancel).await;
175					});
176				}
177			}
178		}
179	});
180	Ok(handle)
181}
182
183/// Read a single NDJSON line with a hard byte cap. Returns `Ok(None)`
184/// on clean EOF; `Ok(Some(_))` with a populated buffer when a line
185/// terminator is seen; and the dedicated [`std::io::ErrorKind::FileTooLarge`]
186/// when the cap is exceeded before a newline arrives.
187async fn read_line_bounded<R>(
188	reader: &mut BufReader<R>,
189	buf: &mut String,
190	cap: usize,
191) -> std::io::Result<Option<()>>
192where
193	R: AsyncRead + Unpin,
194{
195	buf.clear();
196	let start_len = buf.len();
197	loop {
198		let prev_len = buf.len();
199		let n = reader.read_line(buf).await?;
200		if n == 0 {
201			// Clean EOF; return None if nothing buffered, else propagate
202			// whatever the peer flushed without a trailing newline.
203			return if buf.len() == start_len { Ok(None) } else { Ok(Some(())) };
204		}
205		// Strip the trailing newline so callers don't need to.
206		if buf.ends_with('\n') {
207			buf.pop();
208			if buf.ends_with('\r') {
209				buf.pop();
210			}
211			// Cap-check on the post-strip length so a single huge
212			// read that includes the terminator still fails closed.
213			if buf.len() > cap {
214				return Err(std::io::Error::new(
215					std::io::ErrorKind::InvalidData,
216					format!("ndjson line exceeded {cap}-byte cap"),
217				));
218			}
219			return Ok(Some(()));
220		}
221		// No newline read yet — bail if we'd exceed the per-line cap
222		// before the peer flushes a terminator.
223		if buf.len() > cap {
224			return Err(std::io::Error::new(
225				std::io::ErrorKind::InvalidData,
226				format!("ndjson line exceeded {cap}-byte cap"),
227			));
228		}
229		// Keep going — `read_line` can chunk on internal-buffer
230		// boundaries even when more bytes are inbound.
231		if buf.len() == prev_len + n && n == 0 {
232			return Ok(Some(()));
233		}
234	}
235}
236
237/// Generic request loop, abstract over the read/write halves so unit
238/// tests can drive it with `tokio::io::duplex` instead of a real Unix
239/// socket. Production callers always pass the halves of a
240/// [`tokio::net::UnixStream`].
241pub(crate) async fn handle_conn<R, W, H>(
242	read: R,
243	mut write: W,
244	handler: Arc<H>,
245	cancel: CancellationToken,
246) where
247	R: AsyncRead + Unpin,
248	W: AsyncWrite + Unpin,
249	H: Handler,
250{
251	let mut reader = BufReader::new(read);
252	let mut line = String::new();
253	loop {
254		// Select against the per-connection cancel token so a server-
255		// wide shutdown drives every blocked read off the socket
256		// instead of leaving the driver parked on `read_line`.
257		let read_outcome = tokio::select! {
258			biased;
259			() = cancel.cancelled() => return,
260			res = read_line_bounded(&mut reader, &mut line, MAX_NDJSON_LINE_BYTES) => res,
261		};
262		match read_outcome {
263			Ok(None) => return,
264			Ok(Some(())) => {}
265			Err(e) if e.kind() == std::io::ErrorKind::InvalidData => {
266				// Oversized line: write a structured error and close —
267				// don't keep reading on a session whose framing is
268				// already off the rails.
269				let frame = Response {
270					id: 0,
271					outcome: ResponseOutcome::Error {
272						error: WireError::new(WireErrorKind::BadArgs, format!("line too long: {e}")),
273					},
274				};
275				let _ = write_frame(&mut write, &frame).await;
276				return;
277			}
278			Err(e) => {
279				tracing::debug!(?e, "mgmt read failed");
280				return;
281			}
282		}
283		if line.is_empty() {
284			continue;
285		}
286		match serde_json::from_str::<Request>(&line) {
287			Ok(req) => {
288				let id = req.id;
289				match handler.dispatch(req).await {
290					DispatchOutcome::OneShot(Ok(value)) => {
291						let frame = Response { id, outcome: ResponseOutcome::Result { result: value } };
292						if write_frame(&mut write, &frame).await.is_err() {
293							return;
294						}
295					}
296					DispatchOutcome::OneShot(Err(error)) => {
297						let frame = Response { id, outcome: ResponseOutcome::Error { error } };
298						if write_frame(&mut write, &frame).await.is_err() {
299							return;
300						}
301					}
302					DispatchOutcome::Stream(mut stream) => {
303						// Streaming verbs consume the connection — once we
304						// start streaming we don't read more requests on
305						// this socket. Cancel-on-shutdown drives every
306						// `next_event` off so a daemon-wide stop trip
307						// flushes an `End` frame and unblocks the client.
308						loop {
309							tokio::select! {
310								biased;
311								() = cancel.cancelled() => {
312									let end = Response {
313										id,
314										outcome: ResponseOutcome::End { end: EndMarker::default() },
315									};
316									let _ = write_frame(&mut write, &end).await;
317									return;
318								}
319								maybe = stream.next_event() => {
320									let Some(event) = maybe else {
321										let end = Response {
322											id,
323											outcome: ResponseOutcome::End { end: EndMarker::default() },
324										};
325										let _ = write_frame(&mut write, &end).await;
326										return;
327									};
328									let frame = Response { id, outcome: ResponseOutcome::Event { event } };
329									if write_frame(&mut write, &frame).await.is_err() {
330										return;
331									}
332								}
333							}
334						}
335					}
336				}
337			}
338			Err(e) => {
339				let frame = Response {
340					// id is unknown when the frame fails to parse — `0` is
341					// the documented sentinel for "no correlation possible".
342					id: 0,
343					outcome: ResponseOutcome::Error {
344						error: WireError::new(WireErrorKind::BadArgs, format!("parse: {e}")),
345					},
346				};
347				if write_frame(&mut write, &frame).await.is_err() {
348					return;
349				}
350			}
351		}
352	}
353}
354
355/// Encode a response and write it as one NDJSON line. Wraps the two
356/// fallible sub-steps (encode → write) so the streaming loop has a
357/// single error path.
358async fn write_frame<W: AsyncWrite + Unpin>(
359	write: &mut W,
360	frame: &Response,
361) -> Result<(), std::io::Error> {
362	let bytes = match encode_line(frame) {
363		Ok(b) => b,
364		Err(e) => {
365			tracing::error!(?e, "mgmt response encode failed");
366			return Err(std::io::Error::other(e));
367		}
368	};
369	write.write_all(&bytes).await
370}
371
372#[cfg(test)]
373mod tests {
374	use super::*;
375	use std::sync::Mutex;
376
377	struct StubHandler {
378		// Records the last verb seen, for assertions.
379		last_verb: Mutex<Option<String>>,
380	}
381
382	#[async_trait]
383	impl Handler for StubHandler {
384		async fn dispatch(&self, req: Request) -> DispatchOutcome {
385			*self.last_verb.lock().unwrap() = Some(req.verb.clone());
386			let result: Result<serde_json::Value, WireError> = match req.verb.as_str() {
387				"ping" => Ok(serde_json::json!({ "pong": true })),
388				"echo" => Ok(req.args),
389				"stream2" => {
390					return DispatchOutcome::Stream(Box::new(MockStream::with_two_events()));
391				}
392				_ => Err(WireError::new(WireErrorKind::UnknownVerb, format!("unknown {}", req.verb))),
393			};
394			DispatchOutcome::OneShot(result)
395		}
396	}
397
398	/// Trivial event stream: emits two events then terminates with `None`,
399	/// modelling the smallest possible streaming verb.
400	struct MockStream {
401		remaining: Vec<serde_json::Value>,
402	}
403
404	impl MockStream {
405		fn with_two_events() -> Self {
406			// Pop returns the last element first; queue events in reverse
407			// so the wire ordering observed by the client matches the
408			// natural reading order (n=2 then n=1).
409			Self { remaining: vec![serde_json::json!({ "n": 1 }), serde_json::json!({ "n": 2 })] }
410		}
411	}
412
413	#[async_trait]
414	impl EventStream for MockStream {
415		async fn next_event(&mut self) -> Option<serde_json::Value> {
416			self.remaining.pop()
417		}
418	}
419
420	/// Pump one or more request lines through `handle_conn` against a
421	/// stub handler, returning the response bytes the server wrote.
422	async fn drive(handler: Arc<StubHandler>, requests: &str) -> Vec<u8> {
423		// Client writes requests on `c2s_w` → server reads on `c2s_r`.
424		// Server writes responses on `s2c_w` → client reads on `s2c_r`.
425		let (c2s_r, mut c2s_w) = tokio::io::duplex(8192);
426		let (s2c_w, mut s2c_r) = tokio::io::duplex(8192);
427		let req = requests.to_string();
428		let server_task = tokio::spawn(handle_conn(c2s_r, s2c_w, handler, CancellationToken::new()));
429		c2s_w.write_all(req.as_bytes()).await.expect("write requests");
430		// Closing the write half makes `next_line` return None on the
431		// server side so the task completes cleanly.
432		drop(c2s_w);
433		server_task.await.expect("server task");
434		// Read everything the server wrote.
435		let mut buf = Vec::new();
436		tokio::io::AsyncReadExt::read_to_end(&mut s2c_r, &mut buf).await.expect("read responses");
437		buf
438	}
439
440	fn parse_responses(bytes: &[u8]) -> Vec<Response> {
441		std::str::from_utf8(bytes)
442			.expect("utf8")
443			.lines()
444			.filter(|l| !l.is_empty())
445			.map(|l| serde_json::from_str(l).expect("parse response"))
446			.collect()
447	}
448
449	#[tokio::test]
450	async fn server_stub_dispatches_known_verb_and_writes_result_line() {
451		let handler = Arc::new(StubHandler { last_verb: Mutex::new(None) });
452		let req = Request { id: 11, verb: "ping".to_string(), args: serde_json::Value::Null };
453		let raw = serde_json::to_string(&req).unwrap() + "\n";
454		let bytes = drive(Arc::clone(&handler), &raw).await;
455		let responses = parse_responses(&bytes);
456		assert_eq!(responses.len(), 1);
457		assert_eq!(responses[0].id, 11);
458		match &responses[0].outcome {
459			ResponseOutcome::Result { result } => assert_eq!(result["pong"], true),
460			other => panic!("unexpected outcome: {other:?}"),
461		}
462		assert_eq!(handler.last_verb.lock().unwrap().as_deref(), Some("ping"));
463	}
464
465	#[tokio::test]
466	async fn server_stub_writes_error_for_unknown_verb() {
467		let handler = Arc::new(StubHandler { last_verb: Mutex::new(None) });
468		let req = Request { id: 5, verb: "wat".to_string(), args: serde_json::Value::Null };
469		let raw = serde_json::to_string(&req).unwrap() + "\n";
470		let bytes = drive(handler, &raw).await;
471		let responses = parse_responses(&bytes);
472		assert_eq!(responses.len(), 1);
473		assert_eq!(responses[0].id, 5);
474		match &responses[0].outcome {
475			ResponseOutcome::Error { error } => {
476				assert_eq!(error.kind, WireErrorKind::UnknownVerb);
477				assert!(error.message.contains("wat"));
478			}
479			other => panic!("expected error, got {other:?}"),
480		}
481	}
482
483	#[tokio::test]
484	async fn server_stub_writes_bad_args_error_for_unparseable_request() {
485		let handler = Arc::new(StubHandler { last_verb: Mutex::new(None) });
486		let raw = "this is not json\n";
487		let bytes = drive(handler, raw).await;
488		let responses = parse_responses(&bytes);
489		assert_eq!(responses.len(), 1);
490		// id must be the documented 0 sentinel — there's no parsed id to echo.
491		assert_eq!(responses[0].id, 0);
492		match &responses[0].outcome {
493			ResponseOutcome::Error { error } => assert_eq!(error.kind, WireErrorKind::BadArgs),
494			other => panic!("expected error, got {other:?}"),
495		}
496	}
497
498	#[tokio::test]
499	async fn server_dispatches_streaming_verb_writes_event_then_end() {
500		let handler = Arc::new(StubHandler { last_verb: Mutex::new(None) });
501		let req = Request { id: 99, verb: "stream2".to_string(), args: serde_json::Value::Null };
502		let raw = serde_json::to_string(&req).unwrap() + "\n";
503		let bytes = drive(handler, &raw).await;
504		let responses = parse_responses(&bytes);
505		// 2 events + 1 end = 3 frames, all carrying id=99.
506		assert_eq!(responses.len(), 3, "two events plus a terminating End frame");
507		for r in &responses {
508			assert_eq!(r.id, 99, "streaming frames echo the request id");
509		}
510		assert!(matches!(responses[0].outcome, ResponseOutcome::Event { .. }));
511		assert!(matches!(responses[1].outcome, ResponseOutcome::Event { .. }));
512		assert!(matches!(responses[2].outcome, ResponseOutcome::End { .. }));
513		// Exact event payloads in order.
514		if let ResponseOutcome::Event { event } = &responses[0].outcome {
515			assert_eq!(event["n"], 2);
516		}
517		if let ResponseOutcome::Event { event } = &responses[1].outcome {
518			assert_eq!(event["n"], 1);
519		}
520	}
521
522	#[tokio::test]
523	async fn server_rejects_line_exceeding_cap_with_bad_args() {
524		let handler = Arc::new(StubHandler { last_verb: Mutex::new(None) });
525		// Synthesise a line longer than MAX_NDJSON_LINE_BYTES with no
526		// embedded newline. `read_line_bounded` must abort the
527		// connection with a `BadArgs` frame and not let the request
528		// reach the dispatcher.
529		let huge_line = format!(
530			"{{\"id\":1,\"verb\":\"x\",\"args\":\"{}\"}}\n",
531			"A".repeat(MAX_NDJSON_LINE_BYTES + 1)
532		);
533		let bytes = drive(handler.clone(), &huge_line).await;
534		let responses = parse_responses(&bytes);
535		assert_eq!(responses.len(), 1);
536		match &responses[0].outcome {
537			ResponseOutcome::Error { error } => {
538				assert_eq!(error.kind, WireErrorKind::BadArgs);
539				assert!(error.message.contains("line too long"), "{}", error.message);
540			}
541			other => panic!("expected BadArgs error, got {other:?}"),
542		}
543		// Dispatcher never saw the request: handler's last_verb still None.
544		assert!(handler.last_verb.lock().unwrap().is_none());
545	}
546
547	#[tokio::test]
548	async fn server_stub_handles_multiple_requests_serial_per_connection() {
549		let handler = Arc::new(StubHandler { last_verb: Mutex::new(None) });
550		let r1 =
551			serde_json::to_string(&Request { id: 1, verb: "ping".into(), args: serde_json::Value::Null })
552				.unwrap();
553		let r2 = serde_json::to_string(&Request {
554			id: 2,
555			verb: "echo".into(),
556			args: serde_json::json!({"x": 1}),
557		})
558		.unwrap();
559		let r3 =
560			serde_json::to_string(&Request { id: 3, verb: "nope".into(), args: serde_json::Value::Null })
561				.unwrap();
562		let raw = format!("{r1}\n{r2}\n\n{r3}\n");
563		let bytes = drive(handler, &raw).await;
564		let responses = parse_responses(&bytes);
565		assert_eq!(responses.len(), 3, "blank line is skipped, not echoed back");
566		assert_eq!(responses[0].id, 1);
567		assert_eq!(responses[1].id, 2);
568		assert_eq!(responses[2].id, 3);
569		assert!(matches!(responses[0].outcome, ResponseOutcome::Result { .. }));
570		assert!(matches!(responses[1].outcome, ResponseOutcome::Result { .. }));
571		assert!(matches!(responses[2].outcome, ResponseOutcome::Error { .. }));
572	}
573}