use crate::base::scalar::{Scalar, ScalarConversionError};
use alloc::{
format,
string::{String, ToString},
};
use proof_of_sql_parser::intermediate_decimal::{IntermediateDecimal, IntermediateDecimalError};
use serde::{Deserialize, Deserializer, Serialize};
use snafu::Snafu;
#[derive(Snafu, Debug, Eq, PartialEq)]
pub enum DecimalError {
#[snafu(display("Invalid decimal format or value: {error}"))]
InvalidDecimal {
error: String,
},
#[snafu(display("Decimal precision is not valid: {error}"))]
InvalidPrecision {
error: String,
},
#[snafu(display("Decimal scale is not valid: {scale}"))]
InvalidScale {
scale: i16,
},
#[snafu(display("Unsupported operation: cannot round decimal: {error}"))]
RoundingError {
error: String,
},
#[snafu(transparent)]
IntermediateDecimalConversionError {
source: IntermediateDecimalError,
},
}
pub type DecimalResult<T> = Result<T, DecimalError>;
impl From<DecimalError> for String {
fn from(error: DecimalError) -> Self {
error.to_string()
}
}
#[derive(Eq, PartialEq, Debug, Clone, Hash, Serialize, Copy)]
pub struct Precision(u8);
pub(crate) const MAX_SUPPORTED_PRECISION: u8 = 75;
impl Precision {
pub fn new(value: u8) -> Result<Self, DecimalError> {
if value > MAX_SUPPORTED_PRECISION || value == 0 {
Err(DecimalError::InvalidPrecision {
error: format!(
"Failed to parse precision. Value of {} exceeds max supported precision of {}",
value, MAX_SUPPORTED_PRECISION
),
})
} else {
Ok(Precision(value))
}
}
pub fn value(&self) -> u8 {
self.0
}
}
impl<'de> Deserialize<'de> for Precision {
fn deserialize<D>(deserializer: D) -> Result<Precision, D::Error>
where
D: Deserializer<'de>,
{
let value = u8::deserialize(deserializer)?;
Precision::new(value).map_err(serde::de::Error::custom)
}
}
#[derive(Eq, PartialEq, Debug, Clone, Hash, Serialize)]
pub struct Decimal<S: Scalar> {
pub value: S,
pub precision: Precision,
pub scale: i8,
}
impl<S: Scalar> Decimal<S> {
pub fn new(value: S, precision: Precision, scale: i8) -> Self {
Decimal {
value,
precision,
scale,
}
}
pub fn with_precision_and_scale(
&self,
new_precision: Precision,
new_scale: i8,
) -> DecimalResult<Decimal<S>> {
let scale_factor = new_scale - self.scale;
if scale_factor < 0 || new_precision.value() < self.precision.value() + scale_factor as u8 {
return Err(DecimalError::RoundingError {
error: "Scale factor must be non-negative".to_string(),
});
}
let scaled_value = scale_scalar(self.value, scale_factor)?;
Ok(Decimal::new(scaled_value, new_precision, new_scale))
}
pub fn from_i64(value: i64, precision: Precision, scale: i8) -> DecimalResult<Self> {
const MINIMAL_PRECISION: u8 = 19;
let raw_precision = precision.value();
if raw_precision < MINIMAL_PRECISION {
return Err(DecimalError::RoundingError {
error: "Precision must be at least 19".to_string(),
});
}
if scale < 0 || raw_precision < MINIMAL_PRECISION + scale as u8 {
return Err(DecimalError::RoundingError {
error: "Can not scale down a decimal".to_string(),
});
}
let scaled_value = scale_scalar(S::from(&value), scale)?;
Ok(Decimal::new(scaled_value, precision, scale))
}
pub fn from_i128(value: i128, precision: Precision, scale: i8) -> DecimalResult<Self> {
const MINIMAL_PRECISION: u8 = 39;
let raw_precision = precision.value();
if raw_precision < MINIMAL_PRECISION {
return Err(DecimalError::RoundingError {
error: "Precision must be at least 19".to_string(),
});
}
if scale < 0 || raw_precision < MINIMAL_PRECISION + scale as u8 {
return Err(DecimalError::RoundingError {
error: "Can not scale down a decimal".to_string(),
});
}
let scaled_value = scale_scalar(S::from(&value), scale)?;
Ok(Decimal::new(scaled_value, precision, scale))
}
}
pub(crate) fn try_into_to_scalar<S: Scalar>(
d: &IntermediateDecimal,
target_precision: Precision,
target_scale: i8,
) -> DecimalResult<S> {
d.try_into_bigint_with_precision_and_scale(target_precision.value(), target_scale)?
.try_into()
.map_err(|e: ScalarConversionError| DecimalError::InvalidDecimal {
error: e.to_string(),
})
}
pub(crate) fn scale_scalar<S: Scalar>(s: S, scale: i8) -> DecimalResult<S> {
match scale {
0 => Ok(s),
_ if scale < 0 => Err(DecimalError::RoundingError {
error: "Scale factor must be non-negative".to_string(),
}),
_ => {
let ten = S::from(10);
let mut res = s;
for _ in 0..scale {
res *= ten;
}
Ok(res)
}
}
}
#[cfg(test)]
mod scale_adjust_test {
use super::*;
use crate::base::scalar::Curve25519Scalar;
use num_bigint::BigInt;
#[test]
fn we_cannot_scale_past_max_precision() {
let decimal = "12345678901234567890123456789012345678901234567890123456789012345678900.0"
.parse()
.unwrap();
let target_scale = 5;
assert!(try_into_to_scalar::<Curve25519Scalar>(
&decimal,
Precision::new(decimal.value().digits() as u8).unwrap(),
target_scale
)
.is_err());
}
#[test]
fn we_can_match_exact_decimals_from_queries_to_db() {
let decimal: IntermediateDecimal = "123.45".parse().unwrap();
let target_scale = 2;
let target_precision = 20;
let big_int =
decimal.try_into_bigint_with_precision_and_scale(target_precision, target_scale);
let expected_big_int: BigInt = "12345".parse().unwrap();
assert_eq!(big_int, Ok(expected_big_int));
}
#[test]
fn we_can_match_decimals_with_negative_scale() {
let decimal = "120.00".parse().unwrap();
let target_scale = -1;
let expected = [12, 0, 0, 0];
let result = try_into_to_scalar::<Curve25519Scalar>(
&decimal,
Precision::new(MAX_SUPPORTED_PRECISION).unwrap(),
target_scale,
)
.unwrap();
assert_eq!(result, Curve25519Scalar::from(expected));
}
#[test]
fn we_can_match_integers_with_negative_scale() {
let decimal = "12300".parse().unwrap();
let target_scale = -2;
let expected_limbs = [123, 0, 0, 0];
let limbs = try_into_to_scalar::<Curve25519Scalar>(
&decimal,
Precision::new(decimal.value().digits() as u8).unwrap(),
target_scale,
)
.unwrap();
assert_eq!(limbs, Curve25519Scalar::from(expected_limbs));
}
#[test]
fn we_can_match_negative_decimals() {
let decimal = "-123.45".parse().unwrap();
let target_scale = 2;
let expected_limbs = [12345, 0, 0, 0];
let limbs = try_into_to_scalar::<Curve25519Scalar>(
&decimal,
Precision::new(decimal.value().digits() as u8).unwrap(),
target_scale,
)
.unwrap();
assert_eq!(limbs, -Curve25519Scalar::from(expected_limbs));
}
#[test]
fn we_can_match_decimals_at_extrema() {
let decimal = "1234567890123456789012345678901234567890123456789012345678901234567890.0"
.parse()
.unwrap();
let target_scale = 6; assert!(try_into_to_scalar::<Curve25519Scalar>(
&decimal,
Precision::new(decimal.value().digits() as u8,).unwrap(),
target_scale
)
.is_err());
let decimal =
"99999999999999999999999999999999999999999999999999999999999999999999999999.0"
.parse()
.unwrap();
let target_scale = 1;
assert!(try_into_to_scalar::<Curve25519Scalar>(
&decimal,
Precision::new(MAX_SUPPORTED_PRECISION).unwrap(),
target_scale
)
.is_ok());
let decimal =
"999999999999999999999999999999999999999999999999999999999999999999999999999.0"
.parse()
.unwrap();
let target_scale = 1;
assert!(try_into_to_scalar::<Curve25519Scalar>(
&decimal,
Precision::new(MAX_SUPPORTED_PRECISION).unwrap(),
target_scale
)
.is_err());
let decimal =
"0.000000000000000000000000000000000000000000000000000000000000000000000000001"
.parse()
.unwrap();
let target_scale = MAX_SUPPORTED_PRECISION as i8;
assert!(try_into_to_scalar::<Curve25519Scalar>(
&decimal,
Precision::new(decimal.value().digits() as u8,).unwrap(),
target_scale
)
.is_ok());
let decimal = "0.1".parse().unwrap();
let target_scale = MAX_SUPPORTED_PRECISION as i8;
assert!(try_into_to_scalar::<Curve25519Scalar>(
&decimal,
Precision::new(MAX_SUPPORTED_PRECISION).unwrap(),
target_scale
)
.is_ok());
let decimal = "1.0".parse().unwrap();
let target_scale = 75;
assert!(try_into_to_scalar::<Curve25519Scalar>(
&decimal,
Precision::new(decimal.value().digits() as u8,).unwrap(),
target_scale
)
.is_err());
let decimal = "1.0".parse().unwrap();
let target_scale = 74;
assert!(try_into_to_scalar::<Curve25519Scalar>(
&decimal,
Precision::new(MAX_SUPPORTED_PRECISION).unwrap(),
target_scale
)
.is_ok());
}
}