use std::{collections::HashMap, str::FromStr};
use ahash::AHashMap;
use bytes::Bytes;
use chrono::{DateTime, Utc};
use futures::future::join_all;
use nautilus_common::{cache::database::CacheMap, enums::SerializationEncoding};
use nautilus_model::{
accounts::AccountAny,
data::{CustomData, DataType, HasTsInit},
identifiers::{AccountId, ClientOrderId, InstrumentId, PositionId},
instruments::{InstrumentAny, SyntheticInstrument},
orders::OrderAny,
position::Position,
types::Currency,
};
use redis::{AsyncCommands, aio::ConnectionManager};
use serde::{Serialize, de::DeserializeOwned};
use serde_json::Value;
use ustr::Ustr;
use super::get_index_key;
const INDEX: &str = "index";
const GENERAL: &str = "general";
const CURRENCIES: &str = "currencies";
const INSTRUMENTS: &str = "instruments";
const SYNTHETICS: &str = "synthetics";
const ACCOUNTS: &str = "accounts";
const ORDERS: &str = "orders";
const POSITIONS: &str = "positions";
const ACTORS: &str = "actors";
const STRATEGIES: &str = "strategies";
const CUSTOM: &str = "custom";
const REDIS_DELIMITER: char = ':';
const INDEX_ORDER_IDS: &str = "index:order_ids";
const INDEX_ORDER_POSITION: &str = "index:order_position";
const INDEX_ORDER_CLIENT: &str = "index:order_client";
const INDEX_ORDERS: &str = "index:orders";
const INDEX_ORDERS_OPEN: &str = "index:orders_open";
const INDEX_ORDERS_CLOSED: &str = "index:orders_closed";
const INDEX_ORDERS_EMULATED: &str = "index:orders_emulated";
const INDEX_ORDERS_INFLIGHT: &str = "index:orders_inflight";
const INDEX_POSITIONS: &str = "index:positions";
const INDEX_POSITIONS_OPEN: &str = "index:positions_open";
const INDEX_POSITIONS_CLOSED: &str = "index:positions_closed";
#[derive(Debug)]
pub struct DatabaseQueries;
impl DatabaseQueries {
pub fn serialize_payload<T: Serialize>(
encoding: SerializationEncoding,
payload: &T,
) -> anyhow::Result<Vec<u8>> {
let mut value = serde_json::to_value(payload)?;
convert_timestamps(&mut value);
match encoding {
SerializationEncoding::MsgPack => rmp_serde::to_vec(&value)
.map_err(|e| anyhow::anyhow!("Failed to serialize msgpack `payload`: {e}")),
SerializationEncoding::Json => serde_json::to_vec(&value)
.map_err(|e| anyhow::anyhow!("Failed to serialize json `payload`: {e}")),
}
}
pub fn deserialize_payload<T: DeserializeOwned>(
encoding: SerializationEncoding,
payload: &[u8],
) -> anyhow::Result<T> {
let mut value = match encoding {
SerializationEncoding::MsgPack => rmp_serde::from_slice(payload)
.map_err(|e| anyhow::anyhow!("Failed to deserialize msgpack `payload`: {e}"))?,
SerializationEncoding::Json => serde_json::from_slice(payload)
.map_err(|e| anyhow::anyhow!("Failed to deserialize json `payload`: {e}"))?,
};
convert_timestamp_strings(&mut value);
serde_json::from_value(value)
.map_err(|e| anyhow::anyhow!("Failed to convert value to target type: {e}"))
}
pub async fn scan_keys(
con: &mut ConnectionManager,
pattern: String,
) -> anyhow::Result<Vec<String>> {
let mut result = Vec::new();
let mut cursor = 0u64;
loop {
let scan_result: (u64, Vec<String>) = redis::cmd("SCAN")
.arg(cursor)
.arg("MATCH")
.arg(&pattern)
.arg("COUNT")
.arg(5000)
.query_async(con)
.await?;
let (new_cursor, keys) = scan_result;
result.extend(keys);
if new_cursor == 0 {
break;
}
cursor = new_cursor;
}
Ok(result)
}
pub async fn read_bulk(
con: &ConnectionManager,
keys: &[String],
) -> anyhow::Result<Vec<Option<Bytes>>> {
if keys.is_empty() {
return Ok(vec![]);
}
let mut con = con.clone();
let results: Vec<Option<Vec<u8>>> =
redis::cmd("MGET").arg(keys).query_async(&mut con).await?;
let bytes_results: Vec<Option<Bytes>> = results
.into_iter()
.map(|opt| opt.map(Bytes::from))
.collect();
Ok(bytes_results)
}
pub async fn read_bulk_batched(
con: &ConnectionManager,
keys: &[String],
batch_size: usize,
) -> anyhow::Result<Vec<Option<Bytes>>> {
if batch_size == 0 {
anyhow::bail!("`batch_size` must be greater than zero");
}
if keys.is_empty() {
return Ok(vec![]);
}
let mut all_results: Vec<Option<Bytes>> = Vec::with_capacity(keys.len());
for chunk in keys.chunks(batch_size) {
let mut con = con.clone();
let results: Vec<Option<Vec<u8>>> =
redis::cmd("MGET").arg(chunk).query_async(&mut con).await?;
all_results.extend(results.into_iter().map(|opt| opt.map(Bytes::from)));
}
Ok(all_results)
}
pub async fn read(
con: &ConnectionManager,
trader_key: &str,
key: &str,
) -> anyhow::Result<Vec<Bytes>> {
let collection = Self::get_collection_key(key)?;
let full_key = format!("{trader_key}{REDIS_DELIMITER}{key}");
let mut con = con.clone();
match collection {
INDEX => Self::read_index(&mut con, &full_key).await,
GENERAL => Self::read_string(&mut con, &full_key).await,
CURRENCIES => Self::read_string(&mut con, &full_key).await,
INSTRUMENTS => Self::read_string(&mut con, &full_key).await,
SYNTHETICS => Self::read_string(&mut con, &full_key).await,
ACCOUNTS => Self::read_list(&mut con, &full_key).await,
ORDERS => Self::read_list(&mut con, &full_key).await,
POSITIONS => Self::read_list(&mut con, &full_key).await,
ACTORS => Self::read_string(&mut con, &full_key).await,
STRATEGIES => Self::read_string(&mut con, &full_key).await,
_ => anyhow::bail!("Unsupported operation: `read` for collection '{collection}'"),
}
}
pub async fn load_all(
con: &ConnectionManager,
encoding: SerializationEncoding,
trader_key: &str,
) -> anyhow::Result<CacheMap> {
let (currencies, instruments, synthetics, accounts, orders, positions) = tokio::try_join!(
Self::load_currencies(con, trader_key, encoding),
Self::load_instruments(con, trader_key, encoding),
Self::load_synthetics(con, trader_key, encoding),
Self::load_accounts(con, trader_key, encoding),
Self::load_orders(con, trader_key, encoding),
Self::load_positions(con, trader_key, encoding)
)
.map_err(|e| anyhow::anyhow!("Error loading cache data: {e}"))?;
let greeks = AHashMap::new();
let yield_curves = AHashMap::new();
Ok(CacheMap {
currencies,
instruments,
synthetics,
accounts,
orders,
positions,
greeks,
yield_curves,
})
}
pub async fn load_currencies(
con: &ConnectionManager,
trader_key: &str,
encoding: SerializationEncoding,
) -> anyhow::Result<AHashMap<Ustr, Currency>> {
let mut currencies = AHashMap::new();
let pattern = format!("{trader_key}{REDIS_DELIMITER}{CURRENCIES}*");
log::debug!("Loading {pattern}");
let mut con = con.clone();
let keys = Self::scan_keys(&mut con, pattern).await?;
if keys.is_empty() {
return Ok(currencies);
}
let bulk_values = Self::read_bulk(&con, &keys).await?;
for (key, value_opt) in keys.iter().zip(bulk_values.iter()) {
let currency_code = if let Some(code) = key.as_str().rsplit(':').next() {
Ustr::from(code)
} else {
log::error!("Invalid key format: {key}");
continue;
};
if let Some(value_bytes) = value_opt {
match Self::deserialize_payload(encoding, value_bytes) {
Ok(currency) => {
currencies.insert(currency_code, currency);
}
Err(e) => {
log::error!("Failed to deserialize currency {currency_code}: {e}");
}
}
} else {
log::error!("Currency not found in Redis: {currency_code}");
}
}
log::debug!("Loaded {} currencies(s)", currencies.len());
Ok(currencies)
}
pub async fn load_instruments(
con: &ConnectionManager,
trader_key: &str,
encoding: SerializationEncoding,
) -> anyhow::Result<AHashMap<InstrumentId, InstrumentAny>> {
let mut instruments = AHashMap::new();
let pattern = format!("{trader_key}{REDIS_DELIMITER}{INSTRUMENTS}*");
log::debug!("Loading {pattern}");
let mut con = con.clone();
let keys = Self::scan_keys(&mut con, pattern).await?;
let futures: Vec<_> = keys
.iter()
.map(|key| {
let con = con.clone();
async move {
let instrument_id = key
.as_str()
.rsplit(':')
.next()
.ok_or_else(|| {
log::error!("Invalid key format: {key}");
"Invalid key format"
})
.and_then(|code| {
InstrumentId::from_str(code).map_err(|e| {
log::error!("Failed to convert to InstrumentId for {key}: {e}");
"Invalid instrument ID"
})
});
let instrument_id = match instrument_id {
Ok(id) => id,
Err(_) => return None,
};
match Self::load_instrument(&con, trader_key, &instrument_id, encoding).await {
Ok(Some(instrument)) => Some((instrument_id, instrument)),
Ok(None) => {
log::error!("Instrument not found: {instrument_id}");
None
}
Err(e) => {
log::error!("Failed to load instrument {instrument_id}: {e}");
None
}
}
}
})
.collect();
instruments.extend(join_all(futures).await.into_iter().flatten());
log::debug!("Loaded {} instruments(s)", instruments.len());
Ok(instruments)
}
pub async fn load_synthetics(
con: &ConnectionManager,
trader_key: &str,
encoding: SerializationEncoding,
) -> anyhow::Result<AHashMap<InstrumentId, SyntheticInstrument>> {
let mut synthetics = AHashMap::new();
let pattern = format!("{trader_key}{REDIS_DELIMITER}{SYNTHETICS}*");
log::debug!("Loading {pattern}");
let mut con = con.clone();
let keys = Self::scan_keys(&mut con, pattern).await?;
let futures: Vec<_> = keys
.iter()
.map(|key| {
let con = con.clone();
async move {
let instrument_id = key
.as_str()
.rsplit(':')
.next()
.ok_or_else(|| {
log::error!("Invalid key format: {key}");
"Invalid key format"
})
.and_then(|code| {
InstrumentId::from_str(code).map_err(|e| {
log::error!("Failed to parse InstrumentId for {key}: {e}");
"Invalid instrument ID"
})
});
let instrument_id = match instrument_id {
Ok(id) => id,
Err(_) => return None,
};
match Self::load_synthetic(&con, trader_key, &instrument_id, encoding).await {
Ok(Some(synthetic)) => Some((instrument_id, synthetic)),
Ok(None) => {
log::error!("Synthetic not found: {instrument_id}");
None
}
Err(e) => {
log::error!("Failed to load synthetic {instrument_id}: {e}");
None
}
}
}
})
.collect();
synthetics.extend(join_all(futures).await.into_iter().flatten());
log::debug!("Loaded {} synthetics(s)", synthetics.len());
Ok(synthetics)
}
pub async fn load_accounts(
con: &ConnectionManager,
trader_key: &str,
encoding: SerializationEncoding,
) -> anyhow::Result<AHashMap<AccountId, AccountAny>> {
let mut accounts = AHashMap::new();
let pattern = format!("{trader_key}{REDIS_DELIMITER}{ACCOUNTS}*");
log::debug!("Loading {pattern}");
let mut con = con.clone();
let keys = Self::scan_keys(&mut con, pattern).await?;
let futures: Vec<_> = keys
.iter()
.map(|key| {
let con = con.clone();
async move {
let account_id = if let Some(code) = key.as_str().rsplit(':').next() {
AccountId::from(code)
} else {
log::error!("Invalid key format: {key}");
return None;
};
match Self::load_account(&con, trader_key, &account_id, encoding).await {
Ok(Some(account)) => Some((account_id, account)),
Ok(None) => {
log::error!("Account not found: {account_id}");
None
}
Err(e) => {
log::error!("Failed to load account {account_id}: {e}");
None
}
}
}
})
.collect();
accounts.extend(join_all(futures).await.into_iter().flatten());
log::debug!("Loaded {} accounts(s)", accounts.len());
Ok(accounts)
}
pub async fn load_orders(
con: &ConnectionManager,
trader_key: &str,
encoding: SerializationEncoding,
) -> anyhow::Result<AHashMap<ClientOrderId, OrderAny>> {
let mut orders = AHashMap::new();
let pattern = format!("{trader_key}{REDIS_DELIMITER}{ORDERS}*");
log::debug!("Loading {pattern}");
let mut con = con.clone();
let keys = Self::scan_keys(&mut con, pattern).await?;
let futures: Vec<_> = keys
.iter()
.map(|key| {
let con = con.clone();
async move {
let client_order_id = if let Some(code) = key.as_str().rsplit(':').next() {
ClientOrderId::from(code)
} else {
log::error!("Invalid key format: {key}");
return None;
};
match Self::load_order(&con, trader_key, &client_order_id, encoding).await {
Ok(Some(order)) => Some((client_order_id, order)),
Ok(None) => {
log::error!("Order not found: {client_order_id}");
None
}
Err(e) => {
log::error!("Failed to load order {client_order_id}: {e}");
None
}
}
}
})
.collect();
orders.extend(join_all(futures).await.into_iter().flatten());
log::debug!("Loaded {} order(s)", orders.len());
Ok(orders)
}
pub async fn load_positions(
con: &ConnectionManager,
trader_key: &str,
encoding: SerializationEncoding,
) -> anyhow::Result<AHashMap<PositionId, Position>> {
let mut positions = AHashMap::new();
let pattern = format!("{trader_key}{REDIS_DELIMITER}{POSITIONS}*");
log::debug!("Loading {pattern}");
let mut con = con.clone();
let keys = Self::scan_keys(&mut con, pattern).await?;
let futures: Vec<_> = keys
.iter()
.map(|key| {
let con = con.clone();
async move {
let position_id = if let Some(code) = key.as_str().rsplit(':').next() {
PositionId::from(code)
} else {
log::error!("Invalid key format: {key}");
return None;
};
match Self::load_position(&con, trader_key, &position_id, encoding).await {
Ok(Some(position)) => Some((position_id, position)),
Ok(None) => {
log::error!("Position not found: {position_id}");
None
}
Err(e) => {
log::error!("Failed to load position {position_id}: {e}");
None
}
}
}
})
.collect();
positions.extend(join_all(futures).await.into_iter().flatten());
log::debug!("Loaded {} position(s)", positions.len());
Ok(positions)
}
pub async fn load_custom_data(
con: &ConnectionManager,
trader_key: &str,
data_type: &DataType,
) -> anyhow::Result<Vec<CustomData>> {
let pattern = format!("{trader_key}{REDIS_DELIMITER}{CUSTOM}*");
log::debug!("Loading custom data {pattern}");
let mut con = con.clone();
let keys = Self::scan_keys(&mut con, pattern).await?;
if keys.is_empty() {
return Ok(Vec::new());
}
let values = Self::read_bulk(&con, &keys).await?;
let request_type_name = data_type.type_name();
let request_short = request_type_name
.rsplit([':', '.'])
.next()
.unwrap_or(request_type_name);
let request_identifier = data_type.identifier().unwrap_or("");
let mut results = Vec::new();
for value_opt in values {
let Some(value_bytes) = value_opt else {
continue;
};
let custom = match CustomData::from_json_bytes(value_bytes.as_ref()) {
Ok(c) => c,
Err(e) => {
log::warn!("Failed to deserialize custom data from Redis: {e}");
continue;
}
};
let stored_type_name = custom.data_type.type_name();
let type_match =
stored_type_name == request_type_name || stored_type_name == request_short;
let identifier_match =
custom.data_type.identifier().unwrap_or("") == request_identifier;
let metadata_match = match (data_type.metadata(), custom.data_type.metadata()) {
(None, None) => true,
(Some(a), Some(b)) => serde_json::to_value(a).ok() == serde_json::to_value(b).ok(),
_ => false,
};
if type_match && identifier_match && metadata_match {
results.push(custom);
}
}
results.sort_by_key(|c| c.ts_init());
log::debug!("Loaded {} custom data item(s)", results.len());
Ok(results)
}
pub async fn load_currency(
con: &ConnectionManager,
trader_key: &str,
code: &Ustr,
encoding: SerializationEncoding,
) -> anyhow::Result<Option<Currency>> {
let key = format!("{CURRENCIES}{REDIS_DELIMITER}{code}");
let result = Self::read(con, trader_key, &key).await?;
if result.is_empty() {
return Ok(None);
}
let currency = Self::deserialize_payload(encoding, &result[0])?;
Ok(currency)
}
pub async fn load_instrument(
con: &ConnectionManager,
trader_key: &str,
instrument_id: &InstrumentId,
encoding: SerializationEncoding,
) -> anyhow::Result<Option<InstrumentAny>> {
let key = format!("{INSTRUMENTS}{REDIS_DELIMITER}{instrument_id}");
let result = Self::read(con, trader_key, &key).await?;
if result.is_empty() {
return Ok(None);
}
let instrument: InstrumentAny = Self::deserialize_payload(encoding, &result[0])?;
Ok(Some(instrument))
}
pub async fn load_synthetic(
con: &ConnectionManager,
trader_key: &str,
instrument_id: &InstrumentId,
encoding: SerializationEncoding,
) -> anyhow::Result<Option<SyntheticInstrument>> {
let key = format!("{SYNTHETICS}{REDIS_DELIMITER}{instrument_id}");
let result = Self::read(con, trader_key, &key).await?;
if result.is_empty() {
return Ok(None);
}
let synthetic: SyntheticInstrument = Self::deserialize_payload(encoding, &result[0])?;
Ok(Some(synthetic))
}
pub async fn load_account(
con: &ConnectionManager,
trader_key: &str,
account_id: &AccountId,
encoding: SerializationEncoding,
) -> anyhow::Result<Option<AccountAny>> {
let key = format!("{ACCOUNTS}{REDIS_DELIMITER}{account_id}");
let result = Self::read(con, trader_key, &key).await?;
if result.is_empty() {
return Ok(None);
}
let account: AccountAny = Self::deserialize_payload(encoding, &result[0])?;
Ok(Some(account))
}
pub async fn load_order(
con: &ConnectionManager,
trader_key: &str,
client_order_id: &ClientOrderId,
encoding: SerializationEncoding,
) -> anyhow::Result<Option<OrderAny>> {
let key = format!("{ORDERS}{REDIS_DELIMITER}{client_order_id}");
let result = Self::read(con, trader_key, &key).await?;
if result.is_empty() {
return Ok(None);
}
let order: OrderAny = Self::deserialize_payload(encoding, &result[0])?;
Ok(Some(order))
}
pub async fn load_position(
con: &ConnectionManager,
trader_key: &str,
position_id: &PositionId,
encoding: SerializationEncoding,
) -> anyhow::Result<Option<Position>> {
let key = format!("{POSITIONS}{REDIS_DELIMITER}{position_id}");
let result = Self::read(con, trader_key, &key).await?;
if result.is_empty() {
return Ok(None);
}
let position: Position = Self::deserialize_payload(encoding, &result[0])?;
Ok(Some(position))
}
fn get_collection_key(key: &str) -> anyhow::Result<&str> {
key.split_once(REDIS_DELIMITER)
.map(|(collection, _)| collection)
.ok_or_else(|| {
anyhow::anyhow!("Invalid `key`, missing a '{REDIS_DELIMITER}' delimiter, was {key}")
})
}
async fn read_index(conn: &mut ConnectionManager, key: &str) -> anyhow::Result<Vec<Bytes>> {
let index_key = get_index_key(key)?;
match index_key {
INDEX_ORDER_IDS => Self::read_set(conn, key).await,
INDEX_ORDER_POSITION => Self::read_hset(conn, key).await,
INDEX_ORDER_CLIENT => Self::read_hset(conn, key).await,
INDEX_ORDERS => Self::read_set(conn, key).await,
INDEX_ORDERS_OPEN => Self::read_set(conn, key).await,
INDEX_ORDERS_CLOSED => Self::read_set(conn, key).await,
INDEX_ORDERS_EMULATED => Self::read_set(conn, key).await,
INDEX_ORDERS_INFLIGHT => Self::read_set(conn, key).await,
INDEX_POSITIONS => Self::read_set(conn, key).await,
INDEX_POSITIONS_OPEN => Self::read_set(conn, key).await,
INDEX_POSITIONS_CLOSED => Self::read_set(conn, key).await,
_ => anyhow::bail!("Index unknown '{index_key}' on read"),
}
}
async fn read_string(conn: &mut ConnectionManager, key: &str) -> anyhow::Result<Vec<Bytes>> {
let result: Vec<u8> = conn.get(key).await?;
if result.is_empty() {
Ok(vec![])
} else {
Ok(vec![Bytes::from(result)])
}
}
async fn read_set(conn: &mut ConnectionManager, key: &str) -> anyhow::Result<Vec<Bytes>> {
let result: Vec<Bytes> = conn.smembers(key).await?;
Ok(result)
}
async fn read_hset(conn: &mut ConnectionManager, key: &str) -> anyhow::Result<Vec<Bytes>> {
let result: HashMap<String, String> = conn.hgetall(key).await?;
let json = serde_json::to_string(&result)?;
Ok(vec![Bytes::from(json.into_bytes())])
}
async fn read_list(conn: &mut ConnectionManager, key: &str) -> anyhow::Result<Vec<Bytes>> {
let result: Vec<Bytes> = conn.lrange(key, 0, -1).await?;
Ok(result)
}
}
fn is_timestamp_field(key: &str) -> bool {
let expire_match = key == "expire_time_ns";
let ts_match = key.starts_with("ts_");
expire_match || ts_match
}
fn convert_timestamps(value: &mut Value) {
match value {
Value::Object(map) => {
for (key, v) in map {
if is_timestamp_field(key)
&& let Value::Number(n) = v
&& let Some(n) = n.as_u64()
{
let dt = DateTime::<Utc>::from_timestamp_nanos(n as i64);
*v = Value::String(dt.to_rfc3339_opts(chrono::SecondsFormat::Nanos, true));
}
convert_timestamps(v);
}
}
Value::Array(arr) => {
for item in arr {
convert_timestamps(item);
}
}
_ => {}
}
}
fn convert_timestamp_strings(value: &mut Value) {
match value {
Value::Object(map) => {
for (key, v) in map {
if is_timestamp_field(key)
&& let Value::String(s) = v
&& let Ok(dt) = DateTime::parse_from_rfc3339(s)
{
*v = Value::Number(
(dt.with_timezone(&Utc)
.timestamp_nanos_opt()
.expect("Invalid DateTime") as u64)
.into(),
);
}
convert_timestamp_strings(v);
}
}
Value::Array(arr) => {
for item in arr {
convert_timestamp_strings(item);
}
}
_ => {}
}
}