Skip to main content

mantis_ta/strategy/
types.rs

1#[cfg(feature = "serde")]
2use serde::{Deserialize, Serialize};
3
4/// Comparison operators for condition evaluation.
5#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum Operator {
8    /// Left crosses above right
9    CrossesAbove,
10    /// Left crosses below right
11    CrossesBelow,
12    /// Left is strictly above right
13    IsAbove,
14    /// Left is strictly below right
15    IsBelow,
16    /// Left is between lower and upper bounds
17    IsBetween,
18    /// Left equals right (within epsilon)
19    Equals,
20    /// Left is rising over N bars (current > N bars ago)
21    IsRising(u32),
22    /// Left is falling over N bars (current < N bars ago)
23    IsFalling(u32),
24}
25
26/// Right-hand side of a comparison in a condition.
27#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
28#[derive(Debug, Clone, PartialEq)]
29pub enum CompareTarget {
30    /// Compare against a fixed scalar value
31    Value(f64),
32    /// Compare against another indicator's output
33    Indicator(String),
34    /// Compare against a scaled value (e.g., ATR * 2.0)
35    Scaled { indicator: String, multiplier: f64 },
36    /// Compare against a range of values (lower, upper)
37    Range(f64, f64),
38    /// No compare target (used for unary operators like IsRising/IsFalling)
39    None,
40}
41
42/// A single condition: left indicator, operator, right target.
43#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
44#[derive(Debug, Clone, PartialEq)]
45pub struct Condition {
46    pub left: String, // indicator name/id
47    pub operator: Operator,
48    pub right: CompareTarget,
49}
50
51impl Condition {
52    pub fn new(left: impl Into<String>, operator: Operator, right: CompareTarget) -> Self {
53        Self {
54            left: left.into(),
55            operator,
56            right,
57        }
58    }
59}
60
61/// Logical grouping of conditions.
62#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
63#[derive(Debug, Clone, PartialEq)]
64pub enum ConditionGroup {
65    /// All sub-conditions must be true
66    AllOf(Vec<ConditionNode>),
67    /// Any sub-condition must be true
68    AnyOf(Vec<ConditionNode>),
69}
70
71/// A node in the condition tree.
72#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
73#[derive(Debug, Clone, PartialEq)]
74pub enum ConditionNode {
75    Condition(Condition),
76    Group(ConditionGroup),
77}
78
79/// Stop-loss configuration.
80#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
81#[derive(Debug, Clone, Copy, PartialEq)]
82pub enum StopLoss {
83    /// Fixed percentage below entry
84    FixedPercent(f64),
85    /// ATR multiple below entry
86    AtrMultiple(f64),
87    /// Trailing stop: fixed percentage below highest price
88    Trailing(f64),
89}
90
91/// Take-profit configuration.
92#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
93#[derive(Debug, Clone, Copy, PartialEq)]
94pub enum TakeProfit {
95    /// Fixed percentage above entry
96    FixedPercent(f64),
97    /// ATR multiple above entry
98    AtrMultiple(f64),
99}
100
101/// Maximum nesting depth for condition groups (SPEC §5.3).
102const MAX_NESTING_DEPTH: usize = 2;
103
104/// Maximum conditions per group (SPEC §5.3).
105const MAX_CONDITIONS_PER_GROUP: usize = 20;
106
107/// A trading strategy composed of conditions and risk rules.
108#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
109#[derive(Debug, Clone, PartialEq)]
110pub struct Strategy {
111    pub name: String,
112    pub timeframe: crate::types::Timeframe,
113    pub entry: ConditionNode,
114    pub exit: Option<ConditionNode>,
115    pub stop_loss: StopLoss,
116    pub take_profit: Option<TakeProfit>,
117    pub max_position_size_pct: f64,
118    pub max_daily_loss_pct: f64,
119    pub max_drawdown_pct: f64,
120    pub max_concurrent_positions: usize,
121}
122
123impl Strategy {
124    /// Create a new strategy builder.
125    pub fn builder(name: impl Into<String>) -> StrategyBuilder {
126        StrategyBuilder {
127            name: name.into(),
128            timeframe: crate::types::Timeframe::D1,
129            entry: None,
130            exit: None,
131            stop_loss: None,
132            take_profit: None,
133            max_position_size_pct: 5.0,
134            max_daily_loss_pct: 2.0,
135            max_drawdown_pct: 10.0,
136            max_concurrent_positions: 1,
137        }
138    }
139}
140
141/// Fluent builder for constructing strategies with validation.
142#[derive(Debug)]
143pub struct StrategyBuilder {
144    name: String,
145    timeframe: crate::types::Timeframe,
146    entry: Option<ConditionNode>,
147    exit: Option<ConditionNode>,
148    stop_loss: Option<StopLoss>,
149    take_profit: Option<TakeProfit>,
150    max_position_size_pct: f64,
151    max_daily_loss_pct: f64,
152    max_drawdown_pct: f64,
153    max_concurrent_positions: usize,
154}
155
156impl StrategyBuilder {
157    pub fn timeframe(mut self, tf: crate::types::Timeframe) -> Self {
158        self.timeframe = tf;
159        self
160    }
161
162    pub fn entry(mut self, condition: ConditionNode) -> Self {
163        self.entry = Some(condition);
164        self
165    }
166
167    pub fn exit(mut self, condition: ConditionNode) -> Self {
168        self.exit = Some(condition);
169        self
170    }
171
172    pub fn stop_loss(mut self, sl: StopLoss) -> Self {
173        self.stop_loss = Some(sl);
174        self
175    }
176
177    pub fn take_profit(mut self, tp: TakeProfit) -> Self {
178        self.take_profit = Some(tp);
179        self
180    }
181
182    pub fn max_position_size_pct(mut self, pct: f64) -> Self {
183        self.max_position_size_pct = pct;
184        self
185    }
186
187    pub fn max_daily_loss_pct(mut self, pct: f64) -> Self {
188        self.max_daily_loss_pct = pct;
189        self
190    }
191
192    pub fn max_drawdown_pct(mut self, pct: f64) -> Self {
193        self.max_drawdown_pct = pct;
194        self
195    }
196
197    pub fn max_concurrent_positions(mut self, count: usize) -> Self {
198        self.max_concurrent_positions = count;
199        self
200    }
201
202    /// Build and validate the strategy (SPEC §5.3).
203    pub fn build(self) -> crate::types::Result<Strategy> {
204        let Some(entry) = self.entry else {
205            return Err(crate::types::MantisError::StrategyValidation(
206                "Strategy must have an entry condition".to_string(),
207            ));
208        };
209
210        let Some(stop_loss) = self.stop_loss else {
211            return Err(crate::types::MantisError::StrategyValidation(
212                "Strategy must have a stop-loss rule".to_string(),
213            ));
214        };
215
216        if self.max_position_size_pct < 0.1 || self.max_position_size_pct > 100.0 {
217            return Err(crate::types::MantisError::InvalidParameter {
218                param: "max_position_size_pct",
219                value: self.max_position_size_pct.to_string(),
220                reason: "must be between 0.1 and 100",
221            });
222        }
223
224        if self.max_daily_loss_pct < 0.1 || self.max_daily_loss_pct > 50.0 {
225            return Err(crate::types::MantisError::InvalidParameter {
226                param: "max_daily_loss_pct",
227                value: self.max_daily_loss_pct.to_string(),
228                reason: "must be between 0.1 and 50",
229            });
230        }
231
232        if self.max_drawdown_pct < 1.0 || self.max_drawdown_pct > 100.0 {
233            return Err(crate::types::MantisError::InvalidParameter {
234                param: "max_drawdown_pct",
235                value: self.max_drawdown_pct.to_string(),
236                reason: "must be between 1 and 100",
237            });
238        }
239
240        if self.max_concurrent_positions == 0 {
241            return Err(crate::types::MantisError::InvalidParameter {
242                param: "max_concurrent_positions",
243                value: "0".to_string(),
244                reason: "must be at least 1",
245            });
246        }
247
248        // Validate condition nesting depth and group sizes
249        validate_condition_node(&entry, 0)?;
250        if let Some(exit) = &self.exit {
251            validate_condition_node(exit, 0)?;
252        }
253
254        Ok(Strategy {
255            name: self.name,
256            timeframe: self.timeframe,
257            entry,
258            exit: self.exit,
259            stop_loss,
260            take_profit: self.take_profit,
261            max_position_size_pct: self.max_position_size_pct,
262            max_daily_loss_pct: self.max_daily_loss_pct,
263            max_drawdown_pct: self.max_drawdown_pct,
264            max_concurrent_positions: self.max_concurrent_positions,
265        })
266    }
267}
268
269/// Recursively validate condition nesting depth and group sizes (SPEC §5.3).
270fn validate_condition_node(node: &ConditionNode, depth: usize) -> crate::types::Result<()> {
271    if depth > MAX_NESTING_DEPTH {
272        return Err(crate::types::MantisError::StrategyValidation(format!(
273            "Condition nesting exceeds maximum depth of {MAX_NESTING_DEPTH}"
274        )));
275    }
276    if let ConditionNode::Group(group) = node {
277        let children = match group {
278            ConditionGroup::AllOf(c) | ConditionGroup::AnyOf(c) => c,
279        };
280        if children.len() > MAX_CONDITIONS_PER_GROUP {
281            return Err(crate::types::MantisError::StrategyValidation(format!(
282                "Condition group exceeds maximum of {MAX_CONDITIONS_PER_GROUP} conditions"
283            )));
284        }
285        for child in children {
286            validate_condition_node(child, depth + 1)?;
287        }
288    }
289    Ok(())
290}
291
292#[cfg(test)]
293mod tests {
294    use super::*;
295
296    fn sample_condition() -> ConditionNode {
297        ConditionNode::Condition(Condition::new(
298            "sma20",
299            Operator::CrossesAbove,
300            CompareTarget::Value(100.0),
301        ))
302    }
303
304    /// Helper to build a valid strategy with all mandatory fields.
305    fn valid_builder() -> StrategyBuilder {
306        Strategy::builder("test")
307            .entry(sample_condition())
308            .stop_loss(StopLoss::FixedPercent(2.0))
309    }
310
311    #[test]
312    fn builder_requires_entry() {
313        let result = Strategy::builder("test")
314            .exit(sample_condition())
315            .stop_loss(StopLoss::FixedPercent(2.0))
316            .build();
317        assert!(result.is_err());
318    }
319
320    #[test]
321    fn builder_requires_stop_loss() {
322        let result = Strategy::builder("test").entry(sample_condition()).build();
323        assert!(result.is_err());
324    }
325
326    #[test]
327    fn builder_validates_position_size() {
328        let result = valid_builder().max_position_size_pct(150.0).build();
329        assert!(result.is_err());
330
331        let result = valid_builder().max_position_size_pct(0.05).build();
332        assert!(result.is_err());
333    }
334
335    #[test]
336    fn builder_validates_daily_loss_bounds() {
337        let result = valid_builder().max_daily_loss_pct(51.0).build();
338        assert!(result.is_err());
339
340        let result = valid_builder().max_daily_loss_pct(0.05).build();
341        assert!(result.is_err());
342    }
343
344    #[test]
345    fn builder_validates_drawdown_bounds() {
346        let result = valid_builder().max_drawdown_pct(0.5).build();
347        assert!(result.is_err());
348    }
349
350    #[test]
351    fn builder_creates_valid_strategy() {
352        let result = valid_builder().build();
353        assert!(result.is_ok());
354        let strategy = result.unwrap();
355        assert_eq!(strategy.name, "test");
356    }
357
358    #[test]
359    fn builder_rejects_excessive_nesting() {
360        // depth 0: Group -> depth 1: Group -> depth 2: Group -> depth 3: Condition (too deep)
361        let leaf = sample_condition();
362        let depth2 = ConditionNode::Group(ConditionGroup::AllOf(vec![leaf]));
363        let depth1 = ConditionNode::Group(ConditionGroup::AllOf(vec![depth2]));
364        let depth0 = ConditionNode::Group(ConditionGroup::AllOf(vec![depth1]));
365
366        let result = valid_builder().entry(depth0).build();
367        assert!(result.is_err());
368    }
369
370    #[test]
371    fn builder_accepts_valid_nesting() {
372        // depth 0: Group -> depth 1: Group -> depth 2: Condition (within limit)
373        let leaf = sample_condition();
374        let depth1 = ConditionNode::Group(ConditionGroup::AllOf(vec![leaf]));
375        let depth0 = ConditionNode::Group(ConditionGroup::AllOf(vec![depth1]));
376
377        let result = valid_builder().entry(depth0).build();
378        assert!(result.is_ok());
379    }
380
381    #[test]
382    fn builder_rejects_oversized_group() {
383        let conditions: Vec<ConditionNode> = (0..21).map(|_| sample_condition()).collect();
384        let group = ConditionNode::Group(ConditionGroup::AllOf(conditions));
385
386        let result = valid_builder().entry(group).build();
387        assert!(result.is_err());
388    }
389
390    #[cfg(feature = "serde")]
391    #[test]
392    fn strategy_serde_round_trip() {
393        let entry = ConditionNode::Condition(Condition::new(
394            "sma_20",
395            Operator::CrossesAbove,
396            CompareTarget::Indicator("sma_50".to_string()),
397        ));
398        let exit = ConditionNode::Condition(Condition::new(
399            "sma_20",
400            Operator::CrossesBelow,
401            CompareTarget::Indicator("sma_50".to_string()),
402        ));
403        let strategy = Strategy::builder("round_trip_test")
404            .entry(entry)
405            .exit(exit)
406            .stop_loss(StopLoss::FixedPercent(2.0))
407            .take_profit(TakeProfit::AtrMultiple(1.5))
408            .max_concurrent_positions(3)
409            .build()
410            .unwrap();
411
412        let json = serde_json::to_string(&strategy).unwrap();
413        let deserialized: Strategy = serde_json::from_str(&json).unwrap();
414
415        assert_eq!(strategy, deserialized);
416    }
417
418    #[test]
419    fn condition_group_nesting() {
420        let cond1 = ConditionNode::Condition(Condition::new(
421            "sma_20",
422            Operator::IsAbove,
423            CompareTarget::Value(100.0),
424        ));
425        let cond2 = ConditionNode::Condition(Condition::new(
426            "rsi_14",
427            Operator::IsBelow,
428            CompareTarget::Value(70.0),
429        ));
430        let group = ConditionNode::Group(ConditionGroup::AllOf(vec![cond1, cond2]));
431        assert!(matches!(group, ConditionNode::Group(_)));
432    }
433}