fast-steal 6.5.2

特别快的多线程任务调度器,支持超细颗粒度的任务窃取
Documentation
extern crate alloc;
use alloc::sync::{Arc, Weak};
use core::{fmt, ops::Range, sync::atomic::Ordering};
use portable_atomic::AtomicU128;

#[derive(Debug, Clone)]
pub struct Task {
    pub state: Arc<AtomicU128>,
}
#[derive(Debug, Clone)]
pub struct WeakTask {
    pub state: Weak<AtomicU128>,
}

impl WeakTask {
    #[must_use]
    pub fn upgrade(&self) -> Option<Task> {
        self.state.upgrade().map(|state| Task { state })
    }
    #[must_use]
    pub fn strong_count(&self) -> usize {
        self.state.strong_count()
    }
    #[must_use]
    pub fn weak_count(&self) -> usize {
        self.state.weak_count()
    }
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct RangeError;

impl fmt::Display for RangeError {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        write!(f, "Range invariant violated: start > end or overflow")
    }
}

impl Task {
    #[allow(clippy::inline_always)]
    #[inline(always)]
    const fn pack(range: Range<u64>) -> u128 {
        ((range.start as u128) << 64) | (range.end as u128)
    }
    #[allow(clippy::inline_always)]
    #[inline(always)]
    const fn unpack(state: u128) -> Range<u64> {
        #[allow(clippy::cast_possible_truncation)]
        let end = state as u64;
        (state >> 64) as u64..end
    }

    /// # Panics
    /// 当 range.start > range.end
    #[must_use]
    pub fn new(range: Range<u64>) -> Self {
        assert!(range.start <= range.end);
        Self {
            state: Arc::new(AtomicU128::new(Self::pack(range))),
        }
    }
    #[must_use]
    pub fn get(&self) -> Range<u64> {
        let state = self.state.load(Ordering::Acquire);
        Self::unpack(state)
    }
    #[must_use]
    pub fn start(&self) -> u64 {
        (self.state.load(Ordering::Acquire) >> 64) as u64
    }
    /// # Errors
    /// 当 `start` + `bias` <= `old_start` 时返回 [`RangeError`]
    /// 否则返回 `old_start..new_start.min(end)`
    pub fn safe_add_start(&self, start: u64, bias: u64) -> Result<Range<u64>, RangeError> {
        let new_start = start.checked_add(bias).ok_or(RangeError)?;
        let mut old_state = self.state.load(Ordering::Acquire);
        loop {
            let mut range = Self::unpack(old_state);
            let new_start = new_start.min(range.end);
            if new_start <= range.start {
                break Err(RangeError);
            }
            let span = range.start..new_start;
            range.start = new_start;
            let new_state = Self::pack(range);
            match self.state.compare_exchange_weak(
                old_state,
                new_state,
                Ordering::AcqRel,
                Ordering::Acquire,
            ) {
                Ok(_) => break Ok(span),
                Err(x) => old_state = x,
            }
        }
    }
    #[must_use]
    pub fn end(&self) -> u64 {
        let state = self.state.load(Ordering::Acquire);
        #[allow(clippy::cast_possible_truncation)]
        let end = state as u64;
        end
    }
    #[must_use]
    pub fn remain(&self) -> u64 {
        let range = self.get();
        range.end.saturating_sub(range.start)
    }
    /// # Errors
    /// 1. 当 start > end 时返回 [`RangeError`]
    /// 2. 当 remain < 2 时返回 None 并且不会修改自己
    pub fn split_two(&self) -> Result<Option<Range<u64>>, RangeError> {
        let mut old_state = self.state.load(Ordering::Acquire);
        loop {
            let range = Self::unpack(old_state);
            if range.start > range.end {
                return Err(RangeError);
            }
            let mid = range.start + (range.end - range.start) / 2;
            if mid == range.start {
                return Ok(None);
            }
            let new_state = Self::pack(range.start..mid);
            match self.state.compare_exchange_weak(
                old_state,
                new_state,
                Ordering::AcqRel,
                Ordering::Acquire,
            ) {
                Ok(_) => return Ok(Some(mid..range.end)),
                Err(x) => old_state = x,
            }
        }
    }
    #[must_use]
    pub fn take(&self) -> Option<Range<u64>> {
        let mut old_state = self.state.load(Ordering::Acquire);
        loop {
            let range = Self::unpack(old_state);
            if range.start == range.end {
                return None;
            }
            let new_state = Self::pack(range.start..range.start);
            match self.state.compare_exchange_weak(
                old_state,
                new_state,
                Ordering::AcqRel,
                Ordering::Acquire,
            ) {
                Ok(_) => return Some(range),
                Err(x) => old_state = x,
            }
        }
    }
    #[must_use]
    pub fn downgrade(&self) -> WeakTask {
        WeakTask {
            state: Arc::downgrade(&self.state),
        }
    }
    #[must_use]
    pub fn strong_count(&self) -> usize {
        Arc::strong_count(&self.state)
    }
    #[must_use]
    pub fn weak_count(&self) -> usize {
        Arc::weak_count(&self.state)
    }
}
impl From<Range<u64>> for Task {
    fn from(value: Range<u64>) -> Self {
        Self::new(value)
    }
}

impl PartialEq for Task {
    fn eq(&self, other: &Self) -> bool {
        Arc::ptr_eq(&self.state, &other.state)
    }
}
impl Eq for Task {}

#[cfg(test)]
mod tests {
    #![allow(clippy::unwrap_used)]
    use super::*;

    #[test]
    fn test_new_task() {
        let task = Task::new(10..20);
        assert_eq!(task.start(), 10);
        assert_eq!(task.end(), 20);
        assert_eq!(task.remain(), 10);
    }

    #[test]
    fn test_remain() {
        let task = Task::new(10..25);
        assert_eq!(task.remain(), 15);
    }

    #[test]
    fn test_split_two() {
        let task = Task::new(1..6); // 1, 2, 3, 4, 5
        let range = task.split_two().unwrap().unwrap();
        assert_eq!(task.start(), 1);
        assert_eq!(task.end(), 3);
        assert_eq!(range.start, 3);
        assert_eq!(range.end, 6);
    }

    #[test]
    fn test_split_empty() {
        let task = Task::new(1..1);
        let range = task.split_two().unwrap();
        assert_eq!(task.start(), 1);
        assert_eq!(task.end(), 1);
        assert_eq!(range, None);
    }

    #[test]
    fn test_split_one() {
        let task = Task::new(1..2);
        let range = task.split_two().unwrap();
        assert_eq!(task.start(), 1);
        assert_eq!(task.end(), 2);
        assert_eq!(range, None);
    }
}