use crate::error::{NeuralError, Result};
use crate::federated::{AggregationStrategy, ClientUpdate};
use crate::models::sequential::Sequential;
use scirs2_core::ndarray::prelude::*;
use std::sync::{Arc, RwLock};
#[derive(Debug, Clone)]
pub struct ServerConfig {
pub min_clients: usize,
pub round_timeout: u64,
pub aggregation_strategy: String,
pub adaptive_aggregation: bool,
pub staleness_threshold: usize,
pub async_updates: bool,
}
impl From<&crate::federated::FederatedConfig> for ServerConfig {
fn from(config: &crate::federated::FederatedConfig) -> Self {
Self {
min_clients: config.min_clients,
round_timeout: 300,
aggregation_strategy: config.aggregation_strategy.clone(),
adaptive_aggregation: false,
async_updates: false,
staleness_threshold: 5,
}
}
}
pub struct FederatedServer {
config: ServerConfig,
global_model_state: Arc<RwLock<ModelState>>,
aggregator: Box<dyn AggregationStrategy>,
pub current_round: usize,
client_contributions: ClientContributions,
}
#[derive(Clone)]
struct ModelState {
parameters: Vec<Array2<f32>>,
version: usize,
last_updated: std::time::Instant,
}
struct ClientContributions {
samples_per_client: std::collections::HashMap<usize, usize>,
rounds_per_client: std::collections::HashMap<usize, usize>,
performance_history: std::collections::HashMap<usize, Vec<f32>>,
}
impl FederatedServer {
pub fn new(config: ServerConfig) -> Result<Self> {
let aggregator: Box<dyn AggregationStrategy> = match config.aggregation_strategy.as_str() {
"fedavg" => Box::new(crate::federated::FedAvg::new()),
"fedprox" => Box::new(crate::federated::FedProx::new(0.01)),
"fedyogi" => Box::new(crate::federated::FedYogi::new()),
_ => Box::new(crate::federated::FedAvg::new()),
};
Ok(Self {
config,
global_model_state: Arc::new(RwLock::new(ModelState {
parameters: Vec::new(),
version: 0,
last_updated: std::time::Instant::now(),
})),
aggregator,
current_round: 0,
client_contributions: ClientContributions {
samples_per_client: std::collections::HashMap::new(),
rounds_per_client: std::collections::HashMap::new(),
performance_history: std::collections::HashMap::new(),
},
})
}
pub fn get_model_parameters(&self, _model: &Sequential<f32>) -> Result<Vec<Array2<f32>>> {
let state = self
.global_model_state
.read()
.map_err(|_| NeuralError::InferenceError("RwLock poisoned".to_string()))?;
if state.parameters.is_empty() {
Ok(vec![Array2::zeros((10, 10)); 5])
} else {
Ok(state.parameters.clone())
}
}
pub fn aggregate_updates(&mut self, updates: &[ClientUpdate]) -> Result<AggregatedUpdate> {
if updates.len() < self.config.min_clients {
return Err(NeuralError::InvalidArgument(format!(
"Not enough clients: {} < {}",
updates.len(),
self.config.min_clients
)));
}
for update in updates {
*self
.client_contributions
.samples_per_client
.entry(update.client_id)
.or_insert(0) += update.num_samples;
*self
.client_contributions
.rounds_per_client
.entry(update.client_id)
.or_insert(0) += 1;
self.client_contributions
.performance_history
.entry(update.client_id)
.or_default()
.push(update.accuracy);
}
let weights = if self.config.adaptive_aggregation {
self.calculate_adaptive_weights(updates)
} else {
self.calculate_sample_weights(updates)
};
let aggregated_weights = self.aggregator.aggregate(updates, &weights)?;
self.current_round += 1;
Ok(AggregatedUpdate {
aggregated_weights,
round: self.current_round,
num_clients: updates.len(),
total_samples: updates.iter().map(|u| u.num_samples).sum(),
})
}
pub fn calculate_sample_weights(&self, updates: &[ClientUpdate]) -> Vec<f32> {
let total_samples: usize = updates.iter().map(|u| u.num_samples).sum();
if total_samples == 0 {
return vec![1.0 / updates.len() as f32; updates.len()];
}
updates
.iter()
.map(|u| u.num_samples as f32 / total_samples as f32)
.collect()
}
fn calculate_adaptive_weights(&self, updates: &[ClientUpdate]) -> Vec<f32> {
let mut weights: Vec<f32> = updates
.iter()
.map(|update| {
let sample_weight = (update.num_samples as f32).max(1.0);
let performance_factor = self
.client_contributions
.performance_history
.get(&update.client_id)
.map(|history| {
let recent: Vec<f32> = history.iter().rev().take(5).copied().collect();
if recent.is_empty() {
1.0
} else {
recent.iter().sum::<f32>() / recent.len() as f32
}
})
.unwrap_or(1.0);
let rounds = self
.client_contributions
.rounds_per_client
.get(&update.client_id)
.copied()
.unwrap_or(1) as f32;
let participation_factor = (rounds / self.current_round.max(1) as f32).sqrt();
sample_weight * performance_factor * participation_factor
})
.collect();
let sum: f32 = weights.iter().sum();
if sum > 0.0 {
for w in &mut weights {
*w /= sum;
}
} else {
let equal_weight = 1.0 / weights.len().max(1) as f32;
weights.fill(equal_weight);
}
weights
}
pub fn update_global_model(
&mut self,
_model: &mut Sequential<f32>,
update: &AggregatedUpdate,
) -> Result<()> {
let mut state = self
.global_model_state
.write()
.map_err(|_| NeuralError::InferenceError("RwLock poisoned".to_string()))?;
state.parameters = update.aggregated_weights.clone();
state.version += 1;
state.last_updated = std::time::Instant::now();
Ok(())
}
pub fn get_statistics(&self) -> ServerStatistics {
let total_clients = self.client_contributions.samples_per_client.len();
let active_clients = self
.client_contributions
.rounds_per_client
.values()
.filter(|&&rounds| rounds > self.current_round.saturating_sub(5))
.count();
let total_samples: usize = self.client_contributions.samples_per_client.values().sum();
let avg_rounds_per_client = if total_clients > 0 {
self.client_contributions
.rounds_per_client
.values()
.sum::<usize>() as f32
/ total_clients as f32
} else {
0.0
};
let model_version = self
.global_model_state
.read()
.map(|s| s.version)
.unwrap_or(0);
ServerStatistics {
current_round: self.current_round,
total_clients,
active_clients,
total_samples_processed: total_samples,
average_rounds_per_client: avg_rounds_per_client,
model_version,
}
}
pub fn reset(&mut self) {
self.current_round = 0;
self.client_contributions = ClientContributions {
samples_per_client: std::collections::HashMap::new(),
rounds_per_client: std::collections::HashMap::new(),
performance_history: std::collections::HashMap::new(),
};
if let Ok(mut state) = self.global_model_state.write() {
state.parameters.clear();
state.version = 0;
}
}
}
pub struct AggregatedUpdate {
pub aggregated_weights: Vec<Array2<f32>>,
pub round: usize,
pub num_clients: usize,
pub total_samples: usize,
}
pub struct ServerStatistics {
pub current_round: usize,
pub total_clients: usize,
pub active_clients: usize,
pub total_samples_processed: usize,
pub average_rounds_per_client: f32,
pub model_version: usize,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::federated::FederatedConfig;
#[test]
fn test_server_creation() {
let config = FederatedConfig::default();
let server_config = ServerConfig::from(&config);
let server = FederatedServer::new(server_config).expect("FederatedServer::new failed");
assert_eq!(server.current_round, 0);
}
#[test]
fn test_sample_weights() {
let config = ServerConfig::from(&FederatedConfig::default());
let server = FederatedServer::new(config).expect("FederatedServer::new failed");
let updates = vec![
ClientUpdate {
client_id: 0,
weight_updates: vec![],
num_samples: 100,
loss: 0.5,
accuracy: 0.9,
},
ClientUpdate {
client_id: 1,
weight_updates: vec![],
num_samples: 200,
loss: 0.4,
accuracy: 0.92,
},
];
let weights = server.calculate_sample_weights(&updates);
assert_eq!(weights.len(), 2);
assert!((weights[0] - 1.0 / 3.0).abs() < 0.001);
assert!((weights[1] - 2.0 / 3.0).abs() < 0.001);
}
}