use std::sync::Arc;
use morok_dtype::{DType, ScalarDType};
use crate::types::ConstValue;
use crate::uop::UOp;
pub fn scalar_dtype(dtype: &DType) -> DType {
match dtype {
DType::Ptr { base, .. } => scalar_dtype(base),
other => other.clone(),
}
}
pub fn ensure_scalar(d: &Arc<UOp>) -> Arc<UOp> {
let dtype = d.dtype();
let scalar = scalar_dtype(&dtype);
if dtype != scalar { d.cast(scalar) } else { d.clone() }
}
pub fn mantissa_bits(dtype: &DType) -> i64 {
let scalar = scalar_dtype(dtype);
match scalar.scalar() {
Some(ScalarDType::Float64) => 52,
Some(ScalarDType::Float32) => 23,
Some(ScalarDType::Float16) => 10,
_ => panic!("mantissa_bits: unsupported dtype {:?}", dtype),
}
}
pub fn exponent_bias(dtype: &DType) -> i64 {
let scalar = scalar_dtype(dtype);
match scalar.scalar() {
Some(ScalarDType::Float64) => 1023,
Some(ScalarDType::Float32) => 127,
Some(ScalarDType::Float16) => 15,
_ => panic!("exponent_bias: unsupported dtype {:?}", dtype),
}
}
pub fn exponent_mask(dtype: &DType) -> i64 {
let scalar = scalar_dtype(dtype);
match scalar.scalar() {
Some(ScalarDType::Float64) => 2047,
Some(ScalarDType::Float32) => 255,
Some(ScalarDType::Float16) => 31,
_ => panic!("exponent_mask: unsupported dtype {:?}", dtype),
}
}
pub fn float_to_int_dtype(dtype: &DType) -> DType {
let scalar = scalar_dtype(dtype);
match scalar.scalar() {
Some(ScalarDType::Float64) => DType::Int64,
Some(ScalarDType::Float32) => DType::Int32,
Some(ScalarDType::Float16) => DType::Int16,
_ => panic!("float_to_int_dtype: unsupported dtype {:?}", dtype),
}
}
pub fn int_to_float_dtype(dtype: &DType) -> DType {
let scalar = scalar_dtype(dtype);
match scalar.scalar() {
Some(ScalarDType::Int64) => DType::Float64,
Some(ScalarDType::Int32) => DType::Float32,
Some(ScalarDType::Int16) => DType::Float16,
_ => panic!("int_to_float_dtype: unsupported dtype {:?}", dtype),
}
}
pub fn shr(x: &Arc<UOp>, y: i64) -> Arc<UOp> {
if y == 0 {
return x.clone();
}
let shift = UOp::const_(x.dtype(), ConstValue::Int(y));
x.try_shr_op(&shift).expect("shr: shift failed")
}
pub fn shl(x: &Arc<UOp>, y: i64) -> Arc<UOp> {
if y == 0 {
return x.clone();
}
let shift = UOp::const_(x.dtype(), ConstValue::Int(y));
x.try_shl_op(&shift).expect("shl: shift failed")
}
pub fn and_const(x: &Arc<UOp>, mask: i64) -> Arc<UOp> {
let mask_uop = UOp::const_(x.dtype(), ConstValue::Int(mask));
x.try_and_op(&mask_uop).expect("and_const: failed")
}
pub fn poly_n(x: &Arc<UOp>, coeffs: &[f64]) -> Arc<UOp> {
assert!(!coeffs.is_empty(), "poly_n: need at least one coefficient");
let dtype = x.dtype();
let mut result = float_const(&dtype, 0.0);
for &coeff in coeffs {
let c = float_const(&dtype, coeff);
let mul = result.try_mul(x).expect("poly_n: mul failed");
result = mul.try_add(&c).expect("poly_n: add failed");
}
result
}
pub fn float_const(dtype: &DType, value: f64) -> Arc<UOp> {
let scalar = scalar_dtype(dtype);
UOp::const_(scalar, ConstValue::Float(value))
}
pub fn int_const(dtype: &DType, value: i64) -> Arc<UOp> {
let scalar = scalar_dtype(dtype);
UOp::const_(scalar, ConstValue::Int(value))
}
pub fn bool_const(value: bool) -> Arc<UOp> {
UOp::const_(DType::Bool, ConstValue::Bool(value))
}
pub fn rintk(d: &Arc<UOp>) -> Arc<UOp> {
let dtype = d.dtype();
let int_dtype = float_to_int_dtype(&dtype);
let zero = float_const(&dtype, 0.0);
let half = float_const(&dtype, 0.5);
let neg_half = float_const(&dtype, -0.5);
let is_neg = d.try_cmplt(&zero).expect("rintk: cmplt failed");
let adjustment = UOp::try_where(is_neg, neg_half, half).expect("rintk: where failed");
let adjusted = d.try_add(&adjustment).expect("rintk: add failed");
adjusted.cast(int_dtype)
}
pub fn pow2if(q: &Arc<UOp>, float_dtype: &DType) -> Arc<UOp> {
let int_dtype = float_to_int_dtype(float_dtype);
let bias = exponent_bias(float_dtype);
let mantissa = mantissa_bits(float_dtype);
let bias_const = int_const(&int_dtype, bias);
let q_int = if q.dtype() == int_dtype { q.clone() } else { q.cast(int_dtype.clone()) };
let biased = q_int.try_add(&bias_const).expect("pow2if: add failed");
let shifted = shl(&biased, mantissa);
shifted.bitcast(float_dtype.clone())
}
pub fn ldexp2k(d: &Arc<UOp>, e: &Arc<UOp>) -> Arc<UOp> {
let float_dtype = d.dtype();
let e_half = shr(e, 1);
let e_other = e.try_sub(&e_half).expect("ldexp2k: sub failed");
let pow_half = pow2if(&e_half, &float_dtype);
let pow_other = pow2if(&e_other, &float_dtype);
let step1 = d.try_mul(&pow_half).expect("ldexp2k: mul1 failed");
step1.try_mul(&pow_other).expect("ldexp2k: mul2 failed")
}
pub fn ldexp3k(d: &Arc<UOp>, e: &Arc<UOp>) -> Arc<UOp> {
let float_dtype = d.dtype();
let int_dtype = float_to_int_dtype(&float_dtype);
let mantissa = mantissa_bits(&float_dtype);
let d_bits = d.bitcast(int_dtype.clone());
let e_int = e.cast(int_dtype.clone());
let e_shifted = shl(&e_int, mantissa);
let result_bits = d_bits.try_add(&e_shifted).expect("ldexp3k: add failed");
result_bits.bitcast(float_dtype)
}
pub fn ilogb2k(d: &Arc<UOp>) -> Arc<UOp> {
let float_dtype = d.dtype();
let int_dtype = float_to_int_dtype(&float_dtype);
let mantissa = mantissa_bits(&float_dtype);
let mask = exponent_mask(&float_dtype);
let bias = exponent_bias(&float_dtype);
let d_bits = d.bitcast(int_dtype.clone());
let shifted = shr(&d_bits, mantissa);
let masked = and_const(&shifted, mask);
let bias_const = int_const(&int_dtype, bias);
masked.try_sub(&bias_const).expect("ilogb2k: sub failed")
}
pub fn lazy_map_numbers(
x: &Arc<UOp>,
inf_val: &Arc<UOp>,
neg_inf_val: &Arc<UOp>,
nan_val: &Arc<UOp>,
ratio: &Arc<UOp>,
) -> Arc<UOp> {
let dtype = x.dtype();
let pos_inf = float_const(&dtype, f64::INFINITY);
let neg_inf = float_const(&dtype, f64::NEG_INFINITY);
let not_pos_inf = x.try_cmpne(&pos_inf).expect("lazy_map: cmpne pos_inf");
let not_neg_inf = x.try_cmpne(&neg_inf).expect("lazy_map: cmpne neg_inf");
let is_nan = x.try_cmpne(x).expect("lazy_map: cmpne nan");
let inner = UOp::try_where(not_neg_inf, ratio.clone(), neg_inf_val.clone()).expect("lazy_map: where1");
let with_nan = UOp::try_where(is_nan, nan_val.clone(), inner).expect("lazy_map: where2");
UOp::try_where(not_pos_inf, with_nan, inf_val.clone()).expect("lazy_map: where3")
}