1extern crate alloc;
2use alloc::sync::Arc;
3use core::{fmt, ops::Range, sync::atomic::Ordering};
4use portable_atomic::AtomicU128;
5
6#[derive(Debug, Clone)]
7pub struct Task {
8 state: Arc<AtomicU128>,
9}
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub struct RangeError;
13
14impl fmt::Display for RangeError {
15 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
16 write!(f, "Range invariant violated: start > end or overflow")
17 }
18}
19
20impl Task {
21 #[inline(always)]
22 fn pack(range: Range<u64>) -> u128 {
23 ((range.start as u128) << 64) | (range.end as u128)
24 }
25 #[inline(always)]
26 fn unpack(state: u128) -> Range<u64> {
27 (state >> 64) as u64..state as u64
28 }
29
30 pub fn new(range: Range<u64>) -> Self {
31 assert!(range.start <= range.end);
32 Self {
33 state: Arc::new(AtomicU128::new(Self::pack(range))),
34 }
35 }
36 pub fn get(&self) -> Range<u64> {
37 let state = self.state.load(Ordering::Acquire);
38 Self::unpack(state)
39 }
40 pub fn set(&self, range: Range<u64>) {
41 assert!(range.start <= range.end);
42 self.state.store(Self::pack(range), Ordering::Release);
43 }
44 pub fn start(&self) -> u64 {
45 (self.state.load(Ordering::Acquire) >> 64) as u64
46 }
47 pub fn set_start(&self, start: u64) -> Result<(), RangeError> {
49 let mut old_state = self.state.load(Ordering::Acquire);
50 loop {
51 let mut range = Self::unpack(old_state);
52 range.start = start;
53 if range.start > range.end {
54 break Err(RangeError);
55 }
56 let new_state = Self::pack(range);
57 if old_state == new_state {
58 return Ok(());
59 }
60 match self.state.compare_exchange_weak(
61 old_state,
62 new_state,
63 Ordering::AcqRel,
64 Ordering::Acquire,
65 ) {
66 Ok(_) => break Ok(()),
67 Err(x) => old_state = x,
68 }
69 }
70 }
71 pub fn fetch_add_start(&self, value: u64) -> Result<u64, RangeError> {
74 let mut old_state = self.state.load(Ordering::Acquire);
75 loop {
76 let mut range = Self::unpack(old_state);
77 let old_start = range.start;
78 range.start = range.start.checked_add(value).ok_or(RangeError)?;
79 if range.start > range.end {
80 break Err(RangeError);
81 }
82 let new_state = Self::pack(range);
83 match self.state.compare_exchange_weak(
84 old_state,
85 new_state,
86 Ordering::AcqRel,
87 Ordering::Acquire,
88 ) {
89 Ok(_) => break Ok(old_start),
90 Err(x) => old_state = x,
91 }
92 }
93 }
94 pub fn end(&self) -> u64 {
95 self.state.load(Ordering::Acquire) as u64
96 }
97 pub fn set_end(&self, end: u64) -> Result<(), RangeError> {
99 let mut old_state = self.state.load(Ordering::Acquire);
100 loop {
101 let mut range = Self::unpack(old_state);
102 range.end = end;
103 if range.start > range.end {
104 break Err(RangeError);
105 }
106 let new_state = Self::pack(range);
107 if old_state == new_state {
108 return Ok(());
109 }
110 match self.state.compare_exchange_weak(
111 old_state,
112 new_state,
113 Ordering::AcqRel,
114 Ordering::Acquire,
115 ) {
116 Ok(_) => break Ok(()),
117 Err(x) => old_state = x,
118 }
119 }
120 }
121 pub fn fetch_add_end(&self, value: u64) -> Result<u64, RangeError> {
123 let mut old_state = self.state.load(Ordering::Acquire);
124 loop {
125 let mut range = Self::unpack(old_state);
126 let old_end = range.end;
127 range.end = range.end.checked_add(value).ok_or(RangeError)?;
128 let new_state = Self::pack(range);
129 match self.state.compare_exchange_weak(
130 old_state,
131 new_state,
132 Ordering::AcqRel,
133 Ordering::Acquire,
134 ) {
135 Ok(_) => break Ok(old_end),
136 Err(x) => old_state = x,
137 }
138 }
139 }
140 pub fn remain(&self) -> u64 {
141 let range = self.get();
142 range.end.saturating_sub(range.start)
143 }
144 pub fn split_two(&self) -> Result<Option<Range<u64>>, RangeError> {
147 let mut old_state = self.state.load(Ordering::Acquire);
148 loop {
149 let range = Self::unpack(old_state);
150 if range.start > range.end {
151 return Err(RangeError);
152 }
153 let mid = range.start + (range.end - range.start) / 2;
154 if mid == range.start {
155 return Ok(None);
156 }
157 let new_state = Self::pack(range.start..mid);
158 match self.state.compare_exchange_weak(
159 old_state,
160 new_state,
161 Ordering::AcqRel,
162 Ordering::Acquire,
163 ) {
164 Ok(_) => return Ok(Some(mid..range.end)),
165 Err(x) => old_state = x,
166 }
167 }
168 }
169
170 pub fn take(&self) -> Option<Range<u64>> {
171 let mut old_state = self.state.load(Ordering::Acquire);
172 loop {
173 let range = Self::unpack(old_state);
174 if range.start == range.end {
175 return None;
176 }
177 let new_state = Self::pack(range.start..range.start);
178 match self.state.compare_exchange_weak(
179 old_state,
180 new_state,
181 Ordering::AcqRel,
182 Ordering::Acquire,
183 ) {
184 Ok(_) => return Some(range),
185 Err(x) => old_state = x,
186 }
187 }
188 }
189}
190impl From<&Range<u64>> for Task {
191 fn from(value: &Range<u64>) -> Self {
192 Self::new(value.clone())
193 }
194}
195
196impl PartialEq for Task {
197 fn eq(&self, other: &Self) -> bool {
198 Arc::ptr_eq(&self.state, &other.state)
199 }
200}
201impl Eq for Task {}
202impl Ord for Task {
203 fn cmp(&self, other: &Self) -> core::cmp::Ordering {
204 self.remain().cmp(&other.remain())
205 }
206}
207impl PartialOrd for Task {
208 fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
209 Some(self.cmp(other))
210 }
211}
212
213#[cfg(test)]
214mod tests {
215 use super::*;
216
217 #[test]
218 fn test_new_task() {
219 let task = Task::new(10..20);
220 assert_eq!(task.start(), 10);
221 assert_eq!(task.end(), 20);
222 assert_eq!(task.remain(), 10);
223 }
224
225 #[test]
226 fn test_setters() {
227 let task = Task::new(0..0);
228 task.set_end(14).unwrap();
229 task.set_start(7).unwrap();
230 assert_eq!(task.start(), 7);
231 assert_eq!(task.end(), 14);
232 }
233
234 #[test]
235 fn test_remain() {
236 let task = Task::new(10..25);
237 assert_eq!(task.remain(), 15);
238 task.set_start(20).unwrap();
239 assert_eq!(task.remain(), 5);
240 }
241
242 #[test]
243 fn test_split_two() {
244 let task = Task::new(1..6); let range = task.split_two().unwrap().unwrap();
246 assert_eq!(task.start(), 1);
247 assert_eq!(task.end(), 3);
248 assert_eq!(range.start, 3);
249 assert_eq!(range.end, 6);
250 }
251
252 #[test]
253 fn test_split_empty() {
254 let task = Task::new(1..1);
255 let range = task.split_two().unwrap();
256 assert_eq!(task.start(), 1);
257 assert_eq!(task.end(), 1);
258 assert_eq!(range, None);
259 }
260
261 #[test]
262 fn test_split_one() {
263 let task = Task::new(1..2);
264 let range = task.split_two().unwrap();
265 assert_eq!(task.start(), 1);
266 assert_eq!(task.end(), 2);
267 assert_eq!(range, None);
268 }
269}