Skip to main content

hyper_strategy/
grid_trading.rs

1use serde::{Deserialize, Serialize};
2
3// ---------------------------------------------------------------------------
4// #226 — Grid Trading State
5// ---------------------------------------------------------------------------
6
7/// How grid line spacing is calculated.
8#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
9#[serde(rename_all = "snake_case")]
10pub enum GridSpacing {
11    /// Fixed percentage between consecutive grid lines.
12    Fixed(f64),
13    /// ATR multiplier between consecutive grid lines.
14    AtrBased(f64),
15}
16
17/// Side of a grid line order.
18#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
19#[serde(rename_all = "snake_case")]
20pub enum GridSide {
21    Buy,
22    Sell,
23}
24
25/// A signal emitted when a grid line is crossed.
26#[derive(Debug, Clone, Serialize, Deserialize)]
27#[serde(rename_all = "camelCase")]
28pub struct GridSignal {
29    pub price: f64,
30    pub side: GridSide,
31    pub line_index: usize,
32}
33
34/// A single line in the grid.
35#[derive(Debug, Clone, Serialize, Deserialize)]
36#[serde(rename_all = "camelCase")]
37pub struct GridLine {
38    pub price: f64,
39    pub side: GridSide,
40    pub filled: bool,
41    pub fill_price: Option<f64>,
42}
43
44/// State for a grid trading strategy.
45///
46/// The grid is symmetric around `center_price`.  Lines below the center
47/// are **Buy** orders; lines above are **Sell** orders.
48#[derive(Debug, Clone, Serialize, Deserialize)]
49#[serde(rename_all = "camelCase")]
50pub struct GridState {
51    pub center_price: f64,
52    pub grid_lines: Vec<GridLine>,
53    pub spacing: GridSpacing,
54    pub active: bool,
55}
56
57impl GridState {
58    /// Create a new symmetric grid around `center`.
59    ///
60    /// `num_lines` is the number of lines *per side* (so total lines = 2 × num_lines).
61    ///
62    /// For `GridSpacing::Fixed(pct)`, each line is spaced `center * pct / 100`
63    /// apart.  For `GridSpacing::AtrBased(mult)`, the spacing is `atr * mult`
64    /// (requires `atr` to be `Some`).
65    pub fn new(center: f64, num_lines: u8, spacing: GridSpacing, atr: Option<f64>) -> Self {
66        let step = match &spacing {
67            GridSpacing::Fixed(pct) => center * pct / 100.0,
68            GridSpacing::AtrBased(mult) => {
69                let atr_val = atr.unwrap_or(0.0);
70                atr_val * mult
71            }
72        };
73
74        let mut grid_lines = Vec::with_capacity(num_lines as usize * 2);
75
76        // Buy lines below center (closest to center first)
77        for i in 1..=num_lines {
78            let price = center - step * i as f64;
79            grid_lines.push(GridLine {
80                price,
81                side: GridSide::Buy,
82                filled: false,
83                fill_price: None,
84            });
85        }
86
87        // Sell lines above center (closest to center first)
88        for i in 1..=num_lines {
89            let price = center + step * i as f64;
90            grid_lines.push(GridLine {
91                price,
92                side: GridSide::Sell,
93                filled: false,
94                fill_price: None,
95            });
96        }
97
98        Self {
99            center_price: center,
100            grid_lines,
101            spacing,
102            active: true,
103        }
104    }
105
106    /// Check which unfilled grid lines have been crossed by the current price
107    /// and mark them as filled.
108    ///
109    /// Returns signals for each newly filled line.
110    pub fn check_fills(&mut self, current_price: f64) -> Vec<GridSignal> {
111        let mut signals = Vec::new();
112
113        for (idx, line) in self.grid_lines.iter_mut().enumerate() {
114            if line.filled {
115                continue;
116            }
117
118            let triggered = match line.side {
119                GridSide::Buy => current_price <= line.price,
120                GridSide::Sell => current_price >= line.price,
121            };
122
123            if triggered {
124                line.filled = true;
125                line.fill_price = Some(current_price);
126                signals.push(GridSignal {
127                    price: line.price,
128                    side: line.side.clone(),
129                    line_index: idx,
130                });
131            }
132        }
133
134        signals
135    }
136
137    /// Whether the price has moved far enough from center to warrant a reset.
138    ///
139    /// Returns `true` if `|current_price − center| / center > threshold_pct / 100`.
140    pub fn should_reset(&self, current_price: f64, threshold_pct: f64) -> bool {
141        if self.center_price == 0.0 {
142            return false;
143        }
144        let deviation = (current_price - self.center_price).abs() / self.center_price;
145        deviation > threshold_pct / 100.0
146    }
147
148    /// Reset the grid around a new center price, clearing all fills.
149    pub fn reset(&mut self, new_center: f64, atr: Option<f64>) {
150        let num_lines = (self.grid_lines.len() / 2) as u8;
151        let new_grid = Self::new(new_center, num_lines, self.spacing.clone(), atr);
152        self.center_price = new_grid.center_price;
153        self.grid_lines = new_grid.grid_lines;
154    }
155
156    /// Count of unfilled grid lines.
157    pub fn unfilled_count(&self) -> usize {
158        self.grid_lines.iter().filter(|l| !l.filled).count()
159    }
160
161    /// Count of filled grid lines.
162    pub fn filled_count(&self) -> usize {
163        self.grid_lines.iter().filter(|l| l.filled).count()
164    }
165
166    /// Total number of grid lines.
167    pub fn total_lines(&self) -> usize {
168        self.grid_lines.len()
169    }
170}
171
172// ---------------------------------------------------------------------------
173// Tests
174// ---------------------------------------------------------------------------
175
176#[cfg(test)]
177mod tests {
178    use super::*;
179
180    // --- Construction ---
181
182    #[test]
183    fn test_new_fixed_spacing() {
184        let g = GridState::new(100.0, 3, GridSpacing::Fixed(1.0), None);
185        // 3 buy + 3 sell = 6 lines
186        assert_eq!(g.grid_lines.len(), 6);
187        assert_eq!(g.center_price, 100.0);
188        assert!(g.active);
189    }
190
191    #[test]
192    fn test_new_fixed_spacing_prices() {
193        // center=100, 1% spacing → step = 1.0
194        let g = GridState::new(100.0, 3, GridSpacing::Fixed(1.0), None);
195
196        // Buy lines: 99, 98, 97
197        let buys: Vec<f64> = g
198            .grid_lines
199            .iter()
200            .filter(|l| l.side == GridSide::Buy)
201            .map(|l| l.price)
202            .collect();
203        assert_eq!(buys.len(), 3);
204        assert!((buys[0] - 99.0).abs() < 1e-10);
205        assert!((buys[1] - 98.0).abs() < 1e-10);
206        assert!((buys[2] - 97.0).abs() < 1e-10);
207
208        // Sell lines: 101, 102, 103
209        let sells: Vec<f64> = g
210            .grid_lines
211            .iter()
212            .filter(|l| l.side == GridSide::Sell)
213            .map(|l| l.price)
214            .collect();
215        assert_eq!(sells.len(), 3);
216        assert!((sells[0] - 101.0).abs() < 1e-10);
217        assert!((sells[1] - 102.0).abs() < 1e-10);
218        assert!((sells[2] - 103.0).abs() < 1e-10);
219    }
220
221    #[test]
222    fn test_new_atr_spacing() {
223        // ATR = 5.0, multiplier = 2.0 → step = 10.0
224        let g = GridState::new(100.0, 2, GridSpacing::AtrBased(2.0), Some(5.0));
225        assert_eq!(g.grid_lines.len(), 4);
226
227        let buys: Vec<f64> = g
228            .grid_lines
229            .iter()
230            .filter(|l| l.side == GridSide::Buy)
231            .map(|l| l.price)
232            .collect();
233        assert!((buys[0] - 90.0).abs() < 1e-10); // 100 - 10
234        assert!((buys[1] - 80.0).abs() < 1e-10); // 100 - 20
235
236        let sells: Vec<f64> = g
237            .grid_lines
238            .iter()
239            .filter(|l| l.side == GridSide::Sell)
240            .map(|l| l.price)
241            .collect();
242        assert!((sells[0] - 110.0).abs() < 1e-10); // 100 + 10
243        assert!((sells[1] - 120.0).abs() < 1e-10); // 100 + 20
244    }
245
246    #[test]
247    fn test_new_atr_spacing_no_atr_gives_zero_step() {
248        let g = GridState::new(100.0, 2, GridSpacing::AtrBased(2.0), None);
249        // All lines at center since step = 0
250        for line in &g.grid_lines {
251            assert!((line.price - 100.0).abs() < 1e-10);
252        }
253    }
254
255    #[test]
256    fn test_new_zero_lines() {
257        let g = GridState::new(100.0, 0, GridSpacing::Fixed(1.0), None);
258        assert!(g.grid_lines.is_empty());
259    }
260
261    #[test]
262    fn test_all_lines_start_unfilled() {
263        let g = GridState::new(100.0, 5, GridSpacing::Fixed(1.0), None);
264        for line in &g.grid_lines {
265            assert!(!line.filled);
266            assert!(line.fill_price.is_none());
267        }
268    }
269
270    // --- check_fills ---
271
272    #[test]
273    fn test_check_fills_buy_triggered() {
274        let mut g = GridState::new(100.0, 2, GridSpacing::Fixed(1.0), None);
275        // Buy lines at 99 and 98. Price drops to 98.5 → triggers 99 only
276        let signals = g.check_fills(98.5);
277        assert_eq!(signals.len(), 1);
278        assert_eq!(signals[0].side, GridSide::Buy);
279        assert!((signals[0].price - 99.0).abs() < 1e-10);
280    }
281
282    #[test]
283    fn test_check_fills_multiple_buys() {
284        let mut g = GridState::new(100.0, 3, GridSpacing::Fixed(1.0), None);
285        // Price drops to 97 → triggers all 3 buy lines (99, 98, 97)
286        let signals = g.check_fills(97.0);
287        let buy_signals: Vec<_> = signals.iter().filter(|s| s.side == GridSide::Buy).collect();
288        assert_eq!(buy_signals.len(), 3);
289    }
290
291    #[test]
292    fn test_check_fills_sell_triggered() {
293        let mut g = GridState::new(100.0, 2, GridSpacing::Fixed(1.0), None);
294        // Sell lines at 101, 102. Price rises to 101.5 → triggers 101
295        let signals = g.check_fills(101.5);
296        let sell_signals: Vec<_> = signals
297            .iter()
298            .filter(|s| s.side == GridSide::Sell)
299            .collect();
300        assert_eq!(sell_signals.len(), 1);
301        assert!((sell_signals[0].price - 101.0).abs() < 1e-10);
302    }
303
304    #[test]
305    fn test_check_fills_no_double_fill() {
306        let mut g = GridState::new(100.0, 2, GridSpacing::Fixed(1.0), None);
307        let signals1 = g.check_fills(98.0); // triggers buy lines 99, 98
308        assert_eq!(signals1.len(), 2);
309
310        let signals2 = g.check_fills(97.0); // already filled lines should not re-trigger
311                                            // Only buy lines were at 99, 98. No line at 97.
312        assert_eq!(signals2.len(), 0);
313    }
314
315    #[test]
316    fn test_check_fills_records_fill_price() {
317        let mut g = GridState::new(100.0, 1, GridSpacing::Fixed(1.0), None);
318        g.check_fills(98.5); // triggers buy at 99, fill at 98.5
319        let buy_line = g
320            .grid_lines
321            .iter()
322            .find(|l| l.side == GridSide::Buy)
323            .unwrap();
324        assert!(buy_line.filled);
325        assert_eq!(buy_line.fill_price, Some(98.5));
326    }
327
328    #[test]
329    fn test_check_fills_at_center_no_triggers() {
330        let mut g = GridState::new(100.0, 3, GridSpacing::Fixed(1.0), None);
331        let signals = g.check_fills(100.0);
332        assert!(signals.is_empty());
333    }
334
335    // --- should_reset ---
336
337    #[test]
338    fn test_should_reset_within_threshold() {
339        let g = GridState::new(100.0, 3, GridSpacing::Fixed(1.0), None);
340        // 5% threshold, price at 104 → deviation = 4% < 5%
341        assert!(!g.should_reset(104.0, 5.0));
342    }
343
344    #[test]
345    fn test_should_reset_beyond_threshold() {
346        let g = GridState::new(100.0, 3, GridSpacing::Fixed(1.0), None);
347        // 5% threshold, price at 106 → deviation = 6% > 5%
348        assert!(g.should_reset(106.0, 5.0));
349    }
350
351    #[test]
352    fn test_should_reset_below_center() {
353        let g = GridState::new(100.0, 3, GridSpacing::Fixed(1.0), None);
354        // Price at 93 → deviation = 7% > 5%
355        assert!(g.should_reset(93.0, 5.0));
356    }
357
358    #[test]
359    fn test_should_reset_at_exact_threshold() {
360        let g = GridState::new(100.0, 3, GridSpacing::Fixed(1.0), None);
361        // At exactly 5%, should NOT reset (uses > not >=)
362        assert!(!g.should_reset(105.0, 5.0));
363    }
364
365    #[test]
366    fn test_should_reset_zero_center() {
367        let mut g = GridState::new(100.0, 1, GridSpacing::Fixed(1.0), None);
368        g.center_price = 0.0;
369        assert!(!g.should_reset(50.0, 5.0));
370    }
371
372    // --- reset ---
373
374    #[test]
375    fn test_reset_rebuilds_grid() {
376        let mut g = GridState::new(100.0, 2, GridSpacing::Fixed(1.0), None);
377        g.check_fills(98.0); // fill some lines
378        assert!(g.filled_count() > 0);
379
380        g.reset(110.0, None);
381        assert_eq!(g.center_price, 110.0);
382        assert_eq!(g.unfilled_count(), 4); // all unfilled
383        assert_eq!(g.filled_count(), 0);
384
385        // Check new prices
386        let buys: Vec<f64> = g
387            .grid_lines
388            .iter()
389            .filter(|l| l.side == GridSide::Buy)
390            .map(|l| l.price)
391            .collect();
392        // New step = 110 * 1% = 1.1
393        assert!((buys[0] - (110.0 - 1.1)).abs() < 1e-10);
394    }
395
396    #[test]
397    fn test_reset_preserves_num_lines() {
398        let mut g = GridState::new(100.0, 5, GridSpacing::Fixed(2.0), None);
399        assert_eq!(g.total_lines(), 10);
400        g.reset(200.0, None);
401        assert_eq!(g.total_lines(), 10);
402    }
403
404    // --- unfilled_count / filled_count ---
405
406    #[test]
407    fn test_unfilled_count_all_unfilled() {
408        let g = GridState::new(100.0, 3, GridSpacing::Fixed(1.0), None);
409        assert_eq!(g.unfilled_count(), 6);
410        assert_eq!(g.filled_count(), 0);
411    }
412
413    #[test]
414    fn test_counts_after_fills() {
415        let mut g = GridState::new(100.0, 2, GridSpacing::Fixed(1.0), None);
416        g.check_fills(98.0); // triggers both buy lines (99, 98)
417        assert_eq!(g.filled_count(), 2);
418        assert_eq!(g.unfilled_count(), 2); // sell lines still unfilled
419    }
420
421    // --- Serialization ---
422
423    #[test]
424    fn test_grid_state_serialization_roundtrip() {
425        let mut g = GridState::new(100.0, 2, GridSpacing::Fixed(1.5), None);
426        g.check_fills(98.0);
427
428        let json = serde_json::to_string(&g).unwrap();
429        let parsed: GridState = serde_json::from_str(&json).unwrap();
430
431        assert_eq!(parsed.center_price, 100.0);
432        assert_eq!(parsed.grid_lines.len(), 4);
433        assert!(parsed.active);
434        assert_eq!(parsed.spacing, GridSpacing::Fixed(1.5));
435        assert_eq!(parsed.filled_count(), g.filled_count());
436    }
437
438    #[test]
439    fn test_grid_spacing_serialization() {
440        let fixed = GridSpacing::Fixed(1.5);
441        let json = serde_json::to_string(&fixed).unwrap();
442        let parsed: GridSpacing = serde_json::from_str(&json).unwrap();
443        assert_eq!(parsed, GridSpacing::Fixed(1.5));
444
445        let atr = GridSpacing::AtrBased(2.0);
446        let json = serde_json::to_string(&atr).unwrap();
447        let parsed: GridSpacing = serde_json::from_str(&json).unwrap();
448        assert_eq!(parsed, GridSpacing::AtrBased(2.0));
449    }
450
451    #[test]
452    fn test_grid_side_serialization() {
453        let buy_json = serde_json::to_string(&GridSide::Buy).unwrap();
454        assert_eq!(buy_json, "\"buy\"");
455        let sell_json = serde_json::to_string(&GridSide::Sell).unwrap();
456        assert_eq!(sell_json, "\"sell\"");
457    }
458
459    // --- Integration scenario ---
460
461    #[test]
462    fn test_full_grid_scenario() {
463        // Create grid at 1000 with 2% spacing, 3 lines per side
464        let mut g = GridState::new(1000.0, 3, GridSpacing::Fixed(2.0), None);
465        // step = 1000 * 2 / 100 = 20
466        // Buy: 980, 960, 940
467        // Sell: 1020, 1040, 1060
468        assert_eq!(g.total_lines(), 6);
469        assert_eq!(g.unfilled_count(), 6);
470
471        // Price drops to 975 → fills buy at 980
472        let signals = g.check_fills(975.0);
473        assert_eq!(signals.len(), 1);
474        assert_eq!(signals[0].side, GridSide::Buy);
475
476        // Price drops further to 955 → fills buy at 960
477        let signals = g.check_fills(955.0);
478        assert_eq!(signals.len(), 1);
479
480        // Price bounces to 1025 → fills sell at 1020
481        let signals = g.check_fills(1025.0);
482        assert_eq!(signals.len(), 1);
483        assert_eq!(signals[0].side, GridSide::Sell);
484
485        assert_eq!(g.filled_count(), 3);
486        assert_eq!(g.unfilled_count(), 3);
487
488        // Price moves far → should reset
489        assert!(g.should_reset(1100.0, 5.0)); // 10% deviation > 5%
490
491        // Reset at new center
492        g.reset(1100.0, None);
493        assert_eq!(g.filled_count(), 0);
494        assert_eq!(g.center_price, 1100.0);
495    }
496}