1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
#[doc(hidden)]
#[cfg(target_arch = "spirv")]
pub mod __private {
    use super::{ItemKernel, Kernel};

    pub struct KernelArgs {
        pub global_threads: u32,
        pub global_id: u32,
        pub groups: u32,
        pub group_id: u32,
        pub subgroups: u32,
        pub subgroup_id: u32,
        pub subgroup_threads: u32,
        pub subgroup_thread_id: u32,
        pub threads: u32,
        pub thread_id: u32,
    }

    #[allow(deprecated)]
    impl KernelArgs {
        #[inline]
        pub unsafe fn into_kernel(self) -> Kernel {
            let Self {
                global_threads,
                global_id,
                groups,
                group_id,
                subgroups,
                subgroup_id,
                subgroup_threads,
                subgroup_thread_id,
                threads,
                thread_id,
            } = self;
            Kernel {
                global_threads,
                global_id,
                groups,
                group_id,
                subgroups,
                subgroup_id,
                subgroup_threads,
                subgroup_thread_id,
                threads,
                thread_id,
            }
        }
    }

    #[inline]
    pub unsafe fn zero_group_buffer<T: Default + Copy>(
        kernel: &Kernel,
        buffer: &mut [T; 1],
        len: usize,
    ) {
        use spirv_std::arch::IndexUnchecked;

        let mut index = kernel.thread_id();
        let threads = kernel.threads();
        if index < threads {
            while index < len {
                unsafe {
                    *buffer.index_unchecked_mut(index) = T::default();
                }
                index += threads;
            }
        }
    }

    pub struct ItemKernelArgs {
        pub items: u32,
        pub item_id: u32,
    }

    #[allow(deprecated)]
    impl ItemKernelArgs {
        #[inline]
        pub unsafe fn into_item_kernel(self) -> ItemKernel {
            let Self { items, item_id } = self;
            ItemKernel { items, item_id }
        }
    }
}

#[non_exhaustive]
pub struct Kernel {
    #[doc(hidden)]
    #[deprecated(since = "0.0.4", note = "replaced with global_threads()")]
    pub global_threads: u32,
    #[doc(hidden)]
    #[deprecated(since = "0.0.4", note = "replaced with global_id()")]
    pub global_id: u32,
    #[doc(hidden)]
    #[deprecated(since = "0.0.4", note = "replaced with groups()")]
    pub groups: u32,
    #[doc(hidden)]
    #[deprecated(since = "0.0.4", note = "replaced with group_id()")]
    pub group_id: u32,
    #[doc(hidden)]
    #[deprecated(since = "0.0.4", note = "replaced with subgroups()")]
    pub subgroups: u32,
    #[doc(hidden)]
    #[deprecated(since = "0.0.4", note = "replaced with subgroup_id()")]
    pub subgroup_id: u32,
    #[allow(unused)]
    subgroup_threads: u32,
    #[doc(hidden)]
    #[deprecated(since = "0.0.4", note = "replaced with subgroup_thread_id()")]
    pub subgroup_thread_id: u32,
    #[doc(hidden)]
    #[deprecated(since = "0.0.4", note = "replaced with threads()")]
    pub threads: u32,
    #[doc(hidden)]
    #[deprecated(since = "0.0.4", note = "replaced with thread_id()")]
    pub thread_id: u32,
}

#[allow(deprecated)]
impl Kernel {
    /// The number of global threads.
    ///
    /// `global_threads = groups * threads`
    #[inline]
    pub fn global_threads(&self) -> usize {
        self.global_threads as usize
    }
    /// The global thread id.
    ///
    /// `global_id = group_id * threads + thread_id`
    #[inline]
    pub fn global_id(&self) -> usize {
        self.global_id as usize
    }
    /// The number of thread groups.
    #[inline]
    pub fn groups(&self) -> usize {
        self.groups as usize
    }
    /// The group id.
    #[inline]
    pub fn group_id(&self) -> usize {
        self.group_id as usize
    }
    /// The number of subgroups per group.
    #[inline]
    pub fn subgroups(&self) -> usize {
        self.subgroups as usize
    }
    /// The subgroup id.
    #[inline]
    pub fn subgroup_id(&self) -> usize {
        self.subgroup_id as usize
    }
    // TODO: Intel Mesa driver uses variable subgroup size
    // Fixed in https://github.com/charles-r-earp/krnl/tree/update-vulkano
    /*
    /// The number of threads per subgroup.
    #[inline]
    pub fn subgroup_threads(&self) -> usize {
        self.subgroup_threads as usize
    }
    */
    /// The subgroup thread id.
    #[inline]
    pub fn subgroup_thread_id(&self) -> usize {
        self.subgroup_thread_id as usize
    }
    /// The number of threads per group.
    #[inline]
    pub fn threads(&self) -> usize {
        self.threads as usize
    }
    /// The thread id.
    #[inline]
    pub fn thread_id(&self) -> usize {
        self.thread_id as usize
    }
}

#[non_exhaustive]
pub struct ItemKernel {
    #[doc(hidden)]
    #[deprecated(since = "0.0.4", note = "replaced with items()")]
    pub items: u32,
    #[doc(hidden)]
    #[deprecated(since = "0.0.4", note = "replaced with item_id()")]
    pub item_id: u32,
}

#[allow(deprecated)]
impl ItemKernel {
    /// The number of items.
    ///
    /// This will be the minimum length of buffers with `#[item]`.
    #[inline]
    pub fn items(&self) -> usize {
        self.items as usize
    }
    /// The id of the item.
    #[inline]
    pub fn item_id(&self) -> usize {
        self.item_id as usize
    }
}