use crate::ewc::{EwcConfig, EwcPlusPlus};
use crate::loops::background::{BackgroundLoop, BackgroundLoopConfig, BackgroundResult};
use crate::loops::instant::InstantLoop;
use crate::lora::{BaseLoRA, MicroLoRA};
use crate::reasoning_bank::{PatternConfig, ReasoningBank};
use crate::types::{QueryTrajectory, SonaConfig};
use parking_lot::RwLock;
use std::sync::Arc;
pub struct LoopCoordinator {
_config: SonaConfig,
instant: InstantLoop,
background: BackgroundLoop,
reasoning_bank: Arc<RwLock<ReasoningBank>>,
ewc: Arc<RwLock<EwcPlusPlus>>,
base_lora: Arc<RwLock<BaseLoRA>>,
instant_enabled: bool,
background_enabled: bool,
}
impl LoopCoordinator {
pub fn new(hidden_dim: usize) -> Self {
Self::with_config(SonaConfig {
hidden_dim,
embedding_dim: hidden_dim,
..Default::default()
})
}
pub fn with_config(config: SonaConfig) -> Self {
let reasoning_bank = Arc::new(RwLock::new(ReasoningBank::new(PatternConfig {
embedding_dim: config.embedding_dim,
k_clusters: config.pattern_clusters,
..Default::default()
})));
let ewc = Arc::new(RwLock::new(EwcPlusPlus::new(EwcConfig {
param_count: config.hidden_dim * config.base_lora_rank * 2,
initial_lambda: config.ewc_lambda,
..Default::default()
})));
let base_lora = Arc::new(RwLock::new(BaseLoRA::new(
config.hidden_dim,
config.base_lora_rank,
12, )));
let instant = InstantLoop::from_sona_config(&config);
let background = BackgroundLoop::new(
BackgroundLoopConfig::from(&config),
reasoning_bank.clone(),
ewc.clone(),
base_lora.clone(),
);
Self {
_config: config,
instant,
background,
reasoning_bank,
ewc,
base_lora,
instant_enabled: true,
background_enabled: true,
}
}
pub fn on_inference(&self, trajectory: QueryTrajectory) {
if self.instant_enabled {
self.instant.on_trajectory(trajectory);
}
}
pub fn next_trajectory_id(&self) -> u64 {
self.instant.next_id()
}
pub fn maybe_run_background(&self) -> Option<BackgroundResult> {
if !self.background_enabled {
return None;
}
if self.background.should_run() {
let trajectories = self.instant.drain_trajectories();
if !trajectories.is_empty() {
return Some(self.background.run_cycle(trajectories, false));
}
}
None
}
pub fn force_background(&self) -> BackgroundResult {
let trajectories = self.instant.drain_trajectories();
self.background.run_cycle(trajectories, true)
}
pub fn flush_instant(&self) {
self.instant.flush();
}
pub fn micro_lora(&self) -> &Arc<RwLock<MicroLoRA>> {
self.instant.micro_lora()
}
pub fn get_micro_lora_weights(&self) -> (Vec<f32>, Vec<f32>) {
let guard = self.instant.micro_lora().read();
let (down, up) = guard.get_weights();
(down.clone(), up.clone())
}
pub fn restore_micro_lora_weights(&self, down: Vec<f32>, up: Vec<f32>) -> Result<(), String> {
self.instant.micro_lora().write().set_weights(down, up)
}
pub fn base_lora(&self) -> &Arc<RwLock<BaseLoRA>> {
&self.base_lora
}
pub fn reasoning_bank(&self) -> &Arc<RwLock<ReasoningBank>> {
&self.reasoning_bank
}
pub fn ewc(&self) -> &Arc<RwLock<EwcPlusPlus>> {
&self.ewc
}
pub fn set_instant_enabled(&mut self, enabled: bool) {
self.instant_enabled = enabled;
}
pub fn set_background_enabled(&mut self, enabled: bool) {
self.background_enabled = enabled;
}
pub fn stats(&self) -> CoordinatorStats {
let (buffer_len, dropped, success_rate) = self.instant.buffer_stats();
CoordinatorStats {
trajectories_recorded: buffer_len as u64 + dropped,
trajectories_buffered: buffer_len,
trajectories_dropped: dropped,
buffer_success_rate: success_rate,
patterns_stored: self.reasoning_bank.read().pattern_count(),
patterns_learned: self.reasoning_bank.read().pattern_count(),
ewc_tasks: self.ewc.read().task_count(),
instant_enabled: self.instant_enabled,
background_enabled: self.background_enabled,
}
}
pub fn serialize_state(&self) -> String {
let rb = self.reasoning_bank.read();
let patterns = rb.get_all_patterns();
let ewc = self.ewc.read();
serde_json::json!({
"version": 1,
"patterns": patterns,
"ewc_task_count": ewc.task_count(),
"instant_enabled": self.instant_enabled,
"background_enabled": self.background_enabled,
}).to_string()
}
pub fn load_state(&self, json: &str) -> Result<usize, String> {
let state: serde_json::Value = serde_json::from_str(json)
.map_err(|e| format!("Invalid state JSON: {}", e))?;
let mut loaded = 0;
if let Some(patterns) = state.get("patterns").and_then(|p| p.as_array()) {
let mut rb = self.reasoning_bank.write();
for p in patterns {
if let Ok(pattern) = serde_json::from_value::<crate::LearnedPattern>(p.clone()) {
rb.insert_pattern(pattern);
loaded += 1;
}
}
}
Ok(loaded)
}
}
#[derive(Debug, Clone)]
#[cfg_attr(
feature = "serde-support",
derive(serde::Serialize, serde::Deserialize)
)]
pub struct CoordinatorStats {
#[cfg_attr(feature = "serde-support", serde(alias = "trajectoriesRecorded"))]
pub trajectories_recorded: u64,
pub trajectories_buffered: usize,
pub trajectories_dropped: u64,
pub buffer_success_rate: f64,
pub patterns_stored: usize,
#[cfg_attr(feature = "serde-support", serde(alias = "patternsLearned"))]
pub patterns_learned: usize,
pub ewc_tasks: usize,
pub instant_enabled: bool,
pub background_enabled: bool,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::TrajectoryStep;
fn make_trajectory(id: u64) -> QueryTrajectory {
let mut t = QueryTrajectory::new(id, vec![0.1; 256]);
t.add_step(TrajectoryStep::new(vec![0.5; 256], vec![], 0.8, 0));
t.finalize(0.8, 1000);
t
}
#[test]
fn test_coordinator_creation() {
let coord = LoopCoordinator::new(256);
let stats = coord.stats();
assert_eq!(stats.trajectories_buffered, 0);
}
#[test]
fn test_inference_processing() {
let coord = LoopCoordinator::new(256);
for i in 0..10 {
let t = make_trajectory(coord.next_trajectory_id());
coord.on_inference(t);
}
let stats = coord.stats();
assert_eq!(stats.trajectories_buffered, 10);
}
#[test]
fn test_force_background() {
let coord = LoopCoordinator::new(256);
for i in 0..150 {
let t = make_trajectory(coord.next_trajectory_id());
coord.on_inference(t);
}
let result = coord.force_background();
assert_eq!(result.trajectories_processed, 150);
assert!(result.patterns_extracted > 0);
}
#[test]
fn test_get_micro_lora_weights() {
let coord = LoopCoordinator::new(256);
let (down, up) = coord.get_micro_lora_weights();
assert_eq!(down.len(), 256 * 2);
assert_eq!(up.len(), 2 * 256);
}
#[test]
fn test_restore_micro_lora_weights() {
let coord = LoopCoordinator::new(256);
let custom_down = vec![0.5f32; 256 * 2];
let custom_up = vec![0.3f32; 2 * 256];
let result = coord.restore_micro_lora_weights(custom_down.clone(), custom_up.clone());
assert!(result.is_ok());
let (got_down, got_up) = coord.get_micro_lora_weights();
assert_eq!(got_down, custom_down);
assert_eq!(got_up, custom_up);
}
#[test]
fn test_restore_micro_lora_weights_wrong_dim() {
let coord = LoopCoordinator::new(256);
let wrong_down = vec![0.5f32; 256 * 3];
let custom_up = vec![0.3f32; 2 * 256];
let result = coord.restore_micro_lora_weights(wrong_down, custom_up);
assert!(result.is_err());
assert!(result.unwrap_err().contains("down_proj dimension mismatch"));
}
}