use std::fmt;
use std::time::Instant;
use crate::quantize::Q8_0Block;
#[cfg(feature = "cuda")]
mod fused;
#[cfg(feature = "cuda")]
pub use fused::{CoalescedDp4aBrick, FusedFfnBrick};
pub mod tracer;
pub use tracer::{BrickTracer, TraceComparison, TraceDiff, TraceEvent};
pub mod profiler;
pub use profiler::{BrickProfiler, ContractSeverity, OpStats, ProfileReport};
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct TokenBudget {
pub us_per_token: f64,
pub tokens_per_sec: f64,
pub batch_size: usize,
}
impl TokenBudget {
#[must_use]
pub fn from_latency(us_per_token: f64) -> Self {
Self {
us_per_token,
tokens_per_sec: 1_000_000.0 / us_per_token,
batch_size: 1,
}
}
#[must_use]
pub fn from_throughput(tokens_per_sec: f64) -> Self {
Self {
us_per_token: 1_000_000.0 / tokens_per_sec,
tokens_per_sec,
batch_size: 1,
}
}
#[must_use]
pub fn with_batch_size(mut self, batch_size: usize) -> Self {
self.batch_size = batch_size;
self
}
#[must_use]
pub fn is_met(&self, actual_us_per_token: f64) -> bool {
actual_us_per_token <= self.us_per_token
}
#[must_use]
pub fn gap_factor(&self, actual_us_per_token: f64) -> f64 {
actual_us_per_token / self.us_per_token
}
}
impl Default for TokenBudget {
fn default() -> Self {
Self::from_latency(100.0) }
}
#[derive(Debug, Clone)]
pub struct TokenResult<T> {
pub output: T,
pub tokens_processed: usize,
pub us_per_token: f64,
pub tokens_per_sec: f64,
pub budget_met: bool,
}
impl<T: Default> Default for TokenResult<T> {
fn default() -> Self {
Self {
output: T::default(),
tokens_processed: 0,
us_per_token: 0.0,
tokens_per_sec: 0.0,
budget_met: true,
}
}
}
impl<T> TokenResult<T> {
pub fn new(output: T, tokens: usize, elapsed_us: f64, budget: &TokenBudget) -> Self {
let us_per_token = elapsed_us / tokens.max(1) as f64;
let tokens_per_sec = if us_per_token > 0.0 {
1_000_000.0 / us_per_token
} else {
0.0
};
Self {
output,
tokens_processed: tokens,
us_per_token,
tokens_per_sec,
budget_met: budget.is_met(us_per_token),
}
}
}
#[derive(Debug)]
pub enum BrickError {
AssertionFailed {
name: String,
expected: String,
actual: String,
},
BudgetExceeded {
limit_us: f64,
actual_us: f64,
},
ComputeError(String),
InvalidInput(String),
}
impl fmt::Display for BrickError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::AssertionFailed {
name,
expected,
actual,
} => {
write!(
f,
"Assertion failed: {name} - expected {expected}, got {actual}"
)
},
Self::BudgetExceeded {
limit_us,
actual_us,
} => {
write!(
f,
"Budget exceeded: {limit_us:.1}µs/tok limit, {actual_us:.1}µs/tok actual"
)
},
Self::ComputeError(msg) => write!(f, "Compute error: {msg}"),
Self::InvalidInput(msg) => write!(f, "Invalid input: {msg}"),
}
}
}
impl std::error::Error for BrickError {}
#[derive(Debug, Clone)]
pub struct BrickAssertion {
pub name: String,
pub description: String,
pub kind: AssertionKind,
}
#[derive(Debug, Clone)]
pub enum AssertionKind {
EquivScalar {
tolerance: f64,
},
NoNaN,
NoInf,
Bounds {
min: f64,
max: f64,
},
BudgetMet,
Custom {
check_name: String,
},
}
impl BrickAssertion {
pub fn equiv_scalar(tolerance: f64) -> Self {
Self {
name: "equiv_scalar".to_string(),
description: format!("Output matches scalar baseline within {tolerance}"),
kind: AssertionKind::EquivScalar { tolerance },
}
}
pub fn no_nan() -> Self {
Self {
name: "no_nan".to_string(),
description: "Output contains no NaN values".to_string(),
kind: AssertionKind::NoNaN,
}
}
pub fn no_inf() -> Self {
Self {
name: "no_inf".to_string(),
description: "Output contains no Inf values".to_string(),
kind: AssertionKind::NoInf,
}
}
pub fn bounds(min: f64, max: f64) -> Self {
Self {
name: "bounds".to_string(),
description: format!("Output values in [{min}, {max}]"),
kind: AssertionKind::Bounds { min, max },
}
}
pub fn budget_met() -> Self {
Self {
name: "budget_met".to_string(),
description: "Performance budget is met".to_string(),
kind: AssertionKind::BudgetMet,
}
}
pub fn check_f32(&self, output: &[f32], budget_met: bool) -> Result<(), BrickError> {
match &self.kind {
AssertionKind::NoNaN => {
if let Some(idx) = output.iter().position(|x| x.is_nan()) {
return Err(BrickError::AssertionFailed {
name: self.name.clone(),
expected: "no NaN".to_string(),
actual: format!("NaN at index {idx}"),
});
}
},
AssertionKind::NoInf => {
if let Some(idx) = output.iter().position(|x| x.is_infinite()) {
return Err(BrickError::AssertionFailed {
name: self.name.clone(),
expected: "no Inf".to_string(),
actual: format!("Inf at index {idx}"),
});
}
},
AssertionKind::Bounds { min, max } => {
for (idx, &val) in output.iter().enumerate() {
if (val as f64) < *min || (val as f64) > *max {
return Err(BrickError::AssertionFailed {
name: self.name.clone(),
expected: format!("value in [{min}, {max}]"),
actual: format!("value {val} at index {idx}"),
});
}
}
},
AssertionKind::BudgetMet => {
if !budget_met {
return Err(BrickError::AssertionFailed {
name: self.name.clone(),
expected: "budget met".to_string(),
actual: "budget exceeded".to_string(),
});
}
},
AssertionKind::EquivScalar { .. } | AssertionKind::Custom { .. } => {
},
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct BrickVerification {
pub is_valid: bool,
pub results: Vec<(String, bool, String)>,
}
impl BrickVerification {
pub fn pass() -> Self {
Self {
is_valid: true,
results: vec![],
}
}
pub fn fail(name: &str, reason: &str) -> Self {
Self {
is_valid: false,
results: vec![(name.to_string(), false, reason.to_string())],
}
}
pub fn add(&mut self, name: &str, passed: bool, message: &str) {
self.results
.push((name.to_string(), passed, message.to_string()));
if !passed {
self.is_valid = false;
}
}
}
pub trait ComputeBrick: Send + Sync {
type Output;
fn name(&self) -> &'static str;
fn budget(&self) -> TokenBudget;
fn assertions(&self) -> Vec<BrickAssertion>;
fn verify(&self) -> BrickVerification {
let assertions = self.assertions();
if assertions.is_empty() {
return BrickVerification::fail(
self.name(),
"No assertions defined (Popper violation)",
);
}
let budget = self.budget();
if budget.us_per_token <= 0.0 {
return BrickVerification::fail(self.name(), "Zero or negative budget");
}
BrickVerification::pass()
}
fn can_run(&self) -> bool {
self.verify().is_valid
}
}
include!("brick_impls.rs");
include!("mod_tile_flash_attention.rs");
include!("mod_per_activation_quant.rs");
include!("graph.rs");