use vyre_foundation::ir::{BinOp, Expr};
use vyre_foundation::optimizer::passes::algebraic::precision_hint::{
PrecisionHint, TranscendentalOp,
};
#[derive(Debug, Clone, Default)]
pub struct BackendCapabilities {
pub has_mul_high: bool,
pub has_dual_issue_fp32_int32: bool,
pub has_tensor_core_int: bool,
pub has_native_f16: bool,
pub has_warp_shuffle: bool,
pub has_shared_memory: bool,
pub has_transcendental_polynomial_emit: bool,
pub max_native_int_width: u32,
}
#[derive(Debug, Clone)]
pub enum LoweredExpr {
Expr(Expr),
Emitted,
}
pub trait LoweringStrategy: Send + Sync + std::fmt::Debug {
fn name(&self) -> &str;
fn can_apply(&self, caps: &BackendCapabilities, op: &BinOp) -> bool;
fn priority(&self) -> u32;
fn lower(&self, op: &BinOp, left: &Expr, right: &Expr) -> LoweredExpr;
}
pub fn select_strategy<'a>(
strategies: &'a [Box<dyn LoweringStrategy>],
caps: &BackendCapabilities,
op: &BinOp,
) -> Option<&'a dyn LoweringStrategy> {
strategies
.iter()
.filter(|s| s.can_apply(caps, op))
.max_by_key(|s| s.priority())
.map(|s| s.as_ref())
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum PrecisionLoweringPlan {
DefaultF32,
NativeF16 {
max_abs_operand: f32,
},
PolynomialTranscendental {
op: TranscendentalOp,
argument_bound: f32,
degree: u8,
},
}
#[must_use]
pub fn select_precision_lowering(
caps: &BackendCapabilities,
hint: &PrecisionHint,
) -> PrecisionLoweringPlan {
match hint {
PrecisionHint::F16Eligible { max_abs_operand } if caps.has_native_f16 => {
PrecisionLoweringPlan::NativeF16 {
max_abs_operand: *max_abs_operand,
}
}
PrecisionHint::TranscendentalPolynomial { op, argument_bound }
if caps.has_transcendental_polynomial_emit =>
{
PrecisionLoweringPlan::PolynomialTranscendental {
op: *op,
argument_bound: *argument_bound,
degree: polynomial_degree_for(*op, *argument_bound),
}
}
_ => PrecisionLoweringPlan::DefaultF32,
}
}
fn polynomial_degree_for(op: TranscendentalOp, argument_bound: f32) -> u8 {
match op {
TranscendentalOp::Sin => {
if argument_bound <= 0.25 {
3
} else {
5
}
}
TranscendentalOp::Cos => {
if argument_bound <= 0.25 {
4
} else {
6
}
}
TranscendentalOp::Exp | TranscendentalOp::Ln => 5,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Debug)]
struct MockNativeStrategy;
impl LoweringStrategy for MockNativeStrategy {
fn name(&self) -> &str {
"mock-native"
}
fn can_apply(&self, caps: &BackendCapabilities, op: &BinOp) -> bool {
caps.has_mul_high && matches!(op, BinOp::MulHigh)
}
fn priority(&self) -> u32 {
100
}
fn lower(&self, _op: &BinOp, left: &Expr, right: &Expr) -> LoweredExpr {
LoweredExpr::Expr(Expr::mulhi(left.clone(), right.clone()))
}
}
#[derive(Debug)]
struct MockFallbackStrategy;
impl LoweringStrategy for MockFallbackStrategy {
fn name(&self) -> &str {
"mock-fallback"
}
fn can_apply(&self, _caps: &BackendCapabilities, op: &BinOp) -> bool {
matches!(op, BinOp::MulHigh)
}
fn priority(&self) -> u32 {
10
}
fn lower(&self, _op: &BinOp, left: &Expr, right: &Expr) -> LoweredExpr {
LoweredExpr::Expr(Expr::mul(left.clone(), right.clone()))
}
}
#[test]
fn selects_highest_priority() {
let strategies: Vec<Box<dyn LoweringStrategy>> =
vec![Box::new(MockFallbackStrategy), Box::new(MockNativeStrategy)];
let caps = BackendCapabilities {
has_mul_high: true,
..Default::default()
};
let selected = select_strategy(&strategies, &caps, &BinOp::MulHigh);
assert_eq!(selected.unwrap().name(), "mock-native");
}
#[test]
fn falls_back_when_native_unavailable() {
let strategies: Vec<Box<dyn LoweringStrategy>> =
vec![Box::new(MockFallbackStrategy), Box::new(MockNativeStrategy)];
let caps = BackendCapabilities {
has_mul_high: false,
..Default::default()
};
let selected = select_strategy(&strategies, &caps, &BinOp::MulHigh);
assert_eq!(selected.unwrap().name(), "mock-fallback");
}
#[test]
fn returns_none_for_unsupported_op() {
let strategies: Vec<Box<dyn LoweringStrategy>> = vec![Box::new(MockNativeStrategy)];
let caps = BackendCapabilities {
has_mul_high: true,
..Default::default()
};
let selected = select_strategy(&strategies, &caps, &BinOp::Add);
assert!(selected.is_none());
}
#[test]
fn precision_hint_selects_native_f16_when_supported() {
let caps = BackendCapabilities {
has_native_f16: true,
..Default::default()
};
let plan = select_precision_lowering(
&caps,
&PrecisionHint::F16Eligible {
max_abs_operand: 4.0,
},
);
assert_eq!(
plan,
PrecisionLoweringPlan::NativeF16 {
max_abs_operand: 4.0
}
);
}
#[test]
fn precision_hint_keeps_f32_without_native_f16() {
let plan = select_precision_lowering(
&BackendCapabilities::default(),
&PrecisionHint::F16Eligible {
max_abs_operand: 4.0,
},
);
assert_eq!(plan, PrecisionLoweringPlan::DefaultF32);
}
#[test]
fn transcendental_hint_selects_polynomial_when_supported() {
let caps = BackendCapabilities {
has_transcendental_polynomial_emit: true,
..Default::default()
};
let plan = select_precision_lowering(
&caps,
&PrecisionHint::TranscendentalPolynomial {
op: TranscendentalOp::Sin,
argument_bound: 0.2,
},
);
assert_eq!(
plan,
PrecisionLoweringPlan::PolynomialTranscendental {
op: TranscendentalOp::Sin,
argument_bound: 0.2,
degree: 3,
}
);
}
#[test]
fn transcendental_hint_uses_higher_degree_for_wider_sin_range() {
let caps = BackendCapabilities {
has_transcendental_polynomial_emit: true,
..Default::default()
};
let plan = select_precision_lowering(
&caps,
&PrecisionHint::TranscendentalPolynomial {
op: TranscendentalOp::Sin,
argument_bound: 0.75,
},
);
assert_eq!(
plan,
PrecisionLoweringPlan::PolynomialTranscendental {
op: TranscendentalOp::Sin,
argument_bound: 0.75,
degree: 5,
}
);
}
}