use crate::{
Dialect,
shared::{Elem, Item},
};
#[allow(clippy::enum_variant_names)]
#[derive(Debug, Clone, Default, PartialEq)]
pub enum Extension<D: Dialect> {
Erf(Elem<D>, Elem<D>),
Ffs(Elem<D>),
MulHi(Elem<D>),
SafeTanh(Item<D>),
#[default]
NoExtension,
}
pub fn format_erf<D: Dialect>(
f: &mut core::fmt::Formatter<'_>,
input_elem: &Elem<D>,
out_elem: &Elem<D>,
) -> core::fmt::Result {
write!(
f,
"
// Abramowitz and Stegun approximation for erf(x)
inline {out_elem} erf({input_elem} x) {{
const float a1 = 0.254829592f;
const float a2 = -0.284496736f;
const float a3 = 1.421413741f;
const float a4 = -1.453152027f;
const float a5 = 1.061405429f;
const float p = 0.3275911f;
float sign = (x >= 0.0f) ? 1.0f : -1.0f;
x = fabs(x);
float t = 1.0f / (1.0f + p * x);
float y = 1.0f - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * exp(-x * x);
return sign * y;
}}
",
)
}
pub fn format_ffs<D: Dialect>(
f: &mut core::fmt::Formatter<'_>,
input_elem: &Elem<D>,
) -> core::fmt::Result {
match input_elem {
Elem::I32 => write!(
f,
"
int __ffs(int x) {{
return __ffs(static_cast<uint>(x));
}}
"
),
Elem::U32 => write!(
f,
"
uint __ffs(uint x) {{
return x == 0 ? 0 : 32 - clz(x & -x);
}}
"
),
Elem::I64 => write!(
f,
"
int __ffsll(long x) {{
return __ffsll(static_cast<ulong>(x));
}}
"
),
Elem::U64 => write!(
f,
"
uint __ffsll(ulong x) {{
return x == 0 ? 0 : 64 - clz(x & -x);
}}
"
),
_ => Ok(()),
}
}
pub fn format_mulhi<D: Dialect>(
f: &mut core::fmt::Formatter<'_>,
out_elem: &Elem<D>,
) -> core::fmt::Result {
match out_elem {
Elem::I32 => write!(
f,
"
int32_t __mulhi(int32_t a, int32_t b) {{
int64_t product = static_cast<int64_t>(a) * static_cast<int64_t>(b);
return static_cast<int32_t>(product >> 32);
}}
"
),
Elem::U32 => write!(
f,
"
uint32_t __umulhi(uint32_t a, uint32_t b) {{
uint64_t product = static_cast<uint64_t>(a) * static_cast<uint64_t>(b);
return static_cast<uint32_t>(product >> 32);
}}
"
),
Elem::I64 => write!(
f,
"
int64_t __mul64hi(int64_t a, int64_t b) {{
// Determine the sign of the result
bool negative = (a < 0) != (b < 0);
// Compute absolute values
uint64_t ua = static_cast<uint64_t>(a < 0 ? -a : a);
uint64_t ub = static_cast<uint64_t>(b < 0 ? -b : b);
// Perform unsigned high multiplication
uint64_t high = __umul64hi(ua, ub);
// Adjust sign if necessary
return negative ? -static_cast<int64_t>(high) : static_cast<int64_t>(high);
}}
"
),
Elem::U64 => write!(
f,
"
uint64_t __umul64hi(uint64_t a, uint64_t b) {{
// Split the operands into high and low 32-bit parts
uint64_t a_lo = static_cast<uint32_t>(a);
uint64_t a_hi = a >> 32;
uint64_t b_lo = static_cast<uint32_t>(b);
uint64_t b_hi = b >> 32;
// Perform partial multiplications
uint64_t p0 = a_lo * b_lo;
uint64_t p1 = a_lo * b_hi;
uint64_t p2 = a_hi * b_lo;
uint64_t p3 = a_hi * b_hi;
// Combine the results
uint64_t mid = (p0 >> 32) + (p1 & 0xFFFFFFFF) + (p2 & 0xFFFFFFFF);
uint64_t high = p3 + (p1 >> 32) + (p2 >> 32) + (mid >> 32);
return high;
}}
"
),
_ => writeln!(f, "#error HiMul only supports 32 and 64 bit ints"),
}
}
pub fn format_safe_tanh<D: Dialect>(
f: &mut core::fmt::Formatter<'_>,
item: &Item<D>,
) -> core::fmt::Result {
let elem = item.elem();
write!(
f,
"
/// Metal has a weird numerical behaviour with tanh for inputs over 43.0
inline {elem} safe_tanh_scalar({elem} x) {{
if (x > 43.0) {{
return 1.0;
}} else {{
return tanh(x);
}}
}}
"
)?;
writeln!(f, "inline {item} safe_tanh({item} x) {{")?;
if item.vectorization == 1 {
writeln!(f, " return safe_tanh_scalar(x);")?;
} else {
write!(f, " return {item} {{ ")?;
for i in 0..item.vectorization {
let comma = if i != item.vectorization - 1 {
", "
} else {
""
};
write!(f, "safe_tanh_scalar(x.i_{i}){comma}")?;
}
writeln!(f, " }};")?;
}
writeln!(f, "}}")
}