use crate::cx::cx::ObservabilityState;
use crate::observability::metrics::MetricsProvider;
use crate::observability::{LogCollector, ObservabilityConfig};
use crate::runtime::config::{LeakEscalation, ObligationLeakResponse};
use crate::runtime::io_driver::IoDriverHandle;
use crate::runtime::{BlockingPoolHandle, ObligationTable, RegionTable, TaskTable};
use crate::sync::ContendedMutex;
use crate::time::TimerDriverHandle;
use crate::trace::TraceBufferHandle;
use crate::trace::distributed::LogicalClockMode;
use crate::types::{CancelAttributionConfig, RegionId, TaskId, Time};
use crate::util::{ArenaIndex, EntropySource};
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
#[derive(Debug, Clone)]
pub struct ShardedObservability {
config: ObservabilityConfig,
collector: LogCollector,
}
impl ShardedObservability {
#[must_use]
pub fn new(config: ObservabilityConfig) -> Self {
let collector = config.create_collector();
Self { config, collector }
}
#[must_use]
pub fn for_task(&self, region: RegionId, task: TaskId) -> ObservabilityState {
ObservabilityState::new_with_config(
region,
task,
&self.config,
Some(self.collector.clone()),
)
}
#[must_use]
pub fn config(&self) -> &ObservabilityConfig {
&self.config
}
#[must_use]
pub fn collector(&self) -> LogCollector {
self.collector.clone()
}
}
#[derive(Debug)]
pub struct ShardedConfig {
pub io_driver: Option<IoDriverHandle>,
pub timer_driver: Option<TimerDriverHandle>,
pub logical_clock_mode: LogicalClockMode,
pub cancel_attribution: CancelAttributionConfig,
pub entropy_source: Arc<dyn EntropySource>,
pub blocking_pool: Option<BlockingPoolHandle>,
pub obligation_leak_response: ObligationLeakResponse,
pub leak_escalation: Option<LeakEscalation>,
pub observability: Option<ShardedObservability>,
}
pub struct ShardedState {
pub tasks: ContendedMutex<TaskTable>,
pub regions: ContendedMutex<RegionTable>,
root_region: AtomicU64,
pub obligations: ContendedMutex<ObligationTable>,
pub leak_count: AtomicU64,
pub trace: TraceBufferHandle,
pub metrics: Arc<dyn MetricsProvider>,
pub now: AtomicU64,
pub config: Arc<ShardedConfig>,
}
impl std::fmt::Debug for ShardedState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ShardedState")
.field("tasks", &"<ContendedMutex<TaskTable>>")
.field("regions", &"<ContendedMutex<RegionTable>>")
.field("root_region", &self.root_region())
.field("obligations", &"<ContendedMutex<ObligationTable>>")
.field("leak_count", &self.leak_count.load(Ordering::Relaxed))
.field("trace", &self.trace)
.field("metrics", &"<dyn MetricsProvider>")
.field("now", &self.now.load(Ordering::Relaxed))
.field("config", &self.config)
.finish()
}
}
impl ShardedState {
#[must_use]
pub fn new(
trace: TraceBufferHandle,
metrics: Arc<dyn MetricsProvider>,
config: ShardedConfig,
) -> Self {
Self {
tasks: ContendedMutex::new("tasks", TaskTable::new()),
regions: ContendedMutex::new("regions", RegionTable::new()),
root_region: AtomicU64::new(ROOT_REGION_NONE),
obligations: ContendedMutex::new("obligations", ObligationTable::new()),
leak_count: AtomicU64::new(0),
trace,
metrics,
now: AtomicU64::new(0),
config: Arc::new(config),
}
}
#[inline]
#[must_use]
pub fn current_time(&self) -> Time {
Time::from_nanos(self.now.load(Ordering::Acquire))
}
#[inline]
pub fn set_time(&self, time: Time) {
self.now.store(time.as_nanos(), Ordering::Release);
}
#[inline]
pub fn increment_leak_count(&self) -> u64 {
self.leak_count.fetch_add(1, Ordering::Relaxed) + 1
}
#[inline]
#[must_use]
pub fn leak_count(&self) -> u64 {
self.leak_count.load(Ordering::Relaxed)
}
#[inline]
#[must_use]
pub fn root_region(&self) -> Option<RegionId> {
decode_root_region(self.root_region.load(Ordering::Acquire))
}
pub fn set_root_region(&self, region: RegionId) -> bool {
let encoded = encode_root_region(region);
let result = self.root_region.compare_exchange(
ROOT_REGION_NONE,
encoded,
Ordering::AcqRel,
Ordering::Acquire,
);
result.is_ok()
}
#[inline]
#[must_use]
pub fn trace_handle(&self) -> TraceBufferHandle {
self.trace.clone()
}
#[inline]
#[must_use]
pub fn metrics_provider(&self) -> Arc<dyn MetricsProvider> {
Arc::clone(&self.metrics)
}
#[inline]
#[must_use]
pub fn config(&self) -> &Arc<ShardedConfig> {
&self.config
}
#[inline]
#[must_use]
pub fn io_driver_handle(&self) -> Option<IoDriverHandle> {
self.config.io_driver.clone()
}
#[inline]
#[must_use]
pub fn timer_driver_handle(&self) -> Option<TimerDriverHandle> {
self.config.timer_driver.clone()
}
}
const ROOT_REGION_NONE: u64 = 0;
#[inline]
fn encode_root_region(region: RegionId) -> u64 {
let arena = region.arena_index();
let index = u64::from(arena.index());
let generation = u64::from(arena.generation());
let packed = (generation << 32) | index;
assert!(packed != u64::MAX, "region ID too large for atomic storage");
packed + 1
}
#[inline]
fn decode_root_region(encoded: u64) -> Option<RegionId> {
if encoded == ROOT_REGION_NONE {
return None;
}
let packed = encoded - 1;
let index = (packed & 0xFFFF_FFFF) as u32;
let generation = (packed >> 32) as u32;
Some(RegionId::from_arena(ArenaIndex::new(index, generation)))
}
pub struct ShardGuard<'a> {
pub config: &'a Arc<ShardedConfig>,
pub regions: Option<crate::sync::ContendedMutexGuard<'a, RegionTable>>,
pub tasks: Option<crate::sync::ContendedMutexGuard<'a, TaskTable>>,
pub obligations: Option<crate::sync::ContendedMutexGuard<'a, ObligationTable>>,
#[cfg(debug_assertions)]
debug_locks: usize,
}
impl<'a> ShardGuard<'a> {
#[must_use]
pub fn tasks_only(shards: &'a ShardedState) -> Self {
#[cfg(debug_assertions)]
lock_order::before_lock(LockShard::Tasks);
let tasks = shards
.tasks
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
#[cfg(debug_assertions)]
lock_order::after_lock(LockShard::Tasks);
Self {
config: &shards.config,
regions: None,
tasks: Some(tasks),
obligations: None,
#[cfg(debug_assertions)]
debug_locks: 1,
}
}
#[must_use]
pub fn regions_only(shards: &'a ShardedState) -> Self {
#[cfg(debug_assertions)]
lock_order::before_lock(LockShard::Regions);
let regions = shards
.regions
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
#[cfg(debug_assertions)]
lock_order::after_lock(LockShard::Regions);
Self {
config: &shards.config,
regions: Some(regions),
tasks: None,
obligations: None,
#[cfg(debug_assertions)]
debug_locks: 1,
}
}
#[must_use]
pub fn obligations_only(shards: &'a ShardedState) -> Self {
#[cfg(debug_assertions)]
lock_order::before_lock(LockShard::Obligations);
let obligations = shards
.obligations
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
#[cfg(debug_assertions)]
lock_order::after_lock(LockShard::Obligations);
Self {
config: &shards.config,
regions: None,
tasks: None,
obligations: Some(obligations),
#[cfg(debug_assertions)]
debug_locks: 1,
}
}
#[must_use]
pub fn for_task_completed(shards: &'a ShardedState) -> Self {
#[cfg(debug_assertions)]
lock_order::before_lock(LockShard::Regions);
let regions = shards
.regions
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
#[cfg(debug_assertions)]
lock_order::after_lock(LockShard::Regions);
#[cfg(debug_assertions)]
lock_order::before_lock(LockShard::Tasks);
let tasks = shards
.tasks
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
#[cfg(debug_assertions)]
lock_order::after_lock(LockShard::Tasks);
#[cfg(debug_assertions)]
lock_order::before_lock(LockShard::Obligations);
let obligations = shards
.obligations
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
#[cfg(debug_assertions)]
lock_order::after_lock(LockShard::Obligations);
Self {
config: &shards.config,
regions: Some(regions),
tasks: Some(tasks),
obligations: Some(obligations),
#[cfg(debug_assertions)]
debug_locks: 3,
}
}
#[must_use]
pub fn for_cancel(shards: &'a ShardedState) -> Self {
#[cfg(debug_assertions)]
lock_order::before_lock(LockShard::Regions);
let regions = shards
.regions
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
#[cfg(debug_assertions)]
lock_order::after_lock(LockShard::Regions);
#[cfg(debug_assertions)]
lock_order::before_lock(LockShard::Tasks);
let tasks = shards
.tasks
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
#[cfg(debug_assertions)]
lock_order::after_lock(LockShard::Tasks);
#[cfg(debug_assertions)]
lock_order::before_lock(LockShard::Obligations);
let obligations = shards
.obligations
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
#[cfg(debug_assertions)]
lock_order::after_lock(LockShard::Obligations);
Self {
config: &shards.config,
regions: Some(regions),
tasks: Some(tasks),
obligations: Some(obligations),
#[cfg(debug_assertions)]
debug_locks: 3,
}
}
#[must_use]
pub fn for_obligation(shards: &'a ShardedState) -> Self {
#[cfg(debug_assertions)]
lock_order::before_lock(LockShard::Regions);
let regions = shards
.regions
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
#[cfg(debug_assertions)]
lock_order::after_lock(LockShard::Regions);
#[cfg(debug_assertions)]
lock_order::before_lock(LockShard::Obligations);
let obligations = shards
.obligations
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
#[cfg(debug_assertions)]
lock_order::after_lock(LockShard::Obligations);
Self {
config: &shards.config,
regions: Some(regions),
tasks: None,
obligations: Some(obligations),
#[cfg(debug_assertions)]
debug_locks: 2,
}
}
#[must_use]
pub fn for_obligation_resolve(shards: &'a ShardedState) -> Self {
#[cfg(debug_assertions)]
lock_order::before_lock(LockShard::Regions);
let regions = shards
.regions
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
#[cfg(debug_assertions)]
lock_order::after_lock(LockShard::Regions);
#[cfg(debug_assertions)]
lock_order::before_lock(LockShard::Tasks);
let tasks = shards
.tasks
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
#[cfg(debug_assertions)]
lock_order::after_lock(LockShard::Tasks);
#[cfg(debug_assertions)]
lock_order::before_lock(LockShard::Obligations);
let obligations = shards
.obligations
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
#[cfg(debug_assertions)]
lock_order::after_lock(LockShard::Obligations);
Self {
config: &shards.config,
regions: Some(regions),
tasks: Some(tasks),
obligations: Some(obligations),
#[cfg(debug_assertions)]
debug_locks: 3,
}
}
#[must_use]
pub fn for_spawn(shards: &'a ShardedState) -> Self {
#[cfg(debug_assertions)]
lock_order::before_lock(LockShard::Regions);
let regions = shards
.regions
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
#[cfg(debug_assertions)]
lock_order::after_lock(LockShard::Regions);
#[cfg(debug_assertions)]
lock_order::before_lock(LockShard::Tasks);
let tasks = shards
.tasks
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
#[cfg(debug_assertions)]
lock_order::after_lock(LockShard::Tasks);
Self {
config: &shards.config,
regions: Some(regions),
tasks: Some(tasks),
obligations: None,
#[cfg(debug_assertions)]
debug_locks: 2,
}
}
#[must_use]
pub fn all(shards: &'a ShardedState) -> Self {
#[cfg(debug_assertions)]
lock_order::before_lock(LockShard::Regions);
let regions = shards
.regions
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
#[cfg(debug_assertions)]
lock_order::after_lock(LockShard::Regions);
#[cfg(debug_assertions)]
lock_order::before_lock(LockShard::Tasks);
let tasks = shards
.tasks
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
#[cfg(debug_assertions)]
lock_order::after_lock(LockShard::Tasks);
#[cfg(debug_assertions)]
lock_order::before_lock(LockShard::Obligations);
let obligations = shards
.obligations
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
#[cfg(debug_assertions)]
lock_order::after_lock(LockShard::Obligations);
Self {
config: &shards.config,
regions: Some(regions),
tasks: Some(tasks),
obligations: Some(obligations),
#[cfg(debug_assertions)]
debug_locks: 3,
}
}
}
impl Drop for ShardGuard<'_> {
fn drop(&mut self) {
let obligations = self.obligations.take();
let tasks = self.tasks.take();
let regions = self.regions.take();
drop(obligations);
drop(tasks);
drop(regions);
#[cfg(debug_assertions)]
{
lock_order::unlock_n(self.debug_locks);
}
}
}
#[cfg(debug_assertions)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum LockShard {
Regions,
Tasks,
Obligations,
}
#[cfg(debug_assertions)]
impl LockShard {
const fn order(self) -> u8 {
match self {
Self::Regions => 0,
Self::Tasks => 1,
Self::Obligations => 2,
}
}
const fn label(self) -> &'static str {
match self {
Self::Regions => "B:Regions",
Self::Tasks => "A:Tasks",
Self::Obligations => "C:Obligations",
}
}
}
#[cfg(debug_assertions)]
impl std::fmt::Display for LockShard {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.label())
}
}
#[cfg(debug_assertions)]
pub(crate) mod lock_order {
use super::LockShard;
use std::cell::RefCell;
thread_local! {
static HELD: RefCell<Vec<LockShard>> = const { RefCell::new(Vec::new()) };
}
pub fn before_lock(next: LockShard) {
HELD.with(|held| {
let held = held.borrow();
if let Some(last) = held.last() {
debug_assert!(
last.order() < next.order(),
"lock order violation: holding {} (order {}) then acquiring {} (order {}); \
canonical order is B:Regions(0) → A:Tasks(1) → C:Obligations(2)",
last.label(),
last.order(),
next.label(),
next.order(),
);
}
});
}
pub fn after_lock(locked: LockShard) {
HELD.with(|held| {
held.borrow_mut().push(locked);
});
}
pub fn unlock_n(count: usize) {
let _ = HELD.try_with(|held| {
let mut held = held.borrow_mut();
for _ in 0..count {
held.pop();
}
});
}
#[cfg(test)]
pub fn held_count() -> usize {
HELD.with(|held| held.borrow().len())
}
#[cfg(test)]
pub fn held_labels() -> Vec<&'static str> {
HELD.with(|held| held.borrow().iter().map(|s| s.label()).collect())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::observability::metrics::NoOpMetrics;
use crate::trace::TraceBufferHandle;
use crate::util::OsEntropy;
fn test_config() -> ShardedConfig {
ShardedConfig {
io_driver: None,
timer_driver: None,
logical_clock_mode: LogicalClockMode::Lamport,
cancel_attribution: CancelAttributionConfig::default(),
entropy_source: Arc::new(OsEntropy),
blocking_pool: None,
obligation_leak_response: ObligationLeakResponse::Log,
leak_escalation: None,
observability: None,
}
}
#[test]
fn sharded_state_creation() {
let trace = TraceBufferHandle::new(1024);
let metrics: Arc<dyn MetricsProvider> = Arc::new(NoOpMetrics);
let state = ShardedState::new(trace, metrics, test_config());
assert!(state.root_region().is_none());
assert_eq!(state.current_time(), Time::ZERO);
assert_eq!(state.leak_count(), 0);
}
#[test]
fn root_region_set_once() {
let trace = TraceBufferHandle::new(1024);
let metrics: Arc<dyn MetricsProvider> = Arc::new(NoOpMetrics);
let state = ShardedState::new(trace, metrics, test_config());
let first = RegionId::from_arena(ArenaIndex::new(1, 0));
let second = RegionId::from_arena(ArenaIndex::new(2, 0));
assert!(state.set_root_region(first));
assert_eq!(state.root_region(), Some(first));
assert!(!state.set_root_region(second));
assert_eq!(state.root_region(), Some(first));
}
#[test]
fn time_operations() {
let trace = TraceBufferHandle::new(1024);
let metrics: Arc<dyn MetricsProvider> = Arc::new(NoOpMetrics);
let state = ShardedState::new(trace, metrics, test_config());
state.set_time(Time::from_nanos(12345));
assert_eq!(state.current_time(), Time::from_nanos(12345));
}
#[test]
fn leak_count_increment() {
let trace = TraceBufferHandle::new(1024);
let metrics: Arc<dyn MetricsProvider> = Arc::new(NoOpMetrics);
let state = ShardedState::new(trace, metrics, test_config());
assert_eq!(state.increment_leak_count(), 1);
assert_eq!(state.increment_leak_count(), 2);
assert_eq!(state.leak_count(), 2);
}
#[test]
fn tasks_only_guard() {
let trace = TraceBufferHandle::new(1024);
let metrics: Arc<dyn MetricsProvider> = Arc::new(NoOpMetrics);
let state = ShardedState::new(trace, metrics, test_config());
let guard = ShardGuard::tasks_only(&state);
assert!(guard.tasks.is_some());
assert!(guard.regions.is_none());
assert!(guard.obligations.is_none());
}
#[test]
fn for_task_completed_guard() {
let trace = TraceBufferHandle::new(1024);
let metrics: Arc<dyn MetricsProvider> = Arc::new(NoOpMetrics);
let state = ShardedState::new(trace, metrics, test_config());
let guard = ShardGuard::for_task_completed(&state);
assert!(guard.tasks.is_some());
assert!(guard.regions.is_some());
assert!(guard.obligations.is_some());
}
#[test]
fn regions_only_guard() {
let trace = TraceBufferHandle::new(1024);
let metrics: Arc<dyn MetricsProvider> = Arc::new(NoOpMetrics);
let state = ShardedState::new(trace, metrics, test_config());
let guard = ShardGuard::regions_only(&state);
assert!(guard.regions.is_some());
assert!(guard.tasks.is_none());
assert!(guard.obligations.is_none());
}
#[test]
fn obligations_only_guard() {
let trace = TraceBufferHandle::new(1024);
let metrics: Arc<dyn MetricsProvider> = Arc::new(NoOpMetrics);
let state = ShardedState::new(trace, metrics, test_config());
let guard = ShardGuard::obligations_only(&state);
assert!(guard.obligations.is_some());
assert!(guard.regions.is_none());
assert!(guard.tasks.is_none());
}
#[test]
fn for_cancel_guard() {
let trace = TraceBufferHandle::new(1024);
let metrics: Arc<dyn MetricsProvider> = Arc::new(NoOpMetrics);
let state = ShardedState::new(trace, metrics, test_config());
let guard = ShardGuard::for_cancel(&state);
assert!(guard.regions.is_some());
assert!(guard.tasks.is_some());
assert!(guard.obligations.is_some());
}
#[test]
fn for_obligation_guard() {
let trace = TraceBufferHandle::new(1024);
let metrics: Arc<dyn MetricsProvider> = Arc::new(NoOpMetrics);
let state = ShardedState::new(trace, metrics, test_config());
let guard = ShardGuard::for_obligation(&state);
assert!(guard.regions.is_some());
assert!(guard.tasks.is_none());
assert!(guard.obligations.is_some());
}
#[test]
fn for_obligation_resolve_guard() {
let trace = TraceBufferHandle::new(1024);
let metrics: Arc<dyn MetricsProvider> = Arc::new(NoOpMetrics);
let state = ShardedState::new(trace, metrics, test_config());
let guard = ShardGuard::for_obligation_resolve(&state);
assert!(guard.regions.is_some());
assert!(guard.tasks.is_some());
assert!(guard.obligations.is_some());
}
#[test]
fn for_spawn_guard() {
let trace = TraceBufferHandle::new(1024);
let metrics: Arc<dyn MetricsProvider> = Arc::new(NoOpMetrics);
let state = ShardedState::new(trace, metrics, test_config());
let guard = ShardGuard::for_spawn(&state);
assert!(guard.regions.is_some());
assert!(guard.tasks.is_some());
assert!(guard.obligations.is_none());
}
#[test]
fn all_guard() {
let trace = TraceBufferHandle::new(1024);
let metrics: Arc<dyn MetricsProvider> = Arc::new(NoOpMetrics);
let state = ShardedState::new(trace, metrics, test_config());
let guard = ShardGuard::all(&state);
assert!(guard.regions.is_some());
assert!(guard.tasks.is_some());
assert!(guard.obligations.is_some());
}
#[cfg(debug_assertions)]
#[test]
fn lock_order_held_count_tracks_acquisitions() {
assert_eq!(lock_order::held_count(), 0);
let trace = TraceBufferHandle::new(1024);
let metrics: Arc<dyn MetricsProvider> = Arc::new(NoOpMetrics);
let state = ShardedState::new(trace, metrics, test_config());
{
let _guard = ShardGuard::for_task_completed(&state);
assert_eq!(lock_order::held_count(), 3);
assert_eq!(
lock_order::held_labels(),
vec!["B:Regions", "A:Tasks", "C:Obligations"]
);
}
assert_eq!(lock_order::held_count(), 0);
}
#[cfg(debug_assertions)]
#[test]
fn lock_order_single_shard_tracking() {
assert_eq!(lock_order::held_count(), 0);
let trace = TraceBufferHandle::new(1024);
let metrics: Arc<dyn MetricsProvider> = Arc::new(NoOpMetrics);
let state = ShardedState::new(trace, metrics, test_config());
{
let _guard = ShardGuard::tasks_only(&state);
assert_eq!(lock_order::held_count(), 1);
assert_eq!(lock_order::held_labels(), vec!["A:Tasks"]);
}
assert_eq!(lock_order::held_count(), 0);
{
let _guard = ShardGuard::regions_only(&state);
assert_eq!(lock_order::held_count(), 1);
assert_eq!(lock_order::held_labels(), vec!["B:Regions"]);
}
assert_eq!(lock_order::held_count(), 0);
{
let _guard = ShardGuard::obligations_only(&state);
assert_eq!(lock_order::held_count(), 1);
assert_eq!(lock_order::held_labels(), vec!["C:Obligations"]);
}
assert_eq!(lock_order::held_count(), 0);
}
#[cfg(debug_assertions)]
#[test]
fn lock_order_spawn_guard_tracking() {
assert_eq!(lock_order::held_count(), 0);
let trace = TraceBufferHandle::new(1024);
let metrics: Arc<dyn MetricsProvider> = Arc::new(NoOpMetrics);
let state = ShardedState::new(trace, metrics, test_config());
{
let _guard = ShardGuard::for_spawn(&state);
assert_eq!(lock_order::held_count(), 2);
assert_eq!(lock_order::held_labels(), vec!["B:Regions", "A:Tasks"]);
}
assert_eq!(lock_order::held_count(), 0);
}
#[cfg(debug_assertions)]
#[test]
fn lock_order_cancel_guard_tracking() {
assert_eq!(lock_order::held_count(), 0);
let trace = TraceBufferHandle::new(1024);
let metrics: Arc<dyn MetricsProvider> = Arc::new(NoOpMetrics);
let state = ShardedState::new(trace, metrics, test_config());
{
let _guard = ShardGuard::for_cancel(&state);
assert_eq!(lock_order::held_count(), 3);
assert_eq!(
lock_order::held_labels(),
vec!["B:Regions", "A:Tasks", "C:Obligations"]
);
}
assert_eq!(lock_order::held_count(), 0);
}
#[cfg(debug_assertions)]
#[test]
fn lock_order_obligation_guard_tracking() {
assert_eq!(lock_order::held_count(), 0);
let trace = TraceBufferHandle::new(1024);
let metrics: Arc<dyn MetricsProvider> = Arc::new(NoOpMetrics);
let state = ShardedState::new(trace, metrics, test_config());
{
let _guard = ShardGuard::for_obligation(&state);
assert_eq!(lock_order::held_count(), 2);
assert_eq!(
lock_order::held_labels(),
vec!["B:Regions", "C:Obligations"]
);
}
assert_eq!(lock_order::held_count(), 0);
}
#[cfg(debug_assertions)]
#[test]
fn lock_order_obligation_resolve_guard_tracking() {
assert_eq!(lock_order::held_count(), 0);
let trace = TraceBufferHandle::new(1024);
let metrics: Arc<dyn MetricsProvider> = Arc::new(NoOpMetrics);
let state = ShardedState::new(trace, metrics, test_config());
{
let _guard = ShardGuard::for_obligation_resolve(&state);
assert_eq!(lock_order::held_count(), 3);
assert_eq!(
lock_order::held_labels(),
vec!["B:Regions", "A:Tasks", "C:Obligations"]
);
}
assert_eq!(lock_order::held_count(), 0);
}
#[cfg(debug_assertions)]
#[test]
#[should_panic(expected = "lock order violation")]
fn lock_order_violation_tasks_before_regions() {
lock_order::before_lock(LockShard::Tasks);
lock_order::after_lock(LockShard::Tasks);
lock_order::before_lock(LockShard::Regions); }
#[cfg(debug_assertions)]
#[test]
#[should_panic(expected = "lock order violation")]
fn lock_order_violation_obligations_before_tasks() {
lock_order::before_lock(LockShard::Obligations);
lock_order::after_lock(LockShard::Obligations);
lock_order::before_lock(LockShard::Tasks); }
#[cfg(debug_assertions)]
#[test]
#[should_panic(expected = "lock order violation")]
fn lock_order_violation_obligations_before_regions() {
lock_order::before_lock(LockShard::Obligations);
lock_order::after_lock(LockShard::Obligations);
lock_order::before_lock(LockShard::Regions); }
#[cfg(debug_assertions)]
#[test]
#[should_panic(expected = "lock order violation")]
fn lock_order_violation_duplicate_shard() {
lock_order::before_lock(LockShard::Tasks);
lock_order::after_lock(LockShard::Tasks);
lock_order::before_lock(LockShard::Tasks); }
#[cfg(debug_assertions)]
#[test]
fn lock_order_valid_full_sequence() {
lock_order::before_lock(LockShard::Regions);
lock_order::after_lock(LockShard::Regions);
lock_order::before_lock(LockShard::Tasks);
lock_order::after_lock(LockShard::Tasks);
lock_order::before_lock(LockShard::Obligations);
lock_order::after_lock(LockShard::Obligations);
assert_eq!(lock_order::held_count(), 3);
assert_eq!(
lock_order::held_labels(),
vec!["B:Regions", "A:Tasks", "C:Obligations"]
);
lock_order::unlock_n(3);
assert_eq!(lock_order::held_count(), 0);
}
#[cfg(debug_assertions)]
#[test]
fn lock_order_unlock_then_reacquire() {
lock_order::before_lock(LockShard::Obligations);
lock_order::after_lock(LockShard::Obligations);
assert_eq!(lock_order::held_count(), 1);
lock_order::unlock_n(1);
assert_eq!(lock_order::held_count(), 0);
lock_order::before_lock(LockShard::Regions);
lock_order::after_lock(LockShard::Regions);
assert_eq!(lock_order::held_count(), 1);
assert_eq!(lock_order::held_labels(), vec!["B:Regions"]);
lock_order::unlock_n(1);
}
#[cfg(debug_assertions)]
#[test]
fn lock_order_partial_sequence_regions_tasks() {
lock_order::before_lock(LockShard::Regions);
lock_order::after_lock(LockShard::Regions);
lock_order::before_lock(LockShard::Tasks);
lock_order::after_lock(LockShard::Tasks);
assert_eq!(lock_order::held_count(), 2);
lock_order::unlock_n(2);
}
#[cfg(debug_assertions)]
#[test]
fn lock_order_partial_sequence_regions_obligations() {
lock_order::before_lock(LockShard::Regions);
lock_order::after_lock(LockShard::Regions);
lock_order::before_lock(LockShard::Obligations);
lock_order::after_lock(LockShard::Obligations);
assert_eq!(lock_order::held_count(), 2);
lock_order::unlock_n(2);
}
#[cfg(debug_assertions)]
#[test]
fn lock_order_partial_sequence_tasks_obligations() {
lock_order::before_lock(LockShard::Tasks);
lock_order::after_lock(LockShard::Tasks);
lock_order::before_lock(LockShard::Obligations);
lock_order::after_lock(LockShard::Obligations);
assert_eq!(lock_order::held_count(), 2);
lock_order::unlock_n(2);
}
#[cfg(debug_assertions)]
fn lock_rank(shard: LockShard) -> usize {
match shard {
LockShard::Regions => 0,
LockShard::Tasks => 1,
LockShard::Obligations => 2,
}
}
#[cfg(debug_assertions)]
fn canonicalize_labels(mut labels: Vec<&'static str>) -> Vec<&'static str> {
labels.sort_by_key(|label| match *label {
"B:Regions" => 0,
"A:Tasks" => 1,
"C:Obligations" => 2,
other => panic!("unexpected shard label: {other}"),
});
labels.dedup();
labels
}
#[cfg(debug_assertions)]
fn capture_labels(guard: ShardGuard<'_>) -> Vec<&'static str> {
let labels = lock_order::held_labels();
drop(guard);
assert_eq!(lock_order::held_count(), 0);
labels
}
#[cfg(debug_assertions)]
#[test]
fn metamorphic_lock_order_accepts_only_canonical_permutations() {
use std::panic::{AssertUnwindSafe, catch_unwind};
let sequences = [
vec![LockShard::Regions],
vec![LockShard::Tasks],
vec![LockShard::Obligations],
vec![LockShard::Regions, LockShard::Tasks],
vec![LockShard::Tasks, LockShard::Regions],
vec![LockShard::Regions, LockShard::Obligations],
vec![LockShard::Obligations, LockShard::Regions],
vec![LockShard::Tasks, LockShard::Obligations],
vec![LockShard::Obligations, LockShard::Tasks],
vec![LockShard::Regions, LockShard::Tasks, LockShard::Obligations],
vec![LockShard::Regions, LockShard::Obligations, LockShard::Tasks],
vec![LockShard::Tasks, LockShard::Regions, LockShard::Obligations],
vec![LockShard::Tasks, LockShard::Obligations, LockShard::Regions],
vec![LockShard::Obligations, LockShard::Regions, LockShard::Tasks],
vec![LockShard::Obligations, LockShard::Tasks, LockShard::Regions],
];
for sequence in sequences {
let expected_ok = sequence
.windows(2)
.all(|pair| lock_rank(pair[0]) < lock_rank(pair[1]));
let expected_labels: Vec<_> = sequence.iter().map(|shard| shard.label()).collect();
let result = catch_unwind(AssertUnwindSafe(|| {
for shard in &sequence {
lock_order::before_lock(*shard);
lock_order::after_lock(*shard);
}
let labels = lock_order::held_labels();
lock_order::unlock_n(sequence.len());
labels
}));
assert_eq!(
result.is_ok(),
expected_ok,
"canonical lock-order expectation disagreed for {:?}",
expected_labels
);
if let Ok(labels) = result {
assert_eq!(labels, expected_labels);
}
let leaked = lock_order::held_count();
if leaked > 0 {
lock_order::unlock_n(leaked);
}
}
}
#[cfg(debug_assertions)]
#[test]
fn metamorphic_guard_unions_match_canonical_supersets() {
let trace = TraceBufferHandle::new(1024);
let metrics: Arc<dyn MetricsProvider> = Arc::new(NoOpMetrics);
let state = ShardedState::new(trace, metrics, test_config());
let regions = capture_labels(ShardGuard::regions_only(&state));
let tasks = capture_labels(ShardGuard::tasks_only(&state));
let obligations = capture_labels(ShardGuard::obligations_only(&state));
let spawn_union = canonicalize_labels([regions.clone(), tasks.clone()].concat());
let obligation_union = canonicalize_labels([regions.clone(), obligations.clone()].concat());
let full_union = canonicalize_labels([regions, tasks, obligations].concat());
assert_eq!(capture_labels(ShardGuard::for_spawn(&state)), spawn_union);
assert_eq!(
capture_labels(ShardGuard::for_obligation(&state)),
obligation_union
);
assert_eq!(capture_labels(ShardGuard::for_cancel(&state)), full_union);
assert_eq!(
capture_labels(ShardGuard::for_task_completed(&state)),
full_union
);
assert_eq!(
capture_labels(ShardGuard::for_obligation_resolve(&state)),
full_union
);
assert_eq!(capture_labels(ShardGuard::all(&state)), full_union);
}
#[test]
fn concurrent_guard_access_no_deadlock() {
use std::sync::Barrier;
use std::thread;
let trace = TraceBufferHandle::new(1024);
let metrics: Arc<dyn MetricsProvider> = Arc::new(NoOpMetrics);
let state = Arc::new(ShardedState::new(trace, metrics, test_config()));
let barrier = Arc::new(Barrier::new(4));
let iterations = 100;
let handles: Vec<_> = (0..4)
.map(|thread_id| {
let state = Arc::clone(&state);
let barrier = Arc::clone(&barrier);
thread::spawn(move || {
barrier.wait();
for _ in 0..iterations {
match thread_id % 4 {
0 => {
let _g = ShardGuard::tasks_only(&state);
}
1 => {
let _g = ShardGuard::for_spawn(&state);
}
2 => {
let _g = ShardGuard::for_obligation(&state);
}
3 => {
let _g = ShardGuard::for_task_completed(&state);
}
_ => unreachable!(),
}
}
})
})
.collect();
for h in handles {
h.join().expect("thread panicked");
}
}
#[test]
fn concurrent_mixed_guards_no_deadlock() {
use std::sync::Barrier;
use std::thread;
let trace = TraceBufferHandle::new(1024);
let metrics: Arc<dyn MetricsProvider> = Arc::new(NoOpMetrics);
let state = Arc::new(ShardedState::new(trace, metrics, test_config()));
let barrier = Arc::new(Barrier::new(4));
let iterations = 50;
let handles: Vec<_> = (0..4)
.map(|_| {
let state = Arc::clone(&state);
let barrier = Arc::clone(&barrier);
thread::spawn(move || {
barrier.wait();
for i in 0..iterations {
match i % 8 {
0 => {
let _g = ShardGuard::tasks_only(&state);
}
1 => {
let _g = ShardGuard::regions_only(&state);
}
2 => {
let _g = ShardGuard::obligations_only(&state);
}
3 => {
let _g = ShardGuard::for_spawn(&state);
}
4 => {
let _g = ShardGuard::for_cancel(&state);
}
5 => {
let _g = ShardGuard::for_obligation(&state);
}
6 => {
let _g = ShardGuard::for_obligation_resolve(&state);
}
7 => {
let _g = ShardGuard::all(&state);
}
_ => unreachable!(),
}
}
})
})
.collect();
for h in handles {
h.join().expect("thread panicked");
}
}
#[cfg(debug_assertions)]
#[test]
fn guard_drop_cleans_up_lock_order_state() {
let trace = TraceBufferHandle::new(1024);
let metrics: Arc<dyn MetricsProvider> = Arc::new(NoOpMetrics);
let state = ShardedState::new(trace, metrics, test_config());
assert_eq!(lock_order::held_count(), 0);
{
let _g = ShardGuard::all(&state);
assert_eq!(lock_order::held_count(), 3);
}
assert_eq!(lock_order::held_count(), 0);
{
let _g = ShardGuard::obligations_only(&state);
assert_eq!(lock_order::held_count(), 1);
}
assert_eq!(lock_order::held_count(), 0);
{
let _g = ShardGuard::regions_only(&state);
assert_eq!(lock_order::held_count(), 1);
}
assert_eq!(lock_order::held_count(), 0);
}
#[test]
fn root_region_encoding_roundtrip_zero() {
let region = RegionId::from_arena(ArenaIndex::new(0, 0));
let encoded = encode_root_region(region);
assert_ne!(encoded, ROOT_REGION_NONE, "encoded must differ from NONE");
let decoded = decode_root_region(encoded);
assert_eq!(decoded, Some(region));
}
#[test]
fn root_region_encoding_roundtrip_large() {
let region = RegionId::from_arena(ArenaIndex::new(u32::MAX, u32::MAX - 1));
let encoded = encode_root_region(region);
let decoded = decode_root_region(encoded);
assert_eq!(decoded, Some(region));
}
#[test]
#[should_panic(expected = "region ID too large")]
fn root_region_encoding_max_panics() {
let region = RegionId::from_arena(ArenaIndex::new(u32::MAX, u32::MAX));
let _ = encode_root_region(region);
}
#[test]
fn guard_drop_releases_in_reverse_order() {
let trace = TraceBufferHandle::new(1024);
let metrics: Arc<dyn MetricsProvider> = Arc::new(NoOpMetrics);
let state = ShardedState::new(trace, metrics, test_config());
{
let _g = ShardGuard::all(&state);
}
let g = ShardGuard::obligations_only(&state);
assert!(g.obligations.is_some());
}
}