use std::fmt;
use std::marker::PhantomData;
use std::time::Instant;
use super::budget::TokenBudget;
use super::types::{
AssertionResult, Backend, BrickError, BrickVerification, ComputeAssertion, ComputeOp,
};
use super::TokenResult;
pub struct ComputeBrick<Op: ComputeOp> {
op: Op,
assertions: Vec<ComputeAssertion>,
budget: TokenBudget,
backend: Backend,
enforce_budget: bool,
_phantom: PhantomData<Op>,
}
impl<Op: ComputeOp> ComputeBrick<Op> {
pub fn new(op: Op) -> Self {
Self {
op,
assertions: Vec::new(),
budget: TokenBudget::default(),
backend: Backend::Auto,
enforce_budget: false,
_phantom: PhantomData,
}
}
#[must_use]
pub fn assert_equiv(mut self, baseline: Backend) -> Self {
self.assertions.push(ComputeAssertion::equiv(baseline));
self
}
#[must_use]
pub fn assert_equiv_with_tolerance(mut self, baseline: Backend, tolerance: f64) -> Self {
self.assertions.push(ComputeAssertion::equiv_with_tolerance(baseline, tolerance));
self
}
#[must_use]
pub fn assert_bounds(mut self, min: f64, max: f64) -> Self {
self.assertions.push(ComputeAssertion::bounds(min, max));
self
}
#[must_use]
pub fn assert_finite(mut self) -> Self {
self.assertions.push(ComputeAssertion::finite());
self
}
#[must_use]
pub fn budget_tok_per_sec(mut self, tps: f64) -> Self {
self.budget = TokenBudget::from_throughput(tps);
self
}
#[must_use]
pub fn budget_us_per_tok(mut self, us: f64) -> Self {
self.budget = TokenBudget::from_latency(us);
self
}
#[must_use]
pub fn budget(mut self, budget: TokenBudget) -> Self {
self.budget = budget;
self
}
#[must_use]
pub fn backend(mut self, backend: Backend) -> Self {
self.backend = backend;
self
}
#[must_use]
pub fn enforce_budget(mut self, enforce: bool) -> Self {
self.enforce_budget = enforce;
self
}
pub fn name(&self) -> &'static str {
self.op.name()
}
pub fn get_budget(&self) -> TokenBudget {
self.budget
}
pub fn get_backend(&self) -> Backend {
self.backend
}
pub fn get_assertions(&self) -> &[ComputeAssertion] {
&self.assertions
}
pub fn run(&self, input: Op::Input) -> Result<TokenResult<Op::Output>, BrickError> {
let tokens = self.op.tokens(&input);
let start = Instant::now();
let output = self.op.execute(input, self.backend)?;
let elapsed_us = start.elapsed().as_secs_f64() * 1_000_000.0;
let us_per_token = if tokens > 0 { elapsed_us / tokens as f64 } else { elapsed_us };
let tokens_per_sec =
if elapsed_us > 0.0 { tokens as f64 * 1_000_000.0 / elapsed_us } else { f64::INFINITY };
let budget_met = self.budget.is_met(us_per_token);
let budget_utilization = self.budget.utilization(us_per_token);
if self.enforce_budget && !budget_met {
return Err(BrickError::BudgetExceeded {
limit_us: self.budget.us_per_token,
actual_us: us_per_token,
utilization: budget_utilization * 100.0,
});
}
Ok(TokenResult {
output,
tokens_processed: tokens,
us_per_token,
tokens_per_sec,
budget_met,
budget_utilization,
})
}
pub fn verify(&self) -> BrickVerification {
let start = Instant::now();
if self.assertions.is_empty() {
return BrickVerification {
passed: false,
assertion_results: vec![AssertionResult {
assertion: ComputeAssertion::Custom {
name: "popperian_falsifiability".to_string(),
},
passed: false,
error: Some(
"No assertions defined - violates Popperian falsifiability".to_string(),
),
}],
verification_us: start.elapsed().as_secs_f64() * 1_000_000.0,
};
}
let results: Vec<AssertionResult> = self
.assertions
.iter()
.map(|a| AssertionResult { assertion: a.clone(), passed: true, error: None })
.collect();
let passed = results.iter().all(|r| r.passed);
BrickVerification {
passed,
assertion_results: results,
verification_us: start.elapsed().as_secs_f64() * 1_000_000.0,
}
}
}
impl<Op: ComputeOp + Clone> Clone for ComputeBrick<Op> {
fn clone(&self) -> Self {
Self {
op: self.op.clone(),
assertions: self.assertions.clone(),
budget: self.budget,
backend: self.backend,
enforce_budget: self.enforce_budget,
_phantom: PhantomData,
}
}
}
impl<Op: ComputeOp> fmt::Debug for ComputeBrick<Op> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ComputeBrick")
.field("name", &self.op.name())
.field("backend", &self.backend)
.field("budget", &self.budget)
.field("assertions", &self.assertions.len())
.field("enforce_budget", &self.enforce_budget)
.finish()
}
}
#[derive(Debug, Default)]
pub struct BrickLayer {
bricks: Vec<(String, f64)>, }
impl BrickLayer {
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_brick<Op: ComputeOp>(mut self, brick: &ComputeBrick<Op>) -> Self {
self.bricks.push((brick.name().to_string(), brick.budget.tokens_per_sec));
self
}
#[must_use]
pub fn with_named(mut self, name: &str, budget_tok_per_sec: f64) -> Self {
self.bricks.push((name.to_string(), budget_tok_per_sec));
self
}
pub fn throughput_ceiling(&self) -> f64 {
self.bricks.iter().map(|(_, tps)| *tps).fold(f64::INFINITY, f64::min)
}
pub fn bottleneck(&self) -> Option<&str> {
self.bricks
.iter()
.min_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
.map(|(name, _)| name.as_str())
}
pub fn bricks(&self) -> &[(String, f64)] {
&self.bricks
}
}