use bytes::Bytes;
use nautilus_common::{cache::CacheConfig, live::get_runtime};
use nautilus_core::{
UUID4,
python::{to_pyruntime_err, to_pyvalue_err},
};
use nautilus_model::{
data::{CustomData, DataType},
identifiers::{AccountId, ClientOrderId, PositionId, TraderId},
python::{
account::account_any_to_pyobject, instruments::instrument_any_to_pyobject,
orders::order_any_to_pyobject,
},
};
use pyo3::{
IntoPyObjectExt,
prelude::*,
types::{PyBytes, PyDict},
};
use serde_json::Value;
use crate::redis::{
cache::{RedisCacheConfig, RedisCacheDatabase},
queries::DatabaseQueries,
};
#[pymethods]
impl RedisCacheDatabase {
#[new]
#[pyo3(signature = (trader_id, instance_id, config_json, database_config_json=None))]
fn py_new(
trader_id: TraderId,
instance_id: UUID4,
config_json: &[u8],
database_config_json: Option<&[u8]>,
) -> PyResult<Self> {
let (config, database) = parse_inputs(config_json, database_config_json)?;
let result = get_runtime()
.block_on(async { Self::new(trader_id, instance_id, config, database).await });
result.map_err(to_pyruntime_err)
}
#[pyo3(name = "close")]
fn py_close(&mut self) {
self.close();
}
#[pyo3(name = "flushdb")]
fn py_flushdb(&mut self) {
get_runtime().block_on(async { self.flushdb().await });
}
#[pyo3(name = "keys")]
fn py_keys(&mut self, pattern: &str) -> PyResult<Vec<String>> {
let result = get_runtime().block_on(async { self.keys(pattern).await });
result.map_err(to_pyruntime_err)
}
#[pyo3(name = "load_all")]
fn py_load_all(&mut self) -> PyResult<Py<PyAny>> {
let result = get_runtime().block_on(async {
DatabaseQueries::load_all(&self.con, self.get_encoding(), self.get_trader_key()).await
});
match result {
Ok(cache_map) => Python::attach(|py| {
let dict = PyDict::new(py);
let currencies_dict = PyDict::new(py);
for (key, value) in cache_map.currencies {
currencies_dict
.set_item(key.to_string(), value)
.map_err(to_pyvalue_err)?;
}
dict.set_item("currencies", currencies_dict)
.map_err(to_pyvalue_err)?;
let instruments_dict = PyDict::new(py);
for (key, value) in cache_map.instruments {
let py_object = instrument_any_to_pyobject(py, value)?;
instruments_dict
.set_item(key, py_object)
.map_err(to_pyvalue_err)?;
}
dict.set_item("instruments", instruments_dict)
.map_err(to_pyvalue_err)?;
let synthetics_dict = PyDict::new(py);
for (key, value) in cache_map.synthetics {
synthetics_dict
.set_item(key, value)
.map_err(to_pyvalue_err)?;
}
dict.set_item("synthetics", synthetics_dict)
.map_err(to_pyvalue_err)?;
let accounts_dict = PyDict::new(py);
for (key, value) in cache_map.accounts {
let py_object = account_any_to_pyobject(py, value)?;
accounts_dict
.set_item(key, py_object)
.map_err(to_pyvalue_err)?;
}
dict.set_item("accounts", accounts_dict)
.map_err(to_pyvalue_err)?;
let orders_dict = PyDict::new(py);
for (key, value) in cache_map.orders {
let py_object = order_any_to_pyobject(py, value)?;
orders_dict
.set_item(key, py_object)
.map_err(to_pyvalue_err)?;
}
dict.set_item("orders", orders_dict)
.map_err(to_pyvalue_err)?;
let positions_dict = PyDict::new(py);
for (key, value) in cache_map.positions {
positions_dict
.set_item(key, value)
.map_err(to_pyvalue_err)?;
}
dict.set_item("positions", positions_dict)
.map_err(to_pyvalue_err)?;
dict.into_py_any(py)
}),
Err(e) => Err(to_pyruntime_err(e)),
}
}
#[pyo3(name = "read")]
fn py_read(&mut self, py: Python, key: &str) -> PyResult<Vec<Py<PyAny>>> {
let result = get_runtime().block_on(async { self.read(key).await });
match result {
Ok(result) => {
let vec_py_bytes = result
.into_iter()
.map(|r| PyBytes::new(py, r.as_ref()).into())
.collect::<Vec<Py<PyAny>>>();
Ok(vec_py_bytes)
}
Err(e) => Err(to_pyruntime_err(e)),
}
}
#[pyo3(name = "read_bulk")]
#[expect(clippy::needless_pass_by_value)]
fn py_read_bulk(&mut self, py: Python, keys: Vec<String>) -> PyResult<Vec<Option<Py<PyAny>>>> {
let result = get_runtime().block_on(async { self.read_bulk(&keys).await });
match result {
Ok(results) => {
let vec_py_bytes = results
.into_iter()
.map(|opt| opt.map(|bytes| PyBytes::new(py, bytes.as_ref()).into()))
.collect::<Vec<Option<Py<PyAny>>>>();
Ok(vec_py_bytes)
}
Err(e) => Err(to_pyruntime_err(e)),
}
}
#[pyo3(name = "insert")]
fn py_insert(&mut self, key: String, payload: Vec<Vec<u8>>) -> PyResult<()> {
let payload: Vec<Bytes> = payload.into_iter().map(Bytes::from).collect();
self.insert(key, Some(payload)).map_err(to_pyvalue_err)
}
#[pyo3(name = "update")]
fn py_update(&mut self, key: String, payload: Vec<Vec<u8>>) -> PyResult<()> {
let payload: Vec<Bytes> = payload.into_iter().map(Bytes::from).collect();
self.update(key, Some(payload)).map_err(to_pyvalue_err)
}
#[pyo3(name = "delete")]
#[pyo3(signature = (key, payload=None))]
fn py_delete(&mut self, key: String, payload: Option<Vec<Vec<u8>>>) -> PyResult<()> {
let payload: Option<Vec<Bytes>> =
payload.map(|vec| vec.into_iter().map(Bytes::from).collect());
self.delete(key, payload).map_err(to_pyvalue_err)
}
#[pyo3(name = "delete_order")]
fn py_delete_order(&mut self, client_order_id: &str) -> PyResult<()> {
let client_order_id = ClientOrderId::new(client_order_id);
self.delete_order(&client_order_id).map_err(to_pyvalue_err)
}
#[pyo3(name = "delete_position")]
fn py_delete_position(&mut self, position_id: &str) -> PyResult<()> {
let position_id = PositionId::new(position_id);
self.delete_position(&position_id).map_err(to_pyvalue_err)
}
#[pyo3(name = "delete_account_event")]
fn py_delete_account_event(&mut self, account_id: &str, event_id: &str) -> PyResult<()> {
let account_id = AccountId::new(account_id);
self.delete_account_event(&account_id, event_id)
.map_err(to_pyvalue_err)
}
#[pyo3(name = "add_custom_data")]
#[expect(clippy::needless_pass_by_value)]
fn py_add_custom_data(&mut self, data: CustomData) -> PyResult<()> {
self.add_custom_data(&data).map_err(to_pyvalue_err)
}
#[pyo3(name = "load_custom_data")]
#[expect(clippy::needless_pass_by_value)]
fn py_load_custom_data(
&mut self,
py: Python<'_>,
data_type: DataType,
) -> PyResult<Vec<CustomData>> {
py.detach(|| self.load_custom_data(&data_type).map_err(to_pyvalue_err))
}
}
fn parse_inputs(
config_json: &[u8],
database_config_json: Option<&[u8]>,
) -> PyResult<(CacheConfig, RedisCacheConfig)> {
let mut config_value: Value = serde_json::from_slice(config_json).map_err(to_pyvalue_err)?;
let legacy_database = config_value
.as_object_mut()
.and_then(|object| object.remove("database"));
let config = serde_json::from_value(config_value).map_err(to_pyvalue_err)?;
let database = match database_config_json {
Some(raw) => serde_json::from_slice(raw).map_err(to_pyvalue_err)?,
None => match legacy_database {
Some(value) => config_from_legacy_database(value)?,
None => RedisCacheConfig::default(),
},
};
Ok((config, database))
}
fn config_from_legacy_database(mut value: Value) -> PyResult<RedisCacheConfig> {
if value.is_null() {
return Ok(RedisCacheConfig::default());
}
remove_legacy_selector(&mut value, "cache database")?;
serde_json::from_value(value).map_err(to_pyvalue_err)
}
fn remove_legacy_selector(value: &mut Value, label: &str) -> PyResult<()> {
let Some(object) = value.as_object_mut() else {
return Ok(());
};
let selector = object
.remove("database_type")
.or_else(|| object.remove("type"));
let Some(selector) = selector else {
return Ok(());
};
let Some(selector) = selector.as_str() else {
return Err(to_pyvalue_err(format!(
"invalid {label} type selector, expected string"
)));
};
if selector != "redis" {
return Err(to_pyvalue_err(format!(
"invalid {label} type selector, expected 'redis', was '{selector}'"
)));
}
Ok(())
}
#[cfg(test)]
mod tests {
use rstest::rstest;
use serde_json::json;
use super::*;
#[rstest]
fn test_parse_inputs_accepts_legacy_database() {
let config_json = serde_json::to_vec(&json!({
"database": {
"type": "redis",
"host": "redis.example.com",
"port": 6380,
"password": "secret",
"ssl": true,
},
"encoding": "json",
"buffer_interval_ms": 25,
}))
.unwrap();
let (config, database) = parse_inputs(&config_json, None).unwrap();
assert_eq!(config.buffer_interval_ms, Some(25));
assert_eq!(database.host, Some("redis.example.com".to_string()));
assert_eq!(database.port, Some(6380));
assert_eq!(database.password, Some("secret".to_string()));
assert!(database.ssl);
}
#[rstest]
fn test_parse_inputs_defaults_null_legacy_database() {
let config_json = serde_json::to_vec(&json!({
"database": null,
"buffer_interval_ms": 50,
}))
.unwrap();
let (config, database) = parse_inputs(&config_json, None).unwrap();
assert_eq!(config.buffer_interval_ms, Some(50));
assert_eq!(database, RedisCacheConfig::default());
}
#[rstest]
fn test_parse_inputs_prefers_explicit_database_config() {
let config_json = serde_json::to_vec(&json!({
"database": {
"type": "redis",
"host": "legacy.example.com",
},
}))
.unwrap();
let database_config_json = serde_json::to_vec(&json!({
"host": "explicit.example.com",
"port": 6381,
}))
.unwrap();
let (_, database) = parse_inputs(&config_json, Some(&database_config_json)).unwrap();
assert_eq!(database.host, Some("explicit.example.com".to_string()));
assert_eq!(database.port, Some(6381));
}
#[rstest]
fn test_parse_inputs_rejects_non_redis_legacy_database() {
Python::initialize();
let config_json = serde_json::to_vec(&json!({
"database": {
"type": "postgres",
},
}))
.unwrap();
let error = parse_inputs(&config_json, None).unwrap_err();
assert_eq!(
error.to_string(),
"ValueError: invalid cache database type selector, expected 'redis', was 'postgres'"
);
}
}