#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(super) enum LayoutRule {
#[allow(dead_code)]
Std430,
Std140,
}
#[derive(Debug, Clone)]
pub(super) struct Type {
pub id: u32,
pub kind: TypeKind,
}
#[derive(Debug, Clone)]
pub(super) enum TypeKind {
Scalar { width_bytes: u32 },
Vector { component: Box<Type>, count: u32 },
Matrix { column: Box<Type>, cols: u32 },
Array { element: Box<Type>, len: u32 },
Struct { members: Vec<Type> },
}
pub(super) struct StructLayout {
pub member_offsets: Vec<u32>,
pub size: u32,
#[allow(dead_code)]
pub align: u32,
}
const fn round_up(x: u32, a: u32) -> u32 {
(x + a - 1) & !(a - 1)
}
pub(super) fn base_align(t: &TypeKind, rule: LayoutRule) -> u32 {
let inner = match t {
TypeKind::Scalar { width_bytes } => *width_bytes,
TypeKind::Vector { component, count } => match count {
2 => 2 * base_align(&component.kind, rule),
3 | 4 => 4 * base_align(&component.kind, rule),
n => panic!("Unsupported vector component count: {}", n),
},
TypeKind::Matrix { column, .. } => base_align(&column.kind, rule),
TypeKind::Array { element, .. } => base_align(&element.kind, rule),
TypeKind::Struct { members } => members
.iter()
.map(|m| base_align(&m.kind, rule))
.max()
.unwrap_or(4),
};
match (rule, t) {
(LayoutRule::Std140, TypeKind::Array { .. } | TypeKind::Struct { .. }) => inner.max(16),
_ => inner,
}
}
pub(super) fn size_of(t: &TypeKind, rule: LayoutRule) -> u32 {
match t {
TypeKind::Scalar { width_bytes } => *width_bytes,
TypeKind::Vector { component, count } => count * size_of(&component.kind, rule),
TypeKind::Matrix { column, cols } => {
let col_count = column_vec_count(column);
let scalar_w = column_scalar_width(column);
cols * matrix_stride(col_count, scalar_w, rule)
}
TypeKind::Array { element, len } => array_stride(&element.kind, rule) * len,
TypeKind::Struct { members } => layout_struct(members, rule).size,
}
}
pub(super) fn array_stride(elem: &TypeKind, rule: LayoutRule) -> u32 {
let align = base_align(elem, rule);
let raw = round_up(size_of(elem, rule), align);
match rule {
LayoutRule::Std140 => raw.max(16),
LayoutRule::Std430 => raw,
}
}
pub(super) fn matrix_stride(column_vec_count: u32, scalar_w: u32, rule: LayoutRule) -> u32 {
let vec_align = match column_vec_count {
1 => scalar_w,
2 => 2 * scalar_w,
3 | 4 => 4 * scalar_w,
n => panic!("Unsupported matrix column dimension: {}", n),
};
let raw = round_up(column_vec_count * scalar_w, vec_align);
match rule {
LayoutRule::Std140 => raw.max(16),
LayoutRule::Std430 => raw,
}
}
pub(super) fn layout_struct(members: &[Type], rule: LayoutRule) -> StructLayout {
let mut offset = 0u32;
let mut offsets = Vec::with_capacity(members.len());
let mut align = 4u32;
for m in members {
let a = base_align(&m.kind, rule);
offset = round_up(offset, a);
offsets.push(offset);
offset += size_of(&m.kind, rule);
align = align.max(a);
}
let size = round_up(offset, align);
StructLayout {
member_offsets: offsets,
size,
align,
}
}
pub(super) fn column_vec_count(column: &Type) -> u32 {
match &column.kind {
TypeKind::Vector { count, .. } => *count,
TypeKind::Scalar { .. } => 1,
_ => panic!("Matrix column type must be a scalar or vector"),
}
}
pub(super) fn column_scalar_width(column: &Type) -> u32 {
match &column.kind {
TypeKind::Vector { component, .. } => match &component.kind {
TypeKind::Scalar { width_bytes } => *width_bytes,
_ => panic!("Matrix column vector component must be scalar"),
},
TypeKind::Scalar { width_bytes } => *width_bytes,
_ => panic!("Matrix column type must be a scalar or vector"),
}
}