use crate::Decimal;
use crate::types::error::{MMError, MMResult};
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum OrderSide {
Buy,
Sell,
}
impl std::fmt::Display for OrderSide {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Buy => write!(f, "Buy"),
Self::Sell => write!(f, "Sell"),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum GridSpacingType {
#[default]
Geometric,
Arithmetic,
}
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct GridConfig {
pub levels_per_side: u32,
pub grid_spacing: Decimal,
pub base_size: Decimal,
pub size_progression: Option<Decimal>,
pub max_position: Decimal,
pub spacing_type: GridSpacingType,
}
impl GridConfig {
pub fn new(
levels_per_side: u32,
grid_spacing: Decimal,
base_size: Decimal,
max_position: Decimal,
) -> MMResult<Self> {
if levels_per_side == 0 {
return Err(MMError::InvalidConfiguration(
"levels_per_side must be greater than 0".to_string(),
));
}
if grid_spacing <= Decimal::ZERO {
return Err(MMError::InvalidConfiguration(
"grid_spacing must be positive".to_string(),
));
}
if base_size <= Decimal::ZERO {
return Err(MMError::InvalidConfiguration(
"base_size must be positive".to_string(),
));
}
if max_position <= Decimal::ZERO {
return Err(MMError::InvalidConfiguration(
"max_position must be positive".to_string(),
));
}
Ok(Self {
levels_per_side,
grid_spacing,
base_size,
size_progression: None,
max_position,
spacing_type: GridSpacingType::default(),
})
}
#[must_use]
pub fn with_size_progression(mut self, progression: Decimal) -> Self {
self.size_progression = Some(progression);
self
}
#[must_use]
pub fn with_spacing_type(mut self, spacing_type: GridSpacingType) -> Self {
self.spacing_type = spacing_type;
self
}
}
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct GridOrder {
pub price: Decimal,
pub size: Decimal,
pub side: OrderSide,
pub level: i32,
}
impl GridOrder {
#[must_use]
pub fn new(price: Decimal, size: Decimal, side: OrderSide, level: i32) -> Self {
Self {
price,
size,
side,
level,
}
}
#[must_use]
pub fn notional(&self) -> Decimal {
self.price * self.size
}
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct GridStrategy {
config: GridConfig,
reference_price: Decimal,
}
impl GridStrategy {
pub fn new(config: GridConfig) -> MMResult<Self> {
Ok(Self {
config,
reference_price: Decimal::ZERO,
})
}
pub fn with_reference_price(config: GridConfig, reference_price: Decimal) -> MMResult<Self> {
if reference_price <= Decimal::ZERO {
return Err(MMError::InvalidConfiguration(
"reference_price must be positive".to_string(),
));
}
Ok(Self {
config,
reference_price,
})
}
#[must_use]
pub fn config(&self) -> &GridConfig {
&self.config
}
#[must_use]
pub fn reference_price(&self) -> Decimal {
self.reference_price
}
pub fn update_reference_price(&mut self, price: Decimal) {
self.reference_price = price;
}
#[must_use]
pub fn generate_grid(&self, reference_price: Decimal) -> Vec<GridOrder> {
let mut orders = Vec::with_capacity((self.config.levels_per_side * 2) as usize);
for level in 1..=self.config.levels_per_side {
let price = self.calculate_price(reference_price, -(level as i32));
let size = self.calculate_level_size(level as i32);
orders.push(GridOrder::new(price, size, OrderSide::Buy, -(level as i32)));
}
for level in 1..=self.config.levels_per_side {
let price = self.calculate_price(reference_price, level as i32);
let size = self.calculate_level_size(level as i32);
orders.push(GridOrder::new(price, size, OrderSide::Sell, level as i32));
}
orders.sort_by(|a, b| a.price.cmp(&b.price));
orders
}
#[must_use]
pub fn generate_grid_with_inventory(
&self,
reference_price: Decimal,
current_inventory: Decimal,
) -> Vec<GridOrder> {
let mut orders = self.generate_grid(reference_price);
let inventory_ratio = current_inventory.abs() / self.config.max_position;
let scale_factor = (Decimal::ONE - inventory_ratio).max(Decimal::ZERO);
for order in &mut orders {
let should_reduce = (current_inventory > Decimal::ZERO && order.side == OrderSide::Buy)
|| (current_inventory < Decimal::ZERO && order.side == OrderSide::Sell);
if should_reduce {
order.size *= scale_factor;
}
}
orders.retain(|o| o.size > Decimal::new(1, 8));
orders
}
#[must_use]
pub fn calculate_price(&self, reference_price: Decimal, level: i32) -> Decimal {
let level_decimal = Decimal::from(level);
match self.config.spacing_type {
GridSpacingType::Geometric => {
reference_price * (Decimal::ONE + level_decimal * self.config.grid_spacing)
}
GridSpacingType::Arithmetic => {
reference_price + level_decimal * self.config.grid_spacing * reference_price
}
}
}
#[must_use]
pub fn calculate_level_size(&self, level: i32) -> Decimal {
let abs_level = level.unsigned_abs();
match self.config.size_progression {
Some(progression) => {
let multiplier =
Decimal::ONE + Decimal::from(abs_level.saturating_sub(1)) * progression;
self.config.base_size * multiplier
}
None => self.config.base_size,
}
}
#[must_use]
pub fn total_orders(&self) -> u32 {
self.config.levels_per_side * 2
}
#[must_use]
pub fn price_range(&self, reference_price: Decimal) -> (Decimal, Decimal) {
let lowest = self.calculate_price(reference_price, -(self.config.levels_per_side as i32));
let highest = self.calculate_price(reference_price, self.config.levels_per_side as i32);
(lowest, highest)
}
#[must_use]
pub fn max_notional_exposure(&self, reference_price: Decimal) -> Decimal {
let orders = self.generate_grid(reference_price);
orders.iter().map(|o| o.notional()).sum()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dec;
#[test]
fn test_config_valid() {
let config = GridConfig::new(5, dec!(0.005), dec!(1.0), dec!(100.0));
assert!(config.is_ok());
let config = config.unwrap();
assert_eq!(config.levels_per_side, 5);
assert_eq!(config.grid_spacing, dec!(0.005));
assert_eq!(config.base_size, dec!(1.0));
assert_eq!(config.max_position, dec!(100.0));
}
#[test]
fn test_config_invalid_levels() {
let config = GridConfig::new(0, dec!(0.005), dec!(1.0), dec!(100.0));
assert!(config.is_err());
}
#[test]
fn test_config_invalid_spacing() {
let config = GridConfig::new(5, dec!(0.0), dec!(1.0), dec!(100.0));
assert!(config.is_err());
let config = GridConfig::new(5, dec!(-0.005), dec!(1.0), dec!(100.0));
assert!(config.is_err());
}
#[test]
fn test_config_invalid_base_size() {
let config = GridConfig::new(5, dec!(0.005), dec!(0.0), dec!(100.0));
assert!(config.is_err());
}
#[test]
fn test_config_invalid_max_position() {
let config = GridConfig::new(5, dec!(0.005), dec!(1.0), dec!(0.0));
assert!(config.is_err());
}
#[test]
fn test_config_with_progression() {
let config = GridConfig::new(5, dec!(0.005), dec!(1.0), dec!(100.0))
.unwrap()
.with_size_progression(dec!(0.2));
assert_eq!(config.size_progression, Some(dec!(0.2)));
}
#[test]
fn test_strategy_new() {
let config = GridConfig::new(5, dec!(0.005), dec!(1.0), dec!(100.0)).unwrap();
let strategy = GridStrategy::new(config);
assert!(strategy.is_ok());
}
#[test]
fn test_strategy_with_reference_price() {
let config = GridConfig::new(5, dec!(0.005), dec!(1.0), dec!(100.0)).unwrap();
let strategy = GridStrategy::with_reference_price(config, dec!(100.0)).unwrap();
assert_eq!(strategy.reference_price(), dec!(100.0));
}
#[test]
fn test_strategy_invalid_reference_price() {
let config = GridConfig::new(5, dec!(0.005), dec!(1.0), dec!(100.0)).unwrap();
let strategy = GridStrategy::with_reference_price(config, dec!(0.0));
assert!(strategy.is_err());
}
#[test]
fn test_generate_grid_symmetric() {
let config = GridConfig::new(3, dec!(0.01), dec!(1.0), dec!(100.0)).unwrap();
let strategy = GridStrategy::new(config).unwrap();
let orders = strategy.generate_grid(dec!(100.0));
assert_eq!(orders.len(), 6);
let buys: Vec<_> = orders.iter().filter(|o| o.side == OrderSide::Buy).collect();
let sells: Vec<_> = orders
.iter()
.filter(|o| o.side == OrderSide::Sell)
.collect();
assert_eq!(buys.len(), 3);
assert_eq!(sells.len(), 3);
assert!(buys.iter().all(|o| o.price < dec!(100.0)));
assert!(sells.iter().all(|o| o.price > dec!(100.0)));
}
#[test]
fn test_generate_grid_prices() {
let config = GridConfig::new(2, dec!(0.01), dec!(1.0), dec!(100.0)).unwrap();
let strategy = GridStrategy::new(config).unwrap();
let orders = strategy.generate_grid(dec!(100.0));
let prices: Vec<_> = orders.iter().map(|o| o.price).collect();
assert_eq!(prices, vec![dec!(98), dec!(99), dec!(101), dec!(102)]);
}
#[test]
fn test_generate_grid_with_inventory_long() {
let config = GridConfig::new(2, dec!(0.01), dec!(1.0), dec!(10.0)).unwrap();
let strategy = GridStrategy::new(config).unwrap();
let orders = strategy.generate_grid_with_inventory(dec!(100.0), dec!(5.0));
let buys: Vec<_> = orders.iter().filter(|o| o.side == OrderSide::Buy).collect();
let sells: Vec<_> = orders
.iter()
.filter(|o| o.side == OrderSide::Sell)
.collect();
assert!(buys.iter().all(|o| o.size == dec!(0.5)));
assert!(sells.iter().all(|o| o.size == dec!(1.0)));
}
#[test]
fn test_generate_grid_with_inventory_short() {
let config = GridConfig::new(2, dec!(0.01), dec!(1.0), dec!(10.0)).unwrap();
let strategy = GridStrategy::new(config).unwrap();
let orders = strategy.generate_grid_with_inventory(dec!(100.0), dec!(-5.0));
let buys: Vec<_> = orders.iter().filter(|o| o.side == OrderSide::Buy).collect();
let sells: Vec<_> = orders
.iter()
.filter(|o| o.side == OrderSide::Sell)
.collect();
assert!(buys.iter().all(|o| o.size == dec!(1.0)));
assert!(sells.iter().all(|o| o.size == dec!(0.5)));
}
#[test]
fn test_generate_grid_with_max_inventory() {
let config = GridConfig::new(2, dec!(0.01), dec!(1.0), dec!(10.0)).unwrap();
let strategy = GridStrategy::new(config).unwrap();
let orders = strategy.generate_grid_with_inventory(dec!(100.0), dec!(10.0));
let buys: Vec<_> = orders.iter().filter(|o| o.side == OrderSide::Buy).collect();
assert!(buys.is_empty());
}
#[test]
fn test_calculate_level_size_no_progression() {
let config = GridConfig::new(5, dec!(0.01), dec!(1.0), dec!(100.0)).unwrap();
let strategy = GridStrategy::new(config).unwrap();
assert_eq!(strategy.calculate_level_size(1), dec!(1.0));
assert_eq!(strategy.calculate_level_size(3), dec!(1.0));
assert_eq!(strategy.calculate_level_size(5), dec!(1.0));
}
#[test]
fn test_calculate_level_size_with_progression() {
let config = GridConfig::new(5, dec!(0.01), dec!(1.0), dec!(100.0))
.unwrap()
.with_size_progression(dec!(0.2));
let strategy = GridStrategy::new(config).unwrap();
assert_eq!(strategy.calculate_level_size(1), dec!(1.0));
assert_eq!(strategy.calculate_level_size(2), dec!(1.2));
assert_eq!(strategy.calculate_level_size(3), dec!(1.4));
}
#[test]
fn test_price_range() {
let config = GridConfig::new(3, dec!(0.01), dec!(1.0), dec!(100.0)).unwrap();
let strategy = GridStrategy::new(config).unwrap();
let (low, high) = strategy.price_range(dec!(100.0));
assert_eq!(low, dec!(97));
assert_eq!(high, dec!(103));
}
#[test]
fn test_total_orders() {
let config = GridConfig::new(5, dec!(0.01), dec!(1.0), dec!(100.0)).unwrap();
let strategy = GridStrategy::new(config).unwrap();
assert_eq!(strategy.total_orders(), 10);
}
#[test]
fn test_arithmetic_spacing() {
let config = GridConfig::new(2, dec!(0.01), dec!(1.0), dec!(100.0))
.unwrap()
.with_spacing_type(GridSpacingType::Arithmetic);
let strategy = GridStrategy::new(config).unwrap();
let orders = strategy.generate_grid(dec!(100.0));
let prices: Vec<_> = orders.iter().map(|o| o.price).collect();
assert_eq!(prices, vec![dec!(98), dec!(99), dec!(101), dec!(102)]);
}
#[test]
fn test_order_side_display() {
assert_eq!(OrderSide::Buy.to_string(), "Buy");
assert_eq!(OrderSide::Sell.to_string(), "Sell");
}
#[test]
fn test_grid_order_notional() {
let order = GridOrder::new(dec!(100.0), dec!(5.0), OrderSide::Buy, -1);
assert_eq!(order.notional(), dec!(500.0));
}
#[test]
fn test_update_reference_price() {
let config = GridConfig::new(5, dec!(0.01), dec!(1.0), dec!(100.0)).unwrap();
let mut strategy = GridStrategy::new(config).unwrap();
strategy.update_reference_price(dec!(150.0));
assert_eq!(strategy.reference_price(), dec!(150.0));
}
#[cfg(feature = "serde")]
#[test]
fn test_serialization() {
let config = GridConfig::new(5, dec!(0.01), dec!(1.0), dec!(100.0)).unwrap();
let json = serde_json::to_string(&config).unwrap();
let deserialized: GridConfig = serde_json::from_str(&json).unwrap();
assert_eq!(config, deserialized);
}
}