Skip to main content

fast_steal/
task.rs

1extern crate alloc;
2use alloc::sync::{Arc, Weak};
3use core::{fmt, ops::Range, sync::atomic::Ordering};
4use portable_atomic::AtomicU128;
5
6#[derive(Debug, Clone)]
7pub struct Task {
8    pub state: Arc<AtomicU128>,
9}
10#[derive(Debug, Clone)]
11pub struct WeakTask {
12    pub state: Weak<AtomicU128>,
13}
14
15impl WeakTask {
16    pub fn upgrade(&self) -> Option<Task> {
17        self.state.upgrade().map(|state| Task { state })
18    }
19    pub fn strong_count(&self) -> usize {
20        self.state.strong_count()
21    }
22    pub fn weak_count(&self) -> usize {
23        self.state.weak_count()
24    }
25}
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq)]
28pub struct RangeError;
29
30impl fmt::Display for RangeError {
31    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
32        write!(f, "Range invariant violated: start > end or overflow")
33    }
34}
35
36impl Task {
37    #[inline(always)]
38    fn pack(range: Range<u64>) -> u128 {
39        ((range.start as u128) << 64) | (range.end as u128)
40    }
41    #[inline(always)]
42    fn unpack(state: u128) -> Range<u64> {
43        (state >> 64) as u64..state as u64
44    }
45
46    pub fn new(range: Range<u64>) -> Self {
47        assert!(range.start <= range.end);
48        Self {
49            state: Arc::new(AtomicU128::new(Self::pack(range))),
50        }
51    }
52    pub fn get(&self) -> Range<u64> {
53        let state = self.state.load(Ordering::Acquire);
54        Self::unpack(state)
55    }
56    pub fn set(&self, range: Range<u64>) {
57        assert!(range.start <= range.end);
58        self.state.store(Self::pack(range), Ordering::Release);
59    }
60    pub fn start(&self) -> u64 {
61        (self.state.load(Ordering::Acquire) >> 64) as u64
62    }
63    /// 当 start + bias <= old_start 时返回 RangeError
64    /// 否则返回 old_start..new_start.min(end)
65    pub fn safe_add_start(&self, start: u64, bias: u64) -> Result<Range<u64>, RangeError> {
66        let new_start = start.checked_add(bias).ok_or(RangeError)?;
67        let mut old_state = self.state.load(Ordering::Acquire);
68        loop {
69            let mut range = Self::unpack(old_state);
70            let new_start = new_start.min(range.end);
71            if new_start <= range.start {
72                break Err(RangeError);
73            }
74            let span = range.start..new_start;
75            range.start = new_start;
76            let new_state = Self::pack(range);
77            match self.state.compare_exchange_weak(
78                old_state,
79                new_state,
80                Ordering::AcqRel,
81                Ordering::Acquire,
82            ) {
83                Ok(_) => break Ok(span),
84                Err(x) => old_state = x,
85            }
86        }
87    }
88
89    pub fn end(&self) -> u64 {
90        self.state.load(Ordering::Acquire) as u64
91    }
92    pub fn remain(&self) -> u64 {
93        let range = self.get();
94        range.end.saturating_sub(range.start)
95    }
96    /// 1. 当 start > end 时返回 RangeError
97    /// 2. 当 remain < 2 时返回 None 并且不会修改自己
98    pub fn split_two(&self) -> Result<Option<Range<u64>>, RangeError> {
99        let mut old_state = self.state.load(Ordering::Acquire);
100        loop {
101            let range = Self::unpack(old_state);
102            if range.start > range.end {
103                return Err(RangeError);
104            }
105            let mid = range.start + (range.end - range.start) / 2;
106            if mid == range.start {
107                return Ok(None);
108            }
109            let new_state = Self::pack(range.start..mid);
110            match self.state.compare_exchange_weak(
111                old_state,
112                new_state,
113                Ordering::AcqRel,
114                Ordering::Acquire,
115            ) {
116                Ok(_) => return Ok(Some(mid..range.end)),
117                Err(x) => old_state = x,
118            }
119        }
120    }
121    pub fn take(&self) -> Option<Range<u64>> {
122        let mut old_state = self.state.load(Ordering::Acquire);
123        loop {
124            let range = Self::unpack(old_state);
125            if range.start == range.end {
126                return None;
127            }
128            let new_state = Self::pack(range.start..range.start);
129            match self.state.compare_exchange_weak(
130                old_state,
131                new_state,
132                Ordering::AcqRel,
133                Ordering::Acquire,
134            ) {
135                Ok(_) => return Some(range),
136                Err(x) => old_state = x,
137            }
138        }
139    }
140    pub fn downgrade(&self) -> WeakTask {
141        WeakTask {
142            state: Arc::downgrade(&self.state),
143        }
144    }
145    pub fn strong_count(&self) -> usize {
146        Arc::strong_count(&self.state)
147    }
148    pub fn weak_count(&self) -> usize {
149        Arc::weak_count(&self.state)
150    }
151}
152impl From<&Range<u64>> for Task {
153    fn from(value: &Range<u64>) -> Self {
154        Self::new(value.clone())
155    }
156}
157
158impl PartialEq for Task {
159    fn eq(&self, other: &Self) -> bool {
160        Arc::ptr_eq(&self.state, &other.state)
161    }
162}
163impl Eq for Task {}
164
165#[cfg(test)]
166mod tests {
167    use super::*;
168
169    #[test]
170    fn test_new_task() {
171        let task = Task::new(10..20);
172        assert_eq!(task.start(), 10);
173        assert_eq!(task.end(), 20);
174        assert_eq!(task.remain(), 10);
175    }
176
177    #[test]
178    fn test_remain() {
179        let task = Task::new(10..25);
180        assert_eq!(task.remain(), 15);
181    }
182
183    #[test]
184    fn test_split_two() {
185        let task = Task::new(1..6); // 1, 2, 3, 4, 5
186        let range = task.split_two().unwrap().unwrap();
187        assert_eq!(task.start(), 1);
188        assert_eq!(task.end(), 3);
189        assert_eq!(range.start, 3);
190        assert_eq!(range.end, 6);
191    }
192
193    #[test]
194    fn test_split_empty() {
195        let task = Task::new(1..1);
196        let range = task.split_two().unwrap();
197        assert_eq!(task.start(), 1);
198        assert_eq!(task.end(), 1);
199        assert_eq!(range, None);
200    }
201
202    #[test]
203    fn test_split_one() {
204        let task = Task::new(1..2);
205        let range = task.split_two().unwrap();
206        assert_eq!(task.start(), 1);
207        assert_eq!(task.end(), 2);
208        assert_eq!(range, None);
209    }
210}