use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use crate::federated_learning_v2_backup::types::*;
use trustformers_core::{Result, CoreError, Tensor};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FederatedTrainingConfig {
pub num_rounds: u32,
pub min_participants_per_round: u32,
pub max_participants_per_round: u32,
pub participant_selection: ParticipantSelectionStrategy,
pub model_averaging: ModelAveragingStrategy,
pub convergence_criteria: ConvergenceCriteria,
pub early_stopping: EarlyStoppingConfig,
pub round_timeout_seconds: u64,
pub validation_frequency: u32,
pub checkpoint_frequency: u32,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum ParticipantSelectionStrategy {
Random,
RoundRobin,
DataQuality,
DeviceCapabilities,
TrustBased,
Hybrid,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum ModelAveragingStrategy {
FedAvg,
WeightedFedAvg,
FedProx,
Scaffold,
FedNova,
AdaptiveFedOpt,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConvergenceCriteria {
pub target_accuracy: f64,
pub loss_improvement_threshold: f64,
pub min_improvement_rounds: u32,
pub max_rounds_without_improvement: u32,
pub gradient_norm_threshold: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EarlyStoppingConfig {
pub enabled: bool,
pub patience: u32,
pub min_delta: f64,
pub monitor_metric: MonitorMetric,
pub mode: ImprovementMode,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum MonitorMetric {
ValidationLoss,
ValidationAccuracy,
TrainingLoss,
ModelStability,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum ImprovementMode {
Min,
Max,
Auto,
}
#[derive(Debug, Clone)]
pub struct FederatedTrainingState {
pub current_round: u32,
pub global_model: Option<Tensor>,
pub round_statistics: Vec<RoundStatistics>,
pub convergence_metrics: ConvergenceMetrics,
pub training_progress: TrainingProgress,
}
#[derive(Debug)]
pub struct FederatedTrainingCoordinator {
config: FederatedTrainingConfig,
training_state: FederatedTrainingState,
participant_pool: HashMap<String, ParticipantInfo>,
selected_participants: Vec<String>,
round_start_time: Option<u64>,
best_model: Option<Tensor>,
best_metric_value: f64,
rounds_without_improvement: u32,
}
impl FederatedTrainingCoordinator {
pub fn new(config: FederatedTrainingConfig) -> Self {
Self {
config,
training_state: FederatedTrainingState::default(),
participant_pool: HashMap::new(),
selected_participants: Vec::new(),
round_start_time: None,
best_model: None,
best_metric_value: 0.0,
rounds_without_improvement: 0,
}
}
pub fn add_participant(&mut self, participant: ParticipantInfo) -> Result<()> {
self.participant_pool.insert(participant.id.clone(), participant);
Ok(())
}
pub fn remove_participant(&mut self, participant_id: &str) -> Result<()> {
self.participant_pool.remove(participant_id);
self.selected_participants.retain(|id| id != participant_id);
Ok(())
}
pub fn start_round(&mut self) -> Result<Vec<String>> {
if self.training_state.current_round >= self.config.num_rounds {
return Err(TrustformersError::InvalidConfiguration("Maximum rounds reached".to_string()).into());
}
if self.check_convergence()? {
self.training_state.training_progress = TrainingProgress::Converged;
return Ok(Vec::new().into());
}
if self.should_stop_early()? {
self.training_state.training_progress = TrainingProgress::Stopped;
return Ok(Vec::new());
}
self.selected_participants = self.select_participants()?;
if self.selected_participants.len() < self.config.min_participants_per_round as usize {
return Err(TrustformersError::InvalidConfiguration("Insufficient participants for round".to_string()).into());
}
self.training_state.current_round += 1;
self.training_state.training_progress = TrainingProgress::Training;
self.round_start_time = Some(self.get_current_timestamp().into());
Ok(self.selected_participants.clone())
}
pub fn complete_round(&mut self, participant_updates: HashMap<String, Tensor>) -> Result<Tensor> {
if participant_updates.is_empty() {
return Err(TrustformersError::InvalidConfiguration("No participant updates received".to_string()).into());
}
let aggregated_model = self.aggregate_updates(participant_updates)?;
self.training_state.global_model = Some(aggregated_model.clone().into());
let round_duration = self.round_start_time
.map(|start| self.get_current_timestamp() - start)
.unwrap_or(0);
let round_stats = RoundStatistics {
round: self.training_state.current_round,
participants: self.selected_participants.len() as u32,
duration_seconds: round_duration,
avg_update_quality: 0.8, communication_overhead_mb: 10.0, privacy_budget_consumed: 0.1, };
self.training_state.round_statistics.push(round_stats);
self.update_convergence_metrics(&aggregated_model)?;
self.check_improvement()?;
self.round_start_time = None;
Ok(aggregated_model)
}
fn select_participants(&self) -> Result<Vec<String>> {
let available_participants: Vec<_> = self.participant_pool.keys().cloned().collect();
if available_participants.len() < self.config.min_participants_per_round as usize {
return Err(TrustformersError::InvalidConfiguration("Insufficient available participants".to_string()).into());
}
let num_to_select = (self.config.max_participants_per_round as usize)
.min(available_participants.len().into());
let selected = match self.config.participant_selection {
ParticipantSelectionStrategy::Random => {
self.select_random_participants(&available_participants, num_to_select)
}
ParticipantSelectionStrategy::RoundRobin => {
self.select_round_robin_participants(&available_participants, num_to_select)
}
ParticipantSelectionStrategy::DataQuality => {
self.select_by_data_quality(&available_participants, num_to_select)
}
ParticipantSelectionStrategy::DeviceCapabilities => {
self.select_by_device_capabilities(&available_participants, num_to_select)
}
ParticipantSelectionStrategy::TrustBased => {
self.select_by_trust_scores(&available_participants, num_to_select)
}
ParticipantSelectionStrategy::Hybrid => {
self.select_hybrid(&available_participants, num_to_select)
}
};
Ok(selected)
}
fn select_random_participants(&self, participants: &[String], num_to_select: usize) -> Vec<String> {
participants.iter().take(num_to_select).cloned().collect()
}
fn select_round_robin_participants(&self, participants: &[String], num_to_select: usize) -> Vec<String> {
let start_index = (self.training_state.current_round as usize) % participants.len();
let mut selected = Vec::new();
for i in 0..num_to_select {
let index = (start_index + i) % participants.len();
selected.push(participants[index].clone());
}
selected
}
fn select_by_data_quality(&self, participants: &[String], num_to_select: usize) -> Vec<String> {
let mut scored_participants: Vec<_> = participants
.iter()
.map(|id| {
let participant = &self.participant_pool[id];
let quality_score = if participant.participation_history.is_empty() {
0.5
} else {
participant.participation_history.iter()
.map(|record| record.update_quality)
.sum::<f64>() / participant.participation_history.len() as f64
};
(id.clone(), quality_score)
})
.collect();
scored_participants.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scored_participants.into_iter().take(num_to_select).map(|(id, _)| id).collect()
}
fn select_by_device_capabilities(&self, participants: &[String], num_to_select: usize) -> Vec<String> {
let mut scored_participants: Vec<_> = participants
.iter()
.map(|id| {
let participant = &self.participant_pool[id];
let capability_score = match participant.device_capabilities.compute_capability {
ComputeCapability::VeryHigh => 4.0,
ComputeCapability::High => 3.0,
ComputeCapability::Medium => 2.0,
ComputeCapability::Low => 1.0,
};
(id.clone(), capability_score)
})
.collect();
scored_participants.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scored_participants.into_iter().take(num_to_select).map(|(id, _)| id).collect()
}
fn select_by_trust_scores(&self, participants: &[String], num_to_select: usize) -> Vec<String> {
let mut scored_participants: Vec<_> = participants
.iter()
.map(|id| {
let participant = &self.participant_pool[id];
(id.clone(), participant.trust_score)
})
.collect();
scored_participants.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scored_participants.into_iter().take(num_to_select).map(|(id, _)| id).collect()
}
fn select_hybrid(&self, participants: &[String], num_to_select: usize) -> Vec<String> {
let mut scored_participants: Vec<_> = participants
.iter()
.map(|id| {
let participant = &self.participant_pool[id];
let quality_score = if participant.participation_history.is_empty() {
0.5
} else {
participant.participation_history.iter()
.map(|record| record.update_quality)
.sum::<f64>() / participant.participation_history.len() as f64
};
let capability_score = match participant.device_capabilities.compute_capability {
ComputeCapability::VeryHigh => 1.0,
ComputeCapability::High => 0.8,
ComputeCapability::Medium => 0.6,
ComputeCapability::Low => 0.4,
};
let combined_score = 0.4 * participant.trust_score + 0.3 * quality_score + 0.3 * capability_score;
(id.clone(), combined_score)
})
.collect();
scored_participants.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scored_participants.into_iter().take(num_to_select).map(|(id, _)| id).collect()
}
fn aggregate_updates(&self, updates: HashMap<String, Tensor>) -> Result<Tensor> {
if updates.is_empty() {
return Err(TrustformersError::InvalidConfiguration("No updates to aggregate".to_string()).into());
}
match self.config.model_averaging {
ModelAveragingStrategy::FedAvg => self.federated_averaging(&updates),
ModelAveragingStrategy::WeightedFedAvg => self.weighted_federated_averaging(&updates),
ModelAveragingStrategy::FedProx => self.federated_proximal(&updates),
ModelAveragingStrategy::Scaffold => self.scaffold_aggregation(&updates),
ModelAveragingStrategy::FedNova => self.fed_nova_aggregation(&updates),
ModelAveragingStrategy::AdaptiveFedOpt => self.adaptive_fed_opt(&updates),
}
}
fn federated_averaging(&self, updates: &HashMap<String, Tensor>) -> Result<Tensor> {
let first_update = updates.values().next()
.ok_or_else(|| TrustformersError::other("No updates provided".to_string()))?;
let mut aggregated_data = vec![0.0f32; first_update.data()?.len()];
for update in updates.values() {
let data = update.data()?;
for (i, &value) in data.iter().enumerate() {
aggregated_data[i] += value;
}
}
let num_updates = updates.len() as f32;
for value in &mut aggregated_data {
*value /= num_updates;
}
Tensor::from_vec(aggregated_data, first_update.shape())
}
fn weighted_federated_averaging(&self, updates: &HashMap<String, Tensor>) -> Result<Tensor> {
let first_update = updates.values().next()
.ok_or_else(|| TrustformersError::other("No updates provided".to_string()))?;
let mut aggregated_data = vec![0.0f32; first_update.data()?.len()];
let mut total_weight = 0.0;
for (participant_id, update) in updates {
let participant = &self.participant_pool[participant_id];
let weight = participant.trust_score; total_weight += weight;
let data = update.data()?;
for (i, &value) in data.iter().enumerate() {
aggregated_data[i] += value * weight as f32;
}
}
if total_weight > 0.0 {
for value in &mut aggregated_data {
*value /= total_weight as f32;
}
}
Tensor::from_vec(aggregated_data, first_update.shape())
}
fn federated_proximal(&self, updates: &HashMap<String, Tensor>) -> Result<Tensor> {
self.federated_averaging(updates)
}
fn scaffold_aggregation(&self, updates: &HashMap<String, Tensor>) -> Result<Tensor> {
self.federated_averaging(updates)
}
fn fed_nova_aggregation(&self, updates: &HashMap<String, Tensor>) -> Result<Tensor> {
self.weighted_federated_averaging(updates)
}
fn adaptive_fed_opt(&self, updates: &HashMap<String, Tensor>) -> Result<Tensor> {
self.weighted_federated_averaging(updates)
}
fn update_convergence_metrics(&mut self, model: &Tensor) -> Result<()> {
let data = model.data()?;
let gradient_norm: f64 = data.iter().map(|&x| (x as f64).powi(2)).sum::<f64>().sqrt();
self.training_state.convergence_metrics.gradient_norm = gradient_norm;
self.training_state.convergence_metrics.current_loss = gradient_norm; self.training_state.convergence_metrics.current_accuracy = (1.0 / (1.0 + gradient_norm)).min(1.0);
if self.training_state.convergence_metrics.current_accuracy > self.training_state.convergence_metrics.best_accuracy {
self.training_state.convergence_metrics.best_accuracy = self.training_state.convergence_metrics.current_accuracy;
}
Ok(())
}
fn check_convergence(&self) -> Result<bool> {
let metrics = &self.training_state.convergence_metrics;
if metrics.current_accuracy >= self.config.convergence_criteria.target_accuracy {
return Ok(true);
}
if metrics.gradient_norm <= self.config.convergence_criteria.gradient_norm_threshold {
return Ok(true);
}
if self.rounds_without_improvement >= self.config.convergence_criteria.max_rounds_without_improvement {
return Ok(true);
}
Ok(false)
}
fn check_improvement(&mut self) -> Result<()> {
let current_metric = match self.config.early_stopping.monitor_metric {
MonitorMetric::ValidationAccuracy => self.training_state.convergence_metrics.current_accuracy,
MonitorMetric::ValidationLoss => -self.training_state.convergence_metrics.current_loss, MonitorMetric::TrainingLoss => -self.training_state.convergence_metrics.current_loss,
MonitorMetric::ModelStability => self.training_state.convergence_metrics.stability_score,
};
let improvement = match self.config.early_stopping.mode {
ImprovementMode::Max => current_metric > self.best_metric_value + self.config.early_stopping.min_delta,
ImprovementMode::Min => current_metric < self.best_metric_value - self.config.early_stopping.min_delta,
ImprovementMode::Auto => {
match self.config.early_stopping.monitor_metric {
MonitorMetric::ValidationAccuracy => current_metric > self.best_metric_value + self.config.early_stopping.min_delta,
_ => current_metric < self.best_metric_value - self.config.early_stopping.min_delta,
}
}
};
if improvement {
self.best_metric_value = current_metric;
self.rounds_without_improvement = 0;
if let Some(model) = &self.training_state.global_model {
self.best_model = Some(model.clone().into());
}
} else {
self.rounds_without_improvement += 1;
}
Ok(())
}
fn should_stop_early(&self) -> Result<bool> {
if !self.config.early_stopping.enabled {
return Ok(false);
}
Ok(self.rounds_without_improvement >= self.config.early_stopping.patience)
}
pub fn get_training_state(&self) -> &FederatedTrainingState {
&self.training_state
}
pub fn get_best_model(&self) -> Option<&Tensor> {
self.best_model.as_ref()
}
fn get_current_timestamp(&self) -> u64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
}
}
impl Default for FederatedTrainingConfig {
fn default() -> Self {
Self {
num_rounds: 100,
min_participants_per_round: 2,
max_participants_per_round: 100,
participant_selection: ParticipantSelectionStrategy::Hybrid,
model_averaging: ModelAveragingStrategy::WeightedFedAvg,
convergence_criteria: ConvergenceCriteria::default(),
early_stopping: EarlyStoppingConfig::default(),
round_timeout_seconds: 3600,
validation_frequency: 5,
checkpoint_frequency: 10,
}
}
}
impl Default for ConvergenceCriteria {
fn default() -> Self {
Self {
target_accuracy: 0.95,
loss_improvement_threshold: 0.001,
min_improvement_rounds: 5,
max_rounds_without_improvement: 20,
gradient_norm_threshold: 0.01,
}
}
}
impl Default for EarlyStoppingConfig {
fn default() -> Self {
Self {
enabled: true,
patience: 10,
min_delta: 0.001,
monitor_metric: MonitorMetric::ValidationAccuracy,
mode: ImprovementMode::Auto,
}
}
}
impl Default for FederatedTrainingState {
fn default() -> Self {
Self {
current_round: 0,
global_model: None,
round_statistics: Vec::new(),
convergence_metrics: ConvergenceMetrics {
current_accuracy: 0.0,
best_accuracy: 0.0,
current_loss: f64::INFINITY,
loss_reduction_rate: 0.0,
gradient_norm: 0.0,
stability_score: 0.0,
},
training_progress: TrainingProgress::Initializing,
}
}
}