simular/domains/ml/
jidoka.rs1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3
4use super::{RollingStats, TrainingAnomaly};
5
6#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
12pub enum RuleType {
13 GradientClipping,
15 LearningRateDecay,
17 LearningRateWarmup,
19 BatchSizeIncrease,
21 Regularization,
23 ManualReview,
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct RulePatch {
30 pub rule_type: RuleType,
32 pub parameters: HashMap<String, String>,
34}
35
36impl Default for RulePatch {
37 fn default() -> Self {
38 Self {
39 rule_type: RuleType::ManualReview,
40 parameters: HashMap::new(),
41 }
42 }
43}
44
45#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
47pub enum AnomalyType {
48 NonFiniteLoss,
50 GradientExplosion,
52 GradientVanishing,
54 LossSpike,
56 LowConfidence,
58 OracleMismatch,
60}
61
62#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct AnomalyPattern {
65 pub pattern_type: AnomalyType,
67 pub frequency: u64,
69 pub context: HashMap<String, String>,
71 pub suggested_fix: Option<String>,
73}
74
75pub struct JidokaMLFeedback {
79 patterns: Vec<AnomalyPattern>,
81 patches: Vec<RulePatch>,
83 anomaly_rate: RollingStats,
85 auto_patch_threshold: u64,
87}
88
89impl Default for JidokaMLFeedback {
90 fn default() -> Self {
91 Self::new()
92 }
93}
94
95impl JidokaMLFeedback {
96 #[must_use]
98 pub fn new() -> Self {
99 Self {
100 patterns: Vec::new(),
101 patches: Vec::new(),
102 anomaly_rate: RollingStats::new(100),
103 auto_patch_threshold: 3,
104 }
105 }
106
107 #[must_use]
109 pub fn with_auto_patch_threshold(mut self, threshold: u64) -> Self {
110 self.auto_patch_threshold = threshold;
111 self
112 }
113
114 pub fn record_anomaly(&mut self, anomaly: TrainingAnomaly) -> Option<RulePatch> {
116 let pattern_type = self.classify_type(&anomaly);
117 self.anomaly_rate.update(1.0);
118
119 let existing_idx = self
121 .patterns
122 .iter()
123 .position(|p| p.pattern_type == pattern_type);
124
125 if let Some(idx) = existing_idx {
126 self.patterns[idx].frequency += 1;
127
128 if self.patterns[idx].frequency >= self.auto_patch_threshold {
130 let pattern_clone = self.patterns[idx].clone();
131 let patch = self.generate_patch(&pattern_clone);
132 self.patches.push(patch.clone());
133 return Some(patch);
134 }
135 } else {
136 let suggested_fix = Some(self.suggest_fix(pattern_type));
137 let new_pattern = AnomalyPattern {
138 pattern_type,
139 frequency: 1,
140 context: HashMap::new(),
141 suggested_fix,
142 };
143
144 if self.auto_patch_threshold <= 1 {
146 let patch = self.generate_patch(&new_pattern);
147 self.patterns.push(new_pattern);
148 self.patches.push(patch.clone());
149 return Some(patch);
150 }
151
152 self.patterns.push(new_pattern);
153 }
154
155 None
156 }
157
158 #[allow(clippy::unused_self)]
160 fn classify_type(&self, anomaly: &TrainingAnomaly) -> AnomalyType {
161 match anomaly {
162 TrainingAnomaly::NonFiniteLoss => AnomalyType::NonFiniteLoss,
163 TrainingAnomaly::GradientExplosion { .. } => AnomalyType::GradientExplosion,
164 TrainingAnomaly::GradientVanishing { .. } => AnomalyType::GradientVanishing,
165 TrainingAnomaly::LossSpike { .. } => AnomalyType::LossSpike,
166 TrainingAnomaly::LowConfidence { .. } => AnomalyType::LowConfidence,
167 }
168 }
169
170 #[allow(clippy::unused_self)]
172 fn suggest_fix(&self, pattern_type: AnomalyType) -> String {
173 match pattern_type {
174 AnomalyType::GradientExplosion => "Apply gradient clipping with max_norm=1.0",
175 AnomalyType::GradientVanishing => "Use skip connections or residual architecture",
176 AnomalyType::LossSpike => "Reduce learning rate or add warmup",
177 AnomalyType::NonFiniteLoss => "Check for numerical stability issues",
178 AnomalyType::LowConfidence => "Increase model capacity or training data",
179 AnomalyType::OracleMismatch => "Review training data distribution",
180 }
181 .to_string()
182 }
183
184 #[allow(clippy::unused_self)]
186 fn generate_patch(&self, pattern: &AnomalyPattern) -> RulePatch {
187 let mut params = HashMap::new();
188
189 match pattern.pattern_type {
190 AnomalyType::GradientExplosion => {
191 params.insert("max_norm".to_string(), "1.0".to_string());
192 RulePatch {
193 rule_type: RuleType::GradientClipping,
194 parameters: params,
195 }
196 }
197 AnomalyType::LossSpike => {
198 params.insert("warmup_steps".to_string(), "1000".to_string());
199 RulePatch {
200 rule_type: RuleType::LearningRateWarmup,
201 parameters: params,
202 }
203 }
204 AnomalyType::GradientVanishing => {
205 params.insert("factor".to_string(), "10.0".to_string());
206 RulePatch {
207 rule_type: RuleType::LearningRateDecay,
208 parameters: params,
209 }
210 }
211 _ => RulePatch::default(),
212 }
213 }
214
215 #[must_use]
217 pub fn patterns(&self) -> &[AnomalyPattern] {
218 &self.patterns
219 }
220
221 #[must_use]
223 pub fn patches(&self) -> &[RulePatch] {
224 &self.patches
225 }
226
227 #[must_use]
229 pub fn anomaly_rate(&self) -> f64 {
230 self.anomaly_rate.mean()
231 }
232
233 pub fn reset(&mut self) {
235 self.patterns.clear();
236 self.patches.clear();
237 self.anomaly_rate.reset();
238 }
239}