krnl_core/
kernel.rs

1#[doc(hidden)]
2#[cfg(target_arch = "spirv")]
3pub mod __private {
4    use super::{ItemKernel, Kernel};
5    use core::mem::size_of;
6
7    pub struct KernelArgs {
8        pub global_id: u32,
9        pub groups: u32,
10        pub group_id: u32,
11        pub subgroups: u32,
12        pub subgroup_id: u32,
13        //pub subgroup_threads: u32,
14        pub subgroup_thread_id: u32,
15        pub threads: u32,
16        pub thread_id: u32,
17    }
18
19    #[allow(deprecated)]
20    impl KernelArgs {
21        #[inline]
22        pub unsafe fn into_kernel(self) -> Kernel {
23            let Self {
24                global_id,
25                groups,
26                group_id,
27                subgroups,
28                subgroup_id,
29                //subgroup_threads,
30                subgroup_thread_id,
31                threads,
32                thread_id,
33            } = self;
34            Kernel {
35                global_threads: groups * threads,
36                global_id,
37                groups,
38                group_id,
39                subgroups,
40                subgroup_id,
41                //subgroup_threads,
42                subgroup_thread_id,
43                threads,
44                thread_id,
45            }
46        }
47    }
48
49    // ensures __krnl_kernel_data is used, and not optimized away
50    // removed by krnlc
51    #[inline]
52    pub unsafe fn kernel_data(data: &mut [u32]) {
53        use spirv_std::arch::IndexUnchecked;
54
55        unsafe {
56            *data.index_unchecked_mut(0) = 1;
57        }
58    }
59
60    // passes the length (constant, spec constant, or spec const expr) to krnlc
61    // the array is changed from len 1 to the constant
62    #[inline]
63    pub unsafe fn group_buffer_len(data: &mut [u32], index: usize, len: usize) {
64        use spirv_std::arch::IndexUnchecked;
65
66        unsafe {
67            *data.index_unchecked_mut(index) = if len > 0 { len as u32 } else { 1 };
68        }
69    }
70
71    #[inline]
72    pub unsafe fn zero_group_buffer<T: Default + Copy>(
73        kernel: &Kernel,
74        buffer: &mut [T; 1],
75        len: usize,
76    ) {
77        use spirv_std::arch::IndexUnchecked;
78
79        let stride = {
80            if size_of::<T>() == 1 {
81                4
82            } else if size_of::<T>() == 2 {
83                2
84            } else {
85                1
86            }
87        };
88
89        let mut index = kernel.thread_id() * stride;
90        if index < kernel.threads() * stride {
91            while index < len {
92                unsafe {
93                    *buffer.index_unchecked_mut(index) = T::default();
94                }
95                if stride >= 2 {
96                    if index + 1 < len {
97                        unsafe {
98                            *buffer.index_unchecked_mut(index + 1) = T::default();
99                        }
100                    }
101                }
102                if stride == 4 {
103                    if index + 2 < len {
104                        unsafe {
105                            *buffer.index_unchecked_mut(index + 2) = T::default();
106                        }
107                    }
108                    if index + 3 < len {
109                        unsafe {
110                            *buffer.index_unchecked_mut(index + 3) = T::default();
111                        }
112                    }
113                }
114                index += kernel.threads() * stride;
115            }
116        }
117    }
118
119    pub struct ItemKernelArgs {
120        pub items: u32,
121        pub item_id: u32,
122    }
123
124    #[allow(deprecated)]
125    impl ItemKernelArgs {
126        #[inline]
127        pub unsafe fn into_item_kernel(self) -> ItemKernel {
128            let Self { items, item_id } = self;
129            ItemKernel { items, item_id }
130        }
131    }
132}
133
134pub struct Kernel {
135    global_threads: u32,
136    global_id: u32,
137    groups: u32,
138    group_id: u32,
139    subgroups: u32,
140    subgroup_id: u32,
141    //subgroup_threads: u32,
142    subgroup_thread_id: u32,
143    threads: u32,
144    thread_id: u32,
145}
146
147impl Kernel {
148    /// The number of global threads.
149    ///
150    /// `global_threads = groups * threads`
151    #[inline]
152    pub fn global_threads(&self) -> usize {
153        self.global_threads as usize
154    }
155    /// The global thread id.
156    ///
157    /// `global_id = group_id * threads + thread_id`
158    #[inline]
159    pub fn global_id(&self) -> usize {
160        self.global_id as usize
161    }
162    /// The number of thread groups.
163    #[inline]
164    pub fn groups(&self) -> usize {
165        self.groups as usize
166    }
167    /// The group id.
168    #[inline]
169    pub fn group_id(&self) -> usize {
170        self.group_id as usize
171    }
172    /// The number of subgroups per group.
173    #[inline]
174    pub fn subgroups(&self) -> usize {
175        self.subgroups as usize
176    }
177    /// The subgroup id.
178    #[inline]
179    pub fn subgroup_id(&self) -> usize {
180        self.subgroup_id as usize
181    }
182    // TODO: Potentially implement via subgroup ballot / reduce operation
183    /*
184    /// The number of threads per subgroup.
185    #[inline]
186    pub fn subgroup_threads(&self) -> usize {
187        self.subgroup_threads as usize
188    }
189    */
190    /// The subgroup thread id.
191    #[inline]
192    pub fn subgroup_thread_id(&self) -> usize {
193        self.subgroup_thread_id as usize
194    }
195    /// The number of threads per group.
196    #[inline]
197    pub fn threads(&self) -> usize {
198        self.threads as usize
199    }
200    /// The thread id.
201    #[inline]
202    pub fn thread_id(&self) -> usize {
203        self.thread_id as usize
204    }
205}
206
207pub struct ItemKernel {
208    items: u32,
209    item_id: u32,
210}
211
212impl ItemKernel {
213    /// The number of items.
214    ///
215    /// This will be the minimum length of buffers with `#[item]`.
216    #[inline]
217    pub fn items(&self) -> usize {
218        self.items as usize
219    }
220    /// The id of the item.
221    #[inline]
222    pub fn item_id(&self) -> usize {
223        self.item_id as usize
224    }
225}