#![allow(missing_docs)]
use crate::runtime::scheduler::worker::WorkerId;
use crate::types::TaskId;
use parking_lot::RwLock;
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, Instant};
#[derive(Debug, Default, Clone)]
pub struct StealingStats {
pub steal_attempts: u64,
pub successful_steals: u64,
pub failed_steals: u64,
pub ownership_violations: u64,
pub double_execution_violations: u64,
pub lost_work_violations: u64,
pub avg_steal_latency_us: u64,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum OwnershipState {
Owned(WorkerId),
Stealing { from: WorkerId, to: WorkerId },
Completed,
Cancelled,
}
#[derive(Debug, Clone)]
pub enum ViolationType {
MultipleOwners {
task_id: TaskId,
owners: Vec<WorkerId>,
},
DoubleExecution {
task_id: TaskId,
first_worker: WorkerId,
second_worker: WorkerId,
},
LostWork {
task_id: TaskId,
last_owner: WorkerId,
},
SlowSteal {
task_id: TaskId,
from_worker: WorkerId,
to_worker: WorkerId,
duration: Duration,
},
OrderingViolation {
task_id: TaskId,
expected_order: String,
actual_order: String,
},
}
#[derive(Debug)]
pub struct WorkStealingChecker {
task_owners: Arc<RwLock<HashMap<TaskId, OwnershipState>>>,
executing_tasks: Arc<RwLock<HashMap<TaskId, WorkerId>>>,
violations: Arc<RwLock<Vec<ViolationType>>>,
stats: Arc<RwLock<StealingStats>>,
sequence_counter: AtomicU64,
task_sequences: Arc<RwLock<HashMap<TaskId, u64>>>,
enabled: bool,
}
impl Default for WorkStealingChecker {
fn default() -> Self {
Self::new()
}
}
impl WorkStealingChecker {
#[must_use]
pub fn new() -> Self {
Self {
task_owners: Arc::new(RwLock::new(HashMap::new())),
executing_tasks: Arc::new(RwLock::new(HashMap::new())),
violations: Arc::new(RwLock::new(Vec::new())),
stats: Arc::new(RwLock::new(StealingStats::default())),
sequence_counter: AtomicU64::new(0),
task_sequences: Arc::new(RwLock::new(HashMap::new())),
enabled: true,
}
}
#[must_use]
pub fn disabled() -> Self {
Self {
task_owners: Arc::new(RwLock::new(HashMap::new())),
executing_tasks: Arc::new(RwLock::new(HashMap::new())),
violations: Arc::new(RwLock::new(Vec::new())),
stats: Arc::new(RwLock::new(StealingStats::default())),
sequence_counter: AtomicU64::new(0),
task_sequences: Arc::new(RwLock::new(HashMap::new())),
enabled: false,
}
}
pub fn track_task_queued(&self, task_id: TaskId, worker_id: WorkerId) {
if !self.enabled {
return;
}
let sequence = self.sequence_counter.fetch_add(1, Ordering::Relaxed);
{
let mut owners = self.task_owners.write();
owners.insert(task_id, OwnershipState::Owned(worker_id));
}
{
let mut sequences = self.task_sequences.write();
sequences.insert(task_id, sequence);
}
}
pub fn track_steal_start(
&self,
task_id: TaskId,
from_worker: WorkerId,
to_worker: WorkerId,
) -> Option<StealTracker<'_>> {
if !self.enabled {
return None;
}
{
let mut owners = self.task_owners.write();
if let Some(state) = owners.get_mut(&task_id) {
match state {
OwnershipState::Owned(owner) if *owner == from_worker => {
*state = OwnershipState::Stealing {
from: from_worker,
to: to_worker,
};
}
_ => {
self.record_violation(ViolationType::MultipleOwners {
task_id,
owners: vec![from_worker, to_worker],
});
return None;
}
}
}
}
{
let mut stats = self.stats.write();
stats.steal_attempts += 1;
}
Some(StealTracker {
task_id,
from_worker,
to_worker,
start_time: Instant::now(),
checker: self,
})
}
fn track_steal_success(
&self,
task_id: TaskId,
from_worker: WorkerId,
to_worker: WorkerId,
duration: Duration,
) {
{
let mut owners = self.task_owners.write();
if let Some(state) = owners.get_mut(&task_id) {
match state {
OwnershipState::Stealing { from, to }
if *from == from_worker && *to == to_worker =>
{
*state = OwnershipState::Owned(to_worker);
}
_ => {
self.record_violation(ViolationType::LostWork {
task_id,
last_owner: from_worker,
});
return;
}
}
}
}
{
let mut stats = self.stats.write();
stats.successful_steals += 1;
let duration_us = duration.as_micros() as u64;
if stats.successful_steals == 1 {
stats.avg_steal_latency_us = duration_us;
} else {
stats.avg_steal_latency_us = (stats.avg_steal_latency_us * 3 + duration_us) / 4;
}
}
if duration > Duration::from_millis(1) {
self.record_violation(ViolationType::SlowSteal {
task_id,
from_worker,
to_worker,
duration,
});
}
}
fn track_steal_failure(&self, _task_id: TaskId, _from_worker: WorkerId, _to_worker: WorkerId) {
let mut stats = self.stats.write();
stats.failed_steals += 1;
}
pub fn track_task_execution_start(&self, task_id: TaskId, worker_id: WorkerId) {
if !self.enabled {
return;
}
{
let mut executing = self.executing_tasks.write();
if let Some(&existing_worker) = executing.get(&task_id) {
self.record_violation(ViolationType::DoubleExecution {
task_id,
first_worker: existing_worker,
second_worker: worker_id,
});
return;
}
executing.insert(task_id, worker_id);
}
{
let owners = self.task_owners.read();
if let Some(state) = owners.get(&task_id) {
match state {
OwnershipState::Owned(owner) if *owner != worker_id => {
self.record_violation(ViolationType::MultipleOwners {
task_id,
owners: vec![*owner, worker_id],
});
}
OwnershipState::Stealing { .. } => {
self.record_violation(ViolationType::MultipleOwners {
task_id,
owners: vec![worker_id],
});
}
_ => {
}
}
}
}
}
pub fn track_task_execution_complete(&self, task_id: TaskId, _worker_id: WorkerId) {
if !self.enabled {
return;
}
{
let mut executing = self.executing_tasks.write();
executing.remove(&task_id);
}
{
let mut owners = self.task_owners.write();
owners.insert(task_id, OwnershipState::Completed);
}
{
let mut sequences = self.task_sequences.write();
sequences.remove(&task_id);
}
}
fn record_violation(&self, violation: ViolationType) {
let mut violations = self.violations.write();
violations.push(violation);
let mut stats = self.stats.write();
match violations.last().unwrap() {
ViolationType::MultipleOwners { .. } => stats.ownership_violations += 1,
ViolationType::DoubleExecution { .. } => stats.double_execution_violations += 1,
ViolationType::LostWork { .. } => stats.lost_work_violations += 1,
ViolationType::SlowSteal { .. } => {} ViolationType::OrderingViolation { .. } => {} }
}
#[must_use]
pub fn stats(&self) -> StealingStats {
self.stats.read().clone()
}
#[must_use]
pub fn violations(&self) -> Vec<ViolationType> {
self.violations.read().clone()
}
pub fn reset(&self) {
self.violations.write().clear();
*self.stats.write() = StealingStats::default();
self.task_owners.write().clear();
self.executing_tasks.write().clear();
self.task_sequences.write().clear();
self.sequence_counter.store(0, Ordering::Relaxed);
}
pub fn validate_ordering(&self, task_id: TaskId, is_owner: bool, _worker_id: WorkerId) {
if !self.enabled {
return;
}
let sequences = self.task_sequences.read();
if let Some(&_task_sequence) = sequences.get(&task_id) {
let _expected_order = if is_owner { "LIFO" } else { "FIFO" };
drop(sequences);
}
}
#[must_use]
pub fn has_violations(&self) -> bool {
!self.violations.read().is_empty()
}
#[must_use]
pub fn violation_count(&self) -> usize {
self.violations.read().len()
}
}
pub struct StealTracker<'a> {
task_id: TaskId,
from_worker: WorkerId,
to_worker: WorkerId,
start_time: Instant,
checker: &'a WorkStealingChecker,
}
impl StealTracker<'_> {
pub fn success(self) {
let duration = self.start_time.elapsed();
self.checker
.track_steal_success(self.task_id, self.from_worker, self.to_worker, duration);
}
pub fn failure(self) {
self.checker
.track_steal_failure(self.task_id, self.from_worker, self.to_worker);
}
}
impl Drop for StealTracker<'_> {
fn drop(&mut self) {
self.checker
.track_steal_failure(self.task_id, self.from_worker, self.to_worker);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_basic_ownership_tracking() {
let checker = WorkStealingChecker::new();
let worker1 = 1;
let _worker2 = 2;
let task = TaskId::new_for_test(100, 0);
checker.track_task_queued(task, worker1);
checker.track_task_execution_start(task, worker1);
assert!(!checker.has_violations());
checker.track_task_execution_complete(task, worker1);
assert!(!checker.has_violations());
}
#[test]
fn test_steal_operation() {
let checker = WorkStealingChecker::new();
let worker1 = 1;
let worker2 = 2;
let task = TaskId::new_for_test(100, 0);
checker.track_task_queued(task, worker1);
if let Some(tracker) = checker.track_steal_start(task, worker1, worker2) {
tracker.success();
}
checker.track_task_execution_start(task, worker2);
checker.track_task_execution_complete(task, worker2);
let stats = checker.stats();
assert_eq!(stats.successful_steals, 1);
assert!(!checker.has_violations());
}
#[test]
fn test_double_execution_detection() {
let checker = WorkStealingChecker::new();
let worker1 = 1;
let worker2 = 2;
let task = TaskId::new_for_test(100, 0);
checker.track_task_queued(task, worker1);
checker.track_task_execution_start(task, worker1);
checker.track_task_execution_start(task, worker2);
assert!(checker.has_violations());
let violations = checker.violations();
assert_eq!(violations.len(), 1);
matches!(violations[0], ViolationType::DoubleExecution { .. });
}
#[test]
fn test_ownership_violation_detection() {
let checker = WorkStealingChecker::new();
let worker1 = 1;
let worker2 = 2;
let worker3 = 3;
let task = TaskId::new_for_test(100, 0);
checker.track_task_queued(task, worker1);
let tracker = checker.track_steal_start(task, worker2, worker3);
assert!(tracker.is_none());
assert!(checker.has_violations());
}
}