Skip to main content

ndjson_rpc/
http_server.rs

1//! HTTP-over-TCP transport for the same NDJSON frame shapes the Unix
2//! socket [`crate::server`] speaks.
3//!
4//! Plaintext HTTP/1.1 only — TLS termination is the operator's concern.
5//!
6//! Wire shape:
7//! - request: `POST /` with a JSON body matching [`Request`]; any other
8//!   method or path returns `405` / `404`.
9//! - one-shot reply: `200 OK` + `Content-Type: application/json` + a
10//!   single [`Response`] body.
11//! - streaming reply: `200 OK` + `Content-Type: application/x-ndjson` +
12//!   one JSON [`Response`] frame per chunk, terminated by an `End`
13//!   frame. The client cancels by closing the TCP connection.
14//!
15//! Auth: `Authorization: Bearer <token>`, constant-time compared
16//! against the configured token via the `subtle` crate. Boot
17//! validation (e.g. "anonymous access only on loopback") lives in the
18//! caller.
19
20use std::net::SocketAddr;
21use std::pin::Pin;
22use std::sync::Arc;
23use std::task::{Context, Poll};
24use std::time::Duration;
25
26use bytes::Bytes;
27use http_body_util::{BodyExt, Full, Limited, combinators::BoxBody};
28use hyper::body::{Body, Frame, Incoming};
29use hyper::header::{AUTHORIZATION, CONTENT_TYPE, WWW_AUTHENTICATE};
30use hyper::service::service_fn;
31use hyper::{HeaderMap, Method, StatusCode};
32use hyper_util::rt::{TokioIo, TokioTimer};
33use tokio::net::TcpListener;
34use tokio::sync::{Semaphore, mpsc};
35use tokio::task::JoinHandle;
36use tokio_util::sync::CancellationToken;
37
38use crate::protocol::{EndMarker, Request, Response, ResponseOutcome, encode_line};
39use crate::server::{DispatchOutcome, Handler};
40
41/// Hard cap on request body size. Mgmt requests are tiny (a verb +
42/// arg blob); 1 MiB is generous and lets us reject pathological clients
43/// before they pin RAM.
44const MAX_REQUEST_BODY_BYTES: usize = 1024 * 1024;
45
46/// Channel depth for streaming responses. Each slot holds one already-
47/// encoded NDJSON frame; backpressure flows naturally from a slow client
48/// (TCP buffer fills → hyper stops draining → channel fills → producer
49/// awaits).
50const STREAM_CHANNEL_DEPTH: usize = 64;
51
52/// Per-listener cap on concurrent live HTTP connections. Mgmt is a
53/// control plane, not a data plane — a handful of concurrent operators
54/// is the realistic ceiling, and capping here turns connection-flood
55/// attacks into deterministic 503-ish backpressure (new connects sit
56/// in the OS accept queue) instead of unbounded task spawning.
57const MAX_CONCURRENT_CONNECTIONS: usize = 64;
58
59/// Hard cap on the size of one HTTP/1 header-section read buffer.
60/// Mgmt requests carry one `Authorization` header plus a small POST
61/// body framing; 64 KiB is generous and matches the L1 listener cap.
62const HTTP1_MAX_BUF_SIZE: usize = 64 * 1024;
63
64/// Header-section read timeout. A slowloris client that opens a TCP
65/// connection and dribbles header bytes one per second will be
66/// disconnected here instead of pinning a hyper task indefinitely.
67const HTTP1_HEADER_READ_TIMEOUT: Duration = Duration::from_secs(10);
68
69#[derive(Clone, Debug)]
70pub struct HttpServerConfig {
71	/// Bind addresses. Empty = HTTP transport disabled; the caller
72	/// should not call [`spawn_http_server`] in that case.
73	pub binds: Vec<SocketAddr>,
74	/// `Some(token)` enforces bearer auth; `None` means anonymous
75	/// access (typically only safe on loopback — the caller is
76	/// responsible for that policy decision).
77	pub bearer_token: Option<Arc<str>>,
78}
79
80#[derive(thiserror::Error, Debug)]
81pub enum HttpServerError {
82	#[error("ndjson-rpc http: bind {addr} failed: {source}")]
83	Bind { addr: SocketAddr, source: std::io::Error },
84}
85
86/// Spawn one accept loop per bind address. Returns the spawned task
87/// handles; each task runs until `cancel` fires or the listener errors
88/// fatally.
89///
90/// # Errors
91/// On the first bind failure, returns the error and aborts any tasks
92/// spawned for earlier (already-bound) addresses so the daemon does not
93/// end up serving a partial bind set.
94pub async fn spawn_http_server<H: Handler>(
95	cfg: HttpServerConfig,
96	handler: Arc<H>,
97	cancel: CancellationToken,
98) -> Result<Vec<JoinHandle<()>>, HttpServerError> {
99	let mut tasks: Vec<JoinHandle<()>> = Vec::with_capacity(cfg.binds.len());
100	for addr in &cfg.binds {
101		let listener = match TcpListener::bind(addr).await {
102			Ok(l) => l,
103			Err(source) => {
104				// Roll back any earlier successful binds. Cancellation
105				// is the contract; we honor it for partial-failure too.
106				for t in &tasks {
107					t.abort();
108				}
109				return Err(HttpServerError::Bind { addr: *addr, source });
110			}
111		};
112		let handler = Arc::clone(&handler);
113		let cancel = cancel.clone();
114		let token = cfg.bearer_token.clone();
115		let bind_addr = *addr;
116		tasks.push(tokio::spawn(async move {
117			run_accept_loop(listener, handler, token, cancel, bind_addr).await;
118		}));
119	}
120	Ok(tasks)
121}
122
123async fn run_accept_loop<H: Handler>(
124	listener: TcpListener,
125	handler: Arc<H>,
126	token: Option<Arc<str>>,
127	cancel: CancellationToken,
128	bind_addr: SocketAddr,
129) {
130	tracing::info!(%bind_addr, auth = if token.is_some() { "bearer" } else { "anonymous" }, "mgmt http listening");
131	let semaphore = Arc::new(Semaphore::new(MAX_CONCURRENT_CONNECTIONS));
132	loop {
133		// Reserve a connection slot BEFORE accept so a flood of
134		// connects backs up in the OS accept queue rather than
135		// causing us to spawn an unbounded number of hyper drivers
136		// that all then have to serialize behind the cap.
137		let permit = tokio::select! {
138			biased;
139			() = cancel.cancelled() => return,
140			p = Arc::clone(&semaphore).acquire_owned() => match p {
141				Ok(p) => p,
142				// Semaphore::close is the only way `acquire_owned`
143				// errors. We never close, so this is unreachable in
144				// practice; treat it as a hard stop to avoid a spin
145				// loop if the invariant ever breaks.
146				Err(_) => return,
147			},
148		};
149
150		tokio::select! {
151			biased;
152			() = cancel.cancelled() => return,
153			res = listener.accept() => {
154				let (stream, peer) = match res {
155					Ok(v) => v,
156					Err(e) => {
157						tracing::debug!(?e, %bind_addr, "mgmt http accept error");
158						drop(permit);
159						continue;
160					}
161				};
162				let handler = Arc::clone(&handler);
163				let token = token.clone();
164				tokio::spawn(async move {
165					let _permit = permit; // released when this task exits
166					let io = TokioIo::new(stream);
167					let svc = service_fn(move |req| {
168						let handler = Arc::clone(&handler);
169						let token = token.clone();
170						async move { handle_request(req, handler, token, peer).await }
171					});
172					let mut builder = hyper::server::conn::http1::Builder::new();
173					builder
174						.keep_alive(false)
175						.max_buf_size(HTTP1_MAX_BUF_SIZE)
176						.timer(TokioTimer::new())
177						.header_read_timeout(HTTP1_HEADER_READ_TIMEOUT);
178					if let Err(e) = builder.serve_connection(io, svc).await {
179						tracing::debug!(?e, %peer, "mgmt http connection ended");
180					}
181				});
182			}
183		}
184	}
185}
186
187type RespBody = BoxBody<Bytes, std::io::Error>;
188
189async fn handle_request<H: Handler>(
190	req: hyper::Request<Incoming>,
191	handler: Arc<H>,
192	token: Option<Arc<str>>,
193	_peer: SocketAddr,
194) -> Result<hyper::Response<RespBody>, std::convert::Infallible> {
195	// Method / path gating happens before auth so a misrouted client
196	// gets a deterministic 4xx instead of an auth failure that masks
197	// the real problem.
198	if req.uri().path() != "/" {
199		return Ok(simple_status(StatusCode::NOT_FOUND));
200	}
201	if req.method() != Method::POST {
202		return Ok(simple_status(StatusCode::METHOD_NOT_ALLOWED));
203	}
204	if let Some(expected) = &token
205		&& !verify_bearer(req.headers(), expected)
206	{
207		return Ok(unauthorized());
208	}
209	let body_bytes = match read_request_body(req.into_body()).await {
210		Ok(b) => b,
211		Err(BodyReadError::TooLarge) => {
212			return Ok(text_status(
213				StatusCode::PAYLOAD_TOO_LARGE,
214				"request body exceeds management transport limit",
215			));
216		}
217		Err(BodyReadError::Io(e)) => {
218			return Ok(text_status(StatusCode::BAD_REQUEST, &format!("body read failed: {e}")));
219		}
220	};
221	let request = match serde_json::from_slice::<Request>(&body_bytes) {
222		Ok(r) => r,
223		Err(e) => return Ok(text_status(StatusCode::BAD_REQUEST, &format!("json parse: {e}"))),
224	};
225	let id = request.id;
226	match handler.dispatch(request).await {
227		DispatchOutcome::OneShot(Ok(value)) => {
228			Ok(oneshot_response(&Response { id, outcome: ResponseOutcome::Result { result: value } }))
229		}
230		DispatchOutcome::OneShot(Err(error)) => {
231			Ok(oneshot_response(&Response { id, outcome: ResponseOutcome::Error { error } }))
232		}
233		DispatchOutcome::Stream(stream) => Ok(streaming_response(id, stream)),
234	}
235}
236
237/// Constant-time bearer-token check.
238///
239/// `subtle::ConstantTimeEq` runs in time independent of where the
240/// mismatch is, defeating timing-side-channel guesses against the
241/// token. A length mismatch short-circuits to `false` but still touches
242/// the expected slice once so the call shape stays uniform across
243/// equal- and unequal-length inputs.
244fn verify_bearer(headers: &HeaderMap, expected: &Arc<str>) -> bool {
245	use subtle::ConstantTimeEq;
246	let Some(value) = headers.get(AUTHORIZATION) else {
247		return false;
248	};
249	let Ok(s) = value.to_str() else { return false };
250	let Some(token) = s.strip_prefix("Bearer ") else { return false };
251	let exp = expected.as_bytes();
252	let got = token.as_bytes();
253	if exp.len() != got.len() {
254		// Touch the expected slice so the early-exit branch still does
255		// the same work as a length-equal compare (defence in depth —
256		// the length is recoverable from network framing anyway, but
257		// keep the codepath uniform).
258		let _ = exp.ct_eq(exp);
259		return false;
260	}
261	bool::from(exp.ct_eq(got))
262}
263
264enum BodyReadError {
265	TooLarge,
266	Io(String),
267}
268
269async fn read_request_body(body: Incoming) -> Result<Bytes, BodyReadError> {
270	let limited = Limited::new(body, MAX_REQUEST_BODY_BYTES);
271	match limited.collect().await {
272		Ok(c) => Ok(c.to_bytes()),
273		Err(e) => {
274			// `Limited` boxes the underlying error; we discriminate
275			// "too large" from "io" by downcasting to `LengthLimitError`.
276			if e.downcast_ref::<http_body_util::LengthLimitError>().is_some() {
277				Err(BodyReadError::TooLarge)
278			} else {
279				Err(BodyReadError::Io(e.to_string()))
280			}
281		}
282	}
283}
284
285fn oneshot_response(frame: &Response) -> hyper::Response<RespBody> {
286	let body_bytes = match serde_json::to_vec(frame) {
287		Ok(b) => Bytes::from(b),
288		Err(e) => {
289			tracing::error!(?e, "mgmt http oneshot encode failed");
290			return text_status(StatusCode::INTERNAL_SERVER_ERROR, "encode failed");
291		}
292	};
293	build_response(StatusCode::OK, "application/json", full_body(body_bytes))
294}
295
296fn streaming_response(
297	id: u64,
298	mut stream: Box<dyn crate::server::EventStream + Send>,
299) -> hyper::Response<RespBody> {
300	// Channel decouples the stream producer task from hyper's body
301	// poll loop. When hyper drops the body (client disconnect or
302	// connection error) the receiver drops, the next `tx.send` fails,
303	// and the producer task terminates — which drops `stream`,
304	// triggering the EventStream's own cleanup. That is the
305	// streaming-verb cancellation contract.
306	let (tx, rx) = mpsc::channel::<Bytes>(STREAM_CHANNEL_DEPTH);
307	tokio::spawn(async move {
308		loop {
309			let Some(event) = stream.next_event().await else {
310				let end = Response { id, outcome: ResponseOutcome::End { end: EndMarker::default() } };
311				if let Ok(bytes) = encode_line(&end) {
312					let _ = tx.send(Bytes::from(bytes)).await;
313				}
314				return;
315			};
316			let frame = Response { id, outcome: ResponseOutcome::Event { event } };
317			let bytes = match encode_line(&frame) {
318				Ok(b) => Bytes::from(b),
319				Err(e) => {
320					tracing::error!(?e, id, "mgmt http stream encode failed");
321					return;
322				}
323			};
324			if tx.send(bytes).await.is_err() {
325				// Client disconnected; drop the stream and exit.
326				return;
327			}
328		}
329	});
330	let body = ChannelBody { rx }.boxed();
331	build_response(StatusCode::OK, "application/x-ndjson", body)
332}
333
334struct ChannelBody {
335	rx: mpsc::Receiver<Bytes>,
336}
337
338impl Body for ChannelBody {
339	type Data = Bytes;
340	type Error = std::io::Error;
341
342	fn poll_frame(
343		mut self: Pin<&mut Self>,
344		cx: &mut Context<'_>,
345	) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
346		match self.rx.poll_recv(cx) {
347			Poll::Ready(Some(b)) => Poll::Ready(Some(Ok(Frame::data(b)))),
348			Poll::Ready(None) => Poll::Ready(None),
349			Poll::Pending => Poll::Pending,
350		}
351	}
352}
353
354fn build_response(
355	status: StatusCode,
356	content_type: &'static str,
357	body: RespBody,
358) -> hyper::Response<RespBody> {
359	let mut resp = hyper::Response::new(body);
360	*resp.status_mut() = status;
361	resp.headers_mut().insert(CONTENT_TYPE, content_type.parse().expect("static content type"));
362	resp
363}
364
365fn full_body(bytes: Bytes) -> RespBody {
366	Full::new(bytes).map_err(|never: std::convert::Infallible| match never {}).boxed()
367}
368
369fn simple_status(status: StatusCode) -> hyper::Response<RespBody> {
370	let mut resp = hyper::Response::new(full_body(Bytes::new()));
371	*resp.status_mut() = status;
372	resp
373}
374
375fn text_status(status: StatusCode, body: &str) -> hyper::Response<RespBody> {
376	let mut resp = hyper::Response::new(full_body(Bytes::copy_from_slice(body.as_bytes())));
377	*resp.status_mut() = status;
378	resp
379		.headers_mut()
380		.insert(CONTENT_TYPE, "text/plain; charset=utf-8".parse().expect("static content type"));
381	resp
382}
383
384fn unauthorized() -> hyper::Response<RespBody> {
385	let mut resp = simple_status(StatusCode::UNAUTHORIZED);
386	resp.headers_mut().insert(WWW_AUTHENTICATE, "Bearer".parse().expect("static auth scheme"));
387	resp
388}
389
390#[cfg(test)]
391mod tests {
392	use super::*;
393
394	fn header_map(values: &[(hyper::header::HeaderName, &str)]) -> HeaderMap {
395		let mut h = HeaderMap::new();
396		for (name, val) in values {
397			h.insert(name.clone(), val.parse().expect("valid header"));
398		}
399		h
400	}
401
402	#[test]
403	fn verify_bearer_accepts_correct_token() {
404		let token: Arc<str> = "s3cret".into();
405		let headers = header_map(&[(AUTHORIZATION, "Bearer s3cret")]);
406		assert!(verify_bearer(&headers, &token));
407	}
408
409	#[test]
410	fn verify_bearer_rejects_wrong_token() {
411		let token: Arc<str> = "s3cret".into();
412		let headers = header_map(&[(AUTHORIZATION, "Bearer wrongx")]);
413		assert!(!verify_bearer(&headers, &token));
414	}
415
416	#[test]
417	fn verify_bearer_rejects_missing_header() {
418		let token: Arc<str> = "s3cret".into();
419		let headers = HeaderMap::new();
420		assert!(!verify_bearer(&headers, &token));
421	}
422
423	#[test]
424	fn verify_bearer_rejects_non_bearer_scheme() {
425		let token: Arc<str> = "s3cret".into();
426		let headers = header_map(&[(AUTHORIZATION, "Basic dXNlcjpwYXNz")]);
427		assert!(!verify_bearer(&headers, &token));
428	}
429
430	#[test]
431	fn verify_bearer_rejects_length_mismatch_without_panic() {
432		// The length-mismatch branch must reject without panicking and
433		// without leaking the prefix-match boundary via early return.
434		let token: Arc<str> = "s3cret".into();
435		let headers = header_map(&[(AUTHORIZATION, "Bearer s3")]);
436		assert!(!verify_bearer(&headers, &token));
437		let headers = header_map(&[(AUTHORIZATION, "Bearer s3cretextra")]);
438		assert!(!verify_bearer(&headers, &token));
439	}
440
441	#[test]
442	fn verify_bearer_rejects_empty_token_value() {
443		let token: Arc<str> = "s3cret".into();
444		let headers = header_map(&[(AUTHORIZATION, "Bearer ")]);
445		assert!(!verify_bearer(&headers, &token));
446	}
447}