pub mod crossover;
pub mod news;
use ndarray::Array;
use serde::{Deserialize, Serialize};
use std::{fmt::Debug, str::FromStr, sync::Arc};
use strum::{Display, EnumString};
use crate::{
error::{ChapatyResult, DataError},
gym::trading::{action::Actions, observation::Observation},
};
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct GridAxis {
start: f64,
end: f64,
step: f64,
precision: u32,
}
impl GridAxis {
pub fn new(start: &str, end: &str, step: &str) -> ChapatyResult<Self> {
let start_f = f64::from_str(start).map_err(DataError::from)?;
let end_f = f64::from_str(end).map_err(DataError::from)?;
let step_f = f64::from_str(step).map_err(DataError::from)?;
let precision = step.split('.').nth(1).map(|s| s.len() as u32).unwrap_or(0);
Ok(Self {
start: start_f,
end: end_f,
step: step_f,
precision,
})
}
pub fn generate(&self) -> Vec<f64> {
let factor = 10_f64.powi(self.precision as i32);
Array::range(self.start, self.end, self.step)
.iter()
.map(|val| (val * factor).round() / factor)
.collect()
}
}
#[derive(
Clone,
Debug,
PartialEq,
Eq,
Hash,
Display,
Default,
PartialOrd,
Ord,
Serialize,
Deserialize,
EnumString,
)]
#[strum(serialize_all = "SCREAMING_SNAKE_CASE")]
pub enum AgentIdentifier {
#[strum(to_string = "{0}")]
Named(Arc<String>),
#[default]
Random,
}
pub trait Agent {
fn act(&mut self, obs: Observation) -> ChapatyResult<Actions>;
fn identifier(&self) -> AgentIdentifier {
AgentIdentifier::Named(Arc::new(
"UnnamedAgent: override Agent::identifier()".to_string(),
))
}
fn reset(&mut self) {}
}
impl Agent for Box<dyn Agent> {
fn act(&mut self, obs: Observation) -> ChapatyResult<Actions> {
(**self).act(obs)
}
fn identifier(&self) -> AgentIdentifier {
(**self).identifier()
}
fn reset(&mut self) {
(**self).reset()
}
}