Skip to main content

ndjson_rpc/
http_server.rs

1//! HTTP-over-TCP transport for the same NDJSON frame shapes the Unix
2//! socket [`crate::server`] speaks.
3//!
4//! Plaintext HTTP/1.1 only — TLS termination is the operator's concern.
5//!
6//! Wire shape:
7//! - request: `POST /` with a JSON body matching [`Request`]; any other
8//!   method or path returns `405` / `404`.
9//! - one-shot reply: `200 OK` + `Content-Type: application/json` + a
10//!   single [`Response`] body.
11//! - streaming reply: `200 OK` + `Content-Type: application/x-ndjson` +
12//!   one JSON [`Response`] frame per chunk, terminated by an `End`
13//!   frame. The client cancels by closing the TCP connection.
14//!
15//! Auth: `Authorization: Bearer <token>`, constant-time compared
16//! against the configured token via the `subtle` crate. Boot
17//! validation (e.g. "anonymous access only on loopback") lives in the
18//! caller.
19
20use std::net::SocketAddr;
21use std::pin::Pin;
22use std::sync::Arc;
23use std::task::{Context, Poll};
24
25use bytes::Bytes;
26use http_body_util::{BodyExt, Full, Limited, combinators::BoxBody};
27use hyper::body::{Body, Frame, Incoming};
28use hyper::header::{AUTHORIZATION, CONTENT_TYPE, WWW_AUTHENTICATE};
29use hyper::service::service_fn;
30use hyper::{HeaderMap, Method, StatusCode};
31use hyper_util::rt::TokioIo;
32use tokio::net::TcpListener;
33use tokio::sync::mpsc;
34use tokio::task::JoinHandle;
35use tokio_util::sync::CancellationToken;
36
37use crate::protocol::{EndMarker, Request, Response, ResponseOutcome, encode_line};
38use crate::server::{DispatchOutcome, Handler};
39
40/// Hard cap on request body size. Mgmt requests are tiny (a verb +
41/// arg blob); 1 MiB is generous and lets us reject pathological clients
42/// before they pin RAM.
43const MAX_REQUEST_BODY_BYTES: usize = 1024 * 1024;
44
45/// Channel depth for streaming responses. Each slot holds one already-
46/// encoded NDJSON frame; backpressure flows naturally from a slow client
47/// (TCP buffer fills → hyper stops draining → channel fills → producer
48/// awaits).
49const STREAM_CHANNEL_DEPTH: usize = 64;
50
51#[derive(Clone, Debug)]
52pub struct HttpServerConfig {
53	/// Bind addresses. Empty = HTTP transport disabled; the caller
54	/// should not call [`spawn_http_server`] in that case.
55	pub binds: Vec<SocketAddr>,
56	/// `Some(token)` enforces bearer auth; `None` means anonymous
57	/// access (typically only safe on loopback — the caller is
58	/// responsible for that policy decision).
59	pub bearer_token: Option<Arc<str>>,
60}
61
62#[derive(thiserror::Error, Debug)]
63pub enum HttpServerError {
64	#[error("ndjson-rpc http: bind {addr} failed: {source}")]
65	Bind { addr: SocketAddr, source: std::io::Error },
66}
67
68/// Spawn one accept loop per bind address. Returns the spawned task
69/// handles; each task runs until `cancel` fires or the listener errors
70/// fatally.
71///
72/// # Errors
73/// On the first bind failure, returns the error and aborts any tasks
74/// spawned for earlier (already-bound) addresses so the daemon does not
75/// end up serving a partial bind set.
76pub async fn spawn_http_server<H: Handler>(
77	cfg: HttpServerConfig,
78	handler: Arc<H>,
79	cancel: CancellationToken,
80) -> Result<Vec<JoinHandle<()>>, HttpServerError> {
81	let mut tasks: Vec<JoinHandle<()>> = Vec::with_capacity(cfg.binds.len());
82	for addr in &cfg.binds {
83		let listener = match TcpListener::bind(addr).await {
84			Ok(l) => l,
85			Err(source) => {
86				// Roll back any earlier successful binds. Cancellation
87				// is the contract; we honor it for partial-failure too.
88				for t in &tasks {
89					t.abort();
90				}
91				return Err(HttpServerError::Bind { addr: *addr, source });
92			}
93		};
94		let handler = Arc::clone(&handler);
95		let cancel = cancel.clone();
96		let token = cfg.bearer_token.clone();
97		let bind_addr = *addr;
98		tasks.push(tokio::spawn(async move {
99			run_accept_loop(listener, handler, token, cancel, bind_addr).await;
100		}));
101	}
102	Ok(tasks)
103}
104
105async fn run_accept_loop<H: Handler>(
106	listener: TcpListener,
107	handler: Arc<H>,
108	token: Option<Arc<str>>,
109	cancel: CancellationToken,
110	bind_addr: SocketAddr,
111) {
112	tracing::info!(%bind_addr, auth = if token.is_some() { "bearer" } else { "anonymous" }, "mgmt http listening");
113	loop {
114		tokio::select! {
115			biased;
116			() = cancel.cancelled() => return,
117			res = listener.accept() => {
118				let (stream, peer) = match res {
119					Ok(v) => v,
120					Err(e) => {
121						tracing::debug!(?e, %bind_addr, "mgmt http accept error");
122						continue;
123					}
124				};
125				let handler = Arc::clone(&handler);
126				let token = token.clone();
127				tokio::spawn(async move {
128					let io = TokioIo::new(stream);
129					let svc = service_fn(move |req| {
130						let handler = Arc::clone(&handler);
131						let token = token.clone();
132						async move { handle_request(req, handler, token, peer).await }
133					});
134					if let Err(e) = hyper::server::conn::http1::Builder::new()
135						.serve_connection(io, svc)
136						.await
137					{
138						tracing::debug!(?e, %peer, "mgmt http connection ended");
139					}
140				});
141			}
142		}
143	}
144}
145
146type RespBody = BoxBody<Bytes, std::io::Error>;
147
148async fn handle_request<H: Handler>(
149	req: hyper::Request<Incoming>,
150	handler: Arc<H>,
151	token: Option<Arc<str>>,
152	_peer: SocketAddr,
153) -> Result<hyper::Response<RespBody>, std::convert::Infallible> {
154	// Method / path gating happens before auth so a misrouted client
155	// gets a deterministic 4xx instead of an auth failure that masks
156	// the real problem.
157	if req.uri().path() != "/" {
158		return Ok(simple_status(StatusCode::NOT_FOUND));
159	}
160	if req.method() != Method::POST {
161		return Ok(simple_status(StatusCode::METHOD_NOT_ALLOWED));
162	}
163	if let Some(expected) = &token
164		&& !verify_bearer(req.headers(), expected)
165	{
166		return Ok(unauthorized());
167	}
168	let body_bytes = match read_request_body(req.into_body()).await {
169		Ok(b) => b,
170		Err(BodyReadError::TooLarge) => {
171			return Ok(text_status(
172				StatusCode::PAYLOAD_TOO_LARGE,
173				"request body exceeds management transport limit",
174			));
175		}
176		Err(BodyReadError::Io(e)) => {
177			return Ok(text_status(StatusCode::BAD_REQUEST, &format!("body read failed: {e}")));
178		}
179	};
180	let request = match serde_json::from_slice::<Request>(&body_bytes) {
181		Ok(r) => r,
182		Err(e) => return Ok(text_status(StatusCode::BAD_REQUEST, &format!("json parse: {e}"))),
183	};
184	let id = request.id;
185	match handler.dispatch(request).await {
186		DispatchOutcome::OneShot(Ok(value)) => {
187			Ok(oneshot_response(&Response { id, outcome: ResponseOutcome::Result { result: value } }))
188		}
189		DispatchOutcome::OneShot(Err(error)) => {
190			Ok(oneshot_response(&Response { id, outcome: ResponseOutcome::Error { error } }))
191		}
192		DispatchOutcome::Stream(stream) => Ok(streaming_response(id, stream)),
193	}
194}
195
196/// Constant-time bearer-token check.
197///
198/// `subtle::ConstantTimeEq` runs in time independent of where the
199/// mismatch is, defeating timing-side-channel guesses against the
200/// token. A length mismatch short-circuits to `false` but still touches
201/// the expected slice once so the call shape stays uniform across
202/// equal- and unequal-length inputs.
203fn verify_bearer(headers: &HeaderMap, expected: &Arc<str>) -> bool {
204	use subtle::ConstantTimeEq;
205	let Some(value) = headers.get(AUTHORIZATION) else {
206		return false;
207	};
208	let Ok(s) = value.to_str() else { return false };
209	let Some(token) = s.strip_prefix("Bearer ") else { return false };
210	let exp = expected.as_bytes();
211	let got = token.as_bytes();
212	if exp.len() != got.len() {
213		// Touch the expected slice so the early-exit branch still does
214		// the same work as a length-equal compare (defence in depth —
215		// the length is recoverable from network framing anyway, but
216		// keep the codepath uniform).
217		let _ = exp.ct_eq(exp);
218		return false;
219	}
220	bool::from(exp.ct_eq(got))
221}
222
223enum BodyReadError {
224	TooLarge,
225	Io(String),
226}
227
228async fn read_request_body(body: Incoming) -> Result<Bytes, BodyReadError> {
229	let limited = Limited::new(body, MAX_REQUEST_BODY_BYTES);
230	match limited.collect().await {
231		Ok(c) => Ok(c.to_bytes()),
232		Err(e) => {
233			// `Limited` boxes the underlying error; we discriminate
234			// "too large" from "io" by downcasting to `LengthLimitError`.
235			if e.downcast_ref::<http_body_util::LengthLimitError>().is_some() {
236				Err(BodyReadError::TooLarge)
237			} else {
238				Err(BodyReadError::Io(e.to_string()))
239			}
240		}
241	}
242}
243
244fn oneshot_response(frame: &Response) -> hyper::Response<RespBody> {
245	let body_bytes = match serde_json::to_vec(frame) {
246		Ok(b) => Bytes::from(b),
247		Err(e) => {
248			tracing::error!(?e, "mgmt http oneshot encode failed");
249			return text_status(StatusCode::INTERNAL_SERVER_ERROR, "encode failed");
250		}
251	};
252	build_response(StatusCode::OK, "application/json", full_body(body_bytes))
253}
254
255fn streaming_response(
256	id: u64,
257	mut stream: Box<dyn crate::server::EventStream + Send>,
258) -> hyper::Response<RespBody> {
259	// Channel decouples the stream producer task from hyper's body
260	// poll loop. When hyper drops the body (client disconnect or
261	// connection error) the receiver drops, the next `tx.send` fails,
262	// and the producer task terminates — which drops `stream`,
263	// triggering the EventStream's own cleanup. That is the
264	// streaming-verb cancellation contract.
265	let (tx, rx) = mpsc::channel::<Bytes>(STREAM_CHANNEL_DEPTH);
266	tokio::spawn(async move {
267		loop {
268			let Some(event) = stream.next_event().await else {
269				let end = Response { id, outcome: ResponseOutcome::End { end: EndMarker::default() } };
270				if let Ok(bytes) = encode_line(&end) {
271					let _ = tx.send(Bytes::from(bytes)).await;
272				}
273				return;
274			};
275			let frame = Response { id, outcome: ResponseOutcome::Event { event } };
276			let bytes = match encode_line(&frame) {
277				Ok(b) => Bytes::from(b),
278				Err(e) => {
279					tracing::error!(?e, id, "mgmt http stream encode failed");
280					return;
281				}
282			};
283			if tx.send(bytes).await.is_err() {
284				// Client disconnected; drop the stream and exit.
285				return;
286			}
287		}
288	});
289	let body = ChannelBody { rx }.boxed();
290	build_response(StatusCode::OK, "application/x-ndjson", body)
291}
292
293struct ChannelBody {
294	rx: mpsc::Receiver<Bytes>,
295}
296
297impl Body for ChannelBody {
298	type Data = Bytes;
299	type Error = std::io::Error;
300
301	fn poll_frame(
302		mut self: Pin<&mut Self>,
303		cx: &mut Context<'_>,
304	) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
305		match self.rx.poll_recv(cx) {
306			Poll::Ready(Some(b)) => Poll::Ready(Some(Ok(Frame::data(b)))),
307			Poll::Ready(None) => Poll::Ready(None),
308			Poll::Pending => Poll::Pending,
309		}
310	}
311}
312
313fn build_response(
314	status: StatusCode,
315	content_type: &'static str,
316	body: RespBody,
317) -> hyper::Response<RespBody> {
318	let mut resp = hyper::Response::new(body);
319	*resp.status_mut() = status;
320	resp.headers_mut().insert(CONTENT_TYPE, content_type.parse().expect("static content type"));
321	resp
322}
323
324fn full_body(bytes: Bytes) -> RespBody {
325	Full::new(bytes).map_err(|never: std::convert::Infallible| match never {}).boxed()
326}
327
328fn simple_status(status: StatusCode) -> hyper::Response<RespBody> {
329	let mut resp = hyper::Response::new(full_body(Bytes::new()));
330	*resp.status_mut() = status;
331	resp
332}
333
334fn text_status(status: StatusCode, body: &str) -> hyper::Response<RespBody> {
335	let mut resp = hyper::Response::new(full_body(Bytes::copy_from_slice(body.as_bytes())));
336	*resp.status_mut() = status;
337	resp
338		.headers_mut()
339		.insert(CONTENT_TYPE, "text/plain; charset=utf-8".parse().expect("static content type"));
340	resp
341}
342
343fn unauthorized() -> hyper::Response<RespBody> {
344	let mut resp = simple_status(StatusCode::UNAUTHORIZED);
345	resp.headers_mut().insert(WWW_AUTHENTICATE, "Bearer".parse().expect("static auth scheme"));
346	resp
347}
348
349#[cfg(test)]
350mod tests {
351	use super::*;
352
353	fn header_map(values: &[(hyper::header::HeaderName, &str)]) -> HeaderMap {
354		let mut h = HeaderMap::new();
355		for (name, val) in values {
356			h.insert(name.clone(), val.parse().expect("valid header"));
357		}
358		h
359	}
360
361	#[test]
362	fn verify_bearer_accepts_correct_token() {
363		let token: Arc<str> = "s3cret".into();
364		let headers = header_map(&[(AUTHORIZATION, "Bearer s3cret")]);
365		assert!(verify_bearer(&headers, &token));
366	}
367
368	#[test]
369	fn verify_bearer_rejects_wrong_token() {
370		let token: Arc<str> = "s3cret".into();
371		let headers = header_map(&[(AUTHORIZATION, "Bearer wrongx")]);
372		assert!(!verify_bearer(&headers, &token));
373	}
374
375	#[test]
376	fn verify_bearer_rejects_missing_header() {
377		let token: Arc<str> = "s3cret".into();
378		let headers = HeaderMap::new();
379		assert!(!verify_bearer(&headers, &token));
380	}
381
382	#[test]
383	fn verify_bearer_rejects_non_bearer_scheme() {
384		let token: Arc<str> = "s3cret".into();
385		let headers = header_map(&[(AUTHORIZATION, "Basic dXNlcjpwYXNz")]);
386		assert!(!verify_bearer(&headers, &token));
387	}
388
389	#[test]
390	fn verify_bearer_rejects_length_mismatch_without_panic() {
391		// The length-mismatch branch must reject without panicking and
392		// without leaking the prefix-match boundary via early return.
393		let token: Arc<str> = "s3cret".into();
394		let headers = header_map(&[(AUTHORIZATION, "Bearer s3")]);
395		assert!(!verify_bearer(&headers, &token));
396		let headers = header_map(&[(AUTHORIZATION, "Bearer s3cretextra")]);
397		assert!(!verify_bearer(&headers, &token));
398	}
399
400	#[test]
401	fn verify_bearer_rejects_empty_token_value() {
402		let token: Arc<str> = "s3cret".into();
403		let headers = header_map(&[(AUTHORIZATION, "Bearer ")]);
404		assert!(!verify_bearer(&headers, &token));
405	}
406}