use crate::base::{
math::BigDecimalExt,
scalar::{Scalar, ScalarConversionError},
};
use alloc::string::{String, ToString};
use bigdecimal::{BigDecimal, ParseBigDecimalError};
use serde::{Deserialize, Deserializer, Serialize};
use snafu::Snafu;
#[derive(Snafu, Debug, PartialEq)]
pub enum IntermediateDecimalError {
#[snafu(display("{error}"))]
ParseError {
error: ParseBigDecimalError,
},
#[snafu(display("Value out of range for target type"))]
OutOfRange,
#[snafu(display("Fractional part of decimal is non-zero"))]
LossyCast,
#[snafu(display("Conversion to integer failed"))]
ConversionFailure,
}
impl Eq for IntermediateDecimalError {}
#[derive(Snafu, Debug, Eq, PartialEq)]
pub enum DecimalError {
#[snafu(display("Invalid decimal format or value: {error}"))]
InvalidDecimal {
error: String,
},
#[snafu(display("Decimal precision is not valid: {error}"))]
InvalidPrecision {
error: String,
},
#[snafu(display("Decimal scale is not valid: {scale}"))]
InvalidScale {
scale: String,
},
#[snafu(display("Unsupported operation: cannot round decimal: {error}"))]
RoundingError {
error: String,
},
#[snafu(transparent)]
IntermediateDecimalConversionError {
source: IntermediateDecimalError,
},
}
pub type DecimalResult<T> = Result<T, DecimalError>;
impl From<DecimalError> for String {
fn from(error: DecimalError) -> Self {
error.to_string()
}
}
#[derive(Eq, PartialEq, Debug, Clone, Hash, Serialize, Copy)]
pub struct Precision(u8);
pub(crate) const MAX_SUPPORTED_PRECISION: u8 = 75;
impl Precision {
pub fn new(value: u8) -> Result<Self, DecimalError> {
if value > MAX_SUPPORTED_PRECISION || value == 0 {
Err(DecimalError::InvalidPrecision {
error: value.to_string(),
})
} else {
Ok(Precision(value))
}
}
#[must_use]
pub fn value(&self) -> u8 {
self.0
}
}
impl TryFrom<u64> for Precision {
type Error = DecimalError;
fn try_from(value: u64) -> Result<Self, Self::Error> {
Precision::new(
value
.try_into()
.map_err(|_| DecimalError::InvalidPrecision {
error: value.to_string(),
})?,
)
}
}
impl<'de> Deserialize<'de> for Precision {
fn deserialize<D>(deserializer: D) -> Result<Precision, D::Error>
where
D: Deserializer<'de>,
{
let value = u8::deserialize(deserializer)?;
Precision::new(value).map_err(serde::de::Error::custom)
}
}
pub(crate) fn try_convert_intermediate_decimal_to_scalar<S: Scalar>(
d: &BigDecimal,
target_precision: Precision,
target_scale: i8,
) -> DecimalResult<S> {
d.try_into_bigint_with_precision_and_scale(target_precision.value(), target_scale)?
.try_into()
.map_err(|e: ScalarConversionError| DecimalError::InvalidDecimal {
error: e.to_string(),
})
}
#[cfg(test)]
mod scale_adjust_test {
use super::*;
use crate::base::scalar::test_scalar::TestScalar;
use num_bigint::BigInt;
#[test]
fn we_cannot_scale_past_max_precision() {
let decimal = "12345678901234567890123456789012345678901234567890123456789012345678900.0"
.parse()
.unwrap();
let target_scale = 5;
assert!(try_convert_intermediate_decimal_to_scalar::<TestScalar>(
&decimal,
Precision::new(u8::try_from(decimal.precision()).unwrap_or(u8::MAX)).unwrap(),
target_scale
)
.is_err());
}
#[test]
fn we_can_match_exact_decimals_from_queries_to_db() {
let decimal: BigDecimal = "123.45".parse().unwrap();
let target_scale = 2;
let target_precision = 20;
let big_int =
decimal.try_into_bigint_with_precision_and_scale(target_precision, target_scale);
let expected_big_int: BigInt = "12345".parse().unwrap();
assert_eq!(big_int, Ok(expected_big_int));
}
#[test]
fn we_can_match_decimals_with_negative_scale() {
let decimal = "120.00".parse().unwrap();
let target_scale = -1;
let expected = [12, 0, 0, 0];
let result = try_convert_intermediate_decimal_to_scalar::<TestScalar>(
&decimal,
Precision::new(MAX_SUPPORTED_PRECISION).unwrap(),
target_scale,
)
.unwrap();
assert_eq!(result, TestScalar::from(expected));
}
#[test]
fn we_can_match_integers_with_negative_scale() {
let decimal = "12300".parse().unwrap();
let target_scale = -2;
let expected_limbs = [123, 0, 0, 0];
let limbs = try_convert_intermediate_decimal_to_scalar::<TestScalar>(
&decimal,
Precision::new(u8::try_from(decimal.precision()).unwrap_or(u8::MAX)).unwrap(),
target_scale,
)
.unwrap();
assert_eq!(limbs, TestScalar::from(expected_limbs));
}
#[test]
fn we_can_match_negative_decimals() {
let decimal = "-123.45".parse().unwrap();
let target_scale = 2;
let expected_limbs = [12345, 0, 0, 0];
let limbs = try_convert_intermediate_decimal_to_scalar::<TestScalar>(
&decimal,
Precision::new(u8::try_from(decimal.precision()).unwrap_or(u8::MAX)).unwrap(),
target_scale,
)
.unwrap();
assert_eq!(limbs, -TestScalar::from(expected_limbs));
}
#[allow(clippy::cast_possible_wrap)]
#[test]
fn we_can_match_decimals_at_extrema() {
let decimal = "1234567890123456789012345678901234567890123456789012345678901234567890.0"
.parse()
.unwrap();
let target_scale = 6; assert!(try_convert_intermediate_decimal_to_scalar::<TestScalar>(
&decimal,
Precision::new(u8::try_from(decimal.precision()).unwrap_or(u8::MAX),).unwrap(),
target_scale
)
.is_err());
let decimal =
"99999999999999999999999999999999999999999999999999999999999999999999999999.0"
.parse()
.unwrap();
let target_scale = 1;
assert!(try_convert_intermediate_decimal_to_scalar::<TestScalar>(
&decimal,
Precision::new(MAX_SUPPORTED_PRECISION).unwrap(),
target_scale
)
.is_ok());
let decimal =
"999999999999999999999999999999999999999999999999999999999999999999999999999.0"
.parse()
.unwrap();
let target_scale = 1;
assert!(try_convert_intermediate_decimal_to_scalar::<TestScalar>(
&decimal,
Precision::new(MAX_SUPPORTED_PRECISION).unwrap(),
target_scale
)
.is_err());
let decimal =
"0.000000000000000000000000000000000000000000000000000000000000000000000000001"
.parse()
.unwrap();
let target_scale = MAX_SUPPORTED_PRECISION as i8;
assert!(try_convert_intermediate_decimal_to_scalar::<TestScalar>(
&decimal,
Precision::new(u8::try_from(decimal.precision()).unwrap_or(u8::MAX),).unwrap(),
target_scale
)
.is_ok());
let decimal = "0.1".parse().unwrap();
let target_scale = MAX_SUPPORTED_PRECISION as i8;
assert!(try_convert_intermediate_decimal_to_scalar::<TestScalar>(
&decimal,
Precision::new(MAX_SUPPORTED_PRECISION).unwrap(),
target_scale
)
.is_ok());
let decimal = "1.0".parse().unwrap();
let target_scale = 75;
assert!(try_convert_intermediate_decimal_to_scalar::<TestScalar>(
&decimal,
Precision::new(u8::try_from(decimal.precision()).unwrap_or(u8::MAX),).unwrap(),
target_scale
)
.is_err());
let decimal = "1.0".parse().unwrap();
let target_scale = 74;
assert!(try_convert_intermediate_decimal_to_scalar::<TestScalar>(
&decimal,
Precision::new(MAX_SUPPORTED_PRECISION).unwrap(),
target_scale
)
.is_ok());
}
}