krnl_core/
buffer.rs

1use crate::scalar::Scalar;
2#[cfg(not(target_arch = "spirv"))]
3use core::marker::PhantomData;
4use core::ops::Index;
5#[cfg(target_arch = "spirv")]
6use core::{arch::asm, mem::MaybeUninit};
7#[cfg(target_arch = "spirv")]
8use spirv_std::arch::IndexUnchecked;
9
10/** Unsafe Index trait.
11
12Like [Index], performs checked indexing, but the caller must ensure that there is no aliasing of a mutable reference.
13*/
14pub trait UnsafeIndex<Idx> {
15    /// The returned type after indexing.
16    type Output;
17    /// Immutably indexes with `index`.
18    /// # Safety
19    /// The caller must ensure that the returned reference is not aliased by a mutable borrow, ie by a call to `.unsafe_index_mut()` with the same index.
20    unsafe fn unsafe_index(&self, index: Idx) -> &Self::Output;
21    /// Mutably indexes with `index`.
22    /// # Safety
23    /// The caller must ensure that the returned reference is not aliased by another borrow, ie by a call to `.unsafe_index()` or `.unsafe_index_mut()` with the same index.
24    #[allow(clippy::mut_from_ref)]
25    unsafe fn unsafe_index_mut(&self, index: Idx) -> &mut Self::Output;
26}
27
28#[cfg(target_arch = "spirv")]
29trait IndexUncheckedMutExt<T> {
30    unsafe fn index_unchecked_mut_ext(&self, index: usize) -> &mut T;
31}
32
33#[cfg(target_arch = "spirv")]
34impl<T, const N: usize> IndexUncheckedMutExt<T> for [T; N] {
35    #[inline]
36    unsafe fn index_unchecked_mut_ext(&self, index: usize) -> &mut T {
37        let mut output = MaybeUninit::uninit();
38        unsafe {
39            asm!(
40                "%val_ptr = OpInBoundsAccessChain _ {array_ptr} {index}",
41                "OpStore {output} %val_ptr",
42                array_ptr = in(reg) self,
43                index = in(reg) index,
44                output = in(reg) output.as_mut_ptr(),
45            );
46            output.assume_init()
47        }
48    }
49}
50
51mod sealed {
52    pub trait Sealed {}
53}
54use sealed::Sealed;
55
56/// Base trait for [`BufferBase`] representation.
57#[allow(clippy::len_without_is_empty)]
58pub trait DataBase: Sealed {
59    /// The numerical type of the buffer.
60    type Elem: Scalar;
61    #[doc(hidden)]
62    fn len(&self) -> usize;
63}
64
65/// Marker trait for immutable access.
66///
67/// See [`Slice`].
68pub trait Data: DataBase + Index<usize, Output = Self::Elem> {}
69/// Marker trait for unsafe access.
70///
71/// See [`UnsafeSlice`].
72pub trait UnsafeData: DataBase + UnsafeIndex<usize, Output = Self::Elem> {}
73
74/// [`Slice`] representation.
75#[derive(Clone, Copy)]
76pub struct SliceRepr<'a, T> {
77    #[cfg(not(target_arch = "spirv"))]
78    inner: &'a [T],
79    #[cfg(target_arch = "spirv")]
80    inner: &'a [T; 1],
81    #[cfg(target_arch = "spirv")]
82    offset: usize,
83    #[cfg(target_arch = "spirv")]
84    len: usize,
85}
86
87impl<T> Sealed for SliceRepr<'_, T> {}
88
89impl<T: Scalar> DataBase for SliceRepr<'_, T> {
90    type Elem = T;
91    #[cfg(not(target_arch = "spirv"))]
92    #[inline]
93    fn len(&self) -> usize {
94        self.inner.len()
95    }
96    #[cfg(target_arch = "spirv")]
97    #[inline]
98    fn len(&self) -> usize {
99        self.len
100    }
101}
102
103impl<T: Scalar> Index<usize> for SliceRepr<'_, T> {
104    type Output = T;
105    #[inline]
106    fn index(&self, index: usize) -> &Self::Output {
107        #[cfg(target_arch = "spirv")]
108        if index < self.len {
109            unsafe { self.inner.index_unchecked(self.offset + index) }
110        } else {
111            let len = self.len;
112            panic!("index out of bounds: the len is {index} but the index is {len}")
113        }
114        #[cfg(not(target_arch = "spirv"))]
115        self.inner.index(index)
116    }
117}
118
119impl<T: Scalar> Data for SliceRepr<'_, T> {}
120
121/// [`UnsafeSlice`] representation.
122#[derive(Clone, Copy)]
123pub struct UnsafeSliceRepr<'a, T> {
124    #[cfg(not(target_arch = "spirv"))]
125    ptr: *mut T,
126    #[cfg(target_arch = "spirv")]
127    #[allow(unused)]
128    inner: &'a [T; 1],
129    #[cfg(target_arch = "spirv")]
130    #[allow(unused)]
131    offset: usize,
132    len: usize,
133    #[cfg(not(target_arch = "spirv"))]
134    _m: PhantomData<&'a ()>,
135}
136
137impl<T> Sealed for UnsafeSliceRepr<'_, T> {}
138
139impl<T: Scalar> DataBase for UnsafeSliceRepr<'_, T> {
140    type Elem = T;
141    #[inline]
142    fn len(&self) -> usize {
143        self.len
144    }
145}
146
147impl<T: Scalar> UnsafeIndex<usize> for UnsafeSliceRepr<'_, T> {
148    type Output = T;
149    #[inline]
150    unsafe fn unsafe_index(&self, index: usize) -> &Self::Output {
151        if index < self.len {
152            #[cfg(target_arch = "spirv")]
153            unsafe {
154                self.inner.index_unchecked(self.offset + index)
155            }
156            #[cfg(not(target_arch = "spirv"))]
157            unsafe {
158                &*self.ptr.add(index)
159            }
160        } else {
161            let len = self.len;
162            panic!("index out of bounds: the len is {index} but the index is {len}")
163        }
164    }
165    #[inline]
166    unsafe fn unsafe_index_mut(&self, index: usize) -> &mut Self::Output {
167        if index < self.len {
168            #[cfg(target_arch = "spirv")]
169            unsafe {
170                self.inner.index_unchecked_mut_ext(self.offset + index)
171            }
172            #[cfg(not(target_arch = "spirv"))]
173            unsafe {
174                &mut *self.ptr.add(index)
175            }
176        } else {
177            let len = self.len();
178            panic!("index out of bounds: the len is {index} but the index is {len}")
179        }
180    }
181}
182
183impl<T: Scalar> UnsafeData for UnsafeSliceRepr<'_, T> {}
184
185unsafe impl<T: Send> Send for UnsafeSliceRepr<'_, T> {}
186unsafe impl<T: Sync> Sync for UnsafeSliceRepr<'_, T> {}
187
188/// A buffer.
189///
190/// [`Slice`] implements [`Index`] and [`UnsafeSlice`] implements [`UnsafeIndex`].
191#[derive(Clone, Copy)]
192pub struct BufferBase<S> {
193    data: S,
194}
195
196/// [`Slice`] implements [`Index`].
197///
198/// See [`BufferBase`].
199pub type Slice<'a, T> = BufferBase<SliceRepr<'a, T>>;
200/// [`UnsafeSlice`] implements [`UnsafeIndex`].
201///
202/// See [`BufferBase`].
203pub type UnsafeSlice<'a, T> = BufferBase<UnsafeSliceRepr<'a, T>>;
204
205impl<S: DataBase> BufferBase<S> {
206    /// The length of the buffer.
207    #[inline]
208    pub fn len(&self) -> usize {
209        self.data.len()
210    }
211    /// Whether the buffer is empty.
212    #[inline]
213    pub fn is_empty(&self) -> bool {
214        self.len() == 0
215    }
216}
217
218impl<S: Data> Index<usize> for BufferBase<S> {
219    type Output = S::Elem;
220    #[inline]
221    fn index(&self, index: usize) -> &Self::Output {
222        self.data.index(index)
223    }
224}
225
226impl<S: UnsafeData> UnsafeIndex<usize> for BufferBase<S> {
227    type Output = S::Elem;
228    /// # Safety
229    /// The caller must ensure that the returned reference is not aliased by a mutable borrow, ie by a call to `.unsafe_index_mut()` with the same index.
230    #[inline]
231    unsafe fn unsafe_index(&self, index: usize) -> &Self::Output {
232        unsafe { self.data.unsafe_index(index) }
233    }
234    /// # Safety
235    /// The caller must ensure that the returned reference is not aliased by another borrow, ie by a call to `.unsafe_index()` or `.unsafe_index_mut()` with the same index.
236    #[inline]
237    unsafe fn unsafe_index_mut(&self, index: usize) -> &mut Self::Output {
238        unsafe { self.data.unsafe_index_mut(index) }
239    }
240}
241
242impl<'a, T: Scalar> Slice<'a, T> {
243    // For kernel macro.
244    #[doc(hidden)]
245    #[cfg(target_arch = "spirv")]
246    #[inline]
247    pub unsafe fn from_raw_parts(inner: &'a [T; 1], offset: usize, len: usize) -> Self {
248        let data = SliceRepr { inner, offset, len };
249        Self { data }
250    }
251    /// A pointer to the buffer's data.
252    #[cfg(not(target_arch = "spirv"))]
253    #[inline]
254    pub fn as_ptr(&self) -> *const T {
255        self.data.inner.as_ptr()
256    }
257}
258
259impl<'a, T: Scalar> UnsafeSlice<'a, T> {
260    // For kernel macro.
261    #[doc(hidden)]
262    #[cfg(target_arch = "spirv")]
263    #[inline]
264    pub unsafe fn from_unsafe_raw_parts(inner: &'a [T; 1], offset: usize, len: usize) -> Self {
265        let data = UnsafeSliceRepr {
266            inner: &*inner,
267            offset,
268            len,
269        };
270        Self { data }
271    }
272    /// A mutable pointer to the buffer's data.
273    #[cfg(not(target_arch = "spirv"))]
274    #[inline]
275    pub fn as_mut_ptr(&self) -> *mut T {
276        self.data.ptr
277    }
278}
279
280#[cfg(not(target_arch = "spirv"))]
281impl<'a, T: Scalar> From<&'a [T]> for Slice<'a, T> {
282    #[inline]
283    fn from(slice: &'a [T]) -> Self {
284        let data = SliceRepr { inner: slice };
285        Self { data }
286    }
287}
288
289#[cfg(not(target_arch = "spirv"))]
290impl<'a, T: Scalar> From<Slice<'a, T>> for &'a [T] {
291    #[inline]
292    fn from(slice: Slice<'a, T>) -> &'a [T] {
293        slice.data.inner
294    }
295}
296
297#[cfg(not(target_arch = "spirv"))]
298impl<'a, T: Scalar> From<&'a mut [T]> for UnsafeSlice<'a, T> {
299    #[inline]
300    fn from(slice: &'a mut [T]) -> Self {
301        let data = UnsafeSliceRepr {
302            ptr: slice.as_mut_ptr(),
303            len: slice.len(),
304            _m: PhantomData,
305        };
306        Self { data }
307    }
308}