Skip to main content

hyper_strategy/
strategy_config.rs

1use std::collections::HashMap;
2use std::fs;
3use std::path::PathBuf;
4
5use serde::{Deserialize, Serialize};
6
7// ---------------------------------------------------------------------------
8// Data types
9// ---------------------------------------------------------------------------
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
12#[serde(rename_all = "camelCase")]
13pub struct StrategyGroupSummary {
14    pub id: String,
15    pub name: String,
16    pub symbol: String,
17    pub is_active: bool,
18    pub trading_mode: String,
19    pub today_pnl: f64,
20    pub agent_loop_status: String,
21}
22
23#[derive(Serialize, Deserialize, Clone, Debug)]
24#[serde(rename_all = "camelCase")]
25pub struct StrategyGroup {
26    pub id: String,
27    pub name: String,
28    pub vault_address: Option<String>,
29    pub is_active: bool,
30    pub created_at: String,
31    pub symbol: String,
32    pub interval_secs: u64,
33    pub regime_rules: Vec<RegimeRule>,
34    pub default_regime: String,
35    pub hysteresis: HysteresisConfig,
36    pub playbooks: HashMap<String, Playbook>,
37}
38
39#[derive(Serialize, Deserialize, Clone, Debug)]
40#[serde(rename_all = "camelCase")]
41pub struct RegimeRule {
42    pub regime: String,
43    pub conditions: Vec<TaRule>,
44    pub priority: u32,
45}
46
47#[derive(Serialize, Deserialize, Clone, Debug)]
48#[serde(rename_all = "camelCase")]
49pub struct TaRule {
50    pub indicator: String,
51    pub params: Vec<f64>,
52    pub condition: String,
53    pub threshold: f64,
54    pub threshold_upper: Option<f64>,
55    pub signal: String,
56    /// Optional explicit action override (e.g. "buy", "sell", "close").
57    /// When absent, the rule executor infers the action from the signal name.
58    #[serde(default, skip_serializing_if = "Option::is_none")]
59    pub action: Option<String>,
60}
61
62#[derive(Serialize, Deserialize, Clone, Debug)]
63#[serde(rename_all = "camelCase")]
64pub struct Playbook {
65    /// Legacy unified rules list (kept for backward compatibility).
66    #[serde(default)]
67    pub rules: Vec<TaRule>,
68
69    /// Entry-specific rules (new). When non-empty, takes precedence over `rules`.
70    #[serde(default)]
71    pub entry_rules: Vec<TaRule>,
72
73    /// Exit-specific rules (new).
74    #[serde(default)]
75    pub exit_rules: Vec<TaRule>,
76
77    pub system_prompt: String,
78    pub max_position_size: f64,
79    pub stop_loss_pct: Option<f64>,
80    pub take_profit_pct: Option<f64>,
81
82    /// Optional timeout in seconds for the playbook execution.
83    #[serde(default)]
84    pub timeout_secs: Option<u64>,
85
86    /// Optional side hint (e.g. "long", "short").
87    #[serde(default)]
88    pub side: Option<String>,
89}
90
91impl Playbook {
92    /// Returns entry rules: prefers `entry_rules` if non-empty, falls back to `rules`.
93    pub fn effective_entry_rules(&self) -> &[TaRule] {
94        if !self.entry_rules.is_empty() {
95            &self.entry_rules
96        } else {
97            &self.rules
98        }
99    }
100
101    /// Returns exit rules.
102    pub fn effective_exit_rules(&self) -> &[TaRule] {
103        &self.exit_rules
104    }
105}
106
107#[derive(Serialize, Deserialize, Clone, Debug)]
108#[serde(rename_all = "camelCase")]
109pub struct HysteresisConfig {
110    pub min_hold_secs: u64,
111    pub confirmation_count: u32,
112}
113
114impl Default for HysteresisConfig {
115    fn default() -> Self {
116        Self {
117            min_hold_secs: 3600,
118            confirmation_count: 3,
119        }
120    }
121}
122
123// ---------------------------------------------------------------------------
124// Persistence helpers
125// ---------------------------------------------------------------------------
126
127fn strategy_groups_path() -> Option<PathBuf> {
128    dirs::data_dir().map(|d| d.join("hyper-agent").join("strategy_groups.json"))
129}
130
131/// Public accessor for loading strategy groups (used by agent_loop).
132pub fn load_strategy_groups_from_disk_pub() -> Vec<StrategyGroup> {
133    load_strategy_groups_from_disk()
134}
135
136fn load_strategy_groups_from_disk() -> Vec<StrategyGroup> {
137    let path = match strategy_groups_path() {
138        Some(p) if p.exists() => p,
139        _ => return Vec::new(),
140    };
141    let data = match fs::read_to_string(&path) {
142        Ok(d) => d,
143        Err(_) => return Vec::new(),
144    };
145    serde_json::from_str(&data).unwrap_or_default()
146}
147
148pub fn save_strategy_groups_to_disk(groups: &[StrategyGroup]) -> Result<(), String> {
149    let path = strategy_groups_path().ok_or("Could not determine data directory")?;
150    if let Some(parent) = path.parent() {
151        fs::create_dir_all(parent).map_err(|e| format!("Failed to create data dir: {}", e))?;
152    }
153    let json =
154        serde_json::to_string_pretty(groups).map_err(|e| format!("Serialize error: {}", e))?;
155    fs::write(&path, json).map_err(|e| format!("Failed to write strategy_groups file: {}", e))?;
156    Ok(())
157}
158
159// ---------------------------------------------------------------------------
160// Tests
161// ---------------------------------------------------------------------------
162
163#[cfg(test)]
164mod tests {
165    use super::*;
166
167    fn sample_ta_rule() -> TaRule {
168        TaRule {
169            indicator: "RSI".to_string(),
170            params: vec![14.0],
171            condition: "lt".to_string(),
172            threshold: 30.0,
173            threshold_upper: None,
174            signal: "oversold".to_string(),
175            action: None,
176        }
177    }
178
179    fn sample_ta_rule_between() -> TaRule {
180        TaRule {
181            indicator: "BB".to_string(),
182            params: vec![20.0, 2.0],
183            condition: "between".to_string(),
184            threshold: -1.0,
185            threshold_upper: Some(1.0),
186            signal: "inside_bands".to_string(),
187            action: None,
188        }
189    }
190
191    fn sample_strategy_group() -> StrategyGroup {
192        let mut playbooks = HashMap::new();
193        playbooks.insert(
194            "bull".to_string(),
195            Playbook {
196                rules: vec![sample_ta_rule()],
197                entry_rules: vec![],
198                exit_rules: vec![],
199                system_prompt: "You are a bull-market trading agent.".to_string(),
200                max_position_size: 1000.0,
201                stop_loss_pct: Some(5.0),
202                take_profit_pct: Some(10.0),
203                timeout_secs: None,
204                side: None,
205            },
206        );
207        playbooks.insert(
208            "bear".to_string(),
209            Playbook {
210                rules: vec![sample_ta_rule_between()],
211                entry_rules: vec![],
212                exit_rules: vec![],
213                system_prompt: "You are a bear-market trading agent.".to_string(),
214                max_position_size: 500.0,
215                stop_loss_pct: Some(3.0),
216                take_profit_pct: None,
217                timeout_secs: None,
218                side: None,
219            },
220        );
221
222        StrategyGroup {
223            id: "sg-001".to_string(),
224            name: "BTC Regime Strategy".to_string(),
225            vault_address: Some("0xabc123".to_string()),
226            is_active: true,
227            created_at: "2026-03-09T00:00:00Z".to_string(),
228            symbol: "BTC-USD".to_string(),
229            interval_secs: 300,
230            regime_rules: vec![
231                RegimeRule {
232                    regime: "bull".to_string(),
233                    conditions: vec![TaRule {
234                        indicator: "EMA".to_string(),
235                        params: vec![50.0, 200.0],
236                        condition: "cross_above".to_string(),
237                        threshold: 0.0,
238                        threshold_upper: None,
239                        signal: "golden_cross".to_string(),
240                        action: None,
241                    }],
242                    priority: 1,
243                },
244                RegimeRule {
245                    regime: "bear".to_string(),
246                    conditions: vec![TaRule {
247                        indicator: "EMA".to_string(),
248                        params: vec![50.0, 200.0],
249                        condition: "cross_below".to_string(),
250                        threshold: 0.0,
251                        threshold_upper: None,
252                        signal: "death_cross".to_string(),
253                        action: None,
254                    }],
255                    priority: 2,
256                },
257            ],
258            default_regime: "neutral".to_string(),
259            hysteresis: HysteresisConfig::default(),
260            playbooks,
261        }
262    }
263
264    #[test]
265    fn test_strategy_group_serialization_roundtrip() {
266        let group = sample_strategy_group();
267        let json = serde_json::to_string_pretty(&group).unwrap();
268        let parsed: StrategyGroup = serde_json::from_str(&json).unwrap();
269
270        assert_eq!(parsed.id, "sg-001");
271        assert_eq!(parsed.name, "BTC Regime Strategy");
272        assert_eq!(parsed.symbol, "BTC-USD");
273        assert_eq!(parsed.interval_secs, 300);
274        assert_eq!(parsed.regime_rules.len(), 2);
275        assert_eq!(parsed.default_regime, "neutral");
276        assert_eq!(parsed.playbooks.len(), 2);
277        assert!(parsed.is_active);
278    }
279
280    #[test]
281    fn test_hysteresis_defaults() {
282        let h = HysteresisConfig::default();
283        assert_eq!(h.min_hold_secs, 3600);
284        assert_eq!(h.confirmation_count, 3);
285    }
286
287    #[test]
288    fn test_ta_rule_with_threshold_upper() {
289        let rule = sample_ta_rule_between();
290        let json = serde_json::to_string(&rule).unwrap();
291        let parsed: TaRule = serde_json::from_str(&json).unwrap();
292        assert_eq!(parsed.condition, "between");
293        assert_eq!(parsed.threshold_upper, Some(1.0));
294    }
295
296    #[test]
297    fn test_ta_rule_without_threshold_upper() {
298        let rule = sample_ta_rule();
299        let json = serde_json::to_string(&rule).unwrap();
300        let parsed: TaRule = serde_json::from_str(&json).unwrap();
301        assert!(parsed.threshold_upper.is_none());
302    }
303
304    #[test]
305    fn test_playbook_serialization() {
306        let playbook = Playbook {
307            rules: vec![sample_ta_rule()],
308            entry_rules: vec![],
309            exit_rules: vec![],
310            system_prompt: "Trade carefully.".to_string(),
311            max_position_size: 2000.0,
312            stop_loss_pct: None,
313            take_profit_pct: Some(15.0),
314            timeout_secs: None,
315            side: None,
316        };
317        let json = serde_json::to_string(&playbook).unwrap();
318        let parsed: Playbook = serde_json::from_str(&json).unwrap();
319        assert_eq!(parsed.max_position_size, 2000.0);
320        assert!(parsed.stop_loss_pct.is_none());
321        assert_eq!(parsed.take_profit_pct, Some(15.0));
322    }
323
324    #[test]
325    fn test_regime_rule_serialization() {
326        let rule = RegimeRule {
327            regime: "volatile".to_string(),
328            conditions: vec![TaRule {
329                indicator: "ATR".to_string(),
330                params: vec![14.0],
331                condition: "gt".to_string(),
332                threshold: 50.0,
333                threshold_upper: None,
334                signal: "high_volatility".to_string(),
335                action: None,
336            }],
337            priority: 1,
338        };
339        let json = serde_json::to_string(&rule).unwrap();
340        let parsed: RegimeRule = serde_json::from_str(&json).unwrap();
341        assert_eq!(parsed.regime, "volatile");
342        assert_eq!(parsed.conditions.len(), 1);
343        assert_eq!(parsed.priority, 1);
344    }
345
346    #[test]
347    fn test_strategy_group_without_vault() {
348        let mut group = sample_strategy_group();
349        group.vault_address = None;
350        let json = serde_json::to_string(&group).unwrap();
351        let parsed: StrategyGroup = serde_json::from_str(&json).unwrap();
352        assert!(parsed.vault_address.is_none());
353    }
354
355    #[test]
356    fn test_camel_case_keys() {
357        let group = sample_strategy_group();
358        let json = serde_json::to_string(&group).unwrap();
359        assert!(json.contains("\"isActive\""));
360        assert!(json.contains("\"createdAt\""));
361        assert!(json.contains("\"intervalSecs\""));
362        assert!(json.contains("\"regimeRules\""));
363        assert!(json.contains("\"defaultRegime\""));
364        assert!(json.contains("\"vaultAddress\""));
365        assert!(json.contains("\"minHoldSecs\""));
366        assert!(json.contains("\"confirmationCount\""));
367        assert!(json.contains("\"systemPrompt\""));
368        assert!(json.contains("\"maxPositionSize\""));
369        assert!(json.contains("\"stopLossPct\""));
370        assert!(json.contains("\"takeProfitPct\""));
371        assert!(json.contains("\"thresholdUpper\""));
372    }
373
374    #[test]
375    fn test_deserialize_from_json_string() {
376        let json = r#"{
377            "id": "sg-test",
378            "name": "Test Group",
379            "vaultAddress": null,
380            "isActive": false,
381            "createdAt": "2026-01-01T00:00:00Z",
382            "symbol": "ETH-USD",
383            "intervalSecs": 60,
384            "regimeRules": [],
385            "defaultRegime": "neutral",
386            "hysteresis": {
387                "minHoldSecs": 1800,
388                "confirmationCount": 2
389            },
390            "playbooks": {}
391        }"#;
392        let parsed: StrategyGroup = serde_json::from_str(json).unwrap();
393        assert_eq!(parsed.id, "sg-test");
394        assert_eq!(parsed.symbol, "ETH-USD");
395        assert_eq!(parsed.interval_secs, 60);
396        assert_eq!(parsed.hysteresis.min_hold_secs, 1800);
397        assert_eq!(parsed.hysteresis.confirmation_count, 2);
398        assert!(parsed.regime_rules.is_empty());
399        assert!(parsed.playbooks.is_empty());
400    }
401
402    #[test]
403    fn test_ta_rule_action_field_optional() {
404        // Without action field (backward compat)
405        let json = r#"{
406            "indicator": "RSI",
407            "params": [14.0],
408            "condition": "lt",
409            "threshold": 30.0,
410            "signal": "oversold"
411        }"#;
412        let parsed: TaRule = serde_json::from_str(json).unwrap();
413        assert!(parsed.action.is_none());
414
415        // With action field
416        let json2 = r#"{
417            "indicator": "RSI",
418            "params": [14.0],
419            "condition": "lt",
420            "threshold": 30.0,
421            "signal": "oversold",
422            "action": "buy"
423        }"#;
424        let parsed2: TaRule = serde_json::from_str(json2).unwrap();
425        assert_eq!(parsed2.action, Some("buy".to_string()));
426    }
427
428    // -----------------------------------------------------------------------
429    // entry_rules / exit_rules backward-compat tests
430    // -----------------------------------------------------------------------
431
432    #[test]
433    fn test_backward_compat_only_rules() {
434        // Old-format JSON: only "rules", no entry/exit
435        let json = r#"{
436            "rules": [{
437                "indicator": "RSI",
438                "params": [14.0],
439                "condition": "lt",
440                "threshold": 30.0,
441                "signal": "oversold"
442            }],
443            "systemPrompt": "hello",
444            "maxPositionSize": 100.0,
445            "stopLossPct": null,
446            "takeProfitPct": null
447        }"#;
448        let pb: Playbook = serde_json::from_str(json).unwrap();
449        assert_eq!(pb.rules.len(), 1);
450        assert!(pb.entry_rules.is_empty());
451        assert!(pb.exit_rules.is_empty());
452        // effective_entry_rules falls back to rules
453        assert_eq!(pb.effective_entry_rules().len(), 1);
454        assert_eq!(pb.effective_entry_rules()[0].indicator, "RSI");
455        assert!(pb.effective_exit_rules().is_empty());
456        assert!(pb.timeout_secs.is_none());
457        assert!(pb.side.is_none());
458    }
459
460    #[test]
461    fn test_new_format_entry_exit_rules() {
462        let json = r#"{
463            "entryRules": [{
464                "indicator": "RSI",
465                "params": [14.0],
466                "condition": "lt",
467                "threshold": 30.0,
468                "signal": "oversold"
469            }],
470            "exitRules": [{
471                "indicator": "RSI",
472                "params": [14.0],
473                "condition": "gt",
474                "threshold": 70.0,
475                "signal": "overbought"
476            }],
477            "systemPrompt": "hello",
478            "maxPositionSize": 100.0,
479            "stopLossPct": null,
480            "takeProfitPct": null
481        }"#;
482        let pb: Playbook = serde_json::from_str(json).unwrap();
483        assert!(pb.rules.is_empty());
484        assert_eq!(pb.effective_entry_rules().len(), 1);
485        assert_eq!(pb.effective_entry_rules()[0].signal, "oversold");
486        assert_eq!(pb.effective_exit_rules().len(), 1);
487        assert_eq!(pb.effective_exit_rules()[0].signal, "overbought");
488    }
489
490    #[test]
491    fn test_mixed_rules_and_entry_rules_entry_wins() {
492        let pb = Playbook {
493            rules: vec![sample_ta_rule()],
494            entry_rules: vec![sample_ta_rule_between()],
495            exit_rules: vec![],
496            system_prompt: "mixed".to_string(),
497            max_position_size: 100.0,
498            stop_loss_pct: None,
499            take_profit_pct: None,
500            timeout_secs: None,
501            side: None,
502        };
503        // entry_rules wins over rules
504        assert_eq!(pb.effective_entry_rules().len(), 1);
505        assert_eq!(pb.effective_entry_rules()[0].indicator, "BB");
506    }
507
508    #[test]
509    fn test_side_and_timeout_serde() {
510        let pb = Playbook {
511            rules: vec![],
512            entry_rules: vec![],
513            exit_rules: vec![],
514            system_prompt: "test".to_string(),
515            max_position_size: 50.0,
516            stop_loss_pct: None,
517            take_profit_pct: None,
518            timeout_secs: Some(300),
519            side: Some("long".to_string()),
520        };
521        let json = serde_json::to_string(&pb).unwrap();
522        assert!(json.contains("\"timeoutSecs\":300"));
523        assert!(json.contains("\"side\":\"long\""));
524
525        let parsed: Playbook = serde_json::from_str(&json).unwrap();
526        assert_eq!(parsed.timeout_secs, Some(300));
527        assert_eq!(parsed.side, Some("long".to_string()));
528    }
529
530    #[test]
531    fn test_old_json_without_new_fields_deserializes() {
532        // Minimal old-format JSON without any new fields
533        let json = r#"{
534            "rules": [],
535            "systemPrompt": "old format",
536            "maxPositionSize": 200.0,
537            "stopLossPct": 5.0,
538            "takeProfitPct": 10.0
539        }"#;
540        let pb: Playbook = serde_json::from_str(json).unwrap();
541        assert!(pb.entry_rules.is_empty());
542        assert!(pb.exit_rules.is_empty());
543        assert!(pb.timeout_secs.is_none());
544        assert!(pb.side.is_none());
545        assert_eq!(pb.system_prompt, "old format");
546    }
547}