1#[cfg(feature = "serde")]
2use serde::{Deserialize, Serialize};
3
4#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum Operator {
8 CrossesAbove,
10 CrossesBelow,
12 IsAbove,
14 IsBelow,
16 IsBetween,
18 Equals,
20 IsRising(u32),
22 IsFalling(u32),
24}
25
26#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
28#[derive(Debug, Clone, PartialEq)]
29pub enum CompareTarget {
30 Value(f64),
32 Indicator(String),
34 Scaled { indicator: String, multiplier: f64 },
36 Range(f64, f64),
38 None,
40}
41
42#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
44#[derive(Debug, Clone, PartialEq)]
45pub struct Condition {
46 pub left: String, pub operator: Operator,
48 pub right: CompareTarget,
49}
50
51impl Condition {
52 pub fn new(left: impl Into<String>, operator: Operator, right: CompareTarget) -> Self {
53 Self {
54 left: left.into(),
55 operator,
56 right,
57 }
58 }
59}
60
61#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
63#[derive(Debug, Clone, PartialEq)]
64pub enum ConditionGroup {
65 AllOf(Vec<ConditionNode>),
67 AnyOf(Vec<ConditionNode>),
69}
70
71#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
73#[derive(Debug, Clone, PartialEq)]
74pub enum ConditionNode {
75 Condition(Condition),
76 Group(ConditionGroup),
77}
78
79#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
81#[derive(Debug, Clone, Copy, PartialEq)]
82pub enum StopLoss {
83 FixedPercent(f64),
85 AtrMultiple(f64),
87 Trailing(f64),
89}
90
91#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
93#[derive(Debug, Clone, Copy, PartialEq)]
94pub enum TakeProfit {
95 FixedPercent(f64),
97 AtrMultiple(f64),
99}
100
101const MAX_NESTING_DEPTH: usize = 2;
103
104const MAX_CONDITIONS_PER_GROUP: usize = 20;
106
107#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
109#[derive(Debug, Clone, PartialEq)]
110pub struct Strategy {
111 pub name: String,
112 pub timeframe: crate::types::Timeframe,
113 pub entry: ConditionNode,
114 pub exit: ConditionNode,
115 pub stop_loss: StopLoss,
116 pub take_profit: TakeProfit,
117 pub max_position_size_pct: f64,
118 pub max_daily_loss_pct: f64,
119 pub max_drawdown_pct: f64,
120 pub max_concurrent_positions: usize,
121}
122
123impl Strategy {
124 pub fn builder(name: impl Into<String>) -> StrategyBuilder {
126 StrategyBuilder {
127 name: name.into(),
128 timeframe: crate::types::Timeframe::D1,
129 entry: None,
130 exit: None,
131 stop_loss: None,
132 take_profit: None,
133 max_position_size_pct: 5.0,
134 max_daily_loss_pct: 2.0,
135 max_drawdown_pct: 10.0,
136 max_concurrent_positions: 1,
137 }
138 }
139}
140
141#[derive(Debug)]
143pub struct StrategyBuilder {
144 name: String,
145 timeframe: crate::types::Timeframe,
146 entry: Option<ConditionNode>,
147 exit: Option<ConditionNode>,
148 stop_loss: Option<StopLoss>,
149 take_profit: Option<TakeProfit>,
150 max_position_size_pct: f64,
151 max_daily_loss_pct: f64,
152 max_drawdown_pct: f64,
153 max_concurrent_positions: usize,
154}
155
156impl StrategyBuilder {
157 pub fn timeframe(mut self, tf: crate::types::Timeframe) -> Self {
158 self.timeframe = tf;
159 self
160 }
161
162 pub fn entry(mut self, condition: ConditionNode) -> Self {
163 self.entry = Some(condition);
164 self
165 }
166
167 pub fn exit(mut self, condition: ConditionNode) -> Self {
168 self.exit = Some(condition);
169 self
170 }
171
172 pub fn stop_loss(mut self, sl: StopLoss) -> Self {
173 self.stop_loss = Some(sl);
174 self
175 }
176
177 pub fn take_profit(mut self, tp: TakeProfit) -> Self {
178 self.take_profit = Some(tp);
179 self
180 }
181
182 pub fn max_position_size_pct(mut self, pct: f64) -> Self {
183 self.max_position_size_pct = pct;
184 self
185 }
186
187 pub fn max_daily_loss_pct(mut self, pct: f64) -> Self {
188 self.max_daily_loss_pct = pct;
189 self
190 }
191
192 pub fn max_drawdown_pct(mut self, pct: f64) -> Self {
193 self.max_drawdown_pct = pct;
194 self
195 }
196
197 pub fn max_concurrent_positions(mut self, count: usize) -> Self {
198 self.max_concurrent_positions = count;
199 self
200 }
201
202 pub fn build(self) -> crate::types::Result<Strategy> {
204 let Some(entry) = self.entry else {
205 return Err(crate::types::MantisError::StrategyValidation(
206 "Strategy must have an entry condition".to_string(),
207 ));
208 };
209
210 let Some(exit) = self.exit else {
211 return Err(crate::types::MantisError::StrategyValidation(
212 "Strategy must have an exit condition".to_string(),
213 ));
214 };
215
216 let Some(stop_loss) = self.stop_loss else {
217 return Err(crate::types::MantisError::StrategyValidation(
218 "Strategy must have a stop-loss rule".to_string(),
219 ));
220 };
221
222 let Some(take_profit) = self.take_profit else {
223 return Err(crate::types::MantisError::StrategyValidation(
224 "Strategy must have a take-profit rule".to_string(),
225 ));
226 };
227
228 if self.max_position_size_pct < 0.1 || self.max_position_size_pct > 100.0 {
229 return Err(crate::types::MantisError::InvalidParameter {
230 param: "max_position_size_pct",
231 value: self.max_position_size_pct.to_string(),
232 reason: "must be between 0.1 and 100",
233 });
234 }
235
236 if self.max_daily_loss_pct < 0.1 || self.max_daily_loss_pct > 50.0 {
237 return Err(crate::types::MantisError::InvalidParameter {
238 param: "max_daily_loss_pct",
239 value: self.max_daily_loss_pct.to_string(),
240 reason: "must be between 0.1 and 50",
241 });
242 }
243
244 if self.max_drawdown_pct < 1.0 || self.max_drawdown_pct > 100.0 {
245 return Err(crate::types::MantisError::InvalidParameter {
246 param: "max_drawdown_pct",
247 value: self.max_drawdown_pct.to_string(),
248 reason: "must be between 1 and 100",
249 });
250 }
251
252 if self.max_concurrent_positions == 0 {
253 return Err(crate::types::MantisError::InvalidParameter {
254 param: "max_concurrent_positions",
255 value: "0".to_string(),
256 reason: "must be at least 1",
257 });
258 }
259
260 validate_condition_node(&entry, 0)?;
262 validate_condition_node(&exit, 0)?;
263
264 Ok(Strategy {
265 name: self.name,
266 timeframe: self.timeframe,
267 entry,
268 exit,
269 stop_loss,
270 take_profit,
271 max_position_size_pct: self.max_position_size_pct,
272 max_daily_loss_pct: self.max_daily_loss_pct,
273 max_drawdown_pct: self.max_drawdown_pct,
274 max_concurrent_positions: self.max_concurrent_positions,
275 })
276 }
277}
278
279fn validate_condition_node(node: &ConditionNode, depth: usize) -> crate::types::Result<()> {
281 if depth > MAX_NESTING_DEPTH {
282 return Err(crate::types::MantisError::StrategyValidation(format!(
283 "Condition nesting exceeds maximum depth of {}",
284 MAX_NESTING_DEPTH
285 )));
286 }
287 if let ConditionNode::Group(group) = node {
288 let children = match group {
289 ConditionGroup::AllOf(c) | ConditionGroup::AnyOf(c) => c,
290 };
291 if children.len() > MAX_CONDITIONS_PER_GROUP {
292 return Err(crate::types::MantisError::StrategyValidation(format!(
293 "Condition group exceeds maximum of {} conditions",
294 MAX_CONDITIONS_PER_GROUP
295 )));
296 }
297 for child in children {
298 validate_condition_node(child, depth + 1)?;
299 }
300 }
301 Ok(())
302}
303
304#[cfg(test)]
305mod tests {
306 use super::*;
307
308 fn sample_condition() -> ConditionNode {
309 ConditionNode::Condition(Condition::new(
310 "sma_20",
311 Operator::CrossesAbove,
312 CompareTarget::Value(100.0),
313 ))
314 }
315
316 fn valid_builder() -> StrategyBuilder {
318 Strategy::builder("test")
319 .entry(sample_condition())
320 .exit(sample_condition())
321 .stop_loss(StopLoss::FixedPercent(2.0))
322 .take_profit(TakeProfit::FixedPercent(5.0))
323 }
324
325 #[test]
326 fn builder_requires_entry() {
327 let result = Strategy::builder("test")
328 .exit(sample_condition())
329 .stop_loss(StopLoss::FixedPercent(2.0))
330 .take_profit(TakeProfit::FixedPercent(5.0))
331 .build();
332 assert!(result.is_err());
333 }
334
335 #[test]
336 fn builder_requires_exit() {
337 let result = Strategy::builder("test")
338 .entry(sample_condition())
339 .stop_loss(StopLoss::FixedPercent(2.0))
340 .take_profit(TakeProfit::FixedPercent(5.0))
341 .build();
342 assert!(result.is_err());
343 }
344
345 #[test]
346 fn builder_requires_stop_loss() {
347 let result = Strategy::builder("test")
348 .entry(sample_condition())
349 .exit(sample_condition())
350 .take_profit(TakeProfit::FixedPercent(5.0))
351 .build();
352 assert!(result.is_err());
353 }
354
355 #[test]
356 fn builder_requires_take_profit() {
357 let result = Strategy::builder("test")
358 .entry(sample_condition())
359 .exit(sample_condition())
360 .stop_loss(StopLoss::FixedPercent(2.0))
361 .build();
362 assert!(result.is_err());
363 }
364
365 #[test]
366 fn builder_validates_position_size() {
367 let result = valid_builder().max_position_size_pct(150.0).build();
368 assert!(result.is_err());
369
370 let result = valid_builder().max_position_size_pct(0.05).build();
371 assert!(result.is_err());
372 }
373
374 #[test]
375 fn builder_validates_daily_loss_bounds() {
376 let result = valid_builder().max_daily_loss_pct(51.0).build();
377 assert!(result.is_err());
378
379 let result = valid_builder().max_daily_loss_pct(0.05).build();
380 assert!(result.is_err());
381 }
382
383 #[test]
384 fn builder_validates_drawdown_bounds() {
385 let result = valid_builder().max_drawdown_pct(0.5).build();
386 assert!(result.is_err());
387 }
388
389 #[test]
390 fn builder_creates_valid_strategy() {
391 let result = valid_builder().build();
392 assert!(result.is_ok());
393 let strategy = result.unwrap();
394 assert_eq!(strategy.name, "test");
395 }
396
397 #[test]
398 fn builder_rejects_excessive_nesting() {
399 let leaf = sample_condition();
401 let depth2 = ConditionNode::Group(ConditionGroup::AllOf(vec![leaf]));
402 let depth1 = ConditionNode::Group(ConditionGroup::AllOf(vec![depth2]));
403 let depth0 = ConditionNode::Group(ConditionGroup::AllOf(vec![depth1]));
404
405 let result = valid_builder().entry(depth0).build();
406 assert!(result.is_err());
407 }
408
409 #[test]
410 fn builder_accepts_valid_nesting() {
411 let leaf = sample_condition();
413 let depth1 = ConditionNode::Group(ConditionGroup::AllOf(vec![leaf]));
414 let depth0 = ConditionNode::Group(ConditionGroup::AllOf(vec![depth1]));
415
416 let result = valid_builder().entry(depth0).build();
417 assert!(result.is_ok());
418 }
419
420 #[test]
421 fn builder_rejects_oversized_group() {
422 let conditions: Vec<ConditionNode> = (0..21).map(|_| sample_condition()).collect();
423 let group = ConditionNode::Group(ConditionGroup::AllOf(conditions));
424
425 let result = valid_builder().entry(group).build();
426 assert!(result.is_err());
427 }
428
429 #[cfg(feature = "serde")]
430 #[test]
431 fn strategy_serde_round_trip() {
432 let entry = ConditionNode::Condition(Condition::new(
433 "sma_20",
434 Operator::CrossesAbove,
435 CompareTarget::Indicator("sma_50".to_string()),
436 ));
437 let exit = ConditionNode::Condition(Condition::new(
438 "sma_20",
439 Operator::CrossesBelow,
440 CompareTarget::Indicator("sma_50".to_string()),
441 ));
442 let strategy = Strategy::builder("round_trip_test")
443 .entry(entry)
444 .exit(exit)
445 .stop_loss(StopLoss::FixedPercent(2.0))
446 .take_profit(TakeProfit::AtrMultiple(1.5))
447 .max_concurrent_positions(3)
448 .build()
449 .unwrap();
450
451 let json = serde_json::to_string(&strategy).unwrap();
452 let deserialized: Strategy = serde_json::from_str(&json).unwrap();
453
454 assert_eq!(strategy, deserialized);
455 }
456
457 #[test]
458 fn condition_group_nesting() {
459 let cond1 = ConditionNode::Condition(Condition::new(
460 "sma_20",
461 Operator::IsAbove,
462 CompareTarget::Value(100.0),
463 ));
464 let cond2 = ConditionNode::Condition(Condition::new(
465 "rsi_14",
466 Operator::IsBelow,
467 CompareTarget::Value(70.0),
468 ));
469 let group = ConditionNode::Group(ConditionGroup::AllOf(vec![cond1, cond2]));
470 assert!(matches!(group, ConditionNode::Group(_)));
471 }
472}