Skip to main content

finance_query/backtesting/strategy/
ensemble.rs

1//! Ensemble strategy that combines multiple strategies with configurable voting modes.
2//!
3//! An ensemble aggregates signals from multiple sub-strategies and resolves them
4//! using one of four voting modes.
5//!
6//! # Example
7//!
8//! ```no_run
9//! use finance_query::backtesting::{EnsembleStrategy, EnsembleMode, SmaCrossover, RsiReversal, MacdSignal};
10//!
11//! let strategy = EnsembleStrategy::new("Multi-Signal")
12//!     .add(SmaCrossover::new(10, 50), 1.0)
13//!     .add(RsiReversal::default(), 0.5)
14//!     .add(MacdSignal::default(), 1.0)
15//!     .mode(EnsembleMode::WeightedMajority)
16//!     .build();
17//! ```
18
19use crate::indicators::Indicator;
20
21use super::{Signal, Strategy, StrategyContext};
22use crate::backtesting::signal::{SignalDirection, SignalStrength};
23
24/// Voting mode that determines how sub-strategy signals are combined.
25#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
26pub enum EnsembleMode {
27    /// All active sub-strategies must agree on the same direction ("no-dissent" semantics).
28    ///
29    /// Strategies that return `Hold` abstain from the vote and do not block a consensus.
30    /// If any two *active* strategies disagree, the ensemble returns Hold.
31    /// The resulting signal strength is the average of all active strengths.
32    ///
33    /// **Note**: if you require *all* strategies (including abstainers) to explicitly
34    /// vote for the same direction, check `active_count == strategies.len()` yourself
35    /// and wrap this in a custom [`Strategy`] impl.
36    Unanimous,
37
38    /// Conviction-weighted vote: each vote is `weight × signal_strength`.
39    ///
40    /// All five directions (`Long`, `Short`, `Exit`, `ScaleIn`, `ScaleOut`) are
41    /// tallied independently; the highest score wins.
42    ///
43    /// **Strength denominator**: the output `signal_strength` is
44    /// `winner_score / Σ(all_weights)`, not `winner_score / Σ(active_scores)`.
45    /// Dividing by total potential prevents a lone weak voter from being
46    /// artificially amplified when the majority of sub-strategies abstain.
47    ///
48    /// **Scale fraction**: when `ScaleIn` or `ScaleOut` wins, the emitted
49    /// `scale_fraction` is the conviction-weighted average of all same-direction
50    /// voters' fractions, not the single highest-score contributor's fraction.
51    ///
52    /// **Position guard**: `Exit`, `ScaleIn`, and `ScaleOut` are only tallied
53    /// when a position is currently open. While flat they are discarded so they
54    /// cannot suppress entry signal strength. Their weights still count toward
55    /// the total-potential denominator.
56    ///
57    /// **Note on vote splitting**: `Exit` and `Short` are counted as independent
58    /// factions — if their combined intent would dominate but each falls below
59    /// `Long` individually, the ensemble may maintain a long position despite the
60    /// majority wanting to exit. Document your ensemble weights accordingly.
61    #[default]
62    WeightedMajority,
63
64    /// First non-Hold signal wins (strategies are evaluated in insertion order).
65    ///
66    /// **Note**: this gives a permanent priority advantage to strategies added
67    /// first via [`add`](EnsembleStrategy::add). Use [`StrongestSignal`](Self::StrongestSignal)
68    /// if you want insertion-order independence.
69    AnySignal,
70
71    /// The non-Hold signal with the highest `signal_strength` value wins.
72    StrongestSignal,
73}
74
75/// A strategy that aggregates signals from multiple sub-strategies.
76///
77/// Build with the fluent builder methods [`add`](Self::add), [`mode`](Self::mode),
78/// then finalise with [`build`](Self::build).
79///
80/// All six [`SignalDirection`](crate::backtesting::SignalDirection) variants are
81/// fully supported. In [`EnsembleMode::WeightedMajority`], `ScaleIn` and `ScaleOut`
82/// participate in the vote with the same position guard as `Exit` — they are only
83/// tallied when a position is open. In `Unanimous`, `AnySignal`, and
84/// `StrongestSignal` modes they are treated like any other non-Hold direction.
85pub struct EnsembleStrategy {
86    name: String,
87    strategies: Vec<(Box<dyn Strategy>, f64)>,
88    mode: EnsembleMode,
89}
90
91impl EnsembleStrategy {
92    /// Create a new ensemble with the given name.
93    ///
94    /// The default voting mode is [`EnsembleMode::WeightedMajority`].
95    pub fn new(name: impl Into<String>) -> Self {
96        Self {
97            name: name.into(),
98            strategies: Vec::new(),
99            mode: EnsembleMode::default(),
100        }
101    }
102
103    /// Add a sub-strategy with the given weight.
104    ///
105    /// Weight is only meaningful for [`EnsembleMode::WeightedMajority`]; other
106    /// modes ignore it. Negative weights are treated as zero.
107    pub fn add<S: Strategy + 'static>(mut self, strategy: S, weight: f64) -> Self {
108        self.strategies.push((Box::new(strategy), weight.max(0.0)));
109        self
110    }
111
112    /// Set the voting mode.
113    pub fn mode(mut self, mode: EnsembleMode) -> Self {
114        self.mode = mode;
115        self
116    }
117
118    /// Finalise the ensemble. Returns `self` (all configuration happens in the
119    /// builder methods).
120    pub fn build(self) -> Self {
121        self
122    }
123
124    // ── voting helpers ────────────────────────────────────────────────────────
125
126    fn any_signal(&self, ctx: &StrategyContext) -> Signal {
127        for (strategy, _) in &self.strategies {
128            let signal = strategy.on_candle(ctx);
129            if !signal.is_hold() {
130                return signal;
131            }
132        }
133        Signal::hold()
134    }
135
136    fn unanimous(&self, ctx: &StrategyContext) -> Signal {
137        // Evaluated iteratively — no allocations in the hot path.
138        // As soon as any two active sub-strategies disagree we bail out early.
139        let mut first_dir: Option<SignalDirection> = None;
140        let mut first_signal: Option<Signal> = None;
141        let mut total_strength = 0.0_f64;
142        let mut active_count = 0_usize;
143
144        for (strategy, _) in &self.strategies {
145            let signal = strategy.on_candle(ctx);
146            if signal.is_hold() {
147                continue;
148            }
149            match first_dir {
150                None => {
151                    first_dir = Some(signal.direction);
152                    total_strength = signal.strength.value();
153                    first_signal = Some(signal);
154                    active_count = 1;
155                }
156                Some(dir) if dir == signal.direction => {
157                    total_strength += signal.strength.value();
158                    active_count += 1;
159                }
160                _ => return Signal::hold(), // disagreement — short-circuit
161            }
162        }
163
164        let Some(mut sig) = first_signal else {
165            return Signal::hold();
166        };
167
168        let dir = first_dir.unwrap();
169        let avg_strength = total_strength / active_count as f64;
170        let original_reason = sig.reason.take();
171        sig.strength = SignalStrength::clamped(avg_strength);
172        sig.reason = Some(format!(
173            "Unanimous ({} of {} agree): {}",
174            active_count,
175            self.strategies.len(),
176            original_reason.as_deref().unwrap_or(&dir.to_string())
177        ));
178        sig
179    }
180
181    fn weighted_majority(&self, ctx: &StrategyContext) -> Signal {
182        // Denominator = sum of ALL strategy weights (total potential conviction).
183        // Prevents a lone weak voter from being artificially amplified when
184        // the majority of sub-strategies abstain.
185        let total_potential: f64 = self.strategies.iter().map(|(_, w)| *w).sum();
186        if total_potential < f64::EPSILON {
187            return Signal::hold();
188        }
189
190        let mut long_weight = 0.0_f64;
191        let mut short_weight = 0.0_f64;
192        let mut exit_weight = 0.0_f64;
193        let mut scale_in_weight = 0.0_f64;
194        let mut scale_out_weight = 0.0_f64;
195
196        // Σ(scale_fraction × score) for conviction-weighted average fraction.
197        let mut scale_in_frac_score = 0.0_f64;
198        let mut scale_out_frac_score = 0.0_f64;
199
200        // Track (signal, score) so we inherit metadata from the highest-conviction
201        // contributor, not merely the first one encountered.
202        let mut best_long: Option<(Signal, f64)> = None;
203        let mut best_short: Option<(Signal, f64)> = None;
204        let mut best_exit: Option<(Signal, f64)> = None;
205        let mut best_scale_in: Option<(Signal, f64)> = None;
206        let mut best_scale_out: Option<(Signal, f64)> = None;
207
208        let has_position = ctx.has_position();
209
210        for (strategy, weight) in &self.strategies {
211            let signal = strategy.on_candle(ctx);
212            // Vote score = static weight × dynamic conviction
213            let score = weight * signal.strength.value();
214            match signal.direction {
215                SignalDirection::Long => {
216                    long_weight += score;
217                    if best_long.as_ref().is_none_or(|&(_, s)| score > s) {
218                        best_long = Some((signal, score));
219                    }
220                }
221                SignalDirection::Short => {
222                    short_weight += score;
223                    if best_short.as_ref().is_none_or(|&(_, s)| score > s) {
224                        best_short = Some((signal, score));
225                    }
226                }
227                // Exit, ScaleIn, ScaleOut require an open position — while flat
228                // they are discarded so they cannot suppress entry signal strength.
229                // Their weights still count toward total_potential (denominator).
230                SignalDirection::Exit if has_position => {
231                    exit_weight += score;
232                    if best_exit.as_ref().is_none_or(|&(_, s)| score > s) {
233                        best_exit = Some((signal, score));
234                    }
235                }
236                SignalDirection::ScaleIn if has_position => {
237                    let frac = signal.scale_fraction.unwrap_or(0.0);
238                    scale_in_weight += score;
239                    scale_in_frac_score += frac * score;
240                    if best_scale_in.as_ref().is_none_or(|&(_, s)| score > s) {
241                        best_scale_in = Some((signal, score));
242                    }
243                }
244                SignalDirection::ScaleOut if has_position => {
245                    let frac = signal.scale_fraction.unwrap_or(0.0);
246                    scale_out_weight += score;
247                    scale_out_frac_score += frac * score;
248                    if best_scale_out.as_ref().is_none_or(|&(_, s)| score > s) {
249                        best_scale_out = Some((signal, score));
250                    }
251                }
252                _ => {}
253            }
254        }
255
256        // At least one strategy must have cast a non-Hold vote.
257        let total_active =
258            long_weight + short_weight + exit_weight + scale_in_weight + scale_out_weight;
259        if total_active < f64::EPSILON {
260            return Signal::hold();
261        }
262
263        // Determine winner (strict majority across all five directions; ties → Hold)
264        let (winner, winner_score) = if long_weight > short_weight
265            && long_weight > exit_weight
266            && long_weight > scale_in_weight
267            && long_weight > scale_out_weight
268        {
269            (best_long, long_weight)
270        } else if short_weight > long_weight
271            && short_weight > exit_weight
272            && short_weight > scale_in_weight
273            && short_weight > scale_out_weight
274        {
275            (best_short, short_weight)
276        } else if exit_weight > long_weight
277            && exit_weight > short_weight
278            && exit_weight > scale_in_weight
279            && exit_weight > scale_out_weight
280        {
281            (best_exit, exit_weight)
282        } else if scale_in_weight > long_weight
283            && scale_in_weight > short_weight
284            && scale_in_weight > exit_weight
285            && scale_in_weight > scale_out_weight
286        {
287            (best_scale_in, scale_in_weight)
288        } else if scale_out_weight > long_weight
289            && scale_out_weight > short_weight
290            && scale_out_weight > exit_weight
291            && scale_out_weight > scale_in_weight
292        {
293            (best_scale_out, scale_out_weight)
294        } else {
295            return Signal::hold();
296        };
297
298        let Some((mut sig, _)) = winner else {
299            return Signal::hold();
300        };
301
302        // Strength = winner score / total potential (not just active votes).
303        sig.strength = SignalStrength::clamped(winner_score / total_potential);
304
305        // Replace inherited scale_fraction with the conviction-weighted average
306        // of all same-direction voters to avoid single-contributor size bias.
307        if sig.direction == SignalDirection::ScaleIn && scale_in_weight > f64::EPSILON {
308            sig.scale_fraction = Some((scale_in_frac_score / scale_in_weight).clamp(0.0, 1.0));
309        } else if sig.direction == SignalDirection::ScaleOut && scale_out_weight > f64::EPSILON {
310            sig.scale_fraction = Some((scale_out_frac_score / scale_out_weight).clamp(0.0, 1.0));
311        }
312
313        sig.reason = Some(format!(
314            "WeightedMajority: long={long_weight:.2} short={short_weight:.2} \
315             exit={exit_weight:.2} scale_in={scale_in_weight:.2} scale_out={scale_out_weight:.2}"
316        ));
317        sig
318    }
319
320    fn strongest_signal(&self, ctx: &StrategyContext) -> Signal {
321        self.strategies
322            .iter()
323            .map(|(s, _)| s.on_candle(ctx))
324            .filter(|s| !s.is_hold())
325            .max_by(|a, b| a.strength.value().total_cmp(&b.strength.value()))
326            .unwrap_or_else(Signal::hold)
327    }
328}
329
330impl std::fmt::Debug for EnsembleStrategy {
331    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
332        f.debug_struct("EnsembleStrategy")
333            .field("name", &self.name)
334            .field("strategies_count", &self.strategies.len())
335            .field("mode", &self.mode)
336            .finish()
337    }
338}
339
340impl Strategy for EnsembleStrategy {
341    fn name(&self) -> &str {
342        &self.name
343    }
344
345    fn required_indicators(&self) -> Vec<(String, Indicator)> {
346        let mut indicators: Vec<(String, Indicator)> = self
347            .strategies
348            .iter()
349            .flat_map(|(s, _)| s.required_indicators())
350            .collect();
351        indicators.sort_by(|a, b| a.0.cmp(&b.0));
352        indicators.dedup_by(|a, b| a.0 == b.0);
353        indicators
354    }
355
356    fn setup(&mut self, indicators: &std::collections::HashMap<String, Vec<Option<f64>>>) {
357        for (strategy, _) in &mut self.strategies {
358            strategy.setup(indicators);
359        }
360    }
361
362    fn warmup_period(&self) -> usize {
363        self.strategies
364            .iter()
365            .map(|(s, _)| s.warmup_period())
366            .max()
367            .unwrap_or(1)
368    }
369
370    fn on_candle(&self, ctx: &StrategyContext) -> Signal {
371        if self.strategies.is_empty() {
372            return Signal::hold();
373        }
374        match self.mode {
375            EnsembleMode::AnySignal => self.any_signal(ctx),
376            EnsembleMode::Unanimous => self.unanimous(ctx),
377            EnsembleMode::WeightedMajority => self.weighted_majority(ctx),
378            EnsembleMode::StrongestSignal => self.strongest_signal(ctx),
379        }
380    }
381}
382
383#[cfg(test)]
384mod tests {
385    use super::*;
386    use crate::backtesting::signal::SignalDirection;
387    use crate::backtesting::strategy::Strategy;
388    use crate::indicators::Indicator;
389    use crate::models::chart::Candle;
390    use std::collections::HashMap;
391
392    fn make_candle(ts: i64, price: f64) -> Candle {
393        Candle {
394            timestamp: ts,
395            open: price,
396            high: price,
397            low: price,
398            close: price,
399            volume: 1000,
400            adj_close: None,
401        }
402    }
403
404    fn make_ctx<'a>(
405        candles: &'a [Candle],
406        indicators: &'a HashMap<String, Vec<Option<f64>>>,
407    ) -> StrategyContext<'a> {
408        StrategyContext {
409            candles,
410            index: candles.len() - 1,
411            position: None,
412            equity: 10_000.0,
413            indicators,
414        }
415    }
416
417    // A strategy that always emits the given direction
418    struct FixedStrategy {
419        direction: SignalDirection,
420        strength: f64,
421    }
422
423    impl Strategy for FixedStrategy {
424        fn name(&self) -> &str {
425            "Fixed"
426        }
427        fn required_indicators(&self) -> Vec<(String, Indicator)> {
428            vec![]
429        }
430        fn on_candle(&self, ctx: &StrategyContext) -> Signal {
431            match self.direction {
432                SignalDirection::Long => {
433                    let mut s = Signal::long(ctx.timestamp(), ctx.close());
434                    s.strength = SignalStrength::clamped(self.strength);
435                    s
436                }
437                SignalDirection::Short => {
438                    let mut s = Signal::short(ctx.timestamp(), ctx.close());
439                    s.strength = SignalStrength::clamped(self.strength);
440                    s
441                }
442                SignalDirection::Exit => {
443                    let mut s = Signal::exit(ctx.timestamp(), ctx.close());
444                    s.strength = SignalStrength::clamped(self.strength);
445                    s
446                }
447                SignalDirection::ScaleIn => {
448                    let mut s = Signal::scale_in(0.1, ctx.timestamp(), ctx.close());
449                    s.strength = SignalStrength::clamped(self.strength);
450                    s
451                }
452                SignalDirection::ScaleOut => {
453                    let mut s = Signal::scale_out(0.5, ctx.timestamp(), ctx.close());
454                    s.strength = SignalStrength::clamped(self.strength);
455                    s
456                }
457                _ => Signal::hold(),
458            }
459        }
460    }
461
462    fn make_ctx_with_position<'a>(
463        candles: &'a [Candle],
464        indicators: &'a HashMap<String, Vec<Option<f64>>>,
465        position: &'a crate::backtesting::Position,
466    ) -> StrategyContext<'a> {
467        StrategyContext {
468            candles,
469            index: candles.len() - 1,
470            position: Some(position),
471            equity: 10_000.0,
472            indicators,
473        }
474    }
475
476    fn candles() -> Vec<Candle> {
477        vec![make_candle(1, 100.0), make_candle(2, 101.0)]
478    }
479
480    fn empty_indicators() -> HashMap<String, Vec<Option<f64>>> {
481        HashMap::new()
482    }
483
484    #[test]
485    fn test_any_signal_returns_first_non_hold() {
486        let c = candles();
487        let ind = empty_indicators();
488        let ctx = make_ctx(&c, &ind);
489
490        let ensemble = EnsembleStrategy::new("test")
491            .add(
492                FixedStrategy {
493                    direction: SignalDirection::Hold,
494                    strength: 1.0,
495                },
496                1.0,
497            )
498            .add(
499                FixedStrategy {
500                    direction: SignalDirection::Long,
501                    strength: 0.8,
502                },
503                1.0,
504            )
505            .add(
506                FixedStrategy {
507                    direction: SignalDirection::Short,
508                    strength: 1.0,
509                },
510                1.0,
511            )
512            .mode(EnsembleMode::AnySignal)
513            .build();
514
515        let signal = ensemble.on_candle(&ctx);
516        assert_eq!(signal.direction, SignalDirection::Long);
517    }
518
519    #[test]
520    fn test_unanimous_all_agree() {
521        let c = candles();
522        let ind = empty_indicators();
523        let ctx = make_ctx(&c, &ind);
524
525        let ensemble = EnsembleStrategy::new("test")
526            .add(
527                FixedStrategy {
528                    direction: SignalDirection::Long,
529                    strength: 1.0,
530                },
531                1.0,
532            )
533            .add(
534                FixedStrategy {
535                    direction: SignalDirection::Long,
536                    strength: 0.6,
537                },
538                1.0,
539            )
540            .mode(EnsembleMode::Unanimous)
541            .build();
542
543        let signal = ensemble.on_candle(&ctx);
544        assert_eq!(signal.direction, SignalDirection::Long);
545        assert!((signal.strength.value() - 0.8).abs() < 1e-9); // avg of 1.0 and 0.6
546    }
547
548    #[test]
549    fn test_unanimous_disagreement_returns_hold() {
550        let c = candles();
551        let ind = empty_indicators();
552        let ctx = make_ctx(&c, &ind);
553
554        let ensemble = EnsembleStrategy::new("test")
555            .add(
556                FixedStrategy {
557                    direction: SignalDirection::Long,
558                    strength: 1.0,
559                },
560                1.0,
561            )
562            .add(
563                FixedStrategy {
564                    direction: SignalDirection::Short,
565                    strength: 1.0,
566                },
567                1.0,
568            )
569            .mode(EnsembleMode::Unanimous)
570            .build();
571
572        let signal = ensemble.on_candle(&ctx);
573        assert!(signal.is_hold());
574    }
575
576    #[test]
577    fn test_weighted_majority_long_wins() {
578        let c = candles();
579        let ind = empty_indicators();
580        let ctx = make_ctx(&c, &ind);
581
582        let ensemble = EnsembleStrategy::new("test")
583            .add(
584                FixedStrategy {
585                    direction: SignalDirection::Long,
586                    strength: 1.0,
587                },
588                2.0,
589            )
590            .add(
591                FixedStrategy {
592                    direction: SignalDirection::Short,
593                    strength: 1.0,
594                },
595                1.0,
596            )
597            .mode(EnsembleMode::WeightedMajority)
598            .build();
599
600        let signal = ensemble.on_candle(&ctx);
601        assert_eq!(signal.direction, SignalDirection::Long);
602        // strength = 2.0 / 3.0
603        assert!((signal.strength.value() - 2.0 / 3.0).abs() < 1e-9);
604    }
605
606    #[test]
607    fn test_weighted_majority_tie_returns_hold() {
608        let c = candles();
609        let ind = empty_indicators();
610        let ctx = make_ctx(&c, &ind);
611
612        let ensemble = EnsembleStrategy::new("test")
613            .add(
614                FixedStrategy {
615                    direction: SignalDirection::Long,
616                    strength: 1.0,
617                },
618                1.0,
619            )
620            .add(
621                FixedStrategy {
622                    direction: SignalDirection::Short,
623                    strength: 1.0,
624                },
625                1.0,
626            )
627            .mode(EnsembleMode::WeightedMajority)
628            .build();
629
630        let signal = ensemble.on_candle(&ctx);
631        assert!(signal.is_hold());
632    }
633
634    #[test]
635    fn test_strongest_signal() {
636        let c = candles();
637        let ind = empty_indicators();
638        let ctx = make_ctx(&c, &ind);
639
640        let ensemble = EnsembleStrategy::new("test")
641            .add(
642                FixedStrategy {
643                    direction: SignalDirection::Long,
644                    strength: 0.4,
645                },
646                1.0,
647            )
648            .add(
649                FixedStrategy {
650                    direction: SignalDirection::Short,
651                    strength: 0.9,
652                },
653                1.0,
654            )
655            .mode(EnsembleMode::StrongestSignal)
656            .build();
657
658        let signal = ensemble.on_candle(&ctx);
659        assert_eq!(signal.direction, SignalDirection::Short);
660        assert!((signal.strength.value() - 0.9).abs() < 1e-9);
661    }
662
663    #[test]
664    fn test_empty_ensemble_returns_hold() {
665        let c = candles();
666        let ind = empty_indicators();
667        let ctx = make_ctx(&c, &ind);
668
669        let ensemble = EnsembleStrategy::new("empty").build();
670        assert!(ensemble.on_candle(&ctx).is_hold());
671    }
672
673    #[test]
674    fn test_warmup_is_max_of_sub_strategies() {
675        struct WarmupStrategy(usize);
676        impl Strategy for WarmupStrategy {
677            fn name(&self) -> &str {
678                "Warmup"
679            }
680            fn required_indicators(&self) -> Vec<(String, Indicator)> {
681                vec![]
682            }
683            fn on_candle(&self, _ctx: &StrategyContext) -> Signal {
684                Signal::hold()
685            }
686            fn warmup_period(&self) -> usize {
687                self.0
688            }
689        }
690
691        let ensemble = EnsembleStrategy::new("test")
692            .add(WarmupStrategy(10), 1.0)
693            .add(WarmupStrategy(25), 1.0)
694            .add(WarmupStrategy(5), 1.0)
695            .build();
696
697        assert_eq!(ensemble.warmup_period(), 25);
698    }
699
700    #[test]
701    fn test_weighted_majority_exit_ignored_when_flat() {
702        // Exit votes should not suppress Long conviction when there is no position.
703        let c = candles();
704        let ind = empty_indicators();
705        let ctx = make_ctx(&c, &ind); // position = None
706
707        // Exit weight would dominate if counted (3.0 vs Long 2.0), but while flat
708        // it must be discarded and Long should win.
709        let ensemble = EnsembleStrategy::new("test")
710            .add(
711                FixedStrategy {
712                    direction: SignalDirection::Long,
713                    strength: 1.0,
714                },
715                2.0,
716            )
717            .add(
718                FixedStrategy {
719                    direction: SignalDirection::Exit,
720                    strength: 1.0,
721                },
722                3.0,
723            )
724            .mode(EnsembleMode::WeightedMajority)
725            .build();
726
727        let signal = ensemble.on_candle(&ctx);
728        assert_eq!(signal.direction, SignalDirection::Long);
729        // strength = 2.0 / 5.0 = 0.4 (exit weight counts in denominator even when
730        // its vote is discarded, correctly suppressing overconfidence)
731        assert!((signal.strength.value() - 2.0 / 5.0).abs() < 1e-9);
732    }
733
734    #[test]
735    fn test_weighted_majority_scale_in_wins_when_position_open() {
736        use crate::backtesting::{Position, PositionSide};
737
738        let c = candles();
739        let ind = empty_indicators();
740        let pos = Position::new(
741            PositionSide::Long,
742            1,
743            100.0,
744            10.0,
745            0.0,
746            Signal::long(1, 100.0),
747        );
748        let ctx = make_ctx_with_position(&c, &ind, &pos);
749
750        // ScaleIn has the highest conviction score (3.0) while Long has 2.0.
751        let ensemble = EnsembleStrategy::new("test")
752            .add(
753                FixedStrategy {
754                    direction: SignalDirection::Long,
755                    strength: 1.0,
756                },
757                2.0,
758            )
759            .add(
760                FixedStrategy {
761                    direction: SignalDirection::ScaleIn,
762                    strength: 1.0,
763                },
764                3.0,
765            )
766            .mode(EnsembleMode::WeightedMajority)
767            .build();
768
769        let signal = ensemble.on_candle(&ctx);
770        assert_eq!(signal.direction, SignalDirection::ScaleIn);
771        // strength = 3.0 / 5.0
772        assert!((signal.strength.value() - 0.6).abs() < 1e-9);
773        // scale_fraction is the conviction-weighted average of all ScaleIn voters
774        let frac = signal.scale_fraction.expect("scale_fraction must be set");
775        assert!((frac - 0.1).abs() < 1e-9, "expected 0.10, got {frac}");
776    }
777
778    #[test]
779    fn test_weighted_majority_scale_fraction_is_conviction_weighted_average() {
780        use crate::backtesting::{Position, PositionSide};
781
782        let c = candles();
783        let ind = empty_indicators();
784        let pos = Position::new(
785            PositionSide::Long,
786            1,
787            100.0,
788            10.0,
789            0.0,
790            Signal::long(1, 100.0),
791        );
792        let ctx = make_ctx_with_position(&c, &ind, &pos);
793
794        // Strategy A: ScaleOut 10% with score 1.0 × 1.0 = 1.0
795        // Strategy B: ScaleOut 50% with score 1.0 × 1.0 = 1.0
796        // Weighted-average fraction = (0.10 × 1.0 + 0.50 × 1.0) / 2.0 = 0.30
797        struct ScaleOutStrategy {
798            fraction: f64,
799        }
800        impl Strategy for ScaleOutStrategy {
801            fn name(&self) -> &str {
802                "ScaleOut"
803            }
804            fn required_indicators(&self) -> Vec<(String, Indicator)> {
805                vec![]
806            }
807            fn on_candle(&self, ctx: &StrategyContext) -> Signal {
808                Signal::scale_out(self.fraction, ctx.timestamp(), ctx.close())
809            }
810        }
811
812        let ensemble = EnsembleStrategy::new("test")
813            .add(ScaleOutStrategy { fraction: 0.10 }, 1.0)
814            .add(ScaleOutStrategy { fraction: 0.50 }, 1.0)
815            .mode(EnsembleMode::WeightedMajority)
816            .build();
817
818        let signal = ensemble.on_candle(&ctx);
819        assert_eq!(signal.direction, SignalDirection::ScaleOut);
820        let frac = signal.scale_fraction.expect("scale_fraction must be set");
821        assert!((frac - 0.30).abs() < 1e-9, "expected 0.30, got {frac}");
822    }
823
824    #[test]
825    fn test_weighted_majority_scale_in_ignored_when_flat() {
826        let c = candles();
827        let ind = empty_indicators();
828        let ctx = make_ctx(&c, &ind); // position = None
829
830        // ScaleIn would dominate (3.0 vs Long 2.0) but must be ignored while flat.
831        let ensemble = EnsembleStrategy::new("test")
832            .add(
833                FixedStrategy {
834                    direction: SignalDirection::Long,
835                    strength: 1.0,
836                },
837                2.0,
838            )
839            .add(
840                FixedStrategy {
841                    direction: SignalDirection::ScaleIn,
842                    strength: 1.0,
843                },
844                3.0,
845            )
846            .mode(EnsembleMode::WeightedMajority)
847            .build();
848
849        let signal = ensemble.on_candle(&ctx);
850        assert_eq!(signal.direction, SignalDirection::Long);
851        // strength = 2.0 / 5.0 = 0.4 (scale_in weight counts in denominator even
852        // when its vote is discarded while flat)
853        assert!((signal.strength.value() - 2.0 / 5.0).abs() < 1e-9);
854    }
855
856    #[test]
857    fn test_required_indicators_deduplication() {
858        struct IndStrategy(Vec<(String, Indicator)>);
859        impl Strategy for IndStrategy {
860            fn name(&self) -> &str {
861                "Ind"
862            }
863            fn required_indicators(&self) -> Vec<(String, Indicator)> {
864                self.0.clone()
865            }
866            fn on_candle(&self, _ctx: &StrategyContext) -> Signal {
867                Signal::hold()
868            }
869        }
870
871        let ensemble = EnsembleStrategy::new("test")
872            .add(
873                IndStrategy(vec![
874                    ("sma_10".to_string(), Indicator::Sma(10)),
875                    ("sma_20".to_string(), Indicator::Sma(20)),
876                ]),
877                1.0,
878            )
879            .add(
880                IndStrategy(vec![
881                    ("sma_20".to_string(), Indicator::Sma(20)), // duplicate
882                    ("rsi_14".to_string(), Indicator::Rsi(14)),
883                ]),
884                1.0,
885            )
886            .build();
887
888        let indicators = ensemble.required_indicators();
889        assert_eq!(indicators.len(), 3);
890        assert!(indicators.iter().any(|(k, _)| k == "sma_10"));
891        assert!(indicators.iter().any(|(k, _)| k == "sma_20"));
892        assert!(indicators.iter().any(|(k, _)| k == "rsi_14"));
893    }
894}