Skip to main content

fast_steal/
task.rs

1extern crate alloc;
2use alloc::sync::{Arc, Weak};
3use core::{fmt, ops::Range, sync::atomic::Ordering};
4use portable_atomic::AtomicU128;
5
6#[derive(Debug, Clone)]
7pub struct Task {
8    pub state: Arc<AtomicU128>,
9}
10#[derive(Debug, Clone)]
11pub struct WeakTask {
12    pub state: Weak<AtomicU128>,
13}
14
15impl WeakTask {
16    pub fn upgrade(&self) -> Option<Task> {
17        self.state.upgrade().map(|state| Task { state })
18    }
19    pub fn strong_count(&self) -> usize {
20        self.state.strong_count()
21    }
22    pub fn weak_count(&self) -> usize {
23        self.state.weak_count()
24    }
25}
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq)]
28pub struct RangeError;
29
30impl fmt::Display for RangeError {
31    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
32        write!(f, "Range invariant violated: start > end or overflow")
33    }
34}
35
36impl Task {
37    #[inline(always)]
38    fn pack(range: Range<u64>) -> u128 {
39        ((range.start as u128) << 64) | (range.end as u128)
40    }
41    #[inline(always)]
42    fn unpack(state: u128) -> Range<u64> {
43        (state >> 64) as u64..state as u64
44    }
45
46    pub fn new(range: Range<u64>) -> Self {
47        assert!(range.start <= range.end);
48        Self {
49            state: Arc::new(AtomicU128::new(Self::pack(range))),
50        }
51    }
52    pub fn get(&self) -> Range<u64> {
53        let state = self.state.load(Ordering::Acquire);
54        Self::unpack(state)
55    }
56    pub fn set(&self, range: Range<u64>) {
57        assert!(range.start <= range.end);
58        self.state.store(Self::pack(range), Ordering::Release);
59    }
60    pub fn start(&self) -> u64 {
61        (self.state.load(Ordering::Acquire) >> 64) as u64
62    }
63    /// 当 start + bias <= old_start 时返回 RangeError
64    /// 否则返回 old_start..new_start.min(end)
65    pub fn safe_add_start(&self, start: u64, bias: u64) -> Result<Range<u64>, RangeError> {
66        let new_start = start.checked_add(bias).ok_or(RangeError)?;
67        let mut old_state = self.state.load(Ordering::Acquire);
68        loop {
69            let mut range = Self::unpack(old_state);
70            let new_start = new_start.min(range.end);
71            if new_start <= range.start {
72                break Err(RangeError);
73            }
74            let span = range.start..new_start;
75            range.start = new_start;
76            let new_state = Self::pack(range);
77            match self.state.compare_exchange_weak(
78                old_state,
79                new_state,
80                Ordering::AcqRel,
81                Ordering::Acquire,
82            ) {
83                Ok(_) => break Ok(span),
84                Err(x) => old_state = x,
85            }
86        }
87    }
88    // /// 1. 当 start + value 溢出时返回 RangeError
89    //
90    // pub fn fetch_add_start(&self, value: u64) -> Result<u64, RangeError> {
91    //     let mut old_state = self.state.load(Ordering::Acquire);
92    //     loop {
93    //         let mut range = Self::unpack(old_state);
94    //         let old_start = range.start;
95    //         range.start = range.start.checked_add(value).ok_or(RangeError)?;
96    //         if range.start > range.end {
97    //             break Err(RangeError);
98    //         }
99    //         let new_state = Self::pack(range);
100    //         match self.state.compare_exchange_weak(
101    //             old_state,
102    //             new_state,
103    //             Ordering::AcqRel,
104    //             Ordering::Acquire,
105    //         ) {
106    //             Ok(_) => break Ok(old_start),
107    //             Err(x) => old_state = x,
108    //         }
109    //     }
110    // }
111    pub fn end(&self) -> u64 {
112        self.state.load(Ordering::Acquire) as u64
113    }
114    pub fn remain(&self) -> u64 {
115        let range = self.get();
116        range.end.saturating_sub(range.start)
117    }
118    /// 1. 当 start > end 时返回 RangeError
119    /// 2. 当 remain < 2 时返回 None 并且不会修改自己
120    pub fn split_two(&self) -> Result<Option<Range<u64>>, RangeError> {
121        let mut old_state = self.state.load(Ordering::Acquire);
122        loop {
123            let range = Self::unpack(old_state);
124            if range.start > range.end {
125                return Err(RangeError);
126            }
127            let mid = range.start + (range.end - range.start) / 2;
128            if mid == range.start {
129                return Ok(None);
130            }
131            let new_state = Self::pack(range.start..mid);
132            match self.state.compare_exchange_weak(
133                old_state,
134                new_state,
135                Ordering::AcqRel,
136                Ordering::Acquire,
137            ) {
138                Ok(_) => return Ok(Some(mid..range.end)),
139                Err(x) => old_state = x,
140            }
141        }
142    }
143    pub fn take(&self) -> Option<Range<u64>> {
144        let mut old_state = self.state.load(Ordering::Acquire);
145        loop {
146            let range = Self::unpack(old_state);
147            if range.start == range.end {
148                return None;
149            }
150            let new_state = Self::pack(range.start..range.start);
151            match self.state.compare_exchange_weak(
152                old_state,
153                new_state,
154                Ordering::AcqRel,
155                Ordering::Acquire,
156            ) {
157                Ok(_) => return Some(range),
158                Err(x) => old_state = x,
159            }
160        }
161    }
162    pub fn downgrade(&self) -> WeakTask {
163        WeakTask {
164            state: Arc::downgrade(&self.state),
165        }
166    }
167}
168impl From<&Range<u64>> for Task {
169    fn from(value: &Range<u64>) -> Self {
170        Self::new(value.clone())
171    }
172}
173
174impl PartialEq for Task {
175    fn eq(&self, other: &Self) -> bool {
176        Arc::ptr_eq(&self.state, &other.state)
177    }
178}
179impl Eq for Task {}
180
181#[cfg(test)]
182mod tests {
183    use super::*;
184
185    #[test]
186    fn test_new_task() {
187        let task = Task::new(10..20);
188        assert_eq!(task.start(), 10);
189        assert_eq!(task.end(), 20);
190        assert_eq!(task.remain(), 10);
191    }
192
193    #[test]
194    fn test_remain() {
195        let task = Task::new(10..25);
196        assert_eq!(task.remain(), 15);
197    }
198
199    #[test]
200    fn test_split_two() {
201        let task = Task::new(1..6); // 1, 2, 3, 4, 5
202        let range = task.split_two().unwrap().unwrap();
203        assert_eq!(task.start(), 1);
204        assert_eq!(task.end(), 3);
205        assert_eq!(range.start, 3);
206        assert_eq!(range.end, 6);
207    }
208
209    #[test]
210    fn test_split_empty() {
211        let task = Task::new(1..1);
212        let range = task.split_two().unwrap();
213        assert_eq!(task.start(), 1);
214        assert_eq!(task.end(), 1);
215        assert_eq!(range, None);
216    }
217
218    #[test]
219    fn test_split_one() {
220        let task = Task::new(1..2);
221        let range = task.split_two().unwrap();
222        assert_eq!(task.start(), 1);
223        assert_eq!(task.end(), 2);
224        assert_eq!(range, None);
225    }
226}