use std::fmt::Debug;
use bytes::Bytes;
use nautilus_core::UUID4;
use nautilus_model::identifiers::TraderId;
use serde::{Deserialize, Serialize};
use ustr::Ustr;
use crate::enums::SerializationEncoding;
#[cfg_attr(
feature = "python",
pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.common", from_py_object)
)]
#[cfg_attr(
feature = "python",
pyo3_stub_gen::derive::gen_stub_pyclass(module = "nautilus_trader.common")
)]
#[derive(Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(default, deny_unknown_fields)]
pub struct DatabaseConfig {
#[serde(alias = "type")]
pub database_type: String,
pub host: Option<String>,
pub port: Option<u16>,
pub username: Option<String>,
pub password: Option<String>,
pub ssl: bool,
pub connection_timeout: u16,
pub response_timeout: u16,
pub number_of_retries: usize,
pub exponent_base: u64,
pub max_delay: u64,
pub factor: u64,
}
impl Debug for DatabaseConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let redacted = self.password.as_ref().map(|_| "***");
f.debug_struct(stringify!(DatabaseConfig))
.field("database_type", &self.database_type)
.field("host", &self.host)
.field("port", &self.port)
.field("username", &self.username)
.field("password", &redacted)
.field("ssl", &self.ssl)
.field("connection_timeout", &self.connection_timeout)
.field("response_timeout", &self.response_timeout)
.field("number_of_retries", &self.number_of_retries)
.field("exponent_base", &self.exponent_base)
.field("max_delay", &self.max_delay)
.field("factor", &self.factor)
.finish()
}
}
impl Default for DatabaseConfig {
fn default() -> Self {
Self {
database_type: "redis".to_string(),
host: None,
port: None,
username: None,
password: None,
ssl: false,
connection_timeout: 20,
response_timeout: 20,
number_of_retries: 100,
exponent_base: 2,
max_delay: 1000,
factor: 2,
}
}
}
#[cfg_attr(
feature = "python",
pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.common", from_py_object)
)]
#[cfg_attr(
feature = "python",
pyo3_stub_gen::derive::gen_stub_pyclass(module = "nautilus_trader.common")
)]
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, bon::Builder)]
#[serde(default, deny_unknown_fields)]
pub struct MessageBusConfig {
pub database: Option<DatabaseConfig>,
#[builder(default = SerializationEncoding::MsgPack)]
pub encoding: SerializationEncoding,
#[builder(default)]
pub timestamps_as_iso8601: bool,
pub buffer_interval_ms: Option<u32>,
pub autotrim_mins: Option<u32>,
#[builder(default = true)]
pub use_trader_prefix: bool,
#[builder(default = true)]
pub use_trader_id: bool,
#[builder(default)]
pub use_instance_id: bool,
#[builder(default = "stream".to_string())]
pub streams_prefix: String,
#[builder(default = true)]
pub stream_per_topic: bool,
pub external_streams: Option<Vec<String>>,
pub types_filter: Option<Vec<String>>,
pub heartbeat_interval_secs: Option<u16>,
}
impl Default for MessageBusConfig {
fn default() -> Self {
Self::builder().build()
}
}
pub trait MessageBusDatabaseAdapter {
type DatabaseType;
fn new(
trader_id: TraderId,
instance_id: UUID4,
config: MessageBusConfig,
) -> anyhow::Result<Self::DatabaseType>;
fn is_closed(&self) -> bool;
fn publish(&self, topic: Ustr, payload: Bytes);
fn close(&mut self);
}
#[cfg(test)]
mod tests {
use rstest::*;
use serde_json::json;
use super::*;
#[rstest]
fn test_default_database_config() {
let config = DatabaseConfig::default();
assert_eq!(config.database_type, "redis");
assert_eq!(config.host, None);
assert_eq!(config.port, None);
assert_eq!(config.username, None);
assert_eq!(config.password, None);
assert!(!config.ssl);
assert_eq!(config.connection_timeout, 20);
assert_eq!(config.response_timeout, 20);
assert_eq!(config.number_of_retries, 100);
assert_eq!(config.exponent_base, 2);
assert_eq!(config.max_delay, 1000);
assert_eq!(config.factor, 2);
}
#[rstest]
fn test_deserialize_database_config() {
let config_json = json!({
"type": "redis",
"host": "localhost",
"port": 6379,
"username": "user",
"password": "pass",
"ssl": true,
"connection_timeout": 30,
"response_timeout": 10,
"number_of_retries": 3,
"exponent_base": 2,
"max_delay": 10,
"factor": 2
});
let config: DatabaseConfig = serde_json::from_value(config_json).unwrap();
assert_eq!(config.database_type, "redis");
assert_eq!(config.host, Some("localhost".to_string()));
assert_eq!(config.port, Some(6379));
assert_eq!(config.username, Some("user".to_string()));
assert_eq!(config.password, Some("pass".to_string()));
assert!(config.ssl);
assert_eq!(config.connection_timeout, 30);
assert_eq!(config.response_timeout, 10);
assert_eq!(config.number_of_retries, 3);
assert_eq!(config.exponent_base, 2);
assert_eq!(config.max_delay, 10);
assert_eq!(config.factor, 2);
}
#[rstest]
fn test_deserialize_database_config_rejects_unknown_field() {
let config_json = json!({
"type": "redis",
"unexpected": true,
});
let error = serde_json::from_value::<DatabaseConfig>(config_json).unwrap_err();
assert!(error.to_string().contains("unknown field `unexpected`"));
}
#[rstest]
fn test_default_message_bus_config() {
let config = MessageBusConfig::default();
assert_eq!(config.encoding, SerializationEncoding::MsgPack);
assert!(!config.timestamps_as_iso8601);
assert_eq!(config.buffer_interval_ms, None);
assert_eq!(config.autotrim_mins, None);
assert!(config.use_trader_prefix);
assert!(config.use_trader_id);
assert!(!config.use_instance_id);
assert_eq!(config.streams_prefix, "stream");
assert!(config.stream_per_topic);
assert_eq!(config.external_streams, None);
assert_eq!(config.types_filter, None);
}
#[rstest]
fn test_deserialize_message_bus_config() {
let config_json = json!({
"database": {
"type": "redis",
"host": "localhost",
"port": 6379,
"username": "user",
"password": "pass",
"ssl": true,
"connection_timeout": 30,
"response_timeout": 10,
"number_of_retries": 3,
"exponent_base": 2,
"max_delay": 10,
"factor": 2
},
"encoding": "json",
"timestamps_as_iso8601": true,
"buffer_interval_ms": 100,
"autotrim_mins": 60,
"use_trader_prefix": false,
"use_trader_id": false,
"use_instance_id": true,
"streams_prefix": "data_streams",
"stream_per_topic": false,
"external_streams": ["stream1", "stream2"],
"types_filter": ["type1", "type2"]
});
let config: MessageBusConfig = serde_json::from_value(config_json).unwrap();
assert_eq!(config.encoding, SerializationEncoding::Json);
assert!(config.timestamps_as_iso8601);
assert_eq!(config.buffer_interval_ms, Some(100));
assert_eq!(config.autotrim_mins, Some(60));
assert!(!config.use_trader_prefix);
assert!(!config.use_trader_id);
assert!(config.use_instance_id);
assert_eq!(config.streams_prefix, "data_streams");
assert!(!config.stream_per_topic);
assert_eq!(
config.external_streams,
Some(vec!["stream1".to_string(), "stream2".to_string()])
);
assert_eq!(
config.types_filter,
Some(vec!["type1".to_string(), "type2".to_string()])
);
}
}