use crate::error::{NeuralError, Result};
use crate::federated::{AggregationStrategy, ClientUpdate};
use crate::models::sequential::Sequential;
use scirs2_core::ndarray::prelude::*;
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub enum PersonalizationStrategy {
FineTuning {
epochs: usize,
learning_rate: f32,
},
MetaLearning {
inner_steps: usize,
outer_lr: f32,
inner_lr: f32,
MultiTask {
shared_layers: usize,
task_head_sizes: Vec<usize>,
Clustering {
num_clusters: usize,
method: ClusteringMethod,
MixtureOfExperts {
num_experts: usize,
gating_hidden_size: usize,
}
pub enum ClusteringMethod {
KMeansParameters,
KMeansLoss,
Hierarchical,
Spectral,
pub struct PersonalizedFL {
global_model: Option<Sequential<f32>>,
client_models: HashMap<usize, Sequential<f32>>,
strategy: PersonalizationStrategy,
client_stats: HashMap<usize, ClientStatistics>,
cluster_assignments: HashMap<usize, usize>,
cluster_models: HashMap<usize, Sequential<f32>>,
personalization_history: Vec<PersonalizationRound>,
pub struct ClientStatistics {
pub label_distribution: Vec<f32>,
pub task_performance: HashMap<String, f32>,
pub param_stats: ParameterStatistics,
pub gradient_stats: GradientStatistics,
pub struct ParameterStatistics {
pub layer_norms: Vec<f32>,
pub layer_means: Vec<f32>,
pub layer_variances: Vec<f32>,
pub struct GradientStatistics {
pub global_similarities: Vec<f32>,
pub struct PersonalizationRound {
pub round: usize,
pub pre_personalization_performance: HashMap<usize, f32>,
pub post_personalization_performance: HashMap<usize, f32>,
pub improvements: HashMap<usize, f32>,
impl PersonalizedFL {
pub fn new(strategy: PersonalizationStrategy) -> Self {
Self {
global_model: None,
client_models: HashMap::new(),
strategy,
client_stats: HashMap::new(),
cluster_assignments: HashMap::new(),
cluster_models: HashMap::new(),
personalization_history: Vec::new(),
}
}
pub fn set_global_model(&mut self, model: Sequential<f32>) {
self.global_model = Some(model);
pub fn personalize_for_client(
&mut self,
client_id: usize,
client_data: &ArrayView2<f32>,
client_labels: &ArrayView1<usize>,
validation_data: Option<(&ArrayView2<f32>, &ArrayView1<usize>)>,
) -> Result<Sequential<f32>> {
match &self.strategy {
PersonalizationStrategy::FineTuning {
epochs,
learning_rate,
} => self.fine_tune_for_client(
client_id,
client_data,
client_labels,
*epochs,
*learning_rate,
),
PersonalizationStrategy::MetaLearning {
inner_steps,
outer_lr,
inner_lr,
} => self.meta_learn_for_client(
*inner_steps,
*outer_lr,
*inner_lr,
PersonalizationStrategy::MultiTask {
shared_layers,
task_head_sizes,
} => self.multi_task_for_client(
*shared_layers,
PersonalizationStrategy::Clustering {
num_clusters,
method,
} => self.cluster_based_personalization(
*num_clusters,
PersonalizationStrategy::MixtureOfExperts {
num_experts,
gating_hidden_size,
} => self.mixture_of_experts_for_client(
*num_experts,
*gating_hidden_size,
fn fine_tune_for_client(
let mut personalized_model = if let Some(existing) = self.client_models.get(&client_id) {
existing.clone()
} else if let Some(ref global) = self.global_model {
global.clone()
} else {
return Err(NeuralError::InvalidArgument(
"No global model available".to_string(),
));
};
for epoch in 0..epochs {
let batch_size = 32.min(client_data.nrows());
let num_batches = (client_data.nrows() + batch_size - 1) / batch_size;
for batch_idx in 0..num_batches {
let start = batch_idx * batch_size;
let end = ((batch_idx + 1) * batch_size).min(client_data.nrows());
let batch_data = client_data.slice(s![start..end, ..]);
let batch_labels = client_labels.slice(s![start..end]);
let _loss = self.compute_loss(&personalized_model, &batch_data, &batch_labels)?;
}
self.client_models
.insert(client_id, personalized_model.clone());
Ok(personalized_model)
fn meta_learn_for_client(
if self.global_model.is_none() {
"No global model for meta-learning".to_string(),
let global_model = self.global_model.as_ref().expect("Operation failed");
let mut adapted_model = global_model.clone();
let split_point = client_data.nrows() / 2;
let support_data = client_data.slice(s![..split_point, ..]);
let support_labels = client_labels.slice(s![..split_point]);
let query_data = client_data.slice(s![split_point.., ..]);
let query_labels = client_labels.slice(s![split_point..]);
for _ in 0..inner_steps {
let loss = self.compute_loss(&adapted_model, &support_data, &support_labels)?;
let query_loss = self.compute_loss(&adapted_model, &query_data, &query_labels)?;
self.client_models.insert(client_id, adapted_model.clone());
Ok(adapted_model)
fn multi_task_for_client(
task_head_sizes: &[usize],
let mut personalized_model = if let Some(ref global) = self.global_model {
Sequential::new()
let epochs = 10;
for _epoch in 0..epochs {
let _loss = self.compute_loss(&personalized_model, client_data, client_labels)?;
fn cluster_based_personalization(
method: &ClusteringMethod,
// Update client statistics
self.update_client_statistics(client_id, client_data, client_labels)?;
if self.cluster_assignments.is_empty() {
self.perform_clustering(num_clusters, method)?;
let cluster_id = self
.cluster_assignments
.get(&client_id)
.copied()
.unwrap_or(0);
let cluster_model = if let Some(model) = self.cluster_models.get(&cluster_id) {
model.clone()
let model = global.clone();
self.cluster_models.insert(cluster_id, model.clone());
model
"No model available for clustering".to_string(),
let personalized_model =
self.fine_tune_for_client(client_id, client_data, client_labels, 5, 0.01)?;
fn mixture_of_experts_for_client(
self.fine_tune_for_client(client_id, client_data, client_labels, 10, 0.01)
fn update_client_statistics(
) -> Result<()> {
let num_classes = client_labels.iter().cloned().max().unwrap_or(0) + 1;
let mut label_counts = vec![0; num_classes];
for &label in client_labels {
if label < num_classes {
label_counts[label] += 1;
let total = label_counts.iter().sum::<usize>() as f32;
let label_distribution: Vec<f32> = label_counts
.iter()
.map(|&count| count as f32 / total)
.collect();
let param_stats = ParameterStatistics {
layer_norms: vec![1.0; 5], layer_means: vec![0.0; 5], layer_variances: vec![1.0; 5], let gradient_stats = GradientStatistics {
layer_norms: vec![0.1; 5], global_similarities: vec![0.8; 5], let stats = ClientStatistics {
label_distribution,
task_performance: HashMap::new(),
param_stats,
gradient_stats,
self.client_stats.insert(client_id, stats);
Ok(())
fn perform_clustering(&mut self, numclusters: usize, method: &ClusteringMethod) -> Result<()> {
let client_ids: Vec<usize> = self.client_stats.keys().cloned().collect();
match method {
ClusteringMethod::KMeansParameters => {
self.kmeans_clustering_parameters(&client_ids, num_clusters)?;
ClusteringMethod::KMeansLoss => {
self.kmeans_clustering_loss(&client_ids, num_clusters)?;
ClusteringMethod::Hierarchical => {
self.hierarchical_clustering(&client_ids, num_clusters)?;
ClusteringMethod::Spectral => {
self.spectral_clustering(&client_ids, num_clusters)?;
fn kmeans_clustering_parameters(
client_ids: &[usize],
use scirs2_core::random::prelude::*;
use scirs2_core::ndarray::ArrayView1;
let mut rng = rng();
for &client_id in client_ids {
let cluster = rng.random_range(0..num_clusters);
self.cluster_assignments.insert(client_id..cluster);
for _iter in 0..10 {
let mut centroids = vec![vec![0.0; 10]; num_clusters]; let mut cluster_counts = vec![0; num_clusters];
for &client_id in client_ids {
if let (Some(cluster), Some(stats)) = (
self.cluster_assignments.get(&client_id),
self.client_stats.get(&client_id),
) {
cluster_counts[*cluster] += 1;
for (i, &val) in stats.label_distribution.iter().enumerate() {
if i < centroids[*cluster].len() {
centroids[*cluster][i] += val;
}
}
}
for (centroid, &count) in centroids.iter_mut().zip(&cluster_counts) {
if count > 0 {
for val in centroid.iter_mut() {
*val /= count as f32;
if let Some(stats) = self.client_stats.get(&client_id) {
let mut best_cluster = 0;
let mut best_distance = f32::INFINITY;
for (cluster_id, centroid) in centroids.iter().enumerate() {
let distance =
self.compute_distribution_distance(&stats.label_distribution, centroid);
if distance < best_distance {
best_distance = distance;
best_cluster = cluster_id;
self.cluster_assignments.insert(client_id, best_cluster);
fn kmeans_clustering_loss(&mut self, client_ids: &[usize], numclusters: usize) -> Result<()> {
self.kmeans_clustering_parameters(client_ids, num_clusters)
fn hierarchical_clustering(&mut self, client_ids: &[usize], numclusters: usize) -> Result<()> {
fn spectral_clustering(&mut self, client_ids: &[usize], numclusters: usize) -> Result<()> {
fn compute_distribution_distance(&self, dist1: &[f32], dist2: &[f32]) -> f32 {
let mut distance = 0.0;
for (p, q) in dist1.iter().zip(dist2.iter()) {
if *p > 0.0 && *q > 0.0 {
distance += p * (p / q).ln();
distance
fn compute_loss(
&self,
model: &Sequential<f32>,
data: &ArrayView2<f32>,
labels: &ArrayView1<usize>,
) -> Result<f32> {
Ok(0.5) pub fn evaluate_personalization(
round: usize,
client_evaluations: &[(usize, f32, f32)], ) -> PersonalizationRound {
let mut pre_performance = HashMap::new();
let mut post_performance = HashMap::new();
let mut improvements = HashMap::new();
for &(client_id, pre_perf, post_perf) in client_evaluations {
pre_performance.insert(client_id, pre_perf);
post_performance.insert(client_id, post_perf);
improvements.insert(client_id, post_perf - pre_perf);
let round_info = PersonalizationRound {
round,
pre_personalization_performance: pre_performance,
post_personalization_performance: post_performance,
improvements,
self.personalization_history.push(round_info.clone());
round_info
pub fn get_personalization_stats(&self) -> PersonalizationStats {
if self.personalization_history.is_empty() {
return PersonalizationStats::default();
let latest_round = self.personalization_history.last().expect("Operation failed");
let avg_improvement = latest_round.improvements.values().sum::<f32>()
/ latest_round.improvements.len() as f32;
let total_clients_personalized = self.client_models.len();
PersonalizationStats {
average_improvement: avg_improvement,
clients_personalized: total_clients_personalized,
total_rounds: self.personalization_history.len(),
cluster_assignments: self.cluster_assignments.clone(),
#[derive(Debug, Default)]
pub struct PersonalizationStats {
pub average_improvement: f32,
pub clients_personalized: usize,
pub total_rounds: usize,
pub cluster_assignments: HashMap<usize, usize>,
pub struct PersonalizedAggregation {
global_weight: f32,
personal_weight: f32,
personalizer: PersonalizedFL,
impl PersonalizedAggregation {
pub fn new(
global_weight: f32,
personal_weight: f32,
strategy: PersonalizationStrategy,
) -> Self {
global_weight,
personal_weight,
personalizer: PersonalizedFL::new(strategy),
impl AggregationStrategy for PersonalizedAggregation {
fn aggregate(&mut self, updates: &[ClientUpdate], weights: &[f32]) -> Result<Vec<Array2<f32>>> {
let num_tensors = updates[0].weight_updates.len();
let mut aggregated = Vec::with_capacity(num_tensors);
for tensor_idx in 0..num_tensors {
let shape = updates[0].weight_updates[tensor_idx].shape();
let mut weighted_sum = Array2::zeros((shape[0], shape[1]));
for (update, &weight) in updates.iter().zip(weights.iter()) {
if tensor_idx < update.weight_updates.len() {
weighted_sum = weighted_sum + weight * &update.weight_updates[tensor_idx];
weighted_sum = weighted_sum * (self.global_weight + self.personal_weight);
aggregated.push(weighted_sum);
Ok(aggregated)
fn name(&self) -> &str {
"PersonalizedAggregation"
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_personalized_fl_creation() {
let strategy = PersonalizationStrategy::FineTuning {
epochs: 5,
learning_rate: 0.01,
let pfl = PersonalizedFL::new(strategy);
assert_eq!(pfl.client_models.len(), 0);
fn test_clustering_strategy() {
let strategy = PersonalizationStrategy::Clustering {
num_clusters: 3,
method: ClusteringMethod::KMeansParameters,
assert_eq!(pfl.cluster_assignments.len(), 0);
fn test_personalized_aggregation() {
let mut aggregator = PersonalizedAggregation::new(0.7, 0.3, strategy);
let updates = vec![ClientUpdate {
client_id: 0,
weight_updates: vec![Array2::ones((2, 2))],
num_samples: 100,
loss: 0.5,
accuracy: 0.9,
}];
let weights = vec![1.0];
let result = aggregator.aggregate(&updates, &weights).expect("Operation failed");
assert_eq!(result.len(), 1);
assert_eq!(result[0].shape(), &[2, 2]);