use crate::ir::{BinOp, Expr, UnOp};
const MAX_SHIFT_ADD_CHAIN_COST: u32 = 4;
pub(super) fn horner_quadratic_u32(expr: &Expr) -> Option<Expr> {
let mut terms = Vec::with_capacity(4);
collect_add_terms(expr, &mut terms);
if terms.len() != 3 {
return None;
}
let mut constant: Option<u32> = None;
let mut linear: Option<(IdentRef<'_>, u32)> = None;
let mut quadratic: Option<(IdentRef<'_>, u32)> = None;
for term in terms {
if let Expr::LitU32(value) = term {
if constant.replace(*value).is_some() {
return None;
}
continue;
}
if let Some((var, coeff)) = linear_u32_term(term) {
if linear.replace((var, coeff)).is_some() {
return None;
}
continue;
}
if let Some((var, coeff)) = quadratic_u32_term(term) {
if quadratic.replace((var, coeff)).is_some() {
return None;
}
continue;
}
return None;
}
let (linear_var, b) = linear?;
let (quadratic_var, a) = quadratic?;
if linear_var.0 != quadratic_var.0 {
return None;
}
let x = Expr::var(linear_var.0.as_str());
let ax_plus_b = Expr::add(Expr::mul(Expr::u32(a), x.clone()), Expr::u32(b));
Some(Expr::add(Expr::mul(ax_plus_b, x), Expr::u32(constant?)))
}
#[derive(Clone, Copy)]
struct IdentRef<'a>(&'a crate::ir::Ident);
fn collect_add_terms<'a>(expr: &'a Expr, out: &mut Vec<&'a Expr>) {
if let Expr::BinOp {
op: BinOp::Add,
left,
right,
} = expr
{
collect_add_terms(left, out);
collect_add_terms(right, out);
} else {
out.push(expr);
}
}
fn linear_u32_term(expr: &Expr) -> Option<(IdentRef<'_>, u32)> {
if let Expr::Var(name) = expr {
return Some((IdentRef(name), 1));
}
let mut factors = Vec::with_capacity(3);
collect_mul_factors(expr, &mut factors);
if factors.len() != 2 {
return None;
}
let mut coeff = None;
let mut var = None;
for factor in factors {
match factor {
Expr::LitU32(value) => coeff = Some(*value),
Expr::Var(name) => var = Some(IdentRef(name)),
_ => return None,
}
}
Some((var?, coeff?))
}
fn quadratic_u32_term(expr: &Expr) -> Option<(IdentRef<'_>, u32)> {
let mut factors = Vec::with_capacity(4);
collect_mul_factors(expr, &mut factors);
if !(factors.len() == 2 || factors.len() == 3) {
return None;
}
let mut coeff = 1u32;
let mut vars = Vec::with_capacity(2);
for factor in factors {
match factor {
Expr::LitU32(value) => coeff = *value,
Expr::Var(name) => vars.push(name),
_ => return None,
}
}
if vars.len() == 2 && vars[0] == vars[1] {
Some((IdentRef(vars[0]), coeff))
} else {
None
}
}
fn collect_mul_factors<'a>(expr: &'a Expr, out: &mut Vec<&'a Expr>) {
if let Expr::BinOp {
op: BinOp::Mul,
left,
right,
} = expr
{
collect_mul_factors(left, out);
collect_mul_factors(right, out);
} else {
out.push(expr);
}
}
pub(super) fn shift_add_decompose(x: &Expr, constant: &Expr) -> Option<Expr> {
let c = positive_u32_constant(constant)?;
if c <= 1 || c.is_power_of_two() {
return None;
}
if let Some(chain) = shift_add_chain(x, c) {
return Some(chain);
}
for hi in (1u32..=16).rev() {
let high = 1u32 << hi;
if high > c {
continue;
}
let remainder = c - high;
if remainder == 0 {
continue;
}
if remainder.is_power_of_two() {
let lo = remainder.trailing_zeros();
let lo_term = if lo == 0 {
x.clone()
} else {
Expr::shl(x.clone(), Expr::u32(lo))
};
return Some(Expr::add(Expr::shl(x.clone(), Expr::u32(hi)), lo_term));
}
}
for hi in (2u32..=16).rev() {
let high = 1u32 << hi;
if high <= c {
continue;
}
let deficit = high - c;
if deficit.is_power_of_two() && deficit < high {
let lo = deficit.trailing_zeros();
let lo_term = if lo == 0 {
x.clone()
} else {
Expr::shl(x.clone(), Expr::u32(lo))
};
return Some(Expr::sub(Expr::shl(x.clone(), Expr::u32(hi)), lo_term));
}
}
None
}
pub(super) fn shift_add_chain(x: &Expr, c: u32) -> Option<Expr> {
if c <= 1 || c.is_power_of_two() || operand_duplication_cost(x) > 1 {
return None;
}
let terms = naf_terms(c);
if terms.len() < 2 {
return None;
}
let cost = shift_add_cost(&terms);
if cost > MAX_SHIFT_ADD_CHAIN_COST {
return None;
}
let mut terms = terms;
terms.sort_unstable_by(|a, b| b.shift.cmp(&a.shift));
let mut iter = terms.into_iter();
let first = iter.next()?;
let mut acc = if first.sign > 0 {
shifted_term(x, first.shift)
} else {
Expr::sub(Expr::u32(0), shifted_term(x, first.shift))
};
for term in iter {
let rhs = shifted_term(x, term.shift);
acc = if term.sign > 0 {
Expr::add(acc, rhs)
} else {
Expr::sub(acc, rhs)
};
}
Some(acc)
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
struct SignedShiftTerm {
shift: u32,
sign: i8,
}
fn naf_terms(mut n: u32) -> Vec<SignedShiftTerm> {
let mut shift = 0u32;
let mut terms = Vec::with_capacity(n.count_ones() as usize + 1);
while n > 0 {
if n & 1 == 0 {
n >>= 1;
shift += 1;
continue;
}
let sign = if n & 3 == 1 || n == 1 { 1 } else { -1 };
terms.push(SignedShiftTerm { shift, sign });
if sign > 0 {
n -= 1;
} else {
n = n.wrapping_add(1);
}
n >>= 1;
shift += 1;
}
terms
}
fn shift_add_cost(terms: &[SignedShiftTerm]) -> u32 {
let shifts = terms.iter().filter(|term| term.shift != 0).count() as u32;
let combines = terms.len().saturating_sub(1) as u32;
shifts + combines
}
fn shifted_term(x: &Expr, shift: u32) -> Expr {
if shift == 0 {
x.clone()
} else {
Expr::shl(x.clone(), Expr::u32(shift))
}
}
fn operand_duplication_cost(expr: &Expr) -> u32 {
match expr {
Expr::LitU32(_)
| Expr::LitI32(_)
| Expr::LitF32(_)
| Expr::LitBool(_)
| Expr::Var(_)
| Expr::InvocationId { .. }
| Expr::WorkgroupId { .. }
| Expr::LocalId { .. }
| Expr::SubgroupLocalId
| Expr::SubgroupSize => 0,
Expr::Load { .. } | Expr::BufLen { .. } => 1,
_ => 2,
}
}
fn positive_u32_constant(expr: &Expr) -> Option<u32> {
match expr {
Expr::LitU32(v) => Some(*v),
Expr::LitI32(v) if *v > 0 => Some(*v as u32),
_ => None,
}
}
pub(super) fn reciprocal_constant_fold(left: &Expr, right: &Expr) -> Option<Expr> {
if let (Expr::LitF32(one), Expr::LitF32(c)) = (left, right) {
if *one == 1.0 && *c != 0.0 && c.is_finite() {
return Some(Expr::f32(1.0 / c));
}
}
None
}
pub(super) fn synthesize_fma_add(left: &Expr, right: &Expr) -> Option<Expr> {
if let Some((a, b)) = mul_terms(left) {
if is_float_expr(right) {
return Some(Expr::fma(a, b, right.clone()));
}
}
if let Some((a, b)) = mul_terms(right) {
if is_float_expr(left) {
return Some(Expr::fma(a, b, left.clone()));
}
}
if let Some((a, b)) = negated_mul_terms(left) {
if is_float_expr(right) {
return Some(Expr::fma(Expr::negate(a), b, right.clone()));
}
}
if let Some((a, b)) = negated_mul_terms(right) {
if is_float_expr(left) {
return Some(Expr::fma(Expr::negate(a), b, left.clone()));
}
}
None
}
pub(super) fn synthesize_fma_sub(left: &Expr, right: &Expr) -> Option<Expr> {
if let Some((a, b)) = mul_terms(left) {
if is_float_expr(right) {
return Some(Expr::fma(a, b, Expr::negate(right.clone())));
}
}
if let Some((a, b)) = mul_terms(right) {
if is_float_expr(left) {
return Some(Expr::fma(Expr::negate(a), b, left.clone()));
}
}
None
}
pub(super) fn power_of_two_shift(expr: &Expr) -> Option<u32> {
match expr {
Expr::LitU32(value) if value.is_power_of_two() => Some(value.trailing_zeros()),
Expr::LitI32(value) if *value > 0 && (*value as u32).is_power_of_two() => {
Some(value.trailing_zeros())
}
_ => None,
}
}
fn mul_terms(expr: &Expr) -> Option<(Expr, Expr)> {
match expr {
Expr::BinOp {
op: BinOp::Mul,
left,
right,
} => Some((left.as_ref().clone(), right.as_ref().clone())),
_ => None,
}
}
fn negated_mul_terms(expr: &Expr) -> Option<(Expr, Expr)> {
match expr {
Expr::UnOp {
op: UnOp::Negate,
operand,
} => mul_terms(operand),
_ => None,
}
}
fn is_float_expr(expr: &Expr) -> bool {
match expr {
Expr::LitF32(_) | Expr::Fma { .. } => true,
Expr::BinOp { left, .. } => is_float_expr(left),
Expr::UnOp { operand, .. } => is_float_expr(operand),
Expr::Cast { target, .. } => {
matches!(target, crate::ir::DataType::F32 | crate::ir::DataType::F64)
}
_ => false,
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(super) struct DivMagic {
pub magic: u32,
pub shift: u32,
pub needs_fixup: bool,
}
pub(super) fn compute_div_magic(d: u32) -> DivMagic {
debug_assert!(
d >= 2 && !d.is_power_of_two(),
"d must be >= 2 and not a power of 2"
);
let mut needs_fixup = false;
let nc = u32::MAX - (d.wrapping_neg() % d);
let mut p: u32 = 31;
let mut q1 = 0x8000_0000u32 / nc;
let mut r1 = 0x8000_0000u32 - q1 * nc;
let mut q2 = 0x7FFF_FFFFu32 / d;
let mut r2 = 0x7FFF_FFFFu32 - q2 * d;
loop {
p += 1;
if r1 >= nc - r1 {
q1 = q1.wrapping_shl(1).wrapping_add(1);
r1 = r1.wrapping_shl(1).wrapping_sub(nc);
} else {
q1 = q1.wrapping_shl(1);
r1 = r1.wrapping_shl(1);
}
if r2.wrapping_add(1) >= d.wrapping_sub(r2) {
if q2 >= 0x7FFF_FFFFu32 {
needs_fixup = true;
}
q2 = q2.wrapping_shl(1).wrapping_add(1);
r2 = r2.wrapping_shl(1).wrapping_add(1).wrapping_sub(d);
} else {
if q2 >= 0x8000_0000u32 {
needs_fixup = true;
}
q2 = q2.wrapping_shl(1);
r2 = r2.wrapping_shl(1).wrapping_add(1);
}
let delta = d.wrapping_sub(1).wrapping_sub(r2);
if !(p < 64 && (q1 < delta || (q1 == delta && r1 == 0))) {
break;
}
}
DivMagic {
magic: q2.wrapping_add(1),
shift: p - 32,
needs_fixup,
}
}
pub(super) fn granlund_montgomery_div(dividend: &Expr, d: u32) -> Option<Expr> {
if d <= 1 || d.is_power_of_two() {
return None;
}
let magic = compute_div_magic(d);
let n = dividend.clone();
if !magic.needs_fixup {
let hi = Expr::mulhi(n, Expr::u32(magic.magic));
if magic.shift == 0 {
Some(hi)
} else {
Some(Expr::shr(hi, Expr::u32(magic.shift)))
}
} else {
let t = Expr::mulhi(n.clone(), Expr::u32(magic.magic));
let n_minus_t = Expr::sub(n, t.clone());
let half = Expr::shr(n_minus_t, Expr::u32(1));
let sum = Expr::add(t, half);
let shift = magic.shift.saturating_sub(1);
if shift == 0 {
Some(sum)
} else {
Some(Expr::shr(sum, Expr::u32(shift)))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn granlund_montgomery_correctness() {
for d in 2u32..=1000 {
if d.is_power_of_two() {
continue;
}
let magic = compute_div_magic(d);
let test_values: Vec<u32> = vec![
0,
1,
2,
d - 1,
d,
d + 1,
d * 2,
d * 3,
1000,
65535,
65536,
u32::MAX / d * d, u32::MAX / d * d - 1, u32::MAX,
u32::MAX - 1,
];
for &n in &test_values {
let expected = n / d;
let actual = apply_magic(n, &magic);
assert_eq!(
actual, expected,
"n={n}, d={d}, magic={}, shift={}, fixup={}: got {actual}, expected {expected}",
magic.magic, magic.shift, magic.needs_fixup
);
}
}
}
#[test]
fn granlund_montgomery_extreme_divisors() {
for d in [3, 7, 10, 127, 255, 1000, 65535, 0x7FFF_FFFF, u32::MAX - 1] {
if d.is_power_of_two() || d <= 1 {
continue;
}
let magic = compute_div_magic(d);
for &n in &[0u32, 1, d - 1, d, d + 1, u32::MAX] {
assert_eq!(apply_magic(n, &magic), n / d, "n={n}, d={d}");
}
}
}
fn apply_magic(n: u32, magic: &DivMagic) -> u32 {
let hi = ((n as u64).wrapping_mul(magic.magic as u64) >> 32) as u32;
if !magic.needs_fixup {
hi >> magic.shift
} else {
let t = hi;
let half = (n.wrapping_sub(t)) >> 1;
(t.wrapping_add(half)) >> magic.shift.saturating_sub(1)
}
}
}