1use std::collections::HashMap;
2use std::fs;
3use std::path::PathBuf;
4
5use serde::{Deserialize, Serialize};
6
7#[derive(Debug, Clone, Serialize, Deserialize)]
12#[serde(rename_all = "camelCase")]
13pub struct StrategyGroupSummary {
14 pub id: String,
15 pub name: String,
16 pub symbol: String,
17 pub is_active: bool,
18 pub trading_mode: String,
19 pub today_pnl: f64,
20 pub agent_loop_status: String,
21}
22
23#[derive(Serialize, Deserialize, Clone, Debug)]
24#[serde(rename_all = "camelCase")]
25pub struct StrategyGroup {
26 pub id: String,
27 pub name: String,
28 pub vault_address: Option<String>,
29 pub is_active: bool,
30 pub created_at: String,
31 pub symbol: String,
32 pub interval_secs: u64,
33 pub regime_rules: Vec<RegimeRule>,
34 pub default_regime: String,
35 pub hysteresis: HysteresisConfig,
36 pub playbooks: HashMap<String, Playbook>,
37}
38
39#[derive(Serialize, Deserialize, Clone, Debug)]
40#[serde(rename_all = "camelCase")]
41pub struct RegimeRule {
42 pub regime: String,
43 pub conditions: Vec<TaRule>,
44 pub priority: u32,
45}
46
47#[derive(Serialize, Deserialize, Clone, Debug)]
48#[serde(rename_all = "camelCase")]
49pub struct TaRule {
50 pub indicator: String,
51 pub params: Vec<f64>,
52 pub condition: String,
53 pub threshold: f64,
54 pub threshold_upper: Option<f64>,
55 pub signal: String,
56 #[serde(default, skip_serializing_if = "Option::is_none")]
59 pub action: Option<String>,
60}
61
62#[derive(Serialize, Deserialize, Clone, Debug)]
63#[serde(rename_all = "camelCase")]
64pub struct Playbook {
65 #[serde(default)]
67 pub rules: Vec<TaRule>,
68
69 #[serde(default)]
71 pub entry_rules: Vec<TaRule>,
72
73 #[serde(default)]
75 pub exit_rules: Vec<TaRule>,
76
77 pub system_prompt: String,
78 pub max_position_size: f64,
79 pub stop_loss_pct: Option<f64>,
80 pub take_profit_pct: Option<f64>,
81
82 #[serde(default)]
84 pub timeout_secs: Option<u64>,
85
86 #[serde(default)]
88 pub side: Option<String>,
89}
90
91impl Playbook {
92 pub fn effective_entry_rules(&self) -> &[TaRule] {
94 if !self.entry_rules.is_empty() {
95 &self.entry_rules
96 } else {
97 &self.rules
98 }
99 }
100
101 pub fn effective_exit_rules(&self) -> &[TaRule] {
103 &self.exit_rules
104 }
105}
106
107#[derive(Serialize, Deserialize, Clone, Debug)]
108#[serde(rename_all = "camelCase")]
109pub struct HysteresisConfig {
110 pub min_hold_secs: u64,
111 pub confirmation_count: u32,
112}
113
114impl Default for HysteresisConfig {
115 fn default() -> Self {
116 Self {
117 min_hold_secs: 3600,
118 confirmation_count: 3,
119 }
120 }
121}
122
123fn strategy_groups_path() -> Option<PathBuf> {
128 dirs::data_dir().map(|d| d.join("hyper-agent").join("strategy_groups.json"))
129}
130
131pub fn load_strategy_groups_from_disk_pub() -> Vec<StrategyGroup> {
133 load_strategy_groups_from_disk()
134}
135
136fn load_strategy_groups_from_disk() -> Vec<StrategyGroup> {
137 let path = match strategy_groups_path() {
138 Some(p) if p.exists() => p,
139 _ => return Vec::new(),
140 };
141 let data = match fs::read_to_string(&path) {
142 Ok(d) => d,
143 Err(_) => return Vec::new(),
144 };
145 serde_json::from_str(&data).unwrap_or_default()
146}
147
148pub fn save_strategy_groups_to_disk(groups: &[StrategyGroup]) -> Result<(), String> {
149 let path = strategy_groups_path().ok_or("Could not determine data directory")?;
150 if let Some(parent) = path.parent() {
151 fs::create_dir_all(parent).map_err(|e| format!("Failed to create data dir: {}", e))?;
152 }
153 let json =
154 serde_json::to_string_pretty(groups).map_err(|e| format!("Serialize error: {}", e))?;
155 fs::write(&path, json).map_err(|e| format!("Failed to write strategy_groups file: {}", e))?;
156 Ok(())
157}
158
159#[cfg(test)]
164mod tests {
165 use super::*;
166
167 fn sample_ta_rule() -> TaRule {
168 TaRule {
169 indicator: "RSI".to_string(),
170 params: vec![14.0],
171 condition: "lt".to_string(),
172 threshold: 30.0,
173 threshold_upper: None,
174 signal: "oversold".to_string(),
175 action: None,
176 }
177 }
178
179 fn sample_ta_rule_between() -> TaRule {
180 TaRule {
181 indicator: "BB".to_string(),
182 params: vec![20.0, 2.0],
183 condition: "between".to_string(),
184 threshold: -1.0,
185 threshold_upper: Some(1.0),
186 signal: "inside_bands".to_string(),
187 action: None,
188 }
189 }
190
191 fn sample_strategy_group() -> StrategyGroup {
192 let mut playbooks = HashMap::new();
193 playbooks.insert(
194 "bull".to_string(),
195 Playbook {
196 rules: vec![sample_ta_rule()],
197 entry_rules: vec![],
198 exit_rules: vec![],
199 system_prompt: "You are a bull-market trading agent.".to_string(),
200 max_position_size: 1000.0,
201 stop_loss_pct: Some(5.0),
202 take_profit_pct: Some(10.0),
203 timeout_secs: None,
204 side: None,
205 },
206 );
207 playbooks.insert(
208 "bear".to_string(),
209 Playbook {
210 rules: vec![sample_ta_rule_between()],
211 entry_rules: vec![],
212 exit_rules: vec![],
213 system_prompt: "You are a bear-market trading agent.".to_string(),
214 max_position_size: 500.0,
215 stop_loss_pct: Some(3.0),
216 take_profit_pct: None,
217 timeout_secs: None,
218 side: None,
219 },
220 );
221
222 StrategyGroup {
223 id: "sg-001".to_string(),
224 name: "BTC Regime Strategy".to_string(),
225 vault_address: Some("0xabc123".to_string()),
226 is_active: true,
227 created_at: "2026-03-09T00:00:00Z".to_string(),
228 symbol: "BTC-USD".to_string(),
229 interval_secs: 300,
230 regime_rules: vec![
231 RegimeRule {
232 regime: "bull".to_string(),
233 conditions: vec![TaRule {
234 indicator: "EMA".to_string(),
235 params: vec![50.0, 200.0],
236 condition: "cross_above".to_string(),
237 threshold: 0.0,
238 threshold_upper: None,
239 signal: "golden_cross".to_string(),
240 action: None,
241 }],
242 priority: 1,
243 },
244 RegimeRule {
245 regime: "bear".to_string(),
246 conditions: vec![TaRule {
247 indicator: "EMA".to_string(),
248 params: vec![50.0, 200.0],
249 condition: "cross_below".to_string(),
250 threshold: 0.0,
251 threshold_upper: None,
252 signal: "death_cross".to_string(),
253 action: None,
254 }],
255 priority: 2,
256 },
257 ],
258 default_regime: "neutral".to_string(),
259 hysteresis: HysteresisConfig::default(),
260 playbooks,
261 }
262 }
263
264 #[test]
265 fn test_strategy_group_serialization_roundtrip() {
266 let group = sample_strategy_group();
267 let json = serde_json::to_string_pretty(&group).unwrap();
268 let parsed: StrategyGroup = serde_json::from_str(&json).unwrap();
269
270 assert_eq!(parsed.id, "sg-001");
271 assert_eq!(parsed.name, "BTC Regime Strategy");
272 assert_eq!(parsed.symbol, "BTC-USD");
273 assert_eq!(parsed.interval_secs, 300);
274 assert_eq!(parsed.regime_rules.len(), 2);
275 assert_eq!(parsed.default_regime, "neutral");
276 assert_eq!(parsed.playbooks.len(), 2);
277 assert!(parsed.is_active);
278 }
279
280 #[test]
281 fn test_hysteresis_defaults() {
282 let h = HysteresisConfig::default();
283 assert_eq!(h.min_hold_secs, 3600);
284 assert_eq!(h.confirmation_count, 3);
285 }
286
287 #[test]
288 fn test_ta_rule_with_threshold_upper() {
289 let rule = sample_ta_rule_between();
290 let json = serde_json::to_string(&rule).unwrap();
291 let parsed: TaRule = serde_json::from_str(&json).unwrap();
292 assert_eq!(parsed.condition, "between");
293 assert_eq!(parsed.threshold_upper, Some(1.0));
294 }
295
296 #[test]
297 fn test_ta_rule_without_threshold_upper() {
298 let rule = sample_ta_rule();
299 let json = serde_json::to_string(&rule).unwrap();
300 let parsed: TaRule = serde_json::from_str(&json).unwrap();
301 assert!(parsed.threshold_upper.is_none());
302 }
303
304 #[test]
305 fn test_playbook_serialization() {
306 let playbook = Playbook {
307 rules: vec![sample_ta_rule()],
308 entry_rules: vec![],
309 exit_rules: vec![],
310 system_prompt: "Trade carefully.".to_string(),
311 max_position_size: 2000.0,
312 stop_loss_pct: None,
313 take_profit_pct: Some(15.0),
314 timeout_secs: None,
315 side: None,
316 };
317 let json = serde_json::to_string(&playbook).unwrap();
318 let parsed: Playbook = serde_json::from_str(&json).unwrap();
319 assert_eq!(parsed.max_position_size, 2000.0);
320 assert!(parsed.stop_loss_pct.is_none());
321 assert_eq!(parsed.take_profit_pct, Some(15.0));
322 }
323
324 #[test]
325 fn test_regime_rule_serialization() {
326 let rule = RegimeRule {
327 regime: "volatile".to_string(),
328 conditions: vec![TaRule {
329 indicator: "ATR".to_string(),
330 params: vec![14.0],
331 condition: "gt".to_string(),
332 threshold: 50.0,
333 threshold_upper: None,
334 signal: "high_volatility".to_string(),
335 action: None,
336 }],
337 priority: 1,
338 };
339 let json = serde_json::to_string(&rule).unwrap();
340 let parsed: RegimeRule = serde_json::from_str(&json).unwrap();
341 assert_eq!(parsed.regime, "volatile");
342 assert_eq!(parsed.conditions.len(), 1);
343 assert_eq!(parsed.priority, 1);
344 }
345
346 #[test]
347 fn test_strategy_group_without_vault() {
348 let mut group = sample_strategy_group();
349 group.vault_address = None;
350 let json = serde_json::to_string(&group).unwrap();
351 let parsed: StrategyGroup = serde_json::from_str(&json).unwrap();
352 assert!(parsed.vault_address.is_none());
353 }
354
355 #[test]
356 fn test_camel_case_keys() {
357 let group = sample_strategy_group();
358 let json = serde_json::to_string(&group).unwrap();
359 assert!(json.contains("\"isActive\""));
360 assert!(json.contains("\"createdAt\""));
361 assert!(json.contains("\"intervalSecs\""));
362 assert!(json.contains("\"regimeRules\""));
363 assert!(json.contains("\"defaultRegime\""));
364 assert!(json.contains("\"vaultAddress\""));
365 assert!(json.contains("\"minHoldSecs\""));
366 assert!(json.contains("\"confirmationCount\""));
367 assert!(json.contains("\"systemPrompt\""));
368 assert!(json.contains("\"maxPositionSize\""));
369 assert!(json.contains("\"stopLossPct\""));
370 assert!(json.contains("\"takeProfitPct\""));
371 assert!(json.contains("\"thresholdUpper\""));
372 }
373
374 #[test]
375 fn test_deserialize_from_json_string() {
376 let json = r#"{
377 "id": "sg-test",
378 "name": "Test Group",
379 "vaultAddress": null,
380 "isActive": false,
381 "createdAt": "2026-01-01T00:00:00Z",
382 "symbol": "ETH-USD",
383 "intervalSecs": 60,
384 "regimeRules": [],
385 "defaultRegime": "neutral",
386 "hysteresis": {
387 "minHoldSecs": 1800,
388 "confirmationCount": 2
389 },
390 "playbooks": {}
391 }"#;
392 let parsed: StrategyGroup = serde_json::from_str(json).unwrap();
393 assert_eq!(parsed.id, "sg-test");
394 assert_eq!(parsed.symbol, "ETH-USD");
395 assert_eq!(parsed.interval_secs, 60);
396 assert_eq!(parsed.hysteresis.min_hold_secs, 1800);
397 assert_eq!(parsed.hysteresis.confirmation_count, 2);
398 assert!(parsed.regime_rules.is_empty());
399 assert!(parsed.playbooks.is_empty());
400 }
401
402 #[test]
403 fn test_ta_rule_action_field_optional() {
404 let json = r#"{
406 "indicator": "RSI",
407 "params": [14.0],
408 "condition": "lt",
409 "threshold": 30.0,
410 "signal": "oversold"
411 }"#;
412 let parsed: TaRule = serde_json::from_str(json).unwrap();
413 assert!(parsed.action.is_none());
414
415 let json2 = r#"{
417 "indicator": "RSI",
418 "params": [14.0],
419 "condition": "lt",
420 "threshold": 30.0,
421 "signal": "oversold",
422 "action": "buy"
423 }"#;
424 let parsed2: TaRule = serde_json::from_str(json2).unwrap();
425 assert_eq!(parsed2.action, Some("buy".to_string()));
426 }
427
428 #[test]
433 fn test_backward_compat_only_rules() {
434 let json = r#"{
436 "rules": [{
437 "indicator": "RSI",
438 "params": [14.0],
439 "condition": "lt",
440 "threshold": 30.0,
441 "signal": "oversold"
442 }],
443 "systemPrompt": "hello",
444 "maxPositionSize": 100.0,
445 "stopLossPct": null,
446 "takeProfitPct": null
447 }"#;
448 let pb: Playbook = serde_json::from_str(json).unwrap();
449 assert_eq!(pb.rules.len(), 1);
450 assert!(pb.entry_rules.is_empty());
451 assert!(pb.exit_rules.is_empty());
452 assert_eq!(pb.effective_entry_rules().len(), 1);
454 assert_eq!(pb.effective_entry_rules()[0].indicator, "RSI");
455 assert!(pb.effective_exit_rules().is_empty());
456 assert!(pb.timeout_secs.is_none());
457 assert!(pb.side.is_none());
458 }
459
460 #[test]
461 fn test_new_format_entry_exit_rules() {
462 let json = r#"{
463 "entryRules": [{
464 "indicator": "RSI",
465 "params": [14.0],
466 "condition": "lt",
467 "threshold": 30.0,
468 "signal": "oversold"
469 }],
470 "exitRules": [{
471 "indicator": "RSI",
472 "params": [14.0],
473 "condition": "gt",
474 "threshold": 70.0,
475 "signal": "overbought"
476 }],
477 "systemPrompt": "hello",
478 "maxPositionSize": 100.0,
479 "stopLossPct": null,
480 "takeProfitPct": null
481 }"#;
482 let pb: Playbook = serde_json::from_str(json).unwrap();
483 assert!(pb.rules.is_empty());
484 assert_eq!(pb.effective_entry_rules().len(), 1);
485 assert_eq!(pb.effective_entry_rules()[0].signal, "oversold");
486 assert_eq!(pb.effective_exit_rules().len(), 1);
487 assert_eq!(pb.effective_exit_rules()[0].signal, "overbought");
488 }
489
490 #[test]
491 fn test_mixed_rules_and_entry_rules_entry_wins() {
492 let pb = Playbook {
493 rules: vec![sample_ta_rule()],
494 entry_rules: vec![sample_ta_rule_between()],
495 exit_rules: vec![],
496 system_prompt: "mixed".to_string(),
497 max_position_size: 100.0,
498 stop_loss_pct: None,
499 take_profit_pct: None,
500 timeout_secs: None,
501 side: None,
502 };
503 assert_eq!(pb.effective_entry_rules().len(), 1);
505 assert_eq!(pb.effective_entry_rules()[0].indicator, "BB");
506 }
507
508 #[test]
509 fn test_side_and_timeout_serde() {
510 let pb = Playbook {
511 rules: vec![],
512 entry_rules: vec![],
513 exit_rules: vec![],
514 system_prompt: "test".to_string(),
515 max_position_size: 50.0,
516 stop_loss_pct: None,
517 take_profit_pct: None,
518 timeout_secs: Some(300),
519 side: Some("long".to_string()),
520 };
521 let json = serde_json::to_string(&pb).unwrap();
522 assert!(json.contains("\"timeoutSecs\":300"));
523 assert!(json.contains("\"side\":\"long\""));
524
525 let parsed: Playbook = serde_json::from_str(&json).unwrap();
526 assert_eq!(parsed.timeout_secs, Some(300));
527 assert_eq!(parsed.side, Some("long".to_string()));
528 }
529
530 #[test]
531 fn test_old_json_without_new_fields_deserializes() {
532 let json = r#"{
534 "rules": [],
535 "systemPrompt": "old format",
536 "maxPositionSize": 200.0,
537 "stopLossPct": 5.0,
538 "takeProfitPct": 10.0
539 }"#;
540 let pb: Playbook = serde_json::from_str(json).unwrap();
541 assert!(pb.entry_rules.is_empty());
542 assert!(pb.exit_rules.is_empty());
543 assert!(pb.timeout_secs.is_none());
544 assert!(pb.side.is_none());
545 assert_eq!(pb.system_prompt, "old format");
546 }
547}