use kraken_types::Level;
use rust_decimal::Decimal;
use std::cmp::Reverse;
use std::collections::BTreeMap;
#[derive(Debug, Clone, Default)]
pub struct TreeBook {
bids: BTreeMap<Reverse<Decimal>, Level>,
asks: BTreeMap<Decimal, Level>,
}
impl TreeBook {
pub fn new() -> Self {
Self {
bids: BTreeMap::new(),
asks: BTreeMap::new(),
}
}
pub fn insert_bid(&mut self, price: Decimal, qty: Decimal) {
if qty.is_zero() {
self.bids.remove(&Reverse(price));
} else {
self.bids.insert(Reverse(price), Level::new(price, qty));
}
}
pub fn insert_ask(&mut self, price: Decimal, qty: Decimal) {
if qty.is_zero() {
self.asks.remove(&price);
} else {
self.asks.insert(price, Level::new(price, qty));
}
}
pub fn remove_bid(&mut self, price: &Decimal) {
self.bids.remove(&Reverse(*price));
}
pub fn remove_ask(&mut self, price: &Decimal) {
self.asks.remove(price);
}
pub fn best_bid(&self) -> Option<&Level> {
self.bids.values().next()
}
pub fn best_ask(&self) -> Option<&Level> {
self.asks.values().next()
}
pub fn best_bid_price(&self) -> Option<Decimal> {
self.best_bid().map(|l| l.price)
}
pub fn best_ask_price(&self) -> Option<Decimal> {
self.best_ask().map(|l| l.price)
}
pub fn bids(&self) -> impl Iterator<Item = &Level> {
self.bids.values()
}
pub fn asks(&self) -> impl Iterator<Item = &Level> {
self.asks.values()
}
pub fn bids_vec(&self) -> Vec<Level> {
self.bids.values().cloned().collect()
}
pub fn asks_vec(&self) -> Vec<Level> {
self.asks.values().cloned().collect()
}
pub fn top_bids(&self, n: usize) -> Vec<Level> {
self.bids.values().take(n).cloned().collect()
}
pub fn top_asks(&self, n: usize) -> Vec<Level> {
self.asks.values().take(n).cloned().collect()
}
pub fn bid_count(&self) -> usize {
self.bids.len()
}
pub fn ask_count(&self) -> usize {
self.asks.len()
}
pub fn level_count(&self) -> usize {
self.bid_count() + self.ask_count()
}
pub fn is_empty(&self) -> bool {
self.bids.is_empty() && self.asks.is_empty()
}
pub fn clear(&mut self) {
self.bids.clear();
self.asks.clear();
}
pub fn truncate(&mut self, max_depth: usize) {
if self.bids.len() > max_depth {
let keys_to_remove: Vec<_> = self.bids.keys().skip(max_depth).cloned().collect();
for key in keys_to_remove {
self.bids.remove(&key);
}
}
if self.asks.len() > max_depth {
let keys_to_remove: Vec<_> = self.asks.keys().skip(max_depth).cloned().collect();
for key in keys_to_remove {
self.asks.remove(&key);
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use rust_decimal_macros::dec;
#[test]
fn test_bid_order() {
let mut book = TreeBook::new();
book.insert_bid(dec!(100), dec!(1));
book.insert_bid(dec!(101), dec!(2));
book.insert_bid(dec!(99), dec!(3));
let bids: Vec<_> = book.bids().collect();
assert_eq!(bids.len(), 3);
assert_eq!(bids[0].price, dec!(101));
assert_eq!(bids[1].price, dec!(100));
assert_eq!(bids[2].price, dec!(99));
}
#[test]
fn test_ask_order() {
let mut book = TreeBook::new();
book.insert_ask(dec!(100), dec!(1));
book.insert_ask(dec!(101), dec!(2));
book.insert_ask(dec!(99), dec!(3));
let asks: Vec<_> = book.asks().collect();
assert_eq!(asks.len(), 3);
assert_eq!(asks[0].price, dec!(99));
assert_eq!(asks[1].price, dec!(100));
assert_eq!(asks[2].price, dec!(101));
}
#[test]
fn test_zero_qty_removes_level() {
let mut book = TreeBook::new();
book.insert_bid(dec!(100), dec!(1));
assert_eq!(book.bid_count(), 1);
book.insert_bid(dec!(100), dec!(0));
assert_eq!(book.bid_count(), 0);
}
#[test]
fn test_best_bid_ask() {
let mut book = TreeBook::new();
book.insert_bid(dec!(99), dec!(1));
book.insert_bid(dec!(100), dec!(1));
book.insert_ask(dec!(101), dec!(1));
book.insert_ask(dec!(102), dec!(1));
assert_eq!(book.best_bid_price(), Some(dec!(100)));
assert_eq!(book.best_ask_price(), Some(dec!(101)));
}
#[test]
fn test_truncate() {
let mut book = TreeBook::new();
for i in 1..=20 {
book.insert_bid(Decimal::from(i), dec!(1));
book.insert_ask(Decimal::from(100 + i), dec!(1));
}
assert_eq!(book.bid_count(), 20);
assert_eq!(book.ask_count(), 20);
book.truncate(10);
assert_eq!(book.bid_count(), 10);
assert_eq!(book.ask_count(), 10);
assert_eq!(book.best_bid_price(), Some(dec!(20)));
assert_eq!(book.best_ask_price(), Some(dec!(101)));
}
}