use crate::arena::{Arena, Handle, UniqueArena};
use std::{num::NonZeroU32, ops};
pub type Alignment = NonZeroU32;
#[derive(Clone, Copy, Debug, Hash, PartialEq)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
pub struct TypeLayout {
pub size: u32,
pub alignment: Alignment,
}
#[derive(Debug, Default)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
pub struct Layouter {
layouts: Vec<TypeLayout>,
}
impl ops::Index<Handle<crate::Type>> for Layouter {
type Output = TypeLayout;
fn index(&self, handle: Handle<crate::Type>) -> &TypeLayout {
&self.layouts[handle.index()]
}
}
#[derive(Clone, Copy, Debug, PartialEq, thiserror::Error)]
#[error("Base type {0:?} is out of bounds")]
pub struct InvalidBaseType(pub Handle<crate::Type>);
impl Layouter {
pub fn clear(&mut self) {
self.layouts.clear();
}
pub fn round_up(alignment: Alignment, offset: u32) -> u32 {
match offset & (alignment.get() - 1) {
0 => offset,
other => offset + alignment.get() - other,
}
}
pub fn member_placement(
&self,
offset: u32,
ty: Handle<crate::Type>,
align: Option<Alignment>,
size: Option<NonZeroU32>,
) -> (ops::Range<u32>, Alignment) {
let layout = self.layouts[ty.index()];
let alignment = align.unwrap_or(layout.alignment);
let start = Self::round_up(alignment, offset);
let span = match size {
Some(size) => size.get(),
None => layout.size,
};
(start..start + span, alignment)
}
pub fn update(
&mut self,
types: &UniqueArena<crate::Type>,
constants: &Arena<crate::Constant>,
) -> Result<(), InvalidBaseType> {
use crate::TypeInner as Ti;
for (ty_handle, ty) in types.iter().skip(self.layouts.len()) {
let size = ty.inner.span(constants);
let layout = match ty.inner {
Ti::Scalar { width, .. } | Ti::Atomic { width, .. } => TypeLayout {
size,
alignment: Alignment::new(width as u32).unwrap(),
},
Ti::Vector {
size: vec_size,
width,
..
} => TypeLayout {
size,
alignment: {
let count = if vec_size >= crate::VectorSize::Tri {
4
} else {
2
};
Alignment::new((count * width) as u32).unwrap()
},
},
Ti::Matrix {
columns: _,
rows,
width,
} => TypeLayout {
size,
alignment: {
let count = if rows >= crate::VectorSize::Tri { 4 } else { 2 };
Alignment::new((count * width) as u32).unwrap()
},
},
Ti::Pointer { .. } | Ti::ValuePointer { .. } => TypeLayout {
size,
alignment: Alignment::new(1).unwrap(),
},
Ti::Array {
base,
stride: _,
size: _,
} => TypeLayout {
size,
alignment: if base < ty_handle {
self[base].alignment
} else {
return Err(InvalidBaseType(base));
},
},
Ti::Struct { span, ref members } => {
let mut alignment = Alignment::new(1).unwrap();
for member in members {
alignment = if member.ty < ty_handle {
alignment.max(self[member.ty].alignment)
} else {
return Err(InvalidBaseType(member.ty));
};
}
TypeLayout {
size: span,
alignment,
}
}
Ti::Image { .. } | Ti::Sampler { .. } => TypeLayout {
size,
alignment: Alignment::new(1).unwrap(),
},
};
debug_assert!(ty.inner.span(constants) <= layout.size);
self.layouts.push(layout);
}
Ok(())
}
}