use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use super::{RollingStats, TrainingAnomaly};
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub enum RuleType {
GradientClipping,
LearningRateDecay,
LearningRateWarmup,
BatchSizeIncrease,
Regularization,
ManualReview,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RulePatch {
pub rule_type: RuleType,
pub parameters: HashMap<String, String>,
}
impl Default for RulePatch {
fn default() -> Self {
Self {
rule_type: RuleType::ManualReview,
parameters: HashMap::new(),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum AnomalyType {
NonFiniteLoss,
GradientExplosion,
GradientVanishing,
LossSpike,
LowConfidence,
OracleMismatch,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AnomalyPattern {
pub pattern_type: AnomalyType,
pub frequency: u64,
pub context: HashMap<String, String>,
pub suggested_fix: Option<String>,
}
pub struct JidokaMLFeedback {
patterns: Vec<AnomalyPattern>,
patches: Vec<RulePatch>,
anomaly_rate: RollingStats,
auto_patch_threshold: u64,
}
impl Default for JidokaMLFeedback {
fn default() -> Self {
Self::new()
}
}
impl JidokaMLFeedback {
#[must_use]
pub fn new() -> Self {
Self {
patterns: Vec::new(),
patches: Vec::new(),
anomaly_rate: RollingStats::new(100),
auto_patch_threshold: 3,
}
}
#[must_use]
pub fn with_auto_patch_threshold(mut self, threshold: u64) -> Self {
self.auto_patch_threshold = threshold;
self
}
pub fn record_anomaly(&mut self, anomaly: TrainingAnomaly) -> Option<RulePatch> {
let pattern_type = self.classify_type(&anomaly);
self.anomaly_rate.update(1.0);
let existing_idx = self
.patterns
.iter()
.position(|p| p.pattern_type == pattern_type);
if let Some(idx) = existing_idx {
self.patterns[idx].frequency += 1;
if self.patterns[idx].frequency >= self.auto_patch_threshold {
let pattern_clone = self.patterns[idx].clone();
let patch = self.generate_patch(&pattern_clone);
self.patches.push(patch.clone());
return Some(patch);
}
} else {
let suggested_fix = Some(self.suggest_fix(pattern_type));
let new_pattern = AnomalyPattern {
pattern_type,
frequency: 1,
context: HashMap::new(),
suggested_fix,
};
if self.auto_patch_threshold <= 1 {
let patch = self.generate_patch(&new_pattern);
self.patterns.push(new_pattern);
self.patches.push(patch.clone());
return Some(patch);
}
self.patterns.push(new_pattern);
}
None
}
#[allow(clippy::unused_self)]
fn classify_type(&self, anomaly: &TrainingAnomaly) -> AnomalyType {
match anomaly {
TrainingAnomaly::NonFiniteLoss => AnomalyType::NonFiniteLoss,
TrainingAnomaly::GradientExplosion { .. } => AnomalyType::GradientExplosion,
TrainingAnomaly::GradientVanishing { .. } => AnomalyType::GradientVanishing,
TrainingAnomaly::LossSpike { .. } => AnomalyType::LossSpike,
TrainingAnomaly::LowConfidence { .. } => AnomalyType::LowConfidence,
}
}
#[allow(clippy::unused_self)]
fn suggest_fix(&self, pattern_type: AnomalyType) -> String {
match pattern_type {
AnomalyType::GradientExplosion => "Apply gradient clipping with max_norm=1.0",
AnomalyType::GradientVanishing => "Use skip connections or residual architecture",
AnomalyType::LossSpike => "Reduce learning rate or add warmup",
AnomalyType::NonFiniteLoss => "Check for numerical stability issues",
AnomalyType::LowConfidence => "Increase model capacity or training data",
AnomalyType::OracleMismatch => "Review training data distribution",
}
.to_string()
}
#[allow(clippy::unused_self)]
fn generate_patch(&self, pattern: &AnomalyPattern) -> RulePatch {
let mut params = HashMap::new();
match pattern.pattern_type {
AnomalyType::GradientExplosion => {
params.insert("max_norm".to_string(), "1.0".to_string());
RulePatch {
rule_type: RuleType::GradientClipping,
parameters: params,
}
}
AnomalyType::LossSpike => {
params.insert("warmup_steps".to_string(), "1000".to_string());
RulePatch {
rule_type: RuleType::LearningRateWarmup,
parameters: params,
}
}
AnomalyType::GradientVanishing => {
params.insert("factor".to_string(), "10.0".to_string());
RulePatch {
rule_type: RuleType::LearningRateDecay,
parameters: params,
}
}
_ => RulePatch::default(),
}
}
#[must_use]
pub fn patterns(&self) -> &[AnomalyPattern] {
&self.patterns
}
#[must_use]
pub fn patches(&self) -> &[RulePatch] {
&self.patches
}
#[must_use]
pub fn anomaly_rate(&self) -> f64 {
self.anomaly_rate.mean()
}
pub fn reset(&mut self) {
self.patterns.clear();
self.patches.clear();
self.anomaly_rate.reset();
}
}