Skip to main content

mantis_ta/strategy/
types.rs

1#[cfg(feature = "serde")]
2use serde::{Deserialize, Serialize};
3
4/// Comparison operators for condition evaluation.
5#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum Operator {
8    /// Left crosses above right
9    CrossesAbove,
10    /// Left crosses below right
11    CrossesBelow,
12    /// Left is strictly above right
13    IsAbove,
14    /// Left is strictly below right
15    IsBelow,
16    /// Left is between lower and upper bounds
17    IsBetween,
18    /// Left equals right (within epsilon)
19    Equals,
20    /// Left is rising over N bars (current > N bars ago)
21    IsRising(u32),
22    /// Left is falling over N bars (current < N bars ago)
23    IsFalling(u32),
24}
25
26/// Right-hand side of a comparison in a condition.
27#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
28#[derive(Debug, Clone, PartialEq)]
29pub enum CompareTarget {
30    /// Compare against a fixed scalar value
31    Value(f64),
32    /// Compare against another indicator's output
33    Indicator(String),
34    /// Compare against a scaled value (e.g., ATR * 2.0)
35    Scaled { indicator: String, multiplier: f64 },
36    /// Compare against a range of values (lower, upper)
37    Range(f64, f64),
38    /// No compare target (used for unary operators like IsRising/IsFalling)
39    None,
40}
41
42/// A single condition: left indicator, operator, right target.
43#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
44#[derive(Debug, Clone, PartialEq)]
45pub struct Condition {
46    pub left: String, // indicator name/id
47    pub operator: Operator,
48    pub right: CompareTarget,
49}
50
51impl Condition {
52    pub fn new(left: impl Into<String>, operator: Operator, right: CompareTarget) -> Self {
53        Self {
54            left: left.into(),
55            operator,
56            right,
57        }
58    }
59}
60
61/// Logical grouping of conditions.
62#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
63#[derive(Debug, Clone, PartialEq)]
64pub enum ConditionGroup {
65    /// All sub-conditions must be true
66    AllOf(Vec<ConditionNode>),
67    /// Any sub-condition must be true
68    AnyOf(Vec<ConditionNode>),
69}
70
71/// A node in the condition tree.
72#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
73#[derive(Debug, Clone, PartialEq)]
74pub enum ConditionNode {
75    Condition(Condition),
76    Group(ConditionGroup),
77}
78
79/// Stop-loss configuration.
80#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
81#[derive(Debug, Clone, Copy, PartialEq)]
82pub enum StopLoss {
83    /// Fixed percentage below entry
84    FixedPercent(f64),
85    /// ATR multiple below entry
86    AtrMultiple(f64),
87    /// Trailing stop: fixed percentage below highest price
88    Trailing(f64),
89}
90
91/// Take-profit configuration.
92#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
93#[derive(Debug, Clone, Copy, PartialEq)]
94pub enum TakeProfit {
95    /// Fixed percentage above entry
96    FixedPercent(f64),
97    /// ATR multiple above entry
98    AtrMultiple(f64),
99}
100
101/// Maximum nesting depth for condition groups (SPEC §5.3).
102const MAX_NESTING_DEPTH: usize = 2;
103
104/// Maximum conditions per group (SPEC §5.3).
105const MAX_CONDITIONS_PER_GROUP: usize = 20;
106
107/// A trading strategy composed of conditions and risk rules.
108#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
109#[derive(Debug, Clone, PartialEq)]
110pub struct Strategy {
111    pub name: String,
112    pub timeframe: crate::types::Timeframe,
113    pub entry: ConditionNode,
114    pub exit: ConditionNode,
115    pub stop_loss: StopLoss,
116    pub take_profit: TakeProfit,
117    pub max_position_size_pct: f64,
118    pub max_daily_loss_pct: f64,
119    pub max_drawdown_pct: f64,
120    pub max_concurrent_positions: usize,
121}
122
123impl Strategy {
124    /// Create a new strategy builder.
125    pub fn builder(name: impl Into<String>) -> StrategyBuilder {
126        StrategyBuilder {
127            name: name.into(),
128            timeframe: crate::types::Timeframe::D1,
129            entry: None,
130            exit: None,
131            stop_loss: None,
132            take_profit: None,
133            max_position_size_pct: 5.0,
134            max_daily_loss_pct: 2.0,
135            max_drawdown_pct: 10.0,
136            max_concurrent_positions: 1,
137        }
138    }
139}
140
141/// Fluent builder for constructing strategies with validation.
142#[derive(Debug)]
143pub struct StrategyBuilder {
144    name: String,
145    timeframe: crate::types::Timeframe,
146    entry: Option<ConditionNode>,
147    exit: Option<ConditionNode>,
148    stop_loss: Option<StopLoss>,
149    take_profit: Option<TakeProfit>,
150    max_position_size_pct: f64,
151    max_daily_loss_pct: f64,
152    max_drawdown_pct: f64,
153    max_concurrent_positions: usize,
154}
155
156impl StrategyBuilder {
157    pub fn timeframe(mut self, tf: crate::types::Timeframe) -> Self {
158        self.timeframe = tf;
159        self
160    }
161
162    pub fn entry(mut self, condition: ConditionNode) -> Self {
163        self.entry = Some(condition);
164        self
165    }
166
167    pub fn exit(mut self, condition: ConditionNode) -> Self {
168        self.exit = Some(condition);
169        self
170    }
171
172    pub fn stop_loss(mut self, sl: StopLoss) -> Self {
173        self.stop_loss = Some(sl);
174        self
175    }
176
177    pub fn take_profit(mut self, tp: TakeProfit) -> Self {
178        self.take_profit = Some(tp);
179        self
180    }
181
182    pub fn max_position_size_pct(mut self, pct: f64) -> Self {
183        self.max_position_size_pct = pct;
184        self
185    }
186
187    pub fn max_daily_loss_pct(mut self, pct: f64) -> Self {
188        self.max_daily_loss_pct = pct;
189        self
190    }
191
192    pub fn max_drawdown_pct(mut self, pct: f64) -> Self {
193        self.max_drawdown_pct = pct;
194        self
195    }
196
197    pub fn max_concurrent_positions(mut self, count: usize) -> Self {
198        self.max_concurrent_positions = count;
199        self
200    }
201
202    /// Build and validate the strategy (SPEC §5.3).
203    pub fn build(self) -> crate::types::Result<Strategy> {
204        let Some(entry) = self.entry else {
205            return Err(crate::types::MantisError::StrategyValidation(
206                "Strategy must have an entry condition".to_string(),
207            ));
208        };
209
210        let Some(exit) = self.exit else {
211            return Err(crate::types::MantisError::StrategyValidation(
212                "Strategy must have an exit condition".to_string(),
213            ));
214        };
215
216        let Some(stop_loss) = self.stop_loss else {
217            return Err(crate::types::MantisError::StrategyValidation(
218                "Strategy must have a stop-loss rule".to_string(),
219            ));
220        };
221
222        let Some(take_profit) = self.take_profit else {
223            return Err(crate::types::MantisError::StrategyValidation(
224                "Strategy must have a take-profit rule".to_string(),
225            ));
226        };
227
228        if self.max_position_size_pct < 0.1 || self.max_position_size_pct > 100.0 {
229            return Err(crate::types::MantisError::InvalidParameter {
230                param: "max_position_size_pct",
231                value: self.max_position_size_pct.to_string(),
232                reason: "must be between 0.1 and 100",
233            });
234        }
235
236        if self.max_daily_loss_pct < 0.1 || self.max_daily_loss_pct > 50.0 {
237            return Err(crate::types::MantisError::InvalidParameter {
238                param: "max_daily_loss_pct",
239                value: self.max_daily_loss_pct.to_string(),
240                reason: "must be between 0.1 and 50",
241            });
242        }
243
244        if self.max_drawdown_pct < 1.0 || self.max_drawdown_pct > 100.0 {
245            return Err(crate::types::MantisError::InvalidParameter {
246                param: "max_drawdown_pct",
247                value: self.max_drawdown_pct.to_string(),
248                reason: "must be between 1 and 100",
249            });
250        }
251
252        if self.max_concurrent_positions == 0 {
253            return Err(crate::types::MantisError::InvalidParameter {
254                param: "max_concurrent_positions",
255                value: "0".to_string(),
256                reason: "must be at least 1",
257            });
258        }
259
260        // Validate condition nesting depth and group sizes
261        validate_condition_node(&entry, 0)?;
262        validate_condition_node(&exit, 0)?;
263
264        Ok(Strategy {
265            name: self.name,
266            timeframe: self.timeframe,
267            entry,
268            exit,
269            stop_loss,
270            take_profit,
271            max_position_size_pct: self.max_position_size_pct,
272            max_daily_loss_pct: self.max_daily_loss_pct,
273            max_drawdown_pct: self.max_drawdown_pct,
274            max_concurrent_positions: self.max_concurrent_positions,
275        })
276    }
277}
278
279/// Recursively validate condition nesting depth and group sizes (SPEC §5.3).
280fn validate_condition_node(node: &ConditionNode, depth: usize) -> crate::types::Result<()> {
281    if depth > MAX_NESTING_DEPTH {
282        return Err(crate::types::MantisError::StrategyValidation(format!(
283            "Condition nesting exceeds maximum depth of {}",
284            MAX_NESTING_DEPTH
285        )));
286    }
287    if let ConditionNode::Group(group) = node {
288        let children = match group {
289            ConditionGroup::AllOf(c) | ConditionGroup::AnyOf(c) => c,
290        };
291        if children.len() > MAX_CONDITIONS_PER_GROUP {
292            return Err(crate::types::MantisError::StrategyValidation(format!(
293                "Condition group exceeds maximum of {} conditions",
294                MAX_CONDITIONS_PER_GROUP
295            )));
296        }
297        for child in children {
298            validate_condition_node(child, depth + 1)?;
299        }
300    }
301    Ok(())
302}
303
304#[cfg(test)]
305mod tests {
306    use super::*;
307
308    fn sample_condition() -> ConditionNode {
309        ConditionNode::Condition(Condition::new(
310            "sma_20",
311            Operator::CrossesAbove,
312            CompareTarget::Value(100.0),
313        ))
314    }
315
316    /// Helper to build a valid strategy with all mandatory fields.
317    fn valid_builder() -> StrategyBuilder {
318        Strategy::builder("test")
319            .entry(sample_condition())
320            .exit(sample_condition())
321            .stop_loss(StopLoss::FixedPercent(2.0))
322            .take_profit(TakeProfit::FixedPercent(5.0))
323    }
324
325    #[test]
326    fn builder_requires_entry() {
327        let result = Strategy::builder("test")
328            .exit(sample_condition())
329            .stop_loss(StopLoss::FixedPercent(2.0))
330            .take_profit(TakeProfit::FixedPercent(5.0))
331            .build();
332        assert!(result.is_err());
333    }
334
335    #[test]
336    fn builder_requires_exit() {
337        let result = Strategy::builder("test")
338            .entry(sample_condition())
339            .stop_loss(StopLoss::FixedPercent(2.0))
340            .take_profit(TakeProfit::FixedPercent(5.0))
341            .build();
342        assert!(result.is_err());
343    }
344
345    #[test]
346    fn builder_requires_stop_loss() {
347        let result = Strategy::builder("test")
348            .entry(sample_condition())
349            .exit(sample_condition())
350            .take_profit(TakeProfit::FixedPercent(5.0))
351            .build();
352        assert!(result.is_err());
353    }
354
355    #[test]
356    fn builder_requires_take_profit() {
357        let result = Strategy::builder("test")
358            .entry(sample_condition())
359            .exit(sample_condition())
360            .stop_loss(StopLoss::FixedPercent(2.0))
361            .build();
362        assert!(result.is_err());
363    }
364
365    #[test]
366    fn builder_validates_position_size() {
367        let result = valid_builder().max_position_size_pct(150.0).build();
368        assert!(result.is_err());
369
370        let result = valid_builder().max_position_size_pct(0.05).build();
371        assert!(result.is_err());
372    }
373
374    #[test]
375    fn builder_validates_daily_loss_bounds() {
376        let result = valid_builder().max_daily_loss_pct(51.0).build();
377        assert!(result.is_err());
378
379        let result = valid_builder().max_daily_loss_pct(0.05).build();
380        assert!(result.is_err());
381    }
382
383    #[test]
384    fn builder_validates_drawdown_bounds() {
385        let result = valid_builder().max_drawdown_pct(0.5).build();
386        assert!(result.is_err());
387    }
388
389    #[test]
390    fn builder_creates_valid_strategy() {
391        let result = valid_builder().build();
392        assert!(result.is_ok());
393        let strategy = result.unwrap();
394        assert_eq!(strategy.name, "test");
395    }
396
397    #[test]
398    fn builder_rejects_excessive_nesting() {
399        // depth 0: Group -> depth 1: Group -> depth 2: Group -> depth 3: Condition (too deep)
400        let leaf = sample_condition();
401        let depth2 = ConditionNode::Group(ConditionGroup::AllOf(vec![leaf]));
402        let depth1 = ConditionNode::Group(ConditionGroup::AllOf(vec![depth2]));
403        let depth0 = ConditionNode::Group(ConditionGroup::AllOf(vec![depth1]));
404
405        let result = valid_builder().entry(depth0).build();
406        assert!(result.is_err());
407    }
408
409    #[test]
410    fn builder_accepts_valid_nesting() {
411        // depth 0: Group -> depth 1: Group -> depth 2: Condition (within limit)
412        let leaf = sample_condition();
413        let depth1 = ConditionNode::Group(ConditionGroup::AllOf(vec![leaf]));
414        let depth0 = ConditionNode::Group(ConditionGroup::AllOf(vec![depth1]));
415
416        let result = valid_builder().entry(depth0).build();
417        assert!(result.is_ok());
418    }
419
420    #[test]
421    fn builder_rejects_oversized_group() {
422        let conditions: Vec<ConditionNode> = (0..21).map(|_| sample_condition()).collect();
423        let group = ConditionNode::Group(ConditionGroup::AllOf(conditions));
424
425        let result = valid_builder().entry(group).build();
426        assert!(result.is_err());
427    }
428
429    #[cfg(feature = "serde")]
430    #[test]
431    fn strategy_serde_round_trip() {
432        let entry = ConditionNode::Condition(Condition::new(
433            "sma_20",
434            Operator::CrossesAbove,
435            CompareTarget::Indicator("sma_50".to_string()),
436        ));
437        let exit = ConditionNode::Condition(Condition::new(
438            "sma_20",
439            Operator::CrossesBelow,
440            CompareTarget::Indicator("sma_50".to_string()),
441        ));
442        let strategy = Strategy::builder("round_trip_test")
443            .entry(entry)
444            .exit(exit)
445            .stop_loss(StopLoss::FixedPercent(2.0))
446            .take_profit(TakeProfit::AtrMultiple(1.5))
447            .max_concurrent_positions(3)
448            .build()
449            .unwrap();
450
451        let json = serde_json::to_string(&strategy).unwrap();
452        let deserialized: Strategy = serde_json::from_str(&json).unwrap();
453
454        assert_eq!(strategy, deserialized);
455    }
456
457    #[test]
458    fn condition_group_nesting() {
459        let cond1 = ConditionNode::Condition(Condition::new(
460            "sma_20",
461            Operator::IsAbove,
462            CompareTarget::Value(100.0),
463        ));
464        let cond2 = ConditionNode::Condition(Condition::new(
465            "rsi_14",
466            Operator::IsBelow,
467            CompareTarget::Value(70.0),
468        ));
469        let group = ConditionNode::Group(ConditionGroup::AllOf(vec![cond1, cond2]));
470        assert!(matches!(group, ConditionNode::Group(_)));
471    }
472}