Skip to main content

simular/domains/ml/
jidoka.rs

1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3
4use super::{RollingStats, TrainingAnomaly};
5
6// ============================================================================
7// Jidoka ML Feedback Loop
8// ============================================================================
9
10/// Rule patch types for Kaizen improvements.
11#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
12pub enum RuleType {
13    /// Clip gradients to max norm.
14    GradientClipping,
15    /// Reduce learning rate.
16    LearningRateDecay,
17    /// Add warmup steps.
18    LearningRateWarmup,
19    /// Increase batch size.
20    BatchSizeIncrease,
21    /// Add regularization.
22    Regularization,
23    /// Manual review required.
24    ManualReview,
25}
26
27/// Improvement patch generated from anomaly pattern.
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct RulePatch {
30    /// Type of rule patch.
31    pub rule_type: RuleType,
32    /// Parameters for the patch.
33    pub parameters: HashMap<String, String>,
34}
35
36impl Default for RulePatch {
37    fn default() -> Self {
38        Self {
39            rule_type: RuleType::ManualReview,
40            parameters: HashMap::new(),
41        }
42    }
43}
44
45/// Anomaly pattern classification.
46#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
47pub enum AnomalyType {
48    /// Loss became NaN/Inf.
49    NonFiniteLoss,
50    /// Gradient norm exceeded threshold.
51    GradientExplosion,
52    /// Gradient vanished below threshold.
53    GradientVanishing,
54    /// Loss spike (statistical outlier).
55    LossSpike,
56    /// Prediction confidence below threshold.
57    LowConfidence,
58    /// Model output inconsistent with oracle.
59    OracleMismatch,
60}
61
62/// Pattern detected during training/inference.
63#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct AnomalyPattern {
65    /// Pattern type.
66    pub pattern_type: AnomalyType,
67    /// Frequency of occurrence.
68    pub frequency: u64,
69    /// Context information.
70    pub context: HashMap<String, String>,
71    /// Suggested fix description.
72    pub suggested_fix: Option<String>,
73}
74
75/// Jidoka feedback loop for ML simulation.
76///
77/// Each detected anomaly generates improvement patches (Kaizen).
78pub struct JidokaMLFeedback {
79    /// Anomaly patterns detected.
80    patterns: Vec<AnomalyPattern>,
81    /// Generated fixes (rule patches).
82    patches: Vec<RulePatch>,
83    /// Rolling stats for anomaly rate tracking.
84    anomaly_rate: RollingStats,
85    /// Threshold for auto-patch generation.
86    auto_patch_threshold: u64,
87}
88
89impl Default for JidokaMLFeedback {
90    fn default() -> Self {
91        Self::new()
92    }
93}
94
95impl JidokaMLFeedback {
96    /// Create new Jidoka feedback loop.
97    #[must_use]
98    pub fn new() -> Self {
99        Self {
100            patterns: Vec::new(),
101            patches: Vec::new(),
102            anomaly_rate: RollingStats::new(100),
103            auto_patch_threshold: 3,
104        }
105    }
106
107    /// Set threshold for automatic patch generation.
108    #[must_use]
109    pub fn with_auto_patch_threshold(mut self, threshold: u64) -> Self {
110        self.auto_patch_threshold = threshold;
111        self
112    }
113
114    /// Record anomaly and potentially generate improvement patch.
115    pub fn record_anomaly(&mut self, anomaly: TrainingAnomaly) -> Option<RulePatch> {
116        let pattern_type = self.classify_type(&anomaly);
117        self.anomaly_rate.update(1.0);
118
119        // Check if we've seen this pattern before
120        let existing_idx = self
121            .patterns
122            .iter()
123            .position(|p| p.pattern_type == pattern_type);
124
125        if let Some(idx) = existing_idx {
126            self.patterns[idx].frequency += 1;
127
128            // After threshold occurrences, generate automated fix
129            if self.patterns[idx].frequency >= self.auto_patch_threshold {
130                let pattern_clone = self.patterns[idx].clone();
131                let patch = self.generate_patch(&pattern_clone);
132                self.patches.push(patch.clone());
133                return Some(patch);
134            }
135        } else {
136            let suggested_fix = Some(self.suggest_fix(pattern_type));
137            let new_pattern = AnomalyPattern {
138                pattern_type,
139                frequency: 1,
140                context: HashMap::new(),
141                suggested_fix,
142            };
143
144            // Check if threshold is 1 (generate patch on first occurrence)
145            if self.auto_patch_threshold <= 1 {
146                let patch = self.generate_patch(&new_pattern);
147                self.patterns.push(new_pattern);
148                self.patches.push(patch.clone());
149                return Some(patch);
150            }
151
152            self.patterns.push(new_pattern);
153        }
154
155        None
156    }
157
158    /// Classify anomaly into pattern type.
159    #[allow(clippy::unused_self)]
160    fn classify_type(&self, anomaly: &TrainingAnomaly) -> AnomalyType {
161        match anomaly {
162            TrainingAnomaly::NonFiniteLoss => AnomalyType::NonFiniteLoss,
163            TrainingAnomaly::GradientExplosion { .. } => AnomalyType::GradientExplosion,
164            TrainingAnomaly::GradientVanishing { .. } => AnomalyType::GradientVanishing,
165            TrainingAnomaly::LossSpike { .. } => AnomalyType::LossSpike,
166            TrainingAnomaly::LowConfidence { .. } => AnomalyType::LowConfidence,
167        }
168    }
169
170    /// Suggest fix for pattern type.
171    #[allow(clippy::unused_self)]
172    fn suggest_fix(&self, pattern_type: AnomalyType) -> String {
173        match pattern_type {
174            AnomalyType::GradientExplosion => "Apply gradient clipping with max_norm=1.0",
175            AnomalyType::GradientVanishing => "Use skip connections or residual architecture",
176            AnomalyType::LossSpike => "Reduce learning rate or add warmup",
177            AnomalyType::NonFiniteLoss => "Check for numerical stability issues",
178            AnomalyType::LowConfidence => "Increase model capacity or training data",
179            AnomalyType::OracleMismatch => "Review training data distribution",
180        }
181        .to_string()
182    }
183
184    /// Generate rule patch from pattern (Kaizen improvement).
185    #[allow(clippy::unused_self)]
186    fn generate_patch(&self, pattern: &AnomalyPattern) -> RulePatch {
187        let mut params = HashMap::new();
188
189        match pattern.pattern_type {
190            AnomalyType::GradientExplosion => {
191                params.insert("max_norm".to_string(), "1.0".to_string());
192                RulePatch {
193                    rule_type: RuleType::GradientClipping,
194                    parameters: params,
195                }
196            }
197            AnomalyType::LossSpike => {
198                params.insert("warmup_steps".to_string(), "1000".to_string());
199                RulePatch {
200                    rule_type: RuleType::LearningRateWarmup,
201                    parameters: params,
202                }
203            }
204            AnomalyType::GradientVanishing => {
205                params.insert("factor".to_string(), "10.0".to_string());
206                RulePatch {
207                    rule_type: RuleType::LearningRateDecay,
208                    parameters: params,
209                }
210            }
211            _ => RulePatch::default(),
212        }
213    }
214
215    /// Get all detected patterns.
216    #[must_use]
217    pub fn patterns(&self) -> &[AnomalyPattern] {
218        &self.patterns
219    }
220
221    /// Get all generated patches.
222    #[must_use]
223    pub fn patches(&self) -> &[RulePatch] {
224        &self.patches
225    }
226
227    /// Get current anomaly rate.
228    #[must_use]
229    pub fn anomaly_rate(&self) -> f64 {
230        self.anomaly_rate.mean()
231    }
232
233    /// Reset feedback loop state.
234    pub fn reset(&mut self) {
235        self.patterns.clear();
236        self.patches.clear();
237        self.anomaly_rate.reset();
238    }
239}