1use std::cmp::Ordering;
4use std::collections::{BTreeSet, HashMap, HashSet};
5
6use crate::backtesting::config::BacktestConfig;
7use crate::backtesting::engine::{BacktestEngine, update_trailing_hwm};
8use crate::backtesting::error::{BacktestError, Result};
9use crate::backtesting::position::{Position, PositionSide, Trade};
10use crate::backtesting::result::{BacktestResult, EquityPoint, PerformanceMetrics, SignalRecord};
11use crate::backtesting::signal::{Signal, SignalDirection};
12use crate::backtesting::strategy::{Strategy, StrategyContext};
13use crate::models::chart::{Candle, Dividend};
14
15use super::config::PortfolioConfig;
16use super::result::{AllocationSnapshot, PortfolioResult};
17
18#[non_exhaustive]
22#[derive(Debug, Clone)]
23pub struct SymbolData {
24 pub symbol: String,
26
27 pub candles: Vec<Candle>,
29
30 pub dividends: Vec<Dividend>,
34}
35
36impl SymbolData {
37 pub fn new(symbol: impl Into<String>, candles: Vec<Candle>) -> Self {
39 Self {
40 symbol: symbol.into(),
41 candles,
42 dividends: vec![],
43 }
44 }
45
46 pub fn with_dividends(mut self, dividends: Vec<Dividend>) -> Self {
48 self.dividends = dividends;
49 self
50 }
51}
52
53pub struct PortfolioEngine {
58 config: PortfolioConfig,
59}
60
61impl PortfolioEngine {
62 pub fn new(config: PortfolioConfig) -> Self {
64 Self { config }
65 }
66
67 pub fn run<S, F>(&self, symbol_data: &[SymbolData], factory: F) -> Result<PortfolioResult>
80 where
81 S: Strategy,
82 F: Fn(&str) -> S,
83 {
84 let n_symbols = symbol_data.len();
85 self.config.validate(n_symbols)?;
86
87 let initial_capital = self.config.base.initial_capital;
88
89 let helper_engine = BacktestEngine::new(self.config.base.clone());
91
92 let mut states: HashMap<String, SymbolState<S>> = HashMap::with_capacity(n_symbols);
93 for data in symbol_data {
94 let strategy = factory(&data.symbol);
95 let warmup = strategy.warmup_period();
96 if data.candles.len() < warmup {
97 return Err(BacktestError::insufficient_data(warmup, data.candles.len()));
98 }
99 let strategy_name = strategy.name().to_string();
100 let indicators = helper_engine.compute_indicators(&data.candles, &strategy)?;
101 let ts_index: HashMap<i64, usize> = data
102 .candles
103 .iter()
104 .enumerate()
105 .map(|(i, c)| (c.timestamp, i))
106 .collect();
107
108 let sym_initial_capital = self.config.allocation_target(
112 &data.symbol,
113 initial_capital,
114 initial_capital,
115 n_symbols,
116 );
117
118 states.insert(
119 data.symbol.clone(),
120 SymbolState {
121 candles: data.candles.clone(),
122 dividends: data.dividends.clone(),
123 ts_index,
124 indicators,
125 strategy,
126 warmup,
127 position: None,
128 hwm: None,
129 div_idx: 0,
130 trades: vec![],
131 signals: vec![],
132 realized_pnl: 0.0,
133 equity_curve: vec![],
134 sym_peak: sym_initial_capital,
135 sym_initial_capital,
136 strategy_name,
137 },
138 );
139 }
140
141 let master_timeline: BTreeSet<i64> = states
143 .values()
144 .flat_map(|s| s.candles.iter().map(|c| c.timestamp))
145 .collect();
146
147 let mut cash = initial_capital;
149 let mut portfolio_equity_curve: Vec<EquityPoint> = Vec::new();
150 let mut allocation_history: Vec<AllocationSnapshot> = Vec::new();
151 let mut portfolio_peak = initial_capital;
152
153 for ×tamp in &master_timeline {
155 let active_symbols: Vec<String> = states
158 .keys()
159 .filter(|sym| states[*sym].ts_index.contains_key(×tamp))
160 .cloned()
161 .collect();
162
163 let mut auto_exits: Vec<(String, Signal)> = Vec::new();
165
166 for sym in &active_symbols {
167 let state = states.get_mut(sym).unwrap();
168 let candle_idx = state.ts_index[×tamp];
169 let candle = &state.candles[candle_idx];
170
171 update_trailing_hwm(state.position.as_ref(), &mut state.hwm, candle);
174
175 while state.div_idx < state.dividends.len()
177 && state.dividends[state.div_idx].timestamp <= timestamp
178 {
179 if let Some(ref mut pos) = state.position {
180 let per_share = state.dividends[state.div_idx].amount;
181 let income = if pos.is_long() {
182 per_share * pos.quantity
183 } else {
184 -(per_share * pos.quantity)
185 };
186 pos.credit_dividend(
187 income,
188 candle.close,
189 self.config.base.reinvest_dividends,
190 );
191 }
192 state.div_idx += 1;
193 }
194
195 if let Some(ref pos) = state.position
197 && let Some(exit_signal) =
198 check_sl_tp(pos, candle, state.hwm, &self.config.base)
199 {
200 auto_exits.push((sym.clone(), exit_signal));
201 }
202 }
203
204 let mut exited_this_bar: HashSet<String> = HashSet::new();
207 for (sym, exit_signal) in auto_exits {
208 let state = states.get_mut(&sym).unwrap();
209 let fill_price = exit_signal.price;
210
211 let Some(pos) = state.position.take() else {
212 continue;
213 };
214 let exit_price_slipped = self
215 .config
216 .base
217 .apply_exit_slippage(fill_price, pos.is_long());
218 let exit_price = self
219 .config
220 .base
221 .apply_exit_spread(exit_price_slipped, pos.is_long());
222 let exit_comm = self
223 .config
224 .base
225 .calculate_commission(pos.quantity, exit_price);
226 let exit_tax = self
227 .config
228 .base
229 .calculate_transaction_tax(exit_price * pos.quantity, !pos.is_long());
230 let exit_reason = exit_signal.reason.clone();
231 let exit_tags = exit_signal.tags.clone();
232 let trade =
233 pos.close_with_tax(timestamp, exit_price, exit_comm, exit_tax, exit_signal);
234 if trade.is_long() {
235 cash += trade.exit_value() - exit_comm + trade.unreinvested_dividends;
236 } else {
237 cash -=
238 trade.exit_value() + exit_comm + exit_tax - trade.unreinvested_dividends;
239 }
240 state.realized_pnl += trade.pnl;
241 state.trades.push(trade);
242 state.hwm = None;
243 state.signals.push(SignalRecord {
244 timestamp,
245 price: fill_price,
246 direction: SignalDirection::Exit,
247 strength: 1.0,
248 reason: exit_reason,
249 executed: true,
250 tags: exit_tags,
251 });
252 exited_this_bar.insert(sym);
253 }
254
255 let mut pending_entries: Vec<(String, Signal)> = Vec::new();
257
258 for sym in &active_symbols {
259 if exited_this_bar.contains(sym) {
261 continue;
262 }
263
264 let candle_idx = {
267 let state = states.get_mut(sym).unwrap();
268 let idx = state.ts_index[×tamp];
269 if idx < state.warmup.saturating_sub(1) {
270 continue;
271 }
272 idx
273 }; let portfolio_equity = compute_portfolio_equity(cash, &states, timestamp);
277
278 let state = states.get_mut(sym).unwrap();
280
281 let ctx = StrategyContext {
282 candles: &state.candles[..=candle_idx],
283 index: candle_idx,
284 position: state.position.as_ref(),
285 equity: portfolio_equity,
286 indicators: &state.indicators,
287 };
288
289 let signal = state.strategy.on_candle(&ctx);
290
291 if signal.is_hold() {
292 continue;
293 }
294 if signal.strength.value() < self.config.base.min_signal_strength {
295 state.signals.push(SignalRecord {
296 timestamp: signal.timestamp,
297 price: signal.price,
298 direction: signal.direction,
299 strength: signal.strength.value(),
300 reason: signal.reason.clone(),
301 executed: false,
302 tags: signal.tags.clone(),
303 });
304 continue;
305 }
306
307 match signal.direction {
308 SignalDirection::Exit => {
309 if let Some(pos) = state.position.take() {
311 if let Some(fill_candle) = state.candles.get(candle_idx + 1) {
312 let exit_price_slipped = self
313 .config
314 .base
315 .apply_exit_slippage(fill_candle.open, pos.is_long());
316 let exit_price = self
317 .config
318 .base
319 .apply_exit_spread(exit_price_slipped, pos.is_long());
320 let exit_comm = self
321 .config
322 .base
323 .calculate_commission(pos.quantity, exit_price);
324 let exit_tax = self.config.base.calculate_transaction_tax(
325 exit_price * pos.quantity,
326 !pos.is_long(),
327 );
328 let trade = pos.close_with_tax(
329 fill_candle.timestamp,
330 exit_price,
331 exit_comm,
332 exit_tax,
333 signal.clone(),
334 );
335 if trade.is_long() {
336 cash += trade.exit_value() - exit_comm
337 + trade.unreinvested_dividends;
338 } else {
339 cash -= trade.exit_value() + exit_comm + exit_tax
340 - trade.unreinvested_dividends;
341 }
342 state.realized_pnl += trade.pnl;
343 state.trades.push(trade);
344 state.hwm = None;
345 state.signals.push(SignalRecord {
346 timestamp: signal.timestamp,
347 price: signal.price,
348 direction: signal.direction,
349 strength: signal.strength.value(),
350 reason: signal.reason,
351 executed: true,
352 tags: signal.tags,
353 });
354 } else {
355 state.position = Some(pos);
357 state.signals.push(SignalRecord {
358 timestamp: signal.timestamp,
359 price: signal.price,
360 direction: signal.direction,
361 strength: signal.strength.value(),
362 reason: signal.reason,
363 executed: false,
364 tags: signal.tags,
365 });
366 }
367 }
368 }
369 SignalDirection::Long | SignalDirection::Short => {
370 pending_entries.push((sym.clone(), signal));
372 }
373 SignalDirection::ScaleIn => {
374 let fraction = signal.scale_fraction.unwrap_or(0.0).clamp(0.0, 1.0);
375 let executed = fraction > 0.0
376 && state.position.is_some()
377 && state
378 .candles
379 .get(candle_idx + 1)
380 .is_some_and(|fill_candle| {
381 let pos = state.position.as_mut().unwrap();
382 let is_long = pos.is_long();
383 let fill_price = self.config.base.apply_entry_spread(
384 self.config
385 .base
386 .apply_entry_slippage(fill_candle.open, is_long),
387 is_long,
388 );
389 if fill_price <= 0.0 {
390 return false;
391 }
392 let add_value = portfolio_equity * fraction;
393 let add_qty = add_value / fill_price;
394 let commission =
395 self.config.base.calculate_commission(add_qty, fill_price);
396 let entry_tax = self
397 .config
398 .base
399 .calculate_transaction_tax(add_value, is_long);
400 let total_cost = if is_long {
401 add_value + commission + entry_tax
402 } else {
403 commission
404 };
405 if add_qty <= 0.0 || total_cost > cash {
406 return false;
407 }
408 if is_long {
409 cash -= add_value + commission + entry_tax;
410 } else {
411 cash += add_value - commission;
412 }
413 pos.scale_in(fill_price, add_qty, commission, entry_tax);
414 true
415 });
416 state.signals.push(SignalRecord {
417 timestamp: signal.timestamp,
418 price: signal.price,
419 direction: signal.direction,
420 strength: signal.strength.value(),
421 reason: signal.reason,
422 executed,
423 tags: signal.tags,
424 });
425 }
426 SignalDirection::ScaleOut => {
427 let fraction = signal.scale_fraction.unwrap_or(0.0).clamp(0.0, 1.0);
428 let executed = fraction > 0.0 && {
429 let pos_meta =
431 state.position.as_ref().map(|p| (p.is_long(), p.quantity));
432 match (state.candles.get(candle_idx + 1), pos_meta) {
433 (Some(fill_candle), Some((is_long, qty_full))) => {
434 let exit_price = self.config.base.apply_exit_spread(
435 self.config
436 .base
437 .apply_exit_slippage(fill_candle.open, is_long),
438 is_long,
439 );
440 let qty_to_close = if fraction >= 1.0 {
441 qty_full
442 } else {
443 qty_full * fraction
444 };
445 let commission = self
446 .config
447 .base
448 .calculate_commission(qty_to_close, exit_price);
449 let exit_tax = self.config.base.calculate_transaction_tax(
450 exit_price * qty_to_close,
451 !is_long,
452 );
453 let trade = if fraction >= 1.0 {
454 let pos = state.position.take().unwrap();
455 state.hwm = None;
456 pos.close_with_tax(
457 fill_candle.timestamp,
458 exit_price,
459 commission,
460 exit_tax,
461 signal.clone(),
462 )
463 } else {
464 state.position.as_mut().unwrap().partial_close(
465 fraction,
466 fill_candle.timestamp,
467 exit_price,
468 commission,
469 exit_tax,
470 signal.clone(),
471 )
472 };
473 if trade.is_long() {
474 cash += trade.exit_value() - commission
475 + trade.unreinvested_dividends;
476 } else {
477 cash -= trade.exit_value() + commission + exit_tax
478 - trade.unreinvested_dividends;
479 }
480 state.realized_pnl += trade.pnl;
481 state.trades.push(trade);
482 true
483 }
484 _ => false,
485 }
486 };
487 state.signals.push(SignalRecord {
488 timestamp: signal.timestamp,
489 price: signal.price,
490 direction: signal.direction,
491 strength: signal.strength.value(),
492 reason: signal.reason,
493 executed,
494 tags: signal.tags,
495 });
496 }
497 SignalDirection::Hold => {}
498 }
499 }
500
501 pending_entries.sort_by(|(sym_a, sig_a), (sym_b, sig_b)| {
504 sig_b
505 .strength
506 .value()
507 .partial_cmp(&sig_a.strength.value())
508 .unwrap_or(Ordering::Equal)
509 .then_with(|| sym_a.cmp(sym_b))
510 });
511
512 let open_positions_count: usize =
513 states.values().filter(|s| s.position.is_some()).count();
514 let mut positions_open = open_positions_count;
515
516 for (sym, signal) in pending_entries {
517 let (has_position, signal_price, fill_open, fill_ts) = {
524 let state = states.get(&sym).unwrap();
525 let idx = state.ts_index[×tamp];
526 let signal_price = state.candles[idx].close;
527 let next = state.candles.get(idx + 1).map(|c| (c.open, c.timestamp));
528 (
529 state.position.is_some(),
530 signal_price,
531 next.map(|(o, _)| o),
532 next.map(|(_, t)| t),
533 )
534 }; if has_position {
537 continue;
538 }
539
540 let (Some(fill_open), Some(fill_ts)) = (fill_open, fill_ts) else {
542 states.get_mut(&sym).unwrap().signals.push(SignalRecord {
543 timestamp: signal.timestamp,
544 price: signal.price,
545 direction: signal.direction,
546 strength: signal.strength.value(),
547 reason: signal.reason,
548 executed: false,
549 tags: signal.tags,
550 });
551 continue;
552 };
553
554 if let Some(max) = self.config.max_total_positions
556 && positions_open >= max
557 {
558 states.get_mut(&sym).unwrap().signals.push(SignalRecord {
559 timestamp: signal.timestamp,
560 price: signal.price,
561 direction: signal.direction,
562 strength: signal.strength.value(),
563 reason: signal.reason,
564 executed: false,
565 tags: signal.tags,
566 });
567 continue;
568 }
569
570 if signal.direction == SignalDirection::Short && !self.config.base.allow_short {
571 continue;
572 }
573
574 let is_long = signal.direction == SignalDirection::Long;
575 let target_capital =
576 self.config
577 .allocation_target(&sym, cash, initial_capital, n_symbols);
578
579 if target_capital <= 0.0 {
580 continue;
581 }
582
583 let entry_price_slipped = self.config.base.apply_entry_slippage(fill_open, is_long);
584 let entry_price = self
588 .config
589 .base
590 .apply_entry_spread(entry_price_slipped, is_long);
591
592 let (flat_reserve, pct_friction) = if self.config.base.commission_fn.is_some() {
608 (0.0, 0.0)
609 } else {
610 (self.config.base.commission, self.config.base.commission_pct)
611 };
612 let tax_friction = if is_long {
613 self.config.base.transaction_tax_pct
614 } else {
615 0.0
616 };
617 let effective_target = (target_capital - flat_reserve).max(0.0);
618 let quantity =
619 effective_target / (entry_price * (1.0 + pct_friction + tax_friction));
620 let entry_comm = self.config.base.calculate_commission(quantity, entry_price);
621 let entry_tax = self
622 .config
623 .base
624 .calculate_transaction_tax(entry_price * quantity, is_long);
625 let entry_cost = entry_price * quantity + entry_comm + entry_tax;
626
627 if is_long {
628 if entry_cost > cash {
629 continue;
630 }
631 } else if entry_comm > cash {
632 continue;
633 }
634
635 if is_long {
637 cash -= entry_cost;
638 } else {
639 cash += entry_price * quantity - entry_comm;
640 }
641 let side = if is_long {
642 PositionSide::Long
643 } else {
644 PositionSide::Short
645 };
646
647 let state = states.get_mut(&sym).unwrap();
648 state.position = Some(Position::new_with_tax(
649 side,
650 fill_ts,
651 entry_price,
652 quantity,
653 entry_comm,
654 entry_tax,
655 signal.clone(),
656 ));
657 state.hwm = Some(entry_price);
658 state.signals.push(SignalRecord {
659 timestamp: signal.timestamp,
660 price: signal_price,
661 direction: signal.direction,
662 strength: signal.strength.value(),
663 reason: signal.reason,
664 executed: true,
665 tags: signal.tags,
666 });
667 positions_open += 1;
668 }
669
670 let portfolio_equity = compute_portfolio_equity(cash, &states, timestamp);
672
673 if portfolio_equity > portfolio_peak {
674 portfolio_peak = portfolio_equity;
675 }
676 let drawdown_pct = if portfolio_peak > 0.0 {
677 (portfolio_peak - portfolio_equity) / portfolio_peak
678 } else {
679 0.0
680 };
681
682 portfolio_equity_curve.push(EquityPoint {
683 timestamp,
684 equity: portfolio_equity,
685 drawdown_pct,
686 });
687
688 for sym in &active_symbols {
690 let state = states.get_mut(sym).unwrap();
691 let candle_idx = state.ts_index[×tamp];
692 let close = state.candles[candle_idx].close;
693 let unrealized = state
694 .position
695 .as_ref()
696 .map(|pos| pos.unrealized_pnl(close))
697 .unwrap_or(0.0);
698 let sym_equity = state.sym_initial_capital + state.realized_pnl + unrealized;
699 if sym_equity > state.sym_peak {
700 state.sym_peak = sym_equity;
701 }
702 let sym_drawdown = if state.sym_peak > 0.0 {
703 (state.sym_peak - sym_equity) / state.sym_peak
704 } else {
705 0.0
706 };
707 state.equity_curve.push(EquityPoint {
708 timestamp,
709 equity: sym_equity,
710 drawdown_pct: sym_drawdown,
711 });
712 }
713
714 let position_values: HashMap<String, f64> = states
716 .iter()
717 .filter_map(|(sym, s)| {
718 s.position.as_ref().and_then(|pos| {
719 close_at_or_before(s, timestamp).map(|close| {
720 (
721 sym.clone(),
722 pos.current_value(close) + pos.unreinvested_dividends,
723 )
724 })
725 })
726 })
727 .collect();
728
729 allocation_history.push(AllocationSnapshot {
730 timestamp,
731 cash,
732 positions: position_values,
733 });
734 }
735
736 if self.config.base.close_at_end {
738 for state in states.values_mut() {
739 if let Some(pos) = state.position.take() {
740 let last_candle = state.candles.last().unwrap();
741 let exit_price_slipped = self
742 .config
743 .base
744 .apply_exit_slippage(last_candle.close, pos.is_long());
745 let exit_price = self
746 .config
747 .base
748 .apply_exit_spread(exit_price_slipped, pos.is_long());
749 let exit_comm = self
750 .config
751 .base
752 .calculate_commission(pos.quantity, exit_price);
753 let exit_tax = self
754 .config
755 .base
756 .calculate_transaction_tax(exit_price * pos.quantity, !pos.is_long());
757 let exit_signal = Signal::exit(last_candle.timestamp, last_candle.close)
758 .with_reason("End of backtest");
759 let trade = pos.close_with_tax(
760 last_candle.timestamp,
761 exit_price,
762 exit_comm,
763 exit_tax,
764 exit_signal,
765 );
766 if trade.is_long() {
767 cash += trade.exit_value() - exit_comm + trade.unreinvested_dividends;
768 } else {
769 cash -= trade.exit_value() + exit_comm + exit_tax
770 - trade.unreinvested_dividends;
771 }
772 state.realized_pnl += trade.pnl;
773 state.trades.push(trade);
774 state.hwm = None;
775
776 let sym_equity = state.sym_initial_capital + state.realized_pnl;
777 sync_terminal_equity_point(
778 &mut state.equity_curve,
779 last_candle.timestamp,
780 sym_equity,
781 );
782 }
783 }
784 }
785
786 let final_equity: f64 = cash
788 + states
789 .values()
790 .map(|s| {
791 s.position
792 .as_ref()
793 .zip(s.candles.last())
794 .map(|(pos, c)| pos.current_value(c.close) + pos.unreinvested_dividends)
795 .unwrap_or(0.0)
796 })
797 .sum::<f64>();
798
799 if let Some(last_ts) = master_timeline.last().copied() {
800 sync_terminal_equity_point(&mut portfolio_equity_curve, last_ts, final_equity);
801 }
802
803 let symbol_results: HashMap<String, BacktestResult> = states
805 .into_iter()
806 .map(|(sym, state)| {
807 let sym_final_equity = state
811 .equity_curve
812 .last()
813 .map(|ep| ep.equity)
814 .unwrap_or(state.sym_initial_capital);
815
816 let exec_count = state.signals.iter().filter(|s| s.executed).count();
817 let metrics = PerformanceMetrics::calculate(
818 &state.trades,
819 &state.equity_curve,
820 state.sym_initial_capital,
821 state.signals.len(),
822 exec_count,
823 self.config.base.risk_free_rate,
824 self.config.base.bars_per_year,
825 );
826
827 let start_ts = state.candles.first().map(|c| c.timestamp).unwrap_or(0);
828 let end_ts = state.candles.last().map(|c| c.timestamp).unwrap_or(0);
829
830 let result = BacktestResult {
831 symbol: sym.clone(),
832 strategy_name: state.strategy_name.clone(),
833 config: self.config.base.clone(),
834 start_timestamp: start_ts,
835 end_timestamp: end_ts,
836 initial_capital: state.sym_initial_capital,
837 final_equity: sym_final_equity,
838 metrics,
839 trades: state.trades,
840 equity_curve: state.equity_curve,
841 signals: state.signals,
842 open_position: state.position,
843 benchmark: None,
844 diagnostics: vec![],
845 };
846
847 (sym, result)
848 })
849 .collect();
850
851 let all_trades: Vec<Trade> = symbol_results
853 .values()
854 .flat_map(|r| r.trades.iter().cloned())
855 .collect();
856
857 let total_signals: usize = symbol_results.values().map(|r| r.signals.len()).sum();
858 let executed_signals: usize = symbol_results
859 .values()
860 .flat_map(|r| r.signals.iter())
861 .filter(|s| s.executed)
862 .count();
863
864 let mut portfolio_metrics = PerformanceMetrics::calculate(
865 &all_trades,
866 &portfolio_equity_curve,
867 initial_capital,
868 total_signals,
869 executed_signals,
870 self.config.base.risk_free_rate,
871 self.config.base.bars_per_year,
872 );
873 portfolio_metrics.time_in_market_pct =
874 compute_portfolio_time_in_market(&allocation_history);
875
876 Ok(PortfolioResult {
877 symbols: symbol_results,
878 portfolio_equity_curve,
879 portfolio_metrics,
880 initial_capital,
881 final_equity,
882 allocation_history,
883 })
884 }
885}
886
887struct SymbolState<S: Strategy> {
891 candles: Vec<Candle>,
892 dividends: Vec<Dividend>,
893 ts_index: HashMap<i64, usize>,
894 indicators: HashMap<String, Vec<Option<f64>>>,
895 strategy: S,
896 warmup: usize,
897 position: Option<Position>,
898 hwm: Option<f64>,
899 div_idx: usize,
900 trades: Vec<Trade>,
901 signals: Vec<SignalRecord>,
902 realized_pnl: f64,
904 equity_curve: Vec<EquityPoint>,
906 sym_peak: f64,
908 sym_initial_capital: f64,
914 strategy_name: String,
916}
917
918fn compute_portfolio_equity<S: Strategy>(
920 cash: f64,
921 states: &HashMap<String, SymbolState<S>>,
922 timestamp: i64,
923) -> f64 {
924 cash + states
925 .values()
926 .filter_map(|s| {
927 s.position.as_ref().and_then(|pos| {
928 close_at_or_before(s, timestamp)
929 .map(|close| pos.current_value(close) + pos.unreinvested_dividends)
930 })
931 })
932 .sum::<f64>()
933}
934
935fn close_at_or_before<S: Strategy>(state: &SymbolState<S>, timestamp: i64) -> Option<f64> {
936 if let Some(&idx) = state.ts_index.get(×tamp) {
938 return Some(state.candles[idx].close);
939 }
940 match state
943 .candles
944 .binary_search_by_key(×tamp, |c| c.timestamp)
945 {
946 Ok(idx) | Err(idx) if idx > 0 => Some(state.candles[idx.saturating_sub(1)].close),
947 _ => None,
948 }
949}
950
951fn compute_portfolio_time_in_market(allocation_history: &[AllocationSnapshot]) -> f64 {
956 if allocation_history.len() < 2 {
957 return 0.0;
958 }
959
960 let total_span = allocation_history.last().map(|s| s.timestamp).unwrap_or(0)
961 - allocation_history.first().map(|s| s.timestamp).unwrap_or(0);
962
963 if total_span <= 0 {
964 return 0.0;
965 }
966
967 let mut exposed_secs: i64 = 0;
968 for window in allocation_history.windows(2) {
969 let current = &window[0];
970 let next = &window[1];
971 if !current.positions.is_empty() {
972 exposed_secs += (next.timestamp - current.timestamp).max(0);
973 }
974 }
975
976 (exposed_secs as f64 / total_span as f64).clamp(0.0, 1.0)
977}
978
979fn check_sl_tp(
993 pos: &Position,
994 candle: &Candle,
995 hwm: Option<f64>,
996 config: &BacktestConfig,
997) -> Option<Signal> {
998 if let Some(sl_pct) = config.stop_loss_pct {
1000 let stop_price = if pos.is_long() {
1001 pos.entry_price * (1.0 - sl_pct)
1002 } else {
1003 pos.entry_price * (1.0 + sl_pct)
1004 };
1005 let triggered = if pos.is_long() {
1006 candle.low <= stop_price
1007 } else {
1008 candle.high >= stop_price
1009 };
1010 if triggered {
1011 let fill_price = if pos.is_long() {
1012 candle.open.min(stop_price)
1013 } else {
1014 candle.open.max(stop_price)
1015 };
1016 let return_pct = pos.unrealized_return_pct(fill_price);
1017 return Some(
1018 Signal::exit(candle.timestamp, fill_price)
1019 .with_reason(format!("Stop-loss triggered ({:.1}%)", return_pct)),
1020 );
1021 }
1022 }
1023
1024 if let Some(tp_pct) = config.take_profit_pct {
1026 let tp_price = if pos.is_long() {
1027 pos.entry_price * (1.0 + tp_pct)
1028 } else {
1029 pos.entry_price * (1.0 - tp_pct)
1030 };
1031 let triggered = if pos.is_long() {
1032 candle.high >= tp_price
1033 } else {
1034 candle.low <= tp_price
1035 };
1036 if triggered {
1037 let fill_price = if pos.is_long() {
1038 candle.open.max(tp_price)
1039 } else {
1040 candle.open.min(tp_price)
1041 };
1042 let return_pct = pos.unrealized_return_pct(fill_price);
1043 return Some(
1044 Signal::exit(candle.timestamp, fill_price)
1045 .with_reason(format!("Take-profit triggered ({:.1}%)", return_pct)),
1046 );
1047 }
1048 }
1049
1050 if let Some(trail_pct) = config.trailing_stop_pct
1052 && let Some(extreme) = hwm
1053 && extreme > 0.0
1054 {
1055 let trail_stop_price = if pos.is_long() {
1056 extreme * (1.0 - trail_pct)
1057 } else {
1058 extreme * (1.0 + trail_pct)
1059 };
1060 let triggered = if pos.is_long() {
1061 candle.low <= trail_stop_price
1062 } else {
1063 candle.high >= trail_stop_price
1064 };
1065 if triggered {
1066 let fill_price = if pos.is_long() {
1067 candle.open.min(trail_stop_price)
1068 } else {
1069 candle.open.max(trail_stop_price)
1070 };
1071 let adverse_move_pct = if pos.is_long() {
1072 (extreme - fill_price) / extreme
1073 } else {
1074 (fill_price - extreme) / extreme
1075 };
1076 return Some(
1077 Signal::exit(candle.timestamp, fill_price).with_reason(format!(
1078 "Trailing stop triggered ({:.1}% adverse move)",
1079 adverse_move_pct * 100.0
1080 )),
1081 );
1082 }
1083 }
1084
1085 None
1086}
1087
1088fn sync_terminal_equity_point(equity_curve: &mut Vec<EquityPoint>, timestamp: i64, equity: f64) {
1089 if let Some(last) = equity_curve.last_mut()
1090 && last.timestamp == timestamp
1091 {
1092 last.equity = equity;
1093 } else {
1094 equity_curve.push(EquityPoint {
1095 timestamp,
1096 equity,
1097 drawdown_pct: 0.0,
1098 });
1099 }
1100
1101 let peak = equity_curve
1102 .iter()
1103 .map(|point| point.equity)
1104 .fold(f64::NEG_INFINITY, f64::max);
1105 let drawdown = if peak.is_finite() && peak > 0.0 {
1106 (peak - equity) / peak
1107 } else {
1108 0.0
1109 };
1110
1111 if let Some(last) = equity_curve.last_mut() {
1112 last.drawdown_pct = drawdown;
1113 }
1114}
1115
1116#[cfg(test)]
1119mod tests {
1120 use super::*;
1121 use crate::backtesting::portfolio::config::{PortfolioConfig, RebalanceMode};
1122 use crate::backtesting::strategy::{Strategy, StrategyContext};
1123 use crate::backtesting::{BacktestConfig, SmaCrossover};
1124 use crate::indicators::Indicator;
1125
1126 #[derive(Clone)]
1127 struct EnterShortHold;
1128
1129 impl Strategy for EnterShortHold {
1130 fn name(&self) -> &str {
1131 "Enter Short Hold"
1132 }
1133
1134 fn required_indicators(&self) -> Vec<(String, Indicator)> {
1135 vec![]
1136 }
1137
1138 fn on_candle(&self, ctx: &StrategyContext) -> Signal {
1139 if ctx.index == 0 && !ctx.has_position() {
1140 Signal::short(ctx.timestamp(), ctx.close())
1141 } else {
1142 Signal::hold()
1143 }
1144 }
1145 }
1146
1147 #[derive(Clone)]
1148 struct TimedLongStrategy {
1149 entry_idx: usize,
1150 exit_idx: usize,
1151 }
1152
1153 impl Strategy for TimedLongStrategy {
1154 fn name(&self) -> &str {
1155 "Timed Long"
1156 }
1157
1158 fn required_indicators(&self) -> Vec<(String, Indicator)> {
1159 vec![]
1160 }
1161
1162 fn on_candle(&self, ctx: &StrategyContext) -> Signal {
1163 if ctx.index == self.entry_idx && !ctx.has_position() {
1164 Signal::long(ctx.timestamp(), ctx.close())
1165 } else if ctx.index == self.exit_idx && ctx.has_position() {
1166 Signal::exit(ctx.timestamp(), ctx.close())
1167 } else {
1168 Signal::hold()
1169 }
1170 }
1171 }
1172
1173 fn make_candles(prices: &[f64]) -> Vec<Candle> {
1174 prices
1175 .iter()
1176 .enumerate()
1177 .map(|(i, &p)| Candle {
1178 timestamp: i as i64 * 86400,
1179 open: p,
1180 high: p * 1.005,
1181 low: p * 0.995,
1182 close: p,
1183 volume: 1_000,
1184 adj_close: Some(p),
1185 })
1186 .collect()
1187 }
1188
1189 fn make_candles_with_timestamps(prices: &[f64], timestamps: &[i64]) -> Vec<Candle> {
1190 prices
1191 .iter()
1192 .zip(timestamps.iter())
1193 .map(|(&p, &ts)| Candle {
1194 timestamp: ts,
1195 open: p,
1196 high: p * 1.005,
1197 low: p * 0.995,
1198 close: p,
1199 volume: 1_000,
1200 adj_close: Some(p),
1201 })
1202 .collect()
1203 }
1204
1205 #[derive(Clone)]
1206 struct FirstBarLongElseHold {
1207 enabled: bool,
1208 }
1209
1210 impl Strategy for FirstBarLongElseHold {
1211 fn name(&self) -> &str {
1212 "First Bar Long"
1213 }
1214
1215 fn required_indicators(&self) -> Vec<(String, Indicator)> {
1216 vec![]
1217 }
1218
1219 fn on_candle(&self, ctx: &StrategyContext) -> Signal {
1220 if self.enabled && ctx.index == 0 && !ctx.has_position() {
1221 Signal::long(ctx.timestamp(), ctx.close())
1222 } else {
1223 Signal::hold()
1224 }
1225 }
1226 }
1227
1228 fn trending_prices(n: usize, start: f64, rate: f64) -> Vec<f64> {
1229 (0..n).map(|i| start + i as f64 * rate).collect()
1230 }
1231
1232 #[test]
1233 fn test_two_symbol_basic() {
1234 let prices_a = trending_prices(100, 100.0, 0.5);
1235 let prices_b = trending_prices(100, 50.0, 0.25);
1236
1237 let symbol_data = vec![
1238 SymbolData::new("AAPL", make_candles(&prices_a)),
1239 SymbolData::new("MSFT", make_candles(&prices_b)),
1240 ];
1241
1242 let config = PortfolioConfig::new(
1243 BacktestConfig::builder()
1244 .initial_capital(20_000.0)
1245 .commission_pct(0.0)
1246 .slippage_pct(0.0)
1247 .build()
1248 .unwrap(),
1249 )
1250 .max_total_positions(2);
1251
1252 let engine = PortfolioEngine::new(config);
1253 let result = engine
1254 .run(&symbol_data, |_| SmaCrossover::new(5, 20))
1255 .unwrap();
1256
1257 assert!(result.symbols.contains_key("AAPL"));
1258 assert!(result.symbols.contains_key("MSFT"));
1259 assert!(result.final_equity > 0.0);
1260 assert!(!result.portfolio_equity_curve.is_empty());
1261 }
1262
1263 #[test]
1264 fn test_max_total_positions_respected() {
1265 let prices = trending_prices(100, 100.0, 1.0);
1267 let symbol_data = vec![
1268 SymbolData::new("A", make_candles(&prices)),
1269 SymbolData::new("B", make_candles(&prices)),
1270 ];
1271
1272 let config = PortfolioConfig::new(BacktestConfig::default()).max_total_positions(1);
1273
1274 let engine = PortfolioEngine::new(config);
1275 let result = engine
1276 .run(&symbol_data, |_| SmaCrossover::new(5, 20))
1277 .unwrap();
1278
1279 for snapshot in &result.allocation_history {
1281 assert!(
1282 snapshot.positions.len() <= 1,
1283 "more than 1 position open at timestamp {}",
1284 snapshot.timestamp
1285 );
1286 }
1287 }
1288
1289 #[test]
1290 fn test_equal_weight_allocation() {
1291 let prices = trending_prices(100, 100.0, 0.5);
1292 let symbol_data = vec![
1293 SymbolData::new("X", make_candles(&prices)),
1294 SymbolData::new("Y", make_candles(&prices)),
1295 ];
1296
1297 let config = PortfolioConfig::new(
1298 BacktestConfig::builder()
1299 .initial_capital(10_000.0)
1300 .commission_pct(0.0)
1301 .slippage_pct(0.0)
1302 .build()
1303 .unwrap(),
1304 )
1305 .rebalance(RebalanceMode::EqualWeight)
1306 .max_total_positions(2);
1307
1308 let engine = PortfolioEngine::new(config);
1309 let result = engine
1310 .run(&symbol_data, |_| SmaCrossover::new(5, 20))
1311 .unwrap();
1312
1313 assert!(result.final_equity > 0.0);
1315 }
1316
1317 #[test]
1318 fn test_dividend_credited() {
1319 let prices = trending_prices(50, 100.0, 0.2);
1320 let dividends = vec![
1321 Dividend {
1322 timestamp: 20 * 86400,
1323 amount: 1.0,
1324 },
1325 Dividend {
1326 timestamp: 40 * 86400,
1327 amount: 1.0,
1328 },
1329 ];
1330 let symbol_data =
1331 vec![SymbolData::new("DIV", make_candles(&prices)).with_dividends(dividends)];
1332
1333 let config = PortfolioConfig::new(
1334 BacktestConfig::builder()
1335 .commission_pct(0.0)
1336 .slippage_pct(0.0)
1337 .build()
1338 .unwrap(),
1339 );
1340
1341 let engine = PortfolioEngine::new(config);
1342 let result = engine
1343 .run(&symbol_data, |_| SmaCrossover::new(5, 20))
1344 .unwrap();
1345
1346 let total_div: f64 = result.symbols["DIV"]
1348 .trades
1349 .iter()
1350 .map(|t| t.dividend_income)
1351 .sum();
1352 assert!(total_div >= 0.0);
1353 }
1354
1355 #[test]
1356 fn test_empty_symbol_data_fails() {
1357 let config = PortfolioConfig::default();
1358 let engine = PortfolioEngine::new(config);
1359 assert!(
1360 engine
1361 .run::<SmaCrossover, _>(&[], |_| SmaCrossover::new(5, 20))
1362 .is_err()
1363 );
1364 }
1365
1366 #[test]
1367 fn test_short_dividend_is_liability() {
1368 let prices = vec![100.0, 100.0, 100.0];
1369 let candles = make_candles(&prices);
1370 let dividends = vec![Dividend {
1371 timestamp: candles[1].timestamp,
1372 amount: 1.0,
1373 }];
1374 let symbol_data = vec![SymbolData::new("DIVS", candles).with_dividends(dividends)];
1375
1376 let config = PortfolioConfig::new(
1377 BacktestConfig::builder()
1378 .initial_capital(10_000.0)
1379 .allow_short(true)
1380 .commission_pct(0.0)
1381 .slippage_pct(0.0)
1382 .build()
1383 .unwrap(),
1384 );
1385
1386 let engine = PortfolioEngine::new(config);
1387 let result = engine.run(&symbol_data, |_| EnterShortHold).unwrap();
1388
1389 let trades = &result.symbols["DIVS"].trades;
1390 assert_eq!(trades.len(), 1);
1391 assert!(trades[0].dividend_income < 0.0);
1392 assert!(result.final_equity < 10_000.0);
1393 }
1394
1395 #[test]
1396 fn test_portfolio_time_in_market_uses_union_exposure() {
1397 let prices = vec![100.0, 101.0, 102.0, 103.0, 104.0];
1398
1399 let symbol_data = vec![
1400 SymbolData::new("A", make_candles(&prices)),
1401 SymbolData::new("B", make_candles(&prices)),
1402 ];
1403
1404 let config = PortfolioConfig::new(
1405 BacktestConfig::builder()
1406 .initial_capital(10_000.0)
1407 .position_size_pct(0.5)
1408 .commission_pct(0.0)
1409 .slippage_pct(0.0)
1410 .close_at_end(false)
1411 .build()
1412 .unwrap(),
1413 )
1414 .max_total_positions(2);
1415
1416 let engine = PortfolioEngine::new(config);
1417 let result = engine
1418 .run(&symbol_data, |sym| {
1419 if sym == "A" {
1420 TimedLongStrategy {
1421 entry_idx: 0,
1422 exit_idx: 2,
1423 }
1424 } else {
1425 TimedLongStrategy {
1426 entry_idx: 1,
1427 exit_idx: 3,
1428 }
1429 }
1430 })
1431 .unwrap();
1432
1433 let actual = result.portfolio_metrics.time_in_market_pct;
1436 assert!(
1437 (actual - 0.75).abs() < 1e-9,
1438 "expected 0.75 union exposure, got {actual}"
1439 );
1440 }
1441
1442 #[test]
1443 fn test_portfolio_marks_open_positions_on_sparse_timestamps() {
1444 let symbol_data = vec![
1445 SymbolData::new("A", make_candles_with_timestamps(&[100.0, 110.0], &[0, 2])),
1446 SymbolData::new(
1447 "B",
1448 make_candles_with_timestamps(&[50.0, 50.0, 50.0], &[0, 1, 2]),
1449 ),
1450 ];
1451
1452 let config = PortfolioConfig::new(
1453 BacktestConfig::builder()
1454 .initial_capital(10_000.0)
1455 .position_size_pct(1.0)
1456 .commission_pct(0.0)
1457 .slippage_pct(0.0)
1458 .close_at_end(false)
1459 .build()
1460 .unwrap(),
1461 )
1462 .max_total_positions(2);
1463
1464 let engine = PortfolioEngine::new(config);
1465 let result = engine
1466 .run(&symbol_data, |sym| FirstBarLongElseHold {
1467 enabled: sym == "A",
1468 })
1469 .unwrap();
1470
1471 let snapshot_t1 = result
1472 .allocation_history
1473 .iter()
1474 .find(|s| s.timestamp == 1)
1475 .expect("snapshot at timestamp 1");
1476 assert!(
1477 snapshot_t1.positions.contains_key("A"),
1478 "open A position should be valued at t=1"
1479 );
1480 assert!(
1484 snapshot_t1.total_equity() > 8_000.0,
1485 "equity should include carried-forward A valuation, got {}",
1486 snapshot_t1.total_equity()
1487 );
1488 }
1489}