rtbvh/
utils.rs

1use num::*;
2use std::ops::{Index, IndexMut};
3use std::sync::Arc;
4use std::{
5    ops::Range,
6    sync::atomic::{AtomicUsize, Ordering},
7};
8
9#[cfg(not(feature = "wasm_support"))]
10use rayon::prelude::*;
11
12pub fn round_up_log2(bits: u32, mut offset: u32) -> u32 {
13    if bits == 0 {
14        offset
15    } else {
16        while (1 << offset) < bits {
17            offset += 1;
18        }
19
20        offset
21    }
22}
23
24#[allow(dead_code)]
25#[cfg(feature = "wasm_support")]
26pub fn shuffle_prims<T: Sized + Copy + Send + Sync>(primitives: &[T], indices: &[u32]) -> Vec<T> {
27    (0..indices.len())
28        .into_iter()
29        .map(|i| primitives[indices[i] as usize].clone())
30        .collect()
31}
32
33#[allow(dead_code)]
34#[cfg(not(feature = "wasm_support"))]
35pub fn shuffle_prims<T: Sized + Copy + Send + Sync>(primitives: &[T], indices: &[u32]) -> Vec<T> {
36    (0..indices.len())
37        .into_par_iter()
38        .map(|i| primitives[indices[i] as usize])
39        .collect()
40}
41
42pub fn prefix_sum<T: Num + Sized + Copy>(first: &[T], count: usize, output: &mut [T]) -> T {
43    debug_assert!(first.len() >= count);
44    debug_assert!(output.len() >= count);
45
46    if count.is_zero() {
47        return first[0];
48    }
49
50    let mut sum: T = T::zero();
51    for i in 0..count {
52        sum = sum.add(first[i]);
53        output[i] = sum;
54    }
55
56    sum
57}
58
59pub unsafe fn move_backward<T: Sized + Clone>(
60    first: *mut T,
61    mut last: *mut T,
62    mut d_last: *mut T,
63) -> *mut T {
64    while first != last {
65        d_last = d_last.sub(1);
66        last = last.sub(1);
67
68        std::ptr::write(d_last, (*last).clone());
69    }
70
71    d_last
72}
73
74/// Partitions range of slice according to given check.
75/// Returns how many elements went left.
76pub fn partition<T: Sized + Clone, B>(slice: &mut [T], range: Range<usize>, check: B) -> usize
77where
78    B: Fn(&T) -> bool,
79{
80    debug_assert!(
81        slice.len() >= (range.end - range.start),
82        "Slice was smaller ({}) than range ({})",
83        slice.len(),
84        range.end - range.start
85    );
86
87    let mut count: usize = 0;
88    for i in range {
89        if check(&slice[i]) {
90            slice.swap(i, count);
91            count += 1;
92        }
93    }
94
95    count
96}
97
98#[derive(Debug, Clone)]
99pub struct UnsafeSliceWrapper<'a, T: Sized> {
100    ptr: *mut T,
101    slice: &'a [T],
102}
103
104impl<'a, T> Index<usize> for UnsafeSliceWrapper<'a, T> {
105    type Output = T;
106
107    fn index(&self, index: usize) -> &Self::Output {
108        &self.slice[index]
109    }
110}
111
112impl<'a, T> IndexMut<usize> for UnsafeSliceWrapper<'a, T> {
113    fn index_mut(&mut self, index: usize) -> &mut Self::Output {
114        self.get_mut(index).unwrap()
115    }
116}
117
118#[allow(dead_code)]
119impl<'a, T: Sized> UnsafeSliceWrapper<'a, T> {
120    pub fn new(array: &'a mut [T]) -> Self {
121        Self {
122            ptr: array.as_mut_ptr(),
123            slice: array,
124        }
125    }
126
127    pub fn len(&self) -> usize {
128        self.slice.len()
129    }
130
131    pub fn get(&self, idx: usize) -> Option<&'a T> {
132        debug_assert!(idx < self.slice.len());
133        unsafe { self.ptr.add(idx).as_ref() }
134    }
135
136    pub fn get_mut(&self, idx: usize) -> Option<&'a mut T> {
137        debug_assert!(idx < self.slice.len());
138        unsafe { self.ptr.add(idx).as_mut() }
139    }
140
141    pub fn set(&self, idx: usize, val: T) {
142        debug_assert!(idx < self.slice.len());
143        unsafe {
144            std::ptr::write(self.ptr.add(idx), val);
145        }
146    }
147
148    pub fn as_slice(&self) -> &[T] {
149        self.slice
150    }
151
152    #[allow(clippy::mut_from_ref)]
153    pub fn as_mut(&self) -> &mut [T] {
154        unsafe { std::slice::from_raw_parts_mut(self.ptr, self.len()) }
155    }
156
157    pub fn as_ptr(&self) -> *const T {
158        self.ptr as *const T
159    }
160
161    pub fn as_mut_ptr(&self) -> *mut T {
162        self.ptr
163    }
164
165    pub fn swap(&self, a: usize, b: usize) {
166        debug_assert!(a < self.slice.len());
167        debug_assert!(b < self.slice.len());
168        self.as_mut().swap(a, b);
169    }
170
171    #[allow(clippy::mut_from_ref)]
172    pub fn range(&self, start: usize, end: usize) -> &mut [T] {
173        debug_assert!(start < end, "start: {}, end: {}", start, end);
174        debug_assert!(
175            end <= self.len(),
176            "start: {}, end: {}, len: {}",
177            start,
178            end,
179            self.len()
180        );
181        unsafe { std::slice::from_raw_parts_mut(self.ptr.add(start), end - start) }
182    }
183}
184
185unsafe impl<'a, T> Send for UnsafeSliceWrapper<'a, T> {}
186
187unsafe impl<'a, T> Sync for UnsafeSliceWrapper<'a, T> {}
188
189pub struct TaskSpawner {
190    pub threads_in_flight: Arc<AtomicUsize>,
191    config: TaskConfig,
192}
193
194#[derive(Debug, Copy, Clone)]
195struct TaskConfig {
196    pub work_size_threshold: usize,
197    pub max_depth: usize,
198    pub max_leaf_size: usize,
199}
200
201pub trait Task: Sized + Send + Sync {
202    fn run(self) -> Option<(Self, Self)>;
203    fn work_size(&self) -> usize;
204    fn depth(&self) -> usize;
205}
206
207#[allow(dead_code)]
208impl TaskSpawner {
209    pub fn new() -> Self {
210        Self {
211            config: TaskConfig {
212                work_size_threshold: 1024,
213                max_depth: 64,
214                max_leaf_size: 16,
215            },
216            threads_in_flight: Arc::new(AtomicUsize::new(0)),
217        }
218    }
219
220    pub fn with_work_size_threshold(mut self, threshold: usize) -> Self {
221        self.config.work_size_threshold = threshold;
222        self
223    }
224
225    pub fn with_max_depth(mut self, depth: usize) -> Self {
226        self.config.max_depth = depth;
227        self
228    }
229
230    pub fn with_max_leaf_size(mut self, max_leaf_size: usize) -> Self {
231        self.config.max_leaf_size = max_leaf_size;
232        self
233    }
234
235    pub fn run<T: Task>(&self, first_task: T) {
236        let thread_count = self.threads_in_flight.clone();
237        crossbeam::scope(move |s| {
238            Self::run_task(first_task, self.config, thread_count, s);
239        })
240        .unwrap();
241    }
242
243    fn run_task<'a, T: Task + Sized + 'a>(
244        task: T,
245        config: TaskConfig,
246        thread_count: Arc<AtomicUsize>,
247        scope: &crossbeam::thread::Scope<'a>,
248    ) {
249        thread_count.fetch_add(1, Ordering::SeqCst);
250        let mut sub_tasks = Vec::new();
251
252        let mut stack: Vec<T> = vec![task];
253
254        while !stack.is_empty() {
255            let work_item = stack.pop().unwrap();
256            debug_assert!(work_item.depth() <= config.max_depth);
257
258            if let Some((mut task_a, mut task_b)) = work_item.run() {
259                if task_a.work_size() < task_b.work_size() {
260                    // Push more work to new thread
261                    std::mem::swap(&mut task_a, &mut task_b);
262                }
263
264                stack.push(task_b);
265
266                // Remove mutability
267                let task_a = task_a;
268
269                // If threshold is not met, push to stack instead of spawning new thread
270                if task_a.work_size() <= config.work_size_threshold {
271                    stack.push(task_a);
272                    continue;
273                }
274
275                // Spawn new thread
276                let count = thread_count.clone();
277                sub_tasks.push(scope.spawn(move |s| {
278                    Self::run_task(task_a, config, count, s);
279                }));
280            }
281        }
282
283        // Join thread handles
284        while !sub_tasks.is_empty() {
285            let r = sub_tasks.pop().unwrap();
286            r.join().unwrap();
287        }
288    }
289}
290
291#[cfg(test)]
292mod tests {
293    use crate::utils::*;
294
295    #[test]
296    fn prefix_sum_u32_works() {
297        type TestType = u32;
298
299        let input: [TestType; 6] = [1, 2, 3, 4, 5, 6];
300        let output: [TestType; 6] = [1, 3, 6, 10, 15, 21];
301
302        let mut storage: Vec<TestType> = vec![0; 6];
303
304        prefix_sum(&input, 6, storage.as_mut_slice());
305        for i in 0..6 {
306            assert_eq!(output[i], storage[i]);
307        }
308    }
309
310    #[test]
311    fn prefix_sum_usize_works() {
312        type TestType = usize;
313
314        let input: [TestType; 6] = [1, 2, 3, 4, 5, 6];
315        let output: [TestType; 6] = [1, 3, 6, 10, 15, 21];
316
317        let mut storage: Vec<TestType> = vec![0; 6];
318
319        prefix_sum(&input, 6, storage.as_mut_slice());
320        for i in 0..6 {
321            assert_eq!(output[i], storage[i]);
322        }
323    }
324
325    #[test]
326    fn prefix_sum_i32_works() {
327        type TestType = i32;
328
329        let input: [TestType; 6] = [1, 2, 3, 4, 5, 6];
330        let output: [TestType; 6] = [1, 3, 6, 10, 15, 21];
331
332        let mut storage: Vec<TestType> = vec![0; 6];
333
334        prefix_sum(&input, 6, storage.as_mut_slice());
335        for i in 0..6 {
336            assert_eq!(output[i], storage[i]);
337        }
338    }
339
340    #[test]
341    fn prefix_sum_zero() {
342        let input: [u32; 6] = [1, 2, 3, 4, 5, 6];
343        let mut storage: Vec<u32> = vec![0; 6];
344        assert_eq!(input[0], prefix_sum(&input, 0, storage.as_mut()))
345    }
346
347    #[test]
348    fn test_move_backwards() {
349        let mut src: [u32; 3] = [0, 1, 2];
350        let mut dest: [u32; 3] = [0; 3];
351
352        for i in 0..3 {
353            assert_eq!(src[i], i as u32);
354            assert_eq!(dest[i], 0);
355        }
356
357        unsafe {
358            move_backward(
359                src.as_mut_ptr(),
360                src.as_mut_ptr().add(src.len()),
361                dest.as_mut_ptr().add(dest.len()),
362            );
363        }
364
365        for i in 0..3 {
366            assert_eq!(src[i], dest[i]);
367        }
368    }
369}