1use std::net::SocketAddr;
21use std::pin::Pin;
22use std::sync::Arc;
23use std::task::{Context, Poll};
24
25use bytes::Bytes;
26use http_body_util::{BodyExt, Full, Limited, combinators::BoxBody};
27use hyper::body::{Body, Frame, Incoming};
28use hyper::header::{AUTHORIZATION, CONTENT_TYPE, WWW_AUTHENTICATE};
29use hyper::service::service_fn;
30use hyper::{HeaderMap, Method, StatusCode};
31use hyper_util::rt::TokioIo;
32use tokio::net::TcpListener;
33use tokio::sync::mpsc;
34use tokio::task::JoinHandle;
35use tokio_util::sync::CancellationToken;
36
37use crate::protocol::{EndMarker, Request, Response, ResponseOutcome, encode_line};
38use crate::server::{DispatchOutcome, Handler};
39
40const MAX_REQUEST_BODY_BYTES: usize = 1024 * 1024;
44
45const STREAM_CHANNEL_DEPTH: usize = 64;
50
51#[derive(Clone, Debug)]
52pub struct HttpServerConfig {
53 pub binds: Vec<SocketAddr>,
56 pub bearer_token: Option<Arc<str>>,
60}
61
62#[derive(thiserror::Error, Debug)]
63pub enum HttpServerError {
64 #[error("ndjson-rpc http: bind {addr} failed: {source}")]
65 Bind { addr: SocketAddr, source: std::io::Error },
66}
67
68pub async fn spawn_http_server<H: Handler>(
77 cfg: HttpServerConfig,
78 handler: Arc<H>,
79 cancel: CancellationToken,
80) -> Result<Vec<JoinHandle<()>>, HttpServerError> {
81 let mut tasks: Vec<JoinHandle<()>> = Vec::with_capacity(cfg.binds.len());
82 for addr in &cfg.binds {
83 let listener = match TcpListener::bind(addr).await {
84 Ok(l) => l,
85 Err(source) => {
86 for t in &tasks {
89 t.abort();
90 }
91 return Err(HttpServerError::Bind { addr: *addr, source });
92 }
93 };
94 let handler = Arc::clone(&handler);
95 let cancel = cancel.clone();
96 let token = cfg.bearer_token.clone();
97 let bind_addr = *addr;
98 tasks.push(tokio::spawn(async move {
99 run_accept_loop(listener, handler, token, cancel, bind_addr).await;
100 }));
101 }
102 Ok(tasks)
103}
104
105async fn run_accept_loop<H: Handler>(
106 listener: TcpListener,
107 handler: Arc<H>,
108 token: Option<Arc<str>>,
109 cancel: CancellationToken,
110 bind_addr: SocketAddr,
111) {
112 tracing::info!(%bind_addr, auth = if token.is_some() { "bearer" } else { "anonymous" }, "mgmt http listening");
113 loop {
114 tokio::select! {
115 biased;
116 () = cancel.cancelled() => return,
117 res = listener.accept() => {
118 let (stream, peer) = match res {
119 Ok(v) => v,
120 Err(e) => {
121 tracing::debug!(?e, %bind_addr, "mgmt http accept error");
122 continue;
123 }
124 };
125 let handler = Arc::clone(&handler);
126 let token = token.clone();
127 tokio::spawn(async move {
128 let io = TokioIo::new(stream);
129 let svc = service_fn(move |req| {
130 let handler = Arc::clone(&handler);
131 let token = token.clone();
132 async move { handle_request(req, handler, token, peer).await }
133 });
134 if let Err(e) = hyper::server::conn::http1::Builder::new()
135 .serve_connection(io, svc)
136 .await
137 {
138 tracing::debug!(?e, %peer, "mgmt http connection ended");
139 }
140 });
141 }
142 }
143 }
144}
145
146type RespBody = BoxBody<Bytes, std::io::Error>;
147
148async fn handle_request<H: Handler>(
149 req: hyper::Request<Incoming>,
150 handler: Arc<H>,
151 token: Option<Arc<str>>,
152 _peer: SocketAddr,
153) -> Result<hyper::Response<RespBody>, std::convert::Infallible> {
154 if req.uri().path() != "/" {
158 return Ok(simple_status(StatusCode::NOT_FOUND));
159 }
160 if req.method() != Method::POST {
161 return Ok(simple_status(StatusCode::METHOD_NOT_ALLOWED));
162 }
163 if let Some(expected) = &token
164 && !verify_bearer(req.headers(), expected)
165 {
166 return Ok(unauthorized());
167 }
168 let body_bytes = match read_request_body(req.into_body()).await {
169 Ok(b) => b,
170 Err(BodyReadError::TooLarge) => {
171 return Ok(text_status(
172 StatusCode::PAYLOAD_TOO_LARGE,
173 "request body exceeds management transport limit",
174 ));
175 }
176 Err(BodyReadError::Io(e)) => {
177 return Ok(text_status(StatusCode::BAD_REQUEST, &format!("body read failed: {e}")));
178 }
179 };
180 let request = match serde_json::from_slice::<Request>(&body_bytes) {
181 Ok(r) => r,
182 Err(e) => return Ok(text_status(StatusCode::BAD_REQUEST, &format!("json parse: {e}"))),
183 };
184 let id = request.id;
185 match handler.dispatch(request).await {
186 DispatchOutcome::OneShot(Ok(value)) => {
187 Ok(oneshot_response(&Response { id, outcome: ResponseOutcome::Result { result: value } }))
188 }
189 DispatchOutcome::OneShot(Err(error)) => {
190 Ok(oneshot_response(&Response { id, outcome: ResponseOutcome::Error { error } }))
191 }
192 DispatchOutcome::Stream(stream) => Ok(streaming_response(id, stream)),
193 }
194}
195
196fn verify_bearer(headers: &HeaderMap, expected: &Arc<str>) -> bool {
204 use subtle::ConstantTimeEq;
205 let Some(value) = headers.get(AUTHORIZATION) else {
206 return false;
207 };
208 let Ok(s) = value.to_str() else { return false };
209 let Some(token) = s.strip_prefix("Bearer ") else { return false };
210 let exp = expected.as_bytes();
211 let got = token.as_bytes();
212 if exp.len() != got.len() {
213 let _ = exp.ct_eq(exp);
218 return false;
219 }
220 bool::from(exp.ct_eq(got))
221}
222
223enum BodyReadError {
224 TooLarge,
225 Io(String),
226}
227
228async fn read_request_body(body: Incoming) -> Result<Bytes, BodyReadError> {
229 let limited = Limited::new(body, MAX_REQUEST_BODY_BYTES);
230 match limited.collect().await {
231 Ok(c) => Ok(c.to_bytes()),
232 Err(e) => {
233 if e.downcast_ref::<http_body_util::LengthLimitError>().is_some() {
236 Err(BodyReadError::TooLarge)
237 } else {
238 Err(BodyReadError::Io(e.to_string()))
239 }
240 }
241 }
242}
243
244fn oneshot_response(frame: &Response) -> hyper::Response<RespBody> {
245 let body_bytes = match serde_json::to_vec(frame) {
246 Ok(b) => Bytes::from(b),
247 Err(e) => {
248 tracing::error!(?e, "mgmt http oneshot encode failed");
249 return text_status(StatusCode::INTERNAL_SERVER_ERROR, "encode failed");
250 }
251 };
252 build_response(StatusCode::OK, "application/json", full_body(body_bytes))
253}
254
255fn streaming_response(
256 id: u64,
257 mut stream: Box<dyn crate::server::EventStream + Send>,
258) -> hyper::Response<RespBody> {
259 let (tx, rx) = mpsc::channel::<Bytes>(STREAM_CHANNEL_DEPTH);
266 tokio::spawn(async move {
267 loop {
268 let Some(event) = stream.next_event().await else {
269 let end = Response { id, outcome: ResponseOutcome::End { end: EndMarker::default() } };
270 if let Ok(bytes) = encode_line(&end) {
271 let _ = tx.send(Bytes::from(bytes)).await;
272 }
273 return;
274 };
275 let frame = Response { id, outcome: ResponseOutcome::Event { event } };
276 let bytes = match encode_line(&frame) {
277 Ok(b) => Bytes::from(b),
278 Err(e) => {
279 tracing::error!(?e, id, "mgmt http stream encode failed");
280 return;
281 }
282 };
283 if tx.send(bytes).await.is_err() {
284 return;
286 }
287 }
288 });
289 let body = ChannelBody { rx }.boxed();
290 build_response(StatusCode::OK, "application/x-ndjson", body)
291}
292
293struct ChannelBody {
294 rx: mpsc::Receiver<Bytes>,
295}
296
297impl Body for ChannelBody {
298 type Data = Bytes;
299 type Error = std::io::Error;
300
301 fn poll_frame(
302 mut self: Pin<&mut Self>,
303 cx: &mut Context<'_>,
304 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
305 match self.rx.poll_recv(cx) {
306 Poll::Ready(Some(b)) => Poll::Ready(Some(Ok(Frame::data(b)))),
307 Poll::Ready(None) => Poll::Ready(None),
308 Poll::Pending => Poll::Pending,
309 }
310 }
311}
312
313fn build_response(
314 status: StatusCode,
315 content_type: &'static str,
316 body: RespBody,
317) -> hyper::Response<RespBody> {
318 let mut resp = hyper::Response::new(body);
319 *resp.status_mut() = status;
320 resp.headers_mut().insert(CONTENT_TYPE, content_type.parse().expect("static content type"));
321 resp
322}
323
324fn full_body(bytes: Bytes) -> RespBody {
325 Full::new(bytes).map_err(|never: std::convert::Infallible| match never {}).boxed()
326}
327
328fn simple_status(status: StatusCode) -> hyper::Response<RespBody> {
329 let mut resp = hyper::Response::new(full_body(Bytes::new()));
330 *resp.status_mut() = status;
331 resp
332}
333
334fn text_status(status: StatusCode, body: &str) -> hyper::Response<RespBody> {
335 let mut resp = hyper::Response::new(full_body(Bytes::copy_from_slice(body.as_bytes())));
336 *resp.status_mut() = status;
337 resp
338 .headers_mut()
339 .insert(CONTENT_TYPE, "text/plain; charset=utf-8".parse().expect("static content type"));
340 resp
341}
342
343fn unauthorized() -> hyper::Response<RespBody> {
344 let mut resp = simple_status(StatusCode::UNAUTHORIZED);
345 resp.headers_mut().insert(WWW_AUTHENTICATE, "Bearer".parse().expect("static auth scheme"));
346 resp
347}
348
349#[cfg(test)]
350mod tests {
351 use super::*;
352
353 fn header_map(values: &[(hyper::header::HeaderName, &str)]) -> HeaderMap {
354 let mut h = HeaderMap::new();
355 for (name, val) in values {
356 h.insert(name.clone(), val.parse().expect("valid header"));
357 }
358 h
359 }
360
361 #[test]
362 fn verify_bearer_accepts_correct_token() {
363 let token: Arc<str> = "s3cret".into();
364 let headers = header_map(&[(AUTHORIZATION, "Bearer s3cret")]);
365 assert!(verify_bearer(&headers, &token));
366 }
367
368 #[test]
369 fn verify_bearer_rejects_wrong_token() {
370 let token: Arc<str> = "s3cret".into();
371 let headers = header_map(&[(AUTHORIZATION, "Bearer wrongx")]);
372 assert!(!verify_bearer(&headers, &token));
373 }
374
375 #[test]
376 fn verify_bearer_rejects_missing_header() {
377 let token: Arc<str> = "s3cret".into();
378 let headers = HeaderMap::new();
379 assert!(!verify_bearer(&headers, &token));
380 }
381
382 #[test]
383 fn verify_bearer_rejects_non_bearer_scheme() {
384 let token: Arc<str> = "s3cret".into();
385 let headers = header_map(&[(AUTHORIZATION, "Basic dXNlcjpwYXNz")]);
386 assert!(!verify_bearer(&headers, &token));
387 }
388
389 #[test]
390 fn verify_bearer_rejects_length_mismatch_without_panic() {
391 let token: Arc<str> = "s3cret".into();
394 let headers = header_map(&[(AUTHORIZATION, "Bearer s3")]);
395 assert!(!verify_bearer(&headers, &token));
396 let headers = header_map(&[(AUTHORIZATION, "Bearer s3cretextra")]);
397 assert!(!verify_bearer(&headers, &token));
398 }
399
400 #[test]
401 fn verify_bearer_rejects_empty_token_value() {
402 let token: Arc<str> = "s3cret".into();
403 let headers = header_map(&[(AUTHORIZATION, "Bearer ")]);
404 assert!(!verify_bearer(&headers, &token));
405 }
406}