use std::io::{self, Write};
use std::sync::Arc;
use axum::body::Body;
use axum::extract::State;
use axum::response::Response;
use bytes::Bytes;
use tokio::sync::{mpsc, oneshot};
use tokio_stream::wrappers::ReceiverStream;
use hyper_util::rt::{TokioExecutor, TokioIo};
use hyper_util::server::conn::auto::Builder as AutoConnBuilder;
use hyper_util::service::TowerToHyperService;
use super::http_connection_limiter::HttpConnectionPermit;
use super::http_handler_metrics::{HttpRejectReason, HttpTransport};
use super::http_principal_limiter::{PrincipalCapExceeded, PrincipalInflightPermit};
use super::transport::{
find_header_end, json_error, parse_query_string, HttpRequest, HttpResponse, CORS_HEADER_PAIRS,
};
use super::RedDBServer;
const STREAM_CHANNEL_DEPTH: usize = 16;
fn connection_builder() -> AutoConnBuilder<TokioExecutor> {
let mut builder = AutoConnBuilder::new(TokioExecutor::new());
builder.http1().half_close(true);
builder
}
#[derive(Clone)]
pub(super) struct EdgeState {
pub(super) server: RedDBServer,
pub(super) transport: HttpTransport,
}
async fn edge_fallback(State(state): State<EdgeState>, req: axum::extract::Request) -> Response {
state.server.handle_edge_request(req, state.transport).await
}
impl RedDBServer {
fn build_edge_router(&self, transport: HttpTransport) -> axum::Router {
let mut router = axum::Router::new().fallback(edge_fallback);
if !self.websocket_allowed_origins().is_empty() {
router = router.route(
super::ws_edge::REDWIRE_WS_PATH,
axum::routing::get(super::ws_edge::redwire_ws_upgrade),
);
}
router.with_state(EdgeState {
server: self.clone(),
transport,
})
}
pub(crate) async fn serve_edge(
self,
listener: tokio::net::TcpListener,
transport: HttpTransport,
) -> io::Result<()> {
let router = self.build_edge_router(transport);
loop {
let (stream, _peer) = match listener.accept().await {
Ok(pair) => pair,
Err(err) => {
tracing::warn!(target: "reddb::http", error = %err, "accept failed");
continue;
}
};
let service = TowerToHyperService::new(router.clone());
tokio::spawn(async move {
let io = TokioIo::new(stream);
if let Err(err) = connection_builder().serve_connection(io, service).await {
tracing::debug!(target: "reddb::http", error = %err, "connection closed with error");
}
});
}
}
pub(crate) async fn serve_edge_tls(
self,
listener: tokio::net::TcpListener,
acceptor: tokio_rustls::TlsAcceptor,
transport: HttpTransport,
) -> io::Result<()> {
let router = self.build_edge_router(transport);
loop {
let (stream, _peer) = match listener.accept().await {
Ok(pair) => pair,
Err(err) => {
tracing::warn!(target: "reddb::http_tls", error = %err, "accept failed");
continue;
}
};
let acceptor = acceptor.clone();
let service = TowerToHyperService::new(router.clone());
tokio::spawn(async move {
match acceptor.accept(stream).await {
Ok(tls_stream) => {
let io = TokioIo::new(tls_stream);
if let Err(err) = connection_builder().serve_connection(io, service).await {
tracing::debug!(target: "reddb::http_tls", error = %err, "connection closed with error");
}
}
Err(err) => {
tracing::warn!(target: "reddb::http_tls", error = %err, "TLS handshake failed");
}
}
});
}
}
pub(crate) async fn serve_edge_on_std(
self,
listener: std::net::TcpListener,
transport: HttpTransport,
) -> io::Result<()> {
listener.set_nonblocking(true)?;
let listener = tokio::net::TcpListener::from_std(listener)?;
self.serve_edge(listener, transport).await
}
pub(crate) async fn serve_edge_tls_on_std(
self,
listener: std::net::TcpListener,
acceptor: tokio_rustls::TlsAcceptor,
transport: HttpTransport,
) -> io::Result<()> {
listener.set_nonblocking(true)?;
let listener = tokio::net::TcpListener::from_std(listener)?;
self.serve_edge_tls(listener, acceptor, transport).await
}
pub(crate) async fn serve_edge_one(self, stream: tokio::net::TcpStream) {
let service = TowerToHyperService::new(self.build_edge_router(HttpTransport::Http));
let io = TokioIo::new(stream);
if let Err(err) = connection_builder().serve_connection(io, service).await {
tracing::debug!(target: "reddb::http", error = %err, "connection closed with error");
}
}
async fn handle_edge_request(
&self,
req: axum::extract::Request,
transport: HttpTransport,
) -> Response {
let started = std::time::Instant::now();
let request = match read_edge_request(req, self.options.max_body_bytes).await {
Ok(request) => request,
Err(response) => return response,
};
let permit = match self.http_limiter.try_acquire() {
Some(permit) => permit,
None => {
self.http_metrics
.record_reject(transport, HttpRejectReason::CapExhausted);
return self.reject_capacity_response();
}
};
let principal = super::routing::principal_for(&request.headers);
let principal_permit = match self.principal_limiter.try_acquire(&principal) {
Ok(principal_permit) => principal_permit,
Err(err) => {
drop(permit);
self.http_metrics
.record_reject(transport, HttpRejectReason::PrincipalCapExhausted);
return self.reject_principal_response(&err);
}
};
let response = if self.is_streaming_request(&request) {
self.serve_streaming_request(request, permit, principal_permit)
.await
} else {
self.serve_buffered_request(request, permit, principal_permit, transport)
.await
};
self.http_metrics
.record_duration(transport, started.elapsed().as_secs_f64());
response
}
async fn serve_buffered_request(
&self,
request: HttpRequest,
permit: HttpConnectionPermit,
principal_permit: PrincipalInflightPermit,
transport: HttpTransport,
) -> Response {
let server = self.clone();
let join = tokio::task::spawn_blocking(move || {
let _permit = permit;
let _principal_permit = principal_permit;
let inject_ms = server
.slow_inject_ms
.load(std::sync::atomic::Ordering::Relaxed);
if inject_ms > 0 {
std::thread::sleep(std::time::Duration::from_millis(inject_ms));
}
server.route(request)
});
match tokio::time::timeout(self.handler_timeout, join).await {
Ok(Ok(response)) => buffered_response_to_axum(response),
Ok(Err(_join_err)) => internal_error_response(),
Err(_elapsed) => {
self.http_metrics
.record_reject(transport, HttpRejectReason::HandlerTimeout);
handler_timeout_response()
}
}
}
async fn serve_streaming_request(
&self,
request: HttpRequest,
permit: HttpConnectionPermit,
principal_permit: PrincipalInflightPermit,
) -> Response {
let (head_tx, head_rx) = oneshot::channel::<EdgeStreamResponse>();
let server = self.clone();
tokio::task::spawn_blocking(move || {
let _permit = permit;
let _principal_permit = principal_permit;
let mut sink = StreamSink::new(head_tx);
let _ = server.try_route_streaming(&request, &mut sink);
sink.finish();
});
match head_rx.await {
Ok(response) => stream_response_to_axum(response),
Err(_) => internal_error_response(),
}
}
fn reject_principal_response(&self, err: &PrincipalCapExceeded) -> Response {
let response =
super::routing::principal_inflight_refusal_response(err, self.retry_after_secs);
buffered_response_to_axum(response)
}
fn reject_capacity_response(&self) -> Response {
let body = format!(
"{{\"error\":\"server at capacity\",\"retry_after_secs\":{}}}",
self.retry_after_secs
);
let mut builder = Response::builder()
.status(503)
.header(http::header::CONTENT_TYPE, "application/json")
.header(http::header::RETRY_AFTER, self.retry_after_secs.to_string());
for (name, value) in CORS_HEADER_PAIRS {
builder = builder.header(name, value);
}
builder
.body(Body::from(body))
.unwrap_or_else(|_| internal_error_response())
}
}
async fn read_edge_request(
req: axum::extract::Request,
max_body_bytes: usize,
) -> Result<HttpRequest, Response> {
let (parts, body) = req.into_parts();
let method = parts.method.as_str().to_string();
let path = parts.uri.path().to_string();
let query = parts
.uri
.query()
.map(parse_query_string)
.unwrap_or_default();
let mut headers = std::collections::BTreeMap::new();
for (name, value) in parts.headers.iter() {
headers.insert(
name.as_str().to_ascii_lowercase(),
String::from_utf8_lossy(value.as_bytes()).trim().to_string(),
);
}
let body = match axum::body::to_bytes(body, max_body_bytes).await {
Ok(bytes) => bytes.to_vec(),
Err(_) => {
return Err(buffered_response_to_axum(json_error(
413,
"request body exceeds configured limit",
)))
}
};
Ok(HttpRequest {
method,
path,
query,
headers,
body,
})
}
fn buffered_response_to_axum(response: HttpResponse) -> Response {
let mut builder = Response::builder()
.status(response.status)
.header(http::header::CONTENT_TYPE, response.content_type);
for (name, value) in CORS_HEADER_PAIRS {
builder = builder.header(name, value);
}
for (name, value) in response.extra_headers {
builder = builder.header(name, value);
}
builder
.body(Body::from(response.body))
.unwrap_or_else(|_| internal_error_response())
}
fn stream_response_to_axum(response: EdgeStreamResponse) -> Response {
match response {
EdgeStreamResponse::Buffered {
status,
headers,
body,
} => build_response(status, headers, Body::from(body)),
EdgeStreamResponse::Streaming {
status,
headers,
body,
} => build_response(
status,
headers,
Body::from_stream(ReceiverStream::new(body)),
),
}
}
fn build_response(status: u16, headers: Vec<(String, String)>, body: Body) -> Response {
let mut builder = Response::builder().status(status);
for (name, value) in headers {
builder = builder.header(name, value);
}
builder
.body(body)
.unwrap_or_else(|_| internal_error_response())
}
fn internal_error_response() -> Response {
Response::builder()
.status(500)
.header(http::header::CONTENT_TYPE, "application/json")
.body(Body::from("{\"ok\":false,\"error\":\"internal error\"}"))
.expect("static 500 response is well-formed")
}
fn handler_timeout_response() -> Response {
let mut builder = Response::builder()
.status(503)
.header(http::header::CONTENT_TYPE, "application/json");
for (name, value) in CORS_HEADER_PAIRS {
builder = builder.header(name, value);
}
builder
.body(Body::from(
"{\"ok\":false,\"error\":\"handler deadline exceeded\"}",
))
.unwrap_or_else(|_| internal_error_response())
}
enum EdgeStreamResponse {
Buffered {
status: u16,
headers: Vec<(String, String)>,
body: Vec<u8>,
},
Streaming {
status: u16,
headers: Vec<(String, String)>,
body: mpsc::Receiver<Result<Bytes, io::Error>>,
},
}
enum BodyFraming {
Chunked(ChunkDecoder),
CloseDelimited,
}
enum SinkState {
Head(Vec<u8>),
Buffering {
status: u16,
headers: Vec<(String, String)>,
remaining: usize,
body: Vec<u8>,
},
Streaming {
sender: mpsc::Sender<Result<Bytes, io::Error>>,
framing: BodyFraming,
},
Done,
}
struct StreamSink {
head: Option<oneshot::Sender<EdgeStreamResponse>>,
state: SinkState,
}
impl StreamSink {
fn new(head: oneshot::Sender<EdgeStreamResponse>) -> Self {
Self {
head: Some(head),
state: SinkState::Head(Vec::with_capacity(512)),
}
}
fn consume(&mut self, data: &[u8]) -> io::Result<()> {
let (head_bytes, leftover) = match &mut self.state {
SinkState::Head(buffer) => {
buffer.extend_from_slice(data);
match find_header_end(buffer) {
Some(pos) => (buffer[..pos].to_vec(), buffer[pos + 4..].to_vec()),
None => return Ok(()),
}
}
_ => return self.consume_body(data),
};
self.begin_body(&head_bytes)?;
self.consume_body(&leftover)
}
fn begin_body(&mut self, head_bytes: &[u8]) -> io::Result<()> {
let (status, headers, framing) = parse_response_head(head_bytes)?;
match framing {
HeadFraming::ContentLength(0) => {
if let Some(head) = self.head.take() {
let _ = head.send(EdgeStreamResponse::Buffered {
status,
headers,
body: Vec::new(),
});
}
self.state = SinkState::Done;
}
HeadFraming::ContentLength(remaining) => {
self.state = SinkState::Buffering {
status,
headers,
remaining,
body: Vec::with_capacity(remaining),
};
}
HeadFraming::Chunked => {
let (sender, body) = mpsc::channel(STREAM_CHANNEL_DEPTH);
if let Some(head) = self.head.take() {
let _ = head.send(EdgeStreamResponse::Streaming {
status,
headers,
body,
});
}
self.state = SinkState::Streaming {
sender,
framing: BodyFraming::Chunked(ChunkDecoder::new()),
};
}
HeadFraming::CloseDelimited => {
let (sender, body) = mpsc::channel(STREAM_CHANNEL_DEPTH);
if let Some(head) = self.head.take() {
let _ = head.send(EdgeStreamResponse::Streaming {
status,
headers,
body,
});
}
self.state = SinkState::Streaming {
sender,
framing: BodyFraming::CloseDelimited,
};
}
}
Ok(())
}
fn consume_body(&mut self, data: &[u8]) -> io::Result<()> {
match std::mem::replace(&mut self.state, SinkState::Done) {
SinkState::Buffering {
status,
headers,
mut remaining,
mut body,
} => {
let take = remaining.min(data.len());
body.extend_from_slice(&data[..take]);
remaining -= take;
if remaining == 0 {
if let Some(head) = self.head.take() {
let _ = head.send(EdgeStreamResponse::Buffered {
status,
headers,
body,
});
}
} else {
self.state = SinkState::Buffering {
status,
headers,
remaining,
body,
};
}
Ok(())
}
SinkState::Streaming {
sender,
mut framing,
} => {
let result = forward_stream(&sender, &mut framing, data);
if result.is_ok() {
self.state = SinkState::Streaming { sender, framing };
}
result
}
SinkState::Head(buffer) => {
self.state = SinkState::Head(buffer);
Ok(())
}
SinkState::Done => Ok(()),
}
}
fn finish(mut self) {
if let SinkState::Buffering {
status,
headers,
body,
..
} = std::mem::replace(&mut self.state, SinkState::Done)
{
if let Some(head) = self.head.take() {
let _ = head.send(EdgeStreamResponse::Buffered {
status,
headers,
body,
});
}
}
}
}
impl Write for StreamSink {
fn write(&mut self, data: &[u8]) -> io::Result<usize> {
self.consume(data)?;
Ok(data.len())
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
fn forward_stream(
sender: &mpsc::Sender<Result<Bytes, io::Error>>,
framing: &mut BodyFraming,
data: &[u8],
) -> io::Result<()> {
match framing {
BodyFraming::CloseDelimited => {
if !data.is_empty() {
send_frame(sender, Bytes::copy_from_slice(data))?;
}
Ok(())
}
BodyFraming::Chunked(decoder) => {
let mut frames = Vec::new();
decoder.feed(data, &mut frames);
for frame in frames {
send_frame(sender, frame)?;
}
Ok(())
}
}
}
fn send_frame(sender: &mpsc::Sender<Result<Bytes, io::Error>>, frame: Bytes) -> io::Result<()> {
sender
.blocking_send(Ok(frame))
.map_err(|_| io::Error::new(io::ErrorKind::BrokenPipe, "streaming client disconnected"))
}
enum HeadFraming {
ContentLength(usize),
Chunked,
CloseDelimited,
}
fn parse_response_head(head: &[u8]) -> io::Result<(u16, Vec<(String, String)>, HeadFraming)> {
let text = String::from_utf8_lossy(head);
let mut lines = text.split("\r\n");
let status_line = lines
.next()
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "missing status line"))?;
let status: u16 = status_line
.split_whitespace()
.nth(1)
.and_then(|token| token.parse().ok())
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "missing status code"))?;
let mut headers = Vec::new();
let mut chunked = false;
let mut content_length: Option<usize> = None;
for line in lines {
if line.is_empty() {
continue;
}
let Some((name, value)) = line.split_once(':') else {
continue;
};
let name = name.trim();
let value = value.trim();
match name.to_ascii_lowercase().as_str() {
"transfer-encoding" => {
if value.to_ascii_lowercase().contains("chunked") {
chunked = true;
}
}
"content-length" => content_length = value.parse().ok(),
"connection" => {}
_ => headers.push((name.to_string(), value.to_string())),
}
}
let framing = if chunked {
HeadFraming::Chunked
} else if let Some(length) = content_length {
HeadFraming::ContentLength(length)
} else {
HeadFraming::CloseDelimited
};
Ok((status, headers, framing))
}
struct ChunkDecoder {
state: ChunkState,
size_line: Vec<u8>,
remaining: usize,
payload: Vec<u8>,
}
enum ChunkState {
Size,
Data,
TrailingCrlf(usize),
Done,
}
impl ChunkDecoder {
fn new() -> Self {
Self {
state: ChunkState::Size,
size_line: Vec::new(),
remaining: 0,
payload: Vec::new(),
}
}
fn feed(&mut self, mut data: &[u8], frames: &mut Vec<Bytes>) {
while !data.is_empty() {
match self.state {
ChunkState::Size => {
if let Some(idx) = data.iter().position(|&b| b == b'\n') {
self.size_line.extend_from_slice(&data[..idx]);
data = &data[idx + 1..];
let line = String::from_utf8_lossy(&self.size_line);
let hex = line.trim().split(';').next().unwrap_or("").trim();
let size = usize::from_str_radix(hex, 16).unwrap_or(0);
self.size_line.clear();
if size == 0 {
self.state = ChunkState::Done;
} else {
self.remaining = size;
self.payload.clear();
self.payload.reserve(size);
self.state = ChunkState::Data;
}
} else {
self.size_line.extend_from_slice(data);
data = &[];
}
}
ChunkState::Data => {
let take = self.remaining.min(data.len());
self.payload.extend_from_slice(&data[..take]);
data = &data[take..];
self.remaining -= take;
if self.remaining == 0 {
frames.push(Bytes::from(std::mem::take(&mut self.payload)));
self.state = ChunkState::TrailingCrlf(2);
}
}
ChunkState::TrailingCrlf(rem) => {
let take = rem.min(data.len());
data = &data[take..];
let left = rem - take;
self.state = if left == 0 {
ChunkState::Size
} else {
ChunkState::TrailingCrlf(left)
};
}
ChunkState::Done => data = &[],
}
}
}
}
pub(crate) fn build_edge_runtime() -> io::Result<tokio::runtime::Runtime> {
tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
}
pub(crate) fn build_background_edge_runtime() -> io::Result<tokio::runtime::Runtime> {
tokio::runtime::Builder::new_multi_thread()
.worker_threads(2)
.enable_all()
.build()
}
pub(crate) fn tls_acceptor(config: Arc<rustls::ServerConfig>) -> tokio_rustls::TlsAcceptor {
tokio_rustls::TlsAcceptor::from(config)
}
#[cfg(test)]
mod tests {
use super::*;
fn drain(rx: &mut mpsc::Receiver<Result<Bytes, io::Error>>) -> Vec<Vec<u8>> {
let mut out = Vec::new();
while let Ok(item) = rx.try_recv() {
out.push(item.expect("frame ok").to_vec());
}
out
}
#[test]
fn chunk_decoder_emits_one_frame_per_chunk() {
let mut decoder = ChunkDecoder::new();
let mut frames = Vec::new();
decoder.feed(b"a\r\n{\"row\":1}\n\r\n", &mut frames);
decoder.feed(b"3\r\nend\r\n0\r\n\r\n", &mut frames);
let decoded: Vec<String> = frames
.iter()
.map(|f| String::from_utf8_lossy(f).into_owned())
.collect();
assert_eq!(
decoded,
vec!["{\"row\":1}\n".to_string(), "end".to_string()]
);
}
#[test]
fn chunk_decoder_tolerates_split_writes() {
let mut decoder = ChunkDecoder::new();
let mut frames = Vec::new();
let raw = b"a\r\n{\"row\":1}\n\r\n3\r\nend\r\n0\r\n\r\n";
for byte in raw {
decoder.feed(&[*byte], &mut frames);
}
let decoded: Vec<String> = frames
.iter()
.map(|f| String::from_utf8_lossy(f).into_owned())
.collect();
assert_eq!(
decoded,
vec!["{\"row\":1}\n".to_string(), "end".to_string()]
);
}
#[test]
fn parse_head_strips_hop_by_hop_and_detects_chunked() {
let head = b"HTTP/1.1 200 OK\r\nContent-Type: application/x-ndjson\r\nTransfer-Encoding: chunked\r\nConnection: close\r\nAccess-Control-Allow-Origin: *";
let (status, headers, framing) = parse_response_head(head).expect("parse");
assert_eq!(status, 200);
assert!(matches!(framing, HeadFraming::Chunked));
assert!(headers
.iter()
.any(|(n, v)| n == "Content-Type" && v == "application/x-ndjson"));
assert!(headers
.iter()
.any(|(n, _)| n == "Access-Control-Allow-Origin"));
assert!(!headers
.iter()
.any(|(n, _)| n.eq_ignore_ascii_case("connection")));
assert!(!headers
.iter()
.any(|(n, _)| n.eq_ignore_ascii_case("transfer-encoding")));
}
#[tokio::test]
async fn sink_routes_content_length_refusal_as_buffered() {
let (tx, rx) = oneshot::channel();
let mut sink = StreamSink::new(tx);
let body = b"{\"ok\":false,\"code\":\"x\"}";
let head = format!(
"HTTP/1.1 400 Bad Request\r\nContent-Type: application/json\r\nContent-Length: {}\r\nConnection: close\r\n\r\n",
body.len()
);
sink.write_all(head.as_bytes()).unwrap();
sink.write_all(body).unwrap();
sink.finish();
match rx.await.expect("head") {
EdgeStreamResponse::Buffered {
status,
body: collected,
..
} => {
assert_eq!(status, 400);
assert_eq!(collected, body);
}
EdgeStreamResponse::Streaming { .. } => {
panic!("refusal must be buffered, not streamed")
}
}
}
#[tokio::test]
async fn sink_streams_chunked_body_frames() {
let (tx, rx) = oneshot::channel();
let writer = std::thread::spawn(move || {
let mut sink = StreamSink::new(tx);
let head = "HTTP/1.1 200 OK\r\nContent-Type: application/x-ndjson\r\nTransfer-Encoding: chunked\r\nConnection: close\r\n\r\n";
sink.write_all(head.as_bytes()).unwrap();
sink.write_all(b"5\r\nhello\r\n").unwrap();
sink.write_all(b"5\r\nworld\r\n").unwrap();
sink.write_all(b"0\r\n\r\n").unwrap();
sink.finish();
});
let response = rx.await.expect("head");
writer.join().unwrap();
match response {
EdgeStreamResponse::Streaming {
status, mut body, ..
} => {
assert_eq!(status, 200);
let frames = drain(&mut body);
assert_eq!(frames, vec![b"hello".to_vec(), b"world".to_vec()]);
}
EdgeStreamResponse::Buffered { .. } => panic!("chunked body must stream"),
}
}
#[tokio::test]
async fn sink_streams_close_delimited_sse_body() {
let (tx, rx) = oneshot::channel();
let writer = std::thread::spawn(move || {
let mut sink = StreamSink::new(tx);
let head =
"HTTP/1.1 200 OK\r\nContent-Type: text/event-stream\r\nConnection: close\r\n\r\n";
sink.write_all(head.as_bytes()).unwrap();
sink.write_all(b"data: one\n\n").unwrap();
sink.write_all(b"data: two\n\n").unwrap();
sink.finish();
});
let response = rx.await.expect("head");
writer.join().unwrap();
match response {
EdgeStreamResponse::Streaming { mut body, .. } => {
let frames = drain(&mut body);
let joined: Vec<u8> = frames.concat();
assert_eq!(joined, b"data: one\n\ndata: two\n\n");
}
EdgeStreamResponse::Buffered { .. } => panic!("SSE must stream"),
}
}
}