Skip to main content

fast_steal/
task.rs

1extern crate alloc;
2use alloc::sync::Arc;
3use core::{fmt, ops::Range, sync::atomic::Ordering};
4use portable_atomic::AtomicU128;
5
6#[derive(Debug, Clone)]
7pub struct Task {
8    state: Arc<AtomicU128>,
9}
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub struct RangeError;
13
14impl fmt::Display for RangeError {
15    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
16        write!(f, "Range invariant violated: start > end or overflow")
17    }
18}
19
20impl Task {
21    #[inline(always)]
22    fn pack(range: Range<u64>) -> u128 {
23        ((range.start as u128) << 64) | (range.end as u128)
24    }
25    #[inline(always)]
26    fn unpack(state: u128) -> Range<u64> {
27        (state >> 64) as u64..state as u64
28    }
29
30    pub fn new(range: Range<u64>) -> Self {
31        assert!(range.start <= range.end);
32        Self {
33            state: Arc::new(AtomicU128::new(Self::pack(range))),
34        }
35    }
36    pub fn get(&self) -> Range<u64> {
37        let state = self.state.load(Ordering::Acquire);
38        Self::unpack(state)
39    }
40    pub fn set(&self, range: Range<u64>) {
41        assert!(range.start <= range.end);
42        self.state.store(Self::pack(range), Ordering::Release);
43    }
44    pub fn start(&self) -> u64 {
45        (self.state.load(Ordering::Acquire) >> 64) as u64
46    }
47    /// 当 start > end 时返回 RangeError
48    pub fn set_start(&self, start: u64) -> Result<(), RangeError> {
49        let mut old_state = self.state.load(Ordering::Acquire);
50        loop {
51            let mut range = Self::unpack(old_state);
52            range.start = start;
53            if range.start > range.end {
54                break Err(RangeError);
55            }
56            let new_state = Self::pack(range);
57            if old_state == new_state {
58                return Ok(());
59            }
60            match self.state.compare_exchange_weak(
61                old_state,
62                new_state,
63                Ordering::AcqRel,
64                Ordering::Acquire,
65            ) {
66                Ok(_) => break Ok(()),
67                Err(x) => old_state = x,
68            }
69        }
70    }
71    /// 1. 当 start + value 溢出时返回 RangeError
72    /// 2. 当 start + value > end 时返回 RangeError
73    pub fn fetch_add_start(&self, value: u64) -> Result<u64, RangeError> {
74        let mut old_state = self.state.load(Ordering::Acquire);
75        loop {
76            let mut range = Self::unpack(old_state);
77            let old_start = range.start;
78            range.start = range.start.checked_add(value).ok_or(RangeError)?;
79            if range.start > range.end {
80                break Err(RangeError);
81            }
82            let new_state = Self::pack(range);
83            match self.state.compare_exchange_weak(
84                old_state,
85                new_state,
86                Ordering::AcqRel,
87                Ordering::Acquire,
88            ) {
89                Ok(_) => break Ok(old_start),
90                Err(x) => old_state = x,
91            }
92        }
93    }
94    pub fn end(&self) -> u64 {
95        self.state.load(Ordering::Acquire) as u64
96    }
97    /// 当 end < start 时返回 RangeError
98    pub fn set_end(&self, end: u64) -> Result<(), RangeError> {
99        let mut old_state = self.state.load(Ordering::Acquire);
100        loop {
101            let mut range = Self::unpack(old_state);
102            range.end = end;
103            if range.start > range.end {
104                break Err(RangeError);
105            }
106            let new_state = Self::pack(range);
107            if old_state == new_state {
108                return Ok(());
109            }
110            match self.state.compare_exchange_weak(
111                old_state,
112                new_state,
113                Ordering::AcqRel,
114                Ordering::Acquire,
115            ) {
116                Ok(_) => break Ok(()),
117                Err(x) => old_state = x,
118            }
119        }
120    }
121    /// 当 end + value 溢出时返回 RangeError
122    pub fn fetch_add_end(&self, value: u64) -> Result<u64, RangeError> {
123        let mut old_state = self.state.load(Ordering::Acquire);
124        loop {
125            let mut range = Self::unpack(old_state);
126            let old_end = range.end;
127            range.end = range.end.checked_add(value).ok_or(RangeError)?;
128            let new_state = Self::pack(range);
129            match self.state.compare_exchange_weak(
130                old_state,
131                new_state,
132                Ordering::AcqRel,
133                Ordering::Acquire,
134            ) {
135                Ok(_) => break Ok(old_end),
136                Err(x) => old_state = x,
137            }
138        }
139    }
140    pub fn remain(&self) -> u64 {
141        let range = self.get();
142        range.end.saturating_sub(range.start)
143    }
144    /// 1. 当 start > end 时返回 RangeError
145    /// 2. 当 remain < 2 时返回 None 并且不会修改自己
146    pub fn split_two(&self) -> Result<Option<Range<u64>>, RangeError> {
147        let mut old_state = self.state.load(Ordering::Acquire);
148        loop {
149            let range = Self::unpack(old_state);
150            if range.start > range.end {
151                return Err(RangeError);
152            }
153            let mid = range.start + (range.end - range.start) / 2;
154            if mid == range.start {
155                return Ok(None);
156            }
157            let new_state = Self::pack(range.start..mid);
158            match self.state.compare_exchange_weak(
159                old_state,
160                new_state,
161                Ordering::AcqRel,
162                Ordering::Acquire,
163            ) {
164                Ok(_) => return Ok(Some(mid..range.end)),
165                Err(x) => old_state = x,
166            }
167        }
168    }
169
170    pub fn take(&self) -> Option<Range<u64>> {
171        let mut old_state = self.state.load(Ordering::Acquire);
172        loop {
173            let range = Self::unpack(old_state);
174            if range.start == range.end {
175                return None;
176            }
177            let new_state = Self::pack(range.start..range.start);
178            match self.state.compare_exchange_weak(
179                old_state,
180                new_state,
181                Ordering::AcqRel,
182                Ordering::Acquire,
183            ) {
184                Ok(_) => return Some(range),
185                Err(x) => old_state = x,
186            }
187        }
188    }
189}
190impl From<&Range<u64>> for Task {
191    fn from(value: &Range<u64>) -> Self {
192        Self::new(value.clone())
193    }
194}
195
196impl PartialEq for Task {
197    fn eq(&self, other: &Self) -> bool {
198        Arc::ptr_eq(&self.state, &other.state)
199    }
200}
201impl Eq for Task {}
202impl Ord for Task {
203    fn cmp(&self, other: &Self) -> core::cmp::Ordering {
204        self.remain().cmp(&other.remain())
205    }
206}
207impl PartialOrd for Task {
208    fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
209        Some(self.cmp(other))
210    }
211}
212
213#[cfg(test)]
214mod tests {
215    use super::*;
216
217    #[test]
218    fn test_new_task() {
219        let task = Task::new(10..20);
220        assert_eq!(task.start(), 10);
221        assert_eq!(task.end(), 20);
222        assert_eq!(task.remain(), 10);
223    }
224
225    #[test]
226    fn test_setters() {
227        let task = Task::new(0..0);
228        task.set_end(14).unwrap();
229        task.set_start(7).unwrap();
230        assert_eq!(task.start(), 7);
231        assert_eq!(task.end(), 14);
232    }
233
234    #[test]
235    fn test_remain() {
236        let task = Task::new(10..25);
237        assert_eq!(task.remain(), 15);
238        task.set_start(20).unwrap();
239        assert_eq!(task.remain(), 5);
240    }
241
242    #[test]
243    fn test_split_two() {
244        let task = Task::new(1..6); // 1, 2, 3, 4, 5
245        let range = task.split_two().unwrap().unwrap();
246        assert_eq!(task.start(), 1);
247        assert_eq!(task.end(), 3);
248        assert_eq!(range.start, 3);
249        assert_eq!(range.end, 6);
250    }
251
252    #[test]
253    fn test_split_empty() {
254        let task = Task::new(1..1);
255        let range = task.split_two().unwrap();
256        assert_eq!(task.start(), 1);
257        assert_eq!(task.end(), 1);
258        assert_eq!(range, None);
259    }
260
261    #[test]
262    fn test_split_one() {
263        let task = Task::new(1..2);
264        let range = task.split_two().unwrap();
265        assert_eq!(task.start(), 1);
266        assert_eq!(task.end(), 2);
267        assert_eq!(range, None);
268    }
269}