Skip to main content

khal_std/
index.rs

1/// Indexing trait that optionally removes bounds checks on GPU targets.
2///
3/// When the `unsafe_remove_boundchecks` feature is enabled, methods use
4/// unchecked indexing for performance on SPIR-V and CUDA targets.
5/// Otherwise, standard bounds-checked indexing is used.
6pub trait MaybeIndexUnchecked<T> {
7    /// Returns a reference to the element at `id`.
8    fn at(&self, id: usize) -> &T;
9    /// Returns a mutable reference to the element at `id`.
10    fn at_mut(&mut self, id: usize) -> &mut T;
11    /// Copies and returns the element at `id`.
12    fn read(&self, id: usize) -> T;
13    /// Writes `data` to the element at `id`.
14    fn write(&mut self, id: usize, data: T);
15}
16
17impl<T: Copy> MaybeIndexUnchecked<T> for [T] {
18    #[inline(always)]
19    fn at(&self, id: usize) -> &T {
20        #[cfg(all(feature = "unsafe_remove_boundchecks", target_arch = "nvptx64"))]
21        return unsafe { self.get_unchecked(id) };
22        #[cfg(all(feature = "unsafe_remove_boundchecks", not(target_arch = "nvptx64")))]
23        return unsafe {
24            use spirv_std::arch::IndexUnchecked;
25            self.index_unchecked(id)
26        };
27        #[cfg(not(feature = "unsafe_remove_boundchecks"))]
28        return &self[id];
29    }
30
31    #[inline(always)]
32    fn at_mut(&mut self, id: usize) -> &mut T {
33        #[cfg(all(feature = "unsafe_remove_boundchecks", target_arch = "nvptx64"))]
34        return unsafe { self.get_unchecked_mut(id) };
35        #[cfg(all(feature = "unsafe_remove_boundchecks", not(target_arch = "nvptx64")))]
36        return unsafe {
37            use spirv_std::arch::IndexUnchecked;
38            self.index_unchecked_mut(id)
39        };
40        #[cfg(not(feature = "unsafe_remove_boundchecks"))]
41        return &mut self[id];
42    }
43
44    #[inline(always)]
45    fn read(&self, id: usize) -> T {
46        *self.at(id)
47    }
48
49    #[inline(always)]
50    fn write(&mut self, id: usize, data: T) {
51        *self.at_mut(id) = data;
52    }
53}
54
55impl<T: Copy, const N: usize> MaybeIndexUnchecked<T> for [T; N] {
56    #[inline(always)]
57    fn at(&self, id: usize) -> &T {
58        #[cfg(all(feature = "unsafe_remove_boundchecks", target_arch = "nvptx64"))]
59        return unsafe { self.get_unchecked(id) };
60        #[cfg(all(feature = "unsafe_remove_boundchecks", not(target_arch = "nvptx64")))]
61        return unsafe {
62            use spirv_std::arch::IndexUnchecked;
63            self.index_unchecked(id)
64        };
65        #[cfg(not(feature = "unsafe_remove_boundchecks"))]
66        return &self[id];
67    }
68
69    #[inline(always)]
70    fn at_mut(&mut self, id: usize) -> &mut T {
71        #[cfg(all(feature = "unsafe_remove_boundchecks", target_arch = "nvptx64"))]
72        return unsafe { self.get_unchecked_mut(id) };
73        #[cfg(all(feature = "unsafe_remove_boundchecks", not(target_arch = "nvptx64")))]
74        return unsafe {
75            use spirv_std::arch::IndexUnchecked;
76            self.index_unchecked_mut(id)
77        };
78        #[cfg(not(feature = "unsafe_remove_boundchecks"))]
79        return &mut self[id];
80    }
81
82    #[inline(always)]
83    fn read(&self, id: usize) -> T {
84        *self.at(id)
85    }
86
87    #[inline(always)]
88    fn write(&mut self, id: usize, data: T) {
89        *self.at_mut(id) = data;
90    }
91}