use std::fmt;
use std::marker::PhantomData;
use std::rc::Rc;
use crate::cx::{Cx, cap};
use crate::error::Error;
use crate::web::extract::Request;
use crate::web::response::{Response, StatusCode};
pub struct RequestRegion<'a> {
cx: &'a Cx,
request: Request,
}
impl<'a> RequestRegion<'a> {
#[must_use]
pub fn new(cx: &'a Cx, request: Request) -> Self {
Self { cx, request }
}
#[inline]
pub fn run<F>(self, handler: F) -> RegionOutcome
where
F: FnOnce(&RequestContext<'_>) -> Response,
{
let _cx_guard = Cx::set_current(Some(self.cx.clone()));
let ctx = RequestContext {
cx: self.cx,
request: &self.request,
_not_send_sync: PhantomData,
};
if self.cx.checkpoint().is_err() {
return RegionOutcome::Cancelled;
}
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| handler(&ctx)));
match result {
Ok(response) => {
if self.cx.checkpoint().is_err() {
RegionOutcome::Cancelled
} else {
RegionOutcome::Ok(response)
}
}
Err(panic_payload) => {
let message = extract_panic_message(&panic_payload);
RegionOutcome::Panicked(message)
}
}
}
#[inline]
#[allow(clippy::result_large_err)]
pub fn run_sync<F>(self, handler: F) -> RegionOutcome
where
F: FnOnce(&RequestContext<'_>) -> Result<Response, Error>,
{
let _cx_guard = Cx::set_current(Some(self.cx.clone()));
let ctx = RequestContext {
cx: self.cx,
request: &self.request,
_not_send_sync: PhantomData,
};
if self.cx.checkpoint().is_err() {
return RegionOutcome::Cancelled;
}
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| handler(&ctx)));
match result {
Ok(Ok(response)) => {
if self.cx.checkpoint().is_err() {
RegionOutcome::Cancelled
} else {
RegionOutcome::Ok(response)
}
}
Ok(Err(err)) => {
if self.cx.checkpoint().is_err() {
RegionOutcome::Cancelled
} else {
RegionOutcome::Error(err)
}
}
Err(panic_payload) => {
let message = extract_panic_message(&panic_payload);
RegionOutcome::Panicked(message)
}
}
}
#[must_use]
pub fn request(&self) -> &Request {
&self.request
}
#[must_use]
pub fn cx(&self) -> &Cx {
self.cx
}
}
pub struct RequestContext<'a> {
cx: &'a Cx,
request: &'a Request,
_not_send_sync: PhantomData<Rc<()>>,
}
impl RequestContext<'_> {
#[inline]
#[must_use]
pub fn request(&self) -> &Request {
self.request
}
#[inline]
#[must_use]
pub fn cx(&self) -> &Cx {
self.cx
}
#[inline]
#[must_use]
pub fn cx_narrow<Caps>(&self) -> Cx<Caps>
where
Caps: cap::SubsetOf<cap::All>,
{
self.cx.restrict::<Caps>()
}
#[inline]
#[must_use]
pub fn cx_readonly(&self) -> Cx<cap::None> {
self.cx.restrict::<cap::None>()
}
#[inline]
#[must_use]
pub fn method(&self) -> &str {
&self.request.method
}
#[inline]
#[must_use]
pub fn path(&self) -> &str {
&self.request.path
}
#[inline]
#[must_use]
pub fn path_param(&self, name: &str) -> Option<&str> {
self.request.path_params.get(name).map(String::as_str)
}
#[inline]
#[must_use]
pub fn header(&self, name: &str) -> Option<&str> {
self.request.header(name)
}
}
#[derive(Debug)]
pub enum RegionOutcome {
Ok(Response),
Error(Error),
Cancelled,
Panicked(String),
}
impl RegionOutcome {
#[must_use]
pub const fn is_ok(&self) -> bool {
matches!(self, Self::Ok(_))
}
#[must_use]
pub const fn is_panicked(&self) -> bool {
matches!(self, Self::Panicked(_))
}
#[must_use]
pub const fn is_cancelled(&self) -> bool {
matches!(self, Self::Cancelled)
}
#[must_use]
pub const fn is_error(&self) -> bool {
matches!(self, Self::Error(_))
}
#[inline]
#[must_use]
pub fn into_response(self) -> Response {
match self {
Self::Ok(resp) => resp,
Self::Error(_err) => Response::new(
StatusCode::INTERNAL_SERVER_ERROR,
b"Internal Server Error".to_vec(),
),
Self::Cancelled => Response::new(
StatusCode::CLIENT_CLOSED_REQUEST,
b"Client Closed Request: request cancelled".to_vec(),
),
Self::Panicked(_msg) => Response::new(
StatusCode::INTERNAL_SERVER_ERROR,
b"Internal Server Error".to_vec(),
),
}
}
}
impl fmt::Display for RegionOutcome {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Ok(resp) => write!(f, "Ok({})", resp.status.as_u16()),
Self::Error(err) => write!(f, "Error({err})"),
Self::Cancelled => write!(f, "Cancelled"),
Self::Panicked(msg) => write!(f, "Panicked({msg})"),
}
}
}
pub struct IsolatedHandler<F> {
handler: F,
}
impl<F> IsolatedHandler<F>
where
F: Fn(&RequestContext<'_>) -> Response + Send + Sync + 'static,
{
#[must_use]
pub fn new(handler: F) -> Self {
Self { handler }
}
#[inline]
pub fn call(&self, cx: &Cx, request: Request) -> Response {
let region = RequestRegion::new(cx, request);
region.run(&self.handler).into_response()
}
}
fn extract_panic_message(payload: &Box<dyn std::any::Any + Send>) -> String {
payload.downcast_ref::<&str>().map_or_else(
|| {
payload
.downcast_ref::<String>()
.map_or_else(|| "unknown panic".to_string(), Clone::clone)
},
|s| (*s).to_string(),
)
}
#[cfg(test)]
#[allow(clippy::result_large_err)]
mod tests {
use super::*;
use crate::cx::Cx;
use crate::web::extract::Request;
use crate::web::response::StatusCode;
fn test_cx() -> Cx {
Cx::for_testing()
}
fn test_request(method: &str, path: &str) -> Request {
Request::new(method, path)
}
#[test]
fn run_success() {
let cx = test_cx();
let req = test_request("GET", "/hello");
let region = RequestRegion::new(&cx, req);
let outcome = region.run(|ctx| {
assert_eq!(ctx.method(), "GET");
assert_eq!(ctx.path(), "/hello");
Response::new(StatusCode::OK, b"ok".to_vec())
});
assert!(outcome.is_ok());
let resp = outcome.into_response();
assert_eq!(resp.status, StatusCode::OK);
}
#[test]
fn run_panic_isolation() {
let cx = test_cx();
let req = test_request("GET", "/panic");
let region = RequestRegion::new(&cx, req);
let outcome = region.run(|_ctx| {
panic!("handler bug");
});
assert!(outcome.is_panicked());
let resp = outcome.into_response();
assert_eq!(resp.status, StatusCode::INTERNAL_SERVER_ERROR);
}
#[test]
fn run_panic_string_message_preserved() {
let cx = test_cx();
let req = test_request("GET", "/");
let region = RequestRegion::new(&cx, req);
let outcome = region.run(|_ctx| {
panic!("something broke");
});
if let RegionOutcome::Panicked(msg) = &outcome {
assert!(msg.contains("something broke"), "msg: {msg}");
} else {
panic!("expected Panicked outcome");
}
}
#[test]
fn run_cancelled_before_handler_returns_499() {
let cx = test_cx();
cx.set_cancel_requested(true);
let req = test_request("GET", "/cancel");
let region = RequestRegion::new(&cx, req);
let outcome = region.run(|_ctx| {
panic!("should not reach handler");
});
assert!(outcome.is_cancelled());
let resp = outcome.into_response();
assert_eq!(resp.status, StatusCode::CLIENT_CLOSED_REQUEST);
assert_eq!(
resp.body.as_ref(),
b"Client Closed Request: request cancelled"
);
}
#[test]
fn run_cancelled_during_handler_returns_499() {
let cx = test_cx();
let req = test_request("GET", "/cancel-during");
let region = RequestRegion::new(&cx, req);
let outcome = region.run(|ctx| {
ctx.cx().set_cancel_requested(true);
Response::new(StatusCode::OK, b"ok".to_vec())
});
assert!(outcome.is_cancelled());
let resp = outcome.into_response();
assert_eq!(resp.status, StatusCode::CLIENT_CLOSED_REQUEST);
assert_eq!(
resp.body.as_ref(),
b"Client Closed Request: request cancelled"
);
}
#[test]
fn run_installs_current_cx_for_handler_body() {
let cx = test_cx();
let req = test_request("GET", "/current");
let expected_task = cx.task_id();
let expected_region = cx.region_id();
let region = RequestRegion::new(&cx, req);
let outcome = region.run(|_ctx| {
let current = Cx::current().expect("request region should install CURRENT_CX");
assert_eq!(current.task_id(), expected_task);
assert_eq!(current.region_id(), expected_region);
Response::empty(StatusCode::OK)
});
assert!(outcome.is_ok());
assert!(
Cx::current().is_none(),
"request region must restore the prior CURRENT_CX after the handler returns"
);
}
#[test]
fn run_sync_success() {
let cx = test_cx();
let req = test_request("POST", "/data");
let region = RequestRegion::new(&cx, req);
let outcome = region.run_sync(|ctx| {
assert_eq!(ctx.method(), "POST");
Ok(Response::new(StatusCode::CREATED, b"created".to_vec()))
});
assert!(outcome.is_ok());
let resp = outcome.into_response();
assert_eq!(resp.status, StatusCode::CREATED);
}
#[test]
fn run_sync_error() {
let cx = test_cx();
let req = test_request("GET", "/err");
let region = RequestRegion::new(&cx, req);
let outcome = region.run_sync(|_ctx| Err(Error::new(crate::error::ErrorKind::Internal)));
assert!(outcome.is_error());
let resp = outcome.into_response();
assert_eq!(resp.status, StatusCode::INTERNAL_SERVER_ERROR);
assert_eq!(resp.body.as_ref(), b"Internal Server Error");
}
#[test]
fn run_sync_panic() {
let cx = test_cx();
let req = test_request("GET", "/");
let region = RequestRegion::new(&cx, req);
let outcome = region.run_sync(|_ctx| -> Result<Response, Error> {
panic!("boom");
});
assert!(outcome.is_panicked());
}
#[test]
fn run_sync_cancelled_during_handler_returns_499() {
let cx = test_cx();
let req = test_request("GET", "/cancel-during");
let region = RequestRegion::new(&cx, req);
let outcome = region.run_sync(|ctx| {
ctx.cx().set_cancel_requested(true);
Ok(Response::new(StatusCode::OK, b"ok".to_vec()))
});
assert!(outcome.is_cancelled());
let resp = outcome.into_response();
assert_eq!(resp.status, StatusCode::CLIENT_CLOSED_REQUEST);
assert_eq!(
resp.body.as_ref(),
b"Client Closed Request: request cancelled"
);
}
#[test]
fn run_sync_installs_current_cx_for_handler_body() {
let cx = test_cx();
let req = test_request("POST", "/current");
let expected_task = cx.task_id();
let expected_region = cx.region_id();
let region = RequestRegion::new(&cx, req);
let outcome = region.run_sync(|_ctx| {
let current = Cx::current().expect("request region should install CURRENT_CX");
assert_eq!(current.task_id(), expected_task);
assert_eq!(current.region_id(), expected_region);
Ok(Response::empty(StatusCode::OK))
});
assert!(outcome.is_ok());
assert!(
Cx::current().is_none(),
"request region must restore the prior CURRENT_CX after sync handlers return"
);
}
#[test]
fn context_accessors() {
let cx = test_cx();
let mut req = test_request("DELETE", "/users/99");
req.headers
.insert("authorization".to_string(), "Bearer token".to_string());
let mut params = std::collections::HashMap::new();
params.insert("id".to_string(), "99".to_string());
req.path_params = params;
let region = RequestRegion::new(&cx, req);
let outcome = region.run(|ctx| {
assert_eq!(ctx.method(), "DELETE");
assert_eq!(ctx.path(), "/users/99");
assert_eq!(ctx.path_param("id"), Some("99"));
assert_eq!(ctx.path_param("missing"), None);
assert_eq!(ctx.header("Authorization"), Some("Bearer token"));
assert_eq!(ctx.header("authorization"), Some("Bearer token"));
assert_eq!(ctx.header("Missing"), None);
let _readonly = ctx.cx_readonly();
let _narrow = ctx.cx_narrow::<cap::CapSet<true, true, false, false, false>>();
Response::empty(StatusCode::NO_CONTENT)
});
assert!(outcome.is_ok());
}
#[test]
fn isolated_handler_success() {
let handler = IsolatedHandler::new(|ctx| {
let name = ctx.path_param("name").unwrap_or("world");
Response::new(StatusCode::OK, format!("Hello, {name}!").into_bytes())
});
let cx = test_cx();
let mut req = test_request("GET", "/greet/alice");
let mut params = std::collections::HashMap::new();
params.insert("name".to_string(), "alice".to_string());
req.path_params = params;
let resp = handler.call(&cx, req);
assert_eq!(resp.status, StatusCode::OK);
}
#[test]
fn isolated_handler_panic_returns_500() {
let handler = IsolatedHandler::new(|_ctx| {
panic!("handler crash");
});
let cx = test_cx();
let req = test_request("GET", "/");
let resp = handler.call(&cx, req);
assert_eq!(resp.status, StatusCode::INTERNAL_SERVER_ERROR);
assert_eq!(resp.body.as_ref(), b"Internal Server Error");
}
#[test]
fn panicked_response_does_not_leak_panic_message() {
let resp = RegionOutcome::Panicked("secret panic details".to_string()).into_response();
assert_eq!(resp.status, StatusCode::INTERNAL_SERVER_ERROR);
assert_eq!(resp.body.as_ref(), b"Internal Server Error");
}
#[test]
fn isolated_handler_cancelled_returns_499() {
let handler = IsolatedHandler::new(|_ctx| {
panic!("should not run");
});
let cx = test_cx();
cx.set_cancel_requested(true);
let req = test_request("GET", "/");
let resp = handler.call(&cx, req);
assert_eq!(resp.status, StatusCode::CLIENT_CLOSED_REQUEST);
assert_eq!(
resp.body.as_ref(),
b"Client Closed Request: request cancelled"
);
}
#[test]
fn region_outcome_display() {
let ok = RegionOutcome::Ok(Response::empty(StatusCode::OK));
assert!(ok.to_string().contains("200"));
let cancelled = RegionOutcome::Cancelled;
assert_eq!(cancelled.to_string(), "Cancelled");
let panicked = RegionOutcome::Panicked("oof".to_string());
assert!(panicked.to_string().contains("oof"));
}
#[test]
fn panic_message_from_str() {
let msg = extract_panic_message(&(Box::new("oops") as Box<dyn std::any::Any + Send>));
assert_eq!(msg, "oops");
}
#[test]
fn panic_message_from_string() {
let msg = extract_panic_message(
&(Box::new("owned msg".to_string()) as Box<dyn std::any::Any + Send>),
);
assert_eq!(msg, "owned msg");
}
#[test]
fn panic_message_unknown_type() {
let msg = extract_panic_message(&(Box::new(42i32) as Box<dyn std::any::Any + Send>));
assert_eq!(msg, "unknown panic");
}
mod metamorphic_tests {
use super::*;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
use std::time::Duration;
#[test]
fn mr_disconnect_triggers_cancel_within_one_tick() {
let cx = test_cx();
let req = test_request("GET", "/long-running");
let region = RequestRegion::new(&cx, req);
let cancel_observed = Arc::new(AtomicBool::new(false));
let cancel_observed_clone = Arc::clone(&cancel_observed);
let cx_clone = cx.clone();
let cancel_thread = std::thread::spawn(move || {
std::thread::sleep(Duration::from_millis(1)); cx_clone.set_cancel_requested(true);
});
let outcome = region.run(|ctx| {
for _i in 0..10 {
if ctx.cx().is_cancel_requested() {
cancel_observed_clone.store(true, Ordering::SeqCst);
return Response::new(
StatusCode::CLIENT_CLOSED_REQUEST,
b"cancelled".to_vec(),
);
}
std::thread::sleep(Duration::from_millis(1));
}
Response::new(StatusCode::OK, b"completed".to_vec())
});
cancel_thread.join().expect("cancel thread panicked");
assert!(
cancel_observed.load(Ordering::SeqCst) || outcome.is_cancelled(),
"Client disconnect should trigger observable cancellation"
);
}
#[test]
fn mr_downstream_futures_receive_cancellation() {
let cx = test_cx();
let req = test_request("GET", "/spawn-tasks");
let region = RequestRegion::new(&cx, req);
let task_cancelled = Arc::new(AtomicBool::new(false));
let task_cancelled_clone = Arc::clone(&task_cancelled);
let outcome = region.run(|ctx| {
std::thread::scope(|s| {
let task_ctx = ctx.cx().clone();
s.spawn(move || {
for _ in 0..100 {
if task_ctx.is_cancel_requested() {
task_cancelled_clone.store(true, Ordering::SeqCst);
break;
}
std::thread::sleep(Duration::from_millis(1));
}
});
std::thread::sleep(Duration::from_millis(5));
ctx.cx().set_cancel_requested(true);
std::thread::sleep(Duration::from_millis(10));
});
Response::new(StatusCode::OK, b"ok".to_vec())
});
assert!(
task_cancelled.load(Ordering::SeqCst) || outcome.is_cancelled(),
"Spawned tasks should receive cancellation signal"
);
}
#[test]
fn mr_no_obligation_leaks_after_disconnect() {
let cx = test_cx();
let req = test_request("POST", "/transaction");
let region = RequestRegion::new(&cx, req);
let obligation_cleaned = Arc::new(AtomicBool::new(false));
let obligation_cleaned_clone = Arc::clone(&obligation_cleaned);
let _outcome = region.run(|ctx| {
struct MockObligation {
cleaned: Arc<AtomicBool>,
}
impl Drop for MockObligation {
fn drop(&mut self) {
self.cleaned.store(true, Ordering::SeqCst);
}
}
let _obligation = MockObligation {
cleaned: obligation_cleaned_clone,
};
std::thread::sleep(Duration::from_millis(1));
ctx.cx().set_cancel_requested(true);
if ctx.cx().checkpoint().is_err() {
return Response::new(StatusCode::CLIENT_CLOSED_REQUEST, b"cancelled".to_vec());
}
Response::new(StatusCode::OK, b"committed".to_vec())
});
std::thread::sleep(Duration::from_millis(1));
assert!(
obligation_cleaned.load(Ordering::SeqCst),
"Obligations must be cleaned up when request is cancelled"
);
}
#[test]
fn mr_partial_response_flushed_atomically() {
let cx = test_cx();
let req = test_request("GET", "/streaming");
let region = RequestRegion::new(&cx, req);
let response_complete = Arc::new(AtomicBool::new(false));
let response_complete_clone = Arc::clone(&response_complete);
let cancel_cx = cx.clone();
let cancel_thread = std::thread::spawn(move || {
std::thread::sleep(Duration::from_millis(5));
cancel_cx.set_cancel_requested(true);
});
let outcome = region.run(|ctx| {
let mut response_data = Vec::new();
for i in 0..10 {
if ctx.cx().is_cancel_requested() {
return Response::new(
StatusCode::CLIENT_CLOSED_REQUEST,
b"cancelled".to_vec(),
);
}
response_data.push(b'a' + (i % 26) as u8);
std::thread::sleep(Duration::from_millis(1));
}
response_complete_clone.store(true, Ordering::SeqCst);
Response::new(StatusCode::OK, response_data)
});
cancel_thread.join().expect("cancel thread panicked");
match outcome {
RegionOutcome::Ok(_) => assert!(
response_complete.load(Ordering::SeqCst),
"Complete response should only be returned if fully built"
),
RegionOutcome::Cancelled => assert!(
!response_complete.load(Ordering::SeqCst),
"Cancelled response should not complete response building"
),
_ => panic!("Unexpected outcome: {:?}", outcome),
}
}
#[test]
fn mr_reconnect_request_id_deduplicated() {
let cx = test_cx();
let request_counter = Arc::new(AtomicU32::new(0));
let mut req1 = test_request("POST", "/idempotent");
req1.headers
.insert("x-request-id".to_string(), "req-123".to_string());
req1.headers
.insert("x-idempotency-key".to_string(), "key-123".to_string());
let region1 = RequestRegion::new(&cx, req1);
let counter_clone1 = Arc::clone(&request_counter);
let outcome1 = region1.run(|ctx| {
let request_id = ctx.header("x-request-id").unwrap_or("none");
let idempotency_key = ctx.header("x-idempotency-key").unwrap_or("none");
if request_id == "req-123" && idempotency_key == "key-123" {
counter_clone1.fetch_add(1, Ordering::SeqCst);
Response::new(StatusCode::CREATED, b"resource created".to_vec())
} else {
Response::new(StatusCode::BAD_REQUEST, b"missing headers".to_vec())
}
});
let mut req2 = test_request("POST", "/idempotent");
req2.headers
.insert("x-request-id".to_string(), "req-123".to_string());
req2.headers
.insert("x-idempotency-key".to_string(), "key-123".to_string());
let region2 = RequestRegion::new(&cx, req2);
let counter_clone2 = Arc::clone(&request_counter);
let outcome2 = region2.run(|ctx| {
let request_id = ctx.header("x-request-id").unwrap_or("none");
let idempotency_key = ctx.header("x-idempotency-key").unwrap_or("none");
let current_count = counter_clone2.load(Ordering::SeqCst);
if request_id == "req-123" && idempotency_key == "key-123" && current_count > 0 {
Response::new(StatusCode::CREATED, b"resource created".to_vec())
} else if current_count == 0 {
counter_clone2.fetch_add(1, Ordering::SeqCst);
Response::new(StatusCode::CREATED, b"resource created".to_vec())
} else {
Response::new(StatusCode::BAD_REQUEST, b"invalid state".to_vec())
}
});
assert!(outcome1.is_ok(), "First request should succeed");
assert!(
outcome2.is_ok(),
"Second request (reconnect) should succeed"
);
let final_count = request_counter.load(Ordering::SeqCst);
assert_eq!(
final_count, 1,
"Idempotent operation should only execute once despite multiple requests"
);
}
#[test]
fn mr_composite_disconnect_concurrent_operations() {
let cx = test_cx();
let req = test_request("POST", "/complex");
let region = RequestRegion::new(&cx, req);
let task_count = Arc::new(AtomicU32::new(0));
let cleanup_count = Arc::new(AtomicU32::new(0));
let task_count_clone = Arc::clone(&task_count);
let cleanup_count_clone = Arc::clone(&cleanup_count);
let outcome = region.run(|ctx| {
std::thread::scope(|s| {
let mut handles = Vec::new();
for _i in 0..3 {
let task_ctx = ctx.cx().clone();
let task_counter = Arc::clone(&task_count_clone);
let cleanup_counter = Arc::clone(&cleanup_count_clone);
handles.push(s.spawn(move || {
task_counter.fetch_add(1, Ordering::SeqCst);
let _cleanup = CleanupGuard {
counter: cleanup_counter,
};
for _ in 0..20 {
if task_ctx.is_cancel_requested() {
return; }
std::thread::sleep(Duration::from_micros(100));
}
}));
}
std::thread::sleep(Duration::from_millis(2));
ctx.cx().set_cancel_requested(true);
std::thread::sleep(Duration::from_millis(10));
for h in handles {
let _ = h.join();
}
});
Response::new(StatusCode::CLIENT_CLOSED_REQUEST, b"cancelled".to_vec())
});
std::thread::sleep(Duration::from_millis(5));
assert_eq!(
task_count.load(Ordering::SeqCst),
3,
"All spawned tasks should have started"
);
assert_eq!(
cleanup_count.load(Ordering::SeqCst),
3,
"All tasks should have performed cleanup"
);
assert!(
outcome.is_cancelled(),
"Request should be marked as cancelled"
);
}
struct CleanupGuard {
counter: Arc<AtomicU32>,
}
impl Drop for CleanupGuard {
fn drop(&mut self) {
self.counter.fetch_add(1, Ordering::SeqCst);
}
}
}
}