use vyre_foundation::optimizer::eqsat::ENodeLang;
use crate::device_profile::DeviceProfile;
pub const HOT_PATH_COST_SCALE: f32 = 0.5;
pub const HOT_PATH_COST_SCALE_BPS: u32 = 5_000;
pub const COLD_PATH_COST_SCALE: f32 = 1.5;
pub const COLD_PATH_COST_SCALE_BPS: u32 = 15_000;
pub const TENSOR_CORE_COST_SCALE: f32 = 0.25;
pub const TENSOR_CORE_COST_SCALE_BPS: u32 = 2_500;
const MAX_SCALE_BPS: u32 = 40_000;
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub struct NodeHints {
pub fp16_eligible: bool,
pub compile_time_constant: bool,
}
#[must_use]
pub fn device_aware_cost<L, B, H>(
profile: &DeviceProfile,
hot: bool,
base_cost_fn: B,
hint_lookup: H,
) -> impl Fn(&L) -> u64
where
L: ENodeLang,
B: Fn(&L) -> u64,
H: Fn(&L) -> NodeHints,
{
let path_scale_bps = if hot {
HOT_PATH_COST_SCALE_BPS
} else {
COLD_PATH_COST_SCALE_BPS
};
let tensor_scale_bps = if profile.supports_tensor_cores && profile.supports_f16 {
TENSOR_CORE_COST_SCALE_BPS
} else {
crate::numeric::BASIS_POINTS_DENOMINATOR
};
move |node: &L| {
let base = base_cost_fn(node);
let hints = hint_lookup(node);
let mut scale_bps = path_scale_bps;
if hints.fp16_eligible {
scale_bps = compose_scale_basis_points(scale_bps, tensor_scale_bps);
}
scale_cost_basis_points(base, scale_bps)
}
}
fn scale_cost_basis_points(base: u64, scale_bps: u32) -> u64 {
crate::numeric::scale_u64_by_basis_points_round_clamped(
base,
scale_bps,
base,
MAX_SCALE_BPS,
"extraction cost",
"driver",
)
}
fn compose_scale_basis_points(left_bps: u32, right_bps: u32) -> u32 {
crate::numeric::compose_basis_points_u32(
left_bps,
right_bps,
"extraction cost scale composition",
"driver",
)
}
#[cfg(test)]
mod tests {
use super::*;
use vyre_foundation::optimizer::eqsat::{EChildren, ENodeLang};
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
enum Toy {
Const(u32),
Heavy,
}
impl ENodeLang for Toy {
fn children(&self) -> EChildren {
EChildren::new()
}
fn with_children(&self, _children: &[vyre_foundation::optimizer::eqsat::EClassId]) -> Self {
self.clone()
}
}
fn base_cost(node: &Toy) -> u64 {
match node {
Toy::Const(_) => 1,
Toy::Heavy => 100,
}
}
fn no_hints(_: &Toy) -> NodeHints {
NodeHints::default()
}
#[test]
fn cold_path_inflates_base_cost() {
let profile = DeviceProfile::conservative("test");
let cost = device_aware_cost(&profile, false, base_cost, no_hints);
assert_eq!(
cost(&Toy::Heavy),
scale_cost_basis_points(100, COLD_PATH_COST_SCALE_BPS)
);
assert_eq!(
cost(&Toy::Const(0)),
scale_cost_basis_points(1, COLD_PATH_COST_SCALE_BPS)
);
}
#[test]
fn hot_path_shrinks_base_cost() {
let profile = DeviceProfile::conservative("test");
let cost = device_aware_cost(&profile, true, base_cost, no_hints);
assert_eq!(
cost(&Toy::Heavy),
scale_cost_basis_points(100, HOT_PATH_COST_SCALE_BPS)
);
assert_eq!(
cost(&Toy::Const(0)),
scale_cost_basis_points(1, HOT_PATH_COST_SCALE_BPS)
);
}
#[test]
fn tensor_core_profile_scales_fp16_eligible_nodes() {
let mut profile = DeviceProfile::conservative("test");
profile.supports_tensor_cores = true;
profile.supports_f16 = true;
let mark_eligible = |node: &Toy| match node {
Toy::Heavy => NodeHints {
fp16_eligible: true,
compile_time_constant: false,
},
_ => NodeHints::default(),
};
let cost = device_aware_cost(&profile, true, base_cost, mark_eligible);
let expected = scale_cost_basis_points(
100,
compose_scale_basis_points(HOT_PATH_COST_SCALE_BPS, TENSOR_CORE_COST_SCALE_BPS),
);
assert_eq!(cost(&Toy::Heavy), expected);
assert_eq!(
cost(&Toy::Const(0)),
scale_cost_basis_points(1, HOT_PATH_COST_SCALE_BPS)
);
}
#[test]
fn no_tensor_core_support_ignores_fp16_hint() {
let profile = DeviceProfile::conservative("test");
assert!(!profile.supports_tensor_cores);
let mark_eligible = |_: &Toy| NodeHints {
fp16_eligible: true,
compile_time_constant: false,
};
let cost = device_aware_cost(&profile, true, base_cost, mark_eligible);
assert_eq!(
cost(&Toy::Heavy),
scale_cost_basis_points(100, HOT_PATH_COST_SCALE_BPS)
);
}
#[test]
fn scale_cost_clamps_high_multiplier_basis_points() {
assert_eq!(scale_cost_basis_points(10, 1_000_000), 40); }
#[test]
fn zero_basis_point_scale_preserves_invalid_scale_contract() {
assert_eq!(scale_cost_basis_points(7, 0), 7);
}
#[test]
fn extraction_cost_release_path_uses_integer_scaling() {
let source = include_str!("extraction_cost.rs");
let production = source
.split("#[cfg(test)]")
.next()
.expect("Fix: extraction-cost production source must precede tests");
assert!(
production.contains("scale_cost_basis_points")
&& production.contains("compose_scale_basis_points")
&& production.contains("crate::numeric::"),
"Fix: extraction cost scaling must use deterministic integer basis-point arithmetic."
);
assert!(
!production.contains("base as f32")
&& !production.contains("scaled.round()")
&& !production.contains("scale *= tensor_scale"),
"Fix: extraction cost release path must not use lossy float scaling."
);
}
#[test]
fn deterministic_for_identical_profiles() {
let p1 = DeviceProfile::conservative("a");
let p2 = DeviceProfile::conservative("b");
let c1 = device_aware_cost(&p1, false, base_cost, no_hints);
let c2 = device_aware_cost(&p2, false, base_cost, no_hints);
assert_eq!(c1(&Toy::Heavy), c2(&Toy::Heavy));
assert_eq!(c1(&Toy::Const(7)), c2(&Toy::Const(7)));
}
}