1extern crate alloc;
2use alloc::sync::{Arc, Weak};
3use core::{fmt, ops::Range, sync::atomic::Ordering};
4use portable_atomic::AtomicU128;
5
6#[derive(Debug, Clone)]
7pub struct Task {
8 pub state: Arc<AtomicU128>,
9}
10#[derive(Debug, Clone)]
11pub struct WeakTask {
12 pub state: Weak<AtomicU128>,
13}
14
15impl WeakTask {
16 #[must_use]
17 pub fn upgrade(&self) -> Option<Task> {
18 self.state.upgrade().map(|state| Task { state })
19 }
20 #[must_use]
21 pub fn strong_count(&self) -> usize {
22 self.state.strong_count()
23 }
24 #[must_use]
25 pub fn weak_count(&self) -> usize {
26 self.state.weak_count()
27 }
28}
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq)]
31pub struct RangeError;
32
33impl fmt::Display for RangeError {
34 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
35 write!(f, "Range invariant violated: start > end or overflow")
36 }
37}
38
39impl Task {
40 #[allow(clippy::inline_always)]
41 #[inline(always)]
42 const fn pack(range: Range<u64>) -> u128 {
43 ((range.start as u128) << 64) | (range.end as u128)
44 }
45 #[allow(clippy::inline_always)]
46 #[inline(always)]
47 const fn unpack(state: u128) -> Range<u64> {
48 #[allow(clippy::cast_possible_truncation)]
49 let end = state as u64;
50 (state >> 64) as u64..end
51 }
52
53 #[must_use]
56 pub fn new(range: Range<u64>) -> Self {
57 assert!(range.start <= range.end);
58 Self {
59 state: Arc::new(AtomicU128::new(Self::pack(range))),
60 }
61 }
62 #[must_use]
63 pub fn get(&self) -> Range<u64> {
64 let state = self.state.load(Ordering::Acquire);
65 Self::unpack(state)
66 }
67 pub fn set(&self, range: Range<u64>) {
70 assert!(range.start <= range.end);
71 self.state.store(Self::pack(range), Ordering::Release);
72 }
73 #[must_use]
74 pub fn start(&self) -> u64 {
75 (self.state.load(Ordering::Acquire) >> 64) as u64
76 }
77 pub fn safe_add_start(&self, start: u64, bias: u64) -> Result<Range<u64>, RangeError> {
81 let new_start = start.checked_add(bias).ok_or(RangeError)?;
82 let mut old_state = self.state.load(Ordering::Acquire);
83 loop {
84 let mut range = Self::unpack(old_state);
85 let new_start = new_start.min(range.end);
86 if new_start <= range.start {
87 break Err(RangeError);
88 }
89 let span = range.start..new_start;
90 range.start = new_start;
91 let new_state = Self::pack(range);
92 match self.state.compare_exchange_weak(
93 old_state,
94 new_state,
95 Ordering::AcqRel,
96 Ordering::Acquire,
97 ) {
98 Ok(_) => break Ok(span),
99 Err(x) => old_state = x,
100 }
101 }
102 }
103 #[must_use]
104 pub fn end(&self) -> u64 {
105 let state = self.state.load(Ordering::Acquire);
106 #[allow(clippy::cast_possible_truncation)]
107 let end = state as u64;
108 end
109 }
110 #[must_use]
111 pub fn remain(&self) -> u64 {
112 let range = self.get();
113 range.end.saturating_sub(range.start)
114 }
115 pub fn split_two(&self) -> Result<Option<Range<u64>>, RangeError> {
119 let mut old_state = self.state.load(Ordering::Acquire);
120 loop {
121 let range = Self::unpack(old_state);
122 if range.start > range.end {
123 return Err(RangeError);
124 }
125 let mid = range.start + (range.end - range.start) / 2;
126 if mid == range.start {
127 return Ok(None);
128 }
129 let new_state = Self::pack(range.start..mid);
130 match self.state.compare_exchange_weak(
131 old_state,
132 new_state,
133 Ordering::AcqRel,
134 Ordering::Acquire,
135 ) {
136 Ok(_) => return Ok(Some(mid..range.end)),
137 Err(x) => old_state = x,
138 }
139 }
140 }
141 #[must_use]
142 pub fn take(&self) -> Option<Range<u64>> {
143 let mut old_state = self.state.load(Ordering::Acquire);
144 loop {
145 let range = Self::unpack(old_state);
146 if range.start == range.end {
147 return None;
148 }
149 let new_state = Self::pack(range.start..range.start);
150 match self.state.compare_exchange_weak(
151 old_state,
152 new_state,
153 Ordering::AcqRel,
154 Ordering::Acquire,
155 ) {
156 Ok(_) => return Some(range),
157 Err(x) => old_state = x,
158 }
159 }
160 }
161 #[must_use]
162 pub fn downgrade(&self) -> WeakTask {
163 WeakTask {
164 state: Arc::downgrade(&self.state),
165 }
166 }
167 #[must_use]
168 pub fn strong_count(&self) -> usize {
169 Arc::strong_count(&self.state)
170 }
171 #[must_use]
172 pub fn weak_count(&self) -> usize {
173 Arc::weak_count(&self.state)
174 }
175}
176impl From<Range<u64>> for Task {
177 fn from(value: Range<u64>) -> Self {
178 Self::new(value)
179 }
180}
181
182impl PartialEq for Task {
183 fn eq(&self, other: &Self) -> bool {
184 Arc::ptr_eq(&self.state, &other.state)
185 }
186}
187impl Eq for Task {}
188
189#[cfg(test)]
190mod tests {
191 #![allow(clippy::unwrap_used)]
192 use super::*;
193
194 #[test]
195 fn test_new_task() {
196 let task = Task::new(10..20);
197 assert_eq!(task.start(), 10);
198 assert_eq!(task.end(), 20);
199 assert_eq!(task.remain(), 10);
200 }
201
202 #[test]
203 fn test_remain() {
204 let task = Task::new(10..25);
205 assert_eq!(task.remain(), 15);
206 }
207
208 #[test]
209 fn test_split_two() {
210 let task = Task::new(1..6); let range = task.split_two().unwrap().unwrap();
212 assert_eq!(task.start(), 1);
213 assert_eq!(task.end(), 3);
214 assert_eq!(range.start, 3);
215 assert_eq!(range.end, 6);
216 }
217
218 #[test]
219 fn test_split_empty() {
220 let task = Task::new(1..1);
221 let range = task.split_two().unwrap();
222 assert_eq!(task.start(), 1);
223 assert_eq!(task.end(), 1);
224 assert_eq!(range, None);
225 }
226
227 #[test]
228 fn test_split_one() {
229 let task = Task::new(1..2);
230 let range = task.split_two().unwrap();
231 assert_eq!(task.start(), 1);
232 assert_eq!(task.end(), 2);
233 assert_eq!(range, None);
234 }
235}