use vyre_foundation::optimizer::eqsat::ENodeLang;
use crate::device_profile::DeviceProfile;
pub const HOT_PATH_COST_SCALE: f32 = 0.5;
pub const COLD_PATH_COST_SCALE: f32 = 1.5;
pub const TENSOR_CORE_COST_SCALE: f32 = 0.25;
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub struct NodeHints {
pub fp16_eligible: bool,
pub compile_time_constant: bool,
}
#[must_use]
pub fn device_aware_cost<L, B, H>(
profile: &DeviceProfile,
hot: bool,
base_cost_fn: B,
hint_lookup: H,
) -> impl Fn(&L) -> u64
where
L: ENodeLang,
B: Fn(&L) -> u64,
H: Fn(&L) -> NodeHints,
{
let path_scale = if hot {
HOT_PATH_COST_SCALE
} else {
COLD_PATH_COST_SCALE
};
let tensor_scale = if profile.supports_tensor_cores && profile.supports_f16 {
TENSOR_CORE_COST_SCALE
} else {
1.0
};
move |node: &L| {
let base = base_cost_fn(node);
let hints = hint_lookup(node);
let mut scale = path_scale;
if hints.fp16_eligible {
scale *= tensor_scale;
}
scale_cost(base, scale)
}
}
fn scale_cost(base: u64, scale: f32) -> u64 {
if !scale.is_finite() || scale <= 0.0 {
return base;
}
let clamped = scale.min(4.0);
let scaled = (base as f32) * clamped;
if !scaled.is_finite() {
return base;
}
scaled.round().max(0.0) as u64
}
#[cfg(test)]
mod tests {
use super::*;
use vyre_foundation::optimizer::eqsat::{EChildren, ENodeLang};
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
enum Toy {
Const(u32),
Heavy,
}
impl ENodeLang for Toy {
fn children(&self) -> EChildren {
EChildren::new()
}
fn with_children(&self, _children: &[vyre_foundation::optimizer::eqsat::EClassId]) -> Self {
self.clone()
}
}
fn base_cost(node: &Toy) -> u64 {
match node {
Toy::Const(_) => 1,
Toy::Heavy => 100,
}
}
fn no_hints(_: &Toy) -> NodeHints {
NodeHints::default()
}
#[test]
fn cold_path_inflates_base_cost() {
let profile = DeviceProfile::conservative("test");
let cost = device_aware_cost(&profile, false, base_cost, no_hints);
assert_eq!(
cost(&Toy::Heavy),
(100.0 * COLD_PATH_COST_SCALE).round() as u64
);
assert_eq!(
cost(&Toy::Const(0)),
(1.0 * COLD_PATH_COST_SCALE).round() as u64
);
}
#[test]
fn hot_path_shrinks_base_cost() {
let profile = DeviceProfile::conservative("test");
let cost = device_aware_cost(&profile, true, base_cost, no_hints);
assert_eq!(
cost(&Toy::Heavy),
(100.0 * HOT_PATH_COST_SCALE).round() as u64
);
assert_eq!(
cost(&Toy::Const(0)),
(1.0 * HOT_PATH_COST_SCALE).round() as u64
);
}
#[test]
fn tensor_core_profile_scales_fp16_eligible_nodes() {
let mut profile = DeviceProfile::conservative("test");
profile.supports_tensor_cores = true;
profile.supports_f16 = true;
let mark_eligible = |node: &Toy| match node {
Toy::Heavy => NodeHints {
fp16_eligible: true,
compile_time_constant: false,
},
_ => NodeHints::default(),
};
let cost = device_aware_cost(&profile, true, base_cost, mark_eligible);
let expected = (100.0 * HOT_PATH_COST_SCALE * TENSOR_CORE_COST_SCALE).round() as u64;
assert_eq!(cost(&Toy::Heavy), expected);
assert_eq!(
cost(&Toy::Const(0)),
(1.0 * HOT_PATH_COST_SCALE).round() as u64
);
}
#[test]
fn no_tensor_core_support_ignores_fp16_hint() {
let profile = DeviceProfile::conservative("test");
assert!(!profile.supports_tensor_cores);
let mark_eligible = |_: &Toy| NodeHints {
fp16_eligible: true,
compile_time_constant: false,
};
let cost = device_aware_cost(&profile, true, base_cost, mark_eligible);
assert_eq!(
cost(&Toy::Heavy),
(100.0 * HOT_PATH_COST_SCALE).round() as u64
);
}
#[test]
fn scale_cost_clamps_high_multiplier() {
assert_eq!(scale_cost(10, 100.0), 40); }
#[test]
fn scale_cost_falls_back_on_nan_or_negative() {
assert_eq!(scale_cost(7, f32::NAN), 7);
assert_eq!(scale_cost(7, -1.0), 7);
assert_eq!(scale_cost(7, 0.0), 7);
}
#[test]
fn deterministic_for_identical_profiles() {
let p1 = DeviceProfile::conservative("a");
let p2 = DeviceProfile::conservative("b");
let c1 = device_aware_cost(&p1, false, base_cost, no_hints);
let c2 = device_aware_cost(&p2, false, base_cost, no_hints);
assert_eq!(c1(&Toy::Heavy), c2(&Toy::Heavy));
assert_eq!(c1(&Toy::Const(7)), c2(&Toy::Const(7)));
}
}