use crate::ir::{Expr, Node, Program, UnOp};
use rustc_hash::FxHashMap;
use std::sync::{Arc, RwLock};
#[derive(Clone, Debug, PartialEq)]
pub enum PrecisionHint {
F16Eligible {
max_abs_operand: f32,
},
TranscendentalPolynomial {
op: TranscendentalOp,
argument_bound: f32,
},
}
#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
pub enum TranscendentalOp {
Sin,
Cos,
Exp,
Ln,
}
#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
pub struct ExprDigest(pub [u8; 32]);
#[derive(Clone, Debug, Default)]
pub struct PrecisionHints {
inner: Arc<RwLock<FxHashMap<ExprDigest, PrecisionHint>>>,
}
impl PrecisionHints {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn record(&self, digest: ExprDigest, hint: PrecisionHint) {
self.inner
.write()
.expect("precision_hints rwlock poisoned")
.insert(digest, hint);
}
#[must_use]
pub fn lookup(&self, digest: ExprDigest) -> Option<PrecisionHint> {
self.inner
.read()
.expect("precision_hints rwlock poisoned")
.get(&digest)
.cloned()
}
#[must_use]
pub fn len(&self) -> usize {
self.inner
.read()
.expect("precision_hints rwlock poisoned")
.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
pub fn analyse_precision(program: &Program, hints: &PrecisionHints) -> usize {
let mut count = 0usize;
for node in program.entry() {
analyse_node(node, hints, &mut count);
}
count
}
fn analyse_node(node: &Node, hints: &PrecisionHints, count: &mut usize) {
match node {
Node::Let { value, .. } | Node::Assign { value, .. } => {
analyse_expr(value, hints, count);
}
Node::Store { index, value, .. } => {
analyse_expr(index, hints, count);
analyse_expr(value, hints, count);
}
Node::If {
cond,
then,
otherwise,
} => {
analyse_expr(cond, hints, count);
for n in then {
analyse_node(n, hints, count);
}
for n in otherwise {
analyse_node(n, hints, count);
}
}
Node::Loop { from, to, body, .. } => {
analyse_expr(from, hints, count);
analyse_expr(to, hints, count);
for n in body {
analyse_node(n, hints, count);
}
}
Node::Block(body) => {
for n in body {
analyse_node(n, hints, count);
}
}
Node::Region { body, .. } => {
for n in body.iter() {
analyse_node(n, hints, count);
}
}
_ => {}
}
}
fn analyse_expr(expr: &Expr, hints: &PrecisionHints, count: &mut usize) {
if let Some(max_abs) = literal_only_fp_value_max(expr) {
if fits_f16_range(max_abs) {
let digest = digest_of(expr);
hints.record(
digest,
PrecisionHint::F16Eligible {
max_abs_operand: max_abs,
},
);
*count += 1;
}
}
if let Expr::UnOp { op, operand } = expr {
if let Some(transcendental) = transcendental_op(op) {
if let Some(literal) = literal_f32(operand) {
let bound = literal.abs();
let in_range = match transcendental {
TranscendentalOp::Sin | TranscendentalOp::Cos => {
bound <= std::f32::consts::FRAC_PI_4
}
TranscendentalOp::Exp => bound <= 1.0,
TranscendentalOp::Ln => literal >= 1.0 && literal <= 2.0,
};
if in_range {
let digest = digest_of(expr);
hints.record(
digest,
PrecisionHint::TranscendentalPolynomial {
op: transcendental,
argument_bound: bound,
},
);
*count += 1;
}
}
}
}
match expr {
Expr::Load { index, .. } => analyse_expr(index, hints, count),
Expr::BinOp { left, right, .. } => {
analyse_expr(left, hints, count);
analyse_expr(right, hints, count);
}
Expr::UnOp { operand, .. } => analyse_expr(operand, hints, count),
Expr::Call { args, .. } => {
for arg in args {
analyse_expr(arg, hints, count);
}
}
Expr::Select {
cond,
true_val,
false_val,
} => {
analyse_expr(cond, hints, count);
analyse_expr(true_val, hints, count);
analyse_expr(false_val, hints, count);
}
Expr::Cast { value, .. } => analyse_expr(value, hints, count),
Expr::Fma { a, b, c } => {
analyse_expr(a, hints, count);
analyse_expr(b, hints, count);
analyse_expr(c, hints, count);
}
_ => {}
}
}
fn literal_f32(expr: &Expr) -> Option<f32> {
if let Expr::LitF32(v) = expr {
Some(*v)
} else {
None
}
}
fn literal_only_fp_value_max(expr: &Expr) -> Option<f32> {
match expr {
Expr::LitF32(v) => Some(v.abs()),
Expr::BinOp { left, right, .. } => {
let l = literal_only_fp_value_max(left)?;
let r = literal_only_fp_value_max(right)?;
Some(l.max(r))
}
Expr::UnOp { operand, .. } => literal_only_fp_value_max(operand),
Expr::Fma { a, b, c } => {
let a = literal_only_fp_value_max(a)?;
let b = literal_only_fp_value_max(b)?;
let c = literal_only_fp_value_max(c)?;
Some(a.max(b).max(c))
}
_ => None,
}
}
fn fits_f16_range(value: f32) -> bool {
value.is_finite() && value.abs() < 65_504.0
}
fn transcendental_op(op: &UnOp) -> Option<TranscendentalOp> {
match op {
UnOp::Sin => Some(TranscendentalOp::Sin),
UnOp::Cos => Some(TranscendentalOp::Cos),
UnOp::Exp => Some(TranscendentalOp::Exp),
UnOp::Log => Some(TranscendentalOp::Ln),
_ => None,
}
}
fn digest_of(expr: &Expr) -> ExprDigest {
use blake3::Hasher;
let mut hasher = Hasher::new();
hasher.update(format!("{expr:?}").as_bytes());
let mut out = [0u8; 32];
out.copy_from_slice(hasher.finalize().as_bytes());
ExprDigest(out)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ir::{BinOp, BufferAccess, BufferDecl, DataType, Expr, Node};
fn buf() -> BufferDecl {
BufferDecl::storage("buf", 0, BufferAccess::ReadWrite, DataType::F32).with_count(4)
}
fn program(entry: Vec<Node>) -> Program {
Program::wrapped(vec![buf()], [1, 1, 1], entry)
}
#[test]
fn empty_program_records_zero_hints() {
let hints = PrecisionHints::new();
let count = analyse_precision(&program(Vec::new()), &hints);
assert_eq!(count, 0);
assert!(hints.is_empty());
}
#[test]
fn fp_literal_addition_is_f16_eligible() {
let hints = PrecisionHints::new();
let entry = vec![Node::let_bind(
"x",
Expr::BinOp {
op: BinOp::Add,
left: Box::new(Expr::f32(1.5)),
right: Box::new(Expr::f32(2.0)),
},
)];
analyse_precision(&program(entry), &hints);
assert!(hints.len() >= 1);
}
#[test]
fn fp_literal_outside_f16_range_skips_g1() {
let hints = PrecisionHints::new();
let entry = vec![Node::let_bind(
"x",
Expr::BinOp {
op: BinOp::Mul,
left: Box::new(Expr::f32(1e10)),
right: Box::new(Expr::f32(2.0)),
},
)];
analyse_precision(&program(entry), &hints);
for digest in [digest_of(&Expr::f32(1e10))].iter() {
assert!(matches!(
hints.lookup(*digest),
Some(PrecisionHint::F16Eligible { .. }) | None
));
}
let compound = Expr::BinOp {
op: BinOp::Mul,
left: Box::new(Expr::f32(1e10)),
right: Box::new(Expr::f32(2.0)),
};
let compound_digest = digest_of(&compound);
assert!(
!matches!(
hints.lookup(compound_digest),
Some(PrecisionHint::F16Eligible { .. })
),
"1e10 operand must reject F16 eligibility for the parent BinOp"
);
}
#[test]
fn sin_in_quarter_pi_range_recorded() {
let hints = PrecisionHints::new();
let entry = vec![Node::let_bind(
"x",
Expr::UnOp {
op: UnOp::Sin,
operand: Box::new(Expr::f32(0.5)),
},
)];
analyse_precision(&program(entry), &hints);
let digest = digest_of(&Expr::UnOp {
op: UnOp::Sin,
operand: Box::new(Expr::f32(0.5)),
});
assert!(matches!(
hints.lookup(digest),
Some(PrecisionHint::TranscendentalPolynomial {
op: TranscendentalOp::Sin,
..
})
));
}
#[test]
fn sin_outside_quarter_pi_range_skips_g5() {
let hints = PrecisionHints::new();
let entry = vec![Node::let_bind(
"x",
Expr::UnOp {
op: UnOp::Sin,
operand: Box::new(Expr::f32(2.0)),
},
)];
analyse_precision(&program(entry), &hints);
let digest = digest_of(&Expr::UnOp {
op: UnOp::Sin,
operand: Box::new(Expr::f32(2.0)),
});
assert!(!matches!(
hints.lookup(digest),
Some(PrecisionHint::TranscendentalPolynomial { .. })
));
}
#[test]
fn exp_within_unit_range_recorded() {
let hints = PrecisionHints::new();
let entry = vec![Node::let_bind(
"x",
Expr::UnOp {
op: UnOp::Exp,
operand: Box::new(Expr::f32(0.5)),
},
)];
analyse_precision(&program(entry), &hints);
let digest = digest_of(&Expr::UnOp {
op: UnOp::Exp,
operand: Box::new(Expr::f32(0.5)),
});
assert!(matches!(
hints.lookup(digest),
Some(PrecisionHint::TranscendentalPolynomial {
op: TranscendentalOp::Exp,
..
})
));
}
#[test]
fn ln_within_one_to_two_range_recorded() {
let hints = PrecisionHints::new();
let entry = vec![Node::let_bind(
"x",
Expr::UnOp {
op: UnOp::Log,
operand: Box::new(Expr::f32(1.5)),
},
)];
analyse_precision(&program(entry), &hints);
let digest = digest_of(&Expr::UnOp {
op: UnOp::Log,
operand: Box::new(Expr::f32(1.5)),
});
assert!(matches!(
hints.lookup(digest),
Some(PrecisionHint::TranscendentalPolynomial {
op: TranscendentalOp::Ln,
..
})
));
}
#[test]
fn sin_non_literal_skips_g5() {
let hints = PrecisionHints::new();
let entry = vec![Node::let_bind(
"x",
Expr::UnOp {
op: UnOp::Sin,
operand: Box::new(Expr::var("theta")),
},
)];
analyse_precision(&program(entry), &hints);
assert!(
hints.is_empty()
|| !hints
.lookup(digest_of(&Expr::UnOp {
op: UnOp::Sin,
operand: Box::new(Expr::var("theta")),
}))
.map_or(false, |h| matches!(
h,
PrecisionHint::TranscendentalPolynomial { .. }
))
);
}
#[test]
fn hints_are_send_sync() {
fn assert_send<T: Send>() {}
fn assert_sync<T: Sync>() {}
assert_send::<PrecisionHints>();
assert_sync::<PrecisionHints>();
}
}