use crate::spec::types::OpSpec;
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct CostMetrics {
pub size_bytes: u64,
pub dispatch_time_ns: u64,
pub workgroup_memory_bytes: u64,
pub register_pressure: u64,
pub barrier_count: u64,
}
impl CostMetrics {
pub const ZERO: Self = Self {
size_bytes: 0,
dispatch_time_ns: 0,
workgroup_memory_bytes: 0,
register_pressure: 0,
barrier_count: 0,
};
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct CostBudget {
pub size_bytes_max: Option<u64>,
pub dispatch_time_ns_max: Option<u64>,
pub workgroup_memory_bytes_max: Option<u64>,
pub register_pressure_max: Option<u64>,
pub barrier_count_max: Option<u64>,
}
impl CostBudget {
pub const INFINITE: Self = Self {
size_bytes_max: None,
dispatch_time_ns_max: None,
workgroup_memory_bytes_max: None,
register_pressure_max: None,
barrier_count_max: None,
};
#[must_use]
#[inline]
pub fn from_baseline_with_default_ratio(baseline: CostMetrics) -> Self {
Self {
size_bytes_max: Some(scaled(baseline.size_bytes, 105, 100)),
dispatch_time_ns_max: Some(scaled(baseline.dispatch_time_ns, 105, 100)),
workgroup_memory_bytes_max: Some(scaled(baseline.workgroup_memory_bytes, 105, 100)),
register_pressure_max: Some(scaled(baseline.register_pressure, 105, 100)),
barrier_count_max: Some(baseline.barrier_count.saturating_add(1)),
}
}
}
fn scaled(value: u64, num: u64, den: u64) -> u64 {
value.saturating_mul(num) / den.max(1)
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct CostFinding {
pub op_id: String,
pub metric: CostMetric,
pub measured: u64,
pub budget: u64,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub enum CostMetric {
SizeBytes,
DispatchTimeNs,
WorkgroupMemoryBytes,
RegisterPressure,
BarrierCount,
}
impl CostMetric {
#[must_use]
pub const fn name(self) -> &'static str {
match self {
Self::SizeBytes => "size_bytes",
Self::DispatchTimeNs => "dispatch_time_ns",
Self::WorkgroupMemoryBytes => "workgroup_memory_bytes",
Self::RegisterPressure => "register_pressure",
Self::BarrierCount => "barrier_count",
}
}
}
impl std::fmt::Display for CostFinding {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{}: {} = {} exceeds budget {}. Fix: optimize the lowering or relax the per-op budget in spec.toml (with justification).",
self.op_id,
self.metric.name(),
self.measured,
self.budget
)
}
}
#[must_use]
#[inline]
pub fn check_one(op: &OpSpec, measured: CostMetrics, budget: CostBudget) -> Vec<CostFinding> {
let mut findings = Vec::new();
let pairs: &[(CostMetric, u64, Option<u64>)] = &[
(
CostMetric::SizeBytes,
measured.size_bytes,
budget.size_bytes_max,
),
(
CostMetric::DispatchTimeNs,
measured.dispatch_time_ns,
budget.dispatch_time_ns_max,
),
(
CostMetric::WorkgroupMemoryBytes,
measured.workgroup_memory_bytes,
budget.workgroup_memory_bytes_max,
),
(
CostMetric::RegisterPressure,
measured.register_pressure,
budget.register_pressure_max,
),
(
CostMetric::BarrierCount,
measured.barrier_count,
budget.barrier_count_max,
),
];
for (metric, measured_value, budget_value) in pairs {
let Some(budget_value) = budget_value else {
continue;
};
if measured_value > budget_value {
findings.push(CostFinding {
op_id: op.id.to_string(),
metric: *metric,
measured: *measured_value,
budget: *budget_value,
});
}
}
findings
}
#[must_use]
#[inline]
pub fn run(ops: &[OpSpec], observations: &[Option<(CostMetrics, CostBudget)>]) -> Vec<CostFinding> {
let mut findings = Vec::new();
for (op, observation) in ops.iter().zip(observations.iter()) {
let Some((measured, budget)) = observation else {
continue;
};
findings.extend(check_one(op, *measured, *budget));
}
findings
}
pub struct CostCertificateEnforcer;
impl crate::enforce::EnforceGate for CostCertificateEnforcer {
fn id(&self) -> &'static str {
"cost_certificate"
}
fn name(&self) -> &'static str {
"cost_certificate"
}
fn run(&self, _ctx: &crate::enforce::EnforceCtx<'_>) -> Vec<crate::enforce::Finding> {
let messages = Vec::new();
crate::enforce::finding_result(self.id(), messages)
}
}
pub const REGISTERED: CostCertificateEnforcer = CostCertificateEnforcer;
#[cfg(test)]
mod tests {
use super::*;
use crate::spec::types::conform::Strictness;
use crate::spec::types::{DataType, OpSignature};
use crate::spec::AlgebraicLaw;
use vyre_spec::Category;
fn op() -> OpSpec {
OpSpec::builder("test.cost.sample")
.signature(OpSignature {
inputs: vec![DataType::U32, DataType::U32],
output: DataType::U32,
})
.cpu_fn(|i| i.to_vec())
.wgsl_fn(|| "fn main() {}".to_string())
.category(Category::A {
composition_of: vec!["test.cost.sample"],
})
.laws(vec![AlgebraicLaw::Bounded {
lo: 0,
hi: u32::MAX,
}])
.strictness(Strictness::Strict)
.version(1)
.build()
.unwrap()
}
#[test]
fn within_budget_produces_no_findings() {
let measured = CostMetrics {
size_bytes: 100,
dispatch_time_ns: 100,
workgroup_memory_bytes: 100,
register_pressure: 10,
barrier_count: 2,
};
let budget = CostBudget {
size_bytes_max: Some(200),
dispatch_time_ns_max: Some(200),
workgroup_memory_bytes_max: Some(200),
register_pressure_max: Some(20),
barrier_count_max: Some(4),
};
let findings = check_one(&op(), measured, budget);
assert!(findings.is_empty(), "{findings:?}");
}
#[test]
fn each_metric_can_fire_independently() {
let measured = CostMetrics {
size_bytes: 300,
dispatch_time_ns: 400,
workgroup_memory_bytes: 500,
register_pressure: 30,
barrier_count: 5,
};
let budget = CostBudget {
size_bytes_max: Some(200),
dispatch_time_ns_max: Some(300),
workgroup_memory_bytes_max: Some(400),
register_pressure_max: Some(20),
barrier_count_max: Some(3),
};
let findings = check_one(&op(), measured, budget);
let metrics: std::collections::BTreeSet<_> =
findings.iter().map(|finding| finding.metric).collect();
assert_eq!(metrics.len(), 5, "{findings:?}");
}
#[test]
fn none_budget_field_is_ignored() {
let measured = CostMetrics {
size_bytes: u64::MAX,
dispatch_time_ns: u64::MAX,
workgroup_memory_bytes: u64::MAX,
register_pressure: u64::MAX,
barrier_count: u64::MAX,
};
let budget = CostBudget::INFINITE;
let findings = check_one(&op(), measured, budget);
assert!(findings.is_empty(), "{findings:?}");
}
#[test]
fn enforce_skips_ops_with_no_observation() {
let ops = [op()];
let findings = run(&ops, &[None]);
assert!(findings.is_empty());
}
#[test]
fn enforce_reports_over_budget_for_observed_op() {
let ops = [op()];
let measured = CostMetrics {
size_bytes: 1000,
..CostMetrics::ZERO
};
let budget = CostBudget {
size_bytes_max: Some(500),
..CostBudget::INFINITE
};
let findings = run(&ops, &[Some((measured, budget))]);
assert_eq!(findings.len(), 1);
assert_eq!(findings[0].metric, CostMetric::SizeBytes);
}
#[test]
fn from_baseline_with_default_ratio_applies_5_percent_slack() {
let baseline = CostMetrics {
size_bytes: 100,
dispatch_time_ns: 200,
workgroup_memory_bytes: 300,
register_pressure: 10,
barrier_count: 2,
};
let budget = CostBudget::from_baseline_with_default_ratio(baseline);
assert_eq!(budget.size_bytes_max, Some(105));
assert_eq!(budget.dispatch_time_ns_max, Some(210));
assert_eq!(budget.workgroup_memory_bytes_max, Some(315));
assert_eq!(budget.register_pressure_max, Some(10)); assert_eq!(budget.barrier_count_max, Some(3)); }
#[test]
fn display_finding_is_actionable() {
let finding = CostFinding {
op_id: "test.op".to_string(),
metric: CostMetric::SizeBytes,
measured: 1000,
budget: 500,
};
let rendered = format!("{finding}");
assert!(rendered.contains("test.op"));
assert!(rendered.contains("size_bytes"));
assert!(rendered.contains("1000"));
assert!(rendered.contains("500"));
assert!(rendered.contains("Fix:"));
}
#[test]
fn cost_metric_names_are_unique() {
let names: std::collections::BTreeSet<_> = [
CostMetric::SizeBytes,
CostMetric::DispatchTimeNs,
CostMetric::WorkgroupMemoryBytes,
CostMetric::RegisterPressure,
CostMetric::BarrierCount,
]
.into_iter()
.map(|metric| metric.name())
.collect();
assert_eq!(names.len(), 5);
}
}