use crate::{
base::{
math::decimal::DecimalError::{
IntermediateDecimalConversionError, InvalidPrecision, RoundingError,
},
scalar::Scalar,
},
sql::parse::{
ConversionError::{self, DecimalConversionError},
ConversionResult,
},
};
use proof_of_sql_parser::intermediate_decimal::{IntermediateDecimal, IntermediateDecimalError};
use serde::{Deserialize, Deserializer, Serialize};
use thiserror::Error;
#[derive(Error, Debug, Eq, PartialEq)]
pub enum DecimalError {
#[error("Invalid decimal format or value: {0}")]
InvalidDecimal(String),
#[error("Decimal precision is not valid: {0}")]
InvalidPrecision(String),
#[error("Decimal scale is not valid: {0}")]
InvalidScale(i16),
#[error("Unsupported operation: cannot round decimal: {0}")]
RoundingError(String),
#[error("Intermediate decimal conversion error: {0}")]
IntermediateDecimalConversionError(IntermediateDecimalError),
}
impl From<IntermediateDecimalError> for ConversionError {
fn from(err: IntermediateDecimalError) -> ConversionError {
DecimalConversionError(IntermediateDecimalConversionError(err))
}
}
#[derive(Eq, PartialEq, Debug, Clone, Hash, Serialize, Copy)]
pub struct Precision(u8);
pub(crate) const MAX_SUPPORTED_PRECISION: u8 = 75;
impl Precision {
pub fn new(value: u8) -> Result<Self, ConversionError> {
if value > MAX_SUPPORTED_PRECISION || value == 0 {
Err(DecimalConversionError(InvalidPrecision(format!(
"Failed to parse precision. Value of {} exceeds max supported precision of {}",
value, MAX_SUPPORTED_PRECISION
))))
} else {
Ok(Precision(value))
}
}
pub fn value(&self) -> u8 {
self.0
}
}
impl<'de> Deserialize<'de> for Precision {
fn deserialize<D>(deserializer: D) -> Result<Precision, D::Error>
where
D: Deserializer<'de>,
{
let value = u8::deserialize(deserializer)?;
Precision::new(value).map_err(serde::de::Error::custom)
}
}
#[derive(Eq, PartialEq, Debug, Clone, Hash, Serialize)]
pub struct Decimal<S: Scalar> {
pub value: S,
pub precision: Precision,
pub scale: i8,
}
impl<S: Scalar> Decimal<S> {
pub fn new(value: S, precision: Precision, scale: i8) -> Self {
Decimal {
value,
precision,
scale,
}
}
pub fn with_precision_and_scale(
&self,
new_precision: Precision,
new_scale: i8,
) -> ConversionResult<Decimal<S>> {
let scale_factor = new_scale - self.scale;
if scale_factor < 0 || new_precision.value() < self.precision.value() + scale_factor as u8 {
return Err(DecimalConversionError(RoundingError(
"Scale factor must be non-negative".to_string(),
)));
}
let scaled_value = scale_scalar(self.value, scale_factor)?;
Ok(Decimal::new(scaled_value, new_precision, new_scale))
}
pub fn from_i64(value: i64, precision: Precision, scale: i8) -> ConversionResult<Self> {
const MINIMAL_PRECISION: u8 = 19;
let raw_precision = precision.value();
if raw_precision < MINIMAL_PRECISION {
return Err(DecimalConversionError(RoundingError(
"Precision must be at least 19".to_string(),
)));
}
if scale < 0 || raw_precision < MINIMAL_PRECISION + scale as u8 {
return Err(DecimalConversionError(RoundingError(
"Can not scale down a decimal".to_string(),
)));
}
let scaled_value = scale_scalar(S::from(&value), scale)?;
Ok(Decimal::new(scaled_value, precision, scale))
}
pub fn from_i128(value: i128, precision: Precision, scale: i8) -> ConversionResult<Self> {
const MINIMAL_PRECISION: u8 = 39;
let raw_precision = precision.value();
if raw_precision < MINIMAL_PRECISION {
return Err(DecimalConversionError(RoundingError(
"Precision must be at least 19".to_string(),
)));
}
if scale < 0 || raw_precision < MINIMAL_PRECISION + scale as u8 {
return Err(DecimalConversionError(RoundingError(
"Can not scale down a decimal".to_string(),
)));
}
let scaled_value = scale_scalar(S::from(&value), scale)?;
Ok(Decimal::new(scaled_value, precision, scale))
}
}
pub(crate) fn try_into_to_scalar<S: Scalar>(
d: &IntermediateDecimal,
target_precision: Precision,
target_scale: i8,
) -> Result<S, ConversionError> {
d.try_into_bigint_with_precision_and_scale(target_precision.value(), target_scale)?
.try_into()
}
pub(crate) fn scale_scalar<S: Scalar>(s: S, scale: i8) -> ConversionResult<S> {
if scale < 0 {
return Err(DecimalConversionError(RoundingError(
"Scale factor must be non-negative".to_string(),
)));
}
let ten = S::from(10);
let mut res = s;
for _ in 0..scale {
res *= ten;
}
Ok(res)
}
#[cfg(test)]
mod scale_adjust_test {
use super::*;
use crate::base::scalar::Curve25519Scalar;
use num_bigint::BigInt;
#[test]
fn we_cannot_scale_past_max_precision() {
let decimal = "12345678901234567890123456789012345678901234567890123456789012345678900.0"
.parse()
.unwrap();
let target_scale = 5;
assert!(try_into_to_scalar::<Curve25519Scalar>(
&decimal,
Precision::new(decimal.value().digits() as u8).unwrap(),
target_scale
)
.is_err());
}
#[test]
fn we_can_match_exact_decimals_from_queries_to_db() {
let decimal: IntermediateDecimal = "123.45".parse().unwrap();
let target_scale = 2;
let target_precision = 20;
let big_int =
decimal.try_into_bigint_with_precision_and_scale(target_precision, target_scale);
let expected_big_int: BigInt = "12345".parse().unwrap();
assert_eq!(big_int, Ok(expected_big_int));
}
#[test]
fn we_can_match_decimals_with_negative_scale() {
let decimal = "120.00".parse().unwrap();
let target_scale = -1;
let expected = [12, 0, 0, 0];
let result = try_into_to_scalar::<Curve25519Scalar>(
&decimal,
Precision::new(MAX_SUPPORTED_PRECISION).unwrap(),
target_scale,
)
.unwrap();
assert_eq!(result, Curve25519Scalar::from(expected));
}
#[test]
fn we_can_match_integers_with_negative_scale() {
let decimal = "12300".parse().unwrap();
let target_scale = -2;
let expected_limbs = [123, 0, 0, 0];
let limbs = try_into_to_scalar::<Curve25519Scalar>(
&decimal,
Precision::new(decimal.value().digits() as u8).unwrap(),
target_scale,
)
.unwrap();
assert_eq!(limbs, Curve25519Scalar::from(expected_limbs));
}
#[test]
fn we_can_match_negative_decimals() {
let decimal = "-123.45".parse().unwrap();
let target_scale = 2;
let expected_limbs = [12345, 0, 0, 0];
let limbs = try_into_to_scalar::<Curve25519Scalar>(
&decimal,
Precision::new(decimal.value().digits() as u8).unwrap(),
target_scale,
)
.unwrap();
assert_eq!(limbs, -Curve25519Scalar::from(expected_limbs));
}
#[test]
fn we_can_match_decimals_at_extrema() {
let decimal = "1234567890123456789012345678901234567890123456789012345678901234567890.0"
.parse()
.unwrap();
let target_scale = 6; assert!(try_into_to_scalar::<Curve25519Scalar>(
&decimal,
Precision::new(decimal.value().digits() as u8,).unwrap(),
target_scale
)
.is_err());
let decimal =
"99999999999999999999999999999999999999999999999999999999999999999999999999.0"
.parse()
.unwrap();
let target_scale = 1;
assert!(try_into_to_scalar::<Curve25519Scalar>(
&decimal,
Precision::new(MAX_SUPPORTED_PRECISION).unwrap(),
target_scale
)
.is_ok());
let decimal =
"999999999999999999999999999999999999999999999999999999999999999999999999999.0"
.parse()
.unwrap();
let target_scale = 1;
assert!(try_into_to_scalar::<Curve25519Scalar>(
&decimal,
Precision::new(MAX_SUPPORTED_PRECISION).unwrap(),
target_scale
)
.is_err());
let decimal =
"0.000000000000000000000000000000000000000000000000000000000000000000000000001"
.parse()
.unwrap();
let target_scale = MAX_SUPPORTED_PRECISION as i8;
assert!(try_into_to_scalar::<Curve25519Scalar>(
&decimal,
Precision::new(decimal.value().digits() as u8,).unwrap(),
target_scale
)
.is_ok());
let decimal = "0.1".parse().unwrap();
let target_scale = MAX_SUPPORTED_PRECISION as i8;
assert!(try_into_to_scalar::<Curve25519Scalar>(
&decimal,
Precision::new(MAX_SUPPORTED_PRECISION).unwrap(),
target_scale
)
.is_ok());
let decimal = "1.0".parse().unwrap();
let target_scale = 75;
assert!(try_into_to_scalar::<Curve25519Scalar>(
&decimal,
Precision::new(decimal.value().digits() as u8,).unwrap(),
target_scale
)
.is_err());
let decimal = "1.0".parse().unwrap();
let target_scale = 74;
assert!(try_into_to_scalar::<Curve25519Scalar>(
&decimal,
Precision::new(MAX_SUPPORTED_PRECISION).unwrap(),
target_scale
)
.is_ok());
}
}