1use crate::indicators::Indicator;
20
21use super::{Signal, Strategy, StrategyContext};
22use crate::backtesting::signal::{SignalDirection, SignalStrength};
23
24#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
26pub enum EnsembleMode {
27 Unanimous,
37
38 #[default]
62 WeightedMajority,
63
64 AnySignal,
70
71 StrongestSignal,
73}
74
75pub struct EnsembleStrategy {
86 name: String,
87 strategies: Vec<(Box<dyn Strategy>, f64)>,
88 mode: EnsembleMode,
89}
90
91impl EnsembleStrategy {
92 pub fn new(name: impl Into<String>) -> Self {
96 Self {
97 name: name.into(),
98 strategies: Vec::new(),
99 mode: EnsembleMode::default(),
100 }
101 }
102
103 pub fn add<S: Strategy + 'static>(mut self, strategy: S, weight: f64) -> Self {
108 self.strategies.push((Box::new(strategy), weight.max(0.0)));
109 self
110 }
111
112 pub fn mode(mut self, mode: EnsembleMode) -> Self {
114 self.mode = mode;
115 self
116 }
117
118 pub fn build(self) -> Self {
121 self
122 }
123
124 fn any_signal(&self, ctx: &StrategyContext) -> Signal {
127 for (strategy, _) in &self.strategies {
128 let signal = strategy.on_candle(ctx);
129 if !signal.is_hold() {
130 return signal;
131 }
132 }
133 Signal::hold()
134 }
135
136 fn unanimous(&self, ctx: &StrategyContext) -> Signal {
137 let mut first_dir: Option<SignalDirection> = None;
140 let mut first_signal: Option<Signal> = None;
141 let mut total_strength = 0.0_f64;
142 let mut active_count = 0_usize;
143
144 for (strategy, _) in &self.strategies {
145 let signal = strategy.on_candle(ctx);
146 if signal.is_hold() {
147 continue;
148 }
149 match first_dir {
150 None => {
151 first_dir = Some(signal.direction);
152 total_strength = signal.strength.value();
153 first_signal = Some(signal);
154 active_count = 1;
155 }
156 Some(dir) if dir == signal.direction => {
157 total_strength += signal.strength.value();
158 active_count += 1;
159 }
160 _ => return Signal::hold(), }
162 }
163
164 let Some(mut sig) = first_signal else {
165 return Signal::hold();
166 };
167
168 let dir = first_dir.unwrap();
169 let avg_strength = total_strength / active_count as f64;
170 let original_reason = sig.reason.take();
171 sig.strength = SignalStrength::clamped(avg_strength);
172 sig.reason = Some(format!(
173 "Unanimous ({} of {} agree): {}",
174 active_count,
175 self.strategies.len(),
176 original_reason.as_deref().unwrap_or(&dir.to_string())
177 ));
178 sig
179 }
180
181 fn weighted_majority(&self, ctx: &StrategyContext) -> Signal {
182 let total_potential: f64 = self.strategies.iter().map(|(_, w)| *w).sum();
186 if total_potential < f64::EPSILON {
187 return Signal::hold();
188 }
189
190 let mut long_weight = 0.0_f64;
191 let mut short_weight = 0.0_f64;
192 let mut exit_weight = 0.0_f64;
193 let mut scale_in_weight = 0.0_f64;
194 let mut scale_out_weight = 0.0_f64;
195
196 let mut scale_in_frac_score = 0.0_f64;
198 let mut scale_out_frac_score = 0.0_f64;
199
200 let mut best_long: Option<(Signal, f64)> = None;
203 let mut best_short: Option<(Signal, f64)> = None;
204 let mut best_exit: Option<(Signal, f64)> = None;
205 let mut best_scale_in: Option<(Signal, f64)> = None;
206 let mut best_scale_out: Option<(Signal, f64)> = None;
207
208 let has_position = ctx.has_position();
209
210 for (strategy, weight) in &self.strategies {
211 let signal = strategy.on_candle(ctx);
212 let score = weight * signal.strength.value();
214 match signal.direction {
215 SignalDirection::Long => {
216 long_weight += score;
217 if best_long.as_ref().is_none_or(|&(_, s)| score > s) {
218 best_long = Some((signal, score));
219 }
220 }
221 SignalDirection::Short => {
222 short_weight += score;
223 if best_short.as_ref().is_none_or(|&(_, s)| score > s) {
224 best_short = Some((signal, score));
225 }
226 }
227 SignalDirection::Exit if has_position => {
231 exit_weight += score;
232 if best_exit.as_ref().is_none_or(|&(_, s)| score > s) {
233 best_exit = Some((signal, score));
234 }
235 }
236 SignalDirection::ScaleIn if has_position => {
237 let frac = signal.scale_fraction.unwrap_or(0.0);
238 scale_in_weight += score;
239 scale_in_frac_score += frac * score;
240 if best_scale_in.as_ref().is_none_or(|&(_, s)| score > s) {
241 best_scale_in = Some((signal, score));
242 }
243 }
244 SignalDirection::ScaleOut if has_position => {
245 let frac = signal.scale_fraction.unwrap_or(0.0);
246 scale_out_weight += score;
247 scale_out_frac_score += frac * score;
248 if best_scale_out.as_ref().is_none_or(|&(_, s)| score > s) {
249 best_scale_out = Some((signal, score));
250 }
251 }
252 _ => {}
253 }
254 }
255
256 let total_active =
258 long_weight + short_weight + exit_weight + scale_in_weight + scale_out_weight;
259 if total_active < f64::EPSILON {
260 return Signal::hold();
261 }
262
263 let (winner, winner_score) = if long_weight > short_weight
265 && long_weight > exit_weight
266 && long_weight > scale_in_weight
267 && long_weight > scale_out_weight
268 {
269 (best_long, long_weight)
270 } else if short_weight > long_weight
271 && short_weight > exit_weight
272 && short_weight > scale_in_weight
273 && short_weight > scale_out_weight
274 {
275 (best_short, short_weight)
276 } else if exit_weight > long_weight
277 && exit_weight > short_weight
278 && exit_weight > scale_in_weight
279 && exit_weight > scale_out_weight
280 {
281 (best_exit, exit_weight)
282 } else if scale_in_weight > long_weight
283 && scale_in_weight > short_weight
284 && scale_in_weight > exit_weight
285 && scale_in_weight > scale_out_weight
286 {
287 (best_scale_in, scale_in_weight)
288 } else if scale_out_weight > long_weight
289 && scale_out_weight > short_weight
290 && scale_out_weight > exit_weight
291 && scale_out_weight > scale_in_weight
292 {
293 (best_scale_out, scale_out_weight)
294 } else {
295 return Signal::hold();
296 };
297
298 let Some((mut sig, _)) = winner else {
299 return Signal::hold();
300 };
301
302 sig.strength = SignalStrength::clamped(winner_score / total_potential);
304
305 if sig.direction == SignalDirection::ScaleIn && scale_in_weight > f64::EPSILON {
308 sig.scale_fraction = Some((scale_in_frac_score / scale_in_weight).clamp(0.0, 1.0));
309 } else if sig.direction == SignalDirection::ScaleOut && scale_out_weight > f64::EPSILON {
310 sig.scale_fraction = Some((scale_out_frac_score / scale_out_weight).clamp(0.0, 1.0));
311 }
312
313 sig.reason = Some(format!(
314 "WeightedMajority: long={long_weight:.2} short={short_weight:.2} \
315 exit={exit_weight:.2} scale_in={scale_in_weight:.2} scale_out={scale_out_weight:.2}"
316 ));
317 sig
318 }
319
320 fn strongest_signal(&self, ctx: &StrategyContext) -> Signal {
321 self.strategies
322 .iter()
323 .map(|(s, _)| s.on_candle(ctx))
324 .filter(|s| !s.is_hold())
325 .max_by(|a, b| a.strength.value().total_cmp(&b.strength.value()))
326 .unwrap_or_else(Signal::hold)
327 }
328}
329
330impl std::fmt::Debug for EnsembleStrategy {
331 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
332 f.debug_struct("EnsembleStrategy")
333 .field("name", &self.name)
334 .field("strategies_count", &self.strategies.len())
335 .field("mode", &self.mode)
336 .finish()
337 }
338}
339
340impl Strategy for EnsembleStrategy {
341 fn name(&self) -> &str {
342 &self.name
343 }
344
345 fn required_indicators(&self) -> Vec<(String, Indicator)> {
346 let mut indicators: Vec<(String, Indicator)> = self
347 .strategies
348 .iter()
349 .flat_map(|(s, _)| s.required_indicators())
350 .collect();
351 indicators.sort_by(|a, b| a.0.cmp(&b.0));
352 indicators.dedup_by(|a, b| a.0 == b.0);
353 indicators
354 }
355
356 fn setup(&mut self, indicators: &std::collections::HashMap<String, Vec<Option<f64>>>) {
357 for (strategy, _) in &mut self.strategies {
358 strategy.setup(indicators);
359 }
360 }
361
362 fn warmup_period(&self) -> usize {
363 self.strategies
364 .iter()
365 .map(|(s, _)| s.warmup_period())
366 .max()
367 .unwrap_or(1)
368 }
369
370 fn on_candle(&self, ctx: &StrategyContext) -> Signal {
371 if self.strategies.is_empty() {
372 return Signal::hold();
373 }
374 match self.mode {
375 EnsembleMode::AnySignal => self.any_signal(ctx),
376 EnsembleMode::Unanimous => self.unanimous(ctx),
377 EnsembleMode::WeightedMajority => self.weighted_majority(ctx),
378 EnsembleMode::StrongestSignal => self.strongest_signal(ctx),
379 }
380 }
381}
382
383#[cfg(test)]
384mod tests {
385 use super::*;
386 use crate::backtesting::signal::SignalDirection;
387 use crate::backtesting::strategy::Strategy;
388 use crate::indicators::Indicator;
389 use crate::models::chart::Candle;
390 use std::collections::HashMap;
391
392 fn make_candle(ts: i64, price: f64) -> Candle {
393 Candle {
394 timestamp: ts,
395 open: price,
396 high: price,
397 low: price,
398 close: price,
399 volume: 1000,
400 adj_close: None,
401 provider_id: None,
402 }
403 }
404
405 fn make_ctx<'a>(
406 candles: &'a [Candle],
407 indicators: &'a HashMap<String, Vec<Option<f64>>>,
408 ) -> StrategyContext<'a> {
409 StrategyContext {
410 candles,
411 index: candles.len() - 1,
412 position: None,
413 equity: 10_000.0,
414 indicators,
415 }
416 }
417
418 struct FixedStrategy {
420 direction: SignalDirection,
421 strength: f64,
422 }
423
424 impl Strategy for FixedStrategy {
425 fn name(&self) -> &str {
426 "Fixed"
427 }
428 fn required_indicators(&self) -> Vec<(String, Indicator)> {
429 vec![]
430 }
431 fn on_candle(&self, ctx: &StrategyContext) -> Signal {
432 match self.direction {
433 SignalDirection::Long => {
434 let mut s = Signal::long(ctx.timestamp(), ctx.close());
435 s.strength = SignalStrength::clamped(self.strength);
436 s
437 }
438 SignalDirection::Short => {
439 let mut s = Signal::short(ctx.timestamp(), ctx.close());
440 s.strength = SignalStrength::clamped(self.strength);
441 s
442 }
443 SignalDirection::Exit => {
444 let mut s = Signal::exit(ctx.timestamp(), ctx.close());
445 s.strength = SignalStrength::clamped(self.strength);
446 s
447 }
448 SignalDirection::ScaleIn => {
449 let mut s = Signal::scale_in(0.1, ctx.timestamp(), ctx.close());
450 s.strength = SignalStrength::clamped(self.strength);
451 s
452 }
453 SignalDirection::ScaleOut => {
454 let mut s = Signal::scale_out(0.5, ctx.timestamp(), ctx.close());
455 s.strength = SignalStrength::clamped(self.strength);
456 s
457 }
458 _ => Signal::hold(),
459 }
460 }
461 }
462
463 fn make_ctx_with_position<'a>(
464 candles: &'a [Candle],
465 indicators: &'a HashMap<String, Vec<Option<f64>>>,
466 position: &'a crate::backtesting::Position,
467 ) -> StrategyContext<'a> {
468 StrategyContext {
469 candles,
470 index: candles.len() - 1,
471 position: Some(position),
472 equity: 10_000.0,
473 indicators,
474 }
475 }
476
477 fn candles() -> Vec<Candle> {
478 vec![make_candle(1, 100.0), make_candle(2, 101.0)]
479 }
480
481 fn empty_indicators() -> HashMap<String, Vec<Option<f64>>> {
482 HashMap::new()
483 }
484
485 #[test]
486 fn test_any_signal_returns_first_non_hold() {
487 let c = candles();
488 let ind = empty_indicators();
489 let ctx = make_ctx(&c, &ind);
490
491 let ensemble = EnsembleStrategy::new("test")
492 .add(
493 FixedStrategy {
494 direction: SignalDirection::Hold,
495 strength: 1.0,
496 },
497 1.0,
498 )
499 .add(
500 FixedStrategy {
501 direction: SignalDirection::Long,
502 strength: 0.8,
503 },
504 1.0,
505 )
506 .add(
507 FixedStrategy {
508 direction: SignalDirection::Short,
509 strength: 1.0,
510 },
511 1.0,
512 )
513 .mode(EnsembleMode::AnySignal)
514 .build();
515
516 let signal = ensemble.on_candle(&ctx);
517 assert_eq!(signal.direction, SignalDirection::Long);
518 }
519
520 #[test]
521 fn test_unanimous_all_agree() {
522 let c = candles();
523 let ind = empty_indicators();
524 let ctx = make_ctx(&c, &ind);
525
526 let ensemble = EnsembleStrategy::new("test")
527 .add(
528 FixedStrategy {
529 direction: SignalDirection::Long,
530 strength: 1.0,
531 },
532 1.0,
533 )
534 .add(
535 FixedStrategy {
536 direction: SignalDirection::Long,
537 strength: 0.6,
538 },
539 1.0,
540 )
541 .mode(EnsembleMode::Unanimous)
542 .build();
543
544 let signal = ensemble.on_candle(&ctx);
545 assert_eq!(signal.direction, SignalDirection::Long);
546 assert!((signal.strength.value() - 0.8).abs() < 1e-9); }
548
549 #[test]
550 fn test_unanimous_disagreement_returns_hold() {
551 let c = candles();
552 let ind = empty_indicators();
553 let ctx = make_ctx(&c, &ind);
554
555 let ensemble = EnsembleStrategy::new("test")
556 .add(
557 FixedStrategy {
558 direction: SignalDirection::Long,
559 strength: 1.0,
560 },
561 1.0,
562 )
563 .add(
564 FixedStrategy {
565 direction: SignalDirection::Short,
566 strength: 1.0,
567 },
568 1.0,
569 )
570 .mode(EnsembleMode::Unanimous)
571 .build();
572
573 let signal = ensemble.on_candle(&ctx);
574 assert!(signal.is_hold());
575 }
576
577 #[test]
578 fn test_weighted_majority_long_wins() {
579 let c = candles();
580 let ind = empty_indicators();
581 let ctx = make_ctx(&c, &ind);
582
583 let ensemble = EnsembleStrategy::new("test")
584 .add(
585 FixedStrategy {
586 direction: SignalDirection::Long,
587 strength: 1.0,
588 },
589 2.0,
590 )
591 .add(
592 FixedStrategy {
593 direction: SignalDirection::Short,
594 strength: 1.0,
595 },
596 1.0,
597 )
598 .mode(EnsembleMode::WeightedMajority)
599 .build();
600
601 let signal = ensemble.on_candle(&ctx);
602 assert_eq!(signal.direction, SignalDirection::Long);
603 assert!((signal.strength.value() - 2.0 / 3.0).abs() < 1e-9);
605 }
606
607 #[test]
608 fn test_weighted_majority_tie_returns_hold() {
609 let c = candles();
610 let ind = empty_indicators();
611 let ctx = make_ctx(&c, &ind);
612
613 let ensemble = EnsembleStrategy::new("test")
614 .add(
615 FixedStrategy {
616 direction: SignalDirection::Long,
617 strength: 1.0,
618 },
619 1.0,
620 )
621 .add(
622 FixedStrategy {
623 direction: SignalDirection::Short,
624 strength: 1.0,
625 },
626 1.0,
627 )
628 .mode(EnsembleMode::WeightedMajority)
629 .build();
630
631 let signal = ensemble.on_candle(&ctx);
632 assert!(signal.is_hold());
633 }
634
635 #[test]
636 fn test_strongest_signal() {
637 let c = candles();
638 let ind = empty_indicators();
639 let ctx = make_ctx(&c, &ind);
640
641 let ensemble = EnsembleStrategy::new("test")
642 .add(
643 FixedStrategy {
644 direction: SignalDirection::Long,
645 strength: 0.4,
646 },
647 1.0,
648 )
649 .add(
650 FixedStrategy {
651 direction: SignalDirection::Short,
652 strength: 0.9,
653 },
654 1.0,
655 )
656 .mode(EnsembleMode::StrongestSignal)
657 .build();
658
659 let signal = ensemble.on_candle(&ctx);
660 assert_eq!(signal.direction, SignalDirection::Short);
661 assert!((signal.strength.value() - 0.9).abs() < 1e-9);
662 }
663
664 #[test]
665 fn test_empty_ensemble_returns_hold() {
666 let c = candles();
667 let ind = empty_indicators();
668 let ctx = make_ctx(&c, &ind);
669
670 let ensemble = EnsembleStrategy::new("empty").build();
671 assert!(ensemble.on_candle(&ctx).is_hold());
672 }
673
674 #[test]
675 fn test_warmup_is_max_of_sub_strategies() {
676 struct WarmupStrategy(usize);
677 impl Strategy for WarmupStrategy {
678 fn name(&self) -> &str {
679 "Warmup"
680 }
681 fn required_indicators(&self) -> Vec<(String, Indicator)> {
682 vec![]
683 }
684 fn on_candle(&self, _ctx: &StrategyContext) -> Signal {
685 Signal::hold()
686 }
687 fn warmup_period(&self) -> usize {
688 self.0
689 }
690 }
691
692 let ensemble = EnsembleStrategy::new("test")
693 .add(WarmupStrategy(10), 1.0)
694 .add(WarmupStrategy(25), 1.0)
695 .add(WarmupStrategy(5), 1.0)
696 .build();
697
698 assert_eq!(ensemble.warmup_period(), 25);
699 }
700
701 #[test]
702 fn test_weighted_majority_exit_ignored_when_flat() {
703 let c = candles();
705 let ind = empty_indicators();
706 let ctx = make_ctx(&c, &ind); let ensemble = EnsembleStrategy::new("test")
711 .add(
712 FixedStrategy {
713 direction: SignalDirection::Long,
714 strength: 1.0,
715 },
716 2.0,
717 )
718 .add(
719 FixedStrategy {
720 direction: SignalDirection::Exit,
721 strength: 1.0,
722 },
723 3.0,
724 )
725 .mode(EnsembleMode::WeightedMajority)
726 .build();
727
728 let signal = ensemble.on_candle(&ctx);
729 assert_eq!(signal.direction, SignalDirection::Long);
730 assert!((signal.strength.value() - 2.0 / 5.0).abs() < 1e-9);
733 }
734
735 #[test]
736 fn test_weighted_majority_scale_in_wins_when_position_open() {
737 use crate::backtesting::{Position, PositionSide};
738
739 let c = candles();
740 let ind = empty_indicators();
741 let pos = Position::new(
742 PositionSide::Long,
743 1,
744 100.0,
745 10.0,
746 0.0,
747 Signal::long(1, 100.0),
748 );
749 let ctx = make_ctx_with_position(&c, &ind, &pos);
750
751 let ensemble = EnsembleStrategy::new("test")
753 .add(
754 FixedStrategy {
755 direction: SignalDirection::Long,
756 strength: 1.0,
757 },
758 2.0,
759 )
760 .add(
761 FixedStrategy {
762 direction: SignalDirection::ScaleIn,
763 strength: 1.0,
764 },
765 3.0,
766 )
767 .mode(EnsembleMode::WeightedMajority)
768 .build();
769
770 let signal = ensemble.on_candle(&ctx);
771 assert_eq!(signal.direction, SignalDirection::ScaleIn);
772 assert!((signal.strength.value() - 0.6).abs() < 1e-9);
774 let frac = signal.scale_fraction.expect("scale_fraction must be set");
776 assert!((frac - 0.1).abs() < 1e-9, "expected 0.10, got {frac}");
777 }
778
779 #[test]
780 fn test_weighted_majority_scale_fraction_is_conviction_weighted_average() {
781 use crate::backtesting::{Position, PositionSide};
782
783 let c = candles();
784 let ind = empty_indicators();
785 let pos = Position::new(
786 PositionSide::Long,
787 1,
788 100.0,
789 10.0,
790 0.0,
791 Signal::long(1, 100.0),
792 );
793 let ctx = make_ctx_with_position(&c, &ind, &pos);
794
795 struct ScaleOutStrategy {
799 fraction: f64,
800 }
801 impl Strategy for ScaleOutStrategy {
802 fn name(&self) -> &str {
803 "ScaleOut"
804 }
805 fn required_indicators(&self) -> Vec<(String, Indicator)> {
806 vec![]
807 }
808 fn on_candle(&self, ctx: &StrategyContext) -> Signal {
809 Signal::scale_out(self.fraction, ctx.timestamp(), ctx.close())
810 }
811 }
812
813 let ensemble = EnsembleStrategy::new("test")
814 .add(ScaleOutStrategy { fraction: 0.10 }, 1.0)
815 .add(ScaleOutStrategy { fraction: 0.50 }, 1.0)
816 .mode(EnsembleMode::WeightedMajority)
817 .build();
818
819 let signal = ensemble.on_candle(&ctx);
820 assert_eq!(signal.direction, SignalDirection::ScaleOut);
821 let frac = signal.scale_fraction.expect("scale_fraction must be set");
822 assert!((frac - 0.30).abs() < 1e-9, "expected 0.30, got {frac}");
823 }
824
825 #[test]
826 fn test_weighted_majority_scale_in_ignored_when_flat() {
827 let c = candles();
828 let ind = empty_indicators();
829 let ctx = make_ctx(&c, &ind); let ensemble = EnsembleStrategy::new("test")
833 .add(
834 FixedStrategy {
835 direction: SignalDirection::Long,
836 strength: 1.0,
837 },
838 2.0,
839 )
840 .add(
841 FixedStrategy {
842 direction: SignalDirection::ScaleIn,
843 strength: 1.0,
844 },
845 3.0,
846 )
847 .mode(EnsembleMode::WeightedMajority)
848 .build();
849
850 let signal = ensemble.on_candle(&ctx);
851 assert_eq!(signal.direction, SignalDirection::Long);
852 assert!((signal.strength.value() - 2.0 / 5.0).abs() < 1e-9);
855 }
856
857 #[test]
858 fn test_required_indicators_deduplication() {
859 struct IndStrategy(Vec<(String, Indicator)>);
860 impl Strategy for IndStrategy {
861 fn name(&self) -> &str {
862 "Ind"
863 }
864 fn required_indicators(&self) -> Vec<(String, Indicator)> {
865 self.0.clone()
866 }
867 fn on_candle(&self, _ctx: &StrategyContext) -> Signal {
868 Signal::hold()
869 }
870 }
871
872 let ensemble = EnsembleStrategy::new("test")
873 .add(
874 IndStrategy(vec![
875 ("sma_10".to_string(), Indicator::Sma(10)),
876 ("sma_20".to_string(), Indicator::Sma(20)),
877 ]),
878 1.0,
879 )
880 .add(
881 IndStrategy(vec![
882 ("sma_20".to_string(), Indicator::Sma(20)), ("rsi_14".to_string(), Indicator::Rsi(14)),
884 ]),
885 1.0,
886 )
887 .build();
888
889 let indicators = ensemble.required_indicators();
890 assert_eq!(indicators.len(), 3);
891 assert!(indicators.iter().any(|(k, _)| k == "sma_10"));
892 assert!(indicators.iter().any(|(k, _)| k == "sma_20"));
893 assert!(indicators.iter().any(|(k, _)| k == "rsi_14"));
894 }
895}