use serde::{Serialize, Deserialize};
use std::collections::HashMap;
use std::fmt;
use crate::bar_indicators::average::MovingAverageType;
use crate::bar_indicators::ohlcv_field::OhlcvField;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum ParameterValue {
Int(i64),
Float(f64),
MaType(MovingAverageType),
String(String),
Bool(bool),
USize(usize),
U8(u8),
F64(f64),
Source(OhlcvField),
}
impl ParameterValue {
pub fn as_int(&self) -> Option<i64> {
match self {
ParameterValue::Int(v) => Some(*v),
ParameterValue::USize(v) => Some(*v as i64),
ParameterValue::U8(v) => Some(*v as i64),
_ => None,
}
}
pub fn as_float(&self) -> Option<f64> {
match self {
ParameterValue::Float(v) => Some(*v),
ParameterValue::F64(v) => Some(*v),
ParameterValue::Int(v) => Some(*v as f64),
ParameterValue::USize(v) => Some(*v as f64),
ParameterValue::U8(v) => Some(*v as f64),
_ => None,
}
}
pub fn as_ma_type(&self) -> Option<MovingAverageType> {
match self {
ParameterValue::MaType(t) => Some(*t),
_ => None,
}
}
pub fn as_usize(&self) -> Option<usize> {
match self {
ParameterValue::Int(v) if *v >= 0 => Some(*v as usize),
ParameterValue::USize(v) => Some(*v),
ParameterValue::U8(v) => Some(*v as usize),
_ => None,
}
}
pub fn as_u8(&self) -> Option<u8> {
match self {
ParameterValue::U8(v) => Some(*v),
ParameterValue::USize(v) if *v <= 255 => Some(*v as u8),
ParameterValue::Int(v) if *v >= 0 && *v <= 255 => Some(*v as u8),
_ => None,
}
}
pub fn as_f64(&self) -> Option<f64> {
match self {
ParameterValue::F64(v) => Some(*v),
ParameterValue::Float(v) => Some(*v),
ParameterValue::Int(v) => Some(*v as f64),
ParameterValue::USize(v) => Some(*v as f64),
ParameterValue::U8(v) => Some(*v as f64),
_ => None,
}
}
pub fn as_bool(&self) -> Option<bool> {
match self {
ParameterValue::Bool(v) => Some(*v),
_ => None,
}
}
pub fn as_string(&self) -> Option<&str> {
match self {
ParameterValue::String(s) => Some(s.as_str()),
_ => None,
}
}
pub fn as_source(&self) -> Option<OhlcvField> {
match self {
ParameterValue::Source(field) => Some(*field),
_ => None,
}
}
}
impl fmt::Display for ParameterValue {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ParameterValue::Int(v) => write!(f, "{}", v),
ParameterValue::Float(v) => write!(f, "{:.4}", v),
ParameterValue::F64(v) => write!(f, "{:.4}", v),
ParameterValue::MaType(v) => write!(f, "{}", v),
ParameterValue::String(v) => write!(f, "{}", v),
ParameterValue::Bool(v) => write!(f, "{}", v),
ParameterValue::USize(v) => write!(f, "{}", v),
ParameterValue::U8(v) => write!(f, "{}", v),
ParameterValue::Source(v) => write!(f, "{}", v.as_str()),
}
}
}
#[derive(Debug, Clone)]
pub enum ParameterRange {
IntRange { min: i64, max: i64, step: i64 },
FloatRange { min: f64, max: f64, step: f64 },
ValueList(Vec<ParameterValue>),
}
impl ParameterRange {
pub fn generate(&self) -> Vec<ParameterValue> {
match self {
ParameterRange::IntRange { min, max, step } => {
let mut values = Vec::new();
let mut current = *min;
while current <= *max {
values.push(ParameterValue::Int(current));
current += step;
}
values
}
ParameterRange::FloatRange { min, max, step } => {
let mut values = Vec::new();
let steps = ((*max - *min) / *step).ceil() as usize + 1;
for i in 0..steps {
let value = *min + *step * (i as f64);
if value <= *max {
values.push(ParameterValue::Float(value));
}
}
values
}
ParameterRange::ValueList(vals) => vals.clone(),
}
}
pub fn count(&self) -> usize {
match self {
ParameterRange::IntRange { min, max, step } => {
if *step <= 0 { return 0; }
(((*max - *min) / *step) + 1).max(0) as usize
}
ParameterRange::FloatRange { min, max, step } => {
if *step <= 0.0 { return 0; }
((*max - *min) / *step).ceil() as usize + 1
}
ParameterRange::ValueList(vals) => vals.len(),
}
}
}
#[derive(Debug, Clone, Default)]
pub struct ParameterGrid {
ranges: HashMap<String, ParameterRange>,
parameter_order: Vec<String>,
}
impl ParameterGrid {
pub fn new() -> Self {
Self {
ranges: HashMap::new(),
parameter_order: Vec::new(),
}
}
pub fn add_int_range(mut self, name: &str, min: i64, max: i64, step: i64) -> Self {
self.ranges.insert(
name.to_string(),
ParameterRange::IntRange { min, max, step },
);
self.parameter_order.push(name.to_string());
self
}
pub fn add_float_range(mut self, name: &str, min: f64, max: f64, step: f64) -> Self {
self.ranges.insert(
name.to_string(),
ParameterRange::FloatRange { min, max, step },
);
self.parameter_order.push(name.to_string());
self
}
pub fn add_ma_types(mut self, name: &str, types: Vec<MovingAverageType>) -> Self {
let values: Vec<ParameterValue> = types
.into_iter()
.map(ParameterValue::MaType)
.collect();
self.ranges.insert(
name.to_string(),
ParameterRange::ValueList(values),
);
self.parameter_order.push(name.to_string());
self
}
pub fn add_values(mut self, name: &str, values: Vec<ParameterValue>) -> Self {
self.ranges.insert(
name.to_string(),
ParameterRange::ValueList(values),
);
self.parameter_order.push(name.to_string());
self
}
pub fn total_combinations(&self) -> usize {
self.ranges.values()
.map(|r| r.count())
.product()
}
pub fn generate_all(&self) -> Vec<HashMap<String, ParameterValue>> {
if self.parameter_order.is_empty() {
return vec![HashMap::new()];
}
let mut param_values: Vec<(String, Vec<ParameterValue>)> = Vec::new();
for param_name in &self.parameter_order {
if let Some(range) = self.ranges.get(param_name) {
param_values.push((param_name.clone(), range.generate()));
}
}
Self::cartesian_product(¶m_values)
}
fn cartesian_product(
params: &[(String, Vec<ParameterValue>)],
) -> Vec<HashMap<String, ParameterValue>> {
if params.is_empty() {
return vec![HashMap::new()];
}
let (param_name, param_vals) = ¶ms[0];
let rest_combinations = Self::cartesian_product(¶ms[1..]);
let mut result = Vec::new();
for val in param_vals {
for rest_combo in &rest_combinations {
let mut combo = rest_combo.clone();
combo.insert(param_name.clone(), val.clone());
result.push(combo);
}
}
result
}
pub fn parameter_names(&self) -> &[String] {
&self.parameter_order
}
pub fn extract_int_range(&self, param_name: &str) -> Vec<usize> {
if let Some(range) = self.ranges.get(param_name) {
range.generate()
.into_iter()
.filter_map(|v| v.as_usize())
.collect()
} else {
Vec::new()
}
}
pub fn extract_ma_types(&self, param_name: &str) -> Vec<MovingAverageType> {
if let Some(range) = self.ranges.get(param_name) {
range.generate()
.into_iter()
.filter_map(|v| v.as_ma_type())
.collect()
} else {
Vec::new()
}
}
pub fn compute_at_index(&self, idx: usize) -> Option<HashMap<String, ParameterValue>> {
if idx >= self.total_combinations() {
return None;
}
if self.parameter_order.is_empty() {
return Some(HashMap::new());
}
let mut param_value_lists: Vec<(String, Vec<ParameterValue>)> = Vec::new();
for param_name in &self.parameter_order {
if let Some(range) = self.ranges.get(param_name) {
param_value_lists.push((param_name.clone(), range.generate()));
}
}
let mut current_idx = idx;
let mut result = HashMap::new();
for (param_name, param_values) in ¶m_value_lists {
let range_size = param_values.len();
if range_size == 0 {
return None;
}
let local_idx = current_idx % range_size;
current_idx /= range_size;
result.insert(param_name.clone(), param_values[local_idx].clone());
}
Some(result)
}
}
pub struct ParameterGridBuilder {
grid: ParameterGrid,
}
impl Default for ParameterGridBuilder {
fn default() -> Self {
Self::new()
}
}
impl ParameterGridBuilder {
pub fn new() -> Self {
Self {
grid: ParameterGrid::new(),
}
}
pub fn int_range(mut self, name: &str, min: i64, max: i64, step: i64) -> Self {
self.grid = self.grid.add_int_range(name, min, max, step);
self
}
pub fn float_range(mut self, name: &str, min: f64, max: f64, step: f64) -> Self {
self.grid = self.grid.add_float_range(name, min, max, step);
self
}
pub fn ma_types(mut self, name: &str, types: Vec<MovingAverageType>) -> Self {
self.grid = self.grid.add_ma_types(name, types);
self
}
pub fn values(mut self, name: &str, values: Vec<ParameterValue>) -> Self {
self.grid = self.grid.add_values(name, values);
self
}
pub fn build(self) -> ParameterGrid {
self.grid
}
}
pub fn parse_range_str(s: &str) -> Result<(usize, usize, usize), String> {
let parts: Vec<&str> = s.split(',').collect();
if parts.len() != 3 {
return Err("Range format: min,max,step".to_string());
}
let min = parts[0].parse().map_err(|_| "Invalid min value")?;
let max = parts[1].parse().map_err(|_| "Invalid max value")?;
let step = parts[2].parse().map_err(|_| "Invalid step value")?;
Ok((min, max, step))
}
pub fn parse_float_range_str(s: &str) -> Result<(f64, f64, f64), String> {
let parts: Vec<&str> = s.split(',').collect();
if parts.len() != 3 {
return Err("Float range format: min,max,step".to_string());
}
let min = parts[0].parse().map_err(|_| "Invalid min value")?;
let max = parts[1].parse().map_err(|_| "Invalid max value")?;
let step = parts[2].parse().map_err(|_| "Invalid step value")?;
Ok((min, max, step))
}
pub fn parse_ma_types_str(s: &str) -> Vec<MovingAverageType> {
s.split(',')
.filter_map(|s| match s.trim().to_lowercase().as_str() {
"simple" | "sma" => Some(MovingAverageType::SMA),
"exponential" | "ema" => Some(MovingAverageType::EMA),
"wilder" | "rma" => Some(MovingAverageType::RMA),
"weighted" | "wma" => Some(MovingAverageType::WMA),
"ama" => Some(MovingAverageType::AMA),
"dema" => Some(MovingAverageType::DEMA),
"frama" => Some(MovingAverageType::SMA),
"hma" => Some(MovingAverageType::HMA),
"tema" => Some(MovingAverageType::TEMA),
"tma" => Some(MovingAverageType::TMA),
"vwap" => Some(MovingAverageType::VWAP),
"lr" => Some(MovingAverageType::SMA),
"ama_ring" => Some(MovingAverageType::AMA),
_ => None,
})
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_int_range_generation() {
let grid = ParameterGrid::new()
.add_int_range("period", 10, 30, 10);
let combinations = grid.generate_all();
assert_eq!(combinations.len(), 3); }
#[test]
fn test_multiple_ranges() {
let grid = ParameterGrid::new()
.add_int_range("fast_period", 5, 15, 5) .add_int_range("slow_period", 20, 40, 10);
let combinations = grid.generate_all();
assert_eq!(combinations.len(), 9); }
#[test]
fn test_ma_types() {
let grid = ParameterGrid::new()
.add_ma_types("ma_type", vec![
MovingAverageType::SMA,
MovingAverageType::EMA,
]);
let combinations = grid.generate_all();
assert_eq!(combinations.len(), 2);
}
#[test]
fn test_total_combinations_count() {
let grid = ParameterGrid::new()
.add_int_range("period", 10, 30, 10) .add_ma_types("ma_type", vec![
MovingAverageType::SMA,
MovingAverageType::EMA,
]);
assert_eq!(grid.total_combinations(), 6); }
}