Skip to main content

hyper_agent_core/
agent_adjuster.rs

1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3
4use hyper_strategy::strategy_config::{HysteresisConfig, RegimeRule, StrategyGroup, TaRule};
5
6// ---------------------------------------------------------------------------
7// Types
8// ---------------------------------------------------------------------------
9
10/// Represents adjustments Claude can make to a StrategyGroup.
11/// Only whitelisted fields can be modified.
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct StrategyAdjustment {
14    #[serde(alias = "regimeRules", alias = "regime_rules")]
15    pub regime_rules: Option<Vec<RegimeRule>>,
16    #[serde(alias = "defaultRegime", alias = "default_regime")]
17    pub default_regime: Option<String>,
18    pub hysteresis: Option<HysteresisConfig>,
19    #[serde(alias = "playbookOverrides", alias = "playbook_overrides")]
20    pub playbook_overrides: Option<HashMap<String, PlaybookOverride>>,
21}
22
23/// Adjustable fields within a Playbook.
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct PlaybookOverride {
26    pub rules: Option<Vec<TaRule>>,
27    #[serde(alias = "maxPositionSize", alias = "max_position_size")]
28    pub max_position_size: Option<f64>,
29    #[serde(alias = "stopLossPct", alias = "stop_loss_pct")]
30    pub stop_loss_pct: Option<f64>,
31    #[serde(alias = "takeProfitPct", alias = "take_profit_pct")]
32    pub take_profit_pct: Option<f64>,
33}
34
35#[derive(Debug)]
36pub enum AdjustmentError {
37    MaxPositionExceeded {
38        requested: f64,
39        limit: f64,
40    },
41    InvalidThreshold {
42        indicator: String,
43        value: f64,
44        reason: String,
45    },
46    FileWriteError(String),
47}
48
49impl std::fmt::Display for AdjustmentError {
50    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51        match self {
52            Self::MaxPositionExceeded { requested, limit } => {
53                write!(f, "max_position_size {} exceeds limit {}", requested, limit)
54            }
55            Self::InvalidThreshold {
56                indicator,
57                value,
58                reason,
59            } => {
60                write!(
61                    f,
62                    "Invalid threshold for {}: {} ({})",
63                    indicator, value, reason
64                )
65            }
66            Self::FileWriteError(msg) => write!(f, "File write error: {}", msg),
67        }
68    }
69}
70
71impl std::error::Error for AdjustmentError {}
72
73// ---------------------------------------------------------------------------
74// Core functions
75// ---------------------------------------------------------------------------
76
77/// Validate an adjustment against risk limits.
78pub fn validate_adjustment(
79    adjustment: &StrategyAdjustment,
80    max_position_usdc: f64,
81) -> Result<(), AdjustmentError> {
82    if let Some(ref overrides) = adjustment.playbook_overrides {
83        for (_, pb_override) in overrides {
84            if let Some(max_pos) = pb_override.max_position_size {
85                if max_pos > max_position_usdc {
86                    return Err(AdjustmentError::MaxPositionExceeded {
87                        requested: max_pos,
88                        limit: max_position_usdc,
89                    });
90                }
91                if max_pos < 0.0 {
92                    return Err(AdjustmentError::InvalidThreshold {
93                        indicator: "max_position_size".into(),
94                        value: max_pos,
95                        reason: "must be non-negative".into(),
96                    });
97                }
98            }
99            if let Some(sl) = pb_override.stop_loss_pct {
100                if sl < 0.0 || sl > 100.0 {
101                    return Err(AdjustmentError::InvalidThreshold {
102                        indicator: "stop_loss_pct".into(),
103                        value: sl,
104                        reason: "must be between 0 and 100".into(),
105                    });
106                }
107            }
108            if let Some(tp) = pb_override.take_profit_pct {
109                if tp < 0.0 || tp > 1000.0 {
110                    return Err(AdjustmentError::InvalidThreshold {
111                        indicator: "take_profit_pct".into(),
112                        value: tp,
113                        reason: "must be between 0 and 1000".into(),
114                    });
115                }
116            }
117        }
118    }
119    Ok(())
120}
121
122/// Apply a validated adjustment to a StrategyGroup.
123/// Immutable fields (id, name, symbol, interval_secs) are never changed.
124pub fn apply_adjustment(group: &mut StrategyGroup, adjustment: &StrategyAdjustment) {
125    if let Some(ref rules) = adjustment.regime_rules {
126        group.regime_rules = rules.clone();
127    }
128    if let Some(ref regime) = adjustment.default_regime {
129        group.default_regime = regime.clone();
130    }
131    if let Some(ref hyst) = adjustment.hysteresis {
132        group.hysteresis = hyst.clone();
133    }
134    if let Some(ref overrides) = adjustment.playbook_overrides {
135        for (regime_name, pb_override) in overrides {
136            if let Some(playbook) = group.playbooks.get_mut(regime_name) {
137                if let Some(ref rules) = pb_override.rules {
138                    playbook.rules = rules.clone();
139                }
140                if let Some(max_pos) = pb_override.max_position_size {
141                    playbook.max_position_size = max_pos;
142                }
143                if let Some(sl) = pb_override.stop_loss_pct {
144                    playbook.stop_loss_pct = Some(sl);
145                }
146                if let Some(tp) = pb_override.take_profit_pct {
147                    playbook.take_profit_pct = Some(tp);
148                }
149            }
150        }
151    }
152}
153
154/// Save a StrategyGroup to disk using the existing save function.
155/// Creates a backup first.
156pub fn save_with_backup(groups: &[StrategyGroup]) -> Result<(), AdjustmentError> {
157    use hyper_strategy::strategy_config::{
158        load_strategy_groups_from_disk_pub, save_strategy_groups_to_disk,
159    };
160
161    // Backup current version
162    let current = load_strategy_groups_from_disk_pub();
163    if !current.is_empty() {
164        let backup = serde_json::to_string_pretty(&current)
165            .map_err(|e| AdjustmentError::FileWriteError(e.to_string()))?;
166        let backup_path = dirs::data_dir()
167            .unwrap_or_else(|| std::path::PathBuf::from("."))
168            .join("hyper-agent")
169            .join("strategy_groups.bak.json");
170        std::fs::create_dir_all(backup_path.parent().unwrap())
171            .map_err(|e| AdjustmentError::FileWriteError(e.to_string()))?;
172        std::fs::write(&backup_path, backup)
173            .map_err(|e| AdjustmentError::FileWriteError(e.to_string()))?;
174    }
175
176    save_strategy_groups_to_disk(groups)
177        .map_err(|e| AdjustmentError::FileWriteError(e.to_string()))?;
178    Ok(())
179}
180
181// ---------------------------------------------------------------------------
182// Tests
183// ---------------------------------------------------------------------------
184
185#[cfg(test)]
186mod tests {
187    use super::*;
188    use hyper_strategy::strategy_config::{HysteresisConfig, Playbook, StrategyGroup};
189
190    fn make_group() -> StrategyGroup {
191        let mut playbooks = HashMap::new();
192        playbooks.insert(
193            "bull".to_string(),
194            Playbook {
195                rules: vec![],
196                entry_rules: vec![],
197                exit_rules: vec![],
198                system_prompt: "bull".into(),
199                max_position_size: 1000.0,
200                stop_loss_pct: Some(5.0),
201                take_profit_pct: Some(10.0),
202                timeout_secs: None,
203                side: None,
204            },
205        );
206        StrategyGroup {
207            id: "sg-test".into(),
208            name: "Test".into(),
209            vault_address: None,
210            is_active: true,
211            created_at: "2026-01-01".into(),
212            symbol: "BTC-USD".into(),
213            interval_secs: 300,
214            regime_rules: vec![],
215            default_regime: "bull".into(),
216            hysteresis: HysteresisConfig {
217                min_hold_secs: 3600,
218                confirmation_count: 3,
219            },
220            playbooks,
221        }
222    }
223
224    #[test]
225    fn validate_passes_within_limits() {
226        let adj = StrategyAdjustment {
227            regime_rules: None,
228            default_regime: None,
229            hysteresis: None,
230            playbook_overrides: Some(HashMap::from([(
231                "bull".into(),
232                PlaybookOverride {
233                    rules: None,
234                    max_position_size: Some(500.0),
235                    stop_loss_pct: Some(3.0),
236                    take_profit_pct: Some(15.0),
237                },
238            )])),
239        };
240        assert!(validate_adjustment(&adj, 10000.0).is_ok());
241    }
242
243    #[test]
244    fn validate_rejects_exceeding_max_position() {
245        let adj = StrategyAdjustment {
246            regime_rules: None,
247            default_regime: None,
248            hysteresis: None,
249            playbook_overrides: Some(HashMap::from([(
250                "bull".into(),
251                PlaybookOverride {
252                    rules: None,
253                    max_position_size: Some(50000.0),
254                    stop_loss_pct: None,
255                    take_profit_pct: None,
256                },
257            )])),
258        };
259        assert!(matches!(
260            validate_adjustment(&adj, 10000.0),
261            Err(AdjustmentError::MaxPositionExceeded { .. })
262        ));
263    }
264
265    #[test]
266    fn validate_rejects_negative_stop_loss() {
267        let adj = StrategyAdjustment {
268            regime_rules: None,
269            default_regime: None,
270            hysteresis: None,
271            playbook_overrides: Some(HashMap::from([(
272                "bull".into(),
273                PlaybookOverride {
274                    rules: None,
275                    max_position_size: None,
276                    stop_loss_pct: Some(-5.0),
277                    take_profit_pct: None,
278                },
279            )])),
280        };
281        assert!(matches!(
282            validate_adjustment(&adj, 10000.0),
283            Err(AdjustmentError::InvalidThreshold { .. })
284        ));
285    }
286
287    #[test]
288    fn apply_changes_playbook_params() {
289        let mut group = make_group();
290        let adj = StrategyAdjustment {
291            regime_rules: None,
292            default_regime: Some("neutral".into()),
293            hysteresis: None,
294            playbook_overrides: Some(HashMap::from([(
295                "bull".into(),
296                PlaybookOverride {
297                    rules: None,
298                    max_position_size: Some(2000.0),
299                    stop_loss_pct: Some(3.0),
300                    take_profit_pct: None,
301                },
302            )])),
303        };
304        apply_adjustment(&mut group, &adj);
305        assert_eq!(group.default_regime, "neutral");
306        let bull = group.playbooks.get("bull").unwrap();
307        assert_eq!(bull.max_position_size, 2000.0);
308        assert_eq!(bull.stop_loss_pct, Some(3.0));
309        assert_eq!(bull.take_profit_pct, Some(10.0)); // unchanged
310    }
311
312    #[test]
313    fn apply_preserves_immutable_fields() {
314        let mut group = make_group();
315        let original_id = group.id.clone();
316        let original_symbol = group.symbol.clone();
317        apply_adjustment(
318            &mut group,
319            &StrategyAdjustment {
320                regime_rules: None,
321                default_regime: None,
322                hysteresis: None,
323                playbook_overrides: None,
324            },
325        );
326        assert_eq!(group.id, original_id);
327        assert_eq!(group.symbol, original_symbol);
328    }
329}