use crate::lower::LoweredOp;
use crate::units::{UnitError, Units};
impl LoweredOp {
pub fn check_units(&self, var_units: &[Units]) -> Result<Units, UnitError> {
match self {
Self::Const(_) | Self::NamedConst(_) => Ok(Units::DIMENSIONLESS),
Self::Var(i) => {
if *i >= var_units.len() {
Err(UnitError::VarIndexOutOfRange {
index: *i,
n_vars: var_units.len(),
})
} else {
Ok(var_units[*i])
}
}
Self::Neg(x) => x.check_units(var_units),
Self::Add(a, b) | Self::Sub(a, b) => {
let ua = a.check_units(var_units)?;
let ub = b.check_units(var_units)?;
if ua != ub {
Err(UnitError::IncompatibleAddSub {
left: ua,
right: ub,
})
} else {
Ok(ua)
}
}
Self::Mul(a, b) => Ok(a.check_units(var_units)?.mul(&b.check_units(var_units)?)),
Self::Div(a, b) => Ok(a.check_units(var_units)?.div(&b.check_units(var_units)?)),
Self::Pow(base, exp) => {
let base_units = base.check_units(var_units)?;
let exp_units = exp.check_units(var_units)?;
if !exp_units.is_dimensionless() {
return Err(UnitError::NonDimensionlessArgument {
op: "Pow(exponent)",
got: exp_units,
});
}
if base_units.is_dimensionless() {
return Ok(Units::DIMENSIONLESS);
}
match exp.as_ref() {
Self::Const(n) => {
let rounded = n.round() as i32;
if (n - rounded as f64).abs() > 1e-9 {
return Err(UnitError::NonRationalPower { base_units });
}
base_units
.pow_int(rounded)
.map_err(|_| UnitError::NonRationalPower { base_units })
}
Self::NamedConst(nc) => {
let v = nc.value();
let rounded = v.round() as i32;
if (v - rounded as f64).abs() > 1e-9 {
return Err(UnitError::NonRationalPower { base_units });
}
base_units
.pow_int(rounded)
.map_err(|_| UnitError::NonRationalPower { base_units })
}
_ => Err(UnitError::NonRationalPower { base_units }),
}
}
transcendental => {
let (inner, op_name): (&Self, &'static str) = match transcendental {
Self::Exp(x) => (x.as_ref(), "exp"),
Self::Ln(x) => (x.as_ref(), "ln"),
Self::Sin(x) => (x.as_ref(), "sin"),
Self::Cos(x) => (x.as_ref(), "cos"),
Self::Tan(x) => (x.as_ref(), "tan"),
Self::Sinh(x) => (x.as_ref(), "sinh"),
Self::Cosh(x) => (x.as_ref(), "cosh"),
Self::Tanh(x) => (x.as_ref(), "tanh"),
Self::Arcsin(x) => (x.as_ref(), "arcsin"),
Self::Arccos(x) => (x.as_ref(), "arccos"),
Self::Arctan(x) => (x.as_ref(), "arctan"),
Self::Arcsinh(x) => (x.as_ref(), "arcsinh"),
Self::Arccosh(x) => (x.as_ref(), "arccosh"),
Self::Arctanh(x) => (x.as_ref(), "arctanh"),
_ => unreachable!("all non-transcendental variants handled before this arm"),
};
let ux = inner.check_units(var_units)?;
if !ux.is_dimensionless() {
Err(UnitError::NonDimensionlessArgument {
op: op_name,
got: ux,
})
} else {
Ok(Units::DIMENSIONLESS)
}
}
}
}
}