1use serde::{Deserialize, Serialize};
2
3#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
9#[serde(rename_all = "snake_case")]
10pub enum TradeDirection {
11 Long,
12 Short,
13}
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
17#[serde(rename_all = "camelCase")]
18pub struct TurtleUnit {
19 pub entry_price: f64,
20 pub size: f64,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
30#[serde(rename_all = "camelCase")]
31pub struct TurtlePyramidState {
32 pub units: Vec<TurtleUnit>,
33 pub direction: TradeDirection,
34 pub entry_atr: f64,
36 pub max_units: u8,
38}
39
40impl TurtlePyramidState {
41 pub fn new(direction: TradeDirection, entry_price: f64, size: f64, atr: f64) -> Self {
43 Self {
44 units: vec![TurtleUnit { entry_price, size }],
45 direction,
46 entry_atr: atr,
47 max_units: 4,
48 }
49 }
50
51 pub fn should_add_unit(&self, current_price: f64) -> bool {
54 if self.is_full() || self.units.is_empty() {
55 return false;
56 }
57 let last_entry = self.units.last().unwrap().entry_price;
58 let threshold = 0.5 * self.entry_atr;
59
60 match self.direction {
61 TradeDirection::Long => current_price >= last_entry + threshold,
62 TradeDirection::Short => current_price <= last_entry - threshold,
63 }
64 }
65
66 pub fn add_unit(&mut self, entry_price: f64, size: f64) {
70 if !self.is_full() {
71 self.units.push(TurtleUnit { entry_price, size });
72 }
73 }
74
75 pub fn stop_price(&self) -> f64 {
80 if self.units.is_empty() {
81 return 0.0;
82 }
83 let n2 = 2.0 * self.entry_atr;
84
85 match self.direction {
86 TradeDirection::Long => {
87 let lowest = self
88 .units
89 .iter()
90 .map(|u| u.entry_price)
91 .fold(f64::INFINITY, f64::min);
92 lowest - n2
93 }
94 TradeDirection::Short => {
95 let highest = self
96 .units
97 .iter()
98 .map(|u| u.entry_price)
99 .fold(f64::NEG_INFINITY, f64::max);
100 highest + n2
101 }
102 }
103 }
104
105 pub fn average_entry(&self) -> f64 {
107 if self.units.is_empty() {
108 return 0.0;
109 }
110 let total_size = self.total_size();
111 if total_size == 0.0 {
112 return 0.0;
113 }
114 let weighted_sum: f64 = self.units.iter().map(|u| u.entry_price * u.size).sum();
115 weighted_sum / total_size
116 }
117
118 pub fn total_size(&self) -> f64 {
120 self.units.iter().map(|u| u.size).sum()
121 }
122
123 pub fn is_full(&self) -> bool {
125 self.units.len() >= self.max_units as usize
126 }
127}
128
129#[cfg(test)]
134mod tests {
135 use super::*;
136
137 #[test]
140 fn test_new_creates_single_unit() {
141 let p = TurtlePyramidState::new(TradeDirection::Long, 100.0, 1.0, 5.0);
142 assert_eq!(p.units.len(), 1);
143 assert_eq!(p.units[0].entry_price, 100.0);
144 assert_eq!(p.units[0].size, 1.0);
145 assert_eq!(p.entry_atr, 5.0);
146 assert_eq!(p.max_units, 4);
147 assert_eq!(p.direction, TradeDirection::Long);
148 }
149
150 #[test]
153 fn test_should_add_unit_long_below_threshold() {
154 let p = TurtlePyramidState::new(TradeDirection::Long, 100.0, 1.0, 10.0);
155 assert!(!p.should_add_unit(104.9));
157 }
158
159 #[test]
160 fn test_should_add_unit_long_at_threshold() {
161 let p = TurtlePyramidState::new(TradeDirection::Long, 100.0, 1.0, 10.0);
162 assert!(p.should_add_unit(105.0));
163 }
164
165 #[test]
166 fn test_should_add_unit_long_above_threshold() {
167 let p = TurtlePyramidState::new(TradeDirection::Long, 100.0, 1.0, 10.0);
168 assert!(p.should_add_unit(110.0));
169 }
170
171 #[test]
172 fn test_should_add_unit_short_above_threshold() {
173 let p = TurtlePyramidState::new(TradeDirection::Short, 100.0, 1.0, 10.0);
174 assert!(!p.should_add_unit(95.1));
176 }
177
178 #[test]
179 fn test_should_add_unit_short_at_threshold() {
180 let p = TurtlePyramidState::new(TradeDirection::Short, 100.0, 1.0, 10.0);
181 assert!(p.should_add_unit(95.0));
182 }
183
184 #[test]
185 fn test_should_add_unit_short_below_threshold() {
186 let p = TurtlePyramidState::new(TradeDirection::Short, 100.0, 1.0, 10.0);
187 assert!(p.should_add_unit(90.0));
188 }
189
190 #[test]
191 fn test_should_add_unit_false_when_full() {
192 let mut p = TurtlePyramidState::new(TradeDirection::Long, 100.0, 1.0, 10.0);
193 p.add_unit(105.0, 1.0);
194 p.add_unit(110.0, 1.0);
195 p.add_unit(115.0, 1.0);
196 assert!(p.is_full());
197 assert!(!p.should_add_unit(120.0));
199 }
200
201 #[test]
202 fn test_should_add_unit_checks_last_entry_not_first() {
203 let mut p = TurtlePyramidState::new(TradeDirection::Long, 100.0, 1.0, 10.0);
204 p.add_unit(105.0, 1.0);
205 assert!(!p.should_add_unit(109.9));
207 assert!(p.should_add_unit(110.0));
208 }
209
210 #[test]
213 fn test_add_unit_increments_count() {
214 let mut p = TurtlePyramidState::new(TradeDirection::Long, 100.0, 1.0, 5.0);
215 p.add_unit(102.5, 1.0);
216 assert_eq!(p.units.len(), 2);
217 p.add_unit(105.0, 1.0);
218 assert_eq!(p.units.len(), 3);
219 p.add_unit(107.5, 1.0);
220 assert_eq!(p.units.len(), 4);
221 }
222
223 #[test]
224 fn test_add_unit_noop_when_full() {
225 let mut p = TurtlePyramidState::new(TradeDirection::Long, 100.0, 1.0, 5.0);
226 p.add_unit(102.5, 1.0);
227 p.add_unit(105.0, 1.0);
228 p.add_unit(107.5, 1.0);
229 assert!(p.is_full());
230 p.add_unit(110.0, 1.0); assert_eq!(p.units.len(), 4);
232 }
233
234 #[test]
237 fn test_stop_price_long_single_unit() {
238 let p = TurtlePyramidState::new(TradeDirection::Long, 100.0, 1.0, 5.0);
239 assert_eq!(p.stop_price(), 90.0);
241 }
242
243 #[test]
244 fn test_stop_price_long_multiple_units() {
245 let mut p = TurtlePyramidState::new(TradeDirection::Long, 100.0, 1.0, 5.0);
246 p.add_unit(105.0, 1.0);
247 p.add_unit(110.0, 1.0);
248 assert_eq!(p.stop_price(), 90.0);
250 }
251
252 #[test]
253 fn test_stop_price_short_single_unit() {
254 let p = TurtlePyramidState::new(TradeDirection::Short, 100.0, 1.0, 5.0);
255 assert_eq!(p.stop_price(), 110.0);
257 }
258
259 #[test]
260 fn test_stop_price_short_multiple_units() {
261 let mut p = TurtlePyramidState::new(TradeDirection::Short, 100.0, 1.0, 5.0);
262 p.add_unit(95.0, 1.0);
263 p.add_unit(90.0, 1.0);
264 assert_eq!(p.stop_price(), 110.0);
266 }
267
268 #[test]
269 fn test_stop_price_empty_units() {
270 let mut p = TurtlePyramidState::new(TradeDirection::Long, 100.0, 1.0, 5.0);
271 p.units.clear();
272 assert_eq!(p.stop_price(), 0.0);
273 }
274
275 #[test]
278 fn test_average_entry_single_unit() {
279 let p = TurtlePyramidState::new(TradeDirection::Long, 100.0, 1.0, 5.0);
280 assert_eq!(p.average_entry(), 100.0);
281 }
282
283 #[test]
284 fn test_average_entry_equal_sizes() {
285 let mut p = TurtlePyramidState::new(TradeDirection::Long, 100.0, 1.0, 5.0);
286 p.add_unit(110.0, 1.0);
287 assert!((p.average_entry() - 105.0).abs() < 1e-10);
289 }
290
291 #[test]
292 fn test_average_entry_different_sizes() {
293 let mut p = TurtlePyramidState::new(TradeDirection::Long, 100.0, 2.0, 5.0);
294 p.add_unit(110.0, 1.0);
295 assert!((p.average_entry() - 310.0 / 3.0).abs() < 1e-10);
297 }
298
299 #[test]
300 fn test_average_entry_empty() {
301 let mut p = TurtlePyramidState::new(TradeDirection::Long, 100.0, 1.0, 5.0);
302 p.units.clear();
303 assert_eq!(p.average_entry(), 0.0);
304 }
305
306 #[test]
307 fn test_average_entry_zero_sizes() {
308 let mut p = TurtlePyramidState::new(TradeDirection::Long, 100.0, 0.0, 5.0);
309 p.add_unit(110.0, 0.0);
310 assert_eq!(p.average_entry(), 0.0);
311 }
312
313 #[test]
316 fn test_total_size() {
317 let mut p = TurtlePyramidState::new(TradeDirection::Long, 100.0, 1.0, 5.0);
318 assert_eq!(p.total_size(), 1.0);
319 p.add_unit(105.0, 2.0);
320 assert_eq!(p.total_size(), 3.0);
321 p.add_unit(110.0, 0.5);
322 assert_eq!(p.total_size(), 3.5);
323 }
324
325 #[test]
328 fn test_is_full() {
329 let mut p = TurtlePyramidState::new(TradeDirection::Long, 100.0, 1.0, 5.0);
330 assert!(!p.is_full());
331 p.add_unit(105.0, 1.0);
332 assert!(!p.is_full());
333 p.add_unit(110.0, 1.0);
334 assert!(!p.is_full());
335 p.add_unit(115.0, 1.0);
336 assert!(p.is_full());
337 }
338
339 #[test]
340 fn test_is_full_custom_max_units() {
341 let mut p = TurtlePyramidState::new(TradeDirection::Long, 100.0, 1.0, 5.0);
342 p.max_units = 2;
343 assert!(!p.is_full());
344 p.add_unit(105.0, 1.0);
345 assert!(p.is_full());
346 }
347
348 #[test]
351 fn test_full_pyramid_scenario_long() {
352 let mut p = TurtlePyramidState::new(TradeDirection::Long, 100.0, 1.0, 10.0);
354
355 assert!(p.should_add_unit(105.0));
357 p.add_unit(105.0, 1.0);
358
359 assert!(p.should_add_unit(110.0));
361 p.add_unit(110.0, 1.0);
362
363 assert!(p.should_add_unit(115.0));
365 p.add_unit(115.0, 1.0);
366
367 assert!(p.is_full());
368 assert_eq!(p.total_size(), 4.0);
369 assert_eq!(p.average_entry(), 107.5); assert_eq!(p.stop_price(), 80.0); }
372
373 #[test]
374 fn test_full_pyramid_scenario_short() {
375 let mut p = TurtlePyramidState::new(TradeDirection::Short, 200.0, 1.0, 10.0);
376
377 assert!(p.should_add_unit(195.0));
378 p.add_unit(195.0, 1.0);
379
380 assert!(p.should_add_unit(190.0));
381 p.add_unit(190.0, 1.0);
382
383 assert!(p.should_add_unit(185.0));
384 p.add_unit(185.0, 1.0);
385
386 assert!(p.is_full());
387 assert_eq!(p.total_size(), 4.0);
388 assert_eq!(p.average_entry(), 192.5); assert_eq!(p.stop_price(), 220.0); }
391
392 #[test]
395 fn test_serialization_roundtrip() {
396 let mut p = TurtlePyramidState::new(TradeDirection::Long, 100.0, 1.0, 5.0);
397 p.add_unit(102.5, 1.5);
398
399 let json = serde_json::to_string(&p).unwrap();
400 let parsed: TurtlePyramidState = serde_json::from_str(&json).unwrap();
401
402 assert_eq!(parsed.units.len(), 2);
403 assert_eq!(parsed.direction, TradeDirection::Long);
404 assert_eq!(parsed.entry_atr, 5.0);
405 assert_eq!(parsed.max_units, 4);
406 assert_eq!(parsed.units[1].entry_price, 102.5);
407 assert_eq!(parsed.units[1].size, 1.5);
408 }
409
410 #[test]
411 fn test_direction_serialization() {
412 let long_json = serde_json::to_string(&TradeDirection::Long).unwrap();
413 assert_eq!(long_json, "\"long\"");
414 let short_json = serde_json::to_string(&TradeDirection::Short).unwrap();
415 assert_eq!(short_json, "\"short\"");
416 }
417}