use actus::prelude::*;
use serde_json::{Value as JsonValue, json};
use std::net::SocketAddr;
use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::sync::oneshot;
struct ShortCircuit;
#[async_trait]
impl Middleware for ShortCircuit {
async fn before(&self, _request: &mut Request) -> Result<Outcome, WebError> {
Ok(Outcome::Respond(reply::json(json!({ "short": "circuit" }))))
}
}
struct StampTraceId;
#[async_trait]
impl Middleware for StampTraceId {
async fn after(&self, request: &Request, response: &mut ReplyData) -> Result<(), WebError> {
if let Some(v) = request
.headers
.get("x-trace-id")
.and_then(|v| v.to_str().ok())
{
response.add_header("X-Trace-Id", v);
}
Ok(())
}
}
struct RejectIfHeader;
#[async_trait]
impl Middleware for RejectIfHeader {
async fn before(&self, request: &mut Request) -> Result<Outcome, WebError> {
if request.headers.contains_key("x-reject") {
return Err(WebError::Unauthorized);
}
Ok(Outcome::Continue)
}
}
struct LimitClass {
class: &'static str,
}
#[async_trait]
impl Middleware for LimitClass {
async fn before(&self, request: &mut Request) -> Result<Outcome, WebError> {
if request.rate_limit_class == Some(self.class) {
return Err(WebError::TooManyRequests(Some(Duration::from_secs(7))));
}
Ok(Outcome::Continue)
}
}
struct Tight;
struct Classified;
#[controller(rate_limit = "vip")]
impl Classified {
routes! {
GET "" => hello(),
}
pub async fn hello(&self) -> Reply {
reply!(json!({ "ok": true }))
}
}
#[controller(max_body_bytes = 16)]
impl Tight {
routes! {
POST "" => post(data: JsonValue),
}
pub async fn post(&self, data: JsonValue) -> Reply {
reply!(data)
}
}
struct Dummy;
#[controller]
impl Dummy {
routes! {
GET "" => hello(),
POST "post" => post(data: JsonValue),
GET "slow" => slow(),
}
pub async fn hello(&self) -> Reply {
reply!(json!({ "hello": "world" }))
}
pub async fn post(&self, data: JsonValue) -> Reply {
reply!(data)
}
pub async fn slow(&self) -> Reply {
tokio::time::sleep(Duration::from_secs(60)).await;
reply!(json!({ "unreachable": true }))
}
}
app_routes! {
routes {
"svc" => Dummy,
"tight" => Tight,
"classified" => Classified,
}
}
async fn spawn<F>(f: F) -> (SocketAddr, oneshot::Sender<()>)
where
F: FnOnce(Server) -> Server + Send + 'static,
{
let port = std::net::TcpListener::bind("127.0.0.1:0")
.unwrap()
.local_addr()
.unwrap()
.port();
let addr = SocketAddr::from(([127, 0, 0, 1], port));
let (tx, rx) = oneshot::channel::<()>();
tokio::spawn(async move {
let server = f(Server::new(init().await.unwrap()));
server
.run_with_shutdown_on(addr, async move {
let _ = rx.await;
})
.await
.unwrap();
});
for _ in 0..100 {
if tokio::net::TcpStream::connect(addr).await.is_ok() {
break;
}
tokio::time::sleep(Duration::from_millis(20)).await;
}
(addr, tx)
}
async fn http(addr: SocketAddr, raw: &str) -> (u16, http::HeaderMap, Vec<u8>) {
let mut stream = tokio::net::TcpStream::connect(addr).await.unwrap();
stream.write_all(raw.as_bytes()).await.unwrap();
let mut buf = Vec::new();
stream.read_to_end(&mut buf).await.unwrap();
let split = buf
.windows(4)
.position(|w| w == b"\r\n\r\n")
.unwrap_or(buf.len());
let header_part = std::str::from_utf8(&buf[..split]).unwrap();
let body = if split + 4 < buf.len() {
buf[split + 4..].to_vec()
} else {
Vec::new()
};
let mut lines = header_part.split("\r\n");
let status: u16 = lines
.next()
.unwrap()
.split_whitespace()
.nth(1)
.unwrap()
.parse()
.unwrap();
let mut headers = http::HeaderMap::new();
for line in lines {
if let Some((n, v)) = line.split_once(": ")
&& let (Ok(n), Ok(v)) = (
http::HeaderName::from_bytes(n.as_bytes()),
http::HeaderValue::from_str(v),
)
{
headers.append(n, v);
}
}
(status, headers, body)
}
#[tokio::test]
async fn after_hook_runs_on_short_circuit_sees_request_and_can_stamp_headers() {
let (addr, stop) = spawn(|s| {
s.with_middleware(StampTraceId)
.with_middleware(ShortCircuit)
})
.await;
let req =
"GET /svc HTTP/1.1\r\nHost: 127.0.0.1\r\nConnection: close\r\nX-Trace-Id: abc123\r\n\r\n";
let (status, headers, body) = http(addr, req).await;
assert_eq!(status, 200);
let body: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert_eq!(body, json!({ "short": "circuit" }));
assert_eq!(
headers.get("X-Trace-Id").and_then(|v| v.to_str().ok()),
Some("abc123"),
);
let _ = stop.send(());
}
#[tokio::test]
async fn pre_parse_413_carries_cors_headers_and_runs_after_chain() {
let (addr, stop) = spawn(|s| {
s.with_middleware(StampTraceId)
.with_max_body_bytes(16)
.with_cors(CorsLayer::permissive())
})
.await;
let body_bytes = "x".repeat(64);
let req = format!(
"POST /svc HTTP/1.1\r\nHost: 127.0.0.1\r\nConnection: close\r\nOrigin: https://app.example.com\r\nX-Trace-Id: trace-413\r\nContent-Type: application/octet-stream\r\nContent-Length: {}\r\n\r\n{}",
body_bytes.len(),
body_bytes,
);
let (status, headers, _body) = http(addr, &req).await;
assert_eq!(status, 413);
assert_eq!(
headers
.get("access-control-allow-origin")
.and_then(|v| v.to_str().ok()),
Some("https://app.example.com"),
"a pre-parse 413 should still carry CORS headers — the browser needs them to read the body",
);
assert_eq!(
headers.get("X-Trace-Id").and_then(|v| v.to_str().ok()),
Some("trace-413"),
"the after-chain should run on a pre-parse 413 (the request skeleton exists by then)",
);
let _ = stop.send(());
}
#[tokio::test]
async fn after_runs_on_router_404() {
let (addr, stop) = spawn(|s| s.with_middleware(StampTraceId)).await;
let req = "GET /missing/path HTTP/1.1\r\nHost: 127.0.0.1\r\nConnection: close\r\nX-Trace-Id: trace-404\r\n\r\n";
let (status, headers, _body) = http(addr, req).await;
assert_eq!(status, 404);
assert_eq!(
headers.get("X-Trace-Id").and_then(|v| v.to_str().ok()),
Some("trace-404"),
"the after-chain should run on a router 404",
);
let _ = stop.send(());
}
#[tokio::test]
async fn after_runs_on_router_405_and_allow_header_survives() {
let (addr, stop) = spawn(|s| s.with_middleware(StampTraceId)).await;
let req = "DELETE /svc HTTP/1.1\r\nHost: 127.0.0.1\r\nConnection: close\r\nX-Trace-Id: trace-405\r\n\r\n";
let (status, headers, _body) = http(addr, req).await;
assert_eq!(status, 405);
assert_eq!(
headers.get("Allow").and_then(|v| v.to_str().ok()),
Some("GET"),
"the 405 should still carry the framework-emitted Allow header",
);
assert_eq!(
headers.get("X-Trace-Id").and_then(|v| v.to_str().ok()),
Some("trace-405"),
"the after-chain should run on a 405",
);
let _ = stop.send(());
}
#[tokio::test]
async fn after_runs_on_middleware_err_short_circuit() {
let (addr, stop) = spawn(|s| {
s.with_middleware(StampTraceId)
.with_middleware(RejectIfHeader)
})
.await;
let req = "GET /svc HTTP/1.1\r\nHost: 127.0.0.1\r\nConnection: close\r\nX-Reject: yes\r\nX-Trace-Id: trace-401\r\n\r\n";
let (status, headers, _body) = http(addr, req).await;
assert_eq!(status, 401);
assert_eq!(
headers.get("X-Trace-Id").and_then(|v| v.to_str().ok()),
Some("trace-401"),
"the after-chain should run on a middleware-Err short-circuit",
);
let _ = stop.send(());
}
#[tokio::test]
async fn request_timeout_yields_504_and_skips_the_after_chain() {
use std::time::Instant;
let (addr, stop) = spawn(|s| {
s.with_middleware(StampTraceId)
.with_request_timeout(Duration::from_millis(100))
})
.await;
let req = "GET /svc/slow HTTP/1.1\r\nHost: 127.0.0.1\r\nConnection: close\r\nX-Trace-Id: trace-504\r\n\r\n";
let start = Instant::now();
let (status, headers, _body) = http(addr, req).await;
let elapsed = start.elapsed();
assert_eq!(status, 504);
assert!(
elapsed < Duration::from_secs(1),
"504 should come back fast (got {elapsed:?})",
);
assert!(
headers.get("X-Trace-Id").is_none(),
"after-chain should not run on a timeout-504 (got {:?})",
headers.get("X-Trace-Id"),
);
let _ = stop.send(());
}
#[tokio::test]
async fn controller_max_body_attribute_overrides_server_default() {
let (addr, stop) = spawn(|s| s.with_max_body_bytes(1024 * 1024)).await;
let body = "x".repeat(64);
let req = format!(
"POST /tight HTTP/1.1\r\nHost: 127.0.0.1\r\nConnection: close\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}",
body.len(),
body,
);
let (status, _, _) = http(addr, &req).await;
assert_eq!(
status, 413,
"Tight has max_body_bytes=16; a 64-byte body must be rejected",
);
let req = format!(
"POST /svc/post HTTP/1.1\r\nHost: 127.0.0.1\r\nConnection: close\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}",
body.len(),
body,
);
let (status, _, _) = http(addr, &req).await;
assert_eq!(
status, 400,
"no controller cap on Dummy → server default lets the body through, \
then to_params rejects as malformed JSON",
);
let _ = stop.send(());
}
#[tokio::test]
async fn controller_max_body_allows_small_bodies() {
let (addr, stop) = spawn(|s| s).await;
let body = "{}"; let req = format!(
"POST /tight HTTP/1.1\r\nHost: 127.0.0.1\r\nConnection: close\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}",
body.len(),
body,
);
let (status, _, body_bytes) = http(addr, req.as_str()).await;
assert_eq!(status, 200);
let resp: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap();
assert_eq!(resp, serde_json::json!({}));
let _ = stop.send(());
}
#[tokio::test]
async fn nonexistent_path_404s_without_buffering_the_body() {
let (addr, stop) = spawn(|s| s.with_max_body_bytes(16)).await;
let body = "x".repeat(1024); let req = format!(
"POST /no-such-route HTTP/1.1\r\nHost: 127.0.0.1\r\nConnection: close\r\nContent-Type: application/octet-stream\r\nContent-Length: {}\r\n\r\n{}",
body.len(),
body,
);
let (status, _, _) = http(addr, &req).await;
assert_eq!(
status, 404,
"no matching controller → 404 *without* attempting to buffer the body",
);
let _ = stop.send(());
}
#[tokio::test]
async fn inflight_body_budget_returns_503_when_exhausted() {
let (addr, stop) = spawn(|s| s.with_max_body_bytes(32).with_max_inflight_body_bytes(16)).await;
let body = b"abcdefgh"; let req = format!(
"POST /svc/post HTTP/1.1\r\nHost: 127.0.0.1\r\nConnection: close\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}",
body.len(),
std::str::from_utf8(body).unwrap(),
);
let (status, headers, _) = http(addr, &req).await;
assert_eq!(
status, 503,
"per-request cap (32 B) exceeds the inflight budget (16 B); request refused",
);
let retry = headers
.get("Retry-After")
.and_then(|v| v.to_str().ok())
.expect("Retry-After header present");
assert!(retry.parse::<u64>().is_ok(), "Retry-After is delta-seconds");
let _ = stop.send(());
}
#[tokio::test]
async fn drain_deadline_is_honored_on_shutdown() {
use std::time::Instant;
use tokio::sync::oneshot;
let port = std::net::TcpListener::bind("127.0.0.1:0")
.unwrap()
.local_addr()
.unwrap()
.port();
let addr = SocketAddr::from(([127, 0, 0, 1], port));
let (stop_tx, stop_rx) = oneshot::channel::<()>();
let server_task = tokio::spawn(async move {
Server::new(init().await.unwrap())
.with_drain_deadline(Duration::from_millis(200))
.run_with_shutdown_on(addr, async move {
let _ = stop_rx.await;
})
.await
.unwrap();
});
for _ in 0..100 {
if tokio::net::TcpStream::connect(addr).await.is_ok() {
break;
}
tokio::time::sleep(Duration::from_millis(20)).await;
}
let mut stream = tokio::net::TcpStream::connect(addr).await.unwrap();
stream
.write_all(b"GET /svc/slow HTTP/1.1\r\nHost: 127.0.0.1\r\n\r\n")
.await
.unwrap();
tokio::time::sleep(Duration::from_millis(100)).await;
let start = Instant::now();
let _ = stop_tx.send(());
let result = tokio::time::timeout(Duration::from_secs(5), server_task).await;
let elapsed = start.elapsed();
assert!(result.is_ok(), "server task didn't complete inside 5 s");
assert!(
elapsed < Duration::from_secs(2),
"drain took {elapsed:?} — expected ~200 ms, well under the legacy 30 s default",
);
drop(stream);
}
#[tokio::test]
async fn requests_under_timeout_succeed_normally() {
let (addr, stop) = spawn(|s| {
s.with_middleware(StampTraceId)
.with_request_timeout(Duration::from_secs(5))
})
.await;
let req =
"GET /svc HTTP/1.1\r\nHost: 127.0.0.1\r\nConnection: close\r\nX-Trace-Id: trace-ok\r\n\r\n";
let (status, headers, _body) = http(addr, req).await;
assert_eq!(status, 200);
assert_eq!(
headers.get("X-Trace-Id").and_then(|v| v.to_str().ok()),
Some("trace-ok"),
);
let _ = stop.send(());
}
#[tokio::test]
async fn after_runs_on_malformed_body_400() {
let (addr, stop) = spawn(|s| s.with_middleware(StampTraceId)).await;
let body = "{ this is not valid json";
let req = format!(
"POST /svc/post HTTP/1.1\r\nHost: 127.0.0.1\r\nConnection: close\r\nX-Trace-Id: trace-400\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}",
body.len(),
body,
);
let (status, headers, _body) = http(addr, &req).await;
assert_eq!(status, 400);
assert_eq!(
headers.get("X-Trace-Id").and_then(|v| v.to_str().ok()),
Some("trace-400"),
"the after-chain should run on a 400 from to_params",
);
let _ = stop.send(());
}
#[tokio::test]
async fn rate_limit_class_label_reaches_middleware() {
let (addr, stop) = spawn(|s| s.with_middleware(LimitClass { class: "vip" })).await;
let req = "GET /classified HTTP/1.1\r\nHost: 127.0.0.1\r\nConnection: close\r\n\r\n";
let (status, headers, _) = http(addr, req).await;
assert_eq!(status, 429, "the 'vip'-classed controller is rate-limited");
assert_eq!(
headers.get("Retry-After").and_then(|v| v.to_str().ok()),
Some("7"),
"TooManyRequests(Some(7s)) finalizes to `Retry-After: 7`",
);
let req = "GET /svc HTTP/1.1\r\nHost: 127.0.0.1\r\nConnection: close\r\n\r\n";
let (status, _, _) = http(addr, req).await;
assert_eq!(
status, 200,
"a controller with no rate-limit class is not limited"
);
let _ = stop.send(());
}