1extern crate alloc;
2use alloc::sync::{Arc, Weak};
3use core::{fmt, ops::Range, sync::atomic::Ordering};
4use portable_atomic::AtomicU128;
5
6#[derive(Debug, Clone)]
7pub struct Task {
8 pub state: Arc<AtomicU128>,
9}
10#[derive(Debug, Clone)]
11pub struct WeakTask {
12 pub state: Weak<AtomicU128>,
13}
14
15impl WeakTask {
16 pub fn upgrade(&self) -> Option<Task> {
17 self.state.upgrade().map(|state| Task { state })
18 }
19 pub fn strong_count(&self) -> usize {
20 self.state.strong_count()
21 }
22 pub fn weak_count(&self) -> usize {
23 self.state.weak_count()
24 }
25}
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq)]
28pub struct RangeError;
29
30impl fmt::Display for RangeError {
31 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
32 write!(f, "Range invariant violated: start > end or overflow")
33 }
34}
35
36impl Task {
37 #[inline(always)]
38 fn pack(range: Range<u64>) -> u128 {
39 ((range.start as u128) << 64) | (range.end as u128)
40 }
41 #[inline(always)]
42 fn unpack(state: u128) -> Range<u64> {
43 (state >> 64) as u64..state as u64
44 }
45
46 pub fn new(range: Range<u64>) -> Self {
47 assert!(range.start <= range.end);
48 Self {
49 state: Arc::new(AtomicU128::new(Self::pack(range))),
50 }
51 }
52 pub fn get(&self) -> Range<u64> {
53 let state = self.state.load(Ordering::Acquire);
54 Self::unpack(state)
55 }
56 pub fn set(&self, range: Range<u64>) {
57 assert!(range.start <= range.end);
58 self.state.store(Self::pack(range), Ordering::Release);
59 }
60 pub fn start(&self) -> u64 {
61 (self.state.load(Ordering::Acquire) >> 64) as u64
62 }
63 pub fn safe_add_start(&self, start: u64, bias: u64) -> Result<Range<u64>, RangeError> {
66 let new_start = start.checked_add(bias).ok_or(RangeError)?;
67 let mut old_state = self.state.load(Ordering::Acquire);
68 loop {
69 let mut range = Self::unpack(old_state);
70 let new_start = new_start.min(range.end);
71 if new_start <= range.start {
72 break Err(RangeError);
73 }
74 let span = range.start..new_start;
75 range.start = new_start;
76 let new_state = Self::pack(range);
77 match self.state.compare_exchange_weak(
78 old_state,
79 new_state,
80 Ordering::AcqRel,
81 Ordering::Acquire,
82 ) {
83 Ok(_) => break Ok(span),
84 Err(x) => old_state = x,
85 }
86 }
87 }
88
89 pub fn end(&self) -> u64 {
90 self.state.load(Ordering::Acquire) as u64
91 }
92 pub fn remain(&self) -> u64 {
93 let range = self.get();
94 range.end.saturating_sub(range.start)
95 }
96 pub fn split_two(&self) -> Result<Option<Range<u64>>, RangeError> {
99 let mut old_state = self.state.load(Ordering::Acquire);
100 loop {
101 let range = Self::unpack(old_state);
102 if range.start > range.end {
103 return Err(RangeError);
104 }
105 let mid = range.start + (range.end - range.start) / 2;
106 if mid == range.start {
107 return Ok(None);
108 }
109 let new_state = Self::pack(range.start..mid);
110 match self.state.compare_exchange_weak(
111 old_state,
112 new_state,
113 Ordering::AcqRel,
114 Ordering::Acquire,
115 ) {
116 Ok(_) => return Ok(Some(mid..range.end)),
117 Err(x) => old_state = x,
118 }
119 }
120 }
121 pub fn take(&self) -> Option<Range<u64>> {
122 let mut old_state = self.state.load(Ordering::Acquire);
123 loop {
124 let range = Self::unpack(old_state);
125 if range.start == range.end {
126 return None;
127 }
128 let new_state = Self::pack(range.start..range.start);
129 match self.state.compare_exchange_weak(
130 old_state,
131 new_state,
132 Ordering::AcqRel,
133 Ordering::Acquire,
134 ) {
135 Ok(_) => return Some(range),
136 Err(x) => old_state = x,
137 }
138 }
139 }
140 pub fn downgrade(&self) -> WeakTask {
141 WeakTask {
142 state: Arc::downgrade(&self.state),
143 }
144 }
145 pub fn strong_count(&self) -> usize {
146 Arc::strong_count(&self.state)
147 }
148 pub fn weak_count(&self) -> usize {
149 Arc::weak_count(&self.state)
150 }
151}
152impl From<&Range<u64>> for Task {
153 fn from(value: &Range<u64>) -> Self {
154 Self::new(value.clone())
155 }
156}
157
158impl PartialEq for Task {
159 fn eq(&self, other: &Self) -> bool {
160 Arc::ptr_eq(&self.state, &other.state)
161 }
162}
163impl Eq for Task {}
164
165#[cfg(test)]
166mod tests {
167 use super::*;
168
169 #[test]
170 fn test_new_task() {
171 let task = Task::new(10..20);
172 assert_eq!(task.start(), 10);
173 assert_eq!(task.end(), 20);
174 assert_eq!(task.remain(), 10);
175 }
176
177 #[test]
178 fn test_remain() {
179 let task = Task::new(10..25);
180 assert_eq!(task.remain(), 15);
181 }
182
183 #[test]
184 fn test_split_two() {
185 let task = Task::new(1..6); let range = task.split_two().unwrap().unwrap();
187 assert_eq!(task.start(), 1);
188 assert_eq!(task.end(), 3);
189 assert_eq!(range.start, 3);
190 assert_eq!(range.end, 6);
191 }
192
193 #[test]
194 fn test_split_empty() {
195 let task = Task::new(1..1);
196 let range = task.split_two().unwrap();
197 assert_eq!(task.start(), 1);
198 assert_eq!(task.end(), 1);
199 assert_eq!(range, None);
200 }
201
202 #[test]
203 fn test_split_one() {
204 let task = Task::new(1..2);
205 let range = task.split_two().unwrap();
206 assert_eq!(task.start(), 1);
207 assert_eq!(task.end(), 2);
208 assert_eq!(range, None);
209 }
210}