#![allow(missing_docs)]
use crate::runtime::scheduler::worker::WorkerId;
use crate::types::TaskId;
use parking_lot::RwLock;
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, Instant};
#[derive(Debug, Default, Clone)]
pub struct StealingStats {
pub steal_attempts: u64,
pub successful_steals: u64,
pub failed_steals: u64,
pub ownership_violations: u64,
pub double_execution_violations: u64,
pub lost_work_violations: u64,
pub avg_steal_latency_us: u64,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum OwnershipState {
Owned(WorkerId),
Stealing { from: WorkerId, to: WorkerId },
Completed,
Cancelled,
}
#[derive(Debug, Clone)]
pub enum ViolationType {
MultipleOwners {
task_id: TaskId,
owners: Vec<WorkerId>,
},
DoubleExecution {
task_id: TaskId,
first_worker: WorkerId,
second_worker: WorkerId,
},
LostWork {
task_id: TaskId,
last_owner: WorkerId,
},
SlowSteal {
task_id: TaskId,
from_worker: WorkerId,
to_worker: WorkerId,
duration: Duration,
},
OrderingViolation {
task_id: TaskId,
expected_order: String,
actual_order: String,
},
}
#[derive(Debug)]
pub struct WorkStealingChecker {
task_owners: Arc<RwLock<HashMap<TaskId, OwnershipState>>>,
executing_tasks: Arc<RwLock<HashMap<TaskId, WorkerId>>>,
violations: Arc<RwLock<Vec<ViolationType>>>,
stats: Arc<RwLock<StealingStats>>,
sequence_counter: AtomicU64,
task_sequences: Arc<RwLock<HashMap<TaskId, u64>>>,
enabled: bool,
}
impl Default for WorkStealingChecker {
fn default() -> Self {
Self::new()
}
}
impl WorkStealingChecker {
fn with_enabled(enabled: bool) -> Self {
Self {
task_owners: Arc::new(RwLock::new(HashMap::new())),
executing_tasks: Arc::new(RwLock::new(HashMap::new())),
violations: Arc::new(RwLock::new(Vec::new())),
stats: Arc::new(RwLock::new(StealingStats::default())),
sequence_counter: AtomicU64::new(0),
task_sequences: Arc::new(RwLock::new(HashMap::new())),
enabled,
}
}
#[must_use]
pub fn new() -> Self {
Self::with_enabled(true)
}
#[must_use]
pub fn disabled() -> Self {
Self::with_enabled(false)
}
pub fn track_task_queued(&self, task_id: TaskId, worker_id: WorkerId) {
if !self.enabled {
return;
}
let sequence = self.sequence_counter.fetch_add(1, Ordering::Relaxed);
{
let mut owners = self.task_owners.write();
owners.insert(task_id, OwnershipState::Owned(worker_id));
}
{
let mut sequences = self.task_sequences.write();
sequences.insert(task_id, sequence);
}
}
pub fn track_steal_start(
&self,
task_id: TaskId,
from_worker: WorkerId,
to_worker: WorkerId,
) -> Option<StealTracker<'_>> {
if !self.enabled {
return None;
}
let steal_start_violation = {
let executing = self.executing_tasks.read();
if executing.contains_key(&task_id) {
Some(ViolationType::MultipleOwners {
task_id,
owners: vec![from_worker, to_worker],
})
} else {
let mut owners = self.task_owners.write();
if let Some(state) = owners.get_mut(&task_id) {
match state {
OwnershipState::Owned(owner) if *owner == from_worker => {
*state = OwnershipState::Stealing {
from: from_worker,
to: to_worker,
};
None
}
_ => {
Some(ViolationType::MultipleOwners {
task_id,
owners: vec![from_worker, to_worker],
})
}
}
} else {
Some(ViolationType::LostWork {
task_id,
last_owner: from_worker,
})
}
}
};
if let Some(violation) = steal_start_violation {
self.record_violation(violation);
return None;
}
{
let mut stats = self.stats.write();
stats.steal_attempts += 1;
}
Some(StealTracker {
task_id,
from_worker,
to_worker,
start_time: Instant::now(),
checker: self,
completed: false,
})
}
fn track_steal_success(
&self,
task_id: TaskId,
from_worker: WorkerId,
to_worker: WorkerId,
duration: Duration,
) {
{
let mut owners = self.task_owners.write();
if let Some(state) = owners.get_mut(&task_id) {
match state {
OwnershipState::Stealing { from, to }
if *from == from_worker && *to == to_worker =>
{
*state = OwnershipState::Owned(to_worker);
}
_ => {
self.record_violation(ViolationType::LostWork {
task_id,
last_owner: from_worker,
});
return;
}
}
} else {
self.record_violation(ViolationType::LostWork {
task_id,
last_owner: from_worker,
});
return;
}
}
{
let mut stats = self.stats.write();
stats.successful_steals += 1;
let duration_us = duration.as_micros() as u64;
if stats.successful_steals == 1 {
stats.avg_steal_latency_us = duration_us;
} else {
stats.avg_steal_latency_us = (stats.avg_steal_latency_us * 3 + duration_us) / 4;
}
}
if duration > Duration::from_millis(1) {
self.record_violation(ViolationType::SlowSteal {
task_id,
from_worker,
to_worker,
duration,
});
}
}
fn track_steal_failure(&self, task_id: TaskId, from_worker: WorkerId, to_worker: WorkerId) {
{
let mut owners = self.task_owners.write();
if let Some(state) = owners.get_mut(&task_id) {
match state {
OwnershipState::Stealing { from, to }
if *from == from_worker && *to == to_worker =>
{
*state = OwnershipState::Owned(from_worker);
}
_ => {
self.record_violation(ViolationType::LostWork {
task_id,
last_owner: from_worker,
});
return;
}
}
} else {
self.record_violation(ViolationType::LostWork {
task_id,
last_owner: from_worker,
});
return;
}
}
let mut stats = self.stats.write();
stats.failed_steals += 1;
}
pub fn track_task_execution_start(&self, task_id: TaskId, worker_id: WorkerId) {
if !self.enabled {
return;
}
let execution_violation = {
let mut executing = self.executing_tasks.write();
if let Some(&existing_worker) = executing.get(&task_id) {
Some(ViolationType::DoubleExecution {
task_id,
first_worker: existing_worker,
second_worker: worker_id,
})
} else {
let owners = self.task_owners.read();
match owners.get(&task_id) {
Some(OwnershipState::Owned(owner)) if *owner == worker_id => {
executing.insert(task_id, worker_id);
None
}
Some(OwnershipState::Owned(owner)) => Some(ViolationType::MultipleOwners {
task_id,
owners: vec![*owner, worker_id],
}),
Some(OwnershipState::Stealing { from, to }) => {
Some(ViolationType::MultipleOwners {
task_id,
owners: vec![*from, *to, worker_id],
})
}
Some(OwnershipState::Completed | OwnershipState::Cancelled) | None => {
Some(ViolationType::LostWork {
task_id,
last_owner: worker_id,
})
}
}
}
};
if let Some(violation) = execution_violation {
self.record_violation(violation);
}
}
pub fn track_task_execution_complete(&self, task_id: TaskId, worker_id: WorkerId) {
if !self.enabled {
return;
}
let completion_violation = {
let mut executing = self.executing_tasks.write();
match executing.get(&task_id).copied() {
Some(active_worker) if active_worker == worker_id => {
executing.remove(&task_id);
None
}
Some(active_worker) => Some(ViolationType::DoubleExecution {
task_id,
first_worker: active_worker,
second_worker: worker_id,
}),
None => Some(ViolationType::LostWork {
task_id,
last_owner: worker_id,
}),
}
};
if let Some(violation) = completion_violation {
self.record_violation(violation);
return;
}
{
let mut owners = self.task_owners.write();
owners.insert(task_id, OwnershipState::Completed);
}
{
let mut sequences = self.task_sequences.write();
sequences.remove(&task_id);
}
}
fn record_violation(&self, violation: ViolationType) {
let mut stats = self.stats.write();
match &violation {
ViolationType::MultipleOwners { .. } => stats.ownership_violations += 1,
ViolationType::DoubleExecution { .. } => stats.double_execution_violations += 1,
ViolationType::LostWork { .. } => stats.lost_work_violations += 1,
ViolationType::SlowSteal { .. } => {} ViolationType::OrderingViolation { .. } => {} }
drop(stats);
let mut violations = self.violations.write();
violations.push(violation);
}
#[must_use]
pub fn stats(&self) -> StealingStats {
self.stats.read().clone()
}
#[must_use]
pub fn violations(&self) -> Vec<ViolationType> {
self.violations.read().clone()
}
pub fn reset(&self) {
self.violations.write().clear();
*self.stats.write() = StealingStats::default();
self.task_owners.write().clear();
self.executing_tasks.write().clear();
self.task_sequences.write().clear();
self.sequence_counter.store(0, Ordering::Relaxed);
}
pub fn validate_ordering(&self, task_id: TaskId, is_owner: bool, _worker_id: WorkerId) {
if !self.enabled {
return;
}
let sequences = self.task_sequences.read();
if let Some(&_task_sequence) = sequences.get(&task_id) {
let _expected_order = if is_owner { "LIFO" } else { "FIFO" };
drop(sequences);
}
}
#[must_use]
pub fn has_violations(&self) -> bool {
!self.violations.read().is_empty()
}
#[must_use]
pub fn violation_count(&self) -> usize {
self.violations.read().len()
}
}
pub struct StealTracker<'a> {
task_id: TaskId,
from_worker: WorkerId,
to_worker: WorkerId,
start_time: Instant,
checker: &'a WorkStealingChecker,
completed: bool,
}
impl StealTracker<'_> {
pub fn success(mut self) {
let duration = self.start_time.elapsed();
self.checker
.track_steal_success(self.task_id, self.from_worker, self.to_worker, duration);
self.completed = true;
}
pub fn failure(mut self) {
self.checker
.track_steal_failure(self.task_id, self.from_worker, self.to_worker);
self.completed = true;
}
}
impl Drop for StealTracker<'_> {
fn drop(&mut self) {
if !self.completed {
self.checker
.track_steal_failure(self.task_id, self.from_worker, self.to_worker);
}
}
}
#[cfg(test)]
mod tests {
#![allow(
clippy::pedantic,
clippy::nursery,
clippy::expect_fun_call,
clippy::map_unwrap_or,
clippy::cast_possible_wrap,
clippy::future_not_send
)]
use super::*;
#[test]
fn test_basic_ownership_tracking() {
let checker = WorkStealingChecker::new();
let worker1 = 1;
let _worker2 = 2;
let task = TaskId::new_for_test(100, 0);
checker.track_task_queued(task, worker1);
checker.track_task_execution_start(task, worker1);
assert!(!checker.has_violations());
checker.track_task_execution_complete(task, worker1);
assert!(!checker.has_violations());
}
#[test]
fn test_steal_operation() {
let checker = WorkStealingChecker::new();
let worker1 = 1;
let worker2 = 2;
let task = TaskId::new_for_test(100, 0);
checker.track_task_queued(task, worker1);
if let Some(tracker) = checker.track_steal_start(task, worker1, worker2) {
tracker.success();
}
checker.track_task_execution_start(task, worker2);
checker.track_task_execution_complete(task, worker2);
let stats = checker.stats();
assert_eq!(stats.successful_steals, 1);
assert_eq!(stats.failed_steals, 0);
assert!(!checker.has_violations());
}
#[test]
fn test_failed_steal_restores_original_owner() {
let checker = WorkStealingChecker::new();
let worker1 = 1;
let worker2 = 2;
let task = TaskId::new_for_test(101, 0);
checker.track_task_queued(task, worker1);
if let Some(tracker) = checker.track_steal_start(task, worker1, worker2) {
tracker.failure();
}
checker.track_task_execution_start(task, worker1);
checker.track_task_execution_complete(task, worker1);
let stats = checker.stats();
assert_eq!(stats.failed_steals, 1);
assert_eq!(stats.successful_steals, 0);
assert!(!checker.has_violations());
}
#[test]
fn test_dropped_tracker_counts_single_failure() {
let checker = WorkStealingChecker::new();
let worker1 = 1;
let worker2 = 2;
let task = TaskId::new_for_test(102, 0);
checker.track_task_queued(task, worker1);
let tracker = checker
.track_steal_start(task, worker1, worker2)
.expect("steal tracker should be created");
drop(tracker);
let stats = checker.stats();
assert_eq!(stats.failed_steals, 1);
assert_eq!(stats.successful_steals, 0);
checker.track_task_execution_start(task, worker1);
checker.track_task_execution_complete(task, worker1);
assert!(!checker.has_violations());
}
#[test]
fn test_double_execution_detection() {
let checker = WorkStealingChecker::new();
let worker1 = 1;
let worker2 = 2;
let task = TaskId::new_for_test(100, 0);
checker.track_task_queued(task, worker1);
checker.track_task_execution_start(task, worker1);
checker.track_task_execution_start(task, worker2);
assert!(checker.has_violations());
let violations = checker.violations();
assert_eq!(violations.len(), 1);
assert!(matches!(
violations[0],
ViolationType::DoubleExecution { .. }
));
}
#[test]
fn test_ownership_violation_detection() {
let checker = WorkStealingChecker::new();
let worker1 = 1;
let worker2 = 2;
let worker3 = 3;
let task = TaskId::new_for_test(100, 0);
checker.track_task_queued(task, worker1);
let tracker = checker.track_steal_start(task, worker2, worker3);
assert!(tracker.is_none());
assert!(checker.has_violations());
}
#[test]
fn test_steal_start_rejects_unknown_task() {
let checker = WorkStealingChecker::new();
let worker1 = 1;
let worker2 = 2;
let task = TaskId::new_for_test(103, 0);
let tracker = checker.track_steal_start(task, worker1, worker2);
assert!(tracker.is_none());
let stats = checker.stats();
assert_eq!(stats.steal_attempts, 0);
assert_eq!(stats.successful_steals, 0);
assert_eq!(stats.failed_steals, 0);
assert_eq!(stats.lost_work_violations, 1);
let violations = checker.violations();
assert_eq!(violations.len(), 1);
assert!(matches!(
violations[0],
ViolationType::LostWork {
task_id,
last_owner
} if task_id == task && last_owner == worker1
));
}
#[test]
fn test_steal_start_rejects_currently_executing_task() {
let checker = WorkStealingChecker::new();
let worker1 = 1;
let worker2 = 2;
let task = TaskId::new_for_test(104, 0);
checker.track_task_queued(task, worker1);
checker.track_task_execution_start(task, worker1);
let tracker = checker.track_steal_start(task, worker1, worker2);
assert!(tracker.is_none());
let stats = checker.stats();
assert_eq!(stats.steal_attempts, 0);
assert_eq!(stats.successful_steals, 0);
assert_eq!(stats.failed_steals, 0);
assert_eq!(stats.ownership_violations, 1);
assert_eq!(stats.lost_work_violations, 0);
assert_eq!(checker.executing_tasks.read().get(&task), Some(&worker1));
assert_eq!(
checker.task_owners.read().get(&task),
Some(&OwnershipState::Owned(worker1))
);
let violations = checker.violations();
assert_eq!(violations.len(), 1);
assert!(matches!(
violations[0],
ViolationType::MultipleOwners { ref owners, task_id }
if task_id == task && owners == &vec![worker1, worker2]
));
}
#[test]
fn test_steal_success_without_tracked_owner_is_not_counted() {
let checker = WorkStealingChecker::new();
let task = TaskId::new_for_test(105, 0);
checker.track_steal_success(task, 1, 2, Duration::ZERO);
let stats = checker.stats();
assert_eq!(stats.successful_steals, 0);
assert_eq!(stats.lost_work_violations, 1);
}
#[test]
fn test_execution_of_unknown_task_detects_lost_work() {
let checker = WorkStealingChecker::new();
let task = TaskId::new_for_test(106, 0);
checker.track_task_execution_start(task, 7);
let stats = checker.stats();
assert_eq!(stats.double_execution_violations, 0);
assert_eq!(stats.lost_work_violations, 1);
let violations = checker.violations();
assert_eq!(violations.len(), 1);
assert!(matches!(
violations[0],
ViolationType::LostWork {
task_id,
last_owner
} if task_id == task && last_owner == 7
));
assert_eq!(checker.executing_tasks.read().get(&task), None);
assert_eq!(checker.task_owners.read().get(&task), None);
}
#[test]
fn test_wrong_owner_execution_start_does_not_poison_active_execution() {
let checker = WorkStealingChecker::new();
let worker1 = 1;
let worker2 = 2;
let task = TaskId::new_for_test(109, 0);
checker.track_task_queued(task, worker1);
checker.track_task_execution_start(task, worker2);
let stats = checker.stats();
assert_eq!(stats.ownership_violations, 1);
assert_eq!(stats.double_execution_violations, 0);
assert_eq!(checker.executing_tasks.read().get(&task), None);
assert_eq!(
checker.task_owners.read().get(&task),
Some(&OwnershipState::Owned(worker1))
);
checker.track_task_execution_start(task, worker1);
checker.track_task_execution_complete(task, worker1);
assert_eq!(
checker.task_owners.read().get(&task),
Some(&OwnershipState::Completed)
);
}
#[test]
fn test_completed_task_execution_start_does_not_poison_active_execution() {
let checker = WorkStealingChecker::new();
let worker1 = 1;
let task = TaskId::new_for_test(110, 0);
checker.track_task_queued(task, worker1);
checker.track_task_execution_start(task, worker1);
checker.track_task_execution_complete(task, worker1);
assert!(!checker.has_violations());
checker.track_task_execution_start(task, worker1);
let stats = checker.stats();
assert_eq!(stats.lost_work_violations, 1);
assert_eq!(stats.double_execution_violations, 0);
assert_eq!(checker.executing_tasks.read().get(&task), None);
assert_eq!(
checker.task_owners.read().get(&task),
Some(&OwnershipState::Completed)
);
}
#[test]
fn test_completion_from_wrong_worker_preserves_active_execution() {
let checker = WorkStealingChecker::new();
let worker1 = 1;
let worker2 = 2;
let task = TaskId::new_for_test(107, 0);
checker.track_task_queued(task, worker1);
checker.track_task_execution_start(task, worker1);
checker.track_task_execution_complete(task, worker2);
let stats = checker.stats();
assert_eq!(stats.double_execution_violations, 1);
assert_eq!(stats.lost_work_violations, 0);
assert_eq!(checker.executing_tasks.read().get(&task), Some(&worker1));
assert_eq!(
checker.task_owners.read().get(&task),
Some(&OwnershipState::Owned(worker1))
);
checker.track_task_execution_complete(task, worker1);
assert_eq!(
checker.task_owners.read().get(&task),
Some(&OwnershipState::Completed)
);
}
#[test]
fn test_completion_without_execution_start_detects_lost_work() {
let checker = WorkStealingChecker::new();
let worker1 = 1;
let task = TaskId::new_for_test(108, 0);
checker.track_task_queued(task, worker1);
checker.track_task_execution_complete(task, worker1);
let stats = checker.stats();
assert_eq!(stats.lost_work_violations, 1);
assert_eq!(stats.double_execution_violations, 0);
assert_eq!(
checker.task_owners.read().get(&task),
Some(&OwnershipState::Owned(worker1))
);
}
}