1extern crate alloc;
2use alloc::sync::{Arc, Weak};
3use core::{fmt, ops::Range, sync::atomic::Ordering};
4use portable_atomic::AtomicU128;
5
6#[derive(Debug, Clone)]
7pub struct Task {
8 pub state: Arc<AtomicU128>,
9}
10#[derive(Debug, Clone)]
11pub struct WeakTask {
12 pub state: Weak<AtomicU128>,
13}
14
15impl WeakTask {
16 #[must_use]
17 pub fn upgrade(&self) -> Option<Task> {
18 self.state.upgrade().map(|state| Task { state })
19 }
20 #[must_use]
21 pub fn strong_count(&self) -> usize {
22 self.state.strong_count()
23 }
24 #[must_use]
25 pub fn weak_count(&self) -> usize {
26 self.state.weak_count()
27 }
28}
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq)]
31pub struct RangeError;
32
33impl fmt::Display for RangeError {
34 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
35 write!(f, "Range invariant violated: start > end or overflow")
36 }
37}
38
39impl Task {
40 #[allow(clippy::inline_always)]
41 #[inline(always)]
42 const fn pack(range: Range<u64>) -> u128 {
43 ((range.start as u128) << 64) | (range.end as u128)
44 }
45 #[allow(clippy::inline_always)]
46 #[inline(always)]
47 const fn unpack(state: u128) -> Range<u64> {
48 #[allow(clippy::cast_possible_truncation)]
49 let end = state as u64;
50 (state >> 64) as u64..end
51 }
52
53 #[must_use]
56 pub fn new(range: Range<u64>) -> Self {
57 assert!(range.start <= range.end);
58 Self {
59 state: Arc::new(AtomicU128::new(Self::pack(range))),
60 }
61 }
62 #[must_use]
63 pub fn get(&self) -> Range<u64> {
64 let state = self.state.load(Ordering::Acquire);
65 Self::unpack(state)
66 }
67 #[must_use]
68 pub fn start(&self) -> u64 {
69 (self.state.load(Ordering::Acquire) >> 64) as u64
70 }
71 pub fn safe_add_start(&self, start: u64, bias: u64) -> Result<Range<u64>, RangeError> {
75 let new_start = start.checked_add(bias).ok_or(RangeError)?;
76 let mut old_state = self.state.load(Ordering::Acquire);
77 loop {
78 let mut range = Self::unpack(old_state);
79 let new_start = new_start.min(range.end);
80 if new_start <= range.start {
81 break Err(RangeError);
82 }
83 let span = range.start..new_start;
84 range.start = new_start;
85 let new_state = Self::pack(range);
86 match self.state.compare_exchange_weak(
87 old_state,
88 new_state,
89 Ordering::AcqRel,
90 Ordering::Acquire,
91 ) {
92 Ok(_) => break Ok(span),
93 Err(x) => old_state = x,
94 }
95 }
96 }
97 #[must_use]
98 pub fn end(&self) -> u64 {
99 let state = self.state.load(Ordering::Acquire);
100 #[allow(clippy::cast_possible_truncation)]
101 let end = state as u64;
102 end
103 }
104 #[must_use]
105 pub fn remain(&self) -> u64 {
106 let range = self.get();
107 range.end.saturating_sub(range.start)
108 }
109 pub fn split_two(&self) -> Result<Option<Range<u64>>, RangeError> {
113 let mut old_state = self.state.load(Ordering::Acquire);
114 loop {
115 let range = Self::unpack(old_state);
116 if range.start > range.end {
117 return Err(RangeError);
118 }
119 let mid = range.start + (range.end - range.start) / 2;
120 if mid == range.start {
121 return Ok(None);
122 }
123 let new_state = Self::pack(range.start..mid);
124 match self.state.compare_exchange_weak(
125 old_state,
126 new_state,
127 Ordering::AcqRel,
128 Ordering::Acquire,
129 ) {
130 Ok(_) => return Ok(Some(mid..range.end)),
131 Err(x) => old_state = x,
132 }
133 }
134 }
135 #[must_use]
136 pub fn take(&self) -> Option<Range<u64>> {
137 let mut old_state = self.state.load(Ordering::Acquire);
138 loop {
139 let range = Self::unpack(old_state);
140 if range.start == range.end {
141 return None;
142 }
143 let new_state = Self::pack(range.start..range.start);
144 match self.state.compare_exchange_weak(
145 old_state,
146 new_state,
147 Ordering::AcqRel,
148 Ordering::Acquire,
149 ) {
150 Ok(_) => return Some(range),
151 Err(x) => old_state = x,
152 }
153 }
154 }
155 #[must_use]
156 pub fn downgrade(&self) -> WeakTask {
157 WeakTask {
158 state: Arc::downgrade(&self.state),
159 }
160 }
161 #[must_use]
162 pub fn strong_count(&self) -> usize {
163 Arc::strong_count(&self.state)
164 }
165 #[must_use]
166 pub fn weak_count(&self) -> usize {
167 Arc::weak_count(&self.state)
168 }
169}
170impl From<Range<u64>> for Task {
171 fn from(value: Range<u64>) -> Self {
172 Self::new(value)
173 }
174}
175
176impl PartialEq for Task {
177 fn eq(&self, other: &Self) -> bool {
178 Arc::ptr_eq(&self.state, &other.state)
179 }
180}
181impl Eq for Task {}
182
183#[cfg(test)]
184mod tests {
185 #![allow(clippy::unwrap_used)]
186 use super::*;
187
188 #[test]
189 fn test_new_task() {
190 let task = Task::new(10..20);
191 assert_eq!(task.start(), 10);
192 assert_eq!(task.end(), 20);
193 assert_eq!(task.remain(), 10);
194 }
195
196 #[test]
197 fn test_remain() {
198 let task = Task::new(10..25);
199 assert_eq!(task.remain(), 15);
200 }
201
202 #[test]
203 fn test_split_two() {
204 let task = Task::new(1..6); let range = task.split_two().unwrap().unwrap();
206 assert_eq!(task.start(), 1);
207 assert_eq!(task.end(), 3);
208 assert_eq!(range.start, 3);
209 assert_eq!(range.end, 6);
210 }
211
212 #[test]
213 fn test_split_empty() {
214 let task = Task::new(1..1);
215 let range = task.split_two().unwrap();
216 assert_eq!(task.start(), 1);
217 assert_eq!(task.end(), 1);
218 assert_eq!(range, None);
219 }
220
221 #[test]
222 fn test_split_one() {
223 let task = Task::new(1..2);
224 let range = task.split_two().unwrap();
225 assert_eq!(task.start(), 1);
226 assert_eq!(task.end(), 2);
227 assert_eq!(range, None);
228 }
229}