use super::metrics::TrainingMetrics;
use crate::engine::SonaEngine;
use crate::time_compat::SystemTime;
use crate::types::{LearnedPattern, SonaConfig};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct AgentExport {
pub agent_id: String,
pub trajectories: Vec<TrajectoryExport>,
pub stats: AgentExportStats,
pub session_duration_ms: u64,
pub timestamp: u64,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct TrajectoryExport {
pub embedding: Vec<f32>,
pub quality: f32,
pub route: Option<String>,
pub context: Vec<String>,
pub timestamp: u64,
}
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
pub struct AgentExportStats {
pub total_trajectories: usize,
pub avg_quality: f32,
pub patterns_learned: usize,
}
pub struct EphemeralAgent {
agent_id: String,
engine: SonaEngine,
trajectories: Vec<TrajectoryExport>,
start_time: u64,
quality_samples: Vec<f32>,
}
impl EphemeralAgent {
pub fn new(agent_id: impl Into<String>, config: SonaConfig) -> Self {
let now = SystemTime::now().duration_since_epoch().as_millis() as u64;
Self {
agent_id: agent_id.into(),
engine: SonaEngine::with_config(config),
trajectories: Vec::new(),
start_time: now,
quality_samples: Vec::new(),
}
}
pub fn default_federated(agent_id: impl Into<String>, hidden_dim: usize) -> Self {
Self::new(
agent_id,
SonaConfig {
hidden_dim,
embedding_dim: hidden_dim,
micro_lora_rank: 2,
base_lora_rank: 8,
micro_lora_lr: 0.002,
trajectory_capacity: 500, pattern_clusters: 25,
..Default::default()
},
)
}
pub fn agent_id(&self) -> &str {
&self.agent_id
}
pub fn engine(&self) -> &SonaEngine {
&self.engine
}
pub fn engine_mut(&mut self) -> &mut SonaEngine {
&mut self.engine
}
pub fn process_trajectory(
&mut self,
embedding: Vec<f32>,
activations: Vec<f32>,
quality: f32,
route: Option<String>,
context: Vec<String>,
) {
let now = SystemTime::now().duration_since_epoch().as_millis() as u64;
let mut builder = self.engine.begin_trajectory(embedding.clone());
if let Some(ref r) = route {
builder.set_model_route(r);
}
for ctx in &context {
builder.add_context(ctx);
}
builder.add_step(activations, vec![], quality);
self.engine.end_trajectory(builder, quality);
self.trajectories.push(TrajectoryExport {
embedding,
quality,
route,
context,
timestamp: now,
});
self.quality_samples.push(quality);
}
pub fn apply_micro_lora(&self, input: &[f32], output: &mut [f32]) {
self.engine.apply_micro_lora(input, output);
}
pub fn trajectory_count(&self) -> usize {
self.trajectories.len()
}
pub fn avg_quality(&self) -> f32 {
if self.quality_samples.is_empty() {
0.0
} else {
self.quality_samples.iter().sum::<f32>() / self.quality_samples.len() as f32
}
}
pub fn force_learn(&self) -> String {
self.engine.force_learn()
}
pub fn process_task(&mut self, embedding: Vec<f32>, quality: f32) {
self.process_trajectory(embedding.clone(), embedding, quality, None, vec![]);
}
pub fn process_task_with_route(&mut self, embedding: Vec<f32>, quality: f32, route: &str) {
self.process_trajectory(
embedding.clone(),
embedding,
quality,
Some(route.to_string()),
vec![],
);
}
pub fn average_quality(&self) -> f32 {
self.avg_quality()
}
pub fn uptime_seconds(&self) -> u64 {
let now = SystemTime::now().duration_since_epoch().as_millis() as u64;
(now - self.start_time) / 1000
}
pub fn stats(&self) -> AgentExportStats {
let engine_stats = self.engine.stats();
AgentExportStats {
total_trajectories: self.trajectories.len(),
avg_quality: self.avg_quality(),
patterns_learned: engine_stats.patterns_stored,
}
}
pub fn clear(&mut self) {
self.trajectories.clear();
self.quality_samples.clear();
}
pub fn get_patterns(&self) -> Vec<LearnedPattern> {
self.engine.find_patterns(&[], 0)
}
pub fn export_state(&self) -> AgentExport {
let now = SystemTime::now().duration_since_epoch().as_millis() as u64;
self.engine.force_learn();
let stats = self.engine.stats();
AgentExport {
agent_id: self.agent_id.clone(),
trajectories: self.trajectories.clone(),
stats: AgentExportStats {
total_trajectories: self.trajectories.len(),
avg_quality: self.avg_quality(),
patterns_learned: stats.patterns_stored,
},
session_duration_ms: now - self.start_time,
timestamp: now,
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct AgentContribution {
pub trajectory_count: usize,
pub avg_quality: f32,
pub timestamp: u64,
pub session_duration_ms: u64,
}
pub struct FederatedCoordinator {
coordinator_id: String,
master_engine: SonaEngine,
contributions: HashMap<String, AgentContribution>,
quality_threshold: f32,
total_trajectories: usize,
consolidation_interval: usize,
metrics: TrainingMetrics,
}
impl FederatedCoordinator {
pub fn new(coordinator_id: impl Into<String>, config: SonaConfig) -> Self {
let id = coordinator_id.into();
Self {
coordinator_id: id.clone(),
master_engine: SonaEngine::with_config(config),
contributions: HashMap::new(),
quality_threshold: 0.4,
total_trajectories: 0,
consolidation_interval: 50,
metrics: TrainingMetrics::new(&id),
}
}
pub fn default_coordinator(coordinator_id: impl Into<String>, hidden_dim: usize) -> Self {
Self::new(
coordinator_id,
SonaConfig {
hidden_dim,
embedding_dim: hidden_dim,
micro_lora_rank: 2,
base_lora_rank: 16, trajectory_capacity: 50000, pattern_clusters: 200,
ewc_lambda: 2000.0, ..Default::default()
},
)
}
pub fn coordinator_id(&self) -> &str {
&self.coordinator_id
}
pub fn set_quality_threshold(&mut self, threshold: f32) {
self.quality_threshold = threshold;
}
pub fn set_consolidation_interval(&mut self, interval: usize) {
self.consolidation_interval = interval;
}
pub fn master_engine(&self) -> &SonaEngine {
&self.master_engine
}
pub fn aggregate(&mut self, export: AgentExport) -> AggregationResult {
let mut accepted = 0;
let mut rejected = 0;
for traj in &export.trajectories {
if traj.quality >= self.quality_threshold {
let mut builder = self.master_engine.begin_trajectory(traj.embedding.clone());
if let Some(ref route) = traj.route {
builder.set_model_route(route);
}
for ctx in &traj.context {
builder.add_context(ctx);
}
self.master_engine.end_trajectory(builder, traj.quality);
self.metrics.add_quality_sample(traj.quality);
accepted += 1;
} else {
rejected += 1;
}
}
self.total_trajectories += accepted;
let now = SystemTime::now().duration_since_epoch().as_millis() as u64;
self.contributions.insert(
export.agent_id.clone(),
AgentContribution {
trajectory_count: export.trajectories.len(),
avg_quality: export.stats.avg_quality,
timestamp: now,
session_duration_ms: export.session_duration_ms,
},
);
let consolidated = if self.should_consolidate() {
self.master_engine.force_learn();
true
} else {
false
};
AggregationResult {
agent_id: export.agent_id,
trajectories_accepted: accepted,
trajectories_rejected: rejected,
consolidated,
total_agents: self.contributions.len(),
total_trajectories: self.total_trajectories,
}
}
fn should_consolidate(&self) -> bool {
self.contributions.len() % self.consolidation_interval == 0
}
pub fn force_consolidate(&self) -> String {
self.master_engine.force_learn()
}
pub fn get_initial_patterns(&self, k: usize) -> Vec<LearnedPattern> {
self.master_engine
.find_patterns(&[], 0)
.into_iter()
.take(k)
.collect()
}
pub fn get_all_patterns(&self) -> Vec<LearnedPattern> {
self.master_engine.find_patterns(&[], 0)
}
pub fn stats(&self) -> CoordinatorStats {
let engine_stats = self.master_engine.stats();
CoordinatorStats {
coordinator_id: self.coordinator_id.clone(),
total_agents: self.contributions.len(),
total_trajectories: self.total_trajectories,
patterns_learned: engine_stats.patterns_stored,
avg_quality: self.metrics.avg_quality(),
quality_threshold: self.quality_threshold,
}
}
pub fn contributions(&self) -> &HashMap<String, AgentContribution> {
&self.contributions
}
pub fn metrics(&self) -> &TrainingMetrics {
&self.metrics
}
pub fn agent_count(&self) -> usize {
self.contributions.len()
}
pub fn total_trajectories(&self) -> usize {
self.total_trajectories
}
pub fn find_patterns(&self, query: &[f32], k: usize) -> Vec<LearnedPattern> {
self.master_engine.find_patterns(query, k)
}
pub fn apply_lora(&self, input: &[f32]) -> Vec<f32> {
let mut output = vec![0.0; input.len()];
self.master_engine.apply_micro_lora(input, &mut output);
output
}
pub fn consolidate(&self) -> String {
self.force_consolidate()
}
pub fn clear(&mut self) {
self.contributions.clear();
self.total_trajectories = 0;
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct AggregationResult {
pub agent_id: String,
pub trajectories_accepted: usize,
pub trajectories_rejected: usize,
pub consolidated: bool,
pub total_agents: usize,
pub total_trajectories: usize,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct CoordinatorStats {
pub coordinator_id: String,
pub total_agents: usize,
pub total_trajectories: usize,
pub patterns_learned: usize,
pub avg_quality: f32,
pub quality_threshold: f32,
}
impl std::fmt::Display for CoordinatorStats {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Coordinator(id={}, agents={}, trajectories={}, patterns={}, avg_quality={:.4})",
self.coordinator_id,
self.total_agents,
self.total_trajectories,
self.patterns_learned,
self.avg_quality
)
}
}
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
pub enum FederatedTopology {
#[default]
Star,
Hierarchical {
regions: usize,
},
PeerToPeer,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ephemeral_agent_creation() {
let agent = EphemeralAgent::default_federated("agent-1", 256);
assert_eq!(agent.agent_id(), "agent-1");
assert_eq!(agent.trajectory_count(), 0);
}
#[test]
fn test_trajectory_collection() {
let mut agent = EphemeralAgent::default_federated("agent-1", 256);
agent.process_trajectory(
vec![0.1; 256],
vec![0.5; 256],
0.8,
Some("code".into()),
vec!["file:main.rs".into()],
);
assert_eq!(agent.trajectory_count(), 1);
assert!((agent.avg_quality() - 0.8).abs() < 0.01);
}
#[test]
fn test_agent_export() {
let mut agent = EphemeralAgent::default_federated("agent-1", 256);
for i in 0..5 {
agent.process_trajectory(
vec![i as f32 * 0.1; 256],
vec![0.5; 256],
0.7 + i as f32 * 0.05,
None,
vec![],
);
}
let export = agent.export_state();
assert_eq!(export.agent_id, "agent-1");
assert_eq!(export.trajectories.len(), 5);
assert!(export.stats.avg_quality > 0.7);
}
#[test]
fn test_coordinator_creation() {
let coord = FederatedCoordinator::default_coordinator("coord-1", 256);
assert_eq!(coord.coordinator_id(), "coord-1");
let stats = coord.stats();
assert_eq!(stats.total_agents, 0);
assert_eq!(stats.total_trajectories, 0);
}
#[test]
fn test_aggregation() {
let mut coord = FederatedCoordinator::default_coordinator("coord-1", 256);
coord.set_quality_threshold(0.5);
let export = AgentExport {
agent_id: "agent-1".into(),
trajectories: vec![
TrajectoryExport {
embedding: vec![0.1; 256],
quality: 0.8,
route: Some("code".into()),
context: vec![],
timestamp: 0,
},
TrajectoryExport {
embedding: vec![0.2; 256],
quality: 0.3, route: None,
context: vec![],
timestamp: 0,
},
],
stats: AgentExportStats {
total_trajectories: 2,
avg_quality: 0.55,
patterns_learned: 0,
},
session_duration_ms: 1000,
timestamp: 0,
};
let result = coord.aggregate(export);
assert_eq!(result.trajectories_accepted, 1);
assert_eq!(result.trajectories_rejected, 1);
assert_eq!(result.total_agents, 1);
}
#[test]
fn test_multi_agent_aggregation() {
let mut coord = FederatedCoordinator::default_coordinator("coord-1", 256);
coord.set_consolidation_interval(2);
for i in 0..3 {
let export = AgentExport {
agent_id: format!("agent-{}", i),
trajectories: vec![TrajectoryExport {
embedding: vec![i as f32 * 0.1; 256],
quality: 0.8,
route: None,
context: vec![],
timestamp: 0,
}],
stats: AgentExportStats::default(),
session_duration_ms: 1000,
timestamp: 0,
};
let result = coord.aggregate(export);
if i == 1 {
assert!(result.consolidated);
}
}
let stats = coord.stats();
assert_eq!(stats.total_agents, 3);
assert_eq!(stats.total_trajectories, 3);
}
}