use derive_builder::Builder;
use serde::{Deserialize, Serialize};
use validator::{Validate, ValidationError};
use super::InnerShape;
#[derive(Debug, Clone, Builder, Validate, Serialize, Deserialize, PartialEq, Eq)]
pub struct LayoutConfig {
#[validate(range(min = 1))]
pub num_blocks: usize,
#[validate(range(min = 1))]
pub num_layers: usize,
#[validate(range(min = 1, max = 2))]
pub outer_dim: usize,
#[validate(range(min = 1))]
pub page_size: usize,
#[validate(range(min = 1))]
pub inner_dim: usize,
#[validate(custom(function = "validate_power_of_2"))]
#[builder(default = "1")]
pub alignment: usize,
#[validate(custom(function = "validate_dtype_width_bytes"))]
#[builder(default = "2")]
pub dtype_width_bytes: usize,
#[builder(default = "InnerShape::Unknown")]
pub inner_shape: InnerShape,
}
impl LayoutConfig {
pub fn builder() -> LayoutConfigBuilder {
LayoutConfigBuilder::default()
}
pub fn required_bytes(&self) -> usize {
self.num_blocks
.saturating_mul(self.num_layers)
.saturating_mul(self.outer_dim)
.saturating_mul(self.page_size)
.saturating_mul(self.inner_dim)
.saturating_mul(self.dtype_width_bytes)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum BlockDimension {
BlockIsFirstDim,
BlockIsSecondDim,
}
pub fn validate_power_of_2(alignment: usize) -> Result<(), ValidationError> {
if !alignment.is_power_of_two() {
return Err(validator::ValidationError::new(
"alignment_must_be_power_of_2",
));
}
Ok(())
}
pub fn validate_dtype_width_bytes(dtype_width_bytes: usize) -> Result<(), ValidationError> {
if !dtype_width_bytes.is_power_of_two() || !(2..=8).contains(&dtype_width_bytes) {
return Err(validator::ValidationError::new(
"dtype_width_bytes_must_be_power_of_two_and_less_than_8_bytes",
));
}
Ok(())
}