Skip to main content

finance_query/backtesting/strategy/
ensemble.rs

1//! Ensemble strategy that combines multiple strategies with configurable voting modes.
2//!
3//! An ensemble aggregates signals from multiple sub-strategies and resolves them
4//! using one of four voting modes.
5//!
6//! # Example
7//!
8//! ```no_run
9//! use finance_query::backtesting::{EnsembleStrategy, EnsembleMode, SmaCrossover, RsiReversal, MacdSignal};
10//!
11//! let strategy = EnsembleStrategy::new("Multi-Signal")
12//!     .add(SmaCrossover::new(10, 50), 1.0)
13//!     .add(RsiReversal::default(), 0.5)
14//!     .add(MacdSignal::default(), 1.0)
15//!     .mode(EnsembleMode::WeightedMajority)
16//!     .build();
17//! ```
18
19use crate::indicators::Indicator;
20
21use super::{Signal, Strategy, StrategyContext};
22use crate::backtesting::signal::{SignalDirection, SignalStrength};
23
24/// Voting mode that determines how sub-strategy signals are combined.
25#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
26pub enum EnsembleMode {
27    /// All active sub-strategies must agree on the same direction ("no-dissent" semantics).
28    ///
29    /// Strategies that return `Hold` abstain from the vote and do not block a consensus.
30    /// If any two *active* strategies disagree, the ensemble returns Hold.
31    /// The resulting signal strength is the average of all active strengths.
32    ///
33    /// **Note**: if you require *all* strategies (including abstainers) to explicitly
34    /// vote for the same direction, check `active_count == strategies.len()` yourself
35    /// and wrap this in a custom [`Strategy`] impl.
36    Unanimous,
37
38    /// Conviction-weighted vote: each vote is `weight × signal_strength`.
39    ///
40    /// All five directions (`Long`, `Short`, `Exit`, `ScaleIn`, `ScaleOut`) are
41    /// tallied independently; the highest score wins.
42    ///
43    /// **Strength denominator**: the output `signal_strength` is
44    /// `winner_score / Σ(all_weights)`, not `winner_score / Σ(active_scores)`.
45    /// Dividing by total potential prevents a lone weak voter from being
46    /// artificially amplified when the majority of sub-strategies abstain.
47    ///
48    /// **Scale fraction**: when `ScaleIn` or `ScaleOut` wins, the emitted
49    /// `scale_fraction` is the conviction-weighted average of all same-direction
50    /// voters' fractions, not the single highest-score contributor's fraction.
51    ///
52    /// **Position guard**: `Exit`, `ScaleIn`, and `ScaleOut` are only tallied
53    /// when a position is currently open. While flat they are discarded so they
54    /// cannot suppress entry signal strength. Their weights still count toward
55    /// the total-potential denominator.
56    ///
57    /// **Note on vote splitting**: `Exit` and `Short` are counted as independent
58    /// factions — if their combined intent would dominate but each falls below
59    /// `Long` individually, the ensemble may maintain a long position despite the
60    /// majority wanting to exit. Document your ensemble weights accordingly.
61    #[default]
62    WeightedMajority,
63
64    /// First non-Hold signal wins (strategies are evaluated in insertion order).
65    ///
66    /// **Note**: this gives a permanent priority advantage to strategies added
67    /// first via [`add`](EnsembleStrategy::add). Use [`StrongestSignal`](Self::StrongestSignal)
68    /// if you want insertion-order independence.
69    AnySignal,
70
71    /// The non-Hold signal with the highest `signal_strength` value wins.
72    StrongestSignal,
73}
74
75/// A strategy that aggregates signals from multiple sub-strategies.
76///
77/// Build with the fluent builder methods [`add`](Self::add), [`mode`](Self::mode),
78/// then finalise with [`build`](Self::build).
79///
80/// All six [`SignalDirection`](crate::backtesting::SignalDirection) variants are
81/// fully supported. In [`EnsembleMode::WeightedMajority`], `ScaleIn` and `ScaleOut`
82/// participate in the vote with the same position guard as `Exit` — they are only
83/// tallied when a position is open. In `Unanimous`, `AnySignal`, and
84/// `StrongestSignal` modes they are treated like any other non-Hold direction.
85pub struct EnsembleStrategy {
86    name: String,
87    strategies: Vec<(Box<dyn Strategy>, f64)>,
88    mode: EnsembleMode,
89}
90
91impl EnsembleStrategy {
92    /// Create a new ensemble with the given name.
93    ///
94    /// The default voting mode is [`EnsembleMode::WeightedMajority`].
95    pub fn new(name: impl Into<String>) -> Self {
96        Self {
97            name: name.into(),
98            strategies: Vec::new(),
99            mode: EnsembleMode::default(),
100        }
101    }
102
103    /// Add a sub-strategy with the given weight.
104    ///
105    /// Weight is only meaningful for [`EnsembleMode::WeightedMajority`]; other
106    /// modes ignore it. Negative weights are treated as zero.
107    pub fn add<S: Strategy + 'static>(mut self, strategy: S, weight: f64) -> Self {
108        self.strategies.push((Box::new(strategy), weight.max(0.0)));
109        self
110    }
111
112    /// Set the voting mode.
113    pub fn mode(mut self, mode: EnsembleMode) -> Self {
114        self.mode = mode;
115        self
116    }
117
118    /// Finalise the ensemble. Returns `self` (all configuration happens in the
119    /// builder methods).
120    pub fn build(self) -> Self {
121        self
122    }
123
124    // ── voting helpers ────────────────────────────────────────────────────────
125
126    fn any_signal(&self, ctx: &StrategyContext) -> Signal {
127        for (strategy, _) in &self.strategies {
128            let signal = strategy.on_candle(ctx);
129            if !signal.is_hold() {
130                return signal;
131            }
132        }
133        Signal::hold()
134    }
135
136    fn unanimous(&self, ctx: &StrategyContext) -> Signal {
137        // Evaluated iteratively — no allocations in the hot path.
138        // As soon as any two active sub-strategies disagree we bail out early.
139        let mut first_dir: Option<SignalDirection> = None;
140        let mut first_signal: Option<Signal> = None;
141        let mut total_strength = 0.0_f64;
142        let mut active_count = 0_usize;
143
144        for (strategy, _) in &self.strategies {
145            let signal = strategy.on_candle(ctx);
146            if signal.is_hold() {
147                continue;
148            }
149            match first_dir {
150                None => {
151                    first_dir = Some(signal.direction);
152                    total_strength = signal.strength.value();
153                    first_signal = Some(signal);
154                    active_count = 1;
155                }
156                Some(dir) if dir == signal.direction => {
157                    total_strength += signal.strength.value();
158                    active_count += 1;
159                }
160                _ => return Signal::hold(), // disagreement — short-circuit
161            }
162        }
163
164        let Some(mut sig) = first_signal else {
165            return Signal::hold();
166        };
167
168        let dir = first_dir.unwrap();
169        let avg_strength = total_strength / active_count as f64;
170        let original_reason = sig.reason.take();
171        sig.strength = SignalStrength::clamped(avg_strength);
172        sig.reason = Some(format!(
173            "Unanimous ({} of {} agree): {}",
174            active_count,
175            self.strategies.len(),
176            original_reason.as_deref().unwrap_or(&dir.to_string())
177        ));
178        sig
179    }
180
181    fn weighted_majority(&self, ctx: &StrategyContext) -> Signal {
182        // Denominator = sum of ALL strategy weights (total potential conviction).
183        // Prevents a lone weak voter from being artificially amplified when
184        // the majority of sub-strategies abstain.
185        let total_potential: f64 = self.strategies.iter().map(|(_, w)| *w).sum();
186        if total_potential < f64::EPSILON {
187            return Signal::hold();
188        }
189
190        let mut long_weight = 0.0_f64;
191        let mut short_weight = 0.0_f64;
192        let mut exit_weight = 0.0_f64;
193        let mut scale_in_weight = 0.0_f64;
194        let mut scale_out_weight = 0.0_f64;
195
196        // Σ(scale_fraction × score) for conviction-weighted average fraction.
197        let mut scale_in_frac_score = 0.0_f64;
198        let mut scale_out_frac_score = 0.0_f64;
199
200        // Track (signal, score) so we inherit metadata from the highest-conviction
201        // contributor, not merely the first one encountered.
202        let mut best_long: Option<(Signal, f64)> = None;
203        let mut best_short: Option<(Signal, f64)> = None;
204        let mut best_exit: Option<(Signal, f64)> = None;
205        let mut best_scale_in: Option<(Signal, f64)> = None;
206        let mut best_scale_out: Option<(Signal, f64)> = None;
207
208        let has_position = ctx.has_position();
209
210        for (strategy, weight) in &self.strategies {
211            let signal = strategy.on_candle(ctx);
212            // Vote score = static weight × dynamic conviction
213            let score = weight * signal.strength.value();
214            match signal.direction {
215                SignalDirection::Long => {
216                    long_weight += score;
217                    if best_long.as_ref().is_none_or(|&(_, s)| score > s) {
218                        best_long = Some((signal, score));
219                    }
220                }
221                SignalDirection::Short => {
222                    short_weight += score;
223                    if best_short.as_ref().is_none_or(|&(_, s)| score > s) {
224                        best_short = Some((signal, score));
225                    }
226                }
227                // Exit, ScaleIn, ScaleOut require an open position — while flat
228                // they are discarded so they cannot suppress entry signal strength.
229                // Their weights still count toward total_potential (denominator).
230                SignalDirection::Exit if has_position => {
231                    exit_weight += score;
232                    if best_exit.as_ref().is_none_or(|&(_, s)| score > s) {
233                        best_exit = Some((signal, score));
234                    }
235                }
236                SignalDirection::ScaleIn if has_position => {
237                    let frac = signal.scale_fraction.unwrap_or(0.0);
238                    scale_in_weight += score;
239                    scale_in_frac_score += frac * score;
240                    if best_scale_in.as_ref().is_none_or(|&(_, s)| score > s) {
241                        best_scale_in = Some((signal, score));
242                    }
243                }
244                SignalDirection::ScaleOut if has_position => {
245                    let frac = signal.scale_fraction.unwrap_or(0.0);
246                    scale_out_weight += score;
247                    scale_out_frac_score += frac * score;
248                    if best_scale_out.as_ref().is_none_or(|&(_, s)| score > s) {
249                        best_scale_out = Some((signal, score));
250                    }
251                }
252                _ => {}
253            }
254        }
255
256        // At least one strategy must have cast a non-Hold vote.
257        let total_active =
258            long_weight + short_weight + exit_weight + scale_in_weight + scale_out_weight;
259        if total_active < f64::EPSILON {
260            return Signal::hold();
261        }
262
263        // Determine winner (strict majority across all five directions; ties → Hold)
264        let (winner, winner_score) = if long_weight > short_weight
265            && long_weight > exit_weight
266            && long_weight > scale_in_weight
267            && long_weight > scale_out_weight
268        {
269            (best_long, long_weight)
270        } else if short_weight > long_weight
271            && short_weight > exit_weight
272            && short_weight > scale_in_weight
273            && short_weight > scale_out_weight
274        {
275            (best_short, short_weight)
276        } else if exit_weight > long_weight
277            && exit_weight > short_weight
278            && exit_weight > scale_in_weight
279            && exit_weight > scale_out_weight
280        {
281            (best_exit, exit_weight)
282        } else if scale_in_weight > long_weight
283            && scale_in_weight > short_weight
284            && scale_in_weight > exit_weight
285            && scale_in_weight > scale_out_weight
286        {
287            (best_scale_in, scale_in_weight)
288        } else if scale_out_weight > long_weight
289            && scale_out_weight > short_weight
290            && scale_out_weight > exit_weight
291            && scale_out_weight > scale_in_weight
292        {
293            (best_scale_out, scale_out_weight)
294        } else {
295            return Signal::hold();
296        };
297
298        let Some((mut sig, _)) = winner else {
299            return Signal::hold();
300        };
301
302        // Strength = winner score / total potential (not just active votes).
303        sig.strength = SignalStrength::clamped(winner_score / total_potential);
304
305        // Replace inherited scale_fraction with the conviction-weighted average
306        // of all same-direction voters to avoid single-contributor size bias.
307        if sig.direction == SignalDirection::ScaleIn && scale_in_weight > f64::EPSILON {
308            sig.scale_fraction = Some((scale_in_frac_score / scale_in_weight).clamp(0.0, 1.0));
309        } else if sig.direction == SignalDirection::ScaleOut && scale_out_weight > f64::EPSILON {
310            sig.scale_fraction = Some((scale_out_frac_score / scale_out_weight).clamp(0.0, 1.0));
311        }
312
313        sig.reason = Some(format!(
314            "WeightedMajority: long={long_weight:.2} short={short_weight:.2} \
315             exit={exit_weight:.2} scale_in={scale_in_weight:.2} scale_out={scale_out_weight:.2}"
316        ));
317        sig
318    }
319
320    fn strongest_signal(&self, ctx: &StrategyContext) -> Signal {
321        self.strategies
322            .iter()
323            .map(|(s, _)| s.on_candle(ctx))
324            .filter(|s| !s.is_hold())
325            .max_by(|a, b| a.strength.value().total_cmp(&b.strength.value()))
326            .unwrap_or_else(Signal::hold)
327    }
328}
329
330impl std::fmt::Debug for EnsembleStrategy {
331    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
332        f.debug_struct("EnsembleStrategy")
333            .field("name", &self.name)
334            .field("strategies_count", &self.strategies.len())
335            .field("mode", &self.mode)
336            .finish()
337    }
338}
339
340impl Strategy for EnsembleStrategy {
341    fn name(&self) -> &str {
342        &self.name
343    }
344
345    fn required_indicators(&self) -> Vec<(String, Indicator)> {
346        let mut indicators: Vec<(String, Indicator)> = self
347            .strategies
348            .iter()
349            .flat_map(|(s, _)| s.required_indicators())
350            .collect();
351        indicators.sort_by(|a, b| a.0.cmp(&b.0));
352        indicators.dedup_by(|a, b| a.0 == b.0);
353        indicators
354    }
355
356    fn setup(&mut self, indicators: &std::collections::HashMap<String, Vec<Option<f64>>>) {
357        for (strategy, _) in &mut self.strategies {
358            strategy.setup(indicators);
359        }
360    }
361
362    fn warmup_period(&self) -> usize {
363        self.strategies
364            .iter()
365            .map(|(s, _)| s.warmup_period())
366            .max()
367            .unwrap_or(1)
368    }
369
370    fn on_candle(&self, ctx: &StrategyContext) -> Signal {
371        if self.strategies.is_empty() {
372            return Signal::hold();
373        }
374        match self.mode {
375            EnsembleMode::AnySignal => self.any_signal(ctx),
376            EnsembleMode::Unanimous => self.unanimous(ctx),
377            EnsembleMode::WeightedMajority => self.weighted_majority(ctx),
378            EnsembleMode::StrongestSignal => self.strongest_signal(ctx),
379        }
380    }
381}
382
383#[cfg(test)]
384mod tests {
385    use super::*;
386    use crate::backtesting::signal::SignalDirection;
387    use crate::backtesting::strategy::Strategy;
388    use crate::indicators::Indicator;
389    use crate::models::chart::Candle;
390    use std::collections::HashMap;
391
392    fn make_candle(ts: i64, price: f64) -> Candle {
393        Candle {
394            timestamp: ts,
395            open: price,
396            high: price,
397            low: price,
398            close: price,
399            volume: 1000,
400            adj_close: None,
401            provider_id: None,
402        }
403    }
404
405    fn make_ctx<'a>(
406        candles: &'a [Candle],
407        indicators: &'a HashMap<String, Vec<Option<f64>>>,
408    ) -> StrategyContext<'a> {
409        StrategyContext {
410            candles,
411            index: candles.len() - 1,
412            position: None,
413            equity: 10_000.0,
414            indicators,
415        }
416    }
417
418    // A strategy that always emits the given direction
419    struct FixedStrategy {
420        direction: SignalDirection,
421        strength: f64,
422    }
423
424    impl Strategy for FixedStrategy {
425        fn name(&self) -> &str {
426            "Fixed"
427        }
428        fn required_indicators(&self) -> Vec<(String, Indicator)> {
429            vec![]
430        }
431        fn on_candle(&self, ctx: &StrategyContext) -> Signal {
432            match self.direction {
433                SignalDirection::Long => {
434                    let mut s = Signal::long(ctx.timestamp(), ctx.close());
435                    s.strength = SignalStrength::clamped(self.strength);
436                    s
437                }
438                SignalDirection::Short => {
439                    let mut s = Signal::short(ctx.timestamp(), ctx.close());
440                    s.strength = SignalStrength::clamped(self.strength);
441                    s
442                }
443                SignalDirection::Exit => {
444                    let mut s = Signal::exit(ctx.timestamp(), ctx.close());
445                    s.strength = SignalStrength::clamped(self.strength);
446                    s
447                }
448                SignalDirection::ScaleIn => {
449                    let mut s = Signal::scale_in(0.1, ctx.timestamp(), ctx.close());
450                    s.strength = SignalStrength::clamped(self.strength);
451                    s
452                }
453                SignalDirection::ScaleOut => {
454                    let mut s = Signal::scale_out(0.5, ctx.timestamp(), ctx.close());
455                    s.strength = SignalStrength::clamped(self.strength);
456                    s
457                }
458                _ => Signal::hold(),
459            }
460        }
461    }
462
463    fn make_ctx_with_position<'a>(
464        candles: &'a [Candle],
465        indicators: &'a HashMap<String, Vec<Option<f64>>>,
466        position: &'a crate::backtesting::Position,
467    ) -> StrategyContext<'a> {
468        StrategyContext {
469            candles,
470            index: candles.len() - 1,
471            position: Some(position),
472            equity: 10_000.0,
473            indicators,
474        }
475    }
476
477    fn candles() -> Vec<Candle> {
478        vec![make_candle(1, 100.0), make_candle(2, 101.0)]
479    }
480
481    fn empty_indicators() -> HashMap<String, Vec<Option<f64>>> {
482        HashMap::new()
483    }
484
485    #[test]
486    fn test_any_signal_returns_first_non_hold() {
487        let c = candles();
488        let ind = empty_indicators();
489        let ctx = make_ctx(&c, &ind);
490
491        let ensemble = EnsembleStrategy::new("test")
492            .add(
493                FixedStrategy {
494                    direction: SignalDirection::Hold,
495                    strength: 1.0,
496                },
497                1.0,
498            )
499            .add(
500                FixedStrategy {
501                    direction: SignalDirection::Long,
502                    strength: 0.8,
503                },
504                1.0,
505            )
506            .add(
507                FixedStrategy {
508                    direction: SignalDirection::Short,
509                    strength: 1.0,
510                },
511                1.0,
512            )
513            .mode(EnsembleMode::AnySignal)
514            .build();
515
516        let signal = ensemble.on_candle(&ctx);
517        assert_eq!(signal.direction, SignalDirection::Long);
518    }
519
520    #[test]
521    fn test_unanimous_all_agree() {
522        let c = candles();
523        let ind = empty_indicators();
524        let ctx = make_ctx(&c, &ind);
525
526        let ensemble = EnsembleStrategy::new("test")
527            .add(
528                FixedStrategy {
529                    direction: SignalDirection::Long,
530                    strength: 1.0,
531                },
532                1.0,
533            )
534            .add(
535                FixedStrategy {
536                    direction: SignalDirection::Long,
537                    strength: 0.6,
538                },
539                1.0,
540            )
541            .mode(EnsembleMode::Unanimous)
542            .build();
543
544        let signal = ensemble.on_candle(&ctx);
545        assert_eq!(signal.direction, SignalDirection::Long);
546        assert!((signal.strength.value() - 0.8).abs() < 1e-9); // avg of 1.0 and 0.6
547    }
548
549    #[test]
550    fn test_unanimous_disagreement_returns_hold() {
551        let c = candles();
552        let ind = empty_indicators();
553        let ctx = make_ctx(&c, &ind);
554
555        let ensemble = EnsembleStrategy::new("test")
556            .add(
557                FixedStrategy {
558                    direction: SignalDirection::Long,
559                    strength: 1.0,
560                },
561                1.0,
562            )
563            .add(
564                FixedStrategy {
565                    direction: SignalDirection::Short,
566                    strength: 1.0,
567                },
568                1.0,
569            )
570            .mode(EnsembleMode::Unanimous)
571            .build();
572
573        let signal = ensemble.on_candle(&ctx);
574        assert!(signal.is_hold());
575    }
576
577    #[test]
578    fn test_weighted_majority_long_wins() {
579        let c = candles();
580        let ind = empty_indicators();
581        let ctx = make_ctx(&c, &ind);
582
583        let ensemble = EnsembleStrategy::new("test")
584            .add(
585                FixedStrategy {
586                    direction: SignalDirection::Long,
587                    strength: 1.0,
588                },
589                2.0,
590            )
591            .add(
592                FixedStrategy {
593                    direction: SignalDirection::Short,
594                    strength: 1.0,
595                },
596                1.0,
597            )
598            .mode(EnsembleMode::WeightedMajority)
599            .build();
600
601        let signal = ensemble.on_candle(&ctx);
602        assert_eq!(signal.direction, SignalDirection::Long);
603        // strength = 2.0 / 3.0
604        assert!((signal.strength.value() - 2.0 / 3.0).abs() < 1e-9);
605    }
606
607    #[test]
608    fn test_weighted_majority_tie_returns_hold() {
609        let c = candles();
610        let ind = empty_indicators();
611        let ctx = make_ctx(&c, &ind);
612
613        let ensemble = EnsembleStrategy::new("test")
614            .add(
615                FixedStrategy {
616                    direction: SignalDirection::Long,
617                    strength: 1.0,
618                },
619                1.0,
620            )
621            .add(
622                FixedStrategy {
623                    direction: SignalDirection::Short,
624                    strength: 1.0,
625                },
626                1.0,
627            )
628            .mode(EnsembleMode::WeightedMajority)
629            .build();
630
631        let signal = ensemble.on_candle(&ctx);
632        assert!(signal.is_hold());
633    }
634
635    #[test]
636    fn test_strongest_signal() {
637        let c = candles();
638        let ind = empty_indicators();
639        let ctx = make_ctx(&c, &ind);
640
641        let ensemble = EnsembleStrategy::new("test")
642            .add(
643                FixedStrategy {
644                    direction: SignalDirection::Long,
645                    strength: 0.4,
646                },
647                1.0,
648            )
649            .add(
650                FixedStrategy {
651                    direction: SignalDirection::Short,
652                    strength: 0.9,
653                },
654                1.0,
655            )
656            .mode(EnsembleMode::StrongestSignal)
657            .build();
658
659        let signal = ensemble.on_candle(&ctx);
660        assert_eq!(signal.direction, SignalDirection::Short);
661        assert!((signal.strength.value() - 0.9).abs() < 1e-9);
662    }
663
664    #[test]
665    fn test_empty_ensemble_returns_hold() {
666        let c = candles();
667        let ind = empty_indicators();
668        let ctx = make_ctx(&c, &ind);
669
670        let ensemble = EnsembleStrategy::new("empty").build();
671        assert!(ensemble.on_candle(&ctx).is_hold());
672    }
673
674    #[test]
675    fn test_warmup_is_max_of_sub_strategies() {
676        struct WarmupStrategy(usize);
677        impl Strategy for WarmupStrategy {
678            fn name(&self) -> &str {
679                "Warmup"
680            }
681            fn required_indicators(&self) -> Vec<(String, Indicator)> {
682                vec![]
683            }
684            fn on_candle(&self, _ctx: &StrategyContext) -> Signal {
685                Signal::hold()
686            }
687            fn warmup_period(&self) -> usize {
688                self.0
689            }
690        }
691
692        let ensemble = EnsembleStrategy::new("test")
693            .add(WarmupStrategy(10), 1.0)
694            .add(WarmupStrategy(25), 1.0)
695            .add(WarmupStrategy(5), 1.0)
696            .build();
697
698        assert_eq!(ensemble.warmup_period(), 25);
699    }
700
701    #[test]
702    fn test_weighted_majority_exit_ignored_when_flat() {
703        // Exit votes should not suppress Long conviction when there is no position.
704        let c = candles();
705        let ind = empty_indicators();
706        let ctx = make_ctx(&c, &ind); // position = None
707
708        // Exit weight would dominate if counted (3.0 vs Long 2.0), but while flat
709        // it must be discarded and Long should win.
710        let ensemble = EnsembleStrategy::new("test")
711            .add(
712                FixedStrategy {
713                    direction: SignalDirection::Long,
714                    strength: 1.0,
715                },
716                2.0,
717            )
718            .add(
719                FixedStrategy {
720                    direction: SignalDirection::Exit,
721                    strength: 1.0,
722                },
723                3.0,
724            )
725            .mode(EnsembleMode::WeightedMajority)
726            .build();
727
728        let signal = ensemble.on_candle(&ctx);
729        assert_eq!(signal.direction, SignalDirection::Long);
730        // strength = 2.0 / 5.0 = 0.4 (exit weight counts in denominator even when
731        // its vote is discarded, correctly suppressing overconfidence)
732        assert!((signal.strength.value() - 2.0 / 5.0).abs() < 1e-9);
733    }
734
735    #[test]
736    fn test_weighted_majority_scale_in_wins_when_position_open() {
737        use crate::backtesting::{Position, PositionSide};
738
739        let c = candles();
740        let ind = empty_indicators();
741        let pos = Position::new(
742            PositionSide::Long,
743            1,
744            100.0,
745            10.0,
746            0.0,
747            Signal::long(1, 100.0),
748        );
749        let ctx = make_ctx_with_position(&c, &ind, &pos);
750
751        // ScaleIn has the highest conviction score (3.0) while Long has 2.0.
752        let ensemble = EnsembleStrategy::new("test")
753            .add(
754                FixedStrategy {
755                    direction: SignalDirection::Long,
756                    strength: 1.0,
757                },
758                2.0,
759            )
760            .add(
761                FixedStrategy {
762                    direction: SignalDirection::ScaleIn,
763                    strength: 1.0,
764                },
765                3.0,
766            )
767            .mode(EnsembleMode::WeightedMajority)
768            .build();
769
770        let signal = ensemble.on_candle(&ctx);
771        assert_eq!(signal.direction, SignalDirection::ScaleIn);
772        // strength = 3.0 / 5.0
773        assert!((signal.strength.value() - 0.6).abs() < 1e-9);
774        // scale_fraction is the conviction-weighted average of all ScaleIn voters
775        let frac = signal.scale_fraction.expect("scale_fraction must be set");
776        assert!((frac - 0.1).abs() < 1e-9, "expected 0.10, got {frac}");
777    }
778
779    #[test]
780    fn test_weighted_majority_scale_fraction_is_conviction_weighted_average() {
781        use crate::backtesting::{Position, PositionSide};
782
783        let c = candles();
784        let ind = empty_indicators();
785        let pos = Position::new(
786            PositionSide::Long,
787            1,
788            100.0,
789            10.0,
790            0.0,
791            Signal::long(1, 100.0),
792        );
793        let ctx = make_ctx_with_position(&c, &ind, &pos);
794
795        // Strategy A: ScaleOut 10% with score 1.0 × 1.0 = 1.0
796        // Strategy B: ScaleOut 50% with score 1.0 × 1.0 = 1.0
797        // Weighted-average fraction = (0.10 × 1.0 + 0.50 × 1.0) / 2.0 = 0.30
798        struct ScaleOutStrategy {
799            fraction: f64,
800        }
801        impl Strategy for ScaleOutStrategy {
802            fn name(&self) -> &str {
803                "ScaleOut"
804            }
805            fn required_indicators(&self) -> Vec<(String, Indicator)> {
806                vec![]
807            }
808            fn on_candle(&self, ctx: &StrategyContext) -> Signal {
809                Signal::scale_out(self.fraction, ctx.timestamp(), ctx.close())
810            }
811        }
812
813        let ensemble = EnsembleStrategy::new("test")
814            .add(ScaleOutStrategy { fraction: 0.10 }, 1.0)
815            .add(ScaleOutStrategy { fraction: 0.50 }, 1.0)
816            .mode(EnsembleMode::WeightedMajority)
817            .build();
818
819        let signal = ensemble.on_candle(&ctx);
820        assert_eq!(signal.direction, SignalDirection::ScaleOut);
821        let frac = signal.scale_fraction.expect("scale_fraction must be set");
822        assert!((frac - 0.30).abs() < 1e-9, "expected 0.30, got {frac}");
823    }
824
825    #[test]
826    fn test_weighted_majority_scale_in_ignored_when_flat() {
827        let c = candles();
828        let ind = empty_indicators();
829        let ctx = make_ctx(&c, &ind); // position = None
830
831        // ScaleIn would dominate (3.0 vs Long 2.0) but must be ignored while flat.
832        let ensemble = EnsembleStrategy::new("test")
833            .add(
834                FixedStrategy {
835                    direction: SignalDirection::Long,
836                    strength: 1.0,
837                },
838                2.0,
839            )
840            .add(
841                FixedStrategy {
842                    direction: SignalDirection::ScaleIn,
843                    strength: 1.0,
844                },
845                3.0,
846            )
847            .mode(EnsembleMode::WeightedMajority)
848            .build();
849
850        let signal = ensemble.on_candle(&ctx);
851        assert_eq!(signal.direction, SignalDirection::Long);
852        // strength = 2.0 / 5.0 = 0.4 (scale_in weight counts in denominator even
853        // when its vote is discarded while flat)
854        assert!((signal.strength.value() - 2.0 / 5.0).abs() < 1e-9);
855    }
856
857    #[test]
858    fn test_required_indicators_deduplication() {
859        struct IndStrategy(Vec<(String, Indicator)>);
860        impl Strategy for IndStrategy {
861            fn name(&self) -> &str {
862                "Ind"
863            }
864            fn required_indicators(&self) -> Vec<(String, Indicator)> {
865                self.0.clone()
866            }
867            fn on_candle(&self, _ctx: &StrategyContext) -> Signal {
868                Signal::hold()
869            }
870        }
871
872        let ensemble = EnsembleStrategy::new("test")
873            .add(
874                IndStrategy(vec![
875                    ("sma_10".to_string(), Indicator::Sma(10)),
876                    ("sma_20".to_string(), Indicator::Sma(20)),
877                ]),
878                1.0,
879            )
880            .add(
881                IndStrategy(vec![
882                    ("sma_20".to_string(), Indicator::Sma(20)), // duplicate
883                    ("rsi_14".to_string(), Indicator::Rsi(14)),
884                ]),
885                1.0,
886            )
887            .build();
888
889        let indicators = ensemble.required_indicators();
890        assert_eq!(indicators.len(), 3);
891        assert!(indicators.iter().any(|(k, _)| k == "sma_10"));
892        assert!(indicators.iter().any(|(k, _)| k == "sma_20"));
893        assert!(indicators.iter().any(|(k, _)| k == "rsi_14"));
894    }
895}