1use anyhow::Result;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::sync::Arc;
5use tokio::sync::RwLock;
6
7#[derive(Debug, Clone, Serialize, Deserialize, Hash, Eq, PartialEq)]
9pub struct SystemState {
10 pub error_rate: u8, pub latency_level: u8, pub cpu_usage: u8, pub memory_usage: u8, pub active_failures: u8, pub service_health: String, }
17
18#[derive(Debug, Clone, Serialize, Deserialize, Hash, Eq, PartialEq)]
20pub enum RemediationAction {
21 RestartService,
22 ScaleUp(u32),
23 ScaleDown(u32),
24 ClearCache,
25 RollbackDeployment,
26 EnableCircuitBreaker,
27 DisableRateLimiting,
28 RestrictTraffic,
29 NoAction,
30}
31
32#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct QLearningConfig {
35 pub learning_rate: f64, pub discount_factor: f64, pub exploration_rate: f64, pub exploration_decay: f64, pub min_exploration: f64, }
41
42impl Default for QLearningConfig {
43 fn default() -> Self {
44 Self {
45 learning_rate: 0.1,
46 discount_factor: 0.95,
47 exploration_rate: 1.0,
48 exploration_decay: 0.995,
49 min_exploration: 0.01,
50 }
51 }
52}
53
54#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct QValue {
57 pub value: f64,
58 pub visit_count: u64,
59}
60
61pub struct RLAgent {
63 q_table: Arc<RwLock<HashMap<(SystemState, RemediationAction), QValue>>>,
64 config: QLearningConfig,
65 current_epsilon: f64,
66}
67
68impl RLAgent {
69 pub fn new(config: QLearningConfig) -> Self {
70 Self {
71 q_table: Arc::new(RwLock::new(HashMap::new())),
72 current_epsilon: config.exploration_rate,
73 config,
74 }
75 }
76
77 pub async fn select_action(&self, state: &SystemState) -> RemediationAction {
79 if rand::random::<f64>() < self.current_epsilon {
80 self.random_action()
82 } else {
83 self.best_action(state).await
85 }
86 }
87
88 async fn best_action(&self, state: &SystemState) -> RemediationAction {
90 let q_table = self.q_table.read().await;
91 let actions = self.possible_actions();
92
93 let mut best_action = RemediationAction::NoAction;
94 let mut best_value = f64::NEG_INFINITY;
95
96 for action in actions {
97 let key = (state.clone(), action.clone());
98 let value = q_table.get(&key).map(|q| q.value).unwrap_or(0.0);
99
100 if value > best_value {
101 best_value = value;
102 best_action = action;
103 }
104 }
105
106 best_action
107 }
108
109 fn random_action(&self) -> RemediationAction {
111 let actions = self.possible_actions();
112 use rand::Rng;
113 let mut rng = rand::rng();
114 let idx = rng.random_range(0..actions.len());
115 actions[idx].clone()
116 }
117
118 fn possible_actions(&self) -> Vec<RemediationAction> {
120 vec![
121 RemediationAction::RestartService,
122 RemediationAction::ScaleUp(2),
123 RemediationAction::ScaleUp(4),
124 RemediationAction::ScaleDown(2),
125 RemediationAction::ClearCache,
126 RemediationAction::RollbackDeployment,
127 RemediationAction::EnableCircuitBreaker,
128 RemediationAction::DisableRateLimiting,
129 RemediationAction::RestrictTraffic,
130 RemediationAction::NoAction,
131 ]
132 }
133
134 pub async fn update(
136 &mut self,
137 state: &SystemState,
138 action: &RemediationAction,
139 reward: f64,
140 next_state: &SystemState,
141 ) {
142 let mut q_table = self.q_table.write().await;
143
144 let key = (state.clone(), action.clone());
146 let current_q = q_table.get(&key).map(|q| q.value).unwrap_or(0.0);
147
148 let actions = self.possible_actions();
150 let max_next_q = actions
151 .iter()
152 .map(|a| {
153 let next_key = (next_state.clone(), a.clone());
154 q_table.get(&next_key).map(|q| q.value).unwrap_or(0.0)
155 })
156 .fold(f64::NEG_INFINITY, f64::max);
157
158 let new_q = current_q
160 + self.config.learning_rate
161 * (reward + self.config.discount_factor * max_next_q - current_q);
162
163 q_table
165 .entry(key)
166 .and_modify(|q| {
167 q.value = new_q;
168 q.visit_count += 1;
169 })
170 .or_insert(QValue {
171 value: new_q,
172 visit_count: 1,
173 });
174
175 self.current_epsilon =
177 (self.current_epsilon * self.config.exploration_decay).max(self.config.min_exploration);
178 }
179
180 pub fn calculate_reward(&self, before: &SystemState, after: &SystemState) -> f64 {
182 let mut reward = 0.0;
183
184 reward += (before.error_rate as f64 - after.error_rate as f64) * 2.0;
186
187 reward += (before.latency_level as f64 - after.latency_level as f64) * 1.5;
189
190 reward += (before.cpu_usage as f64 - after.cpu_usage as f64) * 0.5;
192
193 reward += (before.active_failures as f64 - after.active_failures as f64) * 5.0;
195
196 reward += match (before.service_health.as_str(), after.service_health.as_str()) {
198 ("critical", "healthy") => 50.0,
199 ("critical", "degraded") => 25.0,
200 ("degraded", "healthy") => 20.0,
201 ("healthy", "degraded") => -30.0,
202 ("healthy", "critical") => -50.0,
203 ("degraded", "critical") => -40.0,
204 _ => 0.0,
205 };
206
207 reward
208 }
209
210 pub async fn get_stats(&self) -> HashMap<String, serde_json::Value> {
212 let q_table = self.q_table.read().await;
213
214 let mut stats = HashMap::new();
215 stats.insert("q_table_size".to_string(), serde_json::json!(q_table.len()));
216 stats.insert("epsilon".to_string(), serde_json::json!(self.current_epsilon));
217
218 let avg_q: f64 = if q_table.is_empty() {
220 0.0
221 } else {
222 q_table.values().map(|q| q.value).sum::<f64>() / q_table.len() as f64
223 };
224 stats.insert("avg_q_value".to_string(), serde_json::json!(avg_q));
225
226 let mut visited: Vec<_> = q_table.iter().collect();
228 visited.sort_by_key(|(_, q)| std::cmp::Reverse(q.visit_count));
229
230 let top_pairs: Vec<_> = visited
231 .iter()
232 .take(10)
233 .map(|((state, action), q)| {
234 serde_json::json!({
235 "state": state,
236 "action": action,
237 "q_value": q.value,
238 "visits": q.visit_count,
239 })
240 })
241 .collect();
242
243 stats.insert("top_state_actions".to_string(), serde_json::json!(top_pairs));
244
245 stats
246 }
247
248 pub async fn save_model(&self, path: &str) -> Result<()> {
250 let q_table = self.q_table.read().await;
251 let data = serde_json::to_string_pretty(&*q_table)?;
252 tokio::fs::write(path, data).await?;
253 Ok(())
254 }
255
256 pub async fn load_model(&mut self, path: &str) -> Result<()> {
258 let data = tokio::fs::read_to_string(path).await?;
259 let loaded_table: HashMap<(SystemState, RemediationAction), QValue> =
260 serde_json::from_str(&data)?;
261
262 let mut q_table = self.q_table.write().await;
263 *q_table = loaded_table;
264
265 Ok(())
266 }
267}
268
269pub struct AdaptiveRiskAssessor {
271 risk_history: Arc<RwLock<Vec<RiskAssessment>>>,
272 rl_agent: Arc<RwLock<RLAgent>>,
273}
274
275#[derive(Debug, Clone, Serialize, Deserialize)]
276pub struct RiskAssessment {
277 pub timestamp: chrono::DateTime<chrono::Utc>,
278 pub state: SystemState,
279 pub risk_level: f64, pub recommended_actions: Vec<RemediationAction>,
281 pub confidence: f64, }
283
284impl AdaptiveRiskAssessor {
285 pub fn new(rl_agent: Arc<RwLock<RLAgent>>) -> Self {
286 Self {
287 risk_history: Arc::new(RwLock::new(Vec::new())),
288 rl_agent,
289 }
290 }
291
292 pub async fn assess_risk(&self, state: &SystemState) -> RiskAssessment {
294 let mut risk_level = 0.0;
295
296 risk_level += state.error_rate as f64 / 100.0 * 0.3;
298 risk_level += state.latency_level as f64 / 100.0 * 0.2;
299 risk_level += state.cpu_usage as f64 / 100.0 * 0.15;
300 risk_level += state.memory_usage as f64 / 100.0 * 0.15;
301 risk_level += state.active_failures as f64 / 10.0 * 0.2;
302
303 risk_level += match state.service_health.as_str() {
305 "critical" => 0.4,
306 "degraded" => 0.2,
307 _ => 0.0,
308 };
309
310 risk_level = risk_level.min(1.0);
311
312 let agent = self.rl_agent.read().await;
314 let action = agent.best_action(state).await;
315
316 let q_table = agent.q_table.read().await;
318 let key = (state.clone(), action.clone());
319 let confidence = q_table
320 .get(&key)
321 .map(|q| (q.visit_count as f64 / 100.0).min(1.0))
322 .unwrap_or(0.1);
323
324 let assessment = RiskAssessment {
325 timestamp: chrono::Utc::now(),
326 state: state.clone(),
327 risk_level,
328 recommended_actions: vec![action],
329 confidence,
330 };
331
332 let mut history = self.risk_history.write().await;
334 history.push(assessment.clone());
335
336 if history.len() > 1000 {
338 let excess = history.len() - 1000;
339 history.drain(0..excess);
340 }
341
342 assessment
343 }
344
345 pub async fn get_risk_trend(&self) -> Vec<(chrono::DateTime<chrono::Utc>, f64)> {
347 let history = self.risk_history.read().await;
348 history.iter().map(|a| (a.timestamp, a.risk_level)).collect()
349 }
350}
351
352#[cfg(test)]
353mod tests {
354 use super::*;
355
356 #[tokio::test]
357 async fn test_rl_agent_learning() {
358 let config = QLearningConfig::default();
359 let mut agent = RLAgent::new(config);
360
361 let state = SystemState {
362 error_rate: 50,
363 latency_level: 60,
364 cpu_usage: 80,
365 memory_usage: 70,
366 active_failures: 3,
367 service_health: "degraded".to_string(),
368 };
369
370 let action = RemediationAction::RestartService;
371
372 let next_state = SystemState {
373 error_rate: 10,
374 latency_level: 20,
375 cpu_usage: 40,
376 memory_usage: 50,
377 active_failures: 0,
378 service_health: "healthy".to_string(),
379 };
380
381 let reward = agent.calculate_reward(&state, &next_state);
382 agent.update(&state, &action, reward, &next_state).await;
383
384 let stats = agent.get_stats().await;
386 assert!(stats.contains_key("q_table_size"));
387 }
388
389 #[tokio::test]
390 async fn test_risk_assessment() {
391 let config = QLearningConfig::default();
392 let agent = Arc::new(RwLock::new(RLAgent::new(config)));
393 let assessor = AdaptiveRiskAssessor::new(agent);
394
395 let state = SystemState {
396 error_rate: 75,
397 latency_level: 80,
398 cpu_usage: 90,
399 memory_usage: 85,
400 active_failures: 5,
401 service_health: "critical".to_string(),
402 };
403
404 let assessment = assessor.assess_risk(&state).await;
405
406 assert!(assessment.risk_level > 0.5);
407 assert!(!assessment.recommended_actions.is_empty());
408 }
409}