use std::{collections::HashSet, fmt::Debug, io::Read, sync::Arc};
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use crate::{
data::event::{
EconomicCalendarId, EmaId, MarketEvent, MarketId, OhlcvId, RsiId, SmaId, StreamId, TpoId,
TradesId, VolumeProfileId,
},
error::{ChapatyResult, IoError, SystemError},
gym::trading::config::EnvConfig,
io::{IoConfig, SerdeFormat},
sorted_vec_map::SortedVecMap,
};
pub type EventMap<S> = SortedVecMap<S, Box<[<S as StreamId>::Event]>>;
pub type OhlcvEventMap = EventMap<OhlcvId>;
pub type TradeEventMap = EventMap<TradesId>;
pub type EconomicCalEventMap = EventMap<EconomicCalendarId>;
pub type VolumeProfileEventMap = EventMap<VolumeProfileId>;
pub type TpoEventMap = EventMap<TpoId>;
pub type EmaEventMap = EventMap<EmaId>;
pub type SmaEventMap = EventMap<SmaId>;
pub type RsiEventMap = EventMap<RsiId>;
#[derive(Debug, Serialize, Deserialize)]
pub(crate) struct Streams {
ohlcv: OhlcvEventMap,
trade: TradeEventMap,
economic_cal: EconomicCalEventMap,
volume_profile: VolumeProfileEventMap,
tpo: TpoEventMap,
ema: EmaEventMap,
sma: SmaEventMap,
rsi: RsiEventMap,
}
impl Default for Streams {
fn default() -> Self {
Self {
ohlcv: OhlcvEventMap::new(),
trade: TradeEventMap::new(),
economic_cal: EconomicCalEventMap::new(),
volume_profile: VolumeProfileEventMap::new(),
tpo: TpoEventMap::new(),
ema: EmaEventMap::new(),
sma: SmaEventMap::new(),
rsi: RsiEventMap::new(),
}
}
}
impl Streams {
pub(crate) fn with_ohlcv(self, ohlcv: OhlcvEventMap) -> Self {
Self { ohlcv, ..self }
}
pub(crate) fn with_trade(self, trade: TradeEventMap) -> Self {
Self { trade, ..self }
}
pub(crate) fn with_economic_news(self, economic_cal: EconomicCalEventMap) -> Self {
Self {
economic_cal,
..self
}
}
pub(crate) fn with_volume_profile(self, volume_profile: VolumeProfileEventMap) -> Self {
Self {
volume_profile,
..self
}
}
pub(crate) fn with_tpo(self, tpo: TpoEventMap) -> Self {
Self { tpo, ..self }
}
pub(crate) fn with_ema(self, ema: EmaEventMap) -> Self {
Self { ema, ..self }
}
pub(crate) fn with_sma(self, sma: SmaEventMap) -> Self {
Self { sma, ..self }
}
pub(crate) fn with_rsi(self, rsi: RsiEventMap) -> Self {
Self { rsi, ..self }
}
}
impl Streams {
fn as_array(&self) -> [&dyn StreamTimeInfo; 8] {
[
&self.ohlcv,
&self.trade,
&self.economic_cal,
&self.volume_profile,
&self.tpo,
&self.ema,
&self.sma,
&self.rsi,
]
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct SimulationData {
streams: Streams,
market_ids: Arc<[MarketId]>,
global_availability_start: DateTime<Utc>,
global_open_start: DateTime<Utc>,
hash: String,
}
impl SimulationData {
pub fn ohlcv(&self) -> &OhlcvEventMap {
&self.streams.ohlcv
}
pub fn trade(&self) -> &TradeEventMap {
&self.streams.trade
}
pub fn economic_cal(&self) -> &EconomicCalEventMap {
&self.streams.economic_cal
}
pub fn volume_profile(&self) -> &VolumeProfileEventMap {
&self.streams.volume_profile
}
pub fn tpo(&self) -> &TpoEventMap {
&self.streams.tpo
}
pub fn ema(&self) -> &EmaEventMap {
&self.streams.ema
}
pub fn sma(&self) -> &SmaEventMap {
&self.streams.sma
}
pub fn rsi(&self) -> &RsiEventMap {
&self.streams.rsi
}
pub fn market_ids(&self) -> Arc<[MarketId]> {
self.market_ids.clone()
}
pub fn global_availability_start(&self) -> DateTime<Utc> {
self.global_availability_start
}
pub fn global_open_start(&self) -> DateTime<Utc> {
self.global_open_start
}
}
impl SimulationData {
#[tracing::instrument(skip(io_cfg, env_cfg), fields(format = ?io_cfg.format))]
pub(crate) async fn read(env_cfg: &EnvConfig, io_cfg: &IoConfig<'_>) -> ChapatyResult<Self> {
let IoConfig {
format,
location,
buffer_size,
file_stem: custom_file_stem,
} = io_cfg;
let hash = env_cfg.hash()?;
let filename = match custom_file_stem {
Some(stem) => format!("{stem}.{format}"),
None => format!("{hash}.{format}"),
};
tracing::debug!(
filename = %filename,
hash = %hash,
"Attempting to read cached simulation data"
);
let (mut reader, file_size) = match location.reader_with_size(&filename, *buffer_size).await
{
Ok(r) => r,
Err(e) => {
tracing::warn!(
filename = %filename,
error = %e,
"Cache miss: simulation data not found"
);
return Err(e);
}
};
let result = tokio::task::spawn_blocking(move || match format {
SerdeFormat::Postcard => {
const MB: u64 = 1024 * 1024;
let capacity = file_size.unwrap_or(100 * MB) as usize;
let mut data = Vec::with_capacity(capacity);
reader
.read_to_end(&mut data)
.map_err(|e| IoError::ReadFailed(e.to_string()))?;
postcard::from_bytes(&data).map_err(|e| IoError::ReadFailed(e.to_string()).into())
}
})
.await
.map_err(|e| SystemError::Generic(e.to_string()))?;
match &result {
Ok(_) => tracing::info!(
filename = %filename,
hash = %hash,
"Successfully loaded cached simulation data"
),
Err(e) => tracing::warn!(
filename = %filename,
hash = %hash,
error = %e,
"Cache miss: deserialization failed (possible schema mismatch)"
),
}
result
}
#[tracing::instrument(skip(self, cfg), fields(hash = %self.hash, format = ?cfg.format))]
pub(crate) async fn write(self: Arc<Self>, cfg: &IoConfig<'_>) -> ChapatyResult<()> {
let IoConfig {
format,
location,
buffer_size,
file_stem: custom_file_stem,
} = cfg;
let filename = match custom_file_stem {
Some(name) => format!("{name}.{format}"),
None => format!("{}.{format}", self.hash),
};
tracing::debug!(
filename = %filename,
"Writing simulation data to storage"
);
let mut writer = location.writer(&filename, *buffer_size).await?;
let result = tokio::task::spawn_blocking(move || {
let res = match format {
SerdeFormat::Postcard => postcard::to_io(&*self, &mut writer)
.map(|_| {})
.map_err(|e| IoError::WriteFailed(e.to_string()).into()),
};
if res.is_ok() {
let _ = writer.flush();
}
res
})
.await
.map_err(|e| SystemError::Generic(e.to_string()))?;
match &result {
Ok(_) => tracing::info!(
filename = %filename,
"Successfully wrote simulation data"
),
Err(e) => tracing::error!(
filename = %filename,
error = %e,
"Failed to write simulation data"
),
}
result
}
pub(crate) fn max_capacity_hint(&self) -> usize {
self.streams
.as_array()
.iter()
.map(|stream| stream.max_stream_len())
.max()
.unwrap_or(0)
}
}
pub trait StreamTimeInfo {
fn min_availability(&self) -> Option<DateTime<Utc>>;
fn min_open_time(&self) -> Option<DateTime<Utc>>;
fn max_stream_len(&self) -> usize;
}
impl<S: StreamId> StreamTimeInfo for EventMap<S> {
fn min_availability(&self) -> Option<DateTime<Utc>> {
self.iter()
.filter_map(|(_, events)| events.first().map(MarketEvent::point_in_time))
.min()
}
fn min_open_time(&self) -> Option<DateTime<Utc>> {
self.iter()
.filter_map(|(_, events)| events.first().map(MarketEvent::opened_at))
.min()
}
fn max_stream_len(&self) -> usize {
self.iter()
.map(|(_, events)| events.len())
.max()
.unwrap_or(0)
}
}
#[derive(Debug, Serialize, Deserialize)]
pub(crate) struct SimulationDataBuilder {
streams: Streams,
}
impl SimulationDataBuilder {
pub(crate) fn new(streams: Streams) -> Self {
Self { streams }
}
pub(crate) fn build(self, env_cfg: EnvConfig) -> ChapatyResult<SimulationData> {
let hash = env_cfg.hash()?;
let global_availability_start = self.global_availability_start();
let global_open_start = self.global_open_start();
let market_ids = self.collect_sorted_market_ids();
Ok(SimulationData {
streams: self.streams,
market_ids: market_ids.into(),
global_availability_start,
global_open_start,
hash,
})
}
}
impl SimulationDataBuilder {
fn global_availability_start(&self) -> DateTime<Utc> {
self.streams
.as_array()
.iter()
.filter_map(|stream| stream.min_availability())
.min()
.unwrap_or(DateTime::<Utc>::MIN_UTC)
}
fn global_open_start(&self) -> DateTime<Utc> {
self.streams
.as_array()
.iter()
.filter_map(|stream| stream.min_open_time())
.min()
.unwrap_or(DateTime::<Utc>::MIN_UTC)
}
fn collect_sorted_market_ids(&self) -> Vec<MarketId> {
let mut unique_markets = HashSet::new();
for id in self.streams.ohlcv.keys() {
unique_markets.insert(MarketId::from(id));
}
for id in self.streams.trade.keys() {
unique_markets.insert(MarketId::from(id));
}
let mut sorted_markets = unique_markets.into_iter().collect::<Vec<_>>();
sorted_markets.sort();
sorted_markets
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
DataSource, SelfHostedApi, StorageLocation,
data::{
config::{EconomicCalendarConfig, OhlcvSpotConfig},
domain::{
CountryCode, DataBroker, EconomicCategory, EconomicEventImpact, Exchange, Period,
Price, Quantity, SpotPair, Symbol,
},
event::{EconomicCalendarId, EconomicEvent, Ohlcv, OhlcvId, TradeEvent, TradesId},
},
transport::source::EndpointUrl,
};
fn make_ohlcv(open_ts: &str, close_ts: &str) -> Ohlcv {
Ohlcv {
open_timestamp: DateTime::parse_from_rfc3339(open_ts)
.unwrap()
.with_timezone(&Utc),
close_timestamp: DateTime::parse_from_rfc3339(close_ts)
.unwrap()
.with_timezone(&Utc),
open: Price(100.0),
high: Price(105.0),
low: Price(95.0),
close: Price(102.0),
volume: Quantity(1000.0),
quote_asset_volume: None,
number_of_trades: None,
taker_buy_base_asset_volume: None,
taker_buy_quote_asset_volume: None,
}
}
fn make_trade(ts: &str) -> TradeEvent {
TradeEvent {
timestamp: DateTime::parse_from_rfc3339(ts)
.unwrap()
.with_timezone(&Utc),
price: Price(100.0),
quantity: crate::data::domain::Quantity(1.0),
trade_id: None,
quote_asset_volume: None,
is_buyer_maker: None,
is_best_match: None,
}
}
fn make_ohlcv_id(symbol: Symbol) -> OhlcvId {
OhlcvId {
broker: DataBroker::Binance,
exchange: Exchange::Binance,
symbol,
period: Period::Minute(1),
}
}
fn make_trade_id(symbol: Symbol) -> TradesId {
TradesId {
broker: DataBroker::Binance,
exchange: Exchange::Binance,
symbol,
}
}
fn make_economic_calendar_id() -> EconomicCalendarId {
EconomicCalendarId {
broker: DataBroker::InvestingCom,
data_source: None,
country_code: Some(CountryCode::Us),
category: Some(EconomicCategory::Employment),
importance: None,
}
}
fn make_economic_event(ts: &str) -> EconomicEvent {
EconomicEvent {
timestamp: DateTime::parse_from_rfc3339(ts)
.unwrap()
.with_timezone(&Utc),
data_source: "investingcom".to_string(),
category: "Employment".to_string(),
news_name: "Nonfarm Payrolls".to_string(),
country_code: CountryCode::Us,
currency_code: "USD".to_string(),
economic_impact: EconomicEventImpact::High,
news_type: Some("NFP".to_string()),
news_type_confidence: Some(0.95),
news_type_source: Some("classifier".to_string()),
period: Some("mom".to_string()),
actual: None,
forecast: None,
previous: None,
}
}
#[test]
fn returns_earliest_availability_from_multiple_streams() {
let symbol = Symbol::Spot(crate::data::domain::SpotPair::BtcUsdt);
let ohlcv_id = make_ohlcv_id(symbol);
let trade_id = make_trade_id(symbol);
let ohlcv = make_ohlcv("2024-01-01T10:00:00Z", "2024-01-01T10:01:00Z");
let trade = make_trade("2024-01-01T09:30:00Z");
let mut ohlcv_map = OhlcvEventMap::new();
ohlcv_map.insert(ohlcv_id, Box::new([ohlcv]));
let mut trade_map = TradeEventMap::new();
trade_map.insert(trade_id, Box::new([trade]));
let streams = Streams::default()
.with_ohlcv(ohlcv_map)
.with_trade(trade_map);
let builder = SimulationDataBuilder { streams };
let result = builder.global_availability_start();
let expected = DateTime::parse_from_rfc3339("2024-01-01T09:30:00Z")
.unwrap()
.with_timezone(&Utc);
assert_eq!(result, expected);
}
#[test]
fn returns_earliest_from_single_stream() {
let symbol = Symbol::Spot(crate::data::domain::SpotPair::BtcUsdt);
let ohlcv_id = make_ohlcv_id(symbol);
let ohlcv1 = make_ohlcv("2024-01-01T10:00:00Z", "2024-01-01T10:01:00Z");
let ohlcv2 = make_ohlcv("2024-01-01T09:00:00Z", "2024-01-01T09:01:00Z");
let mut ohlcv_map = OhlcvEventMap::new();
ohlcv_map.insert(ohlcv_id, Box::new([ohlcv2, ohlcv1]));
let streams = Streams::default().with_ohlcv(ohlcv_map);
let builder = SimulationDataBuilder { streams };
let result = builder.global_availability_start();
let expected = DateTime::parse_from_rfc3339("2024-01-01T09:01:00Z")
.unwrap()
.with_timezone(&Utc);
assert_eq!(result, expected);
}
#[test]
fn global_availability_returns_min_utc_when_no_events() {
let streams = Streams::default();
let builder = SimulationDataBuilder { streams };
let result = builder.global_availability_start();
assert_eq!(result, DateTime::<Utc>::MIN_UTC);
}
#[test]
fn returns_earliest_open_from_multiple_streams() {
let symbol = Symbol::Spot(crate::data::domain::SpotPair::BtcUsdt);
let ohlcv_id = make_ohlcv_id(symbol);
let trade_id = make_trade_id(symbol);
let ohlcv = make_ohlcv("2024-01-01T09:00:00Z", "2024-01-01T09:01:00Z");
let trade = make_trade("2024-01-01T09:30:00Z");
let mut ohlcv_map = OhlcvEventMap::new();
ohlcv_map.insert(ohlcv_id, Box::new([ohlcv]));
let mut trade_map = TradeEventMap::new();
trade_map.insert(trade_id, Box::new([trade]));
let streams = Streams::default()
.with_ohlcv(ohlcv_map)
.with_trade(trade_map);
let builder = SimulationDataBuilder { streams };
let result = builder.global_open_start();
let expected = DateTime::parse_from_rfc3339("2024-01-01T09:00:00Z")
.unwrap()
.with_timezone(&Utc);
assert_eq!(result, expected);
}
#[test]
fn global_open_returns_min_utc_when_no_events() {
let streams = Streams::default();
let builder = SimulationDataBuilder { streams };
let result = builder.global_open_start();
assert_eq!(result, DateTime::<Utc>::MIN_UTC);
}
#[test]
fn returns_sorted_unique_market_ids() {
let symbol_a = Symbol::Spot(SpotPair::BtcUsdt);
let symbol_b = Symbol::Spot(SpotPair::EthUsdt);
let ohlcv_id_a = make_ohlcv_id(symbol_a);
let ohlcv_id_b = make_ohlcv_id(symbol_b);
let ohlcv = make_ohlcv("2024-01-01T10:00:00Z", "2024-01-01T10:01:00Z");
let mut ohlcv_map = OhlcvEventMap::new();
ohlcv_map.insert(ohlcv_id_b, Box::new([ohlcv]));
ohlcv_map.insert(ohlcv_id_a, Box::new([ohlcv]));
let streams = Streams::default().with_ohlcv(ohlcv_map);
let builder = SimulationDataBuilder { streams };
let result = builder.collect_sorted_market_ids();
assert_eq!(result.len(), 2);
for i in 0..result.len() - 1 {
assert!(result[i] < result[i + 1], "MarketIds should be sorted");
}
}
#[test]
fn deduplicates_market_ids_from_ohlcv_and_trade() {
let symbol = Symbol::Spot(SpotPair::BtcUsdt);
let ohlcv_id = make_ohlcv_id(symbol);
let trade_id = make_trade_id(symbol);
let ohlcv = make_ohlcv("2024-01-01T10:00:00Z", "2024-01-01T10:01:00Z");
let trade = make_trade("2024-01-01T09:30:00Z");
let mut ohlcv_map = OhlcvEventMap::new();
ohlcv_map.insert(ohlcv_id, Box::new([ohlcv]));
let mut trade_map = TradeEventMap::new();
trade_map.insert(trade_id, Box::new([trade]));
let streams = Streams::default()
.with_ohlcv(ohlcv_map)
.with_trade(trade_map);
let builder = SimulationDataBuilder { streams };
let result = builder.collect_sorted_market_ids();
assert_eq!(result.len(), 1);
assert_eq!(result[0].symbol, symbol);
}
#[test]
fn returns_empty_when_no_price_authoritative_sources() {
let streams = Streams::default();
let builder = SimulationDataBuilder { streams };
let result = builder.collect_sorted_market_ids();
assert!(result.is_empty());
}
fn make_test_env_config() -> EnvConfig {
EnvConfig::default()
.add_ohlcv_spot(
DataSource::SelfHosted(SelfHostedApi {
endpoint: EndpointUrl::from("http://test:50051"),
api_key: None,
}),
OhlcvSpotConfig {
broker: DataBroker::Binance,
symbol: Symbol::Spot(SpotPair::BtcUsdt),
exchange: Some(Exchange::Binance),
period: Period::Minute(1),
batch_size: 100,
indicators: vec![],
},
)
.add_economic_calendar(
DataSource::SelfHosted(SelfHostedApi {
endpoint: EndpointUrl::from("http://test:50051"),
api_key: None,
}),
EconomicCalendarConfig {
broker: DataBroker::InvestingCom,
data_source: None,
country_code: Some(CountryCode::Us),
category: Some(EconomicCategory::Employment),
importance: None,
batch_size: 1000,
},
)
}
fn make_test_simulation_data(env_cfg: EnvConfig) -> SimulationData {
let symbol = Symbol::Spot(SpotPair::BtcUsdt);
let ohlcv_id = make_ohlcv_id(symbol);
let ohlcv1 = make_ohlcv("2024-01-01T09:00:00Z", "2024-01-01T09:01:00Z");
let ohlcv2 = make_ohlcv("2024-01-01T09:01:00Z", "2024-01-01T09:02:00Z");
let ohlcv3 = make_ohlcv("2024-01-01T09:02:00Z", "2024-01-01T09:03:00Z");
let mut ohlcv_map = OhlcvEventMap::new();
ohlcv_map.insert(ohlcv_id, Box::new([ohlcv1, ohlcv2, ohlcv3]));
let eco_cal_id = make_economic_calendar_id();
let eco_event1 = make_economic_event("2024-01-01T08:30:00Z");
let eco_event2 = make_economic_event("2024-01-01T10:00:00Z");
let mut eco_cal_map = EconomicCalEventMap::new();
eco_cal_map.insert(eco_cal_id, Box::new([eco_event1, eco_event2]));
let streams = Streams::default()
.with_ohlcv(ohlcv_map)
.with_economic_news(eco_cal_map);
SimulationDataBuilder { streams }
.build(env_cfg)
.expect("Failed to build SimulationData")
}
#[tokio::test]
async fn file_based_roundtrip_succeeds() {
let env_cfg = make_test_env_config();
let sim_data = Arc::new(make_test_simulation_data(env_cfg.clone()));
let temp_dir = std::env::temp_dir().join("chapaty_test_cache");
let storage = StorageLocation::Local { path: &temp_dir };
let io_cfg = IoConfig::new(storage);
sim_data
.clone()
.write(&io_cfg)
.await
.expect("write() failed");
let hash = env_cfg.hash().expect("Failed to hash env config");
let cache_path = temp_dir.join(format!("{hash}.postcard"));
assert!(cache_path.exists(), "Cache file was not created");
let loaded = SimulationData::read(&env_cfg, &io_cfg)
.await
.expect("read() failed");
assert_eq!(sim_data.market_ids().len(), loaded.market_ids().len());
assert_eq!(sim_data.ohlcv().len(), loaded.ohlcv().len());
assert_eq!(sim_data.economic_cal().len(), loaded.economic_cal().len());
assert_eq!(sim_data.global_open_start(), loaded.global_open_start());
assert_eq!(
sim_data.global_availability_start(),
loaded.global_availability_start()
);
std::fs::remove_file(&cache_path).expect("Failed to remove cache file");
std::fs::remove_dir(&temp_dir).expect("Failed to remove temp dir");
}
#[tokio::test]
async fn with_file_stem_roundtrip_uses_custom_name() {
const CUSTOM_NAME: &str = "my_custom_cache";
let env_cfg = make_test_env_config();
let sim_data = Arc::new(make_test_simulation_data(env_cfg.clone()));
let temp_dir = std::env::temp_dir().join("chapaty_test_cache_custom");
let storage = StorageLocation::Local { path: &temp_dir };
let io_cfg = IoConfig::new(storage).with_file_stem(CUSTOM_NAME);
sim_data
.clone()
.write(&io_cfg)
.await
.expect("write() failed");
let expected_path = temp_dir.join(format!("{CUSTOM_NAME}.postcard"));
assert!(
expected_path.exists(),
"Expected cache file '{CUSTOM_NAME}.postcard' was not created"
);
let hash = env_cfg.hash().expect("Failed to hash env config");
let hash_path = temp_dir.join(format!("{hash}.postcard"));
assert!(
!hash_path.exists(),
"Hash-named file should not exist when with_file_stem is set"
);
let loaded = SimulationData::read(&env_cfg, &io_cfg)
.await
.expect("read() failed");
assert_eq!(sim_data.market_ids().len(), loaded.market_ids().len());
assert_eq!(sim_data.ohlcv().len(), loaded.ohlcv().len());
assert_eq!(sim_data.economic_cal().len(), loaded.economic_cal().len());
assert_eq!(sim_data.global_open_start(), loaded.global_open_start());
assert_eq!(
sim_data.global_availability_start(),
loaded.global_availability_start()
);
std::fs::remove_file(&expected_path).expect("Failed to remove cache file");
std::fs::remove_dir(&temp_dir).expect("Failed to remove temp dir");
}
}