Skip to main content

hyper_playbook/
backtest.rs

1use std::collections::HashMap;
2
3use serde::{Deserialize, Serialize};
4
5#[allow(deprecated)]
6use hyper_ta::technical_analysis::calculate_indicators;
7use hyper_ta::Candle;
8
9use crate::engine::{PlaybookEngine, TickAction};
10
11// ---------------------------------------------------------------------------
12// Data types
13// ---------------------------------------------------------------------------
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
16#[serde(rename_all = "camelCase")]
17pub struct TradeRecord {
18    pub regime: String,
19    pub side: String,
20    pub entry_price: f64,
21    pub exit_price: f64,
22    pub pnl: f64,
23    pub hold_time_secs: u64,
24    pub exit_reason: String,
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
28#[serde(rename_all = "camelCase")]
29pub struct BacktestResult {
30    pub total_trades: u32,
31    pub win_rate: f64,
32    pub total_pnl: f64,
33    pub max_drawdown: f64,
34    pub sharpe_ratio: f64,
35    pub avg_hold_time_secs: u64,
36    pub regime_distribution: HashMap<String, u32>,
37    pub trades: Vec<TradeRecord>,
38    pub total_ticks: u32,
39}
40
41// ---------------------------------------------------------------------------
42// Internal tracking
43// ---------------------------------------------------------------------------
44
45struct OpenTrade {
46    regime: String,
47    side: String,
48    entry_price: f64,
49    entry_time: u64,
50}
51
52// ---------------------------------------------------------------------------
53// BacktestRunner
54// ---------------------------------------------------------------------------
55
56pub struct BacktestRunner {
57    engine: PlaybookEngine,
58}
59
60impl BacktestRunner {
61    pub fn new(engine: PlaybookEngine) -> Self {
62        Self { engine }
63    }
64
65    /// Run backtest over historical candles.
66    /// Uses a sliding window of `window_size` candles for indicator calculation.
67    pub async fn run(&mut self, candles: &[Candle], window_size: usize) -> BacktestResult {
68        let mut trades: Vec<TradeRecord> = Vec::new();
69        let mut regime_distribution: HashMap<String, u32> = HashMap::new();
70        let mut total_ticks = 0u32;
71
72        // Track open trade state
73        let mut open_trade: Option<OpenTrade> = None;
74
75        for i in window_size..candles.len() {
76            let window = &candles[i + 1 - window_size..=i];
77            #[allow(deprecated)]
78            let indicators = calculate_indicators(window);
79            let now = candles[i].time;
80            let current_price = candles[i].close;
81
82            let tick = self.engine.tick(&indicators, now).await;
83            total_ticks += 1;
84
85            // Track regime distribution
86            *regime_distribution.entry(tick.regime.clone()).or_insert(0) += 1;
87
88            // Track trades based on TickAction
89            match &tick.action {
90                TickAction::OrderPlaced { .. } => {
91                    // Order placed but not yet filled — nothing to record yet
92                }
93                TickAction::OrderFilled { entry_price, .. } => {
94                    open_trade = Some(OpenTrade {
95                        regime: tick.regime.clone(),
96                        side: "long".into(), // simplified: PaperOrderExecutor doesn't track side
97                        entry_price: *entry_price,
98                        entry_time: now,
99                    });
100                }
101                TickAction::PositionClosed { reason } | TickAction::ForceClose { reason } => {
102                    if let Some(trade) = open_trade.take() {
103                        let pnl = if trade.side == "sell" || trade.side == "short" {
104                            trade.entry_price - current_price
105                        } else {
106                            current_price - trade.entry_price
107                        };
108                        trades.push(TradeRecord {
109                            regime: trade.regime,
110                            side: trade.side,
111                            entry_price: trade.entry_price,
112                            exit_price: current_price,
113                            pnl,
114                            hold_time_secs: now.saturating_sub(trade.entry_time),
115                            exit_reason: reason.clone(),
116                        });
117                    }
118                }
119                TickAction::OrderCancelled { .. } | TickAction::None => {}
120            }
121        }
122
123        // Calculate metrics
124        let total_trades = trades.len() as u32;
125        let wins = trades.iter().filter(|t| t.pnl > 0.0).count() as f64;
126        let win_rate = if total_trades > 0 {
127            wins / total_trades as f64
128        } else {
129            0.0
130        };
131        let total_pnl: f64 = trades.iter().map(|t| t.pnl).sum();
132        let max_drawdown = calculate_max_drawdown(&trades);
133        let sharpe_ratio = calculate_sharpe_ratio(&trades);
134        let avg_hold_time_secs = if total_trades > 0 {
135            trades.iter().map(|t| t.hold_time_secs).sum::<u64>() / total_trades as u64
136        } else {
137            0
138        };
139
140        BacktestResult {
141            total_trades,
142            win_rate,
143            total_pnl,
144            max_drawdown,
145            sharpe_ratio,
146            avg_hold_time_secs,
147            regime_distribution,
148            trades,
149            total_ticks,
150        }
151    }
152}
153
154// ---------------------------------------------------------------------------
155// Helper functions
156// ---------------------------------------------------------------------------
157
158fn calculate_max_drawdown(trades: &[TradeRecord]) -> f64 {
159    let mut equity = 0.0_f64;
160    let mut peak = 0.0_f64;
161    let mut max_dd = 0.0_f64;
162    for trade in trades {
163        equity += trade.pnl;
164        if equity > peak {
165            peak = equity;
166        }
167        let dd = peak - equity;
168        if dd > max_dd {
169            max_dd = dd;
170        }
171    }
172    -max_dd // negative number
173}
174
175fn calculate_sharpe_ratio(trades: &[TradeRecord]) -> f64 {
176    if trades.len() < 2 {
177        return 0.0;
178    }
179    let pnls: Vec<f64> = trades.iter().map(|t| t.pnl).collect();
180    let mean = pnls.iter().sum::<f64>() / pnls.len() as f64;
181    let variance = pnls.iter().map(|p| (p - mean).powi(2)).sum::<f64>() / (pnls.len() - 1) as f64;
182    let std_dev = variance.sqrt();
183    if std_dev == 0.0 {
184        return 0.0;
185    }
186    mean / std_dev
187}
188
189// ---------------------------------------------------------------------------
190// Tests
191// ---------------------------------------------------------------------------
192
193#[cfg(test)]
194mod tests {
195    use super::*;
196    use crate::executor::PaperOrderExecutor;
197    use hyper_strategy::strategy_config::{
198        HysteresisConfig, Playbook, RegimeRule, StrategyGroup, TaRule,
199    };
200
201    // -- Candle generators --
202
203    /// Generate synthetic candles: trending up, then ranging, then down.
204    /// Returns at least 250 candles with 300-second intervals.
205    fn generate_synthetic_candles(count: usize) -> Vec<Candle> {
206        let mut candles = Vec::with_capacity(count);
207        let base_time = 1_700_000_000u64;
208        let interval = 300u64;
209        let mut price = 50_000.0;
210
211        for i in 0..count {
212            let phase = if i < count / 3 {
213                // Phase 1: trending up
214                1
215            } else if i < 2 * count / 3 {
216                // Phase 2: ranging
217                2
218            } else {
219                // Phase 3: trending down
220                3
221            };
222
223            let delta = match phase {
224                1 => 50.0 + 20.0 * ((i as f64 * 0.1).sin()),   // uptrend
225                2 => 30.0 * ((i as f64 * 0.3).sin()),          // ranging
226                3 => -60.0 + 15.0 * ((i as f64 * 0.15).sin()), // downtrend
227                _ => 0.0,
228            };
229
230            price += delta;
231            if price < 1000.0 {
232                price = 1000.0;
233            }
234
235            let open = price - delta * 0.3;
236            let high = price.max(open) + 100.0;
237            let low = price.min(open) - 100.0;
238            let volume = 1000.0 + 500.0 * ((i as f64 * 0.05).sin()).abs();
239
240            candles.push(Candle {
241                time: base_time + (i as u64) * interval,
242                open,
243                high,
244                low,
245                close: price,
246                volume,
247            });
248        }
249
250        candles
251    }
252
253    // -- Strategy group for tests --
254
255    fn make_ta_rule(
256        indicator: &str,
257        params: Vec<f64>,
258        condition: &str,
259        threshold: f64,
260        signal: &str,
261    ) -> TaRule {
262        TaRule {
263            indicator: indicator.to_string(),
264            params,
265            condition: condition.to_string(),
266            threshold,
267            threshold_upper: None,
268            signal: signal.to_string(),
269            action: None,
270        }
271    }
272
273    fn make_ta_rule_between(
274        indicator: &str,
275        params: Vec<f64>,
276        lo: f64,
277        hi: f64,
278        signal: &str,
279    ) -> TaRule {
280        TaRule {
281            indicator: indicator.to_string(),
282            params,
283            condition: "between".to_string(),
284            threshold: lo,
285            threshold_upper: Some(hi),
286            signal: signal.to_string(),
287            action: None,
288        }
289    }
290
291    fn simple_strategy_group() -> StrategyGroup {
292        let mut playbooks = HashMap::new();
293
294        // bull playbook: entry when RSI > 60, exit when RSI < 40
295        playbooks.insert(
296            "bull".to_string(),
297            Playbook {
298                rules: vec![],
299                entry_rules: vec![make_ta_rule("RSI", vec![14.0], "gt", 60.0, "buy_momentum")],
300                exit_rules: vec![make_ta_rule("RSI", vec![14.0], "lt", 40.0, "momentum_lost")],
301                system_prompt: "bull".into(),
302                max_position_size: 1000.0,
303                stop_loss_pct: Some(5.0),
304                take_profit_pct: Some(10.0),
305                timeout_secs: Some(600),
306                side: Some("buy".into()),
307            },
308        );
309
310        // bear playbook: zero position size (no trading)
311        playbooks.insert(
312            "bear".to_string(),
313            Playbook {
314                rules: vec![],
315                entry_rules: vec![],
316                exit_rules: vec![],
317                system_prompt: "bear".into(),
318                max_position_size: 0.0,
319                stop_loss_pct: Some(3.0),
320                take_profit_pct: None,
321                timeout_secs: None,
322                side: None,
323            },
324        );
325
326        // neutral playbook: entry when RSI < 30, exit when RSI between 45-55
327        playbooks.insert(
328            "neutral".to_string(),
329            Playbook {
330                rules: vec![],
331                entry_rules: vec![make_ta_rule("RSI", vec![14.0], "lt", 30.0, "oversold_buy")],
332                exit_rules: vec![make_ta_rule_between(
333                    "RSI",
334                    vec![14.0],
335                    45.0,
336                    55.0,
337                    "rsi_neutral_exit",
338                )],
339                system_prompt: "neutral".into(),
340                max_position_size: 500.0,
341                stop_loss_pct: Some(5.0),
342                take_profit_pct: Some(10.0),
343                timeout_secs: Some(300),
344                side: None,
345            },
346        );
347
348        StrategyGroup {
349            id: "sg-backtest".into(),
350            name: "Backtest Test".into(),
351            vault_address: None,
352            is_active: true,
353            created_at: "2026-01-01T00:00:00Z".into(),
354            symbol: "BTC-USD".into(),
355            interval_secs: 300,
356            regime_rules: vec![
357                RegimeRule {
358                    regime: "bull".into(),
359                    conditions: vec![make_ta_rule("ADX", vec![14.0], "gt", 50.0, "strong_bull")],
360                    priority: 1,
361                },
362                RegimeRule {
363                    regime: "bear".into(),
364                    conditions: vec![make_ta_rule("ADX", vec![14.0], "lt", 10.0, "weak_bear")],
365                    priority: 2,
366                },
367            ],
368            default_regime: "neutral".into(),
369            hysteresis: HysteresisConfig {
370                min_hold_secs: 0,
371                confirmation_count: 1,
372            },
373            playbooks,
374        }
375    }
376
377    fn new_engine() -> PlaybookEngine {
378        PlaybookEngine::new(
379            "BTC-USD".into(),
380            simple_strategy_group(),
381            Box::new(PaperOrderExecutor::new()),
382        )
383    }
384
385    // -----------------------------------------------------------------------
386    // 1. Backtest run with synthetic candles — basic smoke test
387    // -----------------------------------------------------------------------
388
389    #[tokio::test]
390    async fn test_backtest_run_synthetic_candles() {
391        let candles = generate_synthetic_candles(250);
392        let engine = new_engine();
393        let mut runner = BacktestRunner::new(engine);
394
395        let result = runner.run(&candles, 50).await;
396
397        // We should have processed candles[50..250] = 200 ticks
398        assert_eq!(result.total_ticks, 200);
399        assert!(
400            !result.regime_distribution.is_empty(),
401            "regime_distribution should have entries"
402        );
403    }
404
405    // -----------------------------------------------------------------------
406    // 2. BacktestResult serializes to JSON
407    // -----------------------------------------------------------------------
408
409    #[tokio::test]
410    async fn test_backtest_result_serializes_to_json() {
411        let candles = generate_synthetic_candles(250);
412        let engine = new_engine();
413        let mut runner = BacktestRunner::new(engine);
414
415        let result = runner.run(&candles, 50).await;
416
417        let json = serde_json::to_string(&result).unwrap();
418        let parsed: BacktestResult = serde_json::from_str(&json).unwrap();
419
420        assert_eq!(parsed.total_ticks, result.total_ticks);
421        assert_eq!(parsed.total_trades, result.total_trades);
422        assert!((parsed.total_pnl - result.total_pnl).abs() < 1e-10);
423    }
424
425    // -----------------------------------------------------------------------
426    // 3. max_drawdown calculation with known values
427    // -----------------------------------------------------------------------
428
429    #[test]
430    fn test_max_drawdown_known_values() {
431        // Trades: +10, -5, +20, -30, +5
432        // Equity: 10, 5, 25, -5, 0
433        // Peak:   10, 10, 25, 25, 25
434        // DD:     0,  5,  0,  30, 25
435        // max_dd = -30
436        let trades = vec![
437            TradeRecord {
438                regime: "bull".into(),
439                side: "long".into(),
440                entry_price: 100.0,
441                exit_price: 110.0,
442                pnl: 10.0,
443                hold_time_secs: 60,
444                exit_reason: "exit_rule".into(),
445            },
446            TradeRecord {
447                regime: "bull".into(),
448                side: "long".into(),
449                entry_price: 110.0,
450                exit_price: 105.0,
451                pnl: -5.0,
452                hold_time_secs: 30,
453                exit_reason: "stop_loss".into(),
454            },
455            TradeRecord {
456                regime: "bull".into(),
457                side: "long".into(),
458                entry_price: 105.0,
459                exit_price: 125.0,
460                pnl: 20.0,
461                hold_time_secs: 120,
462                exit_reason: "take_profit".into(),
463            },
464            TradeRecord {
465                regime: "neutral".into(),
466                side: "long".into(),
467                entry_price: 125.0,
468                exit_price: 95.0,
469                pnl: -30.0,
470                hold_time_secs: 300,
471                exit_reason: "stop_loss".into(),
472            },
473            TradeRecord {
474                regime: "neutral".into(),
475                side: "long".into(),
476                entry_price: 95.0,
477                exit_price: 100.0,
478                pnl: 5.0,
479                hold_time_secs: 60,
480                exit_reason: "exit_rule".into(),
481            },
482        ];
483
484        let dd = calculate_max_drawdown(&trades);
485        assert!(
486            (dd - (-30.0)).abs() < 1e-10,
487            "max drawdown should be -30.0, got {}",
488            dd
489        );
490    }
491
492    #[test]
493    fn test_max_drawdown_empty() {
494        let dd = calculate_max_drawdown(&[]);
495        assert!((dd - 0.0).abs() < 1e-10);
496    }
497
498    #[test]
499    fn test_max_drawdown_all_wins() {
500        let trades = vec![
501            TradeRecord {
502                regime: "bull".into(),
503                side: "long".into(),
504                entry_price: 100.0,
505                exit_price: 110.0,
506                pnl: 10.0,
507                hold_time_secs: 60,
508                exit_reason: "take_profit".into(),
509            },
510            TradeRecord {
511                regime: "bull".into(),
512                side: "long".into(),
513                entry_price: 110.0,
514                exit_price: 120.0,
515                pnl: 10.0,
516                hold_time_secs: 60,
517                exit_reason: "take_profit".into(),
518            },
519        ];
520        let dd = calculate_max_drawdown(&trades);
521        assert!((dd - 0.0).abs() < 1e-10, "no drawdown for all wins");
522    }
523
524    // -----------------------------------------------------------------------
525    // 4. sharpe_ratio calculation with known values
526    // -----------------------------------------------------------------------
527
528    #[test]
529    fn test_sharpe_ratio_known_values() {
530        // PnLs: [10, -5, 20, -30, 5]
531        // mean = 0.0
532        // variance = (100 + 25 + 400 + 900 + 25) / 4 = 1450/4 = 362.5
533        // std_dev = sqrt(362.5) ~ 19.039
534        // sharpe = 0.0 / 19.039 = 0.0
535        let trades = vec![
536            TradeRecord {
537                regime: "a".into(),
538                side: "long".into(),
539                entry_price: 0.0,
540                exit_price: 0.0,
541                pnl: 10.0,
542                hold_time_secs: 0,
543                exit_reason: "".into(),
544            },
545            TradeRecord {
546                regime: "a".into(),
547                side: "long".into(),
548                entry_price: 0.0,
549                exit_price: 0.0,
550                pnl: -5.0,
551                hold_time_secs: 0,
552                exit_reason: "".into(),
553            },
554            TradeRecord {
555                regime: "a".into(),
556                side: "long".into(),
557                entry_price: 0.0,
558                exit_price: 0.0,
559                pnl: 20.0,
560                hold_time_secs: 0,
561                exit_reason: "".into(),
562            },
563            TradeRecord {
564                regime: "a".into(),
565                side: "long".into(),
566                entry_price: 0.0,
567                exit_price: 0.0,
568                pnl: -30.0,
569                hold_time_secs: 0,
570                exit_reason: "".into(),
571            },
572            TradeRecord {
573                regime: "a".into(),
574                side: "long".into(),
575                entry_price: 0.0,
576                exit_price: 0.0,
577                pnl: 5.0,
578                hold_time_secs: 0,
579                exit_reason: "".into(),
580            },
581        ];
582        let sr = calculate_sharpe_ratio(&trades);
583        assert!(
584            sr.abs() < 1e-10,
585            "sharpe should be 0 when mean is 0, got {}",
586            sr
587        );
588    }
589
590    #[test]
591    fn test_sharpe_ratio_positive() {
592        // PnLs: [10, 20, 15]
593        // mean = 15.0
594        // variance = (25 + 25 + 0) / 2 = 25.0
595        // std_dev = 5.0
596        // sharpe = 15.0 / 5.0 = 3.0
597        let trades = vec![
598            TradeRecord {
599                regime: "a".into(),
600                side: "long".into(),
601                entry_price: 0.0,
602                exit_price: 0.0,
603                pnl: 10.0,
604                hold_time_secs: 0,
605                exit_reason: "".into(),
606            },
607            TradeRecord {
608                regime: "a".into(),
609                side: "long".into(),
610                entry_price: 0.0,
611                exit_price: 0.0,
612                pnl: 20.0,
613                hold_time_secs: 0,
614                exit_reason: "".into(),
615            },
616            TradeRecord {
617                regime: "a".into(),
618                side: "long".into(),
619                entry_price: 0.0,
620                exit_price: 0.0,
621                pnl: 15.0,
622                hold_time_secs: 0,
623                exit_reason: "".into(),
624            },
625        ];
626        let sr = calculate_sharpe_ratio(&trades);
627        assert!((sr - 3.0).abs() < 1e-10, "sharpe should be 3.0, got {}", sr);
628    }
629
630    #[test]
631    fn test_sharpe_ratio_single_trade() {
632        let trades = vec![TradeRecord {
633            regime: "a".into(),
634            side: "long".into(),
635            entry_price: 0.0,
636            exit_price: 0.0,
637            pnl: 10.0,
638            hold_time_secs: 0,
639            exit_reason: "".into(),
640        }];
641        let sr = calculate_sharpe_ratio(&trades);
642        assert!(
643            (sr - 0.0).abs() < 1e-10,
644            "sharpe should be 0 for single trade"
645        );
646    }
647
648    #[test]
649    fn test_sharpe_ratio_empty() {
650        let sr = calculate_sharpe_ratio(&[]);
651        assert!((sr - 0.0).abs() < 1e-10);
652    }
653
654    // -----------------------------------------------------------------------
655    // 5. Backtest with adaptive_trend template
656    // -----------------------------------------------------------------------
657
658    #[tokio::test]
659    async fn test_backtest_with_adaptive_trend_template() {
660        use hyper_strategy::strategy_templates::build_template;
661
662        let sg = build_template("adaptive_trend", "BTC-USD").unwrap();
663        let engine = PlaybookEngine::new("BTC-USD".into(), sg, Box::new(PaperOrderExecutor::new()));
664        let mut runner = BacktestRunner::new(engine);
665
666        let candles = generate_synthetic_candles(300);
667        let result = runner.run(&candles, 50).await;
668
669        assert!(result.total_ticks > 0, "should have processed ticks");
670        assert_eq!(result.total_ticks, 250);
671        assert!(
672            !result.regime_distribution.is_empty(),
673            "should have regime distribution entries"
674        );
675
676        // Verify JSON roundtrip
677        let json = serde_json::to_string(&result).unwrap();
678        let _: BacktestResult = serde_json::from_str(&json).unwrap();
679    }
680
681    // -----------------------------------------------------------------------
682    // 6. Empty candles / edge cases
683    // -----------------------------------------------------------------------
684
685    #[tokio::test]
686    async fn test_backtest_empty_candles() {
687        let engine = new_engine();
688        let mut runner = BacktestRunner::new(engine);
689        let result = runner.run(&[], 50).await;
690
691        assert_eq!(result.total_ticks, 0);
692        assert_eq!(result.total_trades, 0);
693        assert!((result.win_rate - 0.0).abs() < 1e-10);
694        assert!((result.total_pnl - 0.0).abs() < 1e-10);
695        assert!((result.max_drawdown - 0.0).abs() < 1e-10);
696        assert!((result.sharpe_ratio - 0.0).abs() < 1e-10);
697        assert_eq!(result.avg_hold_time_secs, 0);
698    }
699
700    #[tokio::test]
701    async fn test_backtest_window_larger_than_candles() {
702        let candles = generate_synthetic_candles(30);
703        let engine = new_engine();
704        let mut runner = BacktestRunner::new(engine);
705
706        let result = runner.run(&candles, 50).await;
707        assert_eq!(result.total_ticks, 0);
708    }
709}