Skip to main content

ndjson_rpc/
server.rs

1//! Unix-socket accept loop + per-connection line-delimited JSON
2//! dispatch to verb handlers. The HTTP-over-TCP transport
3//! ([`crate::http_server`]) speaks the same frame shapes, so dispatch
4//! logic is shared.
5
6use std::os::unix::fs::PermissionsExt;
7use std::path::{Path, PathBuf};
8use std::sync::Arc;
9
10use async_trait::async_trait;
11use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader};
12use tokio::net::{UnixListener, UnixStream};
13use tokio::task::JoinHandle;
14use tokio_util::sync::CancellationToken;
15
16use crate::protocol::{
17	EndMarker, Request, Response, ResponseOutcome, WireError, WireErrorKind, encode_line,
18};
19
20/// Server-side dispatcher. Callers implement this against their own
21/// application state and pass an `Arc<H>` to [`spawn_unix_server`] or
22/// [`crate::spawn_http_server`].
23#[async_trait]
24pub trait Handler: Send + Sync + 'static {
25	/// Dispatch a parsed request to either a one-shot result or a
26	/// streaming event source. The server frames whichever outcome the
27	/// handler returns.
28	async fn dispatch(&self, req: Request) -> DispatchOutcome;
29}
30
31/// What `dispatch` returns. One-shot verbs (`ping`, `stats`, ...)
32/// produce a single result/error frame; streaming verbs
33/// (`tail_flow`) produce a sequence of `Event` frames terminated
34/// by an `End` frame.
35pub enum DispatchOutcome {
36	/// One-shot reply: a single JSON value or a structured error.
37	OneShot(Result<serde_json::Value, WireError>),
38	/// Streaming reply: each call to `next_event` yields the next
39	/// `Event` payload, or `None` to terminate with an `End` frame.
40	Stream(Box<dyn EventStream + Send>),
41}
42
43/// A streaming event source. The server polls `next_event` until the
44/// client disconnects or the stream returns `None`.
45#[async_trait]
46pub trait EventStream: Send {
47	/// `Some(event)` = next event payload to write as `Event { event }`.
48	/// `None` = stream terminated normally; the server writes `End`.
49	async fn next_event(&mut self) -> Option<serde_json::Value>;
50}
51
52/// Bind a Unix socket and serve mgmt requests until `cancel` fires.
53///
54/// On bind, an existing socket file at `socket_path` is unlinked first
55/// — operators are responsible for ensuring no other `vaned` is using
56/// the path. The socket file's mode is set to `0600`: mgmt access is
57/// gated by file-system permissions only, no in-band auth.
58///
59/// On cancellation, the bound socket file is removed before the task
60/// returns so a subsequent `vaned` boot can re-bind cleanly.
61///
62/// # Errors
63/// Bind / chmod / remove-stale-file failures bubble up as
64/// [`std::io::Error`].
65pub async fn spawn_unix_server<H: Handler>(
66	socket_path: &Path,
67	handler: Arc<H>,
68	cancel: CancellationToken,
69) -> std::io::Result<JoinHandle<()>> {
70	// Unlink any stale socket file before bind. systemd-style socket
71	// activation is not supported this round.
72	let _ = std::fs::remove_file(socket_path);
73	let listener = UnixListener::bind(socket_path)?;
74
75	let perms = std::fs::Permissions::from_mode(0o600);
76	std::fs::set_permissions(socket_path, perms)?;
77
78	let socket_path: PathBuf = socket_path.to_path_buf();
79	let handle = tokio::spawn(async move {
80		loop {
81			tokio::select! {
82				biased;
83				() = cancel.cancelled() => {
84					let _ = std::fs::remove_file(&socket_path);
85					return;
86				}
87				accepted = listener.accept() => {
88					let stream: UnixStream = match accepted {
89						Ok((s, _)) => s,
90						Err(e) => {
91							tracing::warn!(?e, "mgmt accept failed");
92							continue;
93						}
94					};
95					let h = Arc::clone(&handler);
96					tokio::spawn(async move {
97						let (read, write) = stream.into_split();
98						handle_conn(read, write, h).await;
99					});
100				}
101			}
102		}
103	});
104	Ok(handle)
105}
106
107/// Generic request loop, abstract over the read/write halves so unit
108/// tests can drive it with `tokio::io::duplex` instead of a real Unix
109/// socket. Production callers always pass the halves of a
110/// [`tokio::net::UnixStream`].
111pub(crate) async fn handle_conn<R, W, H>(read: R, mut write: W, handler: Arc<H>)
112where
113	R: AsyncRead + Unpin,
114	W: AsyncWrite + Unpin,
115	H: Handler,
116{
117	let mut lines = BufReader::new(read).lines();
118	loop {
119		let line = match lines.next_line().await {
120			Ok(Some(l)) => l,
121			Ok(None) => return,
122			Err(e) => {
123				tracing::debug!(?e, "mgmt read failed");
124				return;
125			}
126		};
127		if line.is_empty() {
128			continue;
129		}
130		match serde_json::from_str::<Request>(&line) {
131			Ok(req) => {
132				let id = req.id;
133				match handler.dispatch(req).await {
134					DispatchOutcome::OneShot(Ok(value)) => {
135						let frame = Response { id, outcome: ResponseOutcome::Result { result: value } };
136						if write_frame(&mut write, &frame).await.is_err() {
137							return;
138						}
139					}
140					DispatchOutcome::OneShot(Err(error)) => {
141						let frame = Response { id, outcome: ResponseOutcome::Error { error } };
142						if write_frame(&mut write, &frame).await.is_err() {
143							return;
144						}
145					}
146					DispatchOutcome::Stream(mut stream) => {
147						// Streaming verbs consume the connection — once we
148						// start streaming we don't read more requests on
149						// this socket. Client disconnects by closing the
150						// socket; server detects that via write failure
151						// or the read-side seeing EOF.
152						loop {
153							let Some(event) = stream.next_event().await else {
154								let end =
155									Response { id, outcome: ResponseOutcome::End { end: EndMarker::default() } };
156								let _ = write_frame(&mut write, &end).await;
157								return;
158							};
159							let frame = Response { id, outcome: ResponseOutcome::Event { event } };
160							if write_frame(&mut write, &frame).await.is_err() {
161								return;
162							}
163						}
164					}
165				}
166			}
167			Err(e) => {
168				let frame = Response {
169					// id is unknown when the frame fails to parse — `0` is
170					// the documented sentinel for "no correlation possible".
171					id: 0,
172					outcome: ResponseOutcome::Error {
173						error: WireError { kind: WireErrorKind::BadArgs, message: format!("parse: {e}") },
174					},
175				};
176				if write_frame(&mut write, &frame).await.is_err() {
177					return;
178				}
179			}
180		}
181	}
182}
183
184/// Encode a response and write it as one NDJSON line. Wraps the two
185/// fallible sub-steps (encode → write) so the streaming loop has a
186/// single error path.
187async fn write_frame<W: AsyncWrite + Unpin>(
188	write: &mut W,
189	frame: &Response,
190) -> Result<(), std::io::Error> {
191	let bytes = match encode_line(frame) {
192		Ok(b) => b,
193		Err(e) => {
194			tracing::error!(?e, "mgmt response encode failed");
195			return Err(std::io::Error::other(e));
196		}
197	};
198	write.write_all(&bytes).await
199}
200
201#[cfg(test)]
202mod tests {
203	use super::*;
204	use std::sync::Mutex;
205
206	struct StubHandler {
207		// Records the last verb seen, for assertions.
208		last_verb: Mutex<Option<String>>,
209	}
210
211	#[async_trait]
212	impl Handler for StubHandler {
213		async fn dispatch(&self, req: Request) -> DispatchOutcome {
214			*self.last_verb.lock().unwrap() = Some(req.verb.clone());
215			let result: Result<serde_json::Value, WireError> = match req.verb.as_str() {
216				"ping" => Ok(serde_json::json!({ "pong": true })),
217				"echo" => Ok(req.args),
218				"stream2" => {
219					return DispatchOutcome::Stream(Box::new(MockStream::with_two_events()));
220				}
221				_ => Err(WireError {
222					kind: WireErrorKind::UnknownVerb,
223					message: format!("unknown {}", req.verb),
224				}),
225			};
226			DispatchOutcome::OneShot(result)
227		}
228	}
229
230	/// Trivial event stream: emits two events then terminates with `None`,
231	/// modelling the smallest possible streaming verb.
232	struct MockStream {
233		remaining: Vec<serde_json::Value>,
234	}
235
236	impl MockStream {
237		fn with_two_events() -> Self {
238			// Pop returns the last element first; queue events in reverse
239			// so the wire ordering observed by the client matches the
240			// natural reading order (n=2 then n=1).
241			Self { remaining: vec![serde_json::json!({ "n": 1 }), serde_json::json!({ "n": 2 })] }
242		}
243	}
244
245	#[async_trait]
246	impl EventStream for MockStream {
247		async fn next_event(&mut self) -> Option<serde_json::Value> {
248			self.remaining.pop()
249		}
250	}
251
252	/// Pump one or more request lines through `handle_conn` against a
253	/// stub handler, returning the response bytes the server wrote.
254	async fn drive(handler: Arc<StubHandler>, requests: &str) -> Vec<u8> {
255		// Client writes requests on `c2s_w` → server reads on `c2s_r`.
256		// Server writes responses on `s2c_w` → client reads on `s2c_r`.
257		let (c2s_r, mut c2s_w) = tokio::io::duplex(8192);
258		let (s2c_w, mut s2c_r) = tokio::io::duplex(8192);
259		let req = requests.to_string();
260		let server_task = tokio::spawn(handle_conn(c2s_r, s2c_w, handler));
261		c2s_w.write_all(req.as_bytes()).await.expect("write requests");
262		// Closing the write half makes `next_line` return None on the
263		// server side so the task completes cleanly.
264		drop(c2s_w);
265		server_task.await.expect("server task");
266		// Read everything the server wrote.
267		let mut buf = Vec::new();
268		tokio::io::AsyncReadExt::read_to_end(&mut s2c_r, &mut buf).await.expect("read responses");
269		buf
270	}
271
272	fn parse_responses(bytes: &[u8]) -> Vec<Response> {
273		std::str::from_utf8(bytes)
274			.expect("utf8")
275			.lines()
276			.filter(|l| !l.is_empty())
277			.map(|l| serde_json::from_str(l).expect("parse response"))
278			.collect()
279	}
280
281	#[tokio::test]
282	async fn server_stub_dispatches_known_verb_and_writes_result_line() {
283		let handler = Arc::new(StubHandler { last_verb: Mutex::new(None) });
284		let req = Request { id: 11, verb: "ping".to_string(), args: serde_json::Value::Null };
285		let raw = serde_json::to_string(&req).unwrap() + "\n";
286		let bytes = drive(Arc::clone(&handler), &raw).await;
287		let responses = parse_responses(&bytes);
288		assert_eq!(responses.len(), 1);
289		assert_eq!(responses[0].id, 11);
290		match &responses[0].outcome {
291			ResponseOutcome::Result { result } => assert_eq!(result["pong"], true),
292			other => panic!("unexpected outcome: {other:?}"),
293		}
294		assert_eq!(handler.last_verb.lock().unwrap().as_deref(), Some("ping"));
295	}
296
297	#[tokio::test]
298	async fn server_stub_writes_error_for_unknown_verb() {
299		let handler = Arc::new(StubHandler { last_verb: Mutex::new(None) });
300		let req = Request { id: 5, verb: "wat".to_string(), args: serde_json::Value::Null };
301		let raw = serde_json::to_string(&req).unwrap() + "\n";
302		let bytes = drive(handler, &raw).await;
303		let responses = parse_responses(&bytes);
304		assert_eq!(responses.len(), 1);
305		assert_eq!(responses[0].id, 5);
306		match &responses[0].outcome {
307			ResponseOutcome::Error { error } => {
308				assert_eq!(error.kind, WireErrorKind::UnknownVerb);
309				assert!(error.message.contains("wat"));
310			}
311			other => panic!("expected error, got {other:?}"),
312		}
313	}
314
315	#[tokio::test]
316	async fn server_stub_writes_bad_args_error_for_unparseable_request() {
317		let handler = Arc::new(StubHandler { last_verb: Mutex::new(None) });
318		let raw = "this is not json\n";
319		let bytes = drive(handler, raw).await;
320		let responses = parse_responses(&bytes);
321		assert_eq!(responses.len(), 1);
322		// id must be the documented 0 sentinel — there's no parsed id to echo.
323		assert_eq!(responses[0].id, 0);
324		match &responses[0].outcome {
325			ResponseOutcome::Error { error } => assert_eq!(error.kind, WireErrorKind::BadArgs),
326			other => panic!("expected error, got {other:?}"),
327		}
328	}
329
330	#[tokio::test]
331	async fn server_dispatches_streaming_verb_writes_event_then_end() {
332		let handler = Arc::new(StubHandler { last_verb: Mutex::new(None) });
333		let req = Request { id: 99, verb: "stream2".to_string(), args: serde_json::Value::Null };
334		let raw = serde_json::to_string(&req).unwrap() + "\n";
335		let bytes = drive(handler, &raw).await;
336		let responses = parse_responses(&bytes);
337		// 2 events + 1 end = 3 frames, all carrying id=99.
338		assert_eq!(responses.len(), 3, "two events plus a terminating End frame");
339		for r in &responses {
340			assert_eq!(r.id, 99, "streaming frames echo the request id");
341		}
342		assert!(matches!(responses[0].outcome, ResponseOutcome::Event { .. }));
343		assert!(matches!(responses[1].outcome, ResponseOutcome::Event { .. }));
344		assert!(matches!(responses[2].outcome, ResponseOutcome::End { .. }));
345		// Exact event payloads in order.
346		if let ResponseOutcome::Event { event } = &responses[0].outcome {
347			assert_eq!(event["n"], 2);
348		}
349		if let ResponseOutcome::Event { event } = &responses[1].outcome {
350			assert_eq!(event["n"], 1);
351		}
352	}
353
354	#[tokio::test]
355	async fn server_stub_handles_multiple_requests_serial_per_connection() {
356		let handler = Arc::new(StubHandler { last_verb: Mutex::new(None) });
357		let r1 =
358			serde_json::to_string(&Request { id: 1, verb: "ping".into(), args: serde_json::Value::Null })
359				.unwrap();
360		let r2 = serde_json::to_string(&Request {
361			id: 2,
362			verb: "echo".into(),
363			args: serde_json::json!({"x": 1}),
364		})
365		.unwrap();
366		let r3 =
367			serde_json::to_string(&Request { id: 3, verb: "nope".into(), args: serde_json::Value::Null })
368				.unwrap();
369		let raw = format!("{r1}\n{r2}\n\n{r3}\n");
370		let bytes = drive(handler, &raw).await;
371		let responses = parse_responses(&bytes);
372		assert_eq!(responses.len(), 3, "blank line is skipped, not echoed back");
373		assert_eq!(responses[0].id, 1);
374		assert_eq!(responses[1].id, 2);
375		assert_eq!(responses[2].id, 3);
376		assert!(matches!(responses[0].outcome, ResponseOutcome::Result { .. }));
377		assert!(matches!(responses[1].outcome, ResponseOutcome::Result { .. }));
378		assert!(matches!(responses[2].outcome, ResponseOutcome::Error { .. }));
379	}
380}