finance_query/backtesting/portfolio/
config.rs1use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5
6use crate::backtesting::config::BacktestConfig;
7use crate::backtesting::error::{BacktestError, Result};
8
9#[non_exhaustive]
11#[derive(Debug, Clone, Serialize, Deserialize, Default)]
12pub enum RebalanceMode {
13 #[default]
17 AvailableCapital,
18
19 EqualWeight,
39
40 CustomWeights(HashMap<String, f64>),
45}
46
47#[non_exhaustive]
49#[derive(Debug, Clone, Serialize, Deserialize, Default)]
50pub struct PortfolioConfig {
51 pub base: BacktestConfig,
53
54 pub max_allocation_per_symbol: Option<f64>,
58
59 pub max_total_positions: Option<usize>,
65
66 pub rebalance: RebalanceMode,
68}
69
70impl PortfolioConfig {
71 pub fn new(base: BacktestConfig) -> Self {
73 Self {
74 base,
75 ..Self::default()
76 }
77 }
78
79 pub fn max_allocation_per_symbol(mut self, pct: f64) -> Self {
81 self.max_allocation_per_symbol = Some(pct);
82 self
83 }
84
85 pub fn max_total_positions(mut self, max: usize) -> Self {
87 self.max_total_positions = Some(max);
88 self
89 }
90
91 pub fn rebalance(mut self, mode: RebalanceMode) -> Self {
93 self.rebalance = mode;
94 self
95 }
96
97 pub fn validate(&self, num_symbols: usize) -> Result<()> {
99 self.base.validate()?;
100
101 if let Some(cap) = self.max_allocation_per_symbol
102 && !(0.0..=1.0).contains(&cap)
103 {
104 return Err(BacktestError::invalid_param(
105 "max_allocation_per_symbol",
106 "must be between 0.0 and 1.0",
107 ));
108 }
109
110 if let RebalanceMode::CustomWeights(ref weights) = self.rebalance {
111 for (sym, &w) in weights {
112 if !(0.0..=1.0).contains(&w) {
113 return Err(BacktestError::invalid_param(
114 sym.as_str(),
115 "custom weight must be between 0.0 and 1.0",
116 ));
117 }
118 }
119 }
120
121 if num_symbols == 0 {
122 return Err(BacktestError::invalid_param(
123 "symbol_data",
124 "at least one symbol is required",
125 ));
126 }
127
128 Ok(())
129 }
130
131 pub(crate) fn allocation_target(
136 &self,
137 symbol: &str,
138 available_cash: f64,
139 initial_capital: f64,
140 num_symbols: usize,
141 ) -> f64 {
142 let base = match &self.rebalance {
143 RebalanceMode::AvailableCapital => available_cash * self.base.position_size_pct,
144 RebalanceMode::EqualWeight => {
145 let slots = self
146 .max_total_positions
147 .unwrap_or(num_symbols)
148 .min(num_symbols)
149 .max(1);
150 initial_capital / slots as f64
151 }
152 RebalanceMode::CustomWeights(weights) => {
153 let weight = weights.get(symbol).copied().unwrap_or(0.0);
154 initial_capital * weight
155 }
156 };
157
158 let cap = self
160 .max_allocation_per_symbol
161 .map(|pct| initial_capital * pct)
162 .unwrap_or(f64::MAX);
163
164 base.min(cap).min(available_cash).max(0.0)
165 }
166}
167
168#[cfg(test)]
169mod tests {
170 use super::*;
171
172 #[test]
173 fn test_default_config_validates() {
174 let config = PortfolioConfig::default();
175 assert!(config.validate(1).is_ok());
176 }
177
178 #[test]
179 fn test_custom_weights_allocation() {
180 let mut weights = HashMap::new();
181 weights.insert("AAPL".to_string(), 0.5);
182 weights.insert("MSFT".to_string(), 0.3);
183 let config = PortfolioConfig::default().rebalance(RebalanceMode::CustomWeights(weights));
184
185 let target = config.allocation_target("AAPL", 10_000.0, 10_000.0, 2);
186 assert!((target - 5_000.0).abs() < 0.01);
187
188 let target_unknown = config.allocation_target("GOOG", 10_000.0, 10_000.0, 2);
190 assert!((target_unknown - 0.0).abs() < 0.01);
191 }
192
193 #[test]
194 fn test_max_allocation_cap() {
195 let config = PortfolioConfig::default().max_allocation_per_symbol(0.3);
196 let config = config
198 .rebalance(RebalanceMode::EqualWeight)
199 .max_total_positions(2);
200 let target = config.allocation_target("AAPL", 10_000.0, 10_000.0, 2);
201 assert!((target - 3_000.0).abs() < 0.01, "got {target}");
202 }
203
204 #[test]
205 fn test_validation_zero_symbols() {
206 let config = PortfolioConfig::default();
207 assert!(config.validate(0).is_err());
208 }
209
210 #[test]
211 fn test_validation_invalid_cap() {
212 let config = PortfolioConfig::default().max_allocation_per_symbol(1.5);
213 assert!(config.validate(1).is_err());
214 }
215}