#[cfg(feature = "python")]
use pyo3::prelude::*;
use serde::{Deserialize, Serialize};
use std::fmt;
use std::str::FromStr;
pub const SCALE: i64 = 100_000_000;
pub const BASIS_POINTS_SCALE: u32 = 100_000;
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Default, Serialize, Deserialize)]
#[cfg_attr(feature = "api", derive(utoipa::ToSchema))]
pub struct FixedPoint(pub i64);
impl FixedPoint {
pub fn from_str(s: &str) -> Result<Self, FixedPointError> {
<Self as FromStr>::from_str(s)
}
pub fn compute_range_thresholds(&self, threshold_decimal_bps: u32) -> (FixedPoint, FixedPoint) {
let delta = (self.0 as i128 * threshold_decimal_bps as i128) / BASIS_POINTS_SCALE as i128;
let delta = delta as i64;
let upper = FixedPoint(self.0 + delta);
let lower = FixedPoint(self.0 - delta);
(upper, lower)
}
#[inline]
pub fn compute_range_thresholds_cached(
&self,
threshold_ratio: i64,
) -> (FixedPoint, FixedPoint) {
let delta = (self.0 as i128 * threshold_ratio as i128) / SCALE as i128;
let delta = delta as i64;
let upper = FixedPoint(self.0 + delta);
let lower = FixedPoint(self.0 - delta);
(upper, lower)
}
#[inline]
pub fn from_f64(value: f64) -> Self {
FixedPoint((value * SCALE as f64).round() as i64)
}
#[inline]
pub fn to_f64(&self) -> f64 {
self.0 as f64 / SCALE as f64
}
}
impl fmt::Display for FixedPoint {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let abs_value = self.0.abs();
let integer_part = abs_value / SCALE;
let fractional_part = abs_value % SCALE;
let sign = if self.0 < 0 { "-" } else { "" };
write!(f, "{sign}{integer_part}.{fractional_part:08}")
}
}
impl FromStr for FixedPoint {
type Err = FixedPointError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
if s.is_empty() {
return Err(FixedPointError::InvalidFormat);
}
let (int_str, frac_str_opt) = match s.find('.') {
Some(dot_pos) => {
if s[dot_pos + 1..].contains('.') {
return Err(FixedPointError::InvalidFormat);
}
(&s[..dot_pos], Some(&s[dot_pos + 1..]))
}
None => (s, None),
};
let integer_part: i64 = int_str
.parse()
.map_err(|_| FixedPointError::InvalidFormat)?;
let fractional_part = if let Some(frac_str) = frac_str_opt {
let frac_len = frac_str.len();
if frac_len > 8 {
return Err(FixedPointError::TooManyDecimals);
}
let frac_digits: i64 = frac_str
.parse()
.map_err(|_| FixedPointError::InvalidFormat)?;
const POWERS: [i64; 9] = [
100_000_000,
10_000_000,
1_000_000,
100_000,
10_000,
1_000,
100,
10,
1,
];
frac_digits * POWERS[frac_len]
} else {
0
};
let result = if integer_part >= 0 {
integer_part * SCALE + fractional_part
} else {
integer_part * SCALE - fractional_part
};
Ok(FixedPoint(result))
}
}
#[derive(Debug, Clone, PartialEq, thiserror::Error)]
pub enum FixedPointError {
#[error("Invalid number format")]
InvalidFormat,
#[error("Too many decimal places (max 8)")]
TooManyDecimals,
#[error("Arithmetic overflow")]
Overflow,
}
#[cfg(feature = "python")]
impl From<FixedPointError> for PyErr {
fn from(err: FixedPointError) -> PyErr {
match err {
FixedPointError::InvalidFormat => {
pyo3::exceptions::PyValueError::new_err("Invalid number format")
}
FixedPointError::TooManyDecimals => {
pyo3::exceptions::PyValueError::new_err("Too many decimal places (max 8)")
}
FixedPointError::Overflow => {
pyo3::exceptions::PyOverflowError::new_err("Arithmetic overflow")
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_from_string() {
assert_eq!(FixedPoint::from_str("0").unwrap().0, 0);
assert_eq!(FixedPoint::from_str("1").unwrap().0, SCALE);
assert_eq!(FixedPoint::from_str("1.5").unwrap().0, SCALE + SCALE / 2);
assert_eq!(
FixedPoint::from_str("50000.12345678").unwrap().0,
5000012345678
);
assert_eq!(FixedPoint::from_str("-1.5").unwrap().0, -SCALE - SCALE / 2);
}
#[test]
fn test_to_string() {
assert_eq!(FixedPoint(0).to_string(), "0.00000000");
assert_eq!(FixedPoint(SCALE).to_string(), "1.00000000");
assert_eq!(FixedPoint(SCALE + SCALE / 2).to_string(), "1.50000000");
assert_eq!(FixedPoint(5000012345678).to_string(), "50000.12345678");
assert_eq!(FixedPoint(-SCALE).to_string(), "-1.00000000");
}
#[test]
fn test_round_trip() {
let test_values = [
"0",
"1",
"1.5",
"50000.12345678",
"999999.99999999",
"-1.5",
"-50000.12345678",
];
for val in &test_values {
let fp = FixedPoint::from_str(val).unwrap();
let back = fp.to_string();
let fp2 = FixedPoint::from_str(&back).unwrap();
assert_eq!(fp.0, fp2.0, "Round trip failed for {}", val);
}
}
#[test]
fn test_compute_thresholds() {
let price = FixedPoint::from_str("50000.0").unwrap();
let (upper, lower) = price.compute_range_thresholds(250);
assert_eq!(upper.to_string(), "50125.00000000");
assert_eq!(lower.to_string(), "49875.00000000");
}
#[test]
fn test_error_cases() {
assert!(FixedPoint::from_str("").is_err());
assert!(FixedPoint::from_str("not_a_number").is_err());
assert!(FixedPoint::from_str("1.123456789").is_err()); assert!(FixedPoint::from_str("1.2.3").is_err()); }
#[test]
fn test_comparison() {
let a = FixedPoint::from_str("50000.0").unwrap();
let b = FixedPoint::from_str("50000.1").unwrap();
let c = FixedPoint::from_str("49999.9").unwrap();
assert!(a < b);
assert!(b > a);
assert!(c < a);
assert_eq!(a, a);
}
#[test]
fn test_from_str_too_many_decimals() {
let err = FixedPoint::from_str("0.000000001").unwrap_err();
assert_eq!(err, FixedPointError::TooManyDecimals);
}
#[test]
fn test_from_str_negative_fractional() {
let fp = FixedPoint::from_str("-0.5").unwrap();
assert_eq!(fp.0, 50_000_000);
let fp2 = FixedPoint::from_str("-1.5").unwrap();
assert_eq!(fp2.0, -150_000_000); assert_eq!(fp2.to_f64(), -1.5);
}
#[test]
fn test_from_str_leading_zeros() {
let fp = FixedPoint::from_str("000.123").unwrap();
assert_eq!(fp.0, 12_300_000); }
#[test]
fn test_to_f64_extreme_values() {
let max_fp = FixedPoint(i64::MAX);
let max_f64 = max_fp.to_f64();
assert!(max_f64 > 92_233_720_368.0);
assert!(max_f64.is_finite());
let min_fp = FixedPoint(i64::MIN);
let min_f64 = min_fp.to_f64();
assert!(min_f64 < -92_233_720_368.0);
assert!(min_f64.is_finite());
}
#[test]
fn test_threshold_zero_ratio() {
let price = FixedPoint::from_str("100.0").unwrap();
let (upper, lower) = price.compute_range_thresholds_cached(0);
assert_eq!(upper, price);
assert_eq!(lower, price);
}
#[test]
fn test_threshold_small_price_small_bps() {
let price = FixedPoint::from_str("0.01").unwrap();
let (upper, lower) = price.compute_range_thresholds(1);
assert!(upper > price);
assert!(lower < price);
}
#[test]
fn test_fixedpoint_zero() {
let zero = FixedPoint(0);
assert_eq!(zero.to_f64(), 0.0);
assert_eq!(zero.to_string(), "0.00000000");
let (upper, lower) = zero.compute_range_thresholds(250);
assert_eq!(upper, zero); assert_eq!(lower, zero);
}
#[test]
fn test_from_f64_round_trip() {
let val = 50000.12345678_f64;
let fp = FixedPoint::from_f64(val);
assert!((fp.to_f64() - val).abs() < 1e-8);
}
#[test]
fn test_from_f64_zero() {
assert_eq!(FixedPoint::from_f64(0.0), FixedPoint(0));
}
#[test]
fn test_from_f64_positive() {
assert_eq!(FixedPoint::from_f64(1.5).0, 150_000_000);
}
#[test]
fn test_from_f64_negative() {
assert_eq!(FixedPoint::from_f64(-1.5).0, -150_000_000);
}
#[test]
fn test_from_f64_oracle_matches_from_str() {
let oracle_values = ["50000.12345678", "0.01328000", "112070.01000000"];
for s in &oracle_values {
let from_str_val = FixedPoint::from_str(s).unwrap();
let from_f64_val = FixedPoint::from_f64(s.parse::<f64>().unwrap());
assert_eq!(
from_str_val.0, from_f64_val.0,
"Oracle mismatch for {}: from_str={}, from_f64={}",
s, from_str_val.0, from_f64_val.0
);
}
}
#[test]
fn test_fixedpoint_error_display() {
assert_eq!(
FixedPointError::InvalidFormat.to_string(),
"Invalid number format"
);
assert_eq!(
FixedPointError::TooManyDecimals.to_string(),
"Too many decimal places (max 8)"
);
assert_eq!(FixedPointError::Overflow.to_string(), "Arithmetic overflow");
}
}