use std::fmt;
use rust_decimal::prelude::*;
use rust_decimal::Decimal;
use crate::program::types::OrderSide;
#[derive(Debug, Clone)]
pub struct OrderbookDecimals {
pub orderbook_id: String,
pub base_decimals: u8,
pub quote_decimals: u8,
pub price_decimals: u8,
pub tick_size: u64,
}
pub fn align_price_to_tick(price: Decimal, decimals: &OrderbookDecimals) -> Decimal {
if decimals.tick_size <= 1 {
return price;
}
let quote_multiplier = Decimal::from(10u64.pow(decimals.quote_decimals as u32));
let tick = Decimal::from(decimals.tick_size);
let lamports = (price * quote_multiplier).trunc();
let aligned_lamports = (lamports / tick).trunc() * tick;
aligned_lamports / quote_multiplier
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ScaledAmounts {
pub amount_in: u64,
pub amount_out: u64,
}
#[derive(Debug, Clone)]
pub enum ScalingError {
NonPositivePrice(String),
NonPositiveSize(String),
Overflow { context: String },
ZeroAmount,
FractionalAmount { value: String },
InvalidDecimal { input: String, reason: String },
}
impl fmt::Display for ScalingError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ScalingError::NonPositivePrice(v) => write!(f, "Price must be positive, got {}", v),
ScalingError::NonPositiveSize(v) => write!(f, "Size must be positive, got {}", v),
ScalingError::Overflow { context } => write!(f, "Overflow: {}", context),
ScalingError::ZeroAmount => write!(f, "Computed amount is zero"),
ScalingError::FractionalAmount { value } => {
write!(f, "Fractional lamports not allowed: {}", value)
}
ScalingError::InvalidDecimal { input, reason } => {
write!(f, "Invalid decimal '{}': {}", input, reason)
}
}
}
}
impl std::error::Error for ScalingError {}
pub fn scale_price_size(
price: Decimal,
size: Decimal,
side: OrderSide,
decimals: &OrderbookDecimals,
) -> Result<ScaledAmounts, ScalingError> {
if price <= Decimal::ZERO {
return Err(ScalingError::NonPositivePrice(price.to_string()));
}
if size <= Decimal::ZERO {
return Err(ScalingError::NonPositiveSize(size.to_string()));
}
let base_multiplier = Decimal::from(
10u64
.checked_pow(decimals.base_decimals as u32)
.ok_or_else(|| ScalingError::Overflow {
context: format!("10^{} overflow", decimals.base_decimals),
})?,
);
let quote_multiplier = Decimal::from(
10u64
.checked_pow(decimals.quote_decimals as u32)
.ok_or_else(|| ScalingError::Overflow {
context: format!("10^{} overflow", decimals.quote_decimals),
})?,
);
let size = size.trunc_with_scale(decimals.base_decimals as u32);
let base_lamports =
size.checked_mul(base_multiplier)
.ok_or_else(|| ScalingError::Overflow {
context: "size * 10^base_decimals".to_string(),
})?;
let quote_lamports = price
.checked_mul(size)
.ok_or_else(|| ScalingError::Overflow {
context: "price * size".to_string(),
})?
.checked_mul(quote_multiplier)
.ok_or_else(|| ScalingError::Overflow {
context: "price * size * 10^quote_decimals".to_string(),
})?
.trunc();
if base_lamports.fract() != Decimal::ZERO {
return Err(ScalingError::FractionalAmount {
value: format!("base_lamports = {}", base_lamports),
});
}
let base_u64 = base_lamports
.to_u64()
.ok_or_else(|| ScalingError::Overflow {
context: format!("base_lamports {} does not fit in u64", base_lamports),
})?;
let quote_u64 = quote_lamports
.to_u64()
.ok_or_else(|| ScalingError::Overflow {
context: format!("quote_lamports {} does not fit in u64", quote_lamports),
})?;
if base_u64 == 0 || quote_u64 == 0 {
return Err(ScalingError::ZeroAmount);
}
let (amount_in, amount_out) = match side {
OrderSide::Bid => (quote_u64, base_u64),
OrderSide::Ask => (base_u64, quote_u64),
};
Ok(ScaledAmounts {
amount_in,
amount_out,
})
}
#[cfg(test)]
mod tests {
use super::*;
fn decimals_6_6() -> OrderbookDecimals {
OrderbookDecimals {
orderbook_id: "test".to_string(),
base_decimals: 6,
quote_decimals: 6,
price_decimals: 2,
tick_size: 0,
}
}
fn decimals_6_9() -> OrderbookDecimals {
OrderbookDecimals {
orderbook_id: "test".to_string(),
base_decimals: 6,
quote_decimals: 9,
price_decimals: 2,
tick_size: 0,
}
}
#[test]
fn test_bid_basic() {
let result = scale_price_size(
Decimal::from_str("0.65").unwrap(),
Decimal::from_str("100").unwrap(),
OrderSide::Bid,
&decimals_6_6(),
)
.unwrap();
assert_eq!(result.amount_in, 65_000_000);
assert_eq!(result.amount_out, 100_000_000);
}
#[test]
fn test_ask_basic() {
let result = scale_price_size(
Decimal::from_str("0.65").unwrap(),
Decimal::from_str("100").unwrap(),
OrderSide::Ask,
&decimals_6_6(),
)
.unwrap();
assert_eq!(result.amount_in, 100_000_000);
assert_eq!(result.amount_out, 65_000_000);
}
#[test]
fn test_different_decimals() {
let result = scale_price_size(
Decimal::from_str("0.65").unwrap(),
Decimal::from_str("100").unwrap(),
OrderSide::Bid,
&decimals_6_9(),
)
.unwrap();
assert_eq!(result.amount_in, 65_000_000_000);
assert_eq!(result.amount_out, 100_000_000);
}
#[test]
fn test_zero_price_rejected() {
let result = scale_price_size(
Decimal::ZERO,
Decimal::from_str("100").unwrap(),
OrderSide::Bid,
&decimals_6_6(),
);
assert!(matches!(result, Err(ScalingError::NonPositivePrice(_))));
}
#[test]
fn test_negative_price_rejected() {
let result = scale_price_size(
Decimal::from_str("-0.5").unwrap(),
Decimal::from_str("100").unwrap(),
OrderSide::Bid,
&decimals_6_6(),
);
assert!(matches!(result, Err(ScalingError::NonPositivePrice(_))));
}
#[test]
fn test_zero_size_rejected() {
let result = scale_price_size(
Decimal::from_str("0.65").unwrap(),
Decimal::ZERO,
OrderSide::Bid,
&decimals_6_6(),
);
assert!(matches!(result, Err(ScalingError::NonPositiveSize(_))));
}
#[test]
fn test_negative_size_rejected() {
let result = scale_price_size(
Decimal::from_str("0.65").unwrap(),
Decimal::from_str("-10").unwrap(),
OrderSide::Bid,
&decimals_6_6(),
);
assert!(matches!(result, Err(ScalingError::NonPositiveSize(_))));
}
#[test]
fn test_sub_lamport_size_becomes_zero() {
let result = scale_price_size(
Decimal::from_str("1").unwrap(),
Decimal::from_str("0.0000001").unwrap(),
OrderSide::Bid,
&decimals_6_6(),
);
assert!(matches!(result, Err(ScalingError::ZeroAmount)));
}
#[test]
fn test_f64_noise_in_size_is_truncated() {
let result = scale_price_size(
Decimal::from_str("1").unwrap(),
Decimal::from_str("15.763000000000002").unwrap(),
OrderSide::Bid,
&decimals_6_6(),
)
.unwrap();
assert_eq!(result.amount_in, 15_763_000);
assert_eq!(result.amount_out, 15_763_000);
}
#[test]
fn test_overflow_u64_rejected() {
let result = scale_price_size(
Decimal::from_str("1").unwrap(),
Decimal::from_str("99999999999999999999").unwrap(),
OrderSide::Bid,
&decimals_6_6(),
);
assert!(matches!(result, Err(ScalingError::Overflow { .. })));
}
#[test]
fn test_small_valid_amounts() {
let result = scale_price_size(
Decimal::from_str("1").unwrap(),
Decimal::from_str("0.000001").unwrap(),
OrderSide::Bid,
&decimals_6_6(),
)
.unwrap();
assert_eq!(result.amount_in, 1); assert_eq!(result.amount_out, 1); }
#[test]
fn test_whole_number_price_and_size() {
let result = scale_price_size(
Decimal::from_str("2").unwrap(),
Decimal::from_str("50").unwrap(),
OrderSide::Ask,
&decimals_6_6(),
)
.unwrap();
assert_eq!(result.amount_in, 50_000_000);
assert_eq!(result.amount_out, 100_000_000);
}
#[test]
fn test_align_price_to_tick_basic() {
let d = OrderbookDecimals {
orderbook_id: "t".into(),
base_decimals: 8,
quote_decimals: 6,
price_decimals: 2,
tick_size: 1000,
};
let aligned = align_price_to_tick(Decimal::from_str("0.6005").unwrap(), &d);
assert_eq!(aligned, Decimal::from_str("0.6").unwrap());
}
#[test]
fn test_align_price_to_tick_exact() {
let d = OrderbookDecimals {
orderbook_id: "t".into(),
base_decimals: 8,
quote_decimals: 6,
price_decimals: 2,
tick_size: 1000,
};
let aligned = align_price_to_tick(Decimal::from_str("0.65").unwrap(), &d);
assert_eq!(aligned, Decimal::from_str("0.65").unwrap());
}
#[test]
fn test_align_price_no_tick() {
let d = OrderbookDecimals {
orderbook_id: "t".into(),
base_decimals: 6,
quote_decimals: 6,
price_decimals: 2,
tick_size: 0,
};
let price = Decimal::from_str("0.12345").unwrap();
assert_eq!(align_price_to_tick(price, &d), price);
}
}