use std::cmp;
use std::cmp::Ordering;
use az::SaturatingAs;
use crate::diag::{bail, At, HintedString, SourceResult, StrResult};
use crate::eval::ops;
use crate::foundations::{cast, func, Decimal, IntoValue, Module, Scope, Value};
use crate::layout::{Angle, Fr, Length, Ratio};
use crate::syntax::{Span, Spanned};
use crate::utils::{round_int_with_precision, round_with_precision};
pub fn module() -> Module {
let mut scope = Scope::new();
scope.define_func::<abs>();
scope.define_func::<pow>();
scope.define_func::<exp>();
scope.define_func::<sqrt>();
scope.define_func::<root>();
scope.define_func::<sin>();
scope.define_func::<cos>();
scope.define_func::<tan>();
scope.define_func::<asin>();
scope.define_func::<acos>();
scope.define_func::<atan>();
scope.define_func::<atan2>();
scope.define_func::<sinh>();
scope.define_func::<cosh>();
scope.define_func::<tanh>();
scope.define_func::<log>();
scope.define_func::<ln>();
scope.define_func::<fact>();
scope.define_func::<perm>();
scope.define_func::<binom>();
scope.define_func::<gcd>();
scope.define_func::<lcm>();
scope.define_func::<floor>();
scope.define_func::<ceil>();
scope.define_func::<trunc>();
scope.define_func::<fract>();
scope.define_func::<round>();
scope.define_func::<clamp>();
scope.define_func::<min>();
scope.define_func::<max>();
scope.define_func::<even>();
scope.define_func::<odd>();
scope.define_func::<rem>();
scope.define_func::<div_euclid>();
scope.define_func::<rem_euclid>();
scope.define_func::<quo>();
scope.define("inf", f64::INFINITY);
scope.define("pi", std::f64::consts::PI);
scope.define("tau", std::f64::consts::TAU);
scope.define("e", std::f64::consts::E);
Module::new("calc", scope)
}
#[func(title = "Absolute")]
pub fn abs(
value: ToAbs,
) -> Value {
value.0
}
pub struct ToAbs(Value);
cast! {
ToAbs,
v: i64 => Self(v.abs().into_value()),
v: f64 => Self(v.abs().into_value()),
v: Length => Self(Value::Length(v.try_abs()
.ok_or("cannot take absolute value of this length")?)),
v: Angle => Self(Value::Angle(v.abs())),
v: Ratio => Self(Value::Ratio(v.abs())),
v: Fr => Self(Value::Fraction(v.abs())),
v: Decimal => Self(Value::Decimal(v.abs()))
}
#[func(title = "Power")]
pub fn pow(
span: Span,
base: DecNum,
exponent: Spanned<Num>,
) -> SourceResult<DecNum> {
match exponent.v {
_ if exponent.v.float() == 0.0 && base.is_zero() => {
bail!(span, "zero to the power of zero is undefined")
}
Num::Int(i) if i32::try_from(i).is_err() => {
bail!(exponent.span, "exponent is too large")
}
Num::Float(f) if !f.is_normal() && f != 0.0 => {
bail!(exponent.span, "exponent may not be infinite, subnormal, or NaN")
}
_ => {}
};
match (base, exponent.v) {
(DecNum::Int(a), Num::Int(b)) if b >= 0 => a
.checked_pow(b as u32)
.map(DecNum::Int)
.ok_or_else(too_large)
.at(span),
(DecNum::Decimal(a), Num::Int(b)) => {
a.checked_powi(b).map(DecNum::Decimal).ok_or_else(too_large).at(span)
}
(a, b) => {
let Some(a) = a.float() else {
return Err(cant_apply_to_decimal_and_float()).at(span);
};
let result = if a == std::f64::consts::E {
b.float().exp()
} else if a == 2.0 {
b.float().exp2()
} else if let Num::Int(b) = b {
a.powi(b as i32)
} else {
a.powf(b.float())
};
if result.is_nan() {
bail!(span, "the result is not a real number")
}
Ok(DecNum::Float(result))
}
}
}
#[func(title = "Exponential")]
pub fn exp(
span: Span,
exponent: Spanned<Num>,
) -> SourceResult<f64> {
match exponent.v {
Num::Int(i) if i32::try_from(i).is_err() => {
bail!(exponent.span, "exponent is too large")
}
Num::Float(f) if !f.is_normal() && f != 0.0 => {
bail!(exponent.span, "exponent may not be infinite, subnormal, or NaN")
}
_ => {}
}
let result = exponent.v.float().exp();
if result.is_nan() {
bail!(span, "the result is not a real number")
}
Ok(result)
}
#[func(title = "Square Root")]
pub fn sqrt(
value: Spanned<Num>,
) -> SourceResult<f64> {
if value.v.float() < 0.0 {
bail!(value.span, "cannot take square root of negative number");
}
Ok(value.v.float().sqrt())
}
#[func]
pub fn root(
radicand: f64,
index: Spanned<i64>,
) -> SourceResult<f64> {
if index.v == 0 {
bail!(index.span, "cannot take the 0th root of a number");
} else if radicand < 0.0 {
if index.v % 2 == 0 {
bail!(
index.span,
"negative numbers do not have a real nth root when n is even"
);
} else {
Ok(-(-radicand).powf(1.0 / index.v as f64))
}
} else {
Ok(radicand.powf(1.0 / index.v as f64))
}
}
#[func(title = "Sine")]
pub fn sin(
angle: AngleLike,
) -> f64 {
match angle {
AngleLike::Angle(a) => a.sin(),
AngleLike::Int(n) => (n as f64).sin(),
AngleLike::Float(n) => n.sin(),
}
}
#[func(title = "Cosine")]
pub fn cos(
angle: AngleLike,
) -> f64 {
match angle {
AngleLike::Angle(a) => a.cos(),
AngleLike::Int(n) => (n as f64).cos(),
AngleLike::Float(n) => n.cos(),
}
}
#[func(title = "Tangent")]
pub fn tan(
angle: AngleLike,
) -> f64 {
match angle {
AngleLike::Angle(a) => a.tan(),
AngleLike::Int(n) => (n as f64).tan(),
AngleLike::Float(n) => n.tan(),
}
}
#[func(title = "Arcsine")]
pub fn asin(
value: Spanned<Num>,
) -> SourceResult<Angle> {
let val = value.v.float();
if val < -1.0 || val > 1.0 {
bail!(value.span, "value must be between -1 and 1");
}
Ok(Angle::rad(val.asin()))
}
#[func(title = "Arccosine")]
pub fn acos(
value: Spanned<Num>,
) -> SourceResult<Angle> {
let val = value.v.float();
if val < -1.0 || val > 1.0 {
bail!(value.span, "value must be between -1 and 1");
}
Ok(Angle::rad(val.acos()))
}
#[func(title = "Arctangent")]
pub fn atan(
value: Num,
) -> Angle {
Angle::rad(value.float().atan())
}
#[func(title = "Four-quadrant Arctangent")]
pub fn atan2(
x: Num,
y: Num,
) -> Angle {
Angle::rad(f64::atan2(y.float(), x.float()))
}
#[func(title = "Hyperbolic Sine")]
pub fn sinh(
value: f64,
) -> f64 {
value.sinh()
}
#[func(title = "Hyperbolic Cosine")]
pub fn cosh(
value: f64,
) -> f64 {
value.cosh()
}
#[func(title = "Hyperbolic Tangent")]
pub fn tanh(
value: f64,
) -> f64 {
value.tanh()
}
#[func(title = "Logarithm")]
pub fn log(
span: Span,
value: Spanned<Num>,
#[named]
#[default(Spanned::new(10.0, Span::detached()))]
base: Spanned<f64>,
) -> SourceResult<f64> {
let number = value.v.float();
if number <= 0.0 {
bail!(value.span, "value must be strictly positive")
}
if !base.v.is_normal() {
bail!(base.span, "base may not be zero, NaN, infinite, or subnormal")
}
let result = if base.v == std::f64::consts::E {
number.ln()
} else if base.v == 2.0 {
number.log2()
} else if base.v == 10.0 {
number.log10()
} else {
number.log(base.v)
};
if result.is_infinite() || result.is_nan() {
bail!(span, "the result is not a real number")
}
Ok(result)
}
#[func(title = "Natural Logarithm")]
pub fn ln(
span: Span,
value: Spanned<Num>,
) -> SourceResult<f64> {
let number = value.v.float();
if number <= 0.0 {
bail!(value.span, "value must be strictly positive")
}
let result = number.ln();
if result.is_infinite() {
bail!(span, "result close to -inf")
}
Ok(result)
}
#[func(title = "Factorial")]
pub fn fact(
number: u64,
) -> StrResult<i64> {
Ok(fact_impl(1, number).ok_or_else(too_large)?)
}
#[func(title = "Permutation")]
pub fn perm(
base: u64,
numbers: u64,
) -> StrResult<i64> {
if base < numbers {
return Ok(0);
}
Ok(fact_impl(base - numbers + 1, base).ok_or_else(too_large)?)
}
fn fact_impl(start: u64, end: u64) -> Option<i64> {
if end + 1 < start {
return Some(0);
}
let real_start: u64 = cmp::max(1, start);
let mut count: u64 = 1;
for i in real_start..=end {
count = count.checked_mul(i)?;
}
count.try_into().ok()
}
#[func(title = "Binomial")]
pub fn binom(
n: u64,
k: u64,
) -> StrResult<i64> {
Ok(binom_impl(n, k).ok_or_else(too_large)?)
}
fn binom_impl(n: u64, k: u64) -> Option<i64> {
if k > n {
return Some(0);
}
let real_k = cmp::min(n - k, k);
if real_k == 0 {
return Some(1);
}
let mut result: u64 = 1;
for i in 0..real_k {
result = result.checked_mul(n - i)?.checked_div(i + 1)?;
}
result.try_into().ok()
}
#[func(title = "Greatest Common Divisor")]
pub fn gcd(
a: i64,
b: i64,
) -> i64 {
let (mut a, mut b) = (a, b);
while b != 0 {
let temp = b;
b = a % b;
a = temp;
}
a.abs()
}
#[func(title = "Least Common Multiple")]
pub fn lcm(
a: i64,
b: i64,
) -> StrResult<i64> {
if a == b {
return Ok(a.abs());
}
Ok(a.checked_div(gcd(a, b))
.and_then(|gcd| gcd.checked_mul(b))
.map(|v| v.abs())
.ok_or_else(too_large)?)
}
#[func]
pub fn floor(
value: DecNum,
) -> StrResult<i64> {
match value {
DecNum::Int(n) => Ok(n),
DecNum::Float(n) => Ok(crate::foundations::convert_float_to_int(n.floor())
.map_err(|_| too_large())?),
DecNum::Decimal(n) => Ok(i64::try_from(n.floor()).map_err(|_| too_large())?),
}
}
#[func]
pub fn ceil(
value: DecNum,
) -> StrResult<i64> {
match value {
DecNum::Int(n) => Ok(n),
DecNum::Float(n) => Ok(crate::foundations::convert_float_to_int(n.ceil())
.map_err(|_| too_large())?),
DecNum::Decimal(n) => Ok(i64::try_from(n.ceil()).map_err(|_| too_large())?),
}
}
#[func(title = "Truncate")]
pub fn trunc(
value: DecNum,
) -> StrResult<i64> {
match value {
DecNum::Int(n) => Ok(n),
DecNum::Float(n) => Ok(crate::foundations::convert_float_to_int(n.trunc())
.map_err(|_| too_large())?),
DecNum::Decimal(n) => Ok(i64::try_from(n.trunc()).map_err(|_| too_large())?),
}
}
#[func(title = "Fractional")]
pub fn fract(
value: DecNum,
) -> DecNum {
match value {
DecNum::Int(_) => DecNum::Int(0),
DecNum::Float(n) => DecNum::Float(n.fract()),
DecNum::Decimal(n) => DecNum::Decimal(n.fract()),
}
}
#[func]
pub fn round(
value: DecNum,
#[named]
#[default(0)]
digits: i64,
) -> StrResult<DecNum> {
match value {
DecNum::Int(n) => Ok(DecNum::Int(
round_int_with_precision(n, digits.saturating_as::<i16>())
.ok_or_else(too_large)?,
)),
DecNum::Float(n) => {
Ok(DecNum::Float(round_with_precision(n, digits.saturating_as::<i16>())))
}
DecNum::Decimal(n) => Ok(DecNum::Decimal(
n.round(digits.saturating_as::<i32>()).ok_or_else(too_large)?,
)),
}
}
#[func]
pub fn clamp(
span: Span,
value: DecNum,
min: DecNum,
max: Spanned<DecNum>,
) -> SourceResult<DecNum> {
if min
.apply2(max.v, |min, max| max < min, |min, max| max < min, |min, max| max < min)
.unwrap_or(false)
{
bail!(max.span, "max must be greater than or equal to min")
}
value
.apply3(min, max.v, i64::clamp, f64::clamp, Decimal::clamp)
.ok_or_else(cant_apply_to_decimal_and_float)
.at(span)
}
#[func(title = "Minimum")]
pub fn min(
span: Span,
#[variadic]
values: Vec<Spanned<Value>>,
) -> SourceResult<Value> {
minmax(span, values, Ordering::Less)
}
#[func(title = "Maximum")]
pub fn max(
span: Span,
#[variadic]
values: Vec<Spanned<Value>>,
) -> SourceResult<Value> {
minmax(span, values, Ordering::Greater)
}
fn minmax(
span: Span,
values: Vec<Spanned<Value>>,
goal: Ordering,
) -> SourceResult<Value> {
let mut iter = values.into_iter();
let Some(Spanned { v: mut extremum, .. }) = iter.next() else {
bail!(span, "expected at least one value");
};
for Spanned { v, span } in iter {
let ordering = ops::compare(&v, &extremum).at(span)?;
if ordering == goal {
extremum = v;
}
}
Ok(extremum)
}
#[func]
pub fn even(
value: i64,
) -> bool {
value % 2 == 0
}
#[func]
pub fn odd(
value: i64,
) -> bool {
value % 2 != 0
}
#[func(title = "Remainder")]
pub fn rem(
span: Span,
dividend: DecNum,
divisor: Spanned<DecNum>,
) -> SourceResult<DecNum> {
if divisor.v.is_zero() {
bail!(divisor.span, "divisor must not be zero");
}
dividend
.apply2(
divisor.v,
|a, b| Some(DecNum::Int(a % b)),
|a, b| Some(DecNum::Float(a % b)),
|a, b| a.checked_rem(b).map(DecNum::Decimal),
)
.ok_or_else(cant_apply_to_decimal_and_float)
.at(span)?
.ok_or("dividend too small compared to divisor")
.at(span)
}
#[func(title = "Euclidean Division")]
pub fn div_euclid(
span: Span,
dividend: DecNum,
divisor: Spanned<DecNum>,
) -> SourceResult<DecNum> {
if divisor.v.is_zero() {
bail!(divisor.span, "divisor must not be zero");
}
dividend
.apply2(
divisor.v,
|a, b| Some(DecNum::Int(a.div_euclid(b))),
|a, b| Some(DecNum::Float(a.div_euclid(b))),
|a, b| a.checked_div_euclid(b).map(DecNum::Decimal),
)
.ok_or_else(cant_apply_to_decimal_and_float)
.at(span)?
.ok_or_else(too_large)
.at(span)
}
#[func(title = "Euclidean Remainder")]
pub fn rem_euclid(
span: Span,
dividend: DecNum,
divisor: Spanned<DecNum>,
) -> SourceResult<DecNum> {
if divisor.v.is_zero() {
bail!(divisor.span, "divisor must not be zero");
}
dividend
.apply2(
divisor.v,
|a, b| Some(DecNum::Int(a.rem_euclid(b))),
|a, b| Some(DecNum::Float(a.rem_euclid(b))),
|a, b| a.checked_rem_euclid(b).map(DecNum::Decimal),
)
.ok_or_else(cant_apply_to_decimal_and_float)
.at(span)?
.ok_or("dividend too small compared to divisor")
.at(span)
}
#[func(title = "Quotient")]
pub fn quo(
span: Span,
dividend: DecNum,
divisor: Spanned<DecNum>,
) -> SourceResult<i64> {
if divisor.v.is_zero() {
bail!(divisor.span, "divisor must not be zero");
}
let divided = dividend
.apply2(
divisor.v,
|a, b| Some(DecNum::Int(a / b)),
|a, b| Some(DecNum::Float(a / b)),
|a, b| a.checked_div(b).map(DecNum::Decimal),
)
.ok_or_else(cant_apply_to_decimal_and_float)
.at(span)?
.ok_or_else(too_large)
.at(span)?;
floor(divided).at(span)
}
#[derive(Debug, Copy, Clone)]
pub enum Num {
Int(i64),
Float(f64),
}
impl Num {
fn float(self) -> f64 {
match self {
Self::Int(v) => v as f64,
Self::Float(v) => v,
}
}
}
cast! {
Num,
self => match self {
Self::Int(v) => v.into_value(),
Self::Float(v) => v.into_value(),
},
v: i64 => Self::Int(v),
v: f64 => Self::Float(v),
}
#[derive(Debug, Copy, Clone)]
pub enum DecNum {
Int(i64),
Float(f64),
Decimal(Decimal),
}
impl DecNum {
fn is_zero(self) -> bool {
match self {
Self::Int(i) => i == 0,
Self::Float(f) => f == 0.0,
Self::Decimal(d) => d.is_zero(),
}
}
fn float(self) -> Option<f64> {
match self {
Self::Int(i) => Some(i as f64),
Self::Float(f) => Some(f),
Self::Decimal(_) => None,
}
}
fn decimal(self) -> Option<Decimal> {
match self {
Self::Int(i) => Some(Decimal::from(i)),
Self::Float(_) => None,
Self::Decimal(d) => Some(d),
}
}
fn apply2<T>(
self,
other: Self,
int: impl FnOnce(i64, i64) -> T,
float: impl FnOnce(f64, f64) -> T,
decimal: impl FnOnce(Decimal, Decimal) -> T,
) -> Option<T> {
match (self, other) {
(Self::Int(a), Self::Int(b)) => Some(int(a, b)),
(Self::Decimal(a), Self::Decimal(b)) => Some(decimal(a, b)),
(Self::Decimal(a), Self::Int(b)) => Some(decimal(a, Decimal::from(b))),
(Self::Int(a), Self::Decimal(b)) => Some(decimal(Decimal::from(a), b)),
(a, b) => Some(float(a.float()?, b.float()?)),
}
}
fn apply3(
self,
other: Self,
third: Self,
int: impl FnOnce(i64, i64, i64) -> i64,
float: impl FnOnce(f64, f64, f64) -> f64,
decimal: impl FnOnce(Decimal, Decimal, Decimal) -> Decimal,
) -> Option<Self> {
match (self, other, third) {
(Self::Int(a), Self::Int(b), Self::Int(c)) => Some(Self::Int(int(a, b, c))),
(Self::Decimal(a), b, c) => {
Some(Self::Decimal(decimal(a, b.decimal()?, c.decimal()?)))
}
(a, Self::Decimal(b), c) => {
Some(Self::Decimal(decimal(a.decimal()?, b, c.decimal()?)))
}
(a, b, Self::Decimal(c)) => {
Some(Self::Decimal(decimal(a.decimal()?, b.decimal()?, c)))
}
(a, b, c) => Some(Self::Float(float(a.float()?, b.float()?, c.float()?))),
}
}
}
cast! {
DecNum,
self => match self {
Self::Int(v) => v.into_value(),
Self::Float(v) => v.into_value(),
Self::Decimal(v) => v.into_value(),
},
v: i64 => Self::Int(v),
v: f64 => Self::Float(v),
v: Decimal => Self::Decimal(v),
}
pub enum AngleLike {
Int(i64),
Float(f64),
Angle(Angle),
}
cast! {
AngleLike,
v: i64 => Self::Int(v),
v: f64 => Self::Float(v),
v: Angle => Self::Angle(v),
}
#[cold]
fn too_large() -> &'static str {
"the result is too large"
}
#[cold]
fn cant_apply_to_decimal_and_float() -> HintedString {
HintedString::new("cannot apply this operation to a decimal and a float".into())
.with_hint(
"if loss of precision is acceptable, explicitly cast the \
decimal to a float with `float(value)`",
)
}