1use crate::sealed::Sealed;
2use spirv_cross_sys::{spvc_constant, spvc_specialization_constant, TypeId};
3use std::mem::MaybeUninit;
4use std::ops::{Index, IndexMut};
5use std::slice;
6
7use crate::error::{SpirvCrossError, ToContextError};
8use crate::handle::{ConstantId, Handle};
9use crate::iter::impl_iterator;
10use crate::{error, Compiler, PhantomCompiler};
11use spirv_cross_sys as sys;
12
13mod gfx_maths;
14mod glam;
15mod half;
16
17pub trait ConstantScalar: Default + Sealed + Copy {
19 #[doc(hidden)]
20 unsafe fn get(constant: spvc_constant, column: u32, row: u32) -> Self;
21
22 #[doc(hidden)]
23 unsafe fn set(constant: spvc_constant, column: u32, row: u32, value: Self);
24}
25
26macro_rules! impl_spvc_constant {
27 ($get:ident $set:ident $prim:ty) => {
28 impl Sealed for $prim {}
29 impl ConstantScalar for $prim {
30 unsafe fn get(constant: spvc_constant, column: u32, row: u32) -> Self {
31 unsafe { ::spirv_cross_sys::$get(constant, column, row) as Self }
32 }
33
34 unsafe fn set(constant: spvc_constant, column: u32, row: u32, value: Self) {
35 unsafe { ::spirv_cross_sys::$set(constant, column, row, value) }
36 }
37 }
38 };
39}
40
41#[allow(unused_macros)]
42macro_rules! impl_vec_constant {
43 ($vec_ty:ty [$base_ty:ty; $len:literal] for [$($component:ident),*]) => {
44 impl $crate::sealed::Sealed for $vec_ty {}
45 impl $crate::reflect::constants::ConstantValue for $vec_ty {
46 const COLUMNS: usize = 1;
47 const VECSIZE: usize = $len;
48 type BaseArrayType = [$base_ty; $len];
49 type ArrayType = [[$base_ty; $len]; 1];
50 type BaseType = $base_ty;
51
52 fn from_array(value: Self::ArrayType) -> Self {
53 value[0].into()
54 }
55
56 fn to_array(value: Self) -> Self::ArrayType {
57 [[$(value.$component),*]]
58 }
59 }
60 };
61}
62
63impl_spvc_constant!(spvc_constant_get_scalar_i8 spvc_constant_set_scalar_i8 i8);
64impl_spvc_constant!(spvc_constant_get_scalar_i16 spvc_constant_set_scalar_i16 i16);
65impl_spvc_constant!(spvc_constant_get_scalar_i32 spvc_constant_set_scalar_i32 i32);
66impl_spvc_constant!(spvc_constant_get_scalar_i64 spvc_constant_set_scalar_i64 i64);
67
68impl_spvc_constant!(spvc_constant_get_scalar_u8 spvc_constant_set_scalar_u8 u8);
69impl_spvc_constant!(spvc_constant_get_scalar_u16 spvc_constant_set_scalar_u16 u16);
70impl_spvc_constant!(spvc_constant_get_scalar_u32 spvc_constant_set_scalar_u32 u32);
71impl_spvc_constant!(spvc_constant_get_scalar_u64 spvc_constant_set_scalar_u64 u64);
72
73impl_spvc_constant!(spvc_constant_get_scalar_fp32 spvc_constant_set_scalar_fp32 f32);
74impl_spvc_constant!(spvc_constant_get_scalar_fp64 spvc_constant_set_scalar_fp64 f64);
75
76impl Sealed for bool {}
78impl ConstantScalar for bool {
79 unsafe fn get(constant: spvc_constant, column: u32, row: u32) -> Self {
80 unsafe { sys::spvc_constant_get_scalar_u8(constant, column, row) != 0 }
81 }
82
83 unsafe fn set(constant: spvc_constant, column: u32, row: u32, value: Self) {
84 sys::spvc_constant_set_scalar_u8(constant, column, row, if value { 1 } else { 0 });
85 }
86}
87
88#[derive(Debug, Clone)]
90pub struct SpecializationConstant {
91 pub id: Handle<ConstantId>,
93 pub constant_id: u32,
95}
96
97#[derive(Debug, Clone)]
99pub struct WorkgroupSizeSpecializationConstants {
100 pub x: Option<SpecializationConstant>,
102 pub y: Option<SpecializationConstant>,
104 pub z: Option<SpecializationConstant>,
106 pub builtin_workgroup_size_handle: Option<Handle<ConstantId>>,
108}
109
110pub struct SpecializationConstantIter<'a>(
112 PhantomCompiler,
113 slice::Iter<'a, spvc_specialization_constant>,
114);
115
116impl_iterator!(SpecializationConstantIter<'_>: SpecializationConstant as map |s, o: &spvc_specialization_constant| {
117 SpecializationConstant {
118 id: s.0.create_handle(o.id),
119 constant_id: o.constant_id,
120 }
121} for [1]);
122
123pub struct SpecializationSubConstantIter<'a>(PhantomCompiler, slice::Iter<'a, ConstantId>);
126
127impl_iterator!(SpecializationSubConstantIter<'_>: Handle<ConstantId> as map |s, o: &ConstantId| {
128 s.0.create_handle(*o)
129} for [1]);
130
131impl<T> Compiler<T> {
133 unsafe fn bounds_check_constant(
135 handle: spvc_constant,
136 column: u32,
137 row: u32,
138 ) -> error::Result<()> {
139 if column >= 4 || row >= 4 {
141 return Err(SpirvCrossError::IndexOutOfBounds { row, column });
142 }
143
144 let vecsize = sys::spvc_rs_constant_get_vecsize(handle);
145 let colsize = sys::spvc_rs_constant_get_matrix_colsize(handle);
146
147 if column >= colsize || row >= vecsize {
148 return Err(SpirvCrossError::IndexOutOfBounds { row, column });
149 }
150
151 Ok(())
152 }
153
154 pub fn set_specialization_constant_scalar<S: ConstantScalar>(
166 &mut self,
167 handle: Handle<ConstantId>,
168 column: u32,
169 row: u32,
170 value: S,
171 ) -> error::Result<()> {
172 let constant = self.yield_id(handle)?;
173 unsafe {
174 let handle = sys::spvc_compiler_get_constant_handle(self.ptr.as_ptr(), constant);
176 Self::bounds_check_constant(handle, column, row)?;
177 S::set(handle, column, row, value)
178 }
179 Ok(())
180 }
181
182 pub fn specialization_constant_scalar<S: ConstantScalar>(
195 &self,
196 handle: Handle<ConstantId>,
197 column: u32,
198 row: u32,
199 ) -> error::Result<S> {
200 let constant = self.yield_id(handle)?;
201 unsafe {
202 let handle = sys::spvc_compiler_get_constant_handle(self.ptr.as_ptr(), constant);
204 Self::bounds_check_constant(handle, column, row)?;
205
206 Ok(S::get(handle, column, row))
207 }
208 }
209
210 pub fn specialization_constants(&self) -> error::Result<SpecializationConstantIter<'static>> {
212 unsafe {
213 let mut constants = std::ptr::null();
214 let mut size = 0;
215 sys::spvc_compiler_get_specialization_constants(
216 self.ptr.as_ptr(),
217 &mut constants,
218 &mut size,
219 )
220 .ok(self)?;
221
222 let slice = slice::from_raw_parts(constants, size);
225 Ok(SpecializationConstantIter(self.phantom(), slice.iter()))
226 }
227 }
228
229 pub fn specialization_sub_constants(
231 &self,
232 constant: Handle<ConstantId>,
233 ) -> error::Result<SpecializationSubConstantIter<'_>> {
234 let id = self.yield_id(constant)?;
235 unsafe {
236 let constant = sys::spvc_compiler_get_constant_handle(self.ptr.as_ptr(), id);
237 let mut constants = std::ptr::null();
238 let mut size = 0;
239 sys::spvc_constant_get_subconstants(constant, &mut constants, &mut size);
240
241 Ok(SpecializationSubConstantIter(
242 self.phantom(),
243 slice::from_raw_parts(constants, size).iter(),
244 ))
245 }
246 }
247
248 pub fn work_group_size_specialization_constants(&self) -> WorkgroupSizeSpecializationConstants {
269 unsafe {
270 let mut x = MaybeUninit::zeroed();
271 let mut y = MaybeUninit::zeroed();
272 let mut z = MaybeUninit::zeroed();
273
274 let constant = sys::spvc_compiler_get_work_group_size_specialization_constants(
275 self.ptr.as_ptr(),
276 x.as_mut_ptr(),
277 y.as_mut_ptr(),
278 z.as_mut_ptr(),
279 );
280
281 let constant = self.create_handle_if_not_zero(constant);
282
283 let x = x.assume_init();
284 let y = y.assume_init();
285 let z = z.assume_init();
286
287 let x = self
288 .create_handle_if_not_zero(x.id)
289 .map(|id| SpecializationConstant {
290 id,
291 constant_id: x.constant_id,
292 });
293
294 let y = self
295 .create_handle_if_not_zero(y.id)
296 .map(|id| SpecializationConstant {
297 id,
298 constant_id: y.constant_id,
299 });
300
301 let z = self
302 .create_handle_if_not_zero(z.id)
303 .map(|id| SpecializationConstant {
304 id,
305 constant_id: z.constant_id,
306 });
307
308 WorkgroupSizeSpecializationConstants {
309 x,
310 y,
311 z,
312 builtin_workgroup_size_handle: constant,
313 }
314 }
315 }
316
317 pub fn specialization_constant_type(
319 &self,
320 constant: Handle<ConstantId>,
321 ) -> error::Result<Handle<TypeId>> {
322 let constant = self.yield_id(constant)?;
323 let type_id = unsafe {
324 let constant = sys::spvc_compiler_get_constant_handle(self.ptr.as_ptr(), constant);
326 self.create_handle(sys::spvc_constant_get_type(constant))
327 };
328
329 Ok(type_id)
330 }
331}
332
333pub trait ConstantValue: Sealed + Sized {
335 #[doc(hidden)]
339 const COLUMNS: usize;
340 #[doc(hidden)]
341 const VECSIZE: usize;
342 #[doc(hidden)]
343 type BaseArrayType: Default + Index<usize, Output = Self::BaseType> + IndexMut<usize>;
344 #[doc(hidden)]
345 type ArrayType: Default + Index<usize, Output = Self::BaseArrayType> + IndexMut<usize>;
346 #[doc(hidden)]
347 type BaseType: ConstantScalar;
348
349 #[doc(hidden)]
350 fn from_array(value: Self::ArrayType) -> Self;
351
352 #[doc(hidden)]
353 fn to_array(value: Self) -> Self::ArrayType;
354}
355
356impl<T: ConstantScalar> ConstantValue for T {
357 const COLUMNS: usize = 1;
358 const VECSIZE: usize = 1;
359 type BaseArrayType = [T; 1];
360 type ArrayType = [[T; 1]; 1];
361 type BaseType = T;
362
363 fn from_array(value: Self::ArrayType) -> Self {
364 value[0][0]
365 }
366
367 fn to_array(value: Self) -> Self::ArrayType {
368 [[value]]
369 }
370}
371
372impl<T> Compiler<T> {
373 pub fn specialization_constant_value<S: ConstantValue>(
384 &self,
385 handle: Handle<ConstantId>,
386 ) -> error::Result<S> {
387 let constant = self.yield_id(handle)?;
388 unsafe {
389 let handle = sys::spvc_compiler_get_constant_handle(self.ptr.as_ptr(), constant);
391 let mut output = S::ArrayType::default();
393
394 Self::bounds_check_constant(handle, S::COLUMNS as u32 - 1, S::VECSIZE as u32 - 1)?;
396
397 for column in 0..S::COLUMNS {
398 for row in 0..S::VECSIZE {
399 let value = S::BaseType::get(handle, column as u32, row as u32);
400 output[column][row] = value;
401 }
402 }
403 Ok(S::from_array(output))
404 }
405 }
406
407 pub fn set_specialization_constant_value<S: ConstantValue>(
417 &mut self,
418 handle: Handle<ConstantId>,
419 value: S,
420 ) -> error::Result<()> {
421 let constant = self.yield_id(handle)?;
422 unsafe {
423 let handle = sys::spvc_compiler_get_constant_handle(self.ptr.as_ptr(), constant);
425
426 Self::bounds_check_constant(handle, S::COLUMNS as u32 - 1, S::VECSIZE as u32 - 1)?;
428
429 let value = S::to_array(value);
430 for column in 0..S::COLUMNS {
431 for row in 0..S::VECSIZE {
432 S::BaseType::set(handle, column as u32, row as u32, value[column][row]);
433 }
434 }
435 }
436 Ok(())
437 }
438}
439
440#[allow(unused_imports)]
441#[allow(clippy::needless_pub_self)]
442pub(self) use impl_vec_constant;