use crate::{ArithmeticError, Decimal, RoundingMode};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum OracleDecimals {
Six,
Eight,
Eighteen,
Custom(u8),
}
impl OracleDecimals {
pub const fn value(self) -> u8 {
match self {
Self::Six => 6,
Self::Eight => 8,
Self::Eighteen => 18,
Self::Custom(n) => n,
}
}
pub fn scale_factor(self) -> Decimal {
let decimals = self.value();
Decimal::from(10i64)
.powi(decimals as i32)
.unwrap_or(Decimal::MAX)
}
}
impl From<u8> for OracleDecimals {
fn from(n: u8) -> Self {
match n {
6 => Self::Six,
8 => Self::Eight,
18 => Self::Eighteen,
_ => Self::Custom(n),
}
}
}
pub fn normalize_oracle_price(
raw_value: i64,
decimals: OracleDecimals,
) -> Result<Decimal, ArithmeticError> {
let scale = decimals.scale_factor();
Decimal::from(raw_value)
.checked_div(scale)
.ok_or(ArithmeticError::DivisionByZero)
}
pub fn normalize_oracle_price_i128(
raw_value: i128,
decimals: OracleDecimals,
) -> Result<Decimal, ArithmeticError> {
let scale = decimals.scale_factor();
Decimal::try_from_i128(raw_value)?
.checked_div(scale)
.ok_or(ArithmeticError::DivisionByZero)
}
pub fn denormalize_oracle_price(
value: Decimal,
decimals: OracleDecimals,
) -> Result<i64, ArithmeticError> {
let scale = decimals.scale_factor();
let scaled = value
.checked_mul(scale)
.ok_or(ArithmeticError::Overflow)?
.round(0, RoundingMode::TowardZero);
let (mantissa, _) = scaled.to_parts();
i64::try_from(mantissa).map_err(|_| ArithmeticError::Overflow)
}
pub fn denormalize_oracle_price_i128(
value: Decimal,
decimals: OracleDecimals,
) -> Result<i128, ArithmeticError> {
let scale = decimals.scale_factor();
let scaled = value
.checked_mul(scale)
.ok_or(ArithmeticError::Overflow)?
.round(0, RoundingMode::TowardZero);
let (mantissa, _) = scaled.to_parts();
Ok(mantissa)
}
pub fn convert_decimals(
value: i64,
from: OracleDecimals,
to: OracleDecimals,
) -> Result<i64, ArithmeticError> {
let from_decimals = from.value() as i32;
let to_decimals = to.value() as i32;
let diff = to_decimals - from_decimals;
if diff == 0 {
return Ok(value);
}
let factor = 10i64
.checked_pow(diff.unsigned_abs())
.ok_or(ArithmeticError::Overflow)?;
if diff > 0 {
value.checked_mul(factor).ok_or(ArithmeticError::Overflow)
} else {
Ok(value / factor)
}
}
pub fn convert_decimals_i128(
value: i64,
from: OracleDecimals,
to: OracleDecimals,
) -> Result<i128, ArithmeticError> {
let from_decimals = from.value() as i32;
let to_decimals = to.value() as i32;
let diff = to_decimals - from_decimals;
if diff == 0 {
return Ok(value as i128);
}
let factor = 10i128
.checked_pow(diff.unsigned_abs())
.ok_or(ArithmeticError::Overflow)?;
if diff > 0 {
(value as i128)
.checked_mul(factor)
.ok_or(ArithmeticError::Overflow)
} else {
Ok((value as i128) / factor)
}
}
pub fn scale_token_amount(
amount: i64,
from_decimals: OracleDecimals,
to_decimals: OracleDecimals,
) -> Result<i64, ArithmeticError> {
convert_decimals(amount, from_decimals, to_decimals)
}
pub fn scale_token_amount_i128(
amount: i64,
from_decimals: OracleDecimals,
to_decimals: OracleDecimals,
) -> Result<i128, ArithmeticError> {
convert_decimals_i128(amount, from_decimals, to_decimals)
}
pub fn calculate_value(
amount: i64,
amount_decimals: OracleDecimals,
price: i64,
price_decimals: OracleDecimals,
result_decimals: OracleDecimals,
) -> Result<i64, ArithmeticError> {
let amount_dec = normalize_oracle_price(amount, amount_decimals)?;
let price_dec = normalize_oracle_price(price, price_decimals)?;
let value = amount_dec
.checked_mul(price_dec)
.ok_or(ArithmeticError::Overflow)?;
denormalize_oracle_price(value, result_decimals)
}
pub fn calculate_value_i128(
amount: i64,
amount_decimals: OracleDecimals,
price: i64,
price_decimals: OracleDecimals,
result_decimals: OracleDecimals,
) -> Result<i128, ArithmeticError> {
let amount_dec = normalize_oracle_price(amount, amount_decimals)?;
let price_dec = normalize_oracle_price(price, price_decimals)?;
let value = amount_dec
.checked_mul(price_dec)
.ok_or(ArithmeticError::Overflow)?;
denormalize_oracle_price_i128(value, result_decimals)
}
pub fn normalize_pyth_price(price: i64, exponent: i32) -> Result<Decimal, ArithmeticError> {
let price_dec = Decimal::from(price);
if exponent == 0 {
return Ok(price_dec);
}
let scale = Decimal::from(10i64)
.powi(exponent.abs())
.ok_or(ArithmeticError::Overflow)?;
if exponent > 0 {
price_dec
.checked_mul(scale)
.ok_or(ArithmeticError::Overflow)
} else {
price_dec
.checked_div(scale)
.ok_or(ArithmeticError::DivisionByZero)
}
}
#[cfg(test)]
mod tests {
extern crate alloc;
use super::*;
use alloc::string::ToString;
use core::str::FromStr;
#[test]
fn test_normalize_chainlink_price() {
let raw = 250012345678i64;
let price = normalize_oracle_price(raw, OracleDecimals::Eight).unwrap();
assert_eq!(price.to_string(), "2500.12345678");
}
#[test]
fn test_denormalize_price() {
let price = Decimal::from_str("2500.12345678").unwrap();
let raw = denormalize_oracle_price(price, OracleDecimals::Eight).unwrap();
assert_eq!(raw, 250012345678);
}
#[test]
fn test_convert_8_to_6_decimals() {
let chainlink = 250012345678i64;
let usdc = convert_decimals(chainlink, OracleDecimals::Eight, OracleDecimals::Six).unwrap();
assert_eq!(usdc, 2500123456);
}
#[test]
fn test_convert_8_to_18_decimals_i128() {
let chainlink = 250012345678i64;
let onchain =
convert_decimals_i128(chainlink, OracleDecimals::Eight, OracleDecimals::Eighteen)
.unwrap();
assert_eq!(onchain, 2500123456780000000000i128);
}
#[test]
fn test_convert_18_to_8_decimals_via_normalize() {
let original = 250012345678i64;
let normalized = normalize_oracle_price(original, OracleDecimals::Eight).unwrap();
let recovered = denormalize_oracle_price(normalized, OracleDecimals::Eight).unwrap();
assert_eq!(recovered, original);
}
#[test]
fn test_scale_usdc_to_8_decimals() {
let usdc = 1_000_000_000i64; let scaled = scale_token_amount(usdc, OracleDecimals::Six, OracleDecimals::Eight).unwrap();
assert_eq!(scaled, 100_000_000_000);
}
#[test]
fn test_scale_usdc_to_18_decimals_i128() {
let usdc = 1_000_000_000i64; let scaled =
scale_token_amount_i128(usdc, OracleDecimals::Six, OracleDecimals::Eighteen).unwrap();
assert_eq!(scaled, 1_000_000_000_000_000_000_000i128);
}
#[test]
fn test_pyth_positive_exponent() {
let price = normalize_pyth_price(25, 2).unwrap();
assert_eq!(price.to_string(), "2500");
}
#[test]
fn test_pyth_negative_exponent() {
let price = normalize_pyth_price(250012345678, -8).unwrap();
assert_eq!(price.to_string(), "2500.12345678");
}
#[test]
fn test_pyth_zero_exponent() {
let price = normalize_pyth_price(2500, 0).unwrap();
assert_eq!(price.to_string(), "2500");
}
#[test]
fn test_calculate_usdc_value() {
let usdc_amount = 1_000_000_000i64; let usdc_price = 100000000i64;
let value = calculate_value(
usdc_amount,
OracleDecimals::Six,
usdc_price,
OracleDecimals::Eight,
OracleDecimals::Six,
)
.unwrap();
assert_eq!(value, 1_000_000_000); }
#[test]
fn test_calculate_btc_value() {
let btc_amount = 10_000_000i64; let btc_price = 5000000000000i64;
let value = calculate_value(
btc_amount,
OracleDecimals::Eight,
btc_price,
OracleDecimals::Eight,
OracleDecimals::Six,
)
.unwrap();
assert_eq!(value, 5_000_000_000); }
#[test]
fn test_oracle_decimals_from_u8() {
assert_eq!(OracleDecimals::from(6), OracleDecimals::Six);
assert_eq!(OracleDecimals::from(8), OracleDecimals::Eight);
assert_eq!(OracleDecimals::from(18), OracleDecimals::Eighteen);
assert_eq!(OracleDecimals::from(12), OracleDecimals::Custom(12));
}
}