use std::{str::FromStr, sync::Arc};
use ndarray::Array;
use serde::{Deserialize, Serialize};
use strum::{Display, EnumString};
use crate::{
error::{ChapatyResult, DataError},
impl_add_sub_mul_div_primitive, impl_from_primitive,
};
pub mod trading;
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub struct Reward(pub i64);
impl_from_primitive!(Reward, i64);
impl_add_sub_mul_div_primitive!(Reward, i64);
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub struct InvalidActionPenalty(pub Reward);
impl Default for InvalidActionPenalty {
fn default() -> Self {
Self(Reward(0))
}
}
impl From<InvalidActionPenalty> for Reward {
fn from(penalty: InvalidActionPenalty) -> Self {
penalty.0
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum EnvStatus {
Ready,
Running,
EpisodeDone,
Done,
}
impl EnvStatus {
pub fn is_ready(&self) -> bool {
matches!(self, Self::Ready)
}
pub fn is_running(&self) -> bool {
matches!(self, Self::Running)
}
pub fn is_episode_done(&self) -> bool {
matches!(self, Self::EpisodeDone)
}
pub fn is_done(&self) -> bool {
matches!(self, Self::Done)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum StepOutcome {
InProgress,
Terminated,
Truncated,
Done,
}
impl StepOutcome {
pub fn is_done(&self) -> bool {
matches!(self, Self::Done)
}
pub fn is_terminated(&self) -> bool {
matches!(self, Self::Terminated)
}
pub fn is_truncated(&self) -> bool {
matches!(self, Self::Truncated)
}
pub fn is_terminal(&self) -> bool {
self.is_terminated() || self.is_truncated()
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct GridAxis {
start: f64,
end: f64,
step: f64,
precision: u32,
}
impl GridAxis {
pub fn new(start: &str, end: &str, step: &str) -> ChapatyResult<Self> {
let start_f = f64::from_str(start).map_err(DataError::from)?;
let end_f = f64::from_str(end).map_err(DataError::from)?;
let step_f = f64::from_str(step).map_err(DataError::from)?;
let precision = step.split('.').nth(1).map(|s| s.len() as u32).unwrap_or(0);
Ok(Self {
start: start_f,
end: end_f,
step: step_f,
precision,
})
}
pub fn generate(&self) -> Vec<f64> {
let factor = 10_f64.powi(self.precision as i32);
Array::range(self.start, self.end, self.step)
.iter()
.map(|val| (val * factor).round() / factor)
.collect()
}
}
#[derive(
Clone,
Debug,
PartialEq,
Eq,
Hash,
Display,
Default,
PartialOrd,
Ord,
Serialize,
Deserialize,
EnumString,
)]
#[strum(serialize_all = "SCREAMING_SNAKE_CASE")]
pub enum AgentIdentifier {
#[strum(to_string = "{0}")]
Named(Arc<String>),
#[default]
Random,
}