use crate::{
fixed_point::{pow, Rounding, ONE, SCALE_OFFSET},
AmmMathError,
};
pub fn get_price_from_id(active_id: i32, bin_step: u16) -> Result<u128, AmmMathError> {
let bps =
u128::from(bin_step).checked_shl(SCALE_OFFSET.into()).ok_or(AmmMathError::Overflow)?
/ 10_000u128;
let base = ONE.checked_add(bps).ok_or(AmmMathError::Overflow)?;
pow(base, active_id).ok_or(AmmMathError::Overflow)
}
pub fn q64_64_to_decimal(price: u128) -> f64 {
price as f64 / (1u128 << SCALE_OFFSET) as f64
}
pub fn price_per_lamport_to_price_per_token(
price_per_lamport: f64,
base_decimals: u8,
quote_decimals: u8,
) -> f64 {
price_per_lamport * 10_f64.powi(base_decimals as i32) / 10_f64.powi(quote_decimals as i32)
}
pub fn price_per_token_to_price_per_lamport(
price_per_token: f64,
base_decimals: u8,
quote_decimals: u8,
) -> f64 {
price_per_token * 10_f64.powi(quote_decimals as i32) / 10_f64.powi(base_decimals as i32)
}
pub fn get_id_from_price(
price_per_lamport: f64,
bin_step: u16,
rounding: Rounding,
) -> Result<i32, AmmMathError> {
if price_per_lamport <= 0.0 {
return Err(AmmMathError::Overflow);
}
let bin_step_f64 = bin_step as f64 / 10_000.0;
let base = 1.0 + bin_step_f64;
let raw = price_per_lamport.ln() / base.ln();
Ok(match rounding {
Rounding::Up => raw.ceil() as i32,
Rounding::Down => raw.floor() as i32,
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_price_at_bin_zero() {
let price = get_price_from_id(0, 1).unwrap();
assert_eq!(price, ONE);
let price2 = get_price_from_id(0, 100).unwrap();
assert_eq!(price2, ONE);
}
#[test]
fn test_price_monotonic() {
let bin_step: u16 = 10;
let mut prev = get_price_from_id(-100, bin_step).unwrap();
for id in -99..=100 {
let cur = get_price_from_id(id, bin_step).unwrap();
assert!(cur > prev, "price at bin {id} should exceed bin {}", id - 1);
prev = cur;
}
}
#[test]
fn test_negative_bin_below_one() {
for bin_step in [1u16, 10, 100] {
let price = get_price_from_id(-1, bin_step).unwrap();
assert!(price < ONE, "price at bin -1 (step={bin_step}) should be < 1.0");
}
}
#[test]
fn test_known_value() {
let price = get_price_from_id(100, 1).unwrap();
let decimal = q64_64_to_decimal(price);
let expected = 1.01005016708_f64;
let rel_err = (decimal - expected).abs() / expected;
assert!(rel_err < 1e-6, "decimal={decimal} expected={expected}");
}
#[test]
fn test_q64_64_to_decimal() {
assert!((q64_64_to_decimal(ONE) - 1.0).abs() < 1e-15);
assert!((q64_64_to_decimal(ONE * 2) - 2.0).abs() < 1e-15);
}
#[test]
fn test_price_per_lamport_roundtrip() {
let ppl = 1.5_f64;
let ppt = price_per_lamport_to_price_per_token(ppl, 6, 9);
let recovered = price_per_token_to_price_per_lamport(ppt, 6, 9);
assert!((recovered - ppl).abs() < 1e-12);
}
#[test]
fn test_get_id_from_price_roundtrip() {
let bin_step: u16 = 10;
for id in [-500, -1, 0, 1, 500, 5000] {
let price = get_price_from_id(id, bin_step).unwrap();
let decimal = q64_64_to_decimal(price);
let recovered = get_id_from_price(decimal, bin_step, Rounding::Down).unwrap();
assert!(
(recovered - id).abs() <= 1,
"roundtrip failed for id={id}: recovered={recovered}"
);
}
}
#[test]
fn test_get_id_from_price_invalid() {
assert!(get_id_from_price(0.0, 10, Rounding::Down).is_err());
assert!(get_id_from_price(-1.0, 10, Rounding::Down).is_err());
}
#[test]
fn test_get_id_from_price_rounding() {
let bin_step: u16 = 10;
let p99 = q64_64_to_decimal(get_price_from_id(99, bin_step).unwrap());
let p100 = q64_64_to_decimal(get_price_from_id(100, bin_step).unwrap());
let mid = (p99 + p100) / 2.0;
let down = get_id_from_price(mid, bin_step, Rounding::Down).unwrap();
let up = get_id_from_price(mid, bin_step, Rounding::Up).unwrap();
assert_eq!(down, 99);
assert_eq!(up, 100);
}
}