pub fn validate_brick_contract(brick: &dyn ComputeBrick<Output = Vec<f32>>) -> Result<(), BrickError> {
let assertions = brick.assertions();
if assertions.is_empty() {
return Err(BrickError::AssertionFailed {
name: format!("{}/CB-BUDGET", brick.name()),
expected: "at least 1 assertion".to_string(),
actual: "0 assertions".to_string(),
});
}
let budget = brick.budget();
if budget.us_per_token <= 0.0 || budget.tokens_per_sec <= 0.0 {
return Err(BrickError::BudgetExceeded {
limit_us: 0.0,
actual_us: budget.us_per_token,
});
}
Ok(())
}
#[derive(Debug)]
pub struct RmsNormBrick {
pub weight: Vec<f32>,
pub eps: f32,
budget: TokenBudget,
}
impl RmsNormBrick {
pub fn new(weight: Vec<f32>, eps: f32) -> Self {
Self {
weight,
eps,
budget: TokenBudget::from_latency(1.5), }
}
#[must_use]
pub fn with_budget(mut self, budget: TokenBudget) -> Self {
self.budget = budget;
self
}
pub fn run(&self, input: &[f32]) -> Result<TokenResult<Vec<f32>>, BrickError> {
if input.len() != self.weight.len() {
return Err(BrickError::InvalidInput(format!(
"Input len {} != weight len {}",
input.len(),
self.weight.len()
)));
}
let start = Instant::now();
let rms = (input.iter().map(|x| x * x).sum::<f32>() / input.len() as f32 + self.eps).sqrt();
let output: Vec<f32> = input
.iter()
.zip(self.weight.iter())
.map(|(x, w)| (x / rms) * w)
.collect();
let elapsed_us = start.elapsed().as_micros() as f64;
let result = TokenResult::new(output, 1, elapsed_us, &self.budget);
for assertion in self.assertions() {
assertion.check_f32(&result.output, result.budget_met)?;
}
Ok(result)
}
}
impl ComputeBrick for RmsNormBrick {
type Output = Vec<f32>;
fn name(&self) -> &'static str {
"rms_norm"
}
fn budget(&self) -> TokenBudget {
self.budget
}
fn assertions(&self) -> Vec<BrickAssertion> {
vec![
BrickAssertion::no_nan(),
BrickAssertion::no_inf(),
BrickAssertion::budget_met(),
]
}
}
#[derive(Debug)]
pub struct QkvBrick {
pub hidden_dim: usize,
pub q_dim: usize,
pub k_dim: usize,
pub v_dim: usize,
budget: TokenBudget,
pub has_bias: bool,
}
impl QkvBrick {
pub fn new(hidden_dim: usize, q_dim: usize, k_dim: usize, v_dim: usize) -> Self {
Self {
hidden_dim,
q_dim,
k_dim,
v_dim,
budget: TokenBudget::from_latency(6.0), has_bias: false,
}
}
#[must_use]
pub fn with_budget(mut self, budget: TokenBudget) -> Self {
self.budget = budget;
self
}
#[must_use]
pub fn with_bias(mut self) -> Self {
self.has_bias = true;
self
}
pub fn total_out_dim(&self) -> usize {
self.q_dim + self.k_dim + self.v_dim
}
}
impl ComputeBrick for QkvBrick {
type Output = (Vec<f32>, Vec<f32>, Vec<f32>);
fn name(&self) -> &'static str {
"qkv_proj"
}
fn budget(&self) -> TokenBudget {
self.budget
}
fn assertions(&self) -> Vec<BrickAssertion> {
vec![
BrickAssertion::no_nan(),
BrickAssertion::no_inf(),
BrickAssertion::budget_met(),
]
}
}
#[derive(Debug)]
pub struct RopeBrick {
pub head_dim: usize,
pub num_heads: usize,
pub theta: f32,
pub rope_type: u32,
budget: TokenBudget,
}
impl RopeBrick {
pub fn new(head_dim: usize, num_heads: usize, theta: f32, rope_type: u32) -> Self {
Self {
head_dim,
num_heads,
theta,
rope_type,
budget: TokenBudget::from_latency(1.0), }
}
#[must_use]
pub fn with_budget(mut self, budget: TokenBudget) -> Self {
self.budget = budget;
self
}
}
impl ComputeBrick for RopeBrick {
type Output = Vec<f32>;
fn name(&self) -> &'static str {
"rope"
}
fn budget(&self) -> TokenBudget {
self.budget
}
fn assertions(&self) -> Vec<BrickAssertion> {
vec![
BrickAssertion::no_nan(),
BrickAssertion::no_inf(),
BrickAssertion::budget_met(),
]
}
}
#[derive(Debug)]
pub struct AttentionBrick {
pub num_heads: usize,
pub num_kv_heads: usize,
pub head_dim: usize,
budget: TokenBudget,
}
impl AttentionBrick {
pub fn new(num_heads: usize, num_kv_heads: usize, head_dim: usize) -> Self {
Self {
num_heads,
num_kv_heads,
head_dim,
budget: TokenBudget::from_latency(10.0), }
}
#[must_use]
pub fn with_budget(mut self, budget: TokenBudget) -> Self {
self.budget = budget;
self
}
pub fn group_size(&self) -> usize {
self.num_heads / self.num_kv_heads.max(1)
}
}
impl ComputeBrick for AttentionBrick {
type Output = Vec<f32>;
fn name(&self) -> &'static str {
"attention"
}
fn budget(&self) -> TokenBudget {
self.budget
}
fn assertions(&self) -> Vec<BrickAssertion> {
vec![
BrickAssertion::no_nan(),
BrickAssertion::no_inf(),
BrickAssertion::budget_met(),
BrickAssertion::bounds(-100.0, 100.0),
]
}
}
#[derive(Debug, Clone)]
pub struct FlashAttentionBrick {
pub num_heads: usize,
pub num_kv_heads: usize,
pub head_dim: usize,
pub tile_size: usize,
budget: TokenBudget,
pub use_online_softmax: bool,
}