mod bayesian;
mod grid;
pub use bayesian::BayesianSearch;
pub use grid::GridSearch;
use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use super::result::BacktestResult;
#[non_exhaustive]
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum ParamValue {
Int(i64),
Float(f64),
}
impl ParamValue {
pub fn as_int(&self) -> i64 {
match self {
ParamValue::Int(v) => *v,
ParamValue::Float(v) => *v as i64,
}
}
pub fn as_float(&self) -> f64 {
match self {
ParamValue::Int(v) => *v as f64,
ParamValue::Float(v) => *v,
}
}
}
impl std::fmt::Display for ParamValue {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ParamValue::Int(v) => write!(f, "{v}"),
ParamValue::Float(v) => write!(f, "{v:.4}"),
}
}
}
#[non_exhaustive]
#[derive(Debug, Clone)]
pub enum ParamRange {
IntRange {
start: i64,
end: i64,
step: i64,
},
FloatRange {
start: f64,
end: f64,
step: f64,
},
Values(Vec<ParamValue>),
}
impl ParamRange {
pub fn int_range(start: i64, end: i64, step: i64) -> Self {
Self::IntRange { start, end, step }
}
pub fn float_range(start: f64, end: f64, step: f64) -> Self {
Self::FloatRange { start, end, step }
}
pub fn int_bounds(start: i64, end: i64) -> Self {
Self::IntRange {
start,
end,
step: 1,
}
}
pub fn float_bounds(start: f64, end: f64) -> Self {
Self::FloatRange {
start,
end,
step: 0.0,
}
}
pub(crate) fn sample_at(&self, t: f64) -> ParamValue {
let t = t.clamp(0.0, 1.0);
match self {
ParamRange::IntRange { start, end, .. } => {
let span = (*end - *start) as f64;
let v = *start + (t * (span + 1.0)).floor() as i64;
ParamValue::Int(v.min(*end))
}
ParamRange::FloatRange { start, end, .. } => {
ParamValue::Float(start + t * (end - start))
}
ParamRange::Values(vals) if vals.is_empty() => ParamValue::Int(0),
ParamRange::Values(vals) => {
let idx = (t * vals.len() as f64).floor() as usize;
vals[idx.min(vals.len() - 1)].clone()
}
}
}
pub(crate) fn expand(&self) -> Vec<ParamValue> {
match self {
ParamRange::IntRange { start, end, step } => {
if *step <= 0 {
return vec![];
}
let mut v = Vec::new();
let mut cur = *start;
while cur <= *end {
v.push(ParamValue::Int(cur));
cur += step;
}
v
}
ParamRange::FloatRange { start, end, step } => {
if *step <= 0.0 {
return vec![];
}
let steps = ((end - start) / step).round() as usize;
(0..=steps)
.map(|i| {
let v = if i == steps {
*end
} else {
start + i as f64 * step
};
ParamValue::Float(v)
})
.collect()
}
ParamRange::Values(vals) => vals.clone(),
}
}
}
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum OptimizeMetric {
TotalReturn,
SharpeRatio,
SortinoRatio,
CalmarRatio,
ProfitFactor,
WinRate,
MinDrawdown,
}
impl OptimizeMetric {
pub(crate) fn score(&self, result: &BacktestResult) -> f64 {
match self {
OptimizeMetric::TotalReturn => result.metrics.total_return_pct,
OptimizeMetric::SharpeRatio => result.metrics.sharpe_ratio,
OptimizeMetric::SortinoRatio => result.metrics.sortino_ratio,
OptimizeMetric::CalmarRatio => result.metrics.calmar_ratio,
OptimizeMetric::ProfitFactor => result.metrics.profit_factor,
OptimizeMetric::WinRate => result.metrics.win_rate,
OptimizeMetric::MinDrawdown => -result.metrics.max_drawdown_pct,
}
}
}
#[non_exhaustive]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OptimizationResult {
pub params: HashMap<String, ParamValue>,
pub result: BacktestResult,
}
#[non_exhaustive]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OptimizationReport {
pub strategy_name: String,
pub total_combinations: usize,
pub results: Vec<OptimizationResult>,
pub best: OptimizationResult,
pub skipped_errors: usize,
pub convergence_curve: Vec<f64>,
pub n_evaluations: usize,
}
pub(crate) fn sort_results_best_first(results: &mut [OptimizationResult], metric: OptimizeMetric) {
results.sort_by(|a, b| {
let sa = metric.score(&a.result);
let sb = metric.score(&b.result);
match (sa.is_nan(), sb.is_nan()) {
(true, true) => std::cmp::Ordering::Equal,
(true, false) => std::cmp::Ordering::Greater, (false, true) => std::cmp::Ordering::Less, (false, false) => sb.partial_cmp(&sa).unwrap_or(std::cmp::Ordering::Equal),
}
});
}