mockforge_chaos/
reinforcement_learning.rs

1use anyhow::Result;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::sync::Arc;
5use tokio::sync::RwLock;
6
7/// State representation for RL agent
8#[derive(Debug, Clone, Serialize, Deserialize, Hash, Eq, PartialEq)]
9pub struct SystemState {
10    pub error_rate: u8,         // 0-100
11    pub latency_level: u8,      // 0-100
12    pub cpu_usage: u8,          // 0-100
13    pub memory_usage: u8,       // 0-100
14    pub active_failures: u8,    // Number of active failures
15    pub service_health: String, // "healthy", "degraded", "critical"
16}
17
18/// Remediation action
19#[derive(Debug, Clone, Serialize, Deserialize, Hash, Eq, PartialEq)]
20pub enum RemediationAction {
21    RestartService,
22    ScaleUp(u32),
23    ScaleDown(u32),
24    ClearCache,
25    RollbackDeployment,
26    EnableCircuitBreaker,
27    DisableRateLimiting,
28    RestrictTraffic,
29    NoAction,
30}
31
32/// Q-Learning parameters
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct QLearningConfig {
35    pub learning_rate: f64,     // Alpha: 0.0 - 1.0
36    pub discount_factor: f64,   // Gamma: 0.0 - 1.0
37    pub exploration_rate: f64,  // Epsilon: 0.0 - 1.0
38    pub exploration_decay: f64, // Epsilon decay rate
39    pub min_exploration: f64,   // Minimum epsilon
40}
41
42impl Default for QLearningConfig {
43    fn default() -> Self {
44        Self {
45            learning_rate: 0.1,
46            discount_factor: 0.95,
47            exploration_rate: 1.0,
48            exploration_decay: 0.995,
49            min_exploration: 0.01,
50        }
51    }
52}
53
54/// Q-Table entry
55#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct QValue {
57    pub value: f64,
58    pub visit_count: u64,
59}
60
61/// Reinforcement Learning Agent
62pub struct RLAgent {
63    q_table: Arc<RwLock<HashMap<(SystemState, RemediationAction), QValue>>>,
64    config: QLearningConfig,
65    current_epsilon: f64,
66}
67
68impl RLAgent {
69    pub fn new(config: QLearningConfig) -> Self {
70        Self {
71            q_table: Arc::new(RwLock::new(HashMap::new())),
72            current_epsilon: config.exploration_rate,
73            config,
74        }
75    }
76
77    /// Select action using epsilon-greedy policy
78    pub async fn select_action(&self, state: &SystemState) -> RemediationAction {
79        if rand::random::<f64>() < self.current_epsilon {
80            // Explore: random action
81            self.random_action()
82        } else {
83            // Exploit: best known action
84            self.best_action(state).await
85        }
86    }
87
88    /// Get best action for given state
89    async fn best_action(&self, state: &SystemState) -> RemediationAction {
90        let q_table = self.q_table.read().await;
91        let actions = self.possible_actions();
92
93        let mut best_action = RemediationAction::NoAction;
94        let mut best_value = f64::NEG_INFINITY;
95
96        for action in actions {
97            let key = (state.clone(), action.clone());
98            let value = q_table.get(&key).map(|q| q.value).unwrap_or(0.0);
99
100            if value > best_value {
101                best_value = value;
102                best_action = action;
103            }
104        }
105
106        best_action
107    }
108
109    /// Get random action (for exploration)
110    fn random_action(&self) -> RemediationAction {
111        let actions = self.possible_actions();
112        use rand::Rng;
113        let mut rng = rand::rng();
114        let idx = rng.random_range(0..actions.len());
115        actions[idx].clone()
116    }
117
118    /// Get all possible actions
119    fn possible_actions(&self) -> Vec<RemediationAction> {
120        vec![
121            RemediationAction::RestartService,
122            RemediationAction::ScaleUp(2),
123            RemediationAction::ScaleUp(4),
124            RemediationAction::ScaleDown(2),
125            RemediationAction::ClearCache,
126            RemediationAction::RollbackDeployment,
127            RemediationAction::EnableCircuitBreaker,
128            RemediationAction::DisableRateLimiting,
129            RemediationAction::RestrictTraffic,
130            RemediationAction::NoAction,
131        ]
132    }
133
134    /// Update Q-value based on observed outcome
135    pub async fn update(
136        &mut self,
137        state: &SystemState,
138        action: &RemediationAction,
139        reward: f64,
140        next_state: &SystemState,
141    ) {
142        let mut q_table = self.q_table.write().await;
143
144        // Get current Q-value
145        let key = (state.clone(), action.clone());
146        let current_q = q_table.get(&key).map(|q| q.value).unwrap_or(0.0);
147
148        // Get max Q-value for next state
149        let actions = self.possible_actions();
150        let max_next_q = actions
151            .iter()
152            .map(|a| {
153                let next_key = (next_state.clone(), a.clone());
154                q_table.get(&next_key).map(|q| q.value).unwrap_or(0.0)
155            })
156            .fold(f64::NEG_INFINITY, f64::max);
157
158        // Q-learning update: Q(s,a) = Q(s,a) + α[r + γ·max Q(s',a') - Q(s,a)]
159        let new_q = current_q
160            + self.config.learning_rate
161                * (reward + self.config.discount_factor * max_next_q - current_q);
162
163        // Update Q-table
164        q_table
165            .entry(key)
166            .and_modify(|q| {
167                q.value = new_q;
168                q.visit_count += 1;
169            })
170            .or_insert(QValue {
171                value: new_q,
172                visit_count: 1,
173            });
174
175        // Decay exploration rate
176        self.current_epsilon =
177            (self.current_epsilon * self.config.exploration_decay).max(self.config.min_exploration);
178    }
179
180    /// Calculate reward based on outcome
181    pub fn calculate_reward(&self, before: &SystemState, after: &SystemState) -> f64 {
182        let mut reward = 0.0;
183
184        // Reward for reducing error rate
185        reward += (before.error_rate as f64 - after.error_rate as f64) * 2.0;
186
187        // Reward for reducing latency
188        reward += (before.latency_level as f64 - after.latency_level as f64) * 1.5;
189
190        // Reward for reducing CPU usage
191        reward += (before.cpu_usage as f64 - after.cpu_usage as f64) * 0.5;
192
193        // Reward for reducing active failures
194        reward += (before.active_failures as f64 - after.active_failures as f64) * 5.0;
195
196        // Health state bonus
197        reward += match (before.service_health.as_str(), after.service_health.as_str()) {
198            ("critical", "healthy") => 50.0,
199            ("critical", "degraded") => 25.0,
200            ("degraded", "healthy") => 20.0,
201            ("healthy", "degraded") => -30.0,
202            ("healthy", "critical") => -50.0,
203            ("degraded", "critical") => -40.0,
204            _ => 0.0,
205        };
206
207        reward
208    }
209
210    /// Get policy statistics
211    pub async fn get_stats(&self) -> HashMap<String, serde_json::Value> {
212        let q_table = self.q_table.read().await;
213
214        let mut stats = HashMap::new();
215        stats.insert("q_table_size".to_string(), serde_json::json!(q_table.len()));
216        stats.insert("epsilon".to_string(), serde_json::json!(self.current_epsilon));
217
218        // Calculate average Q-value
219        let avg_q: f64 = if q_table.is_empty() {
220            0.0
221        } else {
222            q_table.values().map(|q| q.value).sum::<f64>() / q_table.len() as f64
223        };
224        stats.insert("avg_q_value".to_string(), serde_json::json!(avg_q));
225
226        // Most visited state-action pairs
227        let mut visited: Vec<_> = q_table.iter().collect();
228        visited.sort_by_key(|(_, q)| std::cmp::Reverse(q.visit_count));
229
230        let top_pairs: Vec<_> = visited
231            .iter()
232            .take(10)
233            .map(|((state, action), q)| {
234                serde_json::json!({
235                    "state": state,
236                    "action": action,
237                    "q_value": q.value,
238                    "visits": q.visit_count,
239                })
240            })
241            .collect();
242
243        stats.insert("top_state_actions".to_string(), serde_json::json!(top_pairs));
244
245        stats
246    }
247
248    /// Save Q-table to disk
249    pub async fn save_model(&self, path: &str) -> Result<()> {
250        let q_table = self.q_table.read().await;
251        let data = serde_json::to_string_pretty(&*q_table)?;
252        tokio::fs::write(path, data).await?;
253        Ok(())
254    }
255
256    /// Load Q-table from disk
257    pub async fn load_model(&mut self, path: &str) -> Result<()> {
258        let data = tokio::fs::read_to_string(path).await?;
259        let loaded_table: HashMap<(SystemState, RemediationAction), QValue> =
260            serde_json::from_str(&data)?;
261
262        let mut q_table = self.q_table.write().await;
263        *q_table = loaded_table;
264
265        Ok(())
266    }
267}
268
269/// Adaptive Risk Assessment Engine
270pub struct AdaptiveRiskAssessor {
271    risk_history: Arc<RwLock<Vec<RiskAssessment>>>,
272    rl_agent: Arc<RwLock<RLAgent>>,
273}
274
275#[derive(Debug, Clone, Serialize, Deserialize)]
276pub struct RiskAssessment {
277    pub timestamp: chrono::DateTime<chrono::Utc>,
278    pub state: SystemState,
279    pub risk_level: f64, // 0.0 - 1.0
280    pub recommended_actions: Vec<RemediationAction>,
281    pub confidence: f64, // 0.0 - 1.0
282}
283
284impl AdaptiveRiskAssessor {
285    pub fn new(rl_agent: Arc<RwLock<RLAgent>>) -> Self {
286        Self {
287            risk_history: Arc::new(RwLock::new(Vec::new())),
288            rl_agent,
289        }
290    }
291
292    /// Assess risk for current system state
293    pub async fn assess_risk(&self, state: &SystemState) -> RiskAssessment {
294        let mut risk_level = 0.0;
295
296        // Factor in various metrics
297        risk_level += state.error_rate as f64 / 100.0 * 0.3;
298        risk_level += state.latency_level as f64 / 100.0 * 0.2;
299        risk_level += state.cpu_usage as f64 / 100.0 * 0.15;
300        risk_level += state.memory_usage as f64 / 100.0 * 0.15;
301        risk_level += state.active_failures as f64 / 10.0 * 0.2;
302
303        // Health state impact
304        risk_level += match state.service_health.as_str() {
305            "critical" => 0.4,
306            "degraded" => 0.2,
307            _ => 0.0,
308        };
309
310        risk_level = risk_level.min(1.0);
311
312        // Get recommended actions from RL agent
313        let agent = self.rl_agent.read().await;
314        let action = agent.best_action(state).await;
315
316        // Calculate confidence based on Q-table visit counts
317        let q_table = agent.q_table.read().await;
318        let key = (state.clone(), action.clone());
319        let confidence = q_table
320            .get(&key)
321            .map(|q| (q.visit_count as f64 / 100.0).min(1.0))
322            .unwrap_or(0.1);
323
324        let assessment = RiskAssessment {
325            timestamp: chrono::Utc::now(),
326            state: state.clone(),
327            risk_level,
328            recommended_actions: vec![action],
329            confidence,
330        };
331
332        // Store in history
333        let mut history = self.risk_history.write().await;
334        history.push(assessment.clone());
335
336        // Keep only last 1000 assessments
337        if history.len() > 1000 {
338            let excess = history.len() - 1000;
339            history.drain(0..excess);
340        }
341
342        assessment
343    }
344
345    /// Get risk trend over time
346    pub async fn get_risk_trend(&self) -> Vec<(chrono::DateTime<chrono::Utc>, f64)> {
347        let history = self.risk_history.read().await;
348        history.iter().map(|a| (a.timestamp, a.risk_level)).collect()
349    }
350}
351
352#[cfg(test)]
353mod tests {
354    use super::*;
355
356    #[tokio::test]
357    async fn test_rl_agent_learning() {
358        let config = QLearningConfig::default();
359        let mut agent = RLAgent::new(config);
360
361        let state = SystemState {
362            error_rate: 50,
363            latency_level: 60,
364            cpu_usage: 80,
365            memory_usage: 70,
366            active_failures: 3,
367            service_health: "degraded".to_string(),
368        };
369
370        let action = RemediationAction::RestartService;
371
372        let next_state = SystemState {
373            error_rate: 10,
374            latency_level: 20,
375            cpu_usage: 40,
376            memory_usage: 50,
377            active_failures: 0,
378            service_health: "healthy".to_string(),
379        };
380
381        let reward = agent.calculate_reward(&state, &next_state);
382        agent.update(&state, &action, reward, &next_state).await;
383
384        // Verify Q-value was updated
385        let stats = agent.get_stats().await;
386        assert!(stats.contains_key("q_table_size"));
387    }
388
389    #[tokio::test]
390    async fn test_risk_assessment() {
391        let config = QLearningConfig::default();
392        let agent = Arc::new(RwLock::new(RLAgent::new(config)));
393        let assessor = AdaptiveRiskAssessor::new(agent);
394
395        let state = SystemState {
396            error_rate: 75,
397            latency_level: 80,
398            cpu_usage: 90,
399            memory_usage: 85,
400            active_failures: 5,
401            service_health: "critical".to_string(),
402        };
403
404        let assessment = assessor.assess_risk(&state).await;
405
406        assert!(assessment.risk_level > 0.5);
407        assert!(!assessment.recommended_actions.is_empty());
408    }
409}