use super::adjustment::{AdjustmentReason, ThresholdAdjustment};
use super::config::{ThresholdConfig, TunerConfig};
use super::error::{SonaTuningError, SonaTuningResult};
use ruvector_sona::{
EwcConfig, EwcPlusPlus, PatternConfig, ReasoningBank, SonaConfig, SonaEngine,
TrajectoryBuilder,
};
use std::collections::VecDeque;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TunerState {
Uninitialized,
Ready,
TrackingRegime,
Consolidating,
}
#[derive(Debug)]
pub struct RegimeTracker {
current_regime: Option<String>,
energy_history: VecDeque<f32>,
max_history: usize,
regime_start_ms: u64,
}
impl RegimeTracker {
pub fn new(max_history: usize) -> Self {
Self {
current_regime: None,
energy_history: VecDeque::with_capacity(max_history),
max_history,
regime_start_ms: 0,
}
}
pub fn start_regime(&mut self, regime_id: impl Into<String>, initial_energy: f32) {
self.current_regime = Some(regime_id.into());
self.energy_history.clear();
self.energy_history.push_back(initial_energy);
self.regime_start_ms = current_time_ms();
}
pub fn record_energy(&mut self, energy: f32) {
if self.energy_history.len() >= self.max_history {
self.energy_history.pop_front();
}
self.energy_history.push_back(energy);
}
pub fn current_regime(&self) -> Option<&str> {
self.current_regime.as_deref()
}
pub fn energy_history(&self) -> &VecDeque<f32> {
&self.energy_history
}
pub fn average_energy(&self) -> f32 {
if self.energy_history.is_empty() {
return 0.0;
}
self.energy_history.iter().sum::<f32>() / self.energy_history.len() as f32
}
pub fn energy_trend(&self) -> f32 {
if self.energy_history.len() < 2 {
return 0.0;
}
let half = self.energy_history.len() / 2;
let first_half_avg: f32 = self.energy_history.iter().take(half).sum::<f32>() / half as f32;
let second_half_avg: f32 =
self.energy_history.iter().skip(half).sum::<f32>() / (self.energy_history.len() - half) as f32;
second_half_avg - first_half_avg
}
pub fn regime_duration_secs(&self) -> f32 {
(current_time_ms() - self.regime_start_ms) as f32 / 1000.0
}
pub fn end_regime(&mut self) -> Option<RegimeSummary> {
self.current_regime.take().map(|id| RegimeSummary {
regime_id: id,
duration_secs: self.regime_duration_secs(),
average_energy: self.average_energy(),
energy_trend: self.energy_trend(),
sample_count: self.energy_history.len(),
})
}
}
#[derive(Debug, Clone)]
pub struct RegimeSummary {
pub regime_id: String,
pub duration_secs: f32,
pub average_energy: f32,
pub energy_trend: f32,
pub sample_count: usize,
}
pub struct SonaThresholdTuner {
engine: SonaEngine,
ewc: EwcPlusPlus,
reasoning_bank: ReasoningBank,
config: TunerConfig,
current_thresholds: ThresholdConfig,
regime_tracker: RegimeTracker,
state: TunerState,
trajectories_since_consolidation: usize,
}
impl SonaThresholdTuner {
pub fn new(config: TunerConfig) -> Self {
let sona_config = SonaConfig {
hidden_dim: config.hidden_dim,
embedding_dim: config.embedding_dim,
..Default::default()
};
let engine = SonaEngine::with_config(sona_config);
let ewc_config = EwcConfig {
initial_lambda: config.ewc_lambda,
..Default::default()
};
let ewc = EwcPlusPlus::new(ewc_config);
let pattern_config = PatternConfig::default();
let reasoning_bank = ReasoningBank::new(pattern_config);
Self {
engine,
ewc,
reasoning_bank,
current_thresholds: config.initial_thresholds,
regime_tracker: RegimeTracker::new(1000),
state: TunerState::Ready,
trajectories_since_consolidation: 0,
config,
}
}
pub fn default_tuner() -> Self {
Self::new(TunerConfig::default())
}
pub fn state(&self) -> TunerState {
self.state
}
pub fn current_thresholds(&self) -> &ThresholdConfig {
&self.current_thresholds
}
pub fn begin_regime(&mut self, energy_trace: &[f32]) -> SonaTuningResult<TrajectoryBuilder> {
if energy_trace.is_empty() {
return Err(SonaTuningError::trajectory("empty energy trace"));
}
let mut embedding = vec![0.0; self.config.embedding_dim];
for (i, &e) in energy_trace.iter().take(self.config.embedding_dim).enumerate() {
embedding[i] = e;
}
let builder = self.engine.begin_trajectory(embedding);
let regime_id = format!("regime_{}", current_time_ms());
self.regime_tracker.start_regime(
®ime_id,
energy_trace.last().copied().unwrap_or(0.0),
);
self.state = TunerState::TrackingRegime;
Ok(builder)
}
pub fn record_energy(&mut self, energy: f32) {
self.regime_tracker.record_energy(energy);
}
pub fn learn_outcome(
&mut self,
builder: TrajectoryBuilder,
success_score: f32,
) -> SonaTuningResult<Option<ThresholdAdjustment>> {
self.engine.end_trajectory(builder, success_score);
let summary = self.regime_tracker.end_regime();
self.trajectories_since_consolidation += 1;
self.state = TunerState::Ready;
if success_score > 0.8 {
self.store_success_pattern(success_score)?;
}
if self.trajectories_since_consolidation >= self.config.auto_consolidate_after {
self.consolidate_knowledge()?;
}
if success_score > 0.9 {
if let Some(summary) = summary {
return Ok(Some(ThresholdAdjustment::new(
&self.current_thresholds,
self.current_thresholds, AdjustmentReason::BackgroundLearning {
samples: summary.sample_count,
},
success_score,
)));
}
}
Ok(None)
}
fn store_success_pattern(&mut self, _score: f32) -> SonaTuningResult<()> {
Ok(())
}
fn threshold_to_embedding(&self, config: &ThresholdConfig) -> Vec<f32> {
let mut embedding = vec![0.0; self.config.embedding_dim];
embedding[0] = config.reflex;
embedding[1] = config.retrieval;
embedding[2] = config.heavy;
embedding[3] = config.persistence_window_secs as f32 / 60.0; embedding
}
fn embedding_to_threshold(&self, embedding: &[f32]) -> Option<ThresholdConfig> {
if embedding.len() < 4 {
return None;
}
let config = ThresholdConfig {
reflex: embedding[0].clamp(0.0, 1.0),
retrieval: embedding[1].clamp(0.0, 1.0),
heavy: embedding[2].clamp(0.0, 1.0),
persistence_window_secs: (embedding[3] * 60.0).max(1.0) as u64,
};
if config.is_valid() {
Some(config)
} else {
None
}
}
pub fn find_similar_regime(&self, current_energy: &[f32]) -> Option<ThresholdConfig> {
let mut query = vec![0.0; self.config.embedding_dim];
for (i, &e) in current_energy.iter().take(self.config.embedding_dim).enumerate() {
query[i] = e;
}
let similar = self.reasoning_bank.find_similar(&query, 1);
if let Some(pattern) = similar.first() {
self.embedding_to_threshold(&pattern.centroid)
} else {
None
}
}
pub fn instant_adapt(&mut self, energy_spike: f32) -> ThresholdAdjustment {
let input = vec![energy_spike; self.config.embedding_dim];
let mut output = vec![0.0; self.config.embedding_dim];
self.engine.apply_micro_lora(&input, &mut output);
ThresholdAdjustment::for_energy_spike(&self.current_thresholds, energy_spike)
}
pub fn apply_adjustment(&mut self, adjustment: &ThresholdAdjustment) {
if adjustment.new_thresholds.is_valid() {
self.current_thresholds = adjustment.new_thresholds;
}
}
pub fn consolidate_knowledge(&mut self) -> SonaTuningResult<()> {
self.state = TunerState::Consolidating;
self.ewc.consolidate_all_tasks();
self.trajectories_since_consolidation = 0;
self.state = TunerState::Ready;
Ok(())
}
pub fn stats(&self) -> TunerStats {
TunerStats {
state: self.state,
current_thresholds: self.current_thresholds,
patterns_stored: self.reasoning_bank.pattern_count(),
trajectories_since_consolidation: self.trajectories_since_consolidation,
regime_average_energy: self.regime_tracker.average_energy(),
regime_energy_trend: self.regime_tracker.energy_trend(),
}
}
pub fn reset(&mut self) {
self.current_thresholds = self.config.initial_thresholds;
self.regime_tracker = RegimeTracker::new(1000);
self.trajectories_since_consolidation = 0;
self.state = TunerState::Ready;
}
}
impl std::fmt::Debug for SonaThresholdTuner {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SonaThresholdTuner")
.field("state", &self.state)
.field("current_thresholds", &self.current_thresholds)
.field("patterns_stored", &self.reasoning_bank.pattern_count())
.finish()
}
}
#[derive(Debug, Clone, Copy)]
pub struct TunerStats {
pub state: TunerState,
pub current_thresholds: ThresholdConfig,
pub patterns_stored: usize,
pub trajectories_since_consolidation: usize,
pub regime_average_energy: f32,
pub regime_energy_trend: f32,
}
fn current_time_ms() -> u64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_millis() as u64)
.unwrap_or(0)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tuner_creation() {
let tuner = SonaThresholdTuner::default_tuner();
assert_eq!(tuner.state(), TunerState::Ready);
}
#[test]
fn test_regime_tracker() {
let mut tracker = RegimeTracker::new(100);
tracker.start_regime("test", 0.5);
tracker.record_energy(0.6);
tracker.record_energy(0.7);
assert_eq!(tracker.current_regime(), Some("test"));
assert!(tracker.average_energy() > 0.5);
assert!(tracker.energy_trend() > 0.0);
}
#[test]
fn test_instant_adapt() {
let mut tuner = SonaThresholdTuner::default_tuner();
let initial = *tuner.current_thresholds();
let adjustment = tuner.instant_adapt(0.5);
assert!(adjustment.new_thresholds.reflex < initial.reflex);
assert!(adjustment.urgent);
}
#[test]
fn test_threshold_embedding_roundtrip() {
let tuner = SonaThresholdTuner::default_tuner();
let original = ThresholdConfig::default();
let embedding = tuner.threshold_to_embedding(&original);
let recovered = tuner.embedding_to_threshold(&embedding);
assert!(recovered.is_some());
let recovered = recovered.unwrap();
assert!((recovered.reflex - original.reflex).abs() < 0.001);
}
}