use super::VertexInputRate;
use crate::{buffer::BufferContents, format::Format};
use std::collections::HashMap;
#[cfg(feature = "macros")]
pub use vulkano_macros::Vertex;
pub unsafe trait Vertex: BufferContents + Sized {
fn per_vertex() -> VertexBufferDescription;
fn per_instance() -> VertexBufferDescription;
fn per_instance_with_divisor(divisor: u32) -> VertexBufferDescription;
}
#[derive(Clone, Debug)]
pub struct VertexBufferDescription {
pub members: HashMap<String, VertexMemberInfo>,
pub stride: u32,
pub input_rate: VertexInputRate,
}
impl VertexBufferDescription {
#[inline]
pub fn per_vertex(self) -> VertexBufferDescription {
let VertexBufferDescription {
members, stride, ..
} = self;
VertexBufferDescription {
members,
stride,
input_rate: VertexInputRate::Vertex,
}
}
#[inline]
pub fn per_instance(self) -> VertexBufferDescription {
self.per_instance_with_divisor(1)
}
#[inline]
pub fn per_instance_with_divisor(self, divisor: u32) -> VertexBufferDescription {
let VertexBufferDescription {
members, stride, ..
} = self;
VertexBufferDescription {
members,
stride,
input_rate: VertexInputRate::Instance { divisor },
}
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct VertexMemberInfo {
pub offset: usize,
pub format: Format,
pub num_elements: u32,
}
impl VertexMemberInfo {
#[inline]
pub fn num_components(&self) -> u32 {
self.format
.components()
.iter()
.filter(|&bits| *bits > 0)
.count() as u32
}
}
#[cfg(test)]
mod tests {
use crate::format::Format;
use crate::pipeline::graphics::vertex_input::Vertex;
use bytemuck::{Pod, Zeroable};
#[test]
fn derive_vertex_multiple_names() {
#[repr(C)]
#[derive(Clone, Copy, Debug, Default, Zeroable, Pod, Vertex)]
struct TestVertex {
#[name("b", "c")]
#[format(R32G32B32A32_SFLOAT)]
a: [f32; 16],
}
let info = TestVertex::per_vertex();
let b = info.members.get("b").unwrap();
let c = info.members.get("c").unwrap();
assert_eq!(b.format, Format::R32G32B32A32_SFLOAT);
assert_eq!(c.format, Format::R32G32B32A32_SFLOAT);
assert_eq!(b.num_elements, 4);
assert_eq!(c.num_elements, 4);
}
#[test]
fn derive_vertex_format() {
#[repr(C)]
#[derive(Clone, Copy, Debug, Default, Zeroable, Pod, Vertex)]
struct TestVertex {
#[format(R8_UNORM)]
unorm: u8,
}
let info = TestVertex::per_instance();
let unorm = info.members.get("unorm").unwrap();
assert_eq!(unorm.format, Format::R8_UNORM);
assert_eq!(unorm.num_elements, 1);
}
}