1use crate::indicators::Indicator;
20
21use super::{Signal, Strategy, StrategyContext};
22use crate::backtesting::signal::{SignalDirection, SignalStrength};
23
24#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
26pub enum EnsembleMode {
27 Unanimous,
37
38 #[default]
62 WeightedMajority,
63
64 AnySignal,
70
71 StrongestSignal,
73}
74
75pub struct EnsembleStrategy {
86 name: String,
87 strategies: Vec<(Box<dyn Strategy>, f64)>,
88 mode: EnsembleMode,
89}
90
91impl EnsembleStrategy {
92 pub fn new(name: impl Into<String>) -> Self {
96 Self {
97 name: name.into(),
98 strategies: Vec::new(),
99 mode: EnsembleMode::default(),
100 }
101 }
102
103 pub fn add<S: Strategy + 'static>(mut self, strategy: S, weight: f64) -> Self {
108 self.strategies.push((Box::new(strategy), weight.max(0.0)));
109 self
110 }
111
112 pub fn mode(mut self, mode: EnsembleMode) -> Self {
114 self.mode = mode;
115 self
116 }
117
118 pub fn build(self) -> Self {
121 self
122 }
123
124 fn any_signal(&self, ctx: &StrategyContext) -> Signal {
127 for (strategy, _) in &self.strategies {
128 let signal = strategy.on_candle(ctx);
129 if !signal.is_hold() {
130 return signal;
131 }
132 }
133 Signal::hold()
134 }
135
136 fn unanimous(&self, ctx: &StrategyContext) -> Signal {
137 let mut first_dir: Option<SignalDirection> = None;
140 let mut first_signal: Option<Signal> = None;
141 let mut total_strength = 0.0_f64;
142 let mut active_count = 0_usize;
143
144 for (strategy, _) in &self.strategies {
145 let signal = strategy.on_candle(ctx);
146 if signal.is_hold() {
147 continue;
148 }
149 match first_dir {
150 None => {
151 first_dir = Some(signal.direction);
152 total_strength = signal.strength.value();
153 first_signal = Some(signal);
154 active_count = 1;
155 }
156 Some(dir) if dir == signal.direction => {
157 total_strength += signal.strength.value();
158 active_count += 1;
159 }
160 _ => return Signal::hold(), }
162 }
163
164 let Some(mut sig) = first_signal else {
165 return Signal::hold();
166 };
167
168 let dir = first_dir.unwrap();
169 let avg_strength = total_strength / active_count as f64;
170 let original_reason = sig.reason.take();
171 sig.strength = SignalStrength::clamped(avg_strength);
172 sig.reason = Some(format!(
173 "Unanimous ({} of {} agree): {}",
174 active_count,
175 self.strategies.len(),
176 original_reason.as_deref().unwrap_or(&dir.to_string())
177 ));
178 sig
179 }
180
181 fn weighted_majority(&self, ctx: &StrategyContext) -> Signal {
182 let total_potential: f64 = self.strategies.iter().map(|(_, w)| *w).sum();
186 if total_potential < f64::EPSILON {
187 return Signal::hold();
188 }
189
190 let mut long_weight = 0.0_f64;
191 let mut short_weight = 0.0_f64;
192 let mut exit_weight = 0.0_f64;
193 let mut scale_in_weight = 0.0_f64;
194 let mut scale_out_weight = 0.0_f64;
195
196 let mut scale_in_frac_score = 0.0_f64;
198 let mut scale_out_frac_score = 0.0_f64;
199
200 let mut best_long: Option<(Signal, f64)> = None;
203 let mut best_short: Option<(Signal, f64)> = None;
204 let mut best_exit: Option<(Signal, f64)> = None;
205 let mut best_scale_in: Option<(Signal, f64)> = None;
206 let mut best_scale_out: Option<(Signal, f64)> = None;
207
208 let has_position = ctx.has_position();
209
210 for (strategy, weight) in &self.strategies {
211 let signal = strategy.on_candle(ctx);
212 let score = weight * signal.strength.value();
214 match signal.direction {
215 SignalDirection::Long => {
216 long_weight += score;
217 if best_long.as_ref().is_none_or(|&(_, s)| score > s) {
218 best_long = Some((signal, score));
219 }
220 }
221 SignalDirection::Short => {
222 short_weight += score;
223 if best_short.as_ref().is_none_or(|&(_, s)| score > s) {
224 best_short = Some((signal, score));
225 }
226 }
227 SignalDirection::Exit if has_position => {
231 exit_weight += score;
232 if best_exit.as_ref().is_none_or(|&(_, s)| score > s) {
233 best_exit = Some((signal, score));
234 }
235 }
236 SignalDirection::ScaleIn if has_position => {
237 let frac = signal.scale_fraction.unwrap_or(0.0);
238 scale_in_weight += score;
239 scale_in_frac_score += frac * score;
240 if best_scale_in.as_ref().is_none_or(|&(_, s)| score > s) {
241 best_scale_in = Some((signal, score));
242 }
243 }
244 SignalDirection::ScaleOut if has_position => {
245 let frac = signal.scale_fraction.unwrap_or(0.0);
246 scale_out_weight += score;
247 scale_out_frac_score += frac * score;
248 if best_scale_out.as_ref().is_none_or(|&(_, s)| score > s) {
249 best_scale_out = Some((signal, score));
250 }
251 }
252 _ => {}
253 }
254 }
255
256 let total_active =
258 long_weight + short_weight + exit_weight + scale_in_weight + scale_out_weight;
259 if total_active < f64::EPSILON {
260 return Signal::hold();
261 }
262
263 let (winner, winner_score) = if long_weight > short_weight
265 && long_weight > exit_weight
266 && long_weight > scale_in_weight
267 && long_weight > scale_out_weight
268 {
269 (best_long, long_weight)
270 } else if short_weight > long_weight
271 && short_weight > exit_weight
272 && short_weight > scale_in_weight
273 && short_weight > scale_out_weight
274 {
275 (best_short, short_weight)
276 } else if exit_weight > long_weight
277 && exit_weight > short_weight
278 && exit_weight > scale_in_weight
279 && exit_weight > scale_out_weight
280 {
281 (best_exit, exit_weight)
282 } else if scale_in_weight > long_weight
283 && scale_in_weight > short_weight
284 && scale_in_weight > exit_weight
285 && scale_in_weight > scale_out_weight
286 {
287 (best_scale_in, scale_in_weight)
288 } else if scale_out_weight > long_weight
289 && scale_out_weight > short_weight
290 && scale_out_weight > exit_weight
291 && scale_out_weight > scale_in_weight
292 {
293 (best_scale_out, scale_out_weight)
294 } else {
295 return Signal::hold();
296 };
297
298 let Some((mut sig, _)) = winner else {
299 return Signal::hold();
300 };
301
302 sig.strength = SignalStrength::clamped(winner_score / total_potential);
304
305 if sig.direction == SignalDirection::ScaleIn && scale_in_weight > f64::EPSILON {
308 sig.scale_fraction = Some((scale_in_frac_score / scale_in_weight).clamp(0.0, 1.0));
309 } else if sig.direction == SignalDirection::ScaleOut && scale_out_weight > f64::EPSILON {
310 sig.scale_fraction = Some((scale_out_frac_score / scale_out_weight).clamp(0.0, 1.0));
311 }
312
313 sig.reason = Some(format!(
314 "WeightedMajority: long={long_weight:.2} short={short_weight:.2} \
315 exit={exit_weight:.2} scale_in={scale_in_weight:.2} scale_out={scale_out_weight:.2}"
316 ));
317 sig
318 }
319
320 fn strongest_signal(&self, ctx: &StrategyContext) -> Signal {
321 self.strategies
322 .iter()
323 .map(|(s, _)| s.on_candle(ctx))
324 .filter(|s| !s.is_hold())
325 .max_by(|a, b| a.strength.value().total_cmp(&b.strength.value()))
326 .unwrap_or_else(Signal::hold)
327 }
328}
329
330impl std::fmt::Debug for EnsembleStrategy {
331 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
332 f.debug_struct("EnsembleStrategy")
333 .field("name", &self.name)
334 .field("strategies_count", &self.strategies.len())
335 .field("mode", &self.mode)
336 .finish()
337 }
338}
339
340impl Strategy for EnsembleStrategy {
341 fn name(&self) -> &str {
342 &self.name
343 }
344
345 fn required_indicators(&self) -> Vec<(String, Indicator)> {
346 let mut indicators: Vec<(String, Indicator)> = self
347 .strategies
348 .iter()
349 .flat_map(|(s, _)| s.required_indicators())
350 .collect();
351 indicators.sort_by(|a, b| a.0.cmp(&b.0));
352 indicators.dedup_by(|a, b| a.0 == b.0);
353 indicators
354 }
355
356 fn setup(&mut self, indicators: &std::collections::HashMap<String, Vec<Option<f64>>>) {
357 for (strategy, _) in &mut self.strategies {
358 strategy.setup(indicators);
359 }
360 }
361
362 fn warmup_period(&self) -> usize {
363 self.strategies
364 .iter()
365 .map(|(s, _)| s.warmup_period())
366 .max()
367 .unwrap_or(1)
368 }
369
370 fn on_candle(&self, ctx: &StrategyContext) -> Signal {
371 if self.strategies.is_empty() {
372 return Signal::hold();
373 }
374 match self.mode {
375 EnsembleMode::AnySignal => self.any_signal(ctx),
376 EnsembleMode::Unanimous => self.unanimous(ctx),
377 EnsembleMode::WeightedMajority => self.weighted_majority(ctx),
378 EnsembleMode::StrongestSignal => self.strongest_signal(ctx),
379 }
380 }
381}
382
383#[cfg(test)]
384mod tests {
385 use super::*;
386 use crate::backtesting::signal::SignalDirection;
387 use crate::backtesting::strategy::Strategy;
388 use crate::indicators::Indicator;
389 use crate::models::chart::Candle;
390 use std::collections::HashMap;
391
392 fn make_candle(ts: i64, price: f64) -> Candle {
393 Candle {
394 timestamp: ts,
395 open: price,
396 high: price,
397 low: price,
398 close: price,
399 volume: 1000,
400 adj_close: None,
401 }
402 }
403
404 fn make_ctx<'a>(
405 candles: &'a [Candle],
406 indicators: &'a HashMap<String, Vec<Option<f64>>>,
407 ) -> StrategyContext<'a> {
408 StrategyContext {
409 candles,
410 index: candles.len() - 1,
411 position: None,
412 equity: 10_000.0,
413 indicators,
414 }
415 }
416
417 struct FixedStrategy {
419 direction: SignalDirection,
420 strength: f64,
421 }
422
423 impl Strategy for FixedStrategy {
424 fn name(&self) -> &str {
425 "Fixed"
426 }
427 fn required_indicators(&self) -> Vec<(String, Indicator)> {
428 vec![]
429 }
430 fn on_candle(&self, ctx: &StrategyContext) -> Signal {
431 match self.direction {
432 SignalDirection::Long => {
433 let mut s = Signal::long(ctx.timestamp(), ctx.close());
434 s.strength = SignalStrength::clamped(self.strength);
435 s
436 }
437 SignalDirection::Short => {
438 let mut s = Signal::short(ctx.timestamp(), ctx.close());
439 s.strength = SignalStrength::clamped(self.strength);
440 s
441 }
442 SignalDirection::Exit => {
443 let mut s = Signal::exit(ctx.timestamp(), ctx.close());
444 s.strength = SignalStrength::clamped(self.strength);
445 s
446 }
447 SignalDirection::ScaleIn => {
448 let mut s = Signal::scale_in(0.1, ctx.timestamp(), ctx.close());
449 s.strength = SignalStrength::clamped(self.strength);
450 s
451 }
452 SignalDirection::ScaleOut => {
453 let mut s = Signal::scale_out(0.5, ctx.timestamp(), ctx.close());
454 s.strength = SignalStrength::clamped(self.strength);
455 s
456 }
457 _ => Signal::hold(),
458 }
459 }
460 }
461
462 fn make_ctx_with_position<'a>(
463 candles: &'a [Candle],
464 indicators: &'a HashMap<String, Vec<Option<f64>>>,
465 position: &'a crate::backtesting::Position,
466 ) -> StrategyContext<'a> {
467 StrategyContext {
468 candles,
469 index: candles.len() - 1,
470 position: Some(position),
471 equity: 10_000.0,
472 indicators,
473 }
474 }
475
476 fn candles() -> Vec<Candle> {
477 vec![make_candle(1, 100.0), make_candle(2, 101.0)]
478 }
479
480 fn empty_indicators() -> HashMap<String, Vec<Option<f64>>> {
481 HashMap::new()
482 }
483
484 #[test]
485 fn test_any_signal_returns_first_non_hold() {
486 let c = candles();
487 let ind = empty_indicators();
488 let ctx = make_ctx(&c, &ind);
489
490 let ensemble = EnsembleStrategy::new("test")
491 .add(
492 FixedStrategy {
493 direction: SignalDirection::Hold,
494 strength: 1.0,
495 },
496 1.0,
497 )
498 .add(
499 FixedStrategy {
500 direction: SignalDirection::Long,
501 strength: 0.8,
502 },
503 1.0,
504 )
505 .add(
506 FixedStrategy {
507 direction: SignalDirection::Short,
508 strength: 1.0,
509 },
510 1.0,
511 )
512 .mode(EnsembleMode::AnySignal)
513 .build();
514
515 let signal = ensemble.on_candle(&ctx);
516 assert_eq!(signal.direction, SignalDirection::Long);
517 }
518
519 #[test]
520 fn test_unanimous_all_agree() {
521 let c = candles();
522 let ind = empty_indicators();
523 let ctx = make_ctx(&c, &ind);
524
525 let ensemble = EnsembleStrategy::new("test")
526 .add(
527 FixedStrategy {
528 direction: SignalDirection::Long,
529 strength: 1.0,
530 },
531 1.0,
532 )
533 .add(
534 FixedStrategy {
535 direction: SignalDirection::Long,
536 strength: 0.6,
537 },
538 1.0,
539 )
540 .mode(EnsembleMode::Unanimous)
541 .build();
542
543 let signal = ensemble.on_candle(&ctx);
544 assert_eq!(signal.direction, SignalDirection::Long);
545 assert!((signal.strength.value() - 0.8).abs() < 1e-9); }
547
548 #[test]
549 fn test_unanimous_disagreement_returns_hold() {
550 let c = candles();
551 let ind = empty_indicators();
552 let ctx = make_ctx(&c, &ind);
553
554 let ensemble = EnsembleStrategy::new("test")
555 .add(
556 FixedStrategy {
557 direction: SignalDirection::Long,
558 strength: 1.0,
559 },
560 1.0,
561 )
562 .add(
563 FixedStrategy {
564 direction: SignalDirection::Short,
565 strength: 1.0,
566 },
567 1.0,
568 )
569 .mode(EnsembleMode::Unanimous)
570 .build();
571
572 let signal = ensemble.on_candle(&ctx);
573 assert!(signal.is_hold());
574 }
575
576 #[test]
577 fn test_weighted_majority_long_wins() {
578 let c = candles();
579 let ind = empty_indicators();
580 let ctx = make_ctx(&c, &ind);
581
582 let ensemble = EnsembleStrategy::new("test")
583 .add(
584 FixedStrategy {
585 direction: SignalDirection::Long,
586 strength: 1.0,
587 },
588 2.0,
589 )
590 .add(
591 FixedStrategy {
592 direction: SignalDirection::Short,
593 strength: 1.0,
594 },
595 1.0,
596 )
597 .mode(EnsembleMode::WeightedMajority)
598 .build();
599
600 let signal = ensemble.on_candle(&ctx);
601 assert_eq!(signal.direction, SignalDirection::Long);
602 assert!((signal.strength.value() - 2.0 / 3.0).abs() < 1e-9);
604 }
605
606 #[test]
607 fn test_weighted_majority_tie_returns_hold() {
608 let c = candles();
609 let ind = empty_indicators();
610 let ctx = make_ctx(&c, &ind);
611
612 let ensemble = EnsembleStrategy::new("test")
613 .add(
614 FixedStrategy {
615 direction: SignalDirection::Long,
616 strength: 1.0,
617 },
618 1.0,
619 )
620 .add(
621 FixedStrategy {
622 direction: SignalDirection::Short,
623 strength: 1.0,
624 },
625 1.0,
626 )
627 .mode(EnsembleMode::WeightedMajority)
628 .build();
629
630 let signal = ensemble.on_candle(&ctx);
631 assert!(signal.is_hold());
632 }
633
634 #[test]
635 fn test_strongest_signal() {
636 let c = candles();
637 let ind = empty_indicators();
638 let ctx = make_ctx(&c, &ind);
639
640 let ensemble = EnsembleStrategy::new("test")
641 .add(
642 FixedStrategy {
643 direction: SignalDirection::Long,
644 strength: 0.4,
645 },
646 1.0,
647 )
648 .add(
649 FixedStrategy {
650 direction: SignalDirection::Short,
651 strength: 0.9,
652 },
653 1.0,
654 )
655 .mode(EnsembleMode::StrongestSignal)
656 .build();
657
658 let signal = ensemble.on_candle(&ctx);
659 assert_eq!(signal.direction, SignalDirection::Short);
660 assert!((signal.strength.value() - 0.9).abs() < 1e-9);
661 }
662
663 #[test]
664 fn test_empty_ensemble_returns_hold() {
665 let c = candles();
666 let ind = empty_indicators();
667 let ctx = make_ctx(&c, &ind);
668
669 let ensemble = EnsembleStrategy::new("empty").build();
670 assert!(ensemble.on_candle(&ctx).is_hold());
671 }
672
673 #[test]
674 fn test_warmup_is_max_of_sub_strategies() {
675 struct WarmupStrategy(usize);
676 impl Strategy for WarmupStrategy {
677 fn name(&self) -> &str {
678 "Warmup"
679 }
680 fn required_indicators(&self) -> Vec<(String, Indicator)> {
681 vec![]
682 }
683 fn on_candle(&self, _ctx: &StrategyContext) -> Signal {
684 Signal::hold()
685 }
686 fn warmup_period(&self) -> usize {
687 self.0
688 }
689 }
690
691 let ensemble = EnsembleStrategy::new("test")
692 .add(WarmupStrategy(10), 1.0)
693 .add(WarmupStrategy(25), 1.0)
694 .add(WarmupStrategy(5), 1.0)
695 .build();
696
697 assert_eq!(ensemble.warmup_period(), 25);
698 }
699
700 #[test]
701 fn test_weighted_majority_exit_ignored_when_flat() {
702 let c = candles();
704 let ind = empty_indicators();
705 let ctx = make_ctx(&c, &ind); let ensemble = EnsembleStrategy::new("test")
710 .add(
711 FixedStrategy {
712 direction: SignalDirection::Long,
713 strength: 1.0,
714 },
715 2.0,
716 )
717 .add(
718 FixedStrategy {
719 direction: SignalDirection::Exit,
720 strength: 1.0,
721 },
722 3.0,
723 )
724 .mode(EnsembleMode::WeightedMajority)
725 .build();
726
727 let signal = ensemble.on_candle(&ctx);
728 assert_eq!(signal.direction, SignalDirection::Long);
729 assert!((signal.strength.value() - 2.0 / 5.0).abs() < 1e-9);
732 }
733
734 #[test]
735 fn test_weighted_majority_scale_in_wins_when_position_open() {
736 use crate::backtesting::{Position, PositionSide};
737
738 let c = candles();
739 let ind = empty_indicators();
740 let pos = Position::new(
741 PositionSide::Long,
742 1,
743 100.0,
744 10.0,
745 0.0,
746 Signal::long(1, 100.0),
747 );
748 let ctx = make_ctx_with_position(&c, &ind, &pos);
749
750 let ensemble = EnsembleStrategy::new("test")
752 .add(
753 FixedStrategy {
754 direction: SignalDirection::Long,
755 strength: 1.0,
756 },
757 2.0,
758 )
759 .add(
760 FixedStrategy {
761 direction: SignalDirection::ScaleIn,
762 strength: 1.0,
763 },
764 3.0,
765 )
766 .mode(EnsembleMode::WeightedMajority)
767 .build();
768
769 let signal = ensemble.on_candle(&ctx);
770 assert_eq!(signal.direction, SignalDirection::ScaleIn);
771 assert!((signal.strength.value() - 0.6).abs() < 1e-9);
773 let frac = signal.scale_fraction.expect("scale_fraction must be set");
775 assert!((frac - 0.1).abs() < 1e-9, "expected 0.10, got {frac}");
776 }
777
778 #[test]
779 fn test_weighted_majority_scale_fraction_is_conviction_weighted_average() {
780 use crate::backtesting::{Position, PositionSide};
781
782 let c = candles();
783 let ind = empty_indicators();
784 let pos = Position::new(
785 PositionSide::Long,
786 1,
787 100.0,
788 10.0,
789 0.0,
790 Signal::long(1, 100.0),
791 );
792 let ctx = make_ctx_with_position(&c, &ind, &pos);
793
794 struct ScaleOutStrategy {
798 fraction: f64,
799 }
800 impl Strategy for ScaleOutStrategy {
801 fn name(&self) -> &str {
802 "ScaleOut"
803 }
804 fn required_indicators(&self) -> Vec<(String, Indicator)> {
805 vec![]
806 }
807 fn on_candle(&self, ctx: &StrategyContext) -> Signal {
808 Signal::scale_out(self.fraction, ctx.timestamp(), ctx.close())
809 }
810 }
811
812 let ensemble = EnsembleStrategy::new("test")
813 .add(ScaleOutStrategy { fraction: 0.10 }, 1.0)
814 .add(ScaleOutStrategy { fraction: 0.50 }, 1.0)
815 .mode(EnsembleMode::WeightedMajority)
816 .build();
817
818 let signal = ensemble.on_candle(&ctx);
819 assert_eq!(signal.direction, SignalDirection::ScaleOut);
820 let frac = signal.scale_fraction.expect("scale_fraction must be set");
821 assert!((frac - 0.30).abs() < 1e-9, "expected 0.30, got {frac}");
822 }
823
824 #[test]
825 fn test_weighted_majority_scale_in_ignored_when_flat() {
826 let c = candles();
827 let ind = empty_indicators();
828 let ctx = make_ctx(&c, &ind); let ensemble = EnsembleStrategy::new("test")
832 .add(
833 FixedStrategy {
834 direction: SignalDirection::Long,
835 strength: 1.0,
836 },
837 2.0,
838 )
839 .add(
840 FixedStrategy {
841 direction: SignalDirection::ScaleIn,
842 strength: 1.0,
843 },
844 3.0,
845 )
846 .mode(EnsembleMode::WeightedMajority)
847 .build();
848
849 let signal = ensemble.on_candle(&ctx);
850 assert_eq!(signal.direction, SignalDirection::Long);
851 assert!((signal.strength.value() - 2.0 / 5.0).abs() < 1e-9);
854 }
855
856 #[test]
857 fn test_required_indicators_deduplication() {
858 struct IndStrategy(Vec<(String, Indicator)>);
859 impl Strategy for IndStrategy {
860 fn name(&self) -> &str {
861 "Ind"
862 }
863 fn required_indicators(&self) -> Vec<(String, Indicator)> {
864 self.0.clone()
865 }
866 fn on_candle(&self, _ctx: &StrategyContext) -> Signal {
867 Signal::hold()
868 }
869 }
870
871 let ensemble = EnsembleStrategy::new("test")
872 .add(
873 IndStrategy(vec![
874 ("sma_10".to_string(), Indicator::Sma(10)),
875 ("sma_20".to_string(), Indicator::Sma(20)),
876 ]),
877 1.0,
878 )
879 .add(
880 IndStrategy(vec![
881 ("sma_20".to_string(), Indicator::Sma(20)), ("rsi_14".to_string(), Indicator::Rsi(14)),
883 ]),
884 1.0,
885 )
886 .build();
887
888 let indicators = ensemble.required_indicators();
889 assert_eq!(indicators.len(), 3);
890 assert!(indicators.iter().any(|(k, _)| k == "sma_10"));
891 assert!(indicators.iter().any(|(k, _)| k == "sma_20"));
892 assert!(indicators.iter().any(|(k, _)| k == "rsi_14"));
893 }
894}