1#[cfg(feature = "serde")]
2use serde::{Deserialize, Serialize};
3
4#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum Operator {
8 CrossesAbove,
10 CrossesBelow,
12 IsAbove,
14 IsBelow,
16 IsBetween,
18 Equals,
20 IsRising(u32),
22 IsFalling(u32),
24}
25
26#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
28#[derive(Debug, Clone, PartialEq)]
29pub enum CompareTarget {
30 Value(f64),
32 Indicator(String),
34 Scaled { indicator: String, multiplier: f64 },
36 Range(f64, f64),
38 None,
40}
41
42#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
44#[derive(Debug, Clone, PartialEq)]
45pub struct Condition {
46 pub left: String, pub operator: Operator,
48 pub right: CompareTarget,
49}
50
51impl Condition {
52 pub fn new(left: impl Into<String>, operator: Operator, right: CompareTarget) -> Self {
53 Self {
54 left: left.into(),
55 operator,
56 right,
57 }
58 }
59}
60
61#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
63#[derive(Debug, Clone, PartialEq)]
64pub enum ConditionGroup {
65 AllOf(Vec<ConditionNode>),
67 AnyOf(Vec<ConditionNode>),
69}
70
71#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
73#[derive(Debug, Clone, PartialEq)]
74pub enum ConditionNode {
75 Condition(Condition),
76 Group(ConditionGroup),
77}
78
79#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
81#[derive(Debug, Clone, Copy, PartialEq)]
82pub enum StopLoss {
83 FixedPercent(f64),
85 AtrMultiple(f64),
87 Trailing(f64),
89}
90
91#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
93#[derive(Debug, Clone, Copy, PartialEq)]
94pub enum TakeProfit {
95 FixedPercent(f64),
97 AtrMultiple(f64),
99}
100
101const MAX_NESTING_DEPTH: usize = 2;
103
104const MAX_CONDITIONS_PER_GROUP: usize = 20;
106
107#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
109#[derive(Debug, Clone, PartialEq)]
110pub struct Strategy {
111 pub name: String,
112 pub timeframe: crate::types::Timeframe,
113 pub entry: ConditionNode,
114 pub exit: Option<ConditionNode>,
115 pub stop_loss: StopLoss,
116 pub take_profit: Option<TakeProfit>,
117 pub max_position_size_pct: f64,
118 pub max_daily_loss_pct: f64,
119 pub max_drawdown_pct: f64,
120 pub max_concurrent_positions: usize,
121}
122
123impl Strategy {
124 pub fn builder(name: impl Into<String>) -> StrategyBuilder {
126 StrategyBuilder {
127 name: name.into(),
128 timeframe: crate::types::Timeframe::D1,
129 entry: None,
130 exit: None,
131 stop_loss: None,
132 take_profit: None,
133 max_position_size_pct: 5.0,
134 max_daily_loss_pct: 2.0,
135 max_drawdown_pct: 10.0,
136 max_concurrent_positions: 1,
137 }
138 }
139}
140
141#[derive(Debug)]
143pub struct StrategyBuilder {
144 name: String,
145 timeframe: crate::types::Timeframe,
146 entry: Option<ConditionNode>,
147 exit: Option<ConditionNode>,
148 stop_loss: Option<StopLoss>,
149 take_profit: Option<TakeProfit>,
150 max_position_size_pct: f64,
151 max_daily_loss_pct: f64,
152 max_drawdown_pct: f64,
153 max_concurrent_positions: usize,
154}
155
156impl StrategyBuilder {
157 pub fn timeframe(mut self, tf: crate::types::Timeframe) -> Self {
158 self.timeframe = tf;
159 self
160 }
161
162 pub fn entry(mut self, condition: ConditionNode) -> Self {
163 self.entry = Some(condition);
164 self
165 }
166
167 pub fn exit(mut self, condition: ConditionNode) -> Self {
168 self.exit = Some(condition);
169 self
170 }
171
172 pub fn stop_loss(mut self, sl: StopLoss) -> Self {
173 self.stop_loss = Some(sl);
174 self
175 }
176
177 pub fn take_profit(mut self, tp: TakeProfit) -> Self {
178 self.take_profit = Some(tp);
179 self
180 }
181
182 pub fn max_position_size_pct(mut self, pct: f64) -> Self {
183 self.max_position_size_pct = pct;
184 self
185 }
186
187 pub fn max_daily_loss_pct(mut self, pct: f64) -> Self {
188 self.max_daily_loss_pct = pct;
189 self
190 }
191
192 pub fn max_drawdown_pct(mut self, pct: f64) -> Self {
193 self.max_drawdown_pct = pct;
194 self
195 }
196
197 pub fn max_concurrent_positions(mut self, count: usize) -> Self {
198 self.max_concurrent_positions = count;
199 self
200 }
201
202 pub fn build(self) -> crate::types::Result<Strategy> {
204 let Some(entry) = self.entry else {
205 return Err(crate::types::MantisError::StrategyValidation(
206 "Strategy must have an entry condition".to_string(),
207 ));
208 };
209
210 let Some(stop_loss) = self.stop_loss else {
211 return Err(crate::types::MantisError::StrategyValidation(
212 "Strategy must have a stop-loss rule".to_string(),
213 ));
214 };
215
216 if self.max_position_size_pct < 0.1 || self.max_position_size_pct > 100.0 {
217 return Err(crate::types::MantisError::InvalidParameter {
218 param: "max_position_size_pct",
219 value: self.max_position_size_pct.to_string(),
220 reason: "must be between 0.1 and 100",
221 });
222 }
223
224 if self.max_daily_loss_pct < 0.1 || self.max_daily_loss_pct > 50.0 {
225 return Err(crate::types::MantisError::InvalidParameter {
226 param: "max_daily_loss_pct",
227 value: self.max_daily_loss_pct.to_string(),
228 reason: "must be between 0.1 and 50",
229 });
230 }
231
232 if self.max_drawdown_pct < 1.0 || self.max_drawdown_pct > 100.0 {
233 return Err(crate::types::MantisError::InvalidParameter {
234 param: "max_drawdown_pct",
235 value: self.max_drawdown_pct.to_string(),
236 reason: "must be between 1 and 100",
237 });
238 }
239
240 if self.max_concurrent_positions == 0 {
241 return Err(crate::types::MantisError::InvalidParameter {
242 param: "max_concurrent_positions",
243 value: "0".to_string(),
244 reason: "must be at least 1",
245 });
246 }
247
248 validate_condition_node(&entry, 0)?;
250 if let Some(exit) = &self.exit {
251 validate_condition_node(exit, 0)?;
252 }
253
254 Ok(Strategy {
255 name: self.name,
256 timeframe: self.timeframe,
257 entry,
258 exit: self.exit,
259 stop_loss,
260 take_profit: self.take_profit,
261 max_position_size_pct: self.max_position_size_pct,
262 max_daily_loss_pct: self.max_daily_loss_pct,
263 max_drawdown_pct: self.max_drawdown_pct,
264 max_concurrent_positions: self.max_concurrent_positions,
265 })
266 }
267}
268
269fn validate_condition_node(node: &ConditionNode, depth: usize) -> crate::types::Result<()> {
271 if depth > MAX_NESTING_DEPTH {
272 return Err(crate::types::MantisError::StrategyValidation(format!(
273 "Condition nesting exceeds maximum depth of {MAX_NESTING_DEPTH}"
274 )));
275 }
276 if let ConditionNode::Group(group) = node {
277 let children = match group {
278 ConditionGroup::AllOf(c) | ConditionGroup::AnyOf(c) => c,
279 };
280 if children.len() > MAX_CONDITIONS_PER_GROUP {
281 return Err(crate::types::MantisError::StrategyValidation(format!(
282 "Condition group exceeds maximum of {MAX_CONDITIONS_PER_GROUP} conditions"
283 )));
284 }
285 for child in children {
286 validate_condition_node(child, depth + 1)?;
287 }
288 }
289 Ok(())
290}
291
292#[cfg(test)]
293mod tests {
294 use super::*;
295
296 fn sample_condition() -> ConditionNode {
297 ConditionNode::Condition(Condition::new(
298 "sma20",
299 Operator::CrossesAbove,
300 CompareTarget::Value(100.0),
301 ))
302 }
303
304 fn valid_builder() -> StrategyBuilder {
306 Strategy::builder("test")
307 .entry(sample_condition())
308 .stop_loss(StopLoss::FixedPercent(2.0))
309 }
310
311 #[test]
312 fn builder_requires_entry() {
313 let result = Strategy::builder("test")
314 .exit(sample_condition())
315 .stop_loss(StopLoss::FixedPercent(2.0))
316 .build();
317 assert!(result.is_err());
318 }
319
320 #[test]
321 fn builder_requires_stop_loss() {
322 let result = Strategy::builder("test").entry(sample_condition()).build();
323 assert!(result.is_err());
324 }
325
326 #[test]
327 fn builder_validates_position_size() {
328 let result = valid_builder().max_position_size_pct(150.0).build();
329 assert!(result.is_err());
330
331 let result = valid_builder().max_position_size_pct(0.05).build();
332 assert!(result.is_err());
333 }
334
335 #[test]
336 fn builder_validates_daily_loss_bounds() {
337 let result = valid_builder().max_daily_loss_pct(51.0).build();
338 assert!(result.is_err());
339
340 let result = valid_builder().max_daily_loss_pct(0.05).build();
341 assert!(result.is_err());
342 }
343
344 #[test]
345 fn builder_validates_drawdown_bounds() {
346 let result = valid_builder().max_drawdown_pct(0.5).build();
347 assert!(result.is_err());
348 }
349
350 #[test]
351 fn builder_creates_valid_strategy() {
352 let result = valid_builder().build();
353 assert!(result.is_ok());
354 let strategy = result.unwrap();
355 assert_eq!(strategy.name, "test");
356 }
357
358 #[test]
359 fn builder_rejects_excessive_nesting() {
360 let leaf = sample_condition();
362 let depth2 = ConditionNode::Group(ConditionGroup::AllOf(vec![leaf]));
363 let depth1 = ConditionNode::Group(ConditionGroup::AllOf(vec![depth2]));
364 let depth0 = ConditionNode::Group(ConditionGroup::AllOf(vec![depth1]));
365
366 let result = valid_builder().entry(depth0).build();
367 assert!(result.is_err());
368 }
369
370 #[test]
371 fn builder_accepts_valid_nesting() {
372 let leaf = sample_condition();
374 let depth1 = ConditionNode::Group(ConditionGroup::AllOf(vec![leaf]));
375 let depth0 = ConditionNode::Group(ConditionGroup::AllOf(vec![depth1]));
376
377 let result = valid_builder().entry(depth0).build();
378 assert!(result.is_ok());
379 }
380
381 #[test]
382 fn builder_rejects_oversized_group() {
383 let conditions: Vec<ConditionNode> = (0..21).map(|_| sample_condition()).collect();
384 let group = ConditionNode::Group(ConditionGroup::AllOf(conditions));
385
386 let result = valid_builder().entry(group).build();
387 assert!(result.is_err());
388 }
389
390 #[cfg(feature = "serde")]
391 #[test]
392 fn strategy_serde_round_trip() {
393 let entry = ConditionNode::Condition(Condition::new(
394 "sma_20",
395 Operator::CrossesAbove,
396 CompareTarget::Indicator("sma_50".to_string()),
397 ));
398 let exit = ConditionNode::Condition(Condition::new(
399 "sma_20",
400 Operator::CrossesBelow,
401 CompareTarget::Indicator("sma_50".to_string()),
402 ));
403 let strategy = Strategy::builder("round_trip_test")
404 .entry(entry)
405 .exit(exit)
406 .stop_loss(StopLoss::FixedPercent(2.0))
407 .take_profit(TakeProfit::AtrMultiple(1.5))
408 .max_concurrent_positions(3)
409 .build()
410 .unwrap();
411
412 let json = serde_json::to_string(&strategy).unwrap();
413 let deserialized: Strategy = serde_json::from_str(&json).unwrap();
414
415 assert_eq!(strategy, deserialized);
416 }
417
418 #[test]
419 fn condition_group_nesting() {
420 let cond1 = ConditionNode::Condition(Condition::new(
421 "sma_20",
422 Operator::IsAbove,
423 CompareTarget::Value(100.0),
424 ));
425 let cond2 = ConditionNode::Condition(Condition::new(
426 "rsi_14",
427 Operator::IsBelow,
428 CompareTarget::Value(70.0),
429 ));
430 let group = ConditionNode::Group(ConditionGroup::AllOf(vec![cond1, cond2]));
431 assert!(matches!(group, ConditionNode::Group(_)));
432 }
433}