use crate::error::Result;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use super::Pattern;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConsolidationConfig {
pub lambda: f32,
pub min_lambda: f32,
pub max_lambda: f32,
pub fisher_decay: f32,
pub min_usage_for_importance: u32,
pub min_quality_threshold: f32,
pub merge_similarity_threshold: f32,
pub max_unused_age_secs: u64,
pub auto_adapt_lambda: bool,
}
impl Default for ConsolidationConfig {
fn default() -> Self {
Self {
lambda: 2000.0,
min_lambda: 100.0,
max_lambda: 15000.0,
fisher_decay: 0.999,
min_usage_for_importance: 3,
min_quality_threshold: 0.3,
merge_similarity_threshold: 0.85,
max_unused_age_secs: 86400 * 7, auto_adapt_lambda: true,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FisherInformation {
pub diagonal: Vec<f32>,
pub sample_count: u64,
pub ema_grad_squared: Vec<f32>,
}
impl FisherInformation {
pub fn new(dim: usize) -> Self {
Self {
diagonal: vec![1.0; dim],
sample_count: 0,
ema_grad_squared: vec![0.0; dim],
}
}
pub fn update(&mut self, gradient: &[f32], decay: f32) {
if gradient.len() != self.diagonal.len() {
return;
}
self.sample_count += 1;
for (i, &g) in gradient.iter().enumerate() {
self.ema_grad_squared[i] = decay * self.ema_grad_squared[i] + (1.0 - decay) * g * g;
self.diagonal[i] = self.ema_grad_squared[i];
}
}
pub fn importance(&self, dim: usize) -> f32 {
if dim < self.diagonal.len() {
self.diagonal[dim]
} else {
0.0
}
}
pub fn total_importance(&self) -> f32 {
self.diagonal.iter().sum()
}
pub fn merge(&mut self, other: &FisherInformation, self_weight: f32) {
if self.diagonal.len() != other.diagonal.len() {
return;
}
let other_weight = 1.0 - self_weight;
for i in 0..self.diagonal.len() {
self.diagonal[i] = self.diagonal[i] * self_weight + other.diagonal[i] * other_weight;
self.ema_grad_squared[i] =
self.ema_grad_squared[i] * self_weight + other.ema_grad_squared[i] * other_weight;
}
self.sample_count = ((self.sample_count as f32 * self_weight)
+ (other.sample_count as f32 * other_weight)) as u64;
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ImportanceScore {
pub pattern_id: u64,
pub score: f32,
pub factors: ImportanceFactors,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ImportanceFactors {
pub usage_factor: f32,
pub quality_factor: f32,
pub recency_factor: f32,
pub success_factor: f32,
pub fisher_factor: f32,
}
impl ImportanceScore {
pub fn compute(
pattern: &Pattern,
fisher: Option<&FisherInformation>,
max_age_secs: u64,
) -> Self {
let mut factors = ImportanceFactors::default();
factors.usage_factor = (pattern.usage_count as f32 + 1.0).ln() / 10.0;
factors.usage_factor = factors.usage_factor.min(1.0);
factors.quality_factor = pattern.avg_quality;
let age_secs = (chrono::Utc::now() - pattern.last_accessed).num_seconds() as f32;
let decay_rate = -age_secs / max_age_secs as f32;
factors.recency_factor = decay_rate.exp();
factors.success_factor = pattern.success_rate();
if let Some(fi) = fisher {
factors.fisher_factor = (fi.total_importance() / fi.diagonal.len() as f32).min(1.0);
} else {
factors.fisher_factor = 0.5; }
let score = 0.25 * factors.usage_factor
+ 0.25 * factors.quality_factor
+ 0.15 * factors.recency_factor
+ 0.20 * factors.success_factor
+ 0.15 * factors.fisher_factor;
Self {
pattern_id: pattern.id,
score,
factors,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConsolidationResult {
pub merged_pattern_ids: Vec<u64>,
pub pruned_pattern_ids: Vec<u64>,
pub patterns_before: usize,
pub patterns_after: usize,
pub importance_preserved: f32,
pub timestamp: chrono::DateTime<chrono::Utc>,
pub lambda_used: f32,
pub stats: ConsolidationStats,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ConsolidationStats {
pub merged_count: usize,
pub pruned_count: usize,
pub avg_pruned_importance: f32,
pub avg_kept_importance: f32,
pub processing_time_ms: u64,
}
pub struct PatternConsolidator {
config: ConsolidationConfig,
fisher_info: HashMap<u64, FisherInformation>,
lambda: f32,
consolidation_count: u64,
total_consolidated: u64,
}
impl PatternConsolidator {
pub fn new(config: ConsolidationConfig) -> Self {
let lambda = config.lambda;
Self {
config,
fisher_info: HashMap::new(),
lambda,
consolidation_count: 0,
total_consolidated: 0,
}
}
pub fn consolidate_patterns(&self, patterns: &[Pattern]) -> Result<ConsolidationResult> {
let start = std::time::Instant::now();
let patterns_before = patterns.len();
let scores: Vec<ImportanceScore> = patterns
.iter()
.map(|p| {
ImportanceScore::compute(
p,
self.fisher_info.get(&p.id),
self.config.max_unused_age_secs,
)
})
.collect();
let pruned_ids: Vec<u64> = scores
.iter()
.filter(|s| {
let pattern = patterns.iter().find(|p| p.id == s.pattern_id);
if let Some(p) = pattern {
s.score < 0.2 && p.avg_quality < self.config.min_quality_threshold
} else {
false
}
})
.map(|s| s.pattern_id)
.collect();
let merged_ids = self.find_mergeable_patterns(patterns, &pruned_ids)?;
let pruned_importance: f32 = scores
.iter()
.filter(|s| pruned_ids.contains(&s.pattern_id))
.map(|s| s.score)
.sum();
let kept_importance: f32 = scores
.iter()
.filter(|s| !pruned_ids.contains(&s.pattern_id) && !merged_ids.contains(&s.pattern_id))
.map(|s| s.score)
.sum();
let patterns_after = patterns_before - pruned_ids.len() - merged_ids.len();
let processing_time_ms = start.elapsed().as_millis() as u64;
let stats = ConsolidationStats {
merged_count: merged_ids.len(),
pruned_count: pruned_ids.len(),
avg_pruned_importance: if pruned_ids.is_empty() {
0.0
} else {
pruned_importance / pruned_ids.len() as f32
},
avg_kept_importance: if patterns_after == 0 {
0.0
} else {
kept_importance / patterns_after as f32
},
processing_time_ms,
};
Ok(ConsolidationResult {
merged_pattern_ids: merged_ids,
pruned_pattern_ids: pruned_ids,
patterns_before,
patterns_after,
importance_preserved: kept_importance,
timestamp: chrono::Utc::now(),
lambda_used: self.lambda,
stats,
})
}
fn find_mergeable_patterns(
&self,
patterns: &[Pattern],
exclude_ids: &[u64],
) -> Result<Vec<u64>> {
let mut merged = Vec::new();
let mut checked = std::collections::HashSet::new();
for i in 0..patterns.len() {
if exclude_ids.contains(&patterns[i].id) || checked.contains(&patterns[i].id) {
continue;
}
for j in (i + 1)..patterns.len() {
if exclude_ids.contains(&patterns[j].id)
|| checked.contains(&patterns[j].id)
|| merged.contains(&patterns[j].id)
{
continue;
}
if patterns[i].category != patterns[j].category {
continue;
}
let sim = patterns[i].similarity(&patterns[j].embedding);
if sim > self.config.merge_similarity_threshold {
merged.push(patterns[j].id);
checked.insert(patterns[j].id);
}
}
checked.insert(patterns[i].id);
}
Ok(merged)
}
pub fn prune_low_quality(&self, patterns: &[Pattern]) -> Vec<u64> {
patterns
.iter()
.filter(|p| {
p.avg_quality < self.config.min_quality_threshold
&& p.usage_count < self.config.min_usage_for_importance
})
.map(|p| p.id)
.collect()
}
pub fn merge_patterns(&self, patterns: &[Pattern]) -> Option<Pattern> {
if patterns.is_empty() {
return None;
}
if patterns.len() == 1 {
return Some(patterns[0].clone());
}
let mut merged = patterns[0].clone();
for pattern in &patterns[1..] {
merged.merge(pattern);
}
Some(merged)
}
pub fn update_fisher(&mut self, pattern_id: u64, gradient: &[f32]) {
let fisher = self
.fisher_info
.entry(pattern_id)
.or_insert_with(|| FisherInformation::new(gradient.len()));
fisher.update(gradient, self.config.fisher_decay);
}
pub fn apply_constraint(&self, pattern_id: u64, gradient: &[f32]) -> Vec<f32> {
if let Some(fisher) = self.fisher_info.get(&pattern_id) {
gradient
.iter()
.enumerate()
.map(|(i, &g)| {
let importance = fisher.importance(i);
if importance > 1e-8 {
let penalty = self.lambda * importance;
g / (1.0 + penalty)
} else {
g
}
})
.collect()
} else {
gradient.to_vec()
}
}
pub fn regularization_loss(
&self,
pattern_id: u64,
current_weights: &[f32],
optimal_weights: &[f32],
) -> f32 {
if current_weights.len() != optimal_weights.len() {
return 0.0;
}
if let Some(fisher) = self.fisher_info.get(&pattern_id) {
let mut loss = 0.0f32;
for i in 0..current_weights.len().min(fisher.diagonal.len()) {
let diff = current_weights[i] - optimal_weights[i];
loss += fisher.diagonal[i] * diff * diff;
}
self.lambda * loss / 2.0
} else {
0.0
}
}
pub fn adapt_lambda(&mut self, patterns: &[Pattern]) {
if !self.config.auto_adapt_lambda {
return;
}
let important_count = patterns
.iter()
.filter(|p| p.usage_count >= self.config.min_usage_for_importance)
.count();
let scale = 1.0 + 0.1 * important_count as f32;
self.lambda =
(self.config.lambda * scale).clamp(self.config.min_lambda, self.config.max_lambda);
}
pub fn consolidate_fisher(&mut self) {
if self.fisher_info.len() < 2 {
return;
}
let dim = self
.fisher_info
.values()
.next()
.map(|f| f.diagonal.len())
.unwrap_or(0);
if dim == 0 {
return;
}
let mut consolidated = FisherInformation::new(dim);
let count = self.fisher_info.len() as f32;
for fisher in self.fisher_info.values() {
for (i, &val) in fisher.diagonal.iter().enumerate() {
if i < consolidated.diagonal.len() {
consolidated.diagonal[i] += val / count;
}
}
for (i, &val) in fisher.ema_grad_squared.iter().enumerate() {
if i < consolidated.ema_grad_squared.len() {
consolidated.ema_grad_squared[i] += val / count;
}
}
consolidated.sample_count += fisher.sample_count;
}
self.fisher_info.clear();
self.fisher_info.insert(0, consolidated);
}
pub fn lambda(&self) -> f32 {
self.lambda
}
pub fn set_lambda(&mut self, lambda: f32) {
self.lambda = lambda.clamp(self.config.min_lambda, self.config.max_lambda);
}
pub fn stats(&self) -> ConsolidatorStats {
ConsolidatorStats {
fisher_entries: self.fisher_info.len(),
current_lambda: self.lambda,
consolidation_count: self.consolidation_count,
total_consolidated: self.total_consolidated,
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ConsolidatorStats {
pub fisher_entries: usize,
pub current_lambda: f32,
pub consolidation_count: u64,
pub total_consolidated: u64,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::reasoning_bank::pattern_store::PatternCategory;
fn make_pattern(id: u64, embedding: Vec<f32>, quality: f32, usage: u32) -> Pattern {
let mut p = Pattern::new(embedding, PatternCategory::General, quality);
p.id = id;
p.usage_count = usage;
p.avg_quality = quality;
p
}
#[test]
fn test_consolidation_config_default() {
let config = ConsolidationConfig::default();
assert_eq!(config.lambda, 2000.0);
assert!(config.auto_adapt_lambda);
}
#[test]
fn test_fisher_information() {
let mut fisher = FisherInformation::new(4);
assert_eq!(fisher.diagonal.len(), 4);
let gradient = vec![0.5, 0.3, 0.2, 0.1];
fisher.update(&gradient, 0.9);
assert!(fisher.sample_count > 0);
assert!(fisher.total_importance() > 0.0);
}
#[test]
fn test_importance_score() {
let pattern = make_pattern(1, vec![0.1; 4], 0.8, 10);
let score = ImportanceScore::compute(&pattern, None, 86400);
assert!(score.score > 0.0);
assert!(score.score <= 1.0);
}
#[test]
fn test_consolidator_creation() {
let config = ConsolidationConfig::default();
let consolidator = PatternConsolidator::new(config);
assert_eq!(consolidator.lambda(), 2000.0);
}
#[test]
fn test_prune_low_quality() {
let config = ConsolidationConfig {
min_quality_threshold: 0.5,
min_usage_for_importance: 5,
..Default::default()
};
let consolidator = PatternConsolidator::new(config);
let patterns = vec![
make_pattern(1, vec![0.1; 4], 0.8, 10), make_pattern(2, vec![0.2; 4], 0.3, 2), make_pattern(3, vec![0.3; 4], 0.4, 8), ];
let pruned = consolidator.prune_low_quality(&patterns);
assert_eq!(pruned.len(), 1);
assert!(pruned.contains(&2));
}
#[test]
fn test_consolidate_patterns() {
let config = ConsolidationConfig::default();
let consolidator = PatternConsolidator::new(config);
let patterns = vec![
make_pattern(1, vec![0.1; 4], 0.8, 10),
make_pattern(2, vec![0.2; 4], 0.1, 1), make_pattern(3, vec![0.3; 4], 0.7, 5),
];
let result = consolidator.consolidate_patterns(&patterns).unwrap();
assert_eq!(result.patterns_before, 3);
assert!(result.patterns_after <= 3);
}
#[test]
fn test_merge_similar_patterns() {
let config = ConsolidationConfig {
merge_similarity_threshold: 0.9,
..Default::default()
};
let consolidator = PatternConsolidator::new(config);
let patterns = vec![
make_pattern(1, vec![1.0, 0.0, 0.0, 0.0], 0.8, 5),
make_pattern(2, vec![0.99, 0.01, 0.0, 0.0], 0.7, 3), make_pattern(3, vec![0.0, 1.0, 0.0, 0.0], 0.9, 10), ];
let merged = consolidator
.find_mergeable_patterns(&patterns, &[])
.unwrap();
assert!(merged.contains(&2));
assert!(!merged.contains(&1));
assert!(!merged.contains(&3));
}
#[test]
fn test_ewc_constraint() {
let config = ConsolidationConfig::default();
let mut consolidator = PatternConsolidator::new(config);
consolidator.update_fisher(1, &vec![1.0, 1.0, 1.0, 1.0]);
consolidator.update_fisher(1, &vec![1.0, 1.0, 1.0, 1.0]);
let gradient = vec![1.0, 1.0, 1.0, 1.0];
let constrained = consolidator.apply_constraint(1, &gradient);
let orig_mag: f32 = gradient.iter().sum();
let const_mag: f32 = constrained.iter().sum();
assert!(const_mag <= orig_mag);
}
#[test]
fn test_regularization_loss() {
let config = ConsolidationConfig::default();
let mut consolidator = PatternConsolidator::new(config);
consolidator.update_fisher(1, &vec![1.0, 1.0]);
let optimal = vec![0.0, 0.0];
let current = vec![1.0, 1.0];
let loss = consolidator.regularization_loss(1, ¤t, &optimal);
assert!(loss > 0.0);
let at_optimal = consolidator.regularization_loss(1, &optimal, &optimal);
assert!(at_optimal < loss);
}
#[test]
fn test_lambda_adaptation() {
let config = ConsolidationConfig {
lambda: 1000.0,
min_usage_for_importance: 5,
auto_adapt_lambda: true,
..Default::default()
};
let mut consolidator = PatternConsolidator::new(config);
let initial_lambda = consolidator.lambda();
let patterns = vec![
make_pattern(1, vec![0.1; 4], 0.8, 10),
make_pattern(2, vec![0.2; 4], 0.7, 8),
make_pattern(3, vec![0.3; 4], 0.9, 15),
];
consolidator.adapt_lambda(&patterns);
assert!(consolidator.lambda() >= initial_lambda);
}
#[test]
fn test_consolidate_fisher() {
let config = ConsolidationConfig::default();
let mut consolidator = PatternConsolidator::new(config);
consolidator.update_fisher(1, &vec![1.0, 0.0]);
consolidator.update_fisher(2, &vec![0.0, 1.0]);
consolidator.update_fisher(3, &vec![0.5, 0.5]);
assert_eq!(consolidator.fisher_info.len(), 3);
consolidator.consolidate_fisher();
assert_eq!(consolidator.fisher_info.len(), 1);
}
}