use crate::ewc::EwcPlusPlus;
use crate::lora::BaseLoRA;
use crate::reasoning_bank::ReasoningBank;
use crate::time_compat::Instant;
use crate::types::{LearnedPattern, QueryTrajectory, SonaConfig};
use parking_lot::RwLock;
use std::sync::Arc;
use std::time::Duration;
#[derive(Clone, Debug)]
pub struct BackgroundLoopConfig {
pub min_trajectories: usize,
pub base_lora_lr: f32,
pub ewc_lambda: f32,
pub extraction_interval: Duration,
}
impl Default for BackgroundLoopConfig {
fn default() -> Self {
Self {
min_trajectories: 100,
base_lora_lr: 0.0001,
ewc_lambda: 1000.0,
extraction_interval: Duration::from_secs(3600),
}
}
}
impl From<&SonaConfig> for BackgroundLoopConfig {
fn from(config: &SonaConfig) -> Self {
Self {
min_trajectories: 100,
base_lora_lr: config.base_lora_lr,
ewc_lambda: config.ewc_lambda,
extraction_interval: Duration::from_millis(config.background_interval_ms),
}
}
}
#[derive(Debug)]
pub struct BackgroundResult {
pub trajectories_processed: usize,
pub patterns_extracted: usize,
pub ewc_updated: bool,
pub elapsed: Duration,
pub status: String,
}
impl BackgroundResult {
fn skipped(reason: &str) -> Self {
Self {
trajectories_processed: 0,
patterns_extracted: 0,
ewc_updated: false,
elapsed: Duration::ZERO,
status: format!("skipped: {}", reason),
}
}
}
pub struct BackgroundLoop {
config: BackgroundLoopConfig,
reasoning_bank: Arc<RwLock<ReasoningBank>>,
ewc: Arc<RwLock<EwcPlusPlus>>,
base_lora: Arc<RwLock<BaseLoRA>>,
last_extraction: RwLock<Instant>,
}
impl BackgroundLoop {
pub fn new(
config: BackgroundLoopConfig,
reasoning_bank: Arc<RwLock<ReasoningBank>>,
ewc: Arc<RwLock<EwcPlusPlus>>,
base_lora: Arc<RwLock<BaseLoRA>>,
) -> Self {
Self {
config,
reasoning_bank,
ewc,
base_lora,
last_extraction: RwLock::new(Instant::now()),
}
}
pub fn should_run(&self) -> bool {
self.last_extraction.read().elapsed() >= self.config.extraction_interval
}
pub fn run_cycle(&self, trajectories: Vec<QueryTrajectory>, force: bool) -> BackgroundResult {
if !force && trajectories.len() < self.config.min_trajectories {
return BackgroundResult::skipped(&format!(
"insufficient trajectories ({} < {} minimum, use forceLearn to bypass)",
trajectories.len(),
self.config.min_trajectories
));
}
if trajectories.is_empty() {
return BackgroundResult::skipped("no trajectories to process");
}
let start = Instant::now();
{
let mut bank = self.reasoning_bank.write();
for trajectory in &trajectories {
bank.add_trajectory(trajectory);
}
}
let patterns = {
let mut bank = self.reasoning_bank.write();
bank.extract_patterns()
};
let gradients = self.compute_pattern_gradients(&patterns);
let constrained_gradients = {
let ewc = self.ewc.read();
ewc.apply_constraints(&gradients)
};
let task_boundary = {
let ewc = self.ewc.read();
ewc.detect_task_boundary(&gradients)
};
if task_boundary {
let mut ewc = self.ewc.write();
ewc.start_new_task();
}
{
let mut ewc = self.ewc.write();
ewc.update_fisher(&constrained_gradients);
}
self.update_base_lora(&constrained_gradients);
*self.last_extraction.write() = Instant::now();
BackgroundResult {
trajectories_processed: trajectories.len(),
patterns_extracted: patterns.len(),
ewc_updated: true,
elapsed: start.elapsed(),
status: "completed".to_string(),
}
}
fn compute_pattern_gradients(&self, patterns: &[LearnedPattern]) -> Vec<f32> {
if patterns.is_empty() {
return Vec::new();
}
let dim = patterns[0].centroid.len();
let mut gradient = vec![0.0f32; dim];
let mut total_weight = 0.0f32;
for pattern in patterns {
let weight = pattern.avg_quality * pattern.cluster_size as f32;
for (i, &v) in pattern.centroid.iter().enumerate() {
if i < dim {
gradient[i] += v * weight;
}
}
total_weight += weight;
}
if total_weight > 0.0 {
for g in &mut gradient {
*g /= total_weight;
}
}
gradient
}
fn update_base_lora(&self, gradients: &[f32]) {
let mut lora = self.base_lora.write();
let num_layers = lora.num_layers();
if num_layers == 0 || gradients.is_empty() {
return;
}
let per_layer = gradients.len() / num_layers;
for (layer_idx, layer) in lora.layers.iter_mut().enumerate() {
let start = layer_idx * per_layer;
let end = (start + per_layer).min(gradients.len());
for (i, &grad) in gradients[start..end].iter().enumerate() {
if i < layer.up_proj.len() {
layer.up_proj[i] += grad * self.config.base_lora_lr;
}
}
}
}
pub fn reasoning_bank(&self) -> &Arc<RwLock<ReasoningBank>> {
&self.reasoning_bank
}
pub fn ewc(&self) -> &Arc<RwLock<EwcPlusPlus>> {
&self.ewc
}
pub fn base_lora(&self) -> &Arc<RwLock<BaseLoRA>> {
&self.base_lora
}
}