use crate::time_compat::Instant;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct LearningSignal {
pub query_embedding: Vec<f32>,
pub gradient_estimate: Vec<f32>,
pub quality_score: f32,
#[serde(skip)]
pub timestamp: Option<Instant>,
pub metadata: SignalMetadata,
}
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
pub struct SignalMetadata {
pub trajectory_id: u64,
pub step_count: usize,
pub model_route: Option<String>,
pub tags: HashMap<String, String>,
}
impl LearningSignal {
pub fn from_trajectory(trajectory: &QueryTrajectory) -> Self {
let gradient = Self::estimate_gradient(trajectory);
Self {
query_embedding: trajectory.query_embedding.clone(),
gradient_estimate: gradient,
quality_score: trajectory.final_quality,
timestamp: Some(Instant::now()),
metadata: SignalMetadata {
trajectory_id: trajectory.id,
step_count: trajectory.steps.len(),
model_route: trajectory.model_route.clone(),
tags: HashMap::new(),
},
}
}
pub fn with_gradient(embedding: Vec<f32>, gradient: Vec<f32>, quality: f32) -> Self {
Self {
query_embedding: embedding,
gradient_estimate: gradient,
quality_score: quality,
timestamp: Some(Instant::now()),
metadata: SignalMetadata::default(),
}
}
fn estimate_gradient(trajectory: &QueryTrajectory) -> Vec<f32> {
if trajectory.steps.is_empty() {
return trajectory.query_embedding.clone();
}
let dim = trajectory.query_embedding.len();
let mut gradient = vec![0.0f32; dim];
let baseline =
trajectory.steps.iter().map(|s| s.reward).sum::<f32>() / trajectory.steps.len() as f32;
for step in &trajectory.steps {
let advantage = step.reward - baseline;
let activation_len = step.activations.len().min(dim);
for (grad, &act) in gradient
.iter_mut()
.zip(step.activations.iter())
.take(activation_len)
{
*grad += advantage * act;
}
}
let norm: f32 = gradient.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 1e-8 {
gradient.iter_mut().for_each(|x| *x /= norm);
}
gradient
}
pub fn scaled_gradient(&self) -> Vec<f32> {
self.gradient_estimate
.iter()
.map(|&g| g * self.quality_score)
.collect()
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct QueryTrajectory {
pub id: u64,
pub query_embedding: Vec<f32>,
pub steps: Vec<TrajectoryStep>,
pub final_quality: f32,
pub latency_us: u64,
pub model_route: Option<String>,
pub context_ids: Vec<String>,
}
impl QueryTrajectory {
pub fn new(id: u64, query_embedding: Vec<f32>) -> Self {
Self {
id,
query_embedding,
steps: Vec::with_capacity(16),
final_quality: 0.0,
latency_us: 0,
model_route: None,
context_ids: Vec::new(),
}
}
pub fn add_step(&mut self, step: TrajectoryStep) {
self.steps.push(step);
}
pub fn finalize(&mut self, quality: f32, latency_us: u64) {
self.final_quality = quality;
self.latency_us = latency_us;
}
pub fn total_reward(&self) -> f32 {
self.steps.iter().map(|s| s.reward).sum()
}
pub fn avg_reward(&self) -> f32 {
if self.steps.is_empty() {
0.0
} else {
self.total_reward() / self.steps.len() as f32
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct TrajectoryStep {
pub activations: Vec<f32>,
pub attention_weights: Vec<f32>,
pub reward: f32,
pub step_idx: usize,
pub layer_name: Option<String>,
}
impl TrajectoryStep {
pub fn new(
activations: Vec<f32>,
attention_weights: Vec<f32>,
reward: f32,
step_idx: usize,
) -> Self {
Self {
activations,
attention_weights,
reward,
step_idx,
layer_name: None,
}
}
pub fn with_layer(mut self, name: &str) -> Self {
self.layer_name = Some(name.to_string());
self
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct LearnedPattern {
pub id: u64,
pub centroid: Vec<f32>,
pub cluster_size: usize,
pub total_weight: f32,
pub avg_quality: f32,
pub created_at: u64,
pub last_accessed: u64,
pub access_count: u32,
pub pattern_type: PatternType,
}
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, Eq)]
pub enum PatternType {
#[default]
General,
Reasoning,
Factual,
Creative,
CodeGen,
Conversational,
}
impl std::fmt::Display for PatternType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
PatternType::General => write!(f, "general"),
PatternType::Reasoning => write!(f, "reasoning"),
PatternType::Factual => write!(f, "factual"),
PatternType::Creative => write!(f, "creative"),
PatternType::CodeGen => write!(f, "codegen"),
PatternType::Conversational => write!(f, "conversational"),
}
}
}
impl LearnedPattern {
pub fn new(id: u64, centroid: Vec<f32>) -> Self {
use crate::time_compat::SystemTime;
let now = SystemTime::now().duration_since_epoch().as_secs();
Self {
id,
centroid,
cluster_size: 1,
total_weight: 1.0,
avg_quality: 0.0,
created_at: now,
last_accessed: now,
access_count: 0,
pattern_type: PatternType::default(),
}
}
pub fn merge(&self, other: &Self) -> Self {
let total_size = self.cluster_size + other.cluster_size;
let w1 = self.cluster_size as f32 / total_size as f32;
let w2 = other.cluster_size as f32 / total_size as f32;
let centroid: Vec<f32> = self
.centroid
.iter()
.zip(&other.centroid)
.map(|(&a, &b)| a * w1 + b * w2)
.collect();
Self {
id: self.id,
centroid,
cluster_size: total_size,
total_weight: self.total_weight + other.total_weight,
avg_quality: self.avg_quality * w1 + other.avg_quality * w2,
created_at: self.created_at.min(other.created_at),
last_accessed: self.last_accessed.max(other.last_accessed),
access_count: self.access_count + other.access_count,
pattern_type: self.pattern_type.clone(),
}
}
pub fn decay(&mut self, factor: f32) {
self.total_weight *= factor;
}
pub fn touch(&mut self) {
use crate::time_compat::SystemTime;
self.access_count += 1;
self.last_accessed = SystemTime::now().duration_since_epoch().as_secs();
}
pub fn should_prune(&self, min_quality: f32, min_accesses: u32, max_age_secs: u64) -> bool {
use crate::time_compat::SystemTime;
let now = SystemTime::now().duration_since_epoch().as_secs();
let age = now.saturating_sub(self.last_accessed);
self.avg_quality < min_quality && self.access_count < min_accesses && age > max_age_secs
}
pub fn similarity(&self, query: &[f32]) -> f32 {
if self.centroid.len() != query.len() {
return 0.0;
}
let dot: f32 = self.centroid.iter().zip(query).map(|(a, b)| a * b).sum();
let norm_a: f32 = self.centroid.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = query.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a > 1e-8 && norm_b > 1e-8 {
dot / (norm_a * norm_b)
} else {
0.0
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct SonaConfig {
pub hidden_dim: usize,
pub embedding_dim: usize,
pub micro_lora_rank: usize,
pub base_lora_rank: usize,
pub micro_lora_lr: f32,
pub base_lora_lr: f32,
pub ewc_lambda: f32,
pub pattern_clusters: usize,
pub trajectory_capacity: usize,
pub background_interval_ms: u64,
pub quality_threshold: f32,
pub enable_simd: bool,
}
impl Default for SonaConfig {
fn default() -> Self {
Self {
hidden_dim: 256,
embedding_dim: 256,
micro_lora_rank: 2, base_lora_rank: 8, micro_lora_lr: 0.002, base_lora_lr: 0.0001,
ewc_lambda: 2000.0, pattern_clusters: 100, trajectory_capacity: 10000,
background_interval_ms: 3600000, quality_threshold: 0.3, enable_simd: true,
}
}
}
impl SonaConfig {
pub fn max_throughput() -> Self {
Self {
hidden_dim: 256,
embedding_dim: 256,
micro_lora_rank: 2, base_lora_rank: 4, micro_lora_lr: 0.0005, base_lora_lr: 0.0001,
ewc_lambda: 2000.0,
pattern_clusters: 100,
trajectory_capacity: 5000,
background_interval_ms: 7200000, quality_threshold: 0.4,
enable_simd: true,
}
}
pub fn max_quality() -> Self {
Self {
hidden_dim: 256,
embedding_dim: 256,
micro_lora_rank: 2,
base_lora_rank: 16, micro_lora_lr: 0.002, base_lora_lr: 0.001, ewc_lambda: 2000.0,
pattern_clusters: 100,
trajectory_capacity: 20000,
background_interval_ms: 1800000, quality_threshold: 0.2, enable_simd: true,
}
}
pub fn edge_deployment() -> Self {
Self {
hidden_dim: 256,
embedding_dim: 256,
micro_lora_rank: 1, base_lora_rank: 4,
micro_lora_lr: 0.001,
base_lora_lr: 0.0001,
ewc_lambda: 1000.0,
pattern_clusters: 50,
trajectory_capacity: 200, background_interval_ms: 3600000,
quality_threshold: 0.5,
enable_simd: true,
}
}
pub fn batch_processing() -> Self {
Self {
hidden_dim: 256,
embedding_dim: 256,
micro_lora_rank: 2,
base_lora_rank: 8,
micro_lora_lr: 0.001,
base_lora_lr: 0.0001,
ewc_lambda: 2000.0,
pattern_clusters: 100,
trajectory_capacity: 10000,
background_interval_ms: 3600000,
quality_threshold: 0.3,
enable_simd: true,
}
}
pub fn for_ephemeral() -> Self {
Self {
hidden_dim: 256,
embedding_dim: 256,
micro_lora_rank: 2,
base_lora_rank: 4, micro_lora_lr: 0.002,
base_lora_lr: 0.0001,
ewc_lambda: 1000.0,
pattern_clusters: 50, trajectory_capacity: 500, background_interval_ms: 60000, quality_threshold: 0.3,
enable_simd: true,
}
}
pub fn for_coordinator() -> Self {
Self {
hidden_dim: 256,
embedding_dim: 256,
micro_lora_rank: 2,
base_lora_rank: 16, micro_lora_lr: 0.001, base_lora_lr: 0.0005, ewc_lambda: 2000.0, pattern_clusters: 200, trajectory_capacity: 50000, background_interval_ms: 300000, quality_threshold: 0.4, enable_simd: true,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_learning_signal_from_trajectory() {
let mut trajectory = QueryTrajectory::new(1, vec![0.1, 0.2, 0.3]);
trajectory.add_step(TrajectoryStep::new(
vec![0.5, 0.3, 0.2],
vec![0.4, 0.4, 0.2],
0.8,
0,
));
trajectory.finalize(0.8, 1000);
let signal = LearningSignal::from_trajectory(&trajectory);
assert_eq!(signal.quality_score, 0.8);
assert_eq!(signal.gradient_estimate.len(), 3);
assert_eq!(signal.metadata.trajectory_id, 1);
}
#[test]
fn test_pattern_merge() {
let p1 = LearnedPattern {
id: 1,
centroid: vec![1.0, 0.0],
cluster_size: 10,
total_weight: 5.0,
avg_quality: 0.8,
created_at: 100,
last_accessed: 200,
access_count: 5,
pattern_type: PatternType::General,
};
let p2 = LearnedPattern {
id: 2,
centroid: vec![0.0, 1.0],
cluster_size: 10,
total_weight: 5.0,
avg_quality: 0.9,
created_at: 150,
last_accessed: 250,
access_count: 3,
pattern_type: PatternType::General,
};
let merged = p1.merge(&p2);
assert_eq!(merged.cluster_size, 20);
assert!((merged.centroid[0] - 0.5).abs() < 1e-6);
assert!((merged.centroid[1] - 0.5).abs() < 1e-6);
assert!((merged.avg_quality - 0.85).abs() < 1e-6);
}
#[test]
fn test_pattern_similarity() {
let pattern = LearnedPattern::new(1, vec![1.0, 0.0, 0.0]);
assert!((pattern.similarity(&[1.0, 0.0, 0.0]) - 1.0).abs() < 1e-6);
assert!(pattern.similarity(&[0.0, 1.0, 0.0]).abs() < 1e-6);
}
#[test]
fn test_trajectory_rewards() {
let mut trajectory = QueryTrajectory::new(1, vec![0.1]);
trajectory.add_step(TrajectoryStep::new(vec![], vec![], 0.5, 0));
trajectory.add_step(TrajectoryStep::new(vec![], vec![], 0.7, 1));
trajectory.add_step(TrajectoryStep::new(vec![], vec![], 0.9, 2));
assert!((trajectory.total_reward() - 2.1).abs() < 1e-6);
assert!((trajectory.avg_reward() - 0.7).abs() < 1e-6);
}
}