1use crate::evaluation::{Evaluation, Metrics};
2use crate::models::WinRate;
3use crate::utils::{cumulative_returns, normalise_returns, round_float};
4use ndarray::arr1;
5
6
7pub struct Backtest {
8 signals: Vec<f64>,
9 trading_costs: f64,
10 weight_asset_1: f64, weight_asset_2: f64, }
13
14impl Backtest {
15 pub fn new(signals: Vec<f64>, trading_costs: f64, weight_asset_1: f64, weight_asset_2: f64) -> Self {
16 Self { weight_asset_1, trading_costs, weight_asset_2, signals }
17 }
18
19 fn trade_costs(&self) -> Vec<f64> {
22 let mut trading_costs: Vec<f64> = vec![0.0; self.signals.len()];
23 for i in 1..self.signals.len() {
24 let val: f64 = self.signals[i];
25 let prev_val: f64 = self.signals[i - 1];
26
27 if val == 0.0 && prev_val != 0.0 {
29 trading_costs[i - 1] = -self.trading_costs;
30
31 } else if val != 0.0 && prev_val == 0.0 {
33 trading_costs[i] = -self.trading_costs;
34
35 } else if val != 0.0 && prev_val != 0.0 && val != prev_val {
37 trading_costs[i - 1] = -self.trading_costs;
38 trading_costs[i] = -self.trading_costs;
39 }
40 }
41 trading_costs
42 }
43
44 fn win_rate_stats(&self, log_rets: &Vec<f64>) -> WinRate {
47 let mut opened: u32 = 0;
48 let mut closed: u32 = 0;
49 let mut closed_profit: u32 = 0;
50 let mut curr_profit: f64 = 0.0;
51 let mut is_open: bool = false;
52
53 for i in 1..self.signals.len() {
54 let val: f64 = self.signals[i];
55 let prev_val: f64 = self.signals[i - 1];
56
57 if val == 0.0 && prev_val != 0.0 {
59 is_open = false;
60 closed += 1;
61 if curr_profit > 0.0 {
62 closed_profit += 1;
63 }
64 curr_profit = 0.0;
65
66 } else if val != 0.0 && prev_val == 0.0 {
68 is_open = true;
69 opened += 1;
70 curr_profit += log_rets[i];
71
72 } else if val != 0.0 && prev_val != 0.0 && val != prev_val {
74 closed += 1;
75 if curr_profit > 0.0 {
76 closed_profit += 1;
77 }
78 curr_profit += log_rets[i];
79 is_open = true;
80 opened += 1;
81
82 } else if is_open {
84 curr_profit += log_rets[i];
85 }
86 }
87
88 let mut win_rate: f64 = 0.0;
89 if closed_profit > 0 && closed > 0 {
90 win_rate = closed_profit as f64 / closed as f64;
91 }
92
93 WinRate { win_rate: round_float(win_rate, 2), opened, closed, closed_profit }
94 }
95
96 fn add_vecs(&self, vec_1: &Vec<f64>, vec_2: &Vec<f64>) -> Vec<f64> {
99 let arr_1 = arr1(&vec_1);
100 let arr_2 = arr1(&vec_2);
101 let net_arr = arr_1 + arr_2;
102 net_arr.to_vec()
103 }
104
105 fn construct_portfolio_returns(&self, log_rets: Vec<f64>, trading_costs: &Vec<f64>, sign: f64, weight: f64) -> Vec<f64> {
111
112 let rets_arr = arr1(&log_rets);
114 let sig_arr = arr1(&self.signals);
115 let strat_log_rets_arr = rets_arr * sig_arr * sign * weight;
116 let strat_log_rets = strat_log_rets_arr.to_vec();
117
118 let strat_log_rets_with_costs: Vec<f64> = self.add_vecs(&strat_log_rets, trading_costs);
120
121 strat_log_rets_with_costs
123 }
124
125 pub fn run_backtest(&self, log_rets_1: Vec<f64>, log_rets_2_opt: Option<Vec<f64>>) -> Result<Metrics, String> {
128
129 let trading_costs: Vec<f64> = self.trade_costs();
131
132 let strat_log_rets_1: Vec<f64> = self.construct_portfolio_returns(log_rets_1, &trading_costs, 1.0, self.weight_asset_1);
134
135 let log_returns: Vec<f64> = match log_rets_2_opt {
137 Some(log_rets_2) => {
138 let strat_log_rets_2: Vec<f64> = self.construct_portfolio_returns(log_rets_2, &trading_costs, -1.0, self.weight_asset_2);
139 self.add_vecs(&strat_log_rets_1, &strat_log_rets_2)
140 },
141 None => strat_log_rets_1
142 };
143
144 let strat_cum_log_rets: Vec<f64> = cumulative_returns(&log_returns);
146
147 let cum_norm_returns: Vec<f64> = normalise_returns(&strat_cum_log_rets);
149
150 let win_rate_stats: WinRate = self.win_rate_stats(&log_returns);
152
153 let evaluation: Evaluation = Evaluation::new(log_returns, cum_norm_returns, win_rate_stats);
155 let eval_metrics: Metrics = evaluation.run_evaluation_metrics();
156
157 Ok(eval_metrics)
159 }
160}
161
162
163
164#[cfg(test)]
165mod tests {
166 use super::*;
167 use crate::models::Signals;
168 use tradestats::metrics::{spread_standard, rolling_zscore};
169 use csv::Reader;
170 use serde::Deserialize;
171
172 #[derive(Debug, Deserialize)]
173 struct Record {
174 series_1: f64,
175 series_2: f64,
176 }
177
178 pub fn get_test_data() -> (Vec<f64>, Vec<f64>) {
179 let mut rdr: Reader<std::fs::File> = Reader::from_path("data/data.csv").unwrap();
180 let mut series_1: Vec<f64> = vec![];
181 let mut series_2: Vec<f64> = vec![];
182 for result in rdr.deserialize() {
183 let record: Record = result.unwrap();
184 series_1.push(record.series_1);
185 series_2.push(record.series_2);
186 }
187 (series_1, series_2)
188 }
189
190
191 #[test]
192 fn tests_backtest() {
193 let (series_1, series_2) = get_test_data();
194 let log_rets_1: Vec<f64> = tradestats::utils::log_returns(&series_1, true);
195 let log_rets_2: Vec<f64> = tradestats::utils::log_returns(&series_2, true);
196
197 let spread: Vec<f64> = spread_standard(&series_1, &series_2).unwrap();
198 let roll_zscore: Vec<f64> = rolling_zscore(&spread, 21).unwrap();
199
200 let trading_costs: f64 = 0.001;
201 let weighting_asset_1: f64 = 1.0; let weighting_asset_2: f64 = 1.0; let json_long_str: &str = r#"{
206 "eq": [-1.5, 0.0],
207 "neq": [null, null],
208 "gt": [null, 0.0],
209 "lt": [-1.5, null],
210 "signal_type": "Long"
211 }"#;
212
213 let params: Signals = serde_json::from_str(&json_long_str).unwrap();
214 let signals_obj: Signals = Signals::new(params.eq, params.neq, params.gt, params.lt, params.signal_type);
215 let long_signals: Vec<f64> = signals_obj.generate_signals(&roll_zscore);
216
217 let json_short_str: &str = r#"{
219 "eq": [1.5, 0.0],
220 "neq": [null, null],
221 "gt": [1.5, null],
222 "lt": [null, 0.0],
223 "signal_type": "Short"
224 }"#;
225
226 let params: Signals = serde_json::from_str(&json_short_str).unwrap();
227 let signals_obj: Signals = Signals::new(params.eq, params.neq, params.gt, params.lt, params.signal_type);
228 let short_signals: Vec<f64> = signals_obj.generate_signals(&roll_zscore);
229
230 let net_signals: Vec<f64> = signals_obj.consolidate_signals(vec![long_signals, short_signals]);
232
233 let backtest: Backtest = Backtest::new(net_signals, trading_costs, weighting_asset_1, weighting_asset_2);
235 let backtest_result: Result<Metrics, String> = backtest.run_backtest(log_rets_1, Some(log_rets_2));
236 match backtest_result {
237 Ok(_) => assert!(true),
238 Err(_) => assert!(false)
239 }
240 }
241}