Skip to main content

fast_steal/
task.rs

1extern crate alloc;
2use alloc::sync::{Arc, Weak};
3use core::{fmt, ops::Range, sync::atomic::Ordering};
4use portable_atomic::AtomicU128;
5
6#[derive(Debug, Clone)]
7pub struct Task {
8    pub state: Arc<AtomicU128>,
9}
10#[derive(Debug, Clone)]
11pub struct WeakTask {
12    pub state: Weak<AtomicU128>,
13}
14
15impl WeakTask {
16    #[must_use]
17    pub fn upgrade(&self) -> Option<Task> {
18        self.state.upgrade().map(|state| Task { state })
19    }
20    #[must_use]
21    pub fn strong_count(&self) -> usize {
22        self.state.strong_count()
23    }
24    #[must_use]
25    pub fn weak_count(&self) -> usize {
26        self.state.weak_count()
27    }
28}
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq)]
31pub struct RangeError;
32
33impl fmt::Display for RangeError {
34    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
35        write!(f, "Range invariant violated: start > end or overflow")
36    }
37}
38
39impl Task {
40    #[allow(clippy::inline_always)]
41    #[inline(always)]
42    const fn pack(range: Range<u64>) -> u128 {
43        ((range.start as u128) << 64) | (range.end as u128)
44    }
45    #[allow(clippy::inline_always)]
46    #[inline(always)]
47    const fn unpack(state: u128) -> Range<u64> {
48        #[allow(clippy::cast_possible_truncation)]
49        let end = state as u64;
50        (state >> 64) as u64..end
51    }
52
53    /// # Panics
54    /// 当 range.start > range.end
55    #[must_use]
56    pub fn new(range: Range<u64>) -> Self {
57        assert!(range.start <= range.end);
58        Self {
59            state: Arc::new(AtomicU128::new(Self::pack(range))),
60        }
61    }
62    #[must_use]
63    pub fn get(&self) -> Range<u64> {
64        let state = self.state.load(Ordering::Acquire);
65        Self::unpack(state)
66    }
67    /// # Panics
68    /// 当 range.start > range.end
69    pub fn set(&self, range: Range<u64>) {
70        assert!(range.start <= range.end);
71        self.state.store(Self::pack(range), Ordering::Release);
72    }
73    #[must_use]
74    pub fn start(&self) -> u64 {
75        (self.state.load(Ordering::Acquire) >> 64) as u64
76    }
77    /// # Errors
78    /// 当 `start` + `bias` <= `old_start` 时返回 [`RangeError`]
79    /// 否则返回 `old_start..new_start.min(end)`
80    pub fn safe_add_start(&self, start: u64, bias: u64) -> Result<Range<u64>, RangeError> {
81        let new_start = start.checked_add(bias).ok_or(RangeError)?;
82        let mut old_state = self.state.load(Ordering::Acquire);
83        loop {
84            let mut range = Self::unpack(old_state);
85            let new_start = new_start.min(range.end);
86            if new_start <= range.start {
87                break Err(RangeError);
88            }
89            let span = range.start..new_start;
90            range.start = new_start;
91            let new_state = Self::pack(range);
92            match self.state.compare_exchange_weak(
93                old_state,
94                new_state,
95                Ordering::AcqRel,
96                Ordering::Acquire,
97            ) {
98                Ok(_) => break Ok(span),
99                Err(x) => old_state = x,
100            }
101        }
102    }
103    #[must_use]
104    pub fn end(&self) -> u64 {
105        let state = self.state.load(Ordering::Acquire);
106        #[allow(clippy::cast_possible_truncation)]
107        let end = state as u64;
108        end
109    }
110    #[must_use]
111    pub fn remain(&self) -> u64 {
112        let range = self.get();
113        range.end.saturating_sub(range.start)
114    }
115    /// # Errors
116    /// 1. 当 start > end 时返回 [`RangeError`]
117    /// 2. 当 remain < 2 时返回 None 并且不会修改自己
118    pub fn split_two(&self) -> Result<Option<Range<u64>>, RangeError> {
119        let mut old_state = self.state.load(Ordering::Acquire);
120        loop {
121            let range = Self::unpack(old_state);
122            if range.start > range.end {
123                return Err(RangeError);
124            }
125            let mid = range.start + (range.end - range.start) / 2;
126            if mid == range.start {
127                return Ok(None);
128            }
129            let new_state = Self::pack(range.start..mid);
130            match self.state.compare_exchange_weak(
131                old_state,
132                new_state,
133                Ordering::AcqRel,
134                Ordering::Acquire,
135            ) {
136                Ok(_) => return Ok(Some(mid..range.end)),
137                Err(x) => old_state = x,
138            }
139        }
140    }
141    #[must_use]
142    pub fn take(&self) -> Option<Range<u64>> {
143        let mut old_state = self.state.load(Ordering::Acquire);
144        loop {
145            let range = Self::unpack(old_state);
146            if range.start == range.end {
147                return None;
148            }
149            let new_state = Self::pack(range.start..range.start);
150            match self.state.compare_exchange_weak(
151                old_state,
152                new_state,
153                Ordering::AcqRel,
154                Ordering::Acquire,
155            ) {
156                Ok(_) => return Some(range),
157                Err(x) => old_state = x,
158            }
159        }
160    }
161    #[must_use]
162    pub fn downgrade(&self) -> WeakTask {
163        WeakTask {
164            state: Arc::downgrade(&self.state),
165        }
166    }
167    #[must_use]
168    pub fn strong_count(&self) -> usize {
169        Arc::strong_count(&self.state)
170    }
171    #[must_use]
172    pub fn weak_count(&self) -> usize {
173        Arc::weak_count(&self.state)
174    }
175}
176impl From<Range<u64>> for Task {
177    fn from(value: Range<u64>) -> Self {
178        Self::new(value)
179    }
180}
181
182impl PartialEq for Task {
183    fn eq(&self, other: &Self) -> bool {
184        Arc::ptr_eq(&self.state, &other.state)
185    }
186}
187impl Eq for Task {}
188
189#[cfg(test)]
190mod tests {
191    #![allow(clippy::unwrap_used)]
192    use super::*;
193
194    #[test]
195    fn test_new_task() {
196        let task = Task::new(10..20);
197        assert_eq!(task.start(), 10);
198        assert_eq!(task.end(), 20);
199        assert_eq!(task.remain(), 10);
200    }
201
202    #[test]
203    fn test_remain() {
204        let task = Task::new(10..25);
205        assert_eq!(task.remain(), 15);
206    }
207
208    #[test]
209    fn test_split_two() {
210        let task = Task::new(1..6); // 1, 2, 3, 4, 5
211        let range = task.split_two().unwrap().unwrap();
212        assert_eq!(task.start(), 1);
213        assert_eq!(task.end(), 3);
214        assert_eq!(range.start, 3);
215        assert_eq!(range.end, 6);
216    }
217
218    #[test]
219    fn test_split_empty() {
220        let task = Task::new(1..1);
221        let range = task.split_two().unwrap();
222        assert_eq!(task.start(), 1);
223        assert_eq!(task.end(), 1);
224        assert_eq!(range, None);
225    }
226
227    #[test]
228    fn test_split_one() {
229        let task = Task::new(1..2);
230        let range = task.split_two().unwrap();
231        assert_eq!(task.start(), 1);
232        assert_eq!(task.end(), 2);
233        assert_eq!(range, None);
234    }
235}