use std::collections::BTreeSet;
use serde::{Deserialize, Serialize};
use strum::{Display, EnumCount, EnumIter, EnumString, IntoStaticStr};
use crate::{
data::{
common::RiskMetricsConfig,
config::{
EconomicCalendarConfig, OhlcvFutureConfig, OhlcvSpotConfig, TpoFutureConfig,
TpoSpotConfig, TradeSpotConfig, VolumeProfileSpotConfig,
},
domain::{DataBroker, Exchange, Period, SpotPair, Symbol},
episode::EpisodeLength,
filter::FilterConfig,
},
error::{ChapatyResult, EnvError},
gym::Reward,
transport::source::{DataSource, SourceGroup},
};
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub struct InvalidActionPenalty(pub Reward);
impl Default for InvalidActionPenalty {
fn default() -> Self {
Self(Reward(0))
}
}
impl From<InvalidActionPenalty> for Reward {
fn from(penalty: InvalidActionPenalty) -> Self {
penalty.0
}
}
#[derive(
Debug,
Clone,
Copy,
PartialEq,
Eq,
Hash,
Serialize,
Deserialize,
EnumString,
Display,
PartialOrd,
Ord,
EnumIter,
IntoStaticStr,
EnumCount,
)]
pub enum EnvPreset {
BtcUsdtEod,
}
impl From<EnvPreset> for EnvConfig {
fn from(preset: EnvPreset) -> Self {
match preset {
EnvPreset::BtcUsdtEod => {
let market_config = OhlcvSpotConfig {
broker: DataBroker::Binance,
symbol: Symbol::Spot(SpotPair::BtcUsdt),
period: Period::Day(1),
batch_size: 1000,
exchange: Some(Exchange::Binance),
indicators: Vec::new(),
};
let allowed_years = (2018..=2025).collect::<BTreeSet<_>>();
let filter = FilterConfig {
allowed_years: Some(allowed_years),
..FilterConfig::default()
};
EnvConfig::default()
.add_ohlcv_spot(DataSource::Chapaty, market_config)
.with_filter_config(filter)
}
}
}
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash, Default)]
pub enum ExecutionBias {
Optimistic,
#[default]
Pessimistic,
}
#[derive(Clone, Serialize, Deserialize)]
pub struct EnvConfig {
ohlcv_spot: Vec<SourceGroup<OhlcvSpotConfig>>,
ohlcv_future: Vec<SourceGroup<OhlcvFutureConfig>>,
trade_spot: Vec<SourceGroup<TradeSpotConfig>>,
tpo_spot: Vec<SourceGroup<TpoSpotConfig>>,
tpo_future: Vec<SourceGroup<TpoFutureConfig>>,
volume_profile_spot: Vec<SourceGroup<VolumeProfileSpotConfig>>,
economic_calendar: Vec<SourceGroup<EconomicCalendarConfig>>,
filter_config: Option<FilterConfig>,
episode_length: EpisodeLength,
risk_metrics_cfg: RiskMetricsConfig,
trade_hint: usize,
invalid_action_penalty: InvalidActionPenalty,
}
impl Default for EnvConfig {
fn default() -> Self {
Self {
ohlcv_spot: Vec::new(),
ohlcv_future: Vec::new(),
trade_spot: Vec::new(),
tpo_spot: Vec::new(),
tpo_future: Vec::new(),
volume_profile_spot: Vec::new(),
economic_calendar: Vec::new(),
filter_config: None,
episode_length: EpisodeLength::default(),
risk_metrics_cfg: RiskMetricsConfig::default(),
trade_hint: 2,
invalid_action_penalty: InvalidActionPenalty::default(),
}
}
}
impl EnvConfig {
pub fn add_ohlcv_spot(self, source: DataSource, config: OhlcvSpotConfig) -> Self {
Self {
ohlcv_spot: update_source_group(self.ohlcv_spot, source, config),
..self
}
}
pub fn add_ohlcv_future(self, source: DataSource, config: OhlcvFutureConfig) -> Self {
Self {
ohlcv_future: update_source_group(self.ohlcv_future, source, config),
..self
}
}
pub fn add_trade_spot(self, source: DataSource, config: TradeSpotConfig) -> Self {
Self {
trade_spot: update_source_group(self.trade_spot, source, config),
..self
}
}
pub fn add_tpo_spot(self, source: DataSource, config: TpoSpotConfig) -> Self {
Self {
tpo_spot: update_source_group(self.tpo_spot, source, config),
..self
}
}
pub fn add_tpo_future(self, source: DataSource, config: TpoFutureConfig) -> Self {
Self {
tpo_future: update_source_group(self.tpo_future, source, config),
..self
}
}
pub fn add_volume_profile_spot(
self,
source: DataSource,
config: VolumeProfileSpotConfig,
) -> Self {
Self {
volume_profile_spot: update_source_group(self.volume_profile_spot, source, config),
..self
}
}
pub fn add_economic_calendar(self, source: DataSource, config: EconomicCalendarConfig) -> Self {
Self {
economic_calendar: update_source_group(self.economic_calendar, source, config),
..self
}
}
}
impl EnvConfig {
pub fn with_filter_config(self, filter_config: FilterConfig) -> Self {
Self {
filter_config: Some(filter_config),
..self
}
}
pub fn with_episode_length(self, episode_length: EpisodeLength) -> Self {
Self {
episode_length,
..self
}
}
pub fn with_risk_metrics_cfg(self, risk_metrics_cfg: RiskMetricsConfig) -> Self {
Self {
risk_metrics_cfg,
..self
}
}
pub fn with_trade_hint(self, trade_hint: u32) -> Self {
Self {
trade_hint: trade_hint.min(32) as usize,
..self
}
}
pub fn with_invalid_action_penalty(self, penalty: InvalidActionPenalty) -> Self {
assert!(
penalty.0.0 <= 0,
"Invalid action penalty must be <= 0, got {}",
penalty.0.0
);
Self {
invalid_action_penalty: penalty,
..self
}
}
}
impl EnvConfig {
pub fn ohlcv_spot(&self) -> &[SourceGroup<OhlcvSpotConfig>] {
&self.ohlcv_spot
}
pub fn ohlcv_future(&self) -> &[SourceGroup<OhlcvFutureConfig>] {
&self.ohlcv_future
}
pub fn trade_spot(&self) -> &[SourceGroup<TradeSpotConfig>] {
&self.trade_spot
}
pub fn tpo_spot(&self) -> &[SourceGroup<TpoSpotConfig>] {
&self.tpo_spot
}
pub fn tpo_future(&self) -> &[SourceGroup<TpoFutureConfig>] {
&self.tpo_future
}
pub fn volume_profile_spot(&self) -> &[SourceGroup<VolumeProfileSpotConfig>] {
&self.volume_profile_spot
}
pub fn economic_calendar(&self) -> &[SourceGroup<EconomicCalendarConfig>] {
&self.economic_calendar
}
pub fn filter_config(&self) -> Option<&FilterConfig> {
self.filter_config.as_ref()
}
pub fn episode_length(&self) -> EpisodeLength {
self.episode_length
}
pub fn risk_metrics_cfg(&self) -> RiskMetricsConfig {
self.risk_metrics_cfg
}
pub fn trade_hint(&self) -> usize {
self.trade_hint
}
pub fn invalid_action_penalty(&self) -> InvalidActionPenalty {
self.invalid_action_penalty
}
pub fn allowed_years(&self) -> Vec<u16> {
if let Some(years_set) = self
.filter_config()
.as_ref()
.and_then(|c| c.allowed_years.as_ref())
{
let mut y: Vec<u16> = years_set.iter().copied().collect();
y.sort_unstable();
return y;
}
(1990..=2040).collect()
}
pub fn max_episode_capacity(&self) -> usize {
let max_episodes_per_year = self.episode_length().max_episodes();
let number_of_years = self.allowed_years().len();
(max_episodes_per_year * number_of_years).max(1)
}
}
fn update_source_group<T>(
mut groups: Vec<SourceGroup<T>>,
source: DataSource,
config: T,
) -> Vec<SourceGroup<T>> {
if let Some(group) = groups.iter_mut().find(|g| g.source == source) {
group.items.push(config);
} else {
groups.push(SourceGroup {
source,
items: vec![config],
});
}
groups
}
impl EnvConfig {
pub fn hash(&self) -> ChapatyResult<String> {
let mut hasher = blake3::Hasher::new();
let bytes = postcard::to_stdvec(self).map_err(EnvError::Encoding)?;
hasher.update(&bytes);
Ok(format!("{}", hasher.finalize()))
}
pub fn is_valid(&self) -> bool {
!self.ohlcv_spot.is_empty() || !self.ohlcv_future.is_empty() || !self.trade_spot.is_empty()
}
}