use asupersync::types::CancelReason;
use asupersync::{Budget, Cx, Outcome, RegionId, TaskId};
use std::sync::Arc;
use crate::dependency::{CleanupStack, DependencyCache, DependencyOverrides, ResolutionStack};
pub const DEFAULT_MAX_BODY_SIZE: usize = 1024 * 1024;
#[derive(Debug, Clone, Copy)]
pub struct BodyLimitConfig {
max_size: usize,
}
impl Default for BodyLimitConfig {
fn default() -> Self {
Self {
max_size: DEFAULT_MAX_BODY_SIZE,
}
}
}
impl BodyLimitConfig {
#[must_use]
pub fn new(max_size: usize) -> Self {
Self { max_size }
}
#[must_use]
pub fn max_size(self) -> usize {
self.max_size
}
}
#[derive(Debug, Clone)]
pub struct RequestContext {
cx: Cx,
request_id: u64,
dependency_cache: Arc<DependencyCache>,
dependency_overrides: Arc<DependencyOverrides>,
resolution_stack: Arc<ResolutionStack>,
cleanup_stack: Arc<CleanupStack>,
body_limit: BodyLimitConfig,
}
impl RequestContext {
#[must_use]
pub fn new(cx: Cx, request_id: u64) -> Self {
Self {
cx,
request_id,
dependency_cache: Arc::new(DependencyCache::new()),
dependency_overrides: Arc::new(DependencyOverrides::new()),
resolution_stack: Arc::new(ResolutionStack::new()),
cleanup_stack: Arc::new(CleanupStack::new()),
body_limit: BodyLimitConfig::default(),
}
}
#[must_use]
pub fn with_body_limit(cx: Cx, request_id: u64, max_body_size: usize) -> Self {
Self {
cx,
request_id,
dependency_cache: Arc::new(DependencyCache::new()),
dependency_overrides: Arc::new(DependencyOverrides::new()),
resolution_stack: Arc::new(ResolutionStack::new()),
cleanup_stack: Arc::new(CleanupStack::new()),
body_limit: BodyLimitConfig::new(max_body_size),
}
}
#[must_use]
pub fn with_overrides(cx: Cx, request_id: u64, overrides: Arc<DependencyOverrides>) -> Self {
Self {
cx,
request_id,
dependency_cache: Arc::new(DependencyCache::new()),
dependency_overrides: overrides,
resolution_stack: Arc::new(ResolutionStack::new()),
cleanup_stack: Arc::new(CleanupStack::new()),
body_limit: BodyLimitConfig::default(),
}
}
#[must_use]
pub fn with_overrides_and_body_limit(
cx: Cx,
request_id: u64,
overrides: Arc<DependencyOverrides>,
max_body_size: usize,
) -> Self {
Self {
cx,
request_id,
dependency_cache: Arc::new(DependencyCache::new()),
dependency_overrides: overrides,
resolution_stack: Arc::new(ResolutionStack::new()),
cleanup_stack: Arc::new(CleanupStack::new()),
body_limit: BodyLimitConfig::new(max_body_size),
}
}
#[must_use]
pub fn request_id(&self) -> u64 {
self.request_id
}
#[must_use]
pub fn dependency_cache(&self) -> &DependencyCache {
&self.dependency_cache
}
#[must_use]
pub fn dependency_overrides(&self) -> &DependencyOverrides {
&self.dependency_overrides
}
#[must_use]
pub fn resolution_stack(&self) -> &ResolutionStack {
&self.resolution_stack
}
#[must_use]
pub fn cleanup_stack(&self) -> &CleanupStack {
&self.cleanup_stack
}
#[must_use]
pub fn body_limit(&self) -> &BodyLimitConfig {
&self.body_limit
}
#[must_use]
pub fn max_body_size(&self) -> usize {
self.body_limit.max_size()
}
#[must_use]
pub fn region_id(&self) -> RegionId {
self.cx.region_id()
}
#[must_use]
pub fn task_id(&self) -> TaskId {
self.cx.task_id()
}
#[must_use]
pub fn budget(&self) -> Budget {
self.cx.budget()
}
#[must_use]
pub fn is_cancelled(&self) -> bool {
self.cx.is_cancel_requested()
}
pub fn checkpoint(&self) -> Result<(), CancelledError> {
self.cx.checkpoint().map_err(|_| CancelledError)
}
pub fn masked<F, R>(&self, f: F) -> R
where
F: FnOnce() -> R,
{
self.cx.masked(f)
}
pub fn trace(&self, message: &str) {
self.cx.trace(message);
}
#[must_use]
pub fn cx(&self) -> &Cx {
&self.cx
}
}
#[derive(Debug, Clone, Copy)]
pub struct CancelledError;
impl std::fmt::Display for CancelledError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "request cancelled")
}
}
impl std::error::Error for CancelledError {}
pub trait IntoOutcome<T, E> {
fn into_outcome(self) -> Outcome<T, E>;
}
impl<T, E> IntoOutcome<T, E> for Result<T, E> {
fn into_outcome(self) -> Outcome<T, E> {
match self {
Ok(v) => Outcome::Ok(v),
Err(e) => Outcome::Err(e),
}
}
}
impl<T, E> IntoOutcome<T, E> for Result<T, CancelledError>
where
E: Default,
{
fn into_outcome(self) -> Outcome<T, E> {
match self {
Ok(v) => Outcome::Ok(v),
Err(CancelledError) => Outcome::Cancelled(CancelReason::user("request cancelled")),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn cancelled_error_display() {
let err = CancelledError;
assert_eq!(format!("{err}"), "request cancelled");
}
#[test]
fn checkpoint_returns_error_when_cancel_requested() {
let cx = Cx::for_testing();
let ctx = RequestContext::new(cx, 1);
ctx.cx().set_cancel_requested(true);
assert!(ctx.checkpoint().is_err());
}
#[test]
fn masked_defers_cancellation_at_checkpoint() {
let cx = Cx::for_testing();
let ctx = RequestContext::new(cx, 1);
ctx.cx().set_cancel_requested(true);
let result = ctx.masked(|| ctx.checkpoint());
assert!(result.is_ok());
assert!(ctx.checkpoint().is_err());
}
#[test]
fn body_limit_config_default() {
let config = BodyLimitConfig::default();
assert_eq!(config.max_size(), DEFAULT_MAX_BODY_SIZE);
assert_eq!(config.max_size(), 1024 * 1024); }
#[test]
fn body_limit_config_custom() {
let config = BodyLimitConfig::new(512 * 1024);
assert_eq!(config.max_size(), 512 * 1024); }
#[test]
fn request_context_default_body_limit() {
let cx = Cx::for_testing();
let ctx = RequestContext::new(cx, 1);
assert_eq!(ctx.max_body_size(), DEFAULT_MAX_BODY_SIZE);
assert_eq!(ctx.body_limit().max_size(), DEFAULT_MAX_BODY_SIZE);
}
#[test]
fn request_context_custom_body_limit() {
let cx = Cx::for_testing();
let ctx = RequestContext::with_body_limit(cx, 1, 2 * 1024 * 1024);
assert_eq!(ctx.max_body_size(), 2 * 1024 * 1024); }
#[test]
fn request_context_with_overrides_has_default_limit() {
let cx = Cx::for_testing();
let overrides = Arc::new(DependencyOverrides::new());
let ctx = RequestContext::with_overrides(cx, 1, overrides);
assert_eq!(ctx.max_body_size(), DEFAULT_MAX_BODY_SIZE);
}
#[test]
fn request_context_with_overrides_and_custom_limit() {
let cx = Cx::for_testing();
let overrides = Arc::new(DependencyOverrides::new());
let ctx = RequestContext::with_overrides_and_body_limit(cx, 1, overrides, 4 * 1024 * 1024);
assert_eq!(ctx.max_body_size(), 4 * 1024 * 1024); }
#[test]
#[allow(clippy::similar_names)]
fn request_id_isolation_unique_per_context() {
let cx1 = Cx::for_testing();
let cx2 = Cx::for_testing();
let cx3 = Cx::for_testing();
let ctx1 = RequestContext::new(cx1, 100);
let ctx2 = RequestContext::new(cx2, 200);
let ctx3 = RequestContext::new(cx3, 300);
assert_eq!(ctx1.request_id(), 100);
assert_eq!(ctx2.request_id(), 200);
assert_eq!(ctx3.request_id(), 300);
assert_ne!(ctx1.request_id(), ctx2.request_id());
assert_ne!(ctx2.request_id(), ctx3.request_id());
}
#[test]
#[allow(clippy::similar_names)]
fn dependency_cache_isolation_per_request() {
let cx1 = Cx::for_testing();
let cx2 = Cx::for_testing();
let ctx1 = RequestContext::new(cx1, 1);
let ctx2 = RequestContext::new(cx2, 2);
ctx1.dependency_cache().insert::<i32>(42);
let value1 = ctx1.dependency_cache().get::<i32>();
let value2 = ctx2.dependency_cache().get::<i32>();
assert!(value1.is_some(), "ctx1 should have cached value");
assert_eq!(value1.unwrap(), 42);
assert!(value2.is_none(), "ctx2 should NOT have ctx1's cached value");
}
#[test]
#[allow(clippy::similar_names)]
fn cleanup_stack_isolation_per_request() {
use std::sync::atomic::{AtomicUsize, Ordering};
let cleanup_counter1 = Arc::new(AtomicUsize::new(0));
let cleanup_counter2 = Arc::new(AtomicUsize::new(0));
let cx1 = Cx::for_testing();
let cx2 = Cx::for_testing();
let ctx1 = RequestContext::new(cx1, 1);
let ctx2 = RequestContext::new(cx2, 2);
{
let counter = cleanup_counter1.clone();
ctx1.cleanup_stack().push(Box::new(move || {
Box::pin(async move {
counter.fetch_add(1, Ordering::SeqCst);
})
}));
}
{
let counter = cleanup_counter2.clone();
ctx2.cleanup_stack().push(Box::new(move || {
Box::pin(async move {
counter.fetch_add(1, Ordering::SeqCst);
})
}));
}
futures_executor::block_on(ctx1.cleanup_stack().run_cleanups());
assert_eq!(
cleanup_counter1.load(Ordering::SeqCst),
1,
"ctx1 cleanup should have run"
);
assert_eq!(
cleanup_counter2.load(Ordering::SeqCst),
0,
"ctx2 cleanup should NOT have run"
);
futures_executor::block_on(ctx2.cleanup_stack().run_cleanups());
assert_eq!(
cleanup_counter2.load(Ordering::SeqCst),
1,
"ctx2 cleanup should have run"
);
}
#[test]
#[allow(clippy::similar_names)]
fn cx_cancellation_isolation_per_request() {
let cx1 = Cx::for_testing();
let cx2 = Cx::for_testing();
let cx3 = Cx::for_testing();
let ctx1 = RequestContext::new(cx1, 1);
let ctx2 = RequestContext::new(cx2, 2);
let ctx3 = RequestContext::new(cx3, 3);
assert!(ctx1.checkpoint().is_ok(), "ctx1 should not be cancelled");
assert!(ctx2.checkpoint().is_ok(), "ctx2 should not be cancelled");
assert!(ctx3.checkpoint().is_ok(), "ctx3 should not be cancelled");
ctx2.cx().set_cancel_requested(true);
assert!(
ctx1.checkpoint().is_ok(),
"ctx1 should still not be cancelled"
);
assert!(ctx2.checkpoint().is_err(), "ctx2 should be cancelled");
assert!(
ctx3.checkpoint().is_ok(),
"ctx3 should still not be cancelled"
);
}
#[test]
#[allow(clippy::similar_names)]
fn body_limit_isolation_per_request() {
let cx1 = Cx::for_testing();
let cx2 = Cx::for_testing();
let ctx1 = RequestContext::with_body_limit(cx1, 1, 1024); let ctx2 = RequestContext::with_body_limit(cx2, 2, 1024 * 1024);
assert_eq!(ctx1.max_body_size(), 1024);
assert_eq!(ctx2.max_body_size(), 1024 * 1024);
assert_ne!(ctx1.max_body_size(), ctx2.max_body_size());
}
#[test]
fn concurrent_requests_fully_isolated() {
use std::thread;
const NUM_REQUESTS: usize = 100;
let results = Arc::new(parking_lot::Mutex::new(Vec::with_capacity(NUM_REQUESTS)));
let handles: Vec<_> = (0..NUM_REQUESTS)
.map(|i| {
let results = results.clone();
thread::spawn(move || {
let cx = Cx::for_testing();
let request_id = (i + 1) as u64 * 1000; let ctx = RequestContext::new(cx, request_id);
ctx.dependency_cache().insert::<u64>(request_id);
let cached = ctx.dependency_cache().get::<u64>();
let retrieved = cached.unwrap_or(0);
results.lock().push((request_id, retrieved));
})
})
.collect();
for handle in handles {
handle.join().expect("Thread panicked");
}
let results = results.lock();
assert_eq!(results.len(), NUM_REQUESTS);
for (request_id, retrieved) in results.iter() {
assert_eq!(
request_id, retrieved,
"Request {request_id} should retrieve its own cached value, not another request's"
);
}
}
#[test]
#[allow(clippy::similar_names)]
fn resolution_stack_isolation_per_request() {
use crate::dependency::DependencyScope;
let cx1 = Cx::for_testing();
let cx2 = Cx::for_testing();
let ctx1 = RequestContext::new(cx1, 1);
let ctx2 = RequestContext::new(cx2, 2);
ctx1.resolution_stack()
.push::<i32>("i32", DependencyScope::Request);
let cycle1 = ctx1.resolution_stack().check_cycle::<i32>("i32");
assert!(cycle1.is_some(), "ctx1 should detect cycle for i32");
let cycle2 = ctx2.resolution_stack().check_cycle::<i32>("i32");
assert!(
cycle2.is_none(),
"ctx2 should NOT see ctx1's resolution stack"
);
ctx2.resolution_stack()
.push::<i32>("i32", DependencyScope::Request);
assert_eq!(ctx2.resolution_stack().depth(), 1);
assert_eq!(ctx1.resolution_stack().depth(), 1);
assert_eq!(ctx2.resolution_stack().depth(), 1);
ctx1.resolution_stack().pop();
ctx2.resolution_stack().pop();
assert!(ctx1.resolution_stack().is_empty());
assert!(ctx2.resolution_stack().is_empty());
}
}