1use asupersync::types::CancelReason;
7use asupersync::{Budget, Cx, Outcome, RegionId, TaskId};
8use std::sync::Arc;
9
10use crate::dependency::{CleanupStack, DependencyCache, DependencyOverrides, ResolutionStack};
11
12pub const DEFAULT_MAX_BODY_SIZE: usize = 1024 * 1024;
14
15#[derive(Debug, Clone, Copy)]
21pub struct BodyLimitConfig {
22 max_size: usize,
24}
25
26impl Default for BodyLimitConfig {
27 fn default() -> Self {
28 Self {
29 max_size: DEFAULT_MAX_BODY_SIZE,
30 }
31 }
32}
33
34impl BodyLimitConfig {
35 #[must_use]
37 pub fn new(max_size: usize) -> Self {
38 Self { max_size }
39 }
40
41 #[must_use]
43 pub fn max_size(&self) -> usize {
44 self.max_size
45 }
46}
47
48#[derive(Debug, Clone)]
75pub struct RequestContext {
76 cx: Cx,
78 request_id: u64,
80 dependency_cache: Arc<DependencyCache>,
82 dependency_overrides: Arc<DependencyOverrides>,
84 resolution_stack: Arc<ResolutionStack>,
86 cleanup_stack: Arc<CleanupStack>,
88 body_limit: BodyLimitConfig,
90}
91
92impl RequestContext {
93 #[must_use]
99 pub fn new(cx: Cx, request_id: u64) -> Self {
100 Self {
101 cx,
102 request_id,
103 dependency_cache: Arc::new(DependencyCache::new()),
104 dependency_overrides: Arc::new(DependencyOverrides::new()),
105 resolution_stack: Arc::new(ResolutionStack::new()),
106 cleanup_stack: Arc::new(CleanupStack::new()),
107 body_limit: BodyLimitConfig::default(),
108 }
109 }
110
111 #[must_use]
116 pub fn with_body_limit(cx: Cx, request_id: u64, max_body_size: usize) -> Self {
117 Self {
118 cx,
119 request_id,
120 dependency_cache: Arc::new(DependencyCache::new()),
121 dependency_overrides: Arc::new(DependencyOverrides::new()),
122 resolution_stack: Arc::new(ResolutionStack::new()),
123 cleanup_stack: Arc::new(CleanupStack::new()),
124 body_limit: BodyLimitConfig::new(max_body_size),
125 }
126 }
127
128 #[must_use]
130 pub fn with_overrides(cx: Cx, request_id: u64, overrides: Arc<DependencyOverrides>) -> Self {
131 Self {
132 cx,
133 request_id,
134 dependency_cache: Arc::new(DependencyCache::new()),
135 dependency_overrides: overrides,
136 resolution_stack: Arc::new(ResolutionStack::new()),
137 cleanup_stack: Arc::new(CleanupStack::new()),
138 body_limit: BodyLimitConfig::default(),
139 }
140 }
141
142 #[must_use]
144 pub fn with_overrides_and_body_limit(
145 cx: Cx,
146 request_id: u64,
147 overrides: Arc<DependencyOverrides>,
148 max_body_size: usize,
149 ) -> Self {
150 Self {
151 cx,
152 request_id,
153 dependency_cache: Arc::new(DependencyCache::new()),
154 dependency_overrides: overrides,
155 resolution_stack: Arc::new(ResolutionStack::new()),
156 cleanup_stack: Arc::new(CleanupStack::new()),
157 body_limit: BodyLimitConfig::new(max_body_size),
158 }
159 }
160
161 #[must_use]
165 pub fn request_id(&self) -> u64 {
166 self.request_id
167 }
168
169 #[must_use]
171 pub fn dependency_cache(&self) -> &DependencyCache {
172 &self.dependency_cache
173 }
174
175 #[must_use]
177 pub fn dependency_overrides(&self) -> &DependencyOverrides {
178 &self.dependency_overrides
179 }
180
181 #[must_use]
183 pub fn resolution_stack(&self) -> &ResolutionStack {
184 &self.resolution_stack
185 }
186
187 #[must_use]
191 pub fn cleanup_stack(&self) -> &CleanupStack {
192 &self.cleanup_stack
193 }
194
195 #[must_use]
200 pub fn body_limit(&self) -> &BodyLimitConfig {
201 &self.body_limit
202 }
203
204 #[must_use]
208 pub fn max_body_size(&self) -> usize {
209 self.body_limit.max_size()
210 }
211
212 #[must_use]
218 pub fn region_id(&self) -> RegionId {
219 self.cx.region_id()
220 }
221
222 #[must_use]
224 pub fn task_id(&self) -> TaskId {
225 self.cx.task_id()
226 }
227
228 #[must_use]
234 pub fn budget(&self) -> Budget {
235 self.cx.budget()
236 }
237
238 #[must_use]
243 pub fn is_cancelled(&self) -> bool {
244 self.cx.is_cancel_requested()
245 }
246
247 pub fn checkpoint(&self) -> Result<(), CancelledError> {
269 self.cx.checkpoint().map_err(|_| CancelledError)
270 }
271
272 pub fn masked<F, R>(&self, f: F) -> R
288 where
289 F: FnOnce() -> R,
290 {
291 self.cx.masked(f)
292 }
293
294 pub fn trace(&self, message: &str) {
299 self.cx.trace(message);
300 }
301
302 #[must_use]
307 pub fn cx(&self) -> &Cx {
308 &self.cx
309 }
310}
311
312#[derive(Debug, Clone, Copy)]
318pub struct CancelledError;
319
320impl std::fmt::Display for CancelledError {
321 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
322 write!(f, "request cancelled")
323 }
324}
325
326impl std::error::Error for CancelledError {}
327
328pub trait IntoOutcome<T, E> {
333 fn into_outcome(self) -> Outcome<T, E>;
335}
336
337impl<T, E> IntoOutcome<T, E> for Result<T, E> {
338 fn into_outcome(self) -> Outcome<T, E> {
339 match self {
340 Ok(v) => Outcome::Ok(v),
341 Err(e) => Outcome::Err(e),
342 }
343 }
344}
345
346impl<T, E> IntoOutcome<T, E> for Result<T, CancelledError>
347where
348 E: Default,
349{
350 fn into_outcome(self) -> Outcome<T, E> {
351 match self {
352 Ok(v) => Outcome::Ok(v),
353 Err(CancelledError) => Outcome::Cancelled(CancelReason::user("request cancelled")),
354 }
355 }
356}
357
358#[cfg(test)]
359mod tests {
360 use super::*;
361
362 #[test]
363 fn cancelled_error_display() {
364 let err = CancelledError;
365 assert_eq!(format!("{err}"), "request cancelled");
366 }
367
368 #[test]
369 fn checkpoint_returns_error_when_cancel_requested() {
370 let cx = Cx::for_testing();
371 let ctx = RequestContext::new(cx, 1);
372 ctx.cx().set_cancel_requested(true);
373 assert!(ctx.checkpoint().is_err());
374 }
375
376 #[test]
377 fn masked_defers_cancellation_at_checkpoint() {
378 let cx = Cx::for_testing();
379 let ctx = RequestContext::new(cx, 1);
380 ctx.cx().set_cancel_requested(true);
381
382 let result = ctx.masked(|| ctx.checkpoint());
383 assert!(result.is_ok());
384 assert!(ctx.checkpoint().is_err());
385 }
386
387 #[test]
392 fn body_limit_config_default() {
393 let config = BodyLimitConfig::default();
394 assert_eq!(config.max_size(), DEFAULT_MAX_BODY_SIZE);
395 assert_eq!(config.max_size(), 1024 * 1024); }
397
398 #[test]
399 fn body_limit_config_custom() {
400 let config = BodyLimitConfig::new(512 * 1024);
401 assert_eq!(config.max_size(), 512 * 1024); }
403
404 #[test]
405 fn request_context_default_body_limit() {
406 let cx = Cx::for_testing();
407 let ctx = RequestContext::new(cx, 1);
408 assert_eq!(ctx.max_body_size(), DEFAULT_MAX_BODY_SIZE);
409 assert_eq!(ctx.body_limit().max_size(), DEFAULT_MAX_BODY_SIZE);
410 }
411
412 #[test]
413 fn request_context_custom_body_limit() {
414 let cx = Cx::for_testing();
415 let ctx = RequestContext::with_body_limit(cx, 1, 2 * 1024 * 1024);
416 assert_eq!(ctx.max_body_size(), 2 * 1024 * 1024); }
418
419 #[test]
420 fn request_context_with_overrides_has_default_limit() {
421 let cx = Cx::for_testing();
422 let overrides = Arc::new(DependencyOverrides::new());
423 let ctx = RequestContext::with_overrides(cx, 1, overrides);
424 assert_eq!(ctx.max_body_size(), DEFAULT_MAX_BODY_SIZE);
425 }
426
427 #[test]
428 fn request_context_with_overrides_and_custom_limit() {
429 let cx = Cx::for_testing();
430 let overrides = Arc::new(DependencyOverrides::new());
431 let ctx = RequestContext::with_overrides_and_body_limit(cx, 1, overrides, 4 * 1024 * 1024);
432 assert_eq!(ctx.max_body_size(), 4 * 1024 * 1024); }
434
435 #[test]
440 #[allow(clippy::similar_names)]
441 fn request_id_isolation_unique_per_context() {
442 let cx1 = Cx::for_testing();
444 let cx2 = Cx::for_testing();
445 let cx3 = Cx::for_testing();
446
447 let ctx1 = RequestContext::new(cx1, 100);
448 let ctx2 = RequestContext::new(cx2, 200);
449 let ctx3 = RequestContext::new(cx3, 300);
450
451 assert_eq!(ctx1.request_id(), 100);
453 assert_eq!(ctx2.request_id(), 200);
454 assert_eq!(ctx3.request_id(), 300);
455
456 assert_ne!(ctx1.request_id(), ctx2.request_id());
458 assert_ne!(ctx2.request_id(), ctx3.request_id());
459 }
460
461 #[test]
462 #[allow(clippy::similar_names)]
463 fn dependency_cache_isolation_per_request() {
464 let cx1 = Cx::for_testing();
466 let cx2 = Cx::for_testing();
467
468 let ctx1 = RequestContext::new(cx1, 1);
469 let ctx2 = RequestContext::new(cx2, 2);
470
471 ctx1.dependency_cache().insert::<i32>(42);
473
474 let value1 = ctx1.dependency_cache().get::<i32>();
476 let value2 = ctx2.dependency_cache().get::<i32>();
477
478 assert!(value1.is_some(), "ctx1 should have cached value");
479 assert_eq!(value1.unwrap(), 42);
480 assert!(value2.is_none(), "ctx2 should NOT have ctx1's cached value");
481 }
482
483 #[test]
484 #[allow(clippy::similar_names)]
485 fn cleanup_stack_isolation_per_request() {
486 use std::sync::atomic::{AtomicUsize, Ordering};
488
489 let cleanup_counter1 = Arc::new(AtomicUsize::new(0));
490 let cleanup_counter2 = Arc::new(AtomicUsize::new(0));
491
492 let cx1 = Cx::for_testing();
493 let cx2 = Cx::for_testing();
494
495 let ctx1 = RequestContext::new(cx1, 1);
496 let ctx2 = RequestContext::new(cx2, 2);
497
498 {
500 let counter = cleanup_counter1.clone();
501 ctx1.cleanup_stack().push(Box::new(move || {
502 Box::pin(async move {
503 counter.fetch_add(1, Ordering::SeqCst);
504 })
505 }));
506 }
507
508 {
510 let counter = cleanup_counter2.clone();
511 ctx2.cleanup_stack().push(Box::new(move || {
512 Box::pin(async move {
513 counter.fetch_add(1, Ordering::SeqCst);
514 })
515 }));
516 }
517
518 futures_executor::block_on(ctx1.cleanup_stack().run_cleanups());
520
521 assert_eq!(
523 cleanup_counter1.load(Ordering::SeqCst),
524 1,
525 "ctx1 cleanup should have run"
526 );
527 assert_eq!(
528 cleanup_counter2.load(Ordering::SeqCst),
529 0,
530 "ctx2 cleanup should NOT have run"
531 );
532
533 futures_executor::block_on(ctx2.cleanup_stack().run_cleanups());
535 assert_eq!(
536 cleanup_counter2.load(Ordering::SeqCst),
537 1,
538 "ctx2 cleanup should have run"
539 );
540 }
541
542 #[test]
543 #[allow(clippy::similar_names)]
544 fn cx_cancellation_isolation_per_request() {
545 let cx1 = Cx::for_testing();
547 let cx2 = Cx::for_testing();
548 let cx3 = Cx::for_testing();
549
550 let ctx1 = RequestContext::new(cx1, 1);
551 let ctx2 = RequestContext::new(cx2, 2);
552 let ctx3 = RequestContext::new(cx3, 3);
553
554 assert!(ctx1.checkpoint().is_ok(), "ctx1 should not be cancelled");
556 assert!(ctx2.checkpoint().is_ok(), "ctx2 should not be cancelled");
557 assert!(ctx3.checkpoint().is_ok(), "ctx3 should not be cancelled");
558
559 ctx2.cx().set_cancel_requested(true);
561
562 assert!(
564 ctx1.checkpoint().is_ok(),
565 "ctx1 should still not be cancelled"
566 );
567 assert!(ctx2.checkpoint().is_err(), "ctx2 should be cancelled");
568 assert!(
569 ctx3.checkpoint().is_ok(),
570 "ctx3 should still not be cancelled"
571 );
572 }
573
574 #[test]
575 #[allow(clippy::similar_names)]
576 fn body_limit_isolation_per_request() {
577 let cx1 = Cx::for_testing();
579 let cx2 = Cx::for_testing();
580
581 let ctx1 = RequestContext::with_body_limit(cx1, 1, 1024); let ctx2 = RequestContext::with_body_limit(cx2, 2, 1024 * 1024); assert_eq!(ctx1.max_body_size(), 1024);
587 assert_eq!(ctx2.max_body_size(), 1024 * 1024);
588
589 assert_ne!(ctx1.max_body_size(), ctx2.max_body_size());
591 }
592
593 #[test]
594 fn concurrent_requests_fully_isolated() {
595 use std::thread;
597
598 const NUM_REQUESTS: usize = 100;
599 let results = Arc::new(parking_lot::Mutex::new(Vec::with_capacity(NUM_REQUESTS)));
600
601 let handles: Vec<_> = (0..NUM_REQUESTS)
602 .map(|i| {
603 let results = results.clone();
604 thread::spawn(move || {
605 let cx = Cx::for_testing();
606 let request_id = (i + 1) as u64 * 1000; let ctx = RequestContext::new(cx, request_id);
608
609 ctx.dependency_cache().insert::<u64>(request_id);
611
612 let cached = ctx.dependency_cache().get::<u64>();
614 let retrieved = cached.unwrap_or(0);
615
616 results.lock().push((request_id, retrieved));
617 })
618 })
619 .collect();
620
621 for handle in handles {
623 handle.join().expect("Thread panicked");
624 }
625
626 let results = results.lock();
628 assert_eq!(results.len(), NUM_REQUESTS);
629
630 for (request_id, retrieved) in results.iter() {
631 assert_eq!(
632 request_id, retrieved,
633 "Request {request_id} should retrieve its own cached value, not another request's"
634 );
635 }
636 }
637
638 #[test]
639 #[allow(clippy::similar_names)]
640 fn resolution_stack_isolation_per_request() {
641 use crate::dependency::DependencyScope;
643
644 let cx1 = Cx::for_testing();
645 let cx2 = Cx::for_testing();
646
647 let ctx1 = RequestContext::new(cx1, 1);
648 let ctx2 = RequestContext::new(cx2, 2);
649
650 ctx1.resolution_stack()
652 .push::<i32>("i32", DependencyScope::Request);
653
654 let cycle1 = ctx1.resolution_stack().check_cycle::<i32>("i32");
656 assert!(cycle1.is_some(), "ctx1 should detect cycle for i32");
657
658 let cycle2 = ctx2.resolution_stack().check_cycle::<i32>("i32");
660 assert!(
661 cycle2.is_none(),
662 "ctx2 should NOT see ctx1's resolution stack"
663 );
664
665 ctx2.resolution_stack()
667 .push::<i32>("i32", DependencyScope::Request);
668 assert_eq!(ctx2.resolution_stack().depth(), 1);
669
670 assert_eq!(ctx1.resolution_stack().depth(), 1);
672 assert_eq!(ctx2.resolution_stack().depth(), 1);
673
674 ctx1.resolution_stack().pop();
676 ctx2.resolution_stack().pop();
677 assert!(ctx1.resolution_stack().is_empty());
678 assert!(ctx2.resolution_stack().is_empty());
679 }
680}