use alloc::borrow::Cow;
use core::mem;
#[cfg_attr(not(any(feature = "spirv", doc)), expect(unused_imports))]
use crate::ShaderSource;
#[cfg(doc)]
use crate::Device;
const SPIRV_MAGIC_NUMBER: u32 = 0x0723_0203;
#[cfg(feature = "spirv")] pub fn make_spirv(data: &[u8]) -> ShaderSource<'_> {
ShaderSource::SpirV(make_spirv_raw(data))
}
#[track_caller]
const fn assert_has_spirv_magic_number_and_length(bytes: &[u8]) -> bool {
let found_magic_number: Option<bool> = match *bytes {
[] => panic!("byte slice is empty, not SPIR-V"),
[b1, b2, b3, b4, ..] => {
let prefix = u32::from_ne_bytes([b1, b2, b3, b4]);
if prefix == SPIRV_MAGIC_NUMBER {
Some(false)
} else if prefix == const { SPIRV_MAGIC_NUMBER.swap_bytes() } {
Some(true)
} else {
None
}
}
_ => None, };
match found_magic_number {
Some(needs_byte_swap) => {
assert!(
bytes.len().is_multiple_of(mem::size_of::<u32>()),
"SPIR-V data must be a multiple of 4 bytes long"
);
needs_byte_swap
}
None => {
panic!(
"byte slice does not start with SPIR-V magic number. \
Make sure you are using a binary SPIR-V file."
);
}
}
}
#[cfg_attr(not(feature = "spirv"), expect(rustdoc::broken_intra_doc_links))]
pub fn make_spirv_raw(bytes: &[u8]) -> Cow<'_, [u32]> {
let needs_byte_swap = assert_has_spirv_magic_number_and_length(bytes);
let mut words: Cow<'_, [u32]> = match bytemuck::try_cast_slice(bytes) {
Ok(words) => Cow::Borrowed(words),
Err(_) => Cow::Owned(bytemuck::pod_collect_to_vec(bytes)),
};
if needs_byte_swap {
for word in Cow::to_mut(&mut words) {
*word = word.swap_bytes();
}
}
assert!(
words[0] == SPIRV_MAGIC_NUMBER,
"can't happen: wrong magic number after swap_bytes"
);
words
}
#[doc(hidden)]
pub const fn make_spirv_const<const IN: usize, const OUT: usize>(bytes: [u8; IN]) -> [u32; OUT] {
let needs_byte_swap = assert_has_spirv_magic_number_and_length(&bytes);
assert!(mem::size_of_val(&bytes) == mem::size_of::<[u32; OUT]>());
let mut words: [u32; OUT] = unsafe { mem::transmute_copy(&bytes) };
if needs_byte_swap {
let mut idx = 0;
while idx < words.len() {
words[idx] = words[idx].swap_bytes();
idx += 1;
}
}
assert!(
words[0] == SPIRV_MAGIC_NUMBER,
"can't happen: wrong magic number after swap_bytes"
);
words
}
#[cfg(test)]
mod tests {
use super::*;
use alloc::vec;
fn test_success_with_misalignments<const IN: usize, const OUT: usize>(
input: &[u8; IN],
expected: [u32; OUT],
) {
let mut buffer = vec![0; input.len() + 4];
for offset in 0..4 {
let misaligned_slice: &mut [u8; IN] =
(&mut buffer[offset..][..input.len()]).try_into().unwrap();
misaligned_slice.copy_from_slice(input);
assert_eq!(*make_spirv_raw(misaligned_slice), expected);
assert_eq!(make_spirv_const(*misaligned_slice), expected);
}
}
#[test]
fn success_be() {
let input = b"\x07\x23\x02\x03\xF1\xF2\xF3\xF4";
let expected: [u32; 2] = [SPIRV_MAGIC_NUMBER, 0xF1F2F3F4];
test_success_with_misalignments(input, expected);
}
#[test]
fn success_le() {
let input = b"\x03\x02\x23\x07\xF1\xF2\xF3\xF4";
let expected: [u32; 2] = [SPIRV_MAGIC_NUMBER, 0xF4F3F2F1];
test_success_with_misalignments(input, expected);
}
#[should_panic = "multiple of 4"]
#[test]
fn nonconst_le_fail() {
let _: Cow<'_, [u32]> = make_spirv_raw(&[0x03, 0x02, 0x23, 0x07, 0x44, 0x33]);
}
#[should_panic = "multiple of 4"]
#[test]
fn nonconst_be_fail() {
let _: Cow<'_, [u32]> = make_spirv_raw(&[0x07, 0x23, 0x02, 0x03, 0x11, 0x22]);
}
#[should_panic = "multiple of 4"]
#[test]
fn const_le_fail() {
let _: [u32; 1] = make_spirv_const([0x03, 0x02, 0x23, 0x07, 0x44, 0x33]);
}
#[should_panic = "multiple of 4"]
#[test]
fn const_be_fail() {
let _: [u32; 1] = make_spirv_const([0x07, 0x23, 0x02, 0x03, 0x11, 0x22]);
}
#[should_panic = "byte slice is empty, not SPIR-V"]
#[test]
fn make_spirv_empty() {
let _: [u32; 0] = make_spirv_const([]);
}
}