1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
use {
hal::{device::ShaderError, pso},
spirv_cross::spirv,
};
pub type FastHashMap<K, V> = std::collections::HashMap<K, V, std::hash::BuildHasherDefault<fxhash::FxHasher>>;
pub fn spirv_cross_specialize_ast<T>(
ast: &mut spirv::Ast<T>,
specialization: &pso::Specialization,
) -> Result<(), ShaderError>
where
T: spirv::Target,
spirv::Ast<T>: spirv::Compile<T> + spirv::Parse<T>,
{
let spec_constants = ast.get_specialization_constants().map_err(|err| {
ShaderError::CompilationFailed(match err {
spirv_cross::ErrorCode::CompilationError(msg) => msg,
spirv_cross::ErrorCode::Unhandled => "Unexpected specialization constant error".into(),
})
})?;
for spec_constant in spec_constants {
if let Some(constant) = specialization
.constants
.iter()
.find(|c| c.id == spec_constant.constant_id)
{
let value = specialization.data
[constant.range.start as usize .. constant.range.end as usize]
.iter()
.rev()
.fold(0u64, |u, &b| (u << 8) + b as u64);
ast.set_scalar_constant(spec_constant.id, value)
.map_err(|err| {
ShaderError::CompilationFailed(match err {
spirv_cross::ErrorCode::CompilationError(msg) => msg,
spirv_cross::ErrorCode::Unhandled => {
"Unexpected specialization constant error".into()
}
})
})?;
}
}
Ok(())
}