1extern crate alloc;
2use alloc::sync::{Arc, Weak};
3use core::{fmt, ops::Range, sync::atomic::Ordering};
4use portable_atomic::AtomicU128;
5
6#[derive(Debug, Clone)]
7pub struct Task {
8 pub state: Arc<AtomicU128>,
9}
10#[derive(Debug, Clone)]
11pub struct WeakTask {
12 pub state: Weak<AtomicU128>,
13}
14
15impl WeakTask {
16 pub fn upgrade(&self) -> Option<Task> {
17 self.state.upgrade().map(|state| Task { state })
18 }
19 pub fn strong_count(&self) -> usize {
20 self.state.strong_count()
21 }
22 pub fn weak_count(&self) -> usize {
23 self.state.weak_count()
24 }
25}
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq)]
28pub struct RangeError;
29
30impl fmt::Display for RangeError {
31 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
32 write!(f, "Range invariant violated: start > end or overflow")
33 }
34}
35
36impl Task {
37 #[inline(always)]
38 fn pack(range: Range<u64>) -> u128 {
39 ((range.start as u128) << 64) | (range.end as u128)
40 }
41 #[inline(always)]
42 fn unpack(state: u128) -> Range<u64> {
43 (state >> 64) as u64..state as u64
44 }
45
46 pub fn new(range: Range<u64>) -> Self {
47 assert!(range.start <= range.end);
48 Self {
49 state: Arc::new(AtomicU128::new(Self::pack(range))),
50 }
51 }
52 pub fn get(&self) -> Range<u64> {
53 let state = self.state.load(Ordering::Acquire);
54 Self::unpack(state)
55 }
56 pub fn set(&self, range: Range<u64>) {
57 assert!(range.start <= range.end);
58 self.state.store(Self::pack(range), Ordering::Release);
59 }
60 pub fn start(&self) -> u64 {
61 (self.state.load(Ordering::Acquire) >> 64) as u64
62 }
63 pub fn safe_add_start(&self, start: u64, bias: u64) -> Result<Range<u64>, RangeError> {
66 let new_start = start.checked_add(bias).ok_or(RangeError)?;
67 let mut old_state = self.state.load(Ordering::Acquire);
68 loop {
69 let mut range = Self::unpack(old_state);
70 let new_start = new_start.min(range.end);
71 if new_start <= range.start {
72 break Err(RangeError);
73 }
74 let span = range.start..new_start;
75 range.start = new_start;
76 let new_state = Self::pack(range);
77 match self.state.compare_exchange_weak(
78 old_state,
79 new_state,
80 Ordering::AcqRel,
81 Ordering::Acquire,
82 ) {
83 Ok(_) => break Ok(span),
84 Err(x) => old_state = x,
85 }
86 }
87 }
88 pub fn end(&self) -> u64 {
112 self.state.load(Ordering::Acquire) as u64
113 }
114 pub fn remain(&self) -> u64 {
115 let range = self.get();
116 range.end.saturating_sub(range.start)
117 }
118 pub fn split_two(&self) -> Result<Option<Range<u64>>, RangeError> {
121 let mut old_state = self.state.load(Ordering::Acquire);
122 loop {
123 let range = Self::unpack(old_state);
124 if range.start > range.end {
125 return Err(RangeError);
126 }
127 let mid = range.start + (range.end - range.start) / 2;
128 if mid == range.start {
129 return Ok(None);
130 }
131 let new_state = Self::pack(range.start..mid);
132 match self.state.compare_exchange_weak(
133 old_state,
134 new_state,
135 Ordering::AcqRel,
136 Ordering::Acquire,
137 ) {
138 Ok(_) => return Ok(Some(mid..range.end)),
139 Err(x) => old_state = x,
140 }
141 }
142 }
143 pub fn take(&self) -> Option<Range<u64>> {
144 let mut old_state = self.state.load(Ordering::Acquire);
145 loop {
146 let range = Self::unpack(old_state);
147 if range.start == range.end {
148 return None;
149 }
150 let new_state = Self::pack(range.start..range.start);
151 match self.state.compare_exchange_weak(
152 old_state,
153 new_state,
154 Ordering::AcqRel,
155 Ordering::Acquire,
156 ) {
157 Ok(_) => return Some(range),
158 Err(x) => old_state = x,
159 }
160 }
161 }
162 pub fn downgrade(&self) -> WeakTask {
163 WeakTask {
164 state: Arc::downgrade(&self.state),
165 }
166 }
167}
168impl From<&Range<u64>> for Task {
169 fn from(value: &Range<u64>) -> Self {
170 Self::new(value.clone())
171 }
172}
173
174impl PartialEq for Task {
175 fn eq(&self, other: &Self) -> bool {
176 Arc::ptr_eq(&self.state, &other.state)
177 }
178}
179impl Eq for Task {}
180
181#[cfg(test)]
182mod tests {
183 use super::*;
184
185 #[test]
186 fn test_new_task() {
187 let task = Task::new(10..20);
188 assert_eq!(task.start(), 10);
189 assert_eq!(task.end(), 20);
190 assert_eq!(task.remain(), 10);
191 }
192
193 #[test]
194 fn test_remain() {
195 let task = Task::new(10..25);
196 assert_eq!(task.remain(), 15);
197 }
198
199 #[test]
200 fn test_split_two() {
201 let task = Task::new(1..6); let range = task.split_two().unwrap().unwrap();
203 assert_eq!(task.start(), 1);
204 assert_eq!(task.end(), 3);
205 assert_eq!(range.start, 3);
206 assert_eq!(range.end, 6);
207 }
208
209 #[test]
210 fn test_split_empty() {
211 let task = Task::new(1..1);
212 let range = task.split_two().unwrap();
213 assert_eq!(task.start(), 1);
214 assert_eq!(task.end(), 1);
215 assert_eq!(range, None);
216 }
217
218 #[test]
219 fn test_split_one() {
220 let task = Task::new(1..2);
221 let range = task.split_two().unwrap();
222 assert_eq!(task.start(), 1);
223 assert_eq!(task.end(), 2);
224 assert_eq!(range, None);
225 }
226}