finance_query/backtesting/
config.rs1use std::fmt;
4use std::sync::Arc;
5
6use serde::{Deserialize, Serialize};
7
8use super::error::{BacktestError, Result};
9
10#[derive(Clone)]
31pub struct CommissionFn(Arc<dyn Fn(f64, f64) -> f64 + Send + Sync>);
32
33impl CommissionFn {
34 pub fn new<F>(f: F) -> Self
36 where
37 F: Fn(f64, f64) -> f64 + Send + Sync + 'static,
38 {
39 Self(Arc::new(f))
40 }
41
42 #[inline]
44 pub(crate) fn call(&self, size: f64, price: f64) -> f64 {
45 (self.0)(size, price)
46 }
47}
48
49impl fmt::Debug for CommissionFn {
50 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
51 write!(f, "CommissionFn(<closure>)")
52 }
53}
54
55#[non_exhaustive]
75#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct BacktestConfig {
77 pub initial_capital: f64,
79
80 pub commission: f64,
82
83 pub commission_pct: f64,
85
86 pub slippage_pct: f64,
88
89 pub position_size_pct: f64,
91
92 pub max_positions: Option<usize>,
94
95 pub allow_short: bool,
97
98 pub min_signal_strength: f64,
100
101 pub stop_loss_pct: Option<f64>,
103
104 pub take_profit_pct: Option<f64>,
106
107 pub close_at_end: bool,
109
110 pub risk_free_rate: f64,
115
116 pub trailing_stop_pct: Option<f64>,
127
128 pub reinvest_dividends: bool,
135
136 pub bars_per_year: f64,
143
144 pub spread_pct: f64,
156
157 pub transaction_tax_pct: f64,
167
168 #[serde(skip)]
181 pub commission_fn: Option<CommissionFn>,
182}
183
184impl Default for BacktestConfig {
185 fn default() -> Self {
186 Self {
187 initial_capital: 10_000.0,
188 commission: 0.0,
189 commission_pct: 0.001, slippage_pct: 0.001, position_size_pct: 1.0, max_positions: Some(1), allow_short: false,
194 min_signal_strength: 0.0,
195 stop_loss_pct: None,
196 take_profit_pct: None,
197 close_at_end: true,
198 risk_free_rate: 0.0,
199 trailing_stop_pct: None,
200 reinvest_dividends: false,
201 bars_per_year: 252.0,
202 spread_pct: 0.0,
203 transaction_tax_pct: 0.0,
204 commission_fn: None,
205 }
206 }
207}
208
209impl BacktestConfig {
210 pub fn zero_cost() -> Self {
215 Self {
216 commission: 0.0,
217 commission_pct: 0.0,
218 slippage_pct: 0.0,
219 spread_pct: 0.0,
220 transaction_tax_pct: 0.0,
221 commission_fn: None,
222 ..Default::default()
223 }
224 }
225
226 pub fn builder() -> BacktestConfigBuilder {
228 BacktestConfigBuilder::default()
229 }
230
231 pub fn validate(&self) -> Result<()> {
233 if self.initial_capital <= 0.0 {
234 return Err(BacktestError::invalid_param(
235 "initial_capital",
236 "must be positive",
237 ));
238 }
239
240 if self.commission < 0.0 {
241 return Err(BacktestError::invalid_param(
242 "commission",
243 "cannot be negative",
244 ));
245 }
246
247 if !(0.0..=1.0).contains(&self.commission_pct) {
248 return Err(BacktestError::invalid_param(
249 "commission_pct",
250 "must be between 0.0 and 1.0",
251 ));
252 }
253
254 if !(0.0..=1.0).contains(&self.slippage_pct) {
255 return Err(BacktestError::invalid_param(
256 "slippage_pct",
257 "must be between 0.0 and 1.0",
258 ));
259 }
260
261 if self.position_size_pct <= 0.0 || self.position_size_pct > 1.0 {
262 return Err(BacktestError::invalid_param(
263 "position_size_pct",
264 "must be between 0.0 (exclusive) and 1.0 (inclusive)",
265 ));
266 }
267
268 if !(0.0..=1.0).contains(&self.min_signal_strength) {
269 return Err(BacktestError::invalid_param(
270 "min_signal_strength",
271 "must be between 0.0 and 1.0",
272 ));
273 }
274
275 if let Some(sl) = self.stop_loss_pct
276 && !(0.0..=1.0).contains(&sl)
277 {
278 return Err(BacktestError::invalid_param(
279 "stop_loss_pct",
280 "must be between 0.0 and 1.0",
281 ));
282 }
283
284 if let Some(tp) = self.take_profit_pct
285 && !(0.0..=1.0).contains(&tp)
286 {
287 return Err(BacktestError::invalid_param(
288 "take_profit_pct",
289 "must be between 0.0 and 1.0",
290 ));
291 }
292
293 if !(0.0..=1.0).contains(&self.risk_free_rate) {
294 return Err(BacktestError::invalid_param(
295 "risk_free_rate",
296 "must be between 0.0 and 1.0",
297 ));
298 }
299
300 if let Some(trail) = self.trailing_stop_pct
301 && !(0.0..=1.0).contains(&trail)
302 {
303 return Err(BacktestError::invalid_param(
304 "trailing_stop_pct",
305 "must be between 0.0 and 1.0",
306 ));
307 }
308
309 if self.bars_per_year <= 0.0 {
310 return Err(BacktestError::invalid_param(
311 "bars_per_year",
312 "must be positive (e.g. 252 for daily, 52 for weekly)",
313 ));
314 }
315
316 if !(0.0..=1.0).contains(&self.spread_pct) {
317 return Err(BacktestError::invalid_param(
318 "spread_pct",
319 "must be between 0.0 and 1.0",
320 ));
321 }
322
323 if !(0.0..=1.0).contains(&self.transaction_tax_pct) {
324 return Err(BacktestError::invalid_param(
325 "transaction_tax_pct",
326 "must be between 0.0 and 1.0",
327 ));
328 }
329
330 Ok(())
331 }
332
333 pub fn calculate_commission(&self, size: f64, price: f64) -> f64 {
342 if let Some(ref f) = self.commission_fn {
343 f.call(size, price)
344 } else {
345 self.commission + (size * price * self.commission_pct)
346 }
347 }
348
349 pub fn apply_entry_slippage(&self, price: f64, is_long: bool) -> f64 {
351 if is_long {
352 price * (1.0 + self.slippage_pct)
353 } else {
354 price * (1.0 - self.slippage_pct)
355 }
356 }
357
358 pub fn apply_exit_slippage(&self, price: f64, is_long: bool) -> f64 {
360 if is_long {
361 price * (1.0 - self.slippage_pct)
362 } else {
363 price * (1.0 + self.slippage_pct)
364 }
365 }
366
367 pub fn apply_entry_spread(&self, price: f64, is_long: bool) -> f64 {
372 let half = self.spread_pct / 2.0;
373 if is_long {
374 price * (1.0 + half)
375 } else {
376 price * (1.0 - half)
377 }
378 }
379
380 pub fn apply_exit_spread(&self, price: f64, is_long: bool) -> f64 {
385 let half = self.spread_pct / 2.0;
386 if is_long {
387 price * (1.0 - half)
388 } else {
389 price * (1.0 + half)
390 }
391 }
392
393 pub fn calculate_transaction_tax(&self, trade_value: f64, is_buy: bool) -> f64 {
401 if is_buy {
402 trade_value * self.transaction_tax_pct
403 } else {
404 0.0
405 }
406 }
407
408 pub fn calculate_position_size(&self, available_capital: f64, price: f64) -> f64 {
421 let capital_to_use = available_capital * self.position_size_pct;
422
423 let adjusted_capital = if self.commission_fn.is_some() {
424 capital_to_use / (1.0 + self.spread_pct + self.transaction_tax_pct)
427 } else {
428 let friction =
433 1.0 + 2.0 * self.commission_pct + self.spread_pct + self.transaction_tax_pct;
434 capital_to_use / friction - 2.0 * self.commission
435 };
436
437 (adjusted_capital / price).max(0.0)
438 }
439}
440
441#[derive(Default)]
443pub struct BacktestConfigBuilder {
444 config: BacktestConfig,
445}
446
447impl BacktestConfigBuilder {
448 pub fn initial_capital(mut self, capital: f64) -> Self {
450 self.config.initial_capital = capital;
451 self
452 }
453
454 pub fn commission(mut self, fee: f64) -> Self {
456 self.config.commission = fee;
457 self
458 }
459
460 pub fn commission_pct(mut self, pct: f64) -> Self {
462 self.config.commission_pct = pct;
463 self
464 }
465
466 pub fn slippage_pct(mut self, pct: f64) -> Self {
468 self.config.slippage_pct = pct;
469 self
470 }
471
472 pub fn position_size_pct(mut self, pct: f64) -> Self {
474 self.config.position_size_pct = pct;
475 self
476 }
477
478 pub fn max_positions(mut self, max: usize) -> Self {
480 self.config.max_positions = Some(max);
481 self
482 }
483
484 pub fn unlimited_positions(mut self) -> Self {
486 self.config.max_positions = None;
487 self
488 }
489
490 pub fn allow_short(mut self, allow: bool) -> Self {
492 self.config.allow_short = allow;
493 self
494 }
495
496 pub fn min_signal_strength(mut self, threshold: f64) -> Self {
498 self.config.min_signal_strength = threshold;
499 self
500 }
501
502 pub fn stop_loss_pct(mut self, pct: f64) -> Self {
504 self.config.stop_loss_pct = Some(pct);
505 self
506 }
507
508 pub fn take_profit_pct(mut self, pct: f64) -> Self {
510 self.config.take_profit_pct = Some(pct);
511 self
512 }
513
514 pub fn close_at_end(mut self, close: bool) -> Self {
516 self.config.close_at_end = close;
517 self
518 }
519
520 pub fn risk_free_rate(mut self, rate: f64) -> Self {
524 self.config.risk_free_rate = rate;
525 self
526 }
527
528 pub fn trailing_stop_pct(mut self, pct: f64) -> Self {
533 self.config.trailing_stop_pct = Some(pct);
534 self
535 }
536
537 pub fn reinvest_dividends(mut self, reinvest: bool) -> Self {
541 self.config.reinvest_dividends = reinvest;
542 self
543 }
544
545 pub fn bars_per_year(mut self, n: f64) -> Self {
553 self.config.bars_per_year = n;
554 self
555 }
556
557 pub fn spread_pct(mut self, pct: f64) -> Self {
563 self.config.spread_pct = pct;
564 self
565 }
566
567 pub fn transaction_tax_pct(mut self, pct: f64) -> Self {
572 self.config.transaction_tax_pct = pct;
573 self
574 }
575
576 pub fn commission_fn<F>(mut self, f: F) -> Self
594 where
595 F: Fn(f64, f64) -> f64 + Send + Sync + 'static,
596 {
597 self.config.commission_fn = Some(CommissionFn::new(f));
598 self
599 }
600
601 pub fn build(self) -> Result<BacktestConfig> {
603 self.config.validate()?;
604 Ok(self.config)
605 }
606}
607
608#[cfg(test)]
609mod tests {
610 use super::*;
611
612 #[test]
613 fn test_default_config() {
614 let config = BacktestConfig::default();
615 assert_eq!(config.initial_capital, 10_000.0);
616 assert!(config.validate().is_ok());
617 }
618
619 #[test]
620 fn test_builder() {
621 let config = BacktestConfig::builder()
622 .initial_capital(50_000.0)
623 .commission_pct(0.002)
624 .allow_short(true)
625 .stop_loss_pct(0.05)
626 .take_profit_pct(0.10)
627 .build()
628 .unwrap();
629
630 assert_eq!(config.initial_capital, 50_000.0);
631 assert_eq!(config.commission_pct, 0.002);
632 assert!(config.allow_short);
633 assert_eq!(config.stop_loss_pct, Some(0.05));
634 assert_eq!(config.take_profit_pct, Some(0.10));
635 }
636
637 #[test]
638 fn test_validation_failures() {
639 assert!(
640 BacktestConfig::builder()
641 .initial_capital(-100.0)
642 .build()
643 .is_err()
644 );
645
646 assert!(
647 BacktestConfig::builder()
648 .commission_pct(1.5)
649 .build()
650 .is_err()
651 );
652
653 assert!(
654 BacktestConfig::builder()
655 .stop_loss_pct(2.0)
656 .build()
657 .is_err()
658 );
659 }
660
661 #[test]
662 fn test_commission_calculation() {
663 let config = BacktestConfig::builder()
664 .commission(5.0)
665 .commission_pct(0.01)
666 .build()
667 .unwrap();
668
669 let commission = config.calculate_commission(10.0, 100.0);
671 assert!((commission - 15.0).abs() < 0.01);
672 }
673
674 #[test]
675 fn test_slippage() {
676 let config = BacktestConfig::builder()
677 .slippage_pct(0.01) .build()
679 .unwrap();
680
681 let entry_price = config.apply_entry_slippage(100.0, true);
683 assert!((entry_price - 101.0).abs() < 0.01);
684
685 let exit_price = config.apply_exit_slippage(100.0, true);
687 assert!((exit_price - 99.0).abs() < 0.01);
688
689 let short_entry = config.apply_entry_slippage(100.0, false);
691 assert!((short_entry - 99.0).abs() < 0.01);
692
693 let short_exit = config.apply_exit_slippage(100.0, false);
695 assert!((short_exit - 101.0).abs() < 0.01);
696 }
697
698 #[test]
699 fn test_position_sizing() {
700 let config = BacktestConfig::builder()
701 .position_size_pct(0.5) .commission_pct(0.0) .build()
704 .unwrap();
705
706 let size = config.calculate_position_size(10_000.0, 100.0);
708 assert!((size - 50.0).abs() < 0.01);
709 }
710
711 #[test]
712 fn test_risk_free_rate() {
713 let config = BacktestConfig::builder()
714 .risk_free_rate(0.05)
715 .build()
716 .unwrap();
717 assert!((config.risk_free_rate - 0.05).abs() < f64::EPSILON);
718
719 assert!(
721 BacktestConfig::builder()
722 .risk_free_rate(1.5)
723 .build()
724 .is_err()
725 );
726 }
727
728 #[test]
729 fn test_trailing_stop() {
730 let config = BacktestConfig::builder()
731 .trailing_stop_pct(0.05)
732 .build()
733 .unwrap();
734 assert_eq!(config.trailing_stop_pct, Some(0.05));
735
736 assert!(
738 BacktestConfig::builder()
739 .trailing_stop_pct(1.5)
740 .build()
741 .is_err()
742 );
743 }
744
745 #[test]
746 fn test_position_sizing_with_commission() {
747 let config = BacktestConfig::builder()
748 .position_size_pct(0.5) .commission_pct(0.001) .build()
751 .unwrap();
752
753 let size = config.calculate_position_size(10_000.0, 100.0);
757 let expected = 5000.0 / 1.002 / 100.0;
758 assert!((size - expected).abs() < 0.01);
759 }
760
761 #[test]
762 fn test_position_size_zero_rejected() {
763 assert!(
764 BacktestConfig::builder()
765 .position_size_pct(0.0)
766 .build()
767 .is_err()
768 );
769 }
770
771 #[test]
772 fn test_bars_per_year_validation() {
773 let config = BacktestConfig::default();
775 assert!((config.bars_per_year - 252.0).abs() < f64::EPSILON);
776 assert!(config.validate().is_ok());
777
778 let config = BacktestConfig::builder()
780 .bars_per_year(52.0)
781 .build()
782 .unwrap();
783 assert!((config.bars_per_year - 52.0).abs() < f64::EPSILON);
784
785 assert!(
787 BacktestConfig::builder()
788 .bars_per_year(0.0)
789 .build()
790 .is_err()
791 );
792
793 assert!(
795 BacktestConfig::builder()
796 .bars_per_year(-1.0)
797 .build()
798 .is_err()
799 );
800 }
801
802 #[test]
803 fn test_position_sizing_accounts_for_exit_commission() {
804 let comm = 0.01; let config = BacktestConfig::builder()
807 .commission_pct(comm)
808 .position_size_pct(1.0)
809 .build()
810 .unwrap();
811 let size = config.calculate_position_size(10_000.0, 100.0);
812 let expected = 10_000.0 / (1.0 + 2.0 * comm) / 100.0;
813 assert!((size - expected).abs() < 0.001);
814 }
815
816 #[test]
817 fn test_position_sizing_flat_commission_reduces_size() {
818 let config = BacktestConfig::builder()
820 .commission(10.0)
821 .commission_pct(0.0)
822 .position_size_pct(1.0)
823 .build()
824 .unwrap();
825 let size_with_flat = config.calculate_position_size(10_000.0, 100.0);
826
827 let config_no_flat = BacktestConfig::builder()
828 .commission_pct(0.0)
829 .position_size_pct(1.0)
830 .build()
831 .unwrap();
832 let size_no_flat = config_no_flat.calculate_position_size(10_000.0, 100.0);
833
834 assert!(size_with_flat < size_no_flat);
836 let expected = (10_000.0 - 20.0) / 100.0;
838 assert!((size_with_flat - expected).abs() < 0.001);
839 }
840
841 #[test]
842 fn test_position_sizing_flat_commission_exceeds_capital_returns_zero() {
843 let config = BacktestConfig::builder()
845 .commission(6_000.0) .position_size_pct(1.0)
847 .build()
848 .unwrap();
849 let size = config.calculate_position_size(10_000.0, 100.0);
850 assert_eq!(size, 0.0);
851 }
852
853 #[test]
856 fn test_spread_entry_long() {
857 let config = BacktestConfig::builder()
858 .spread_pct(0.0004) .build()
860 .unwrap();
861 let price = config.apply_entry_spread(100.0, true);
863 assert!((price - 100.02).abs() < 1e-10);
864 }
865
866 #[test]
867 fn test_spread_exit_long() {
868 let config = BacktestConfig::builder()
869 .spread_pct(0.0004)
870 .build()
871 .unwrap();
872 let price = config.apply_exit_spread(100.0, true);
874 assert!((price - 99.98).abs() < 1e-10);
875 }
876
877 #[test]
878 fn test_spread_entry_short() {
879 let config = BacktestConfig::builder()
880 .spread_pct(0.0004)
881 .build()
882 .unwrap();
883 let price = config.apply_entry_spread(100.0, false);
885 assert!((price - 99.98).abs() < 1e-10);
886 }
887
888 #[test]
889 fn test_spread_exit_short() {
890 let config = BacktestConfig::builder()
891 .spread_pct(0.0004)
892 .build()
893 .unwrap();
894 let price = config.apply_exit_spread(100.0, false);
896 assert!((price - 100.02).abs() < 1e-10);
897 }
898
899 #[test]
900 fn test_spread_zero_is_noop() {
901 let config = BacktestConfig::default(); assert!((config.apply_entry_spread(123.45, true) - 123.45).abs() < 1e-10);
903 assert!((config.apply_exit_spread(123.45, false) - 123.45).abs() < 1e-10);
904 }
905
906 #[test]
907 fn test_spread_validation() {
908 assert!(BacktestConfig::builder().spread_pct(1.5).build().is_err());
909 assert!(BacktestConfig::builder().spread_pct(-0.01).build().is_err());
910 assert!(BacktestConfig::builder().spread_pct(0.0).build().is_ok());
911 assert!(BacktestConfig::builder().spread_pct(1.0).build().is_ok());
912 }
913
914 #[test]
915 fn test_transaction_tax_on_buy() {
916 let config = BacktestConfig::builder()
917 .transaction_tax_pct(0.005) .build()
919 .unwrap();
920 let tax = config.calculate_transaction_tax(10_000.0, true);
921 assert!((tax - 50.0).abs() < 1e-10);
922 }
923
924 #[test]
925 fn test_transaction_tax_not_on_sell() {
926 let config = BacktestConfig::builder()
927 .transaction_tax_pct(0.005)
928 .build()
929 .unwrap();
930 let tax = config.calculate_transaction_tax(10_000.0, false);
931 assert_eq!(tax, 0.0);
932 }
933
934 #[test]
935 fn test_transaction_tax_zero_default() {
936 let config = BacktestConfig::default();
937 assert_eq!(config.calculate_transaction_tax(100_000.0, true), 0.0);
938 }
939
940 #[test]
941 fn test_transaction_tax_validation() {
942 assert!(
943 BacktestConfig::builder()
944 .transaction_tax_pct(1.5)
945 .build()
946 .is_err()
947 );
948 assert!(
949 BacktestConfig::builder()
950 .transaction_tax_pct(-0.001)
951 .build()
952 .is_err()
953 );
954 }
955
956 #[test]
957 fn test_commission_fn_replaces_flat_and_pct() {
958 let config = BacktestConfig::builder()
960 .commission_fn(|size, _price| (size * 0.005_f64).max(1.00))
961 .build()
962 .unwrap();
963 let comm = config.calculate_commission(100.0, 50.0);
965 assert!((comm - 1.00).abs() < 1e-10);
966 let comm = config.calculate_commission(500.0, 50.0);
968 assert!((comm - 2.50).abs() < 1e-10);
969 }
970
971 #[test]
972 fn test_commission_fn_ignores_flat_and_pct_fields() {
973 let config = BacktestConfig::builder()
975 .commission(5.0)
976 .commission_pct(0.01)
977 .commission_fn(|size, price| size * price * 0.0005)
978 .build()
979 .unwrap();
980 let comm = config.calculate_commission(10.0, 100.0);
982 assert!((comm - 0.50).abs() < 1e-10);
983 }
984
985 #[test]
986 fn test_commission_fn_fallback_when_none() {
987 let config = BacktestConfig::builder()
989 .commission(1.0)
990 .commission_pct(0.002)
991 .build()
992 .unwrap();
993 let comm = config.calculate_commission(10.0, 100.0);
995 assert!((comm - 3.0).abs() < 1e-10);
996 }
997
998 #[test]
999 fn test_position_sizing_includes_spread_and_tax() {
1000 let spread = 0.0004; let tax = 0.005; let config = BacktestConfig::builder()
1003 .commission_pct(0.0)
1004 .spread_pct(spread)
1005 .transaction_tax_pct(tax)
1006 .position_size_pct(1.0)
1007 .build()
1008 .unwrap();
1009
1010 let size = config.calculate_position_size(10_000.0, 100.0);
1011 let expected = 10_000.0 / (1.0 + spread + tax) / 100.0;
1012 assert!((size - expected).abs() < 0.01);
1013 }
1014
1015 #[test]
1016 fn test_zero_cost_clears_new_fields() {
1017 let config = BacktestConfig::zero_cost();
1018 assert_eq!(config.spread_pct, 0.0);
1019 assert_eq!(config.transaction_tax_pct, 0.0);
1020 assert!(config.commission_fn.is_none());
1021 }
1022}