use std::collections::BTreeSet;
use serde::{Deserialize, Serialize};
use strum::{Display, EnumCount, EnumIter, EnumString, IntoStaticStr};
use crate::{
ApiKey, EndpointUrl, SelfHostedApi,
data::{
common::{ProfileAggregation, RiskMetricsConfig},
config::{
EconomicCalendarConfig, OhlcvFutureConfig, OhlcvSpotConfig, TpoFutureConfig,
TpoSpotConfig, TradeSpotConfig, VolumeProfileSpotConfig,
},
domain::{
ContractMonth, ContractYear, CountryCode, DataBroker, EconomicCategory,
EconomicEventImpact, Exchange, FutureContract, FutureRoot, Period, SpotPair, Symbol,
},
episode::EpisodeLength,
filter::{EconomicCalendarPolicy, FilterConfig},
indicator::{SmaWindow, TechnicalIndicator},
},
error::{ChapatyResult, EnvError},
gym::InvalidActionPenalty,
transport::source::{DataSource, SourceGroup},
};
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash, Default)]
pub enum ExecutionBias {
Optimistic,
#[default]
Pessimistic,
}
#[derive(
Debug,
Clone,
Copy,
PartialEq,
Eq,
Hash,
Serialize,
Deserialize,
EnumString,
Display,
PartialOrd,
Ord,
EnumIter,
IntoStaticStr,
EnumCount,
)]
#[strum(serialize_all = "snake_case")]
pub enum EnvPreset {
BinanceBtcUsdt1d,
BinanceBtcUsdt1m,
BinanceBtcUsdt1m15m,
NinjaTraderCme6eh61m5mUsEmpHigh,
NinjaTraderCme6eh61mUsEmpHighEventsOnly,
NinjaTraderCme6eh61m5mUsEmpHighEventsOnly,
BinanceBtcUsdt1dSma20Sma50,
BinanceBtcUsdt1h1mVolumeProfile1d100Usdt,
BinanceBtcUsdt1h1mTpo1d1Usdt,
NinjaTraderCme6eh61mTpo1d,
}
fn self_hosted_source() -> DataSource {
let api_key = std::env::var("CHAPATY_API_KEY").ok().map(ApiKey);
DataSource::SelfHosted(SelfHostedApi {
endpoint: EndpointUrl("http://[::1]:50051".to_string()),
api_key,
})
}
impl From<EnvPreset> for EnvConfig {
fn from(preset: EnvPreset) -> Self {
let source = self_hosted_source();
match preset {
EnvPreset::BinanceBtcUsdt1d => {
let market_config = OhlcvSpotConfig {
broker: DataBroker::Binance,
symbol: Symbol::Spot(SpotPair::BtcUsdt),
period: Period::Day(1),
batch_size: 1000,
exchange: Some(Exchange::Binance),
indicators: Vec::new(),
};
let allowed_years = (2017..=2026).collect::<BTreeSet<_>>();
let filter = FilterConfig {
allowed_years: Some(allowed_years),
..FilterConfig::default()
};
EnvConfig::default()
.add_ohlcv_spot(source.clone(), market_config)
.with_episode_length(EpisodeLength::Infinite)
.with_filter_config(filter)
}
EnvPreset::BinanceBtcUsdt1m => {
let market_config = OhlcvSpotConfig {
broker: DataBroker::Binance,
symbol: Symbol::Spot(SpotPair::BtcUsdt),
period: Period::Minute(1),
batch_size: 1000,
exchange: Some(Exchange::Binance),
indicators: Vec::new(),
};
let filter = FilterConfig {
allowed_years: Some((2017..=2026).collect::<BTreeSet<_>>()),
..FilterConfig::default()
};
EnvConfig::default()
.add_ohlcv_spot(source.clone(), market_config)
.with_episode_length(EpisodeLength::Infinite)
.with_filter_config(filter)
}
EnvPreset::BinanceBtcUsdt1m15m => {
let ohlcv_1m = OhlcvSpotConfig {
broker: DataBroker::Binance,
symbol: Symbol::Spot(SpotPair::BtcUsdt),
exchange: Some(Exchange::Binance),
period: Period::Minute(1),
batch_size: 1000,
indicators: Vec::new(),
};
let ohlcv_15m = OhlcvSpotConfig {
broker: DataBroker::Binance,
symbol: Symbol::Spot(SpotPair::BtcUsdt),
exchange: Some(Exchange::Binance),
period: Period::Minute(15),
batch_size: 1000,
indicators: Vec::new(),
};
let filter = FilterConfig {
allowed_years: Some((2017..=2026).collect::<BTreeSet<_>>()),
..FilterConfig::default()
};
EnvConfig::default()
.add_ohlcv_spot(source.clone(), ohlcv_1m)
.add_ohlcv_spot(source.clone(), ohlcv_15m)
.with_episode_length(EpisodeLength::Infinite)
.with_filter_config(filter)
}
EnvPreset::NinjaTraderCme6eh61m5mUsEmpHigh => {
let ohlcv_1m = OhlcvFutureConfig {
broker: DataBroker::NinjaTrader,
symbol: Symbol::Future(FutureContract {
root: FutureRoot::EurUsd,
month: ContractMonth::June,
year: ContractYear::Y6,
}),
exchange: Some(Exchange::Cme),
period: Period::Minute(1),
batch_size: 1000,
indicators: vec![],
};
let ohlcv_5m = OhlcvFutureConfig {
broker: DataBroker::NinjaTrader,
symbol: Symbol::Future(FutureContract {
root: FutureRoot::EurUsd,
month: ContractMonth::June,
year: ContractYear::Y6,
}),
exchange: Some(Exchange::Cme),
period: Period::Minute(5),
batch_size: 1000,
indicators: vec![],
};
let calendar = EconomicCalendarConfig {
broker: DataBroker::InvestingCom,
data_source: None,
country_code: Some(CountryCode::Us),
category: Some(EconomicCategory::Employment),
importance: Some(EconomicEventImpact::High),
batch_size: 1000,
};
let filter = FilterConfig {
allowed_years: Some((2008..=2026).collect::<BTreeSet<_>>()),
..FilterConfig::default()
};
EnvConfig::default()
.add_ohlcv_future(source.clone(), ohlcv_1m)
.add_ohlcv_future(source.clone(), ohlcv_5m)
.with_episode_length(EpisodeLength::Day)
.with_filter_config(filter)
.add_economic_calendar(source.clone(), calendar)
.with_trade_hint(4)
}
EnvPreset::NinjaTraderCme6eh61mUsEmpHighEventsOnly => {
let ohlcv = OhlcvFutureConfig {
broker: DataBroker::NinjaTrader,
symbol: Symbol::Future(FutureContract {
root: FutureRoot::EurUsd,
month: ContractMonth::June,
year: ContractYear::Y6,
}),
exchange: Some(Exchange::Cme),
period: Period::Minute(1),
batch_size: 1000,
indicators: vec![],
};
let calendar = EconomicCalendarConfig {
broker: DataBroker::InvestingCom,
data_source: None,
country_code: Some(CountryCode::Us),
category: Some(EconomicCategory::Employment),
importance: Some(EconomicEventImpact::High),
batch_size: 1000,
};
let filter = FilterConfig {
allowed_years: Some((2008..=2026).collect::<BTreeSet<_>>()),
economic_news_policy: Some(EconomicCalendarPolicy::OnlyWithEvents),
..FilterConfig::default()
};
EnvConfig::default()
.add_ohlcv_future(source.clone(), ohlcv)
.with_episode_length(EpisodeLength::Day)
.with_filter_config(filter)
.add_economic_calendar(source.clone(), calendar)
.with_trade_hint(2)
}
EnvPreset::NinjaTraderCme6eh61m5mUsEmpHighEventsOnly => {
let ohlcv_1m = OhlcvFutureConfig {
broker: DataBroker::NinjaTrader,
symbol: Symbol::Future(FutureContract {
root: FutureRoot::EurUsd,
month: ContractMonth::June,
year: ContractYear::Y6,
}),
exchange: Some(Exchange::Cme),
period: Period::Minute(1),
batch_size: 1000,
indicators: vec![],
};
let ohlcv_5m = OhlcvFutureConfig {
broker: DataBroker::NinjaTrader,
symbol: Symbol::Future(FutureContract {
root: FutureRoot::EurUsd,
month: ContractMonth::June,
year: ContractYear::Y6,
}),
exchange: Some(Exchange::Cme),
period: Period::Minute(5),
batch_size: 1000,
indicators: vec![],
};
let calendar = EconomicCalendarConfig {
broker: DataBroker::InvestingCom,
data_source: None,
country_code: Some(CountryCode::Us),
category: Some(EconomicCategory::Employment),
importance: Some(EconomicEventImpact::High),
batch_size: 1000,
};
let filter = FilterConfig {
allowed_years: Some((2008..=2026).collect::<BTreeSet<_>>()),
economic_news_policy: Some(EconomicCalendarPolicy::OnlyWithEvents),
..FilterConfig::default()
};
EnvConfig::default()
.add_ohlcv_future(source.clone(), ohlcv_1m)
.add_ohlcv_future(source.clone(), ohlcv_5m)
.with_episode_length(EpisodeLength::Day)
.with_filter_config(filter)
.add_economic_calendar(source.clone(), calendar)
.with_trade_hint(4)
}
EnvPreset::BinanceBtcUsdt1dSma20Sma50 => {
let market_config = OhlcvSpotConfig {
broker: DataBroker::Binance,
symbol: Symbol::Spot(SpotPair::BtcUsdt),
exchange: Some(Exchange::Binance),
period: Period::Day(1),
batch_size: 1000,
indicators: vec![
TechnicalIndicator::Sma(SmaWindow(20)),
TechnicalIndicator::Sma(SmaWindow(50)),
],
};
let filter = FilterConfig {
allowed_years: Some((2017..=2026).collect::<BTreeSet<_>>()),
..FilterConfig::default()
};
EnvConfig::default()
.add_ohlcv_spot(source.clone(), market_config)
.with_episode_length(EpisodeLength::Infinite)
.with_filter_config(filter)
}
EnvPreset::BinanceBtcUsdt1h1mVolumeProfile1d100Usdt => {
let ohlcv_1h = OhlcvSpotConfig {
broker: DataBroker::Binance,
symbol: Symbol::Spot(SpotPair::BtcUsdt),
exchange: Some(Exchange::Binance),
period: Period::Hour(1),
batch_size: 1000,
indicators: vec![],
};
let ohlcv_1m = OhlcvSpotConfig {
broker: DataBroker::Binance,
symbol: Symbol::Spot(SpotPair::BtcUsdt),
exchange: Some(Exchange::Binance),
period: Period::Minute(1),
batch_size: 1000,
indicators: vec![],
};
let vp = VolumeProfileSpotConfig {
broker: DataBroker::Binance,
symbol: Symbol::Spot(SpotPair::BtcUsdt),
exchange: Some(Exchange::Binance),
aggregation: Some(ProfileAggregation {
time_frame: Some(Period::Day(1)),
ticks_per_bin: Some(10_000),
..ProfileAggregation::default()
}),
batch_size: 1000,
};
let filter = FilterConfig {
allowed_years: Some((2017..=2026).collect::<BTreeSet<_>>()),
..FilterConfig::default()
};
EnvConfig::default()
.add_ohlcv_spot(source.clone(), ohlcv_1h)
.add_ohlcv_spot(source.clone(), ohlcv_1m)
.add_volume_profile_spot(source.clone(), vp)
.with_episode_length(EpisodeLength::Day)
.with_filter_config(filter)
}
EnvPreset::BinanceBtcUsdt1h1mTpo1d1Usdt => {
let ohlcv_1h = OhlcvSpotConfig {
broker: DataBroker::Binance,
symbol: Symbol::Spot(SpotPair::BtcUsdt),
exchange: Some(Exchange::Binance),
period: Period::Hour(1),
batch_size: 1000,
indicators: vec![],
};
let ohlcv_1m = OhlcvSpotConfig {
broker: DataBroker::Binance,
symbol: Symbol::Spot(SpotPair::BtcUsdt),
exchange: Some(Exchange::Binance),
period: Period::Minute(1),
batch_size: 1000,
indicators: vec![],
};
let tpo = TpoSpotConfig {
broker: DataBroker::Binance,
symbol: Symbol::Spot(SpotPair::BtcUsdt),
exchange: Some(Exchange::Binance),
aggregation: Some(ProfileAggregation {
time_frame: Some(Period::Day(1)),
ticks_per_bin: Some(100),
..ProfileAggregation::default()
}),
batch_size: 1000,
};
let filter = FilterConfig {
allowed_years: Some((2017..=2026).collect::<BTreeSet<_>>()),
..FilterConfig::default()
};
EnvConfig::default()
.add_ohlcv_spot(source.clone(), ohlcv_1h)
.add_ohlcv_spot(source.clone(), ohlcv_1m)
.add_tpo_spot(source.clone(), tpo)
.with_episode_length(EpisodeLength::Day)
.with_filter_config(filter)
}
EnvPreset::NinjaTraderCme6eh61mTpo1d => {
let ohlcv = OhlcvFutureConfig {
broker: DataBroker::NinjaTrader,
symbol: Symbol::Future(FutureContract {
root: FutureRoot::EurUsd,
month: ContractMonth::June,
year: ContractYear::Y6,
}),
exchange: Some(Exchange::Cme),
period: Period::Minute(1),
batch_size: 1000,
indicators: vec![],
};
let tpo = TpoFutureConfig {
broker: DataBroker::NinjaTrader,
symbol: Symbol::Future(FutureContract {
root: FutureRoot::EurUsd,
month: ContractMonth::June,
year: ContractYear::Y6,
}),
exchange: Some(Exchange::Cme),
aggregation: Some(ProfileAggregation {
time_frame: Some(Period::Day(1)),
..ProfileAggregation::default()
}),
batch_size: 1000,
};
let filter = FilterConfig {
allowed_years: Some((2006..=2026).collect::<BTreeSet<_>>()),
..FilterConfig::default()
};
EnvConfig::default()
.add_ohlcv_future(source.clone(), ohlcv)
.add_tpo_future(source.clone(), tpo)
.with_episode_length(EpisodeLength::Day)
.with_filter_config(filter)
}
}
}
}
#[derive(Clone, Serialize, Deserialize)]
pub struct EnvConfig {
ohlcv_spot: Vec<SourceGroup<OhlcvSpotConfig>>,
ohlcv_future: Vec<SourceGroup<OhlcvFutureConfig>>,
trade_spot: Vec<SourceGroup<TradeSpotConfig>>,
tpo_spot: Vec<SourceGroup<TpoSpotConfig>>,
tpo_future: Vec<SourceGroup<TpoFutureConfig>>,
volume_profile_spot: Vec<SourceGroup<VolumeProfileSpotConfig>>,
economic_calendar: Vec<SourceGroup<EconomicCalendarConfig>>,
filter_config: Option<FilterConfig>,
episode_length: EpisodeLength,
risk_metrics_cfg: RiskMetricsConfig,
trade_hint: usize,
invalid_action_penalty: InvalidActionPenalty,
}
impl Default for EnvConfig {
fn default() -> Self {
Self {
ohlcv_spot: Vec::new(),
ohlcv_future: Vec::new(),
trade_spot: Vec::new(),
tpo_spot: Vec::new(),
tpo_future: Vec::new(),
volume_profile_spot: Vec::new(),
economic_calendar: Vec::new(),
filter_config: None,
episode_length: EpisodeLength::default(),
risk_metrics_cfg: RiskMetricsConfig::default(),
trade_hint: 2,
invalid_action_penalty: InvalidActionPenalty::default(),
}
}
}
impl EnvConfig {
pub fn add_ohlcv_spot(self, source: DataSource, config: OhlcvSpotConfig) -> Self {
Self {
ohlcv_spot: update_source_group(self.ohlcv_spot, source, config),
..self
}
}
pub fn add_ohlcv_future(self, source: DataSource, config: OhlcvFutureConfig) -> Self {
Self {
ohlcv_future: update_source_group(self.ohlcv_future, source, config),
..self
}
}
pub fn add_trade_spot(self, source: DataSource, config: TradeSpotConfig) -> Self {
Self {
trade_spot: update_source_group(self.trade_spot, source, config),
..self
}
}
pub fn add_tpo_spot(self, source: DataSource, config: TpoSpotConfig) -> Self {
Self {
tpo_spot: update_source_group(self.tpo_spot, source, config),
..self
}
}
pub fn add_tpo_future(self, source: DataSource, config: TpoFutureConfig) -> Self {
Self {
tpo_future: update_source_group(self.tpo_future, source, config),
..self
}
}
pub fn add_volume_profile_spot(
self,
source: DataSource,
config: VolumeProfileSpotConfig,
) -> Self {
Self {
volume_profile_spot: update_source_group(self.volume_profile_spot, source, config),
..self
}
}
pub fn add_economic_calendar(self, source: DataSource, config: EconomicCalendarConfig) -> Self {
Self {
economic_calendar: update_source_group(self.economic_calendar, source, config),
..self
}
}
}
impl EnvConfig {
pub fn with_filter_config(self, filter_config: FilterConfig) -> Self {
Self {
filter_config: Some(filter_config),
..self
}
}
pub fn with_episode_length(self, episode_length: EpisodeLength) -> Self {
Self {
episode_length,
..self
}
}
pub fn with_risk_metrics_cfg(self, risk_metrics_cfg: RiskMetricsConfig) -> Self {
Self {
risk_metrics_cfg,
..self
}
}
pub fn with_trade_hint(self, trade_hint: u32) -> Self {
Self {
trade_hint: trade_hint.min(32) as usize,
..self
}
}
pub fn with_invalid_action_penalty(self, penalty: InvalidActionPenalty) -> Self {
assert!(
penalty.0.0 <= 0,
"Invalid action penalty must be <= 0, got {}",
penalty.0.0
);
Self {
invalid_action_penalty: penalty,
..self
}
}
}
impl EnvConfig {
pub fn ohlcv_spot(&self) -> &[SourceGroup<OhlcvSpotConfig>] {
&self.ohlcv_spot
}
pub fn ohlcv_future(&self) -> &[SourceGroup<OhlcvFutureConfig>] {
&self.ohlcv_future
}
pub fn trade_spot(&self) -> &[SourceGroup<TradeSpotConfig>] {
&self.trade_spot
}
pub fn tpo_spot(&self) -> &[SourceGroup<TpoSpotConfig>] {
&self.tpo_spot
}
pub fn tpo_future(&self) -> &[SourceGroup<TpoFutureConfig>] {
&self.tpo_future
}
pub fn volume_profile_spot(&self) -> &[SourceGroup<VolumeProfileSpotConfig>] {
&self.volume_profile_spot
}
pub fn economic_calendar(&self) -> &[SourceGroup<EconomicCalendarConfig>] {
&self.economic_calendar
}
pub fn filter_config(&self) -> Option<&FilterConfig> {
self.filter_config.as_ref()
}
pub fn episode_length(&self) -> EpisodeLength {
self.episode_length
}
pub fn risk_metrics_cfg(&self) -> RiskMetricsConfig {
self.risk_metrics_cfg
}
pub fn trade_hint(&self) -> usize {
self.trade_hint
}
pub fn invalid_action_penalty(&self) -> InvalidActionPenalty {
self.invalid_action_penalty
}
pub fn allowed_years(&self) -> Vec<u16> {
if let Some(years_set) = self
.filter_config()
.as_ref()
.and_then(|c| c.allowed_years.as_ref())
{
let mut y: Vec<u16> = years_set.iter().copied().collect();
y.sort_unstable();
return y;
}
(1990..=2040).collect()
}
pub fn max_episode_capacity(&self) -> usize {
let max_episodes_per_year = self.episode_length().max_episodes();
let number_of_years = self.allowed_years().len();
(max_episodes_per_year * number_of_years).max(1)
}
}
fn update_source_group<T>(
mut groups: Vec<SourceGroup<T>>,
source: DataSource,
config: T,
) -> Vec<SourceGroup<T>> {
if let Some(group) = groups.iter_mut().find(|g| g.source == source) {
group.items.push(config);
} else {
groups.push(SourceGroup {
source,
items: vec![config],
});
}
groups
}
impl EnvConfig {
pub fn hash(&self) -> ChapatyResult<String> {
let mut hasher = blake3::Hasher::new();
let bytes = postcard::to_stdvec(self).map_err(EnvError::Encoding)?;
hasher.update(&bytes);
Ok(format!("{}", hasher.finalize()))
}
pub fn is_valid(&self) -> bool {
!self.ohlcv_spot.is_empty() || !self.ohlcv_future.is_empty() || !self.trade_spot.is_empty()
}
}