Skip to main content

fast_steal/
task.rs

1extern crate alloc;
2use alloc::sync::{Arc, Weak};
3use core::{fmt, ops::Range, sync::atomic::Ordering};
4use portable_atomic::AtomicU128;
5
6#[derive(Debug, Clone)]
7pub struct Task {
8    pub state: Arc<AtomicU128>,
9}
10#[derive(Debug, Clone)]
11pub struct WeakTask {
12    pub state: Weak<AtomicU128>,
13}
14
15impl WeakTask {
16    #[must_use]
17    pub fn upgrade(&self) -> Option<Task> {
18        self.state.upgrade().map(|state| Task { state })
19    }
20    #[must_use]
21    pub fn strong_count(&self) -> usize {
22        self.state.strong_count()
23    }
24    #[must_use]
25    pub fn weak_count(&self) -> usize {
26        self.state.weak_count()
27    }
28}
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq)]
31pub struct RangeError;
32
33impl fmt::Display for RangeError {
34    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
35        write!(f, "Range invariant violated: start > end or overflow")
36    }
37}
38
39impl Task {
40    #[allow(clippy::inline_always)]
41    #[inline(always)]
42    const fn pack(range: Range<u64>) -> u128 {
43        ((range.start as u128) << 64) | (range.end as u128)
44    }
45    #[allow(clippy::inline_always)]
46    #[inline(always)]
47    const fn unpack(state: u128) -> Range<u64> {
48        #[allow(clippy::cast_possible_truncation)]
49        let end = state as u64;
50        (state >> 64) as u64..end
51    }
52
53    /// # Panics
54    /// 当 range.start > range.end
55    #[must_use]
56    pub fn new(range: Range<u64>) -> Self {
57        assert!(range.start <= range.end);
58        Self {
59            state: Arc::new(AtomicU128::new(Self::pack(range))),
60        }
61    }
62    #[must_use]
63    pub fn get(&self) -> Range<u64> {
64        let state = self.state.load(Ordering::Acquire);
65        Self::unpack(state)
66    }
67    #[must_use]
68    pub fn start(&self) -> u64 {
69        (self.state.load(Ordering::Acquire) >> 64) as u64
70    }
71    /// # Errors
72    /// 当 `start` + `bias` <= `old_start` 时返回 [`RangeError`]
73    /// 否则返回 `old_start..new_start.min(end)`
74    pub fn safe_add_start(&self, start: u64, bias: u64) -> Result<Range<u64>, RangeError> {
75        let new_start = start.checked_add(bias).ok_or(RangeError)?;
76        let mut old_state = self.state.load(Ordering::Acquire);
77        loop {
78            let mut range = Self::unpack(old_state);
79            let new_start = new_start.min(range.end);
80            if new_start <= range.start {
81                break Err(RangeError);
82            }
83            let span = range.start..new_start;
84            range.start = new_start;
85            let new_state = Self::pack(range);
86            match self.state.compare_exchange_weak(
87                old_state,
88                new_state,
89                Ordering::AcqRel,
90                Ordering::Acquire,
91            ) {
92                Ok(_) => break Ok(span),
93                Err(x) => old_state = x,
94            }
95        }
96    }
97    #[must_use]
98    pub fn end(&self) -> u64 {
99        let state = self.state.load(Ordering::Acquire);
100        #[allow(clippy::cast_possible_truncation)]
101        let end = state as u64;
102        end
103    }
104    #[must_use]
105    pub fn remain(&self) -> u64 {
106        let range = self.get();
107        range.end.saturating_sub(range.start)
108    }
109    /// # Errors
110    /// 1. 当 start > end 时返回 [`RangeError`]
111    /// 2. 当 remain < 2 时返回 None 并且不会修改自己
112    pub fn split_two(&self) -> Result<Option<Range<u64>>, RangeError> {
113        let mut old_state = self.state.load(Ordering::Acquire);
114        loop {
115            let range = Self::unpack(old_state);
116            if range.start > range.end {
117                return Err(RangeError);
118            }
119            let mid = range.start + (range.end - range.start) / 2;
120            if mid == range.start {
121                return Ok(None);
122            }
123            let new_state = Self::pack(range.start..mid);
124            match self.state.compare_exchange_weak(
125                old_state,
126                new_state,
127                Ordering::AcqRel,
128                Ordering::Acquire,
129            ) {
130                Ok(_) => return Ok(Some(mid..range.end)),
131                Err(x) => old_state = x,
132            }
133        }
134    }
135    #[must_use]
136    pub fn take(&self) -> Option<Range<u64>> {
137        let mut old_state = self.state.load(Ordering::Acquire);
138        loop {
139            let range = Self::unpack(old_state);
140            if range.start == range.end {
141                return None;
142            }
143            let new_state = Self::pack(range.start..range.start);
144            match self.state.compare_exchange_weak(
145                old_state,
146                new_state,
147                Ordering::AcqRel,
148                Ordering::Acquire,
149            ) {
150                Ok(_) => return Some(range),
151                Err(x) => old_state = x,
152            }
153        }
154    }
155    #[must_use]
156    pub fn downgrade(&self) -> WeakTask {
157        WeakTask {
158            state: Arc::downgrade(&self.state),
159        }
160    }
161    #[must_use]
162    pub fn strong_count(&self) -> usize {
163        Arc::strong_count(&self.state)
164    }
165    #[must_use]
166    pub fn weak_count(&self) -> usize {
167        Arc::weak_count(&self.state)
168    }
169}
170impl From<Range<u64>> for Task {
171    fn from(value: Range<u64>) -> Self {
172        Self::new(value)
173    }
174}
175
176impl PartialEq for Task {
177    fn eq(&self, other: &Self) -> bool {
178        Arc::ptr_eq(&self.state, &other.state)
179    }
180}
181impl Eq for Task {}
182
183#[cfg(test)]
184mod tests {
185    #![allow(clippy::unwrap_used)]
186    use super::*;
187
188    #[test]
189    fn test_new_task() {
190        let task = Task::new(10..20);
191        assert_eq!(task.start(), 10);
192        assert_eq!(task.end(), 20);
193        assert_eq!(task.remain(), 10);
194    }
195
196    #[test]
197    fn test_remain() {
198        let task = Task::new(10..25);
199        assert_eq!(task.remain(), 15);
200    }
201
202    #[test]
203    fn test_split_two() {
204        let task = Task::new(1..6); // 1, 2, 3, 4, 5
205        let range = task.split_two().unwrap().unwrap();
206        assert_eq!(task.start(), 1);
207        assert_eq!(task.end(), 3);
208        assert_eq!(range.start, 3);
209        assert_eq!(range.end, 6);
210    }
211
212    #[test]
213    fn test_split_empty() {
214        let task = Task::new(1..1);
215        let range = task.split_two().unwrap();
216        assert_eq!(task.start(), 1);
217        assert_eq!(task.end(), 1);
218        assert_eq!(range, None);
219    }
220
221    #[test]
222    fn test_split_one() {
223        let task = Task::new(1..2);
224        let range = task.split_two().unwrap();
225        assert_eq!(task.start(), 1);
226        assert_eq!(task.end(), 2);
227        assert_eq!(range, None);
228    }
229}