use std::fmt::{self, Display, Formatter};
use std::hash::{Hash, Hasher};
use std::ops::Neg;
use std::str::FromStr;
use ecow::{EcoString, eco_format};
use rust_decimal::MathematicalOps;
use typst_syntax::{Span, Spanned, ast};
use crate::World;
use crate::diag::{At, SourceResult, warning};
use crate::engine::Engine;
use crate::foundations::{Repr, Str, cast, func, repr, scope, ty};
#[ty(scope, cast)]
#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, Ord, PartialOrd)]
pub struct Decimal(rust_decimal::Decimal);
impl Decimal {
pub const ZERO: Self = Self(rust_decimal::Decimal::ZERO);
pub const ONE: Self = Self(rust_decimal::Decimal::ONE);
pub const MIN: Self = Self(rust_decimal::Decimal::MIN);
pub const MAX: Self = Self(rust_decimal::Decimal::MAX);
pub const fn is_zero(self) -> bool {
self.0.is_zero()
}
pub const fn is_negative(self) -> bool {
self.0.is_sign_negative()
}
pub fn is_integer(self) -> bool {
self.0.is_integer()
}
pub fn abs(self) -> Self {
Self(self.0.abs())
}
pub fn floor(self) -> Self {
Self(self.0.floor())
}
pub fn ceil(self) -> Self {
Self(self.0.ceil())
}
pub fn trunc(self) -> Self {
Self(self.0.trunc())
}
pub fn fract(self) -> Self {
Self(self.0.fract())
}
pub fn round(self, digits: i32) -> Option<Self> {
if let Ok(positive_digits) = u32::try_from(digits) {
return Some(Self(self.0.round_dp_with_strategy(
positive_digits,
rust_decimal::RoundingStrategy::MidpointAwayFromZero,
)));
}
let mut num = self.0;
let old_scale = num.scale();
let digits = -digits as u32;
let (Ok(_), Some(ten_to_digits)) = (
num.set_scale(old_scale + digits),
rust_decimal::Decimal::TEN.checked_powi(digits as i64),
) else {
let mut zero = rust_decimal::Decimal::ZERO;
zero.set_sign_negative(self.is_negative());
return Some(Self(zero));
};
num = num.round_dp_with_strategy(
0,
rust_decimal::RoundingStrategy::MidpointAwayFromZero,
);
num.checked_mul(ten_to_digits).map(Self)
}
pub fn checked_add(self, other: Self) -> Option<Self> {
self.0.checked_add(other.0).map(Self)
}
pub fn checked_sub(self, other: Self) -> Option<Self> {
self.0.checked_sub(other.0).map(Self)
}
pub fn checked_mul(self, other: Self) -> Option<Self> {
self.0.checked_mul(other.0).map(Self)
}
pub fn checked_div(self, other: Self) -> Option<Self> {
self.0.checked_div(other.0).map(Self)
}
pub fn checked_div_euclid(self, other: Self) -> Option<Self> {
let q = self.0.checked_div(other.0)?.trunc();
if self
.0
.checked_rem(other.0)
.as_ref()
.is_some_and(rust_decimal::Decimal::is_sign_negative)
{
return if other.0.is_sign_positive() {
q.checked_sub(rust_decimal::Decimal::ONE).map(Self)
} else {
q.checked_add(rust_decimal::Decimal::ONE).map(Self)
};
}
Some(Self(q))
}
pub fn checked_rem_euclid(self, other: Self) -> Option<Self> {
let r = self.0.checked_rem(other.0)?;
Some(Self(if r.is_sign_negative() { r.checked_add(other.0.abs())? } else { r }))
}
pub fn checked_rem(self, other: Self) -> Option<Self> {
self.0.checked_rem(other.0).map(Self)
}
pub fn checked_powi(self, other: i64) -> Option<Self> {
self.0.checked_powi(other).map(Self)
}
}
#[scope]
impl Decimal {
#[func(constructor)]
pub fn construct(
engine: &mut Engine,
value: Spanned<ToDecimal>,
) -> SourceResult<Decimal> {
match value.v {
ToDecimal::Str(str) => Self::from_str(&str.replace(repr::MINUS_SIGN, "-"))
.map_err(|_| eco_format!("invalid decimal: {str}"))
.at(value.span),
ToDecimal::Int(int) => Ok(Self::from(int)),
ToDecimal::Float(float) => {
warn_on_float_literal(engine, value.span);
Self::try_from(float)
.map_err(|_| {
eco_format!(
"float is not a valid decimal: {}",
repr::format_float(float, None, true, "")
)
})
.at(value.span)
}
ToDecimal::Decimal(decimal) => Ok(decimal),
}
}
}
fn warn_on_float_literal(engine: &mut Engine, span: Span) -> Option<()> {
let id = span.id()?;
let source = engine.world.source(id).ok()?;
let node = source.find(span)?;
if node.is::<ast::Float>() {
engine.sink.warn(warning!(
span,
"creating a decimal using imprecise float literal";
hint: "use a string in the decimal constructor to avoid loss \
of precision: `decimal({})`",
node.text().repr()
));
}
Some(())
}
impl FromStr for Decimal {
type Err = rust_decimal::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
rust_decimal::Decimal::from_str_exact(s).map(Self)
}
}
impl From<i64> for Decimal {
fn from(value: i64) -> Self {
Self(rust_decimal::Decimal::from(value))
}
}
impl TryFrom<f64> for Decimal {
type Error = ();
fn try_from(value: f64) -> Result<Self, Self::Error> {
rust_decimal::Decimal::from_f64_retain(value).map(Self).ok_or(())
}
}
impl TryFrom<Decimal> for f64 {
type Error = rust_decimal::Error;
fn try_from(value: Decimal) -> Result<Self, Self::Error> {
value.0.try_into()
}
}
impl TryFrom<Decimal> for i64 {
type Error = rust_decimal::Error;
fn try_from(value: Decimal) -> Result<Self, Self::Error> {
value.0.try_into()
}
}
impl Display for Decimal {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
if self.0.is_sign_negative() {
f.write_str(repr::MINUS_SIGN)?;
}
self.0.abs().fmt(f)
}
}
impl Repr for Decimal {
fn repr(&self) -> EcoString {
eco_format!("decimal({})", eco_format!("{}", self.0).repr())
}
}
impl Neg for Decimal {
type Output = Self;
fn neg(self) -> Self {
Self(-self.0)
}
}
impl Hash for Decimal {
fn hash<H: Hasher>(&self, state: &mut H) {
self.0.serialize().hash(state);
}
}
pub enum ToDecimal {
Decimal(Decimal),
Str(EcoString),
Int(i64),
Float(f64),
}
cast! {
ToDecimal,
v: Decimal => Self::Decimal(v),
v: i64 => Self::Int(v),
v: bool => Self::Int(v as i64),
v: f64 => Self::Float(v),
v: Str => Self::Str(EcoString::from(v)),
}
#[cfg(test)]
mod tests {
use std::str::FromStr;
use typst_utils::hash128;
use super::Decimal;
#[test]
fn test_decimals_with_equal_scales_hash_identically() {
let a = Decimal::from_str("3.14").unwrap();
let b = Decimal::from_str("3.14").unwrap();
assert_eq!(a, b);
assert_eq!(hash128(&a), hash128(&b));
}
#[test]
fn test_decimals_with_different_scales_hash_differently() {
let a = Decimal::from_str("3.140").unwrap();
let b = Decimal::from_str("3.14000").unwrap();
assert_eq!(a, b);
assert_ne!(hash128(&a), hash128(&b));
}
#[track_caller]
fn test_round(value: &str, digits: i32, expected: &str) {
assert_eq!(
Decimal::from_str(value).unwrap().round(digits),
Some(Decimal::from_str(expected).unwrap()),
);
}
#[test]
fn test_decimal_positive_round() {
test_round("312.55553", 0, "313.00000");
test_round("312.55553", 3, "312.556");
test_round("312.5555300000", 3, "312.556");
test_round("-312.55553", 3, "-312.556");
test_round("312.55553", 28, "312.55553");
test_round("312.55553", 2341, "312.55553");
test_round("-312.55553", 2341, "-312.55553");
}
#[test]
fn test_decimal_negative_round() {
test_round("4596.55553", -1, "4600");
test_round("4596.555530000000", -1, "4600");
test_round("-4596.55553", -3, "-5000");
test_round("4596.55553", -28, "0");
test_round("-4596.55553", -2341, "0");
assert_eq!(Decimal::MAX.round(-1), None);
}
}