Skip to main content

hyper_strategy/
turtle_pyramid.rs

1use serde::{Deserialize, Serialize};
2
3// ---------------------------------------------------------------------------
4// #225 — Turtle Trading Pyramid
5// ---------------------------------------------------------------------------
6
7/// Direction of a Turtle pyramid position.
8#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
9#[serde(rename_all = "snake_case")]
10pub enum TradeDirection {
11    Long,
12    Short,
13}
14
15/// A single unit in the Turtle pyramid.
16#[derive(Debug, Clone, Serialize, Deserialize)]
17#[serde(rename_all = "camelCase")]
18pub struct TurtleUnit {
19    pub entry_price: f64,
20    pub size: f64,
21}
22
23/// State for a Turtle Trading pyramid position (up to `max_units` units).
24///
25/// The classic Turtle system adds a new unit each time price moves 0.5 ATR
26/// from the last entry, up to a maximum of 4 units.  The stop for the entire
27/// position is placed at 2N (2 × entry ATR) below the *lowest* unit entry
28/// for longs, or above the *highest* unit entry for shorts.
29#[derive(Debug, Clone, Serialize, Deserialize)]
30#[serde(rename_all = "camelCase")]
31pub struct TurtlePyramidState {
32    pub units: Vec<TurtleUnit>,
33    pub direction: TradeDirection,
34    /// ATR at the time of the first entry.
35    pub entry_atr: f64,
36    /// Maximum number of units allowed (default 4).
37    pub max_units: u8,
38}
39
40impl TurtlePyramidState {
41    /// Create a new pyramid with the first unit.
42    pub fn new(direction: TradeDirection, entry_price: f64, size: f64, atr: f64) -> Self {
43        Self {
44            units: vec![TurtleUnit { entry_price, size }],
45            direction,
46            entry_atr: atr,
47            max_units: 4,
48        }
49    }
50
51    /// Whether price has moved 0.5 ATR from the last entry in the
52    /// direction of the trade, indicating we should add another unit.
53    pub fn should_add_unit(&self, current_price: f64) -> bool {
54        if self.is_full() || self.units.is_empty() {
55            return false;
56        }
57        let last_entry = self.units.last().unwrap().entry_price;
58        let threshold = 0.5 * self.entry_atr;
59
60        match self.direction {
61            TradeDirection::Long => current_price >= last_entry + threshold,
62            TradeDirection::Short => current_price <= last_entry - threshold,
63        }
64    }
65
66    /// Add a new unit to the pyramid.
67    ///
68    /// Does nothing if the pyramid is already full.
69    pub fn add_unit(&mut self, entry_price: f64, size: f64) {
70        if !self.is_full() {
71            self.units.push(TurtleUnit { entry_price, size });
72        }
73    }
74
75    /// Compute the stop price for the entire position.
76    ///
77    /// - **Long**: `lowest_entry − 2 × entry_atr`
78    /// - **Short**: `highest_entry + 2 × entry_atr`
79    pub fn stop_price(&self) -> f64 {
80        if self.units.is_empty() {
81            return 0.0;
82        }
83        let n2 = 2.0 * self.entry_atr;
84
85        match self.direction {
86            TradeDirection::Long => {
87                let lowest = self
88                    .units
89                    .iter()
90                    .map(|u| u.entry_price)
91                    .fold(f64::INFINITY, f64::min);
92                lowest - n2
93            }
94            TradeDirection::Short => {
95                let highest = self
96                    .units
97                    .iter()
98                    .map(|u| u.entry_price)
99                    .fold(f64::NEG_INFINITY, f64::max);
100                highest + n2
101            }
102        }
103    }
104
105    /// Weighted average entry price across all units.
106    pub fn average_entry(&self) -> f64 {
107        if self.units.is_empty() {
108            return 0.0;
109        }
110        let total_size = self.total_size();
111        if total_size == 0.0 {
112            return 0.0;
113        }
114        let weighted_sum: f64 = self.units.iter().map(|u| u.entry_price * u.size).sum();
115        weighted_sum / total_size
116    }
117
118    /// Total size across all units.
119    pub fn total_size(&self) -> f64 {
120        self.units.iter().map(|u| u.size).sum()
121    }
122
123    /// Whether the pyramid is at its maximum number of units.
124    pub fn is_full(&self) -> bool {
125        self.units.len() >= self.max_units as usize
126    }
127}
128
129// ---------------------------------------------------------------------------
130// Tests
131// ---------------------------------------------------------------------------
132
133#[cfg(test)]
134mod tests {
135    use super::*;
136
137    // --- Construction ---
138
139    #[test]
140    fn test_new_creates_single_unit() {
141        let p = TurtlePyramidState::new(TradeDirection::Long, 100.0, 1.0, 5.0);
142        assert_eq!(p.units.len(), 1);
143        assert_eq!(p.units[0].entry_price, 100.0);
144        assert_eq!(p.units[0].size, 1.0);
145        assert_eq!(p.entry_atr, 5.0);
146        assert_eq!(p.max_units, 4);
147        assert_eq!(p.direction, TradeDirection::Long);
148    }
149
150    // --- should_add_unit ---
151
152    #[test]
153    fn test_should_add_unit_long_below_threshold() {
154        let p = TurtlePyramidState::new(TradeDirection::Long, 100.0, 1.0, 10.0);
155        // Threshold = 0.5 * 10 = 5.0 → need price >= 105
156        assert!(!p.should_add_unit(104.9));
157    }
158
159    #[test]
160    fn test_should_add_unit_long_at_threshold() {
161        let p = TurtlePyramidState::new(TradeDirection::Long, 100.0, 1.0, 10.0);
162        assert!(p.should_add_unit(105.0));
163    }
164
165    #[test]
166    fn test_should_add_unit_long_above_threshold() {
167        let p = TurtlePyramidState::new(TradeDirection::Long, 100.0, 1.0, 10.0);
168        assert!(p.should_add_unit(110.0));
169    }
170
171    #[test]
172    fn test_should_add_unit_short_above_threshold() {
173        let p = TurtlePyramidState::new(TradeDirection::Short, 100.0, 1.0, 10.0);
174        // For shorts: need price <= 100 - 5 = 95
175        assert!(!p.should_add_unit(95.1));
176    }
177
178    #[test]
179    fn test_should_add_unit_short_at_threshold() {
180        let p = TurtlePyramidState::new(TradeDirection::Short, 100.0, 1.0, 10.0);
181        assert!(p.should_add_unit(95.0));
182    }
183
184    #[test]
185    fn test_should_add_unit_short_below_threshold() {
186        let p = TurtlePyramidState::new(TradeDirection::Short, 100.0, 1.0, 10.0);
187        assert!(p.should_add_unit(90.0));
188    }
189
190    #[test]
191    fn test_should_add_unit_false_when_full() {
192        let mut p = TurtlePyramidState::new(TradeDirection::Long, 100.0, 1.0, 10.0);
193        p.add_unit(105.0, 1.0);
194        p.add_unit(110.0, 1.0);
195        p.add_unit(115.0, 1.0);
196        assert!(p.is_full());
197        // Even though price moved enough, should return false
198        assert!(!p.should_add_unit(120.0));
199    }
200
201    #[test]
202    fn test_should_add_unit_checks_last_entry_not_first() {
203        let mut p = TurtlePyramidState::new(TradeDirection::Long, 100.0, 1.0, 10.0);
204        p.add_unit(105.0, 1.0);
205        // Now last entry is 105.0, threshold = 5.0 → need >= 110
206        assert!(!p.should_add_unit(109.9));
207        assert!(p.should_add_unit(110.0));
208    }
209
210    // --- add_unit ---
211
212    #[test]
213    fn test_add_unit_increments_count() {
214        let mut p = TurtlePyramidState::new(TradeDirection::Long, 100.0, 1.0, 5.0);
215        p.add_unit(102.5, 1.0);
216        assert_eq!(p.units.len(), 2);
217        p.add_unit(105.0, 1.0);
218        assert_eq!(p.units.len(), 3);
219        p.add_unit(107.5, 1.0);
220        assert_eq!(p.units.len(), 4);
221    }
222
223    #[test]
224    fn test_add_unit_noop_when_full() {
225        let mut p = TurtlePyramidState::new(TradeDirection::Long, 100.0, 1.0, 5.0);
226        p.add_unit(102.5, 1.0);
227        p.add_unit(105.0, 1.0);
228        p.add_unit(107.5, 1.0);
229        assert!(p.is_full());
230        p.add_unit(110.0, 1.0); // should be ignored
231        assert_eq!(p.units.len(), 4);
232    }
233
234    // --- stop_price ---
235
236    #[test]
237    fn test_stop_price_long_single_unit() {
238        let p = TurtlePyramidState::new(TradeDirection::Long, 100.0, 1.0, 5.0);
239        // stop = 100 - 2*5 = 90
240        assert_eq!(p.stop_price(), 90.0);
241    }
242
243    #[test]
244    fn test_stop_price_long_multiple_units() {
245        let mut p = TurtlePyramidState::new(TradeDirection::Long, 100.0, 1.0, 5.0);
246        p.add_unit(105.0, 1.0);
247        p.add_unit(110.0, 1.0);
248        // Lowest entry = 100, stop = 100 - 10 = 90
249        assert_eq!(p.stop_price(), 90.0);
250    }
251
252    #[test]
253    fn test_stop_price_short_single_unit() {
254        let p = TurtlePyramidState::new(TradeDirection::Short, 100.0, 1.0, 5.0);
255        // stop = 100 + 2*5 = 110
256        assert_eq!(p.stop_price(), 110.0);
257    }
258
259    #[test]
260    fn test_stop_price_short_multiple_units() {
261        let mut p = TurtlePyramidState::new(TradeDirection::Short, 100.0, 1.0, 5.0);
262        p.add_unit(95.0, 1.0);
263        p.add_unit(90.0, 1.0);
264        // Highest entry = 100, stop = 100 + 10 = 110
265        assert_eq!(p.stop_price(), 110.0);
266    }
267
268    #[test]
269    fn test_stop_price_empty_units() {
270        let mut p = TurtlePyramidState::new(TradeDirection::Long, 100.0, 1.0, 5.0);
271        p.units.clear();
272        assert_eq!(p.stop_price(), 0.0);
273    }
274
275    // --- average_entry ---
276
277    #[test]
278    fn test_average_entry_single_unit() {
279        let p = TurtlePyramidState::new(TradeDirection::Long, 100.0, 1.0, 5.0);
280        assert_eq!(p.average_entry(), 100.0);
281    }
282
283    #[test]
284    fn test_average_entry_equal_sizes() {
285        let mut p = TurtlePyramidState::new(TradeDirection::Long, 100.0, 1.0, 5.0);
286        p.add_unit(110.0, 1.0);
287        // (100*1 + 110*1) / 2 = 105
288        assert!((p.average_entry() - 105.0).abs() < 1e-10);
289    }
290
291    #[test]
292    fn test_average_entry_different_sizes() {
293        let mut p = TurtlePyramidState::new(TradeDirection::Long, 100.0, 2.0, 5.0);
294        p.add_unit(110.0, 1.0);
295        // (100*2 + 110*1) / 3 = 310/3 ≈ 103.333...
296        assert!((p.average_entry() - 310.0 / 3.0).abs() < 1e-10);
297    }
298
299    #[test]
300    fn test_average_entry_empty() {
301        let mut p = TurtlePyramidState::new(TradeDirection::Long, 100.0, 1.0, 5.0);
302        p.units.clear();
303        assert_eq!(p.average_entry(), 0.0);
304    }
305
306    #[test]
307    fn test_average_entry_zero_sizes() {
308        let mut p = TurtlePyramidState::new(TradeDirection::Long, 100.0, 0.0, 5.0);
309        p.add_unit(110.0, 0.0);
310        assert_eq!(p.average_entry(), 0.0);
311    }
312
313    // --- total_size ---
314
315    #[test]
316    fn test_total_size() {
317        let mut p = TurtlePyramidState::new(TradeDirection::Long, 100.0, 1.0, 5.0);
318        assert_eq!(p.total_size(), 1.0);
319        p.add_unit(105.0, 2.0);
320        assert_eq!(p.total_size(), 3.0);
321        p.add_unit(110.0, 0.5);
322        assert_eq!(p.total_size(), 3.5);
323    }
324
325    // --- is_full ---
326
327    #[test]
328    fn test_is_full() {
329        let mut p = TurtlePyramidState::new(TradeDirection::Long, 100.0, 1.0, 5.0);
330        assert!(!p.is_full());
331        p.add_unit(105.0, 1.0);
332        assert!(!p.is_full());
333        p.add_unit(110.0, 1.0);
334        assert!(!p.is_full());
335        p.add_unit(115.0, 1.0);
336        assert!(p.is_full());
337    }
338
339    #[test]
340    fn test_is_full_custom_max_units() {
341        let mut p = TurtlePyramidState::new(TradeDirection::Long, 100.0, 1.0, 5.0);
342        p.max_units = 2;
343        assert!(!p.is_full());
344        p.add_unit(105.0, 1.0);
345        assert!(p.is_full());
346    }
347
348    // --- Full pyramid scenario ---
349
350    #[test]
351    fn test_full_pyramid_scenario_long() {
352        // ATR = 10, so 0.5 ATR = 5, 2N = 20
353        let mut p = TurtlePyramidState::new(TradeDirection::Long, 100.0, 1.0, 10.0);
354
355        // Add unit at 105 (0.5 ATR above first entry)
356        assert!(p.should_add_unit(105.0));
357        p.add_unit(105.0, 1.0);
358
359        // Add unit at 110 (0.5 ATR above second entry)
360        assert!(p.should_add_unit(110.0));
361        p.add_unit(110.0, 1.0);
362
363        // Add unit at 115
364        assert!(p.should_add_unit(115.0));
365        p.add_unit(115.0, 1.0);
366
367        assert!(p.is_full());
368        assert_eq!(p.total_size(), 4.0);
369        assert_eq!(p.average_entry(), 107.5); // (100+105+110+115)/4
370        assert_eq!(p.stop_price(), 80.0); // 100 - 2*10 = 80
371    }
372
373    #[test]
374    fn test_full_pyramid_scenario_short() {
375        let mut p = TurtlePyramidState::new(TradeDirection::Short, 200.0, 1.0, 10.0);
376
377        assert!(p.should_add_unit(195.0));
378        p.add_unit(195.0, 1.0);
379
380        assert!(p.should_add_unit(190.0));
381        p.add_unit(190.0, 1.0);
382
383        assert!(p.should_add_unit(185.0));
384        p.add_unit(185.0, 1.0);
385
386        assert!(p.is_full());
387        assert_eq!(p.total_size(), 4.0);
388        assert_eq!(p.average_entry(), 192.5); // (200+195+190+185)/4
389        assert_eq!(p.stop_price(), 220.0); // 200 + 2*10 = 220
390    }
391
392    // --- Serialization ---
393
394    #[test]
395    fn test_serialization_roundtrip() {
396        let mut p = TurtlePyramidState::new(TradeDirection::Long, 100.0, 1.0, 5.0);
397        p.add_unit(102.5, 1.5);
398
399        let json = serde_json::to_string(&p).unwrap();
400        let parsed: TurtlePyramidState = serde_json::from_str(&json).unwrap();
401
402        assert_eq!(parsed.units.len(), 2);
403        assert_eq!(parsed.direction, TradeDirection::Long);
404        assert_eq!(parsed.entry_atr, 5.0);
405        assert_eq!(parsed.max_units, 4);
406        assert_eq!(parsed.units[1].entry_price, 102.5);
407        assert_eq!(parsed.units[1].size, 1.5);
408    }
409
410    #[test]
411    fn test_direction_serialization() {
412        let long_json = serde_json::to_string(&TradeDirection::Long).unwrap();
413        assert_eq!(long_json, "\"long\"");
414        let short_json = serde_json::to_string(&TradeDirection::Short).unwrap();
415        assert_eq!(short_json, "\"short\"");
416    }
417}