1use num::*;
2use std::ops::{Index, IndexMut};
3use std::sync::Arc;
4use std::{
5 ops::Range,
6 sync::atomic::{AtomicUsize, Ordering},
7};
8
9#[cfg(not(feature = "wasm_support"))]
10use rayon::prelude::*;
11
12pub fn round_up_log2(bits: u32, mut offset: u32) -> u32 {
13 if bits == 0 {
14 offset
15 } else {
16 while (1 << offset) < bits {
17 offset += 1;
18 }
19
20 offset
21 }
22}
23
24#[allow(dead_code)]
25#[cfg(feature = "wasm_support")]
26pub fn shuffle_prims<T: Sized + Copy + Send + Sync>(primitives: &[T], indices: &[u32]) -> Vec<T> {
27 (0..indices.len())
28 .into_iter()
29 .map(|i| primitives[indices[i] as usize].clone())
30 .collect()
31}
32
33#[allow(dead_code)]
34#[cfg(not(feature = "wasm_support"))]
35pub fn shuffle_prims<T: Sized + Copy + Send + Sync>(primitives: &[T], indices: &[u32]) -> Vec<T> {
36 (0..indices.len())
37 .into_par_iter()
38 .map(|i| primitives[indices[i] as usize])
39 .collect()
40}
41
42pub fn prefix_sum<T: Num + Sized + Copy>(first: &[T], count: usize, output: &mut [T]) -> T {
43 debug_assert!(first.len() >= count);
44 debug_assert!(output.len() >= count);
45
46 if count.is_zero() {
47 return first[0];
48 }
49
50 let mut sum: T = T::zero();
51 for i in 0..count {
52 sum = sum.add(first[i]);
53 output[i] = sum;
54 }
55
56 sum
57}
58
59pub unsafe fn move_backward<T: Sized + Clone>(
60 first: *mut T,
61 mut last: *mut T,
62 mut d_last: *mut T,
63) -> *mut T {
64 while first != last {
65 d_last = d_last.sub(1);
66 last = last.sub(1);
67
68 std::ptr::write(d_last, (*last).clone());
69 }
70
71 d_last
72}
73
74pub fn partition<T: Sized + Clone, B>(slice: &mut [T], range: Range<usize>, check: B) -> usize
77where
78 B: Fn(&T) -> bool,
79{
80 debug_assert!(
81 slice.len() >= (range.end - range.start),
82 "Slice was smaller ({}) than range ({})",
83 slice.len(),
84 range.end - range.start
85 );
86
87 let mut count: usize = 0;
88 for i in range {
89 if check(&slice[i]) {
90 slice.swap(i, count);
91 count += 1;
92 }
93 }
94
95 count
96}
97
98#[derive(Debug, Clone)]
99pub struct UnsafeSliceWrapper<'a, T: Sized> {
100 ptr: *mut T,
101 slice: &'a [T],
102}
103
104impl<'a, T> Index<usize> for UnsafeSliceWrapper<'a, T> {
105 type Output = T;
106
107 fn index(&self, index: usize) -> &Self::Output {
108 &self.slice[index]
109 }
110}
111
112impl<'a, T> IndexMut<usize> for UnsafeSliceWrapper<'a, T> {
113 fn index_mut(&mut self, index: usize) -> &mut Self::Output {
114 self.get_mut(index).unwrap()
115 }
116}
117
118#[allow(dead_code)]
119impl<'a, T: Sized> UnsafeSliceWrapper<'a, T> {
120 pub fn new(array: &'a mut [T]) -> Self {
121 Self {
122 ptr: array.as_mut_ptr(),
123 slice: array,
124 }
125 }
126
127 pub fn len(&self) -> usize {
128 self.slice.len()
129 }
130
131 pub fn get(&self, idx: usize) -> Option<&'a T> {
132 debug_assert!(idx < self.slice.len());
133 unsafe { self.ptr.add(idx).as_ref() }
134 }
135
136 pub fn get_mut(&self, idx: usize) -> Option<&'a mut T> {
137 debug_assert!(idx < self.slice.len());
138 unsafe { self.ptr.add(idx).as_mut() }
139 }
140
141 pub fn set(&self, idx: usize, val: T) {
142 debug_assert!(idx < self.slice.len());
143 unsafe {
144 std::ptr::write(self.ptr.add(idx), val);
145 }
146 }
147
148 pub fn as_slice(&self) -> &[T] {
149 self.slice
150 }
151
152 #[allow(clippy::mut_from_ref)]
153 pub fn as_mut(&self) -> &mut [T] {
154 unsafe { std::slice::from_raw_parts_mut(self.ptr, self.len()) }
155 }
156
157 pub fn as_ptr(&self) -> *const T {
158 self.ptr as *const T
159 }
160
161 pub fn as_mut_ptr(&self) -> *mut T {
162 self.ptr
163 }
164
165 pub fn swap(&self, a: usize, b: usize) {
166 debug_assert!(a < self.slice.len());
167 debug_assert!(b < self.slice.len());
168 self.as_mut().swap(a, b);
169 }
170
171 #[allow(clippy::mut_from_ref)]
172 pub fn range(&self, start: usize, end: usize) -> &mut [T] {
173 debug_assert!(start < end, "start: {}, end: {}", start, end);
174 debug_assert!(
175 end <= self.len(),
176 "start: {}, end: {}, len: {}",
177 start,
178 end,
179 self.len()
180 );
181 unsafe { std::slice::from_raw_parts_mut(self.ptr.add(start), end - start) }
182 }
183}
184
185unsafe impl<'a, T> Send for UnsafeSliceWrapper<'a, T> {}
186
187unsafe impl<'a, T> Sync for UnsafeSliceWrapper<'a, T> {}
188
189pub struct TaskSpawner {
190 pub threads_in_flight: Arc<AtomicUsize>,
191 config: TaskConfig,
192}
193
194#[derive(Debug, Copy, Clone)]
195struct TaskConfig {
196 pub work_size_threshold: usize,
197 pub max_depth: usize,
198 pub max_leaf_size: usize,
199}
200
201pub trait Task: Sized + Send + Sync {
202 fn run(self) -> Option<(Self, Self)>;
203 fn work_size(&self) -> usize;
204 fn depth(&self) -> usize;
205}
206
207#[allow(dead_code)]
208impl TaskSpawner {
209 pub fn new() -> Self {
210 Self {
211 config: TaskConfig {
212 work_size_threshold: 1024,
213 max_depth: 64,
214 max_leaf_size: 16,
215 },
216 threads_in_flight: Arc::new(AtomicUsize::new(0)),
217 }
218 }
219
220 pub fn with_work_size_threshold(mut self, threshold: usize) -> Self {
221 self.config.work_size_threshold = threshold;
222 self
223 }
224
225 pub fn with_max_depth(mut self, depth: usize) -> Self {
226 self.config.max_depth = depth;
227 self
228 }
229
230 pub fn with_max_leaf_size(mut self, max_leaf_size: usize) -> Self {
231 self.config.max_leaf_size = max_leaf_size;
232 self
233 }
234
235 pub fn run<T: Task>(&self, first_task: T) {
236 let thread_count = self.threads_in_flight.clone();
237 crossbeam::scope(move |s| {
238 Self::run_task(first_task, self.config, thread_count, s);
239 })
240 .unwrap();
241 }
242
243 fn run_task<'a, T: Task + Sized + 'a>(
244 task: T,
245 config: TaskConfig,
246 thread_count: Arc<AtomicUsize>,
247 scope: &crossbeam::thread::Scope<'a>,
248 ) {
249 thread_count.fetch_add(1, Ordering::SeqCst);
250 let mut sub_tasks = Vec::new();
251
252 let mut stack: Vec<T> = vec![task];
253
254 while !stack.is_empty() {
255 let work_item = stack.pop().unwrap();
256 debug_assert!(work_item.depth() <= config.max_depth);
257
258 if let Some((mut task_a, mut task_b)) = work_item.run() {
259 if task_a.work_size() < task_b.work_size() {
260 std::mem::swap(&mut task_a, &mut task_b);
262 }
263
264 stack.push(task_b);
265
266 let task_a = task_a;
268
269 if task_a.work_size() <= config.work_size_threshold {
271 stack.push(task_a);
272 continue;
273 }
274
275 let count = thread_count.clone();
277 sub_tasks.push(scope.spawn(move |s| {
278 Self::run_task(task_a, config, count, s);
279 }));
280 }
281 }
282
283 while !sub_tasks.is_empty() {
285 let r = sub_tasks.pop().unwrap();
286 r.join().unwrap();
287 }
288 }
289}
290
291#[cfg(test)]
292mod tests {
293 use crate::utils::*;
294
295 #[test]
296 fn prefix_sum_u32_works() {
297 type TestType = u32;
298
299 let input: [TestType; 6] = [1, 2, 3, 4, 5, 6];
300 let output: [TestType; 6] = [1, 3, 6, 10, 15, 21];
301
302 let mut storage: Vec<TestType> = vec![0; 6];
303
304 prefix_sum(&input, 6, storage.as_mut_slice());
305 for i in 0..6 {
306 assert_eq!(output[i], storage[i]);
307 }
308 }
309
310 #[test]
311 fn prefix_sum_usize_works() {
312 type TestType = usize;
313
314 let input: [TestType; 6] = [1, 2, 3, 4, 5, 6];
315 let output: [TestType; 6] = [1, 3, 6, 10, 15, 21];
316
317 let mut storage: Vec<TestType> = vec![0; 6];
318
319 prefix_sum(&input, 6, storage.as_mut_slice());
320 for i in 0..6 {
321 assert_eq!(output[i], storage[i]);
322 }
323 }
324
325 #[test]
326 fn prefix_sum_i32_works() {
327 type TestType = i32;
328
329 let input: [TestType; 6] = [1, 2, 3, 4, 5, 6];
330 let output: [TestType; 6] = [1, 3, 6, 10, 15, 21];
331
332 let mut storage: Vec<TestType> = vec![0; 6];
333
334 prefix_sum(&input, 6, storage.as_mut_slice());
335 for i in 0..6 {
336 assert_eq!(output[i], storage[i]);
337 }
338 }
339
340 #[test]
341 fn prefix_sum_zero() {
342 let input: [u32; 6] = [1, 2, 3, 4, 5, 6];
343 let mut storage: Vec<u32> = vec![0; 6];
344 assert_eq!(input[0], prefix_sum(&input, 0, storage.as_mut()))
345 }
346
347 #[test]
348 fn test_move_backwards() {
349 let mut src: [u32; 3] = [0, 1, 2];
350 let mut dest: [u32; 3] = [0; 3];
351
352 for i in 0..3 {
353 assert_eq!(src[i], i as u32);
354 assert_eq!(dest[i], 0);
355 }
356
357 unsafe {
358 move_backward(
359 src.as_mut_ptr(),
360 src.as_mut_ptr().add(src.len()),
361 dest.as_mut_ptr().add(dest.len()),
362 );
363 }
364
365 for i in 0..3 {
366 assert_eq!(src[i], dest[i]);
367 }
368 }
369}