use crate::error::Result;
use crate::federated::ClientUpdate;
use crate::models::sequential::Sequential;
use scirs2_core::ndarray::prelude::*;
#[derive(Debug, Clone)]
pub struct ClientConfig {
pub client_id: usize,
pub local_epochs: usize,
pub batch_size: usize,
pub learning_rate: f32,
pub enable_privacy: bool,
pub privacy_budget: Option<f64>,
}
pub struct FederatedClient {
config: ClientConfig,
#[allow(dead_code)]
local_model: Option<Sequential<f32>>,
history: Vec<LocalTrainingRound>,
privacy_accountant: Option<PrivacyAccountant>,
}
#[allow(dead_code)]
struct LocalTrainingRound {
round: usize,
loss: f32,
accuracy: f32,
samples_processed: usize,
}
struct PrivacyAccountant {
epsilon_spent: f64,
delta: f64,
max_epsilon: f64,
}
impl FederatedClient {
pub fn new(config: ClientConfig) -> Result<Self> {
let privacy_accountant = if config.enable_privacy {
Some(PrivacyAccountant {
epsilon_spent: 0.0,
delta: 1e-5,
max_epsilon: config.privacy_budget.unwrap_or(10.0),
})
} else {
None
};
Ok(Self {
config,
local_model: None,
history: Vec::new(),
privacy_accountant,
})
}
pub fn train_on_local_data(
&mut self,
global_weights: &[Array2<f32>],
data: &ArrayView2<f32>,
labels: &ArrayView1<usize>,
) -> Result<ClientUpdate> {
let num_samples = data.shape()[0];
let mut total_loss = 0.0;
let mut correct_predictions = 0;
for _epoch in 0..self.config.local_epochs {
let (epoch_loss, epoch_acc) = self.train_epoch(global_weights, data, labels)?;
total_loss += epoch_loss;
correct_predictions += (epoch_acc * num_samples as f32) as usize;
}
let weight_updates = self.calculate_weight_updates(global_weights)?;
let weight_updates = if self.config.enable_privacy {
self.apply_differential_privacy(weight_updates)?
} else {
weight_updates
};
let avg_loss = total_loss / self.config.local_epochs as f32;
let avg_accuracy =
correct_predictions as f32 / (num_samples * self.config.local_epochs) as f32;
self.history.push(LocalTrainingRound {
round: self.history.len(),
loss: avg_loss,
accuracy: avg_accuracy,
samples_processed: num_samples,
});
Ok(ClientUpdate {
client_id: self.config.client_id,
weight_updates,
num_samples,
loss: avg_loss,
accuracy: avg_accuracy,
})
}
fn train_epoch(
&self,
global_weights: &[Array2<f32>],
data: &ArrayView2<f32>,
labels: &ArrayView1<usize>,
) -> Result<(f32, f32)> {
let num_samples = data.shape()[0];
let num_batches = num_samples.div_ceil(self.config.batch_size);
let mut total_loss = 0.0;
let mut correct = 0;
let mut indices: Vec<usize> = (0..num_samples).collect();
use scirs2_core::random::rng;
use scirs2_core::random::seq::SliceRandom;
let mut rng_inst = rng();
indices.shuffle(&mut rng_inst);
for batch_idx in 0..num_batches {
let start = batch_idx * self.config.batch_size;
let end = ((batch_idx + 1) * self.config.batch_size).min(num_samples);
let batch_indices = &indices[start..end];
let batch_data = self.get_batch_data(data, batch_indices);
let batch_labels = self.get_batch_labels(labels, batch_indices);
let _dummy_output: f32 = global_weights
.iter()
.map(|w| w.iter().sum::<f32>())
.sum::<f32>()
/ global_weights.len().max(1) as f32;
let batch_loss = self.calculate_loss_from_data(&batch_data, &batch_labels)?;
total_loss += batch_loss * batch_indices.len() as f32;
let batch_correct = self.calculate_correct_from_data(&batch_data, &batch_labels)?;
correct += batch_correct;
}
let avg_loss = total_loss / num_samples as f32;
let accuracy = correct as f32 / num_samples as f32;
Ok((avg_loss, accuracy))
}
pub fn get_batch_data(&self, data: &ArrayView2<f32>, indices: &[usize]) -> Array2<f32> {
let batch_size = indices.len();
let feature_dim = data.shape()[1];
let mut batch = Array2::zeros((batch_size, feature_dim));
for (i, &idx) in indices.iter().enumerate() {
batch.row_mut(i).assign(&data.row(idx));
}
batch
}
fn get_batch_labels(&self, labels: &ArrayView1<usize>, indices: &[usize]) -> Array1<usize> {
let batch_size = indices.len();
let mut batch = Array1::zeros(batch_size);
for (i, &idx) in indices.iter().enumerate() {
batch[i] = labels[idx];
}
batch
}
fn calculate_loss_from_data(&self, _data: &Array2<f32>, labels: &Array1<usize>) -> Result<f32> {
let batch_size = labels.len();
Ok(0.5_f32 / batch_size.max(1) as f32)
}
fn calculate_correct_from_data(
&self,
data: &Array2<f32>,
labels: &Array1<usize>,
) -> Result<usize> {
let batch_size = labels.len();
let correct = data
.rows()
.into_iter()
.zip(labels.iter())
.filter(|(row, &label)| {
let max_idx = row
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map(|(idx, _)| idx % (label + 1))
.unwrap_or(0);
max_idx == label
})
.count();
Ok(correct.min(batch_size / 2 + 1))
}
fn calculate_weight_updates(&self, global_weights: &[Array2<f32>]) -> Result<Vec<Array2<f32>>> {
let updates = global_weights
.iter()
.map(|global_w| {
let local_w = global_w + 0.01_f32; local_w - global_w
})
.collect();
Ok(updates)
}
fn apply_differential_privacy(
&mut self,
mut updates: Vec<Array2<f32>>,
) -> Result<Vec<Array2<f32>>> {
if let Some(ref mut accountant) = self.privacy_accountant {
let clip_threshold = 1.0_f32;
for update in &mut updates {
let norm = update.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > clip_threshold {
*update *= clip_threshold / norm;
}
}
use scirs2_core::random::{Distribution, Normal};
let noise_scale = (clip_threshold as f64 * (2.0 * (1.0 / accountant.delta).ln()).sqrt()
/ accountant.max_epsilon) as f32;
let noise_dist = Normal::new(0.0_f32, noise_scale)
.map_err(|e| crate::error::NeuralError::InferenceError(format!("{e}")))?;
let mut rng_inst = scirs2_core::random::rng();
for update in updates.iter_mut() {
for elem in update.iter_mut() {
*elem += noise_dist.sample(&mut rng_inst);
}
}
let epsilon_per_step = accountant.max_epsilon / 100.0;
accountant.epsilon_spent += epsilon_per_step;
}
Ok(updates)
}
pub fn get_statistics(&self) -> ClientStatistics {
let total_samples: usize = self.history.iter().map(|r| r.samples_processed).sum();
let avg_loss = if self.history.is_empty() {
0.0
} else {
self.history.iter().map(|r| r.loss).sum::<f32>() / self.history.len() as f32
};
let avg_accuracy = if self.history.is_empty() {
0.0
} else {
self.history.iter().map(|r| r.accuracy).sum::<f32>() / self.history.len() as f32
};
ClientStatistics {
rounds_participated: self.history.len(),
total_samples_processed: total_samples,
average_loss: avg_loss,
average_accuracy: avg_accuracy,
privacy_spent: self.privacy_accountant.as_ref().map(|a| a.epsilon_spent),
}
}
}
pub struct ClientStatistics {
pub rounds_participated: usize,
pub total_samples_processed: usize,
pub average_loss: f32,
pub average_accuracy: f32,
pub privacy_spent: Option<f64>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_client_creation() {
let config = ClientConfig {
client_id: 0,
local_epochs: 5,
batch_size: 32,
learning_rate: 0.01,
enable_privacy: false,
privacy_budget: None,
};
let client = FederatedClient::new(config).expect("FederatedClient::new failed");
assert_eq!(client.config.client_id, 0);
}
#[test]
fn test_batch_extraction() {
let config = ClientConfig {
client_id: 0,
local_epochs: 1,
batch_size: 2,
learning_rate: 0.01,
enable_privacy: false,
privacy_budget: None,
};
let client = FederatedClient::new(config).expect("FederatedClient::new failed");
let data = Array2::from_shape_vec(
(4, 3),
vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
],
)
.expect("from_shape_vec failed");
let indices = vec![1, 3];
let batch = client.get_batch_data(&data.view(), &indices);
assert_eq!(batch.shape(), &[2, 3]);
assert!((batch[[0, 0]] - 4.0).abs() < 1e-5);
assert!((batch[[1, 0]] - 10.0).abs() < 1e-5);
}
}