use smallvec::SmallVec;
use vyre_foundation::optimizer::eqsat::{extract_best, EClassId, EGraph, ENodeLang};
use crate::autotune_store::AutotuneRecord;
use crate::device_profile::DeviceProfile;
use crate::extraction_cost::{device_aware_cost, NodeHints};
use crate::trace_jit_policy::{decide_trace_jit_speculation, TraceJitDecision, TraceJitInputs};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ExtractionDevice<'a> {
pub profile: &'a DeviceProfile,
pub autotune_record: Option<&'a AutotuneRecord>,
pub trace_jit: Option<TraceJitInputs>,
pub hot_path: bool,
}
impl<'a> ExtractionDevice<'a> {
#[must_use]
pub const fn new(profile: &'a DeviceProfile, hot_path: bool) -> Self {
Self {
profile,
autotune_record: None,
trace_jit: None,
hot_path,
}
}
#[must_use]
pub const fn with_autotune_record(mut self, record: &'a AutotuneRecord) -> Self {
self.autotune_record = Some(record);
self
}
#[must_use]
pub const fn with_trace_jit(mut self, counters: TraceJitInputs) -> Self {
self.trace_jit = Some(counters);
self
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct DeviceExtraction<L> {
pub backend: &'static str,
pub hot_path: bool,
pub node: L,
pub cost: u64,
}
#[must_use]
pub fn extract_best_for_device<L, B, H>(
egraph: &EGraph<L>,
root: EClassId,
device: ExtractionDevice<'_>,
base_cost: B,
hint_lookup: H,
) -> Option<DeviceExtraction<L>>
where
L: ENodeLang,
B: Fn(&L) -> u64,
H: Fn(&L) -> NodeHints,
{
if root.0 as usize >= egraph.class_count() {
return None;
}
let profile_cost = device_aware_cost(device.profile, device.hot_path, &base_cost, &hint_lookup);
let cost = |node: &L| {
let hints = hint_lookup(node);
let cost = profile_cost(node);
apply_context_bias(cost, extraction_bias_bps(device, hints))
};
extract_best(egraph, root, cost).map(|(node, cost)| DeviceExtraction {
backend: device.profile.backend,
hot_path: device.hot_path,
node,
cost,
})
}
#[must_use]
pub fn extract_best_for_devices<'a, L, B, H>(
egraph: &EGraph<L>,
root: EClassId,
devices: impl IntoIterator<Item = ExtractionDevice<'a>>,
base_cost: B,
hint_lookup: H,
) -> SmallVec<[DeviceExtraction<L>; 4]>
where
L: ENodeLang,
B: Fn(&L) -> u64,
H: Fn(&L) -> NodeHints,
{
let mut out = SmallVec::new();
for device in devices {
if let Some(extracted) =
extract_best_for_device(egraph, root, device, &base_cost, &hint_lookup)
{
out.push(extracted);
}
}
out
}
fn extraction_bias_bps(device: ExtractionDevice<'_>, hints: NodeHints) -> u32 {
let mut bps = 10_000u32;
if let Some(record) = device.autotune_record {
if hints.compile_time_constant && record.unroll > 1 {
bps = bps.saturating_mul(8_000) / 10_000;
}
if hints.fp16_eligible && record.tile.iter().any(|dim| *dim > 1) {
bps = bps.saturating_mul(9_500) / 10_000;
}
}
if hints.compile_time_constant {
if let Some(counters) = device.trace_jit {
if matches!(
decide_trace_jit_speculation(counters),
TraceJitDecision::Speculate { .. }
) {
bps = bps.saturating_mul(7_000) / 10_000;
}
}
}
bps.max(1)
}
fn apply_context_bias(cost: u64, bps: u32) -> u64 {
if bps >= 10_000 {
return cost;
}
let scaled = (cost as u128).saturating_mul(bps as u128) / 10_000u128;
u64::try_from(scaled).unwrap_or(u64::MAX).max(1)
}
#[cfg(test)]
mod tests {
use super::*;
use vyre_foundation::optimizer::eqsat::{EChildren, EGraph};
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
enum Toy {
Scalar,
TensorCore,
Specialized,
}
impl ENodeLang for Toy {
fn children(&self) -> EChildren {
EChildren::new()
}
fn with_children(&self, _children: &[EClassId]) -> Self {
self.clone()
}
}
fn base_cost(node: &Toy) -> u64 {
match node {
Toy::Scalar => 10,
Toy::TensorCore => 30,
Toy::Specialized => 11,
}
}
fn hints(node: &Toy) -> NodeHints {
match node {
Toy::TensorCore => NodeHints {
fp16_eligible: true,
compile_time_constant: false,
},
Toy::Specialized => NodeHints {
fp16_eligible: false,
compile_time_constant: true,
},
Toy::Scalar => NodeHints::default(),
}
}
fn equivalent_toy_graph() -> (EGraph<Toy>, EClassId) {
let mut graph = EGraph::new();
let scalar = graph.add(Toy::Scalar);
let tensor = graph.add(Toy::TensorCore);
graph.union(scalar, tensor);
graph.rebuild();
(graph, scalar)
}
fn specialized_toy_graph() -> (EGraph<Toy>, EClassId) {
let mut graph = EGraph::new();
let scalar = graph.add(Toy::Scalar);
let specialized = graph.add(Toy::Specialized);
graph.union(scalar, specialized);
graph.rebuild();
(graph, scalar)
}
#[test]
fn conservative_profile_extracts_scalar_variant() {
let (graph, root) = equivalent_toy_graph();
let profile = DeviceProfile::conservative("portable");
let extracted = extract_best_for_device(
&graph,
root,
ExtractionDevice::new(&profile, true),
base_cost,
hints,
)
.expect("equivalent toy graph must extract");
assert_eq!(extracted.backend, "portable");
assert_eq!(extracted.node, Toy::Scalar);
assert_eq!(extracted.cost, 5);
}
#[test]
fn tensor_core_profile_extracts_fp16_variant() {
let (graph, root) = equivalent_toy_graph();
let mut profile = DeviceProfile::conservative("native");
profile.supports_f16 = true;
profile.supports_tensor_cores = true;
let extracted = extract_best_for_device(
&graph,
root,
ExtractionDevice::new(&profile, true),
base_cost,
hints,
)
.expect("equivalent toy graph must extract");
assert_eq!(extracted.backend, "native");
assert_eq!(extracted.node, Toy::TensorCore);
assert_eq!(extracted.cost, 4);
}
#[test]
fn several_devices_extract_from_one_saturated_graph() {
let (graph, root) = equivalent_toy_graph();
let portable = DeviceProfile::conservative("portable");
let mut native = DeviceProfile::conservative("native");
native.supports_f16 = true;
native.supports_tensor_cores = true;
let variants = extract_best_for_devices(
&graph,
root,
[
ExtractionDevice::new(&portable, true),
ExtractionDevice::new(&native, true),
],
base_cost,
hints,
);
assert_eq!(variants.len(), 2);
assert_eq!(variants[0].node, Toy::Scalar);
assert_eq!(variants[1].node, Toy::TensorCore);
}
#[test]
fn autotune_record_biases_compile_time_constant_variant() {
let (graph, root) = specialized_toy_graph();
let profile = DeviceProfile::conservative("native");
let record = AutotuneRecord {
workgroup_size: [128, 1, 1],
unroll: 4,
tile: [0, 0, 0],
recorded_at: String::new(),
};
let extracted = extract_best_for_device(
&graph,
root,
ExtractionDevice::new(&profile, true).with_autotune_record(&record),
base_cost,
hints,
)
.expect("equivalent toy graph must extract");
assert_eq!(extracted.node, Toy::Specialized);
assert_eq!(extracted.cost, 4);
}
#[test]
fn trace_jit_biases_specialized_variant_when_speculation_pays() {
let (graph, root) = specialized_toy_graph();
let profile = DeviceProfile::conservative("native");
let counters = TraceJitInputs {
shader_hit_count: 64,
prediction_confidence_bps: 10_000,
speculative_spec_cost_ns: 1,
miss_cost_ns: 1_000_000,
};
let extracted = extract_best_for_device(
&graph,
root,
ExtractionDevice::new(&profile, true).with_trace_jit(counters),
base_cost,
hints,
)
.expect("equivalent toy graph must extract");
assert_eq!(extracted.node, Toy::Specialized);
assert_eq!(extracted.cost, 4);
}
#[test]
fn missing_root_returns_no_variant() {
let graph: EGraph<Toy> = EGraph::new();
let profile = DeviceProfile::conservative("portable");
let variants = extract_best_for_devices(
&graph,
EClassId(77),
[ExtractionDevice::new(&profile, true)],
base_cost,
hints,
);
assert!(variants.is_empty());
}
}