1use std::net::SocketAddr;
21use std::pin::Pin;
22use std::sync::Arc;
23use std::task::{Context, Poll};
24use std::time::Duration;
25
26use bytes::Bytes;
27use http_body_util::{BodyExt, Full, Limited, combinators::BoxBody};
28use hyper::body::{Body, Frame, Incoming};
29use hyper::header::{AUTHORIZATION, CONTENT_TYPE, WWW_AUTHENTICATE};
30use hyper::service::service_fn;
31use hyper::{HeaderMap, Method, StatusCode};
32use hyper_util::rt::{TokioIo, TokioTimer};
33use tokio::net::TcpListener;
34use tokio::sync::{Semaphore, mpsc};
35use tokio::task::JoinHandle;
36use tokio_util::sync::CancellationToken;
37
38use crate::protocol::{EndMarker, Request, Response, ResponseOutcome, encode_line};
39use crate::server::{DispatchOutcome, Handler};
40
41const MAX_REQUEST_BODY_BYTES: usize = 1024 * 1024;
45
46const STREAM_CHANNEL_DEPTH: usize = 64;
51
52const MAX_CONCURRENT_CONNECTIONS: usize = 64;
58
59const HTTP1_MAX_BUF_SIZE: usize = 64 * 1024;
63
64const HTTP1_HEADER_READ_TIMEOUT: Duration = Duration::from_secs(10);
68
69#[derive(Clone, Debug)]
70pub struct HttpServerConfig {
71 pub binds: Vec<SocketAddr>,
74 pub bearer_token: Option<Arc<str>>,
78}
79
80#[derive(thiserror::Error, Debug)]
81pub enum HttpServerError {
82 #[error("ndjson-rpc http: bind {addr} failed: {source}")]
83 Bind { addr: SocketAddr, source: std::io::Error },
84}
85
86pub async fn spawn_http_server<H: Handler>(
95 cfg: HttpServerConfig,
96 handler: Arc<H>,
97 cancel: CancellationToken,
98) -> Result<Vec<JoinHandle<()>>, HttpServerError> {
99 let mut tasks: Vec<JoinHandle<()>> = Vec::with_capacity(cfg.binds.len());
100 for addr in &cfg.binds {
101 let listener = match TcpListener::bind(addr).await {
102 Ok(l) => l,
103 Err(source) => {
104 for t in &tasks {
107 t.abort();
108 }
109 return Err(HttpServerError::Bind { addr: *addr, source });
110 }
111 };
112 let handler = Arc::clone(&handler);
113 let cancel = cancel.clone();
114 let token = cfg.bearer_token.clone();
115 let bind_addr = *addr;
116 tasks.push(tokio::spawn(async move {
117 run_accept_loop(listener, handler, token, cancel, bind_addr).await;
118 }));
119 }
120 Ok(tasks)
121}
122
123async fn run_accept_loop<H: Handler>(
124 listener: TcpListener,
125 handler: Arc<H>,
126 token: Option<Arc<str>>,
127 cancel: CancellationToken,
128 bind_addr: SocketAddr,
129) {
130 tracing::info!(%bind_addr, auth = if token.is_some() { "bearer" } else { "anonymous" }, "mgmt http listening");
131 let semaphore = Arc::new(Semaphore::new(MAX_CONCURRENT_CONNECTIONS));
132 loop {
133 let permit = tokio::select! {
138 biased;
139 () = cancel.cancelled() => return,
140 p = Arc::clone(&semaphore).acquire_owned() => match p {
141 Ok(p) => p,
142 Err(_) => return,
147 },
148 };
149
150 tokio::select! {
151 biased;
152 () = cancel.cancelled() => return,
153 res = listener.accept() => {
154 let (stream, peer) = match res {
155 Ok(v) => v,
156 Err(e) => {
157 tracing::debug!(?e, %bind_addr, "mgmt http accept error");
158 drop(permit);
159 continue;
160 }
161 };
162 let handler = Arc::clone(&handler);
163 let token = token.clone();
164 tokio::spawn(async move {
165 let _permit = permit; let io = TokioIo::new(stream);
167 let svc = service_fn(move |req| {
168 let handler = Arc::clone(&handler);
169 let token = token.clone();
170 async move { handle_request(req, handler, token, peer).await }
171 });
172 let mut builder = hyper::server::conn::http1::Builder::new();
173 builder
174 .keep_alive(false)
175 .max_buf_size(HTTP1_MAX_BUF_SIZE)
176 .timer(TokioTimer::new())
177 .header_read_timeout(HTTP1_HEADER_READ_TIMEOUT);
178 if let Err(e) = builder.serve_connection(io, svc).await {
179 tracing::debug!(?e, %peer, "mgmt http connection ended");
180 }
181 });
182 }
183 }
184 }
185}
186
187type RespBody = BoxBody<Bytes, std::io::Error>;
188
189async fn handle_request<H: Handler>(
190 req: hyper::Request<Incoming>,
191 handler: Arc<H>,
192 token: Option<Arc<str>>,
193 _peer: SocketAddr,
194) -> Result<hyper::Response<RespBody>, std::convert::Infallible> {
195 if req.uri().path() != "/" {
199 return Ok(simple_status(StatusCode::NOT_FOUND));
200 }
201 if req.method() != Method::POST {
202 return Ok(simple_status(StatusCode::METHOD_NOT_ALLOWED));
203 }
204 if let Some(expected) = &token
205 && !verify_bearer(req.headers(), expected)
206 {
207 return Ok(unauthorized());
208 }
209 let body_bytes = match read_request_body(req.into_body()).await {
210 Ok(b) => b,
211 Err(BodyReadError::TooLarge) => {
212 return Ok(text_status(
213 StatusCode::PAYLOAD_TOO_LARGE,
214 "request body exceeds management transport limit",
215 ));
216 }
217 Err(BodyReadError::Io(e)) => {
218 return Ok(text_status(StatusCode::BAD_REQUEST, &format!("body read failed: {e}")));
219 }
220 };
221 let request = match serde_json::from_slice::<Request>(&body_bytes) {
222 Ok(r) => r,
223 Err(e) => return Ok(text_status(StatusCode::BAD_REQUEST, &format!("json parse: {e}"))),
224 };
225 let id = request.id;
226 match handler.dispatch(request).await {
227 DispatchOutcome::OneShot(Ok(value)) => {
228 Ok(oneshot_response(&Response { id, outcome: ResponseOutcome::Result { result: value } }))
229 }
230 DispatchOutcome::OneShot(Err(error)) => {
231 Ok(oneshot_response(&Response { id, outcome: ResponseOutcome::Error { error } }))
232 }
233 DispatchOutcome::Stream(stream) => Ok(streaming_response(id, stream)),
234 }
235}
236
237fn verify_bearer(headers: &HeaderMap, expected: &Arc<str>) -> bool {
245 use subtle::ConstantTimeEq;
246 let Some(value) = headers.get(AUTHORIZATION) else {
247 return false;
248 };
249 let Ok(s) = value.to_str() else { return false };
250 let Some(token) = s.strip_prefix("Bearer ") else { return false };
251 let exp = expected.as_bytes();
252 let got = token.as_bytes();
253 if exp.len() != got.len() {
254 let _ = exp.ct_eq(exp);
259 return false;
260 }
261 bool::from(exp.ct_eq(got))
262}
263
264enum BodyReadError {
265 TooLarge,
266 Io(String),
267}
268
269async fn read_request_body(body: Incoming) -> Result<Bytes, BodyReadError> {
270 let limited = Limited::new(body, MAX_REQUEST_BODY_BYTES);
271 match limited.collect().await {
272 Ok(c) => Ok(c.to_bytes()),
273 Err(e) => {
274 if e.downcast_ref::<http_body_util::LengthLimitError>().is_some() {
277 Err(BodyReadError::TooLarge)
278 } else {
279 Err(BodyReadError::Io(e.to_string()))
280 }
281 }
282 }
283}
284
285fn oneshot_response(frame: &Response) -> hyper::Response<RespBody> {
286 let body_bytes = match serde_json::to_vec(frame) {
287 Ok(b) => Bytes::from(b),
288 Err(e) => {
289 tracing::error!(?e, "mgmt http oneshot encode failed");
290 return text_status(StatusCode::INTERNAL_SERVER_ERROR, "encode failed");
291 }
292 };
293 build_response(StatusCode::OK, "application/json", full_body(body_bytes))
294}
295
296fn streaming_response(
297 id: u64,
298 mut stream: Box<dyn crate::server::EventStream + Send>,
299) -> hyper::Response<RespBody> {
300 let (tx, rx) = mpsc::channel::<Bytes>(STREAM_CHANNEL_DEPTH);
307 tokio::spawn(async move {
308 loop {
309 let Some(event) = stream.next_event().await else {
310 let end = Response { id, outcome: ResponseOutcome::End { end: EndMarker::default() } };
311 if let Ok(bytes) = encode_line(&end) {
312 let _ = tx.send(Bytes::from(bytes)).await;
313 }
314 return;
315 };
316 let frame = Response { id, outcome: ResponseOutcome::Event { event } };
317 let bytes = match encode_line(&frame) {
318 Ok(b) => Bytes::from(b),
319 Err(e) => {
320 tracing::error!(?e, id, "mgmt http stream encode failed");
321 return;
322 }
323 };
324 if tx.send(bytes).await.is_err() {
325 return;
327 }
328 }
329 });
330 let body = ChannelBody { rx }.boxed();
331 build_response(StatusCode::OK, "application/x-ndjson", body)
332}
333
334struct ChannelBody {
335 rx: mpsc::Receiver<Bytes>,
336}
337
338impl Body for ChannelBody {
339 type Data = Bytes;
340 type Error = std::io::Error;
341
342 fn poll_frame(
343 mut self: Pin<&mut Self>,
344 cx: &mut Context<'_>,
345 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
346 match self.rx.poll_recv(cx) {
347 Poll::Ready(Some(b)) => Poll::Ready(Some(Ok(Frame::data(b)))),
348 Poll::Ready(None) => Poll::Ready(None),
349 Poll::Pending => Poll::Pending,
350 }
351 }
352}
353
354fn build_response(
355 status: StatusCode,
356 content_type: &'static str,
357 body: RespBody,
358) -> hyper::Response<RespBody> {
359 let mut resp = hyper::Response::new(body);
360 *resp.status_mut() = status;
361 resp.headers_mut().insert(CONTENT_TYPE, content_type.parse().expect("static content type"));
362 resp
363}
364
365fn full_body(bytes: Bytes) -> RespBody {
366 Full::new(bytes).map_err(|never: std::convert::Infallible| match never {}).boxed()
367}
368
369fn simple_status(status: StatusCode) -> hyper::Response<RespBody> {
370 let mut resp = hyper::Response::new(full_body(Bytes::new()));
371 *resp.status_mut() = status;
372 resp
373}
374
375fn text_status(status: StatusCode, body: &str) -> hyper::Response<RespBody> {
376 let mut resp = hyper::Response::new(full_body(Bytes::copy_from_slice(body.as_bytes())));
377 *resp.status_mut() = status;
378 resp
379 .headers_mut()
380 .insert(CONTENT_TYPE, "text/plain; charset=utf-8".parse().expect("static content type"));
381 resp
382}
383
384fn unauthorized() -> hyper::Response<RespBody> {
385 let mut resp = simple_status(StatusCode::UNAUTHORIZED);
386 resp.headers_mut().insert(WWW_AUTHENTICATE, "Bearer".parse().expect("static auth scheme"));
387 resp
388}
389
390#[cfg(test)]
391mod tests {
392 use super::*;
393
394 fn header_map(values: &[(hyper::header::HeaderName, &str)]) -> HeaderMap {
395 let mut h = HeaderMap::new();
396 for (name, val) in values {
397 h.insert(name.clone(), val.parse().expect("valid header"));
398 }
399 h
400 }
401
402 #[test]
403 fn verify_bearer_accepts_correct_token() {
404 let token: Arc<str> = "s3cret".into();
405 let headers = header_map(&[(AUTHORIZATION, "Bearer s3cret")]);
406 assert!(verify_bearer(&headers, &token));
407 }
408
409 #[test]
410 fn verify_bearer_rejects_wrong_token() {
411 let token: Arc<str> = "s3cret".into();
412 let headers = header_map(&[(AUTHORIZATION, "Bearer wrongx")]);
413 assert!(!verify_bearer(&headers, &token));
414 }
415
416 #[test]
417 fn verify_bearer_rejects_missing_header() {
418 let token: Arc<str> = "s3cret".into();
419 let headers = HeaderMap::new();
420 assert!(!verify_bearer(&headers, &token));
421 }
422
423 #[test]
424 fn verify_bearer_rejects_non_bearer_scheme() {
425 let token: Arc<str> = "s3cret".into();
426 let headers = header_map(&[(AUTHORIZATION, "Basic dXNlcjpwYXNz")]);
427 assert!(!verify_bearer(&headers, &token));
428 }
429
430 #[test]
431 fn verify_bearer_rejects_length_mismatch_without_panic() {
432 let token: Arc<str> = "s3cret".into();
435 let headers = header_map(&[(AUTHORIZATION, "Bearer s3")]);
436 assert!(!verify_bearer(&headers, &token));
437 let headers = header_map(&[(AUTHORIZATION, "Bearer s3cretextra")]);
438 assert!(!verify_bearer(&headers, &token));
439 }
440
441 #[test]
442 fn verify_bearer_rejects_empty_token_value() {
443 let token: Arc<str> = "s3cret".into();
444 let headers = header_map(&[(AUTHORIZATION, "Bearer ")]);
445 assert!(!verify_bearer(&headers, &token));
446 }
447}