use super::metrics::{Metrics, compute_metrics};
use super::strategy::{BacktestResult, Strategy, run_backtest};
#[cfg(feature = "parallel")]
pub fn sweep<F, P>(
params: &[P],
periods_per_year: f64,
risk_free: f64,
run_fn: F,
) -> Vec<Option<Metrics>>
where
F: Fn(&P) -> Vec<f64> + Sync,
P: Sync,
{
use rayon::prelude::*;
params
.par_iter()
.map(|p| {
let returns = run_fn(p);
compute_metrics(&returns, periods_per_year, risk_free)
})
.collect()
}
#[cfg(feature = "parallel")]
pub fn sweep_strategy<F, P, S>(
params: &[P],
price_series: &[Vec<(crate::Symbol, i64)>],
initial_cash: i64,
cost_model: super::CostModel,
periods_per_year: f64,
risk_free: f64,
make_strategy: F,
) -> Vec<BacktestResult>
where
F: Fn(&P) -> S + Sync,
P: Sync,
S: Strategy,
{
use rayon::prelude::*;
params
.par_iter()
.map(|p| {
let strategy = make_strategy(p);
run_backtest(
&strategy,
price_series,
initial_cash,
cost_model,
periods_per_year,
risk_free,
)
})
.collect()
}
#[cfg(test)]
#[cfg(feature = "parallel")]
mod tests {
use super::*;
#[test]
fn sweep_basic() {
let params = vec![1.0_f64, 2.0, 3.0];
let results = sweep(¶ms, 12.0, 0.0, |&scale| {
vec![0.01 * scale, -0.005 * scale, 0.02 * scale]
});
assert_eq!(results.len(), 3);
for r in &results {
assert!(r.is_some());
}
let r1 = results[0].as_ref().unwrap().total_return;
let r2 = results[1].as_ref().unwrap().total_return;
let r3 = results[2].as_ref().unwrap().total_return;
assert!(r2 > r1);
assert!(r3 > r2);
}
#[test]
fn sweep_empty_params() {
let params: Vec<f64> = vec![];
let results = sweep(¶ms, 12.0, 0.0, |_: &f64| vec![0.01]);
assert!(results.is_empty());
}
#[test]
fn sweep_strategy_basic() {
use crate::Symbol;
use crate::portfolio::{CostModel, EqualWeight};
fn sym(s: &str) -> Symbol {
Symbol::new(s)
}
let prices = vec![
vec![(sym("A"), 100_00)],
vec![(sym("A"), 110_00)],
vec![(sym("A"), 105_00)],
];
let params = vec![100_000_00_i64, 500_000_00, 1_000_000_00];
let results = sweep_strategy(
¶ms,
&prices,
1_000_000_00, CostModel::zero(),
12.0,
0.0,
|_| EqualWeight,
);
assert_eq!(results.len(), 3);
for r in &results {
assert!(r.metrics.is_some());
}
}
}