1use crate::indicators::Indicator;
20
21use super::{Signal, Strategy, StrategyContext};
22use crate::backtesting::signal::{SignalDirection, SignalStrength};
23
24#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
26pub enum EnsembleMode {
27 Unanimous,
37
38 #[default]
62 WeightedMajority,
63
64 AnySignal,
70
71 StrongestSignal,
73}
74
75pub struct EnsembleStrategy {
86 name: String,
87 strategies: Vec<(Box<dyn Strategy>, f64)>,
88 mode: EnsembleMode,
89}
90
91impl EnsembleStrategy {
92 pub fn new(name: impl Into<String>) -> Self {
96 Self {
97 name: name.into(),
98 strategies: Vec::new(),
99 mode: EnsembleMode::default(),
100 }
101 }
102
103 pub fn add<S: Strategy + 'static>(mut self, strategy: S, weight: f64) -> Self {
108 self.strategies.push((Box::new(strategy), weight.max(0.0)));
109 self
110 }
111
112 pub fn mode(mut self, mode: EnsembleMode) -> Self {
114 self.mode = mode;
115 self
116 }
117
118 pub fn build(self) -> Self {
121 self
122 }
123
124 fn any_signal(&self, ctx: &StrategyContext) -> Signal {
127 for (strategy, _) in &self.strategies {
128 let signal = strategy.on_candle(ctx);
129 if !signal.is_hold() {
130 return signal;
131 }
132 }
133 Signal::hold()
134 }
135
136 fn unanimous(&self, ctx: &StrategyContext) -> Signal {
137 let mut first_dir: Option<SignalDirection> = None;
140 let mut first_signal: Option<Signal> = None;
141 let mut total_strength = 0.0_f64;
142 let mut active_count = 0_usize;
143
144 for (strategy, _) in &self.strategies {
145 let signal = strategy.on_candle(ctx);
146 if signal.is_hold() {
147 continue;
148 }
149 match first_dir {
150 None => {
151 first_dir = Some(signal.direction);
152 total_strength = signal.strength.value();
153 first_signal = Some(signal);
154 active_count = 1;
155 }
156 Some(dir) if dir == signal.direction => {
157 total_strength += signal.strength.value();
158 active_count += 1;
159 }
160 _ => return Signal::hold(), }
162 }
163
164 let Some(mut sig) = first_signal else {
165 return Signal::hold();
166 };
167
168 let dir = first_dir.unwrap();
169 let avg_strength = total_strength / active_count as f64;
170 let original_reason = sig.reason.take();
171 sig.strength = SignalStrength::clamped(avg_strength);
172 sig.reason = Some(format!(
173 "Unanimous ({} of {} agree): {}",
174 active_count,
175 self.strategies.len(),
176 original_reason.as_deref().unwrap_or(&dir.to_string())
177 ));
178 sig
179 }
180
181 fn weighted_majority(&self, ctx: &StrategyContext) -> Signal {
182 let total_potential: f64 = self.strategies.iter().map(|(_, w)| *w).sum();
186 if total_potential < f64::EPSILON {
187 return Signal::hold();
188 }
189
190 let mut long_weight = 0.0_f64;
191 let mut short_weight = 0.0_f64;
192 let mut exit_weight = 0.0_f64;
193 let mut scale_in_weight = 0.0_f64;
194 let mut scale_out_weight = 0.0_f64;
195
196 let mut scale_in_frac_score = 0.0_f64;
198 let mut scale_out_frac_score = 0.0_f64;
199
200 let mut best_long: Option<(Signal, f64)> = None;
203 let mut best_short: Option<(Signal, f64)> = None;
204 let mut best_exit: Option<(Signal, f64)> = None;
205 let mut best_scale_in: Option<(Signal, f64)> = None;
206 let mut best_scale_out: Option<(Signal, f64)> = None;
207
208 let has_position = ctx.has_position();
209
210 for (strategy, weight) in &self.strategies {
211 let signal = strategy.on_candle(ctx);
212 let score = weight * signal.strength.value();
214 match signal.direction {
215 SignalDirection::Long => {
216 long_weight += score;
217 if best_long.as_ref().is_none_or(|&(_, s)| score > s) {
218 best_long = Some((signal, score));
219 }
220 }
221 SignalDirection::Short => {
222 short_weight += score;
223 if best_short.as_ref().is_none_or(|&(_, s)| score > s) {
224 best_short = Some((signal, score));
225 }
226 }
227 SignalDirection::Exit if has_position => {
231 exit_weight += score;
232 if best_exit.as_ref().is_none_or(|&(_, s)| score > s) {
233 best_exit = Some((signal, score));
234 }
235 }
236 SignalDirection::ScaleIn if has_position => {
237 let frac = signal.scale_fraction.unwrap_or(0.0);
238 scale_in_weight += score;
239 scale_in_frac_score += frac * score;
240 if best_scale_in.as_ref().is_none_or(|&(_, s)| score > s) {
241 best_scale_in = Some((signal, score));
242 }
243 }
244 SignalDirection::ScaleOut if has_position => {
245 let frac = signal.scale_fraction.unwrap_or(0.0);
246 scale_out_weight += score;
247 scale_out_frac_score += frac * score;
248 if best_scale_out.as_ref().is_none_or(|&(_, s)| score > s) {
249 best_scale_out = Some((signal, score));
250 }
251 }
252 _ => {}
253 }
254 }
255
256 let total_active =
258 long_weight + short_weight + exit_weight + scale_in_weight + scale_out_weight;
259 if total_active < f64::EPSILON {
260 return Signal::hold();
261 }
262
263 let (winner, winner_score) = if long_weight > short_weight
265 && long_weight > exit_weight
266 && long_weight > scale_in_weight
267 && long_weight > scale_out_weight
268 {
269 (best_long, long_weight)
270 } else if short_weight > long_weight
271 && short_weight > exit_weight
272 && short_weight > scale_in_weight
273 && short_weight > scale_out_weight
274 {
275 (best_short, short_weight)
276 } else if exit_weight > long_weight
277 && exit_weight > short_weight
278 && exit_weight > scale_in_weight
279 && exit_weight > scale_out_weight
280 {
281 (best_exit, exit_weight)
282 } else if scale_in_weight > long_weight
283 && scale_in_weight > short_weight
284 && scale_in_weight > exit_weight
285 && scale_in_weight > scale_out_weight
286 {
287 (best_scale_in, scale_in_weight)
288 } else if scale_out_weight > long_weight
289 && scale_out_weight > short_weight
290 && scale_out_weight > exit_weight
291 && scale_out_weight > scale_in_weight
292 {
293 (best_scale_out, scale_out_weight)
294 } else {
295 return Signal::hold();
296 };
297
298 let Some((mut sig, _)) = winner else {
299 return Signal::hold();
300 };
301
302 sig.strength = SignalStrength::clamped(winner_score / total_potential);
304
305 if sig.direction == SignalDirection::ScaleIn && scale_in_weight > f64::EPSILON {
308 sig.scale_fraction = Some((scale_in_frac_score / scale_in_weight).clamp(0.0, 1.0));
309 } else if sig.direction == SignalDirection::ScaleOut && scale_out_weight > f64::EPSILON {
310 sig.scale_fraction = Some((scale_out_frac_score / scale_out_weight).clamp(0.0, 1.0));
311 }
312
313 sig.reason = Some(format!(
314 "WeightedMajority: long={long_weight:.2} short={short_weight:.2} \
315 exit={exit_weight:.2} scale_in={scale_in_weight:.2} scale_out={scale_out_weight:.2}"
316 ));
317 sig
318 }
319
320 fn strongest_signal(&self, ctx: &StrategyContext) -> Signal {
321 self.strategies
322 .iter()
323 .map(|(s, _)| s.on_candle(ctx))
324 .filter(|s| !s.is_hold())
325 .max_by(|a, b| a.strength.value().total_cmp(&b.strength.value()))
326 .unwrap_or_else(Signal::hold)
327 }
328}
329
330impl std::fmt::Debug for EnsembleStrategy {
331 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
332 f.debug_struct("EnsembleStrategy")
333 .field("name", &self.name)
334 .field("strategies_count", &self.strategies.len())
335 .field("mode", &self.mode)
336 .finish()
337 }
338}
339
340impl Strategy for EnsembleStrategy {
341 fn name(&self) -> &str {
342 &self.name
343 }
344
345 fn required_indicators(&self) -> Vec<(String, Indicator)> {
346 let mut indicators: Vec<(String, Indicator)> = self
347 .strategies
348 .iter()
349 .flat_map(|(s, _)| s.required_indicators())
350 .collect();
351 indicators.sort_by(|a, b| a.0.cmp(&b.0));
352 indicators.dedup_by(|a, b| a.0 == b.0);
353 indicators
354 }
355
356 fn warmup_period(&self) -> usize {
357 self.strategies
358 .iter()
359 .map(|(s, _)| s.warmup_period())
360 .max()
361 .unwrap_or(1)
362 }
363
364 fn on_candle(&self, ctx: &StrategyContext) -> Signal {
365 if self.strategies.is_empty() {
366 return Signal::hold();
367 }
368 match self.mode {
369 EnsembleMode::AnySignal => self.any_signal(ctx),
370 EnsembleMode::Unanimous => self.unanimous(ctx),
371 EnsembleMode::WeightedMajority => self.weighted_majority(ctx),
372 EnsembleMode::StrongestSignal => self.strongest_signal(ctx),
373 }
374 }
375}
376
377#[cfg(test)]
378mod tests {
379 use super::*;
380 use crate::backtesting::signal::SignalDirection;
381 use crate::backtesting::strategy::Strategy;
382 use crate::indicators::Indicator;
383 use crate::models::chart::Candle;
384 use std::collections::HashMap;
385
386 fn make_candle(ts: i64, price: f64) -> Candle {
387 Candle {
388 timestamp: ts,
389 open: price,
390 high: price,
391 low: price,
392 close: price,
393 volume: 1000,
394 adj_close: None,
395 }
396 }
397
398 fn make_ctx<'a>(
399 candles: &'a [Candle],
400 indicators: &'a HashMap<String, Vec<Option<f64>>>,
401 ) -> StrategyContext<'a> {
402 StrategyContext {
403 candles,
404 index: candles.len() - 1,
405 position: None,
406 equity: 10_000.0,
407 indicators,
408 }
409 }
410
411 struct FixedStrategy {
413 direction: SignalDirection,
414 strength: f64,
415 }
416
417 impl Strategy for FixedStrategy {
418 fn name(&self) -> &str {
419 "Fixed"
420 }
421 fn required_indicators(&self) -> Vec<(String, Indicator)> {
422 vec![]
423 }
424 fn on_candle(&self, ctx: &StrategyContext) -> Signal {
425 match self.direction {
426 SignalDirection::Long => {
427 let mut s = Signal::long(ctx.timestamp(), ctx.close());
428 s.strength = SignalStrength::clamped(self.strength);
429 s
430 }
431 SignalDirection::Short => {
432 let mut s = Signal::short(ctx.timestamp(), ctx.close());
433 s.strength = SignalStrength::clamped(self.strength);
434 s
435 }
436 SignalDirection::Exit => {
437 let mut s = Signal::exit(ctx.timestamp(), ctx.close());
438 s.strength = SignalStrength::clamped(self.strength);
439 s
440 }
441 SignalDirection::ScaleIn => {
442 let mut s = Signal::scale_in(0.1, ctx.timestamp(), ctx.close());
443 s.strength = SignalStrength::clamped(self.strength);
444 s
445 }
446 SignalDirection::ScaleOut => {
447 let mut s = Signal::scale_out(0.5, ctx.timestamp(), ctx.close());
448 s.strength = SignalStrength::clamped(self.strength);
449 s
450 }
451 _ => Signal::hold(),
452 }
453 }
454 }
455
456 fn make_ctx_with_position<'a>(
457 candles: &'a [Candle],
458 indicators: &'a HashMap<String, Vec<Option<f64>>>,
459 position: &'a crate::backtesting::Position,
460 ) -> StrategyContext<'a> {
461 StrategyContext {
462 candles,
463 index: candles.len() - 1,
464 position: Some(position),
465 equity: 10_000.0,
466 indicators,
467 }
468 }
469
470 fn candles() -> Vec<Candle> {
471 vec![make_candle(1, 100.0), make_candle(2, 101.0)]
472 }
473
474 fn empty_indicators() -> HashMap<String, Vec<Option<f64>>> {
475 HashMap::new()
476 }
477
478 #[test]
479 fn test_any_signal_returns_first_non_hold() {
480 let c = candles();
481 let ind = empty_indicators();
482 let ctx = make_ctx(&c, &ind);
483
484 let ensemble = EnsembleStrategy::new("test")
485 .add(
486 FixedStrategy {
487 direction: SignalDirection::Hold,
488 strength: 1.0,
489 },
490 1.0,
491 )
492 .add(
493 FixedStrategy {
494 direction: SignalDirection::Long,
495 strength: 0.8,
496 },
497 1.0,
498 )
499 .add(
500 FixedStrategy {
501 direction: SignalDirection::Short,
502 strength: 1.0,
503 },
504 1.0,
505 )
506 .mode(EnsembleMode::AnySignal)
507 .build();
508
509 let signal = ensemble.on_candle(&ctx);
510 assert_eq!(signal.direction, SignalDirection::Long);
511 }
512
513 #[test]
514 fn test_unanimous_all_agree() {
515 let c = candles();
516 let ind = empty_indicators();
517 let ctx = make_ctx(&c, &ind);
518
519 let ensemble = EnsembleStrategy::new("test")
520 .add(
521 FixedStrategy {
522 direction: SignalDirection::Long,
523 strength: 1.0,
524 },
525 1.0,
526 )
527 .add(
528 FixedStrategy {
529 direction: SignalDirection::Long,
530 strength: 0.6,
531 },
532 1.0,
533 )
534 .mode(EnsembleMode::Unanimous)
535 .build();
536
537 let signal = ensemble.on_candle(&ctx);
538 assert_eq!(signal.direction, SignalDirection::Long);
539 assert!((signal.strength.value() - 0.8).abs() < 1e-9); }
541
542 #[test]
543 fn test_unanimous_disagreement_returns_hold() {
544 let c = candles();
545 let ind = empty_indicators();
546 let ctx = make_ctx(&c, &ind);
547
548 let ensemble = EnsembleStrategy::new("test")
549 .add(
550 FixedStrategy {
551 direction: SignalDirection::Long,
552 strength: 1.0,
553 },
554 1.0,
555 )
556 .add(
557 FixedStrategy {
558 direction: SignalDirection::Short,
559 strength: 1.0,
560 },
561 1.0,
562 )
563 .mode(EnsembleMode::Unanimous)
564 .build();
565
566 let signal = ensemble.on_candle(&ctx);
567 assert!(signal.is_hold());
568 }
569
570 #[test]
571 fn test_weighted_majority_long_wins() {
572 let c = candles();
573 let ind = empty_indicators();
574 let ctx = make_ctx(&c, &ind);
575
576 let ensemble = EnsembleStrategy::new("test")
577 .add(
578 FixedStrategy {
579 direction: SignalDirection::Long,
580 strength: 1.0,
581 },
582 2.0,
583 )
584 .add(
585 FixedStrategy {
586 direction: SignalDirection::Short,
587 strength: 1.0,
588 },
589 1.0,
590 )
591 .mode(EnsembleMode::WeightedMajority)
592 .build();
593
594 let signal = ensemble.on_candle(&ctx);
595 assert_eq!(signal.direction, SignalDirection::Long);
596 assert!((signal.strength.value() - 2.0 / 3.0).abs() < 1e-9);
598 }
599
600 #[test]
601 fn test_weighted_majority_tie_returns_hold() {
602 let c = candles();
603 let ind = empty_indicators();
604 let ctx = make_ctx(&c, &ind);
605
606 let ensemble = EnsembleStrategy::new("test")
607 .add(
608 FixedStrategy {
609 direction: SignalDirection::Long,
610 strength: 1.0,
611 },
612 1.0,
613 )
614 .add(
615 FixedStrategy {
616 direction: SignalDirection::Short,
617 strength: 1.0,
618 },
619 1.0,
620 )
621 .mode(EnsembleMode::WeightedMajority)
622 .build();
623
624 let signal = ensemble.on_candle(&ctx);
625 assert!(signal.is_hold());
626 }
627
628 #[test]
629 fn test_strongest_signal() {
630 let c = candles();
631 let ind = empty_indicators();
632 let ctx = make_ctx(&c, &ind);
633
634 let ensemble = EnsembleStrategy::new("test")
635 .add(
636 FixedStrategy {
637 direction: SignalDirection::Long,
638 strength: 0.4,
639 },
640 1.0,
641 )
642 .add(
643 FixedStrategy {
644 direction: SignalDirection::Short,
645 strength: 0.9,
646 },
647 1.0,
648 )
649 .mode(EnsembleMode::StrongestSignal)
650 .build();
651
652 let signal = ensemble.on_candle(&ctx);
653 assert_eq!(signal.direction, SignalDirection::Short);
654 assert!((signal.strength.value() - 0.9).abs() < 1e-9);
655 }
656
657 #[test]
658 fn test_empty_ensemble_returns_hold() {
659 let c = candles();
660 let ind = empty_indicators();
661 let ctx = make_ctx(&c, &ind);
662
663 let ensemble = EnsembleStrategy::new("empty").build();
664 assert!(ensemble.on_candle(&ctx).is_hold());
665 }
666
667 #[test]
668 fn test_warmup_is_max_of_sub_strategies() {
669 struct WarmupStrategy(usize);
670 impl Strategy for WarmupStrategy {
671 fn name(&self) -> &str {
672 "Warmup"
673 }
674 fn required_indicators(&self) -> Vec<(String, Indicator)> {
675 vec![]
676 }
677 fn on_candle(&self, _ctx: &StrategyContext) -> Signal {
678 Signal::hold()
679 }
680 fn warmup_period(&self) -> usize {
681 self.0
682 }
683 }
684
685 let ensemble = EnsembleStrategy::new("test")
686 .add(WarmupStrategy(10), 1.0)
687 .add(WarmupStrategy(25), 1.0)
688 .add(WarmupStrategy(5), 1.0)
689 .build();
690
691 assert_eq!(ensemble.warmup_period(), 25);
692 }
693
694 #[test]
695 fn test_weighted_majority_exit_ignored_when_flat() {
696 let c = candles();
698 let ind = empty_indicators();
699 let ctx = make_ctx(&c, &ind); let ensemble = EnsembleStrategy::new("test")
704 .add(
705 FixedStrategy {
706 direction: SignalDirection::Long,
707 strength: 1.0,
708 },
709 2.0,
710 )
711 .add(
712 FixedStrategy {
713 direction: SignalDirection::Exit,
714 strength: 1.0,
715 },
716 3.0,
717 )
718 .mode(EnsembleMode::WeightedMajority)
719 .build();
720
721 let signal = ensemble.on_candle(&ctx);
722 assert_eq!(signal.direction, SignalDirection::Long);
723 assert!((signal.strength.value() - 2.0 / 5.0).abs() < 1e-9);
726 }
727
728 #[test]
729 fn test_weighted_majority_scale_in_wins_when_position_open() {
730 use crate::backtesting::{Position, PositionSide};
731
732 let c = candles();
733 let ind = empty_indicators();
734 let pos = Position::new(
735 PositionSide::Long,
736 1,
737 100.0,
738 10.0,
739 0.0,
740 Signal::long(1, 100.0),
741 );
742 let ctx = make_ctx_with_position(&c, &ind, &pos);
743
744 let ensemble = EnsembleStrategy::new("test")
746 .add(
747 FixedStrategy {
748 direction: SignalDirection::Long,
749 strength: 1.0,
750 },
751 2.0,
752 )
753 .add(
754 FixedStrategy {
755 direction: SignalDirection::ScaleIn,
756 strength: 1.0,
757 },
758 3.0,
759 )
760 .mode(EnsembleMode::WeightedMajority)
761 .build();
762
763 let signal = ensemble.on_candle(&ctx);
764 assert_eq!(signal.direction, SignalDirection::ScaleIn);
765 assert!((signal.strength.value() - 0.6).abs() < 1e-9);
767 let frac = signal.scale_fraction.expect("scale_fraction must be set");
769 assert!((frac - 0.1).abs() < 1e-9, "expected 0.10, got {frac}");
770 }
771
772 #[test]
773 fn test_weighted_majority_scale_fraction_is_conviction_weighted_average() {
774 use crate::backtesting::{Position, PositionSide};
775
776 let c = candles();
777 let ind = empty_indicators();
778 let pos = Position::new(
779 PositionSide::Long,
780 1,
781 100.0,
782 10.0,
783 0.0,
784 Signal::long(1, 100.0),
785 );
786 let ctx = make_ctx_with_position(&c, &ind, &pos);
787
788 struct ScaleOutStrategy {
792 fraction: f64,
793 }
794 impl Strategy for ScaleOutStrategy {
795 fn name(&self) -> &str {
796 "ScaleOut"
797 }
798 fn required_indicators(&self) -> Vec<(String, Indicator)> {
799 vec![]
800 }
801 fn on_candle(&self, ctx: &StrategyContext) -> Signal {
802 Signal::scale_out(self.fraction, ctx.timestamp(), ctx.close())
803 }
804 }
805
806 let ensemble = EnsembleStrategy::new("test")
807 .add(ScaleOutStrategy { fraction: 0.10 }, 1.0)
808 .add(ScaleOutStrategy { fraction: 0.50 }, 1.0)
809 .mode(EnsembleMode::WeightedMajority)
810 .build();
811
812 let signal = ensemble.on_candle(&ctx);
813 assert_eq!(signal.direction, SignalDirection::ScaleOut);
814 let frac = signal.scale_fraction.expect("scale_fraction must be set");
815 assert!((frac - 0.30).abs() < 1e-9, "expected 0.30, got {frac}");
816 }
817
818 #[test]
819 fn test_weighted_majority_scale_in_ignored_when_flat() {
820 let c = candles();
821 let ind = empty_indicators();
822 let ctx = make_ctx(&c, &ind); let ensemble = EnsembleStrategy::new("test")
826 .add(
827 FixedStrategy {
828 direction: SignalDirection::Long,
829 strength: 1.0,
830 },
831 2.0,
832 )
833 .add(
834 FixedStrategy {
835 direction: SignalDirection::ScaleIn,
836 strength: 1.0,
837 },
838 3.0,
839 )
840 .mode(EnsembleMode::WeightedMajority)
841 .build();
842
843 let signal = ensemble.on_candle(&ctx);
844 assert_eq!(signal.direction, SignalDirection::Long);
845 assert!((signal.strength.value() - 2.0 / 5.0).abs() < 1e-9);
848 }
849
850 #[test]
851 fn test_required_indicators_deduplication() {
852 struct IndStrategy(Vec<(String, Indicator)>);
853 impl Strategy for IndStrategy {
854 fn name(&self) -> &str {
855 "Ind"
856 }
857 fn required_indicators(&self) -> Vec<(String, Indicator)> {
858 self.0.clone()
859 }
860 fn on_candle(&self, _ctx: &StrategyContext) -> Signal {
861 Signal::hold()
862 }
863 }
864
865 let ensemble = EnsembleStrategy::new("test")
866 .add(
867 IndStrategy(vec![
868 ("sma_10".to_string(), Indicator::Sma(10)),
869 ("sma_20".to_string(), Indicator::Sma(20)),
870 ]),
871 1.0,
872 )
873 .add(
874 IndStrategy(vec![
875 ("sma_20".to_string(), Indicator::Sma(20)), ("rsi_14".to_string(), Indicator::Rsi(14)),
877 ]),
878 1.0,
879 )
880 .build();
881
882 let indicators = ensemble.required_indicators();
883 assert_eq!(indicators.len(), 3);
884 assert!(indicators.iter().any(|(k, _)| k == "sma_10"));
885 assert!(indicators.iter().any(|(k, _)| k == "sma_20"));
886 assert!(indicators.iter().any(|(k, _)| k == "rsi_14"));
887 }
888}