scirs2-neural 0.4.3

Neural network building blocks module for SciRS2 (scirs2-neural) - Minimal Version
Documentation
//! Federated learning client implementation

use crate::error::Result;
use crate::federated::ClientUpdate;
use crate::models::sequential::Sequential;
use scirs2_core::ndarray::prelude::*;

/// Configuration for a federated client
#[derive(Debug, Clone)]
pub struct ClientConfig {
    /// Unique client identifier
    pub client_id: usize,
    /// Number of local training epochs
    pub local_epochs: usize,
    /// Local batch size
    pub batch_size: usize,
    /// Learning rate
    pub learning_rate: f32,
    /// Enable differential privacy
    pub enable_privacy: bool,
    /// Privacy budget (epsilon)
    pub privacy_budget: Option<f64>,
}

/// Federated learning client
pub struct FederatedClient {
    config: ClientConfig,
    /// Local model (optional, for stateful clients)
    #[allow(dead_code)]
    local_model: Option<Sequential<f32>>,
    /// Training history
    history: Vec<LocalTrainingRound>,
    /// Privacy accountant
    privacy_accountant: Option<PrivacyAccountant>,
}

/// Local training round information
#[allow(dead_code)]
struct LocalTrainingRound {
    round: usize,
    loss: f32,
    accuracy: f32,
    samples_processed: usize,
}

/// Privacy accountant for differential privacy
struct PrivacyAccountant {
    epsilon_spent: f64,
    delta: f64,
    max_epsilon: f64,
}

impl FederatedClient {
    /// Create a new federated client
    pub fn new(config: ClientConfig) -> Result<Self> {
        let privacy_accountant = if config.enable_privacy {
            Some(PrivacyAccountant {
                epsilon_spent: 0.0,
                delta: 1e-5,
                max_epsilon: config.privacy_budget.unwrap_or(10.0),
            })
        } else {
            None
        };
        Ok(Self {
            config,
            local_model: None,
            history: Vec::new(),
            privacy_accountant,
        })
    }

    /// Train on local data
    pub fn train_on_local_data(
        &mut self,
        global_weights: &[Array2<f32>],
        data: &ArrayView2<f32>,
        labels: &ArrayView1<usize>,
    ) -> Result<ClientUpdate> {
        let num_samples = data.shape()[0];
        let mut total_loss = 0.0;
        let mut correct_predictions = 0;

        // Local training epochs
        for _epoch in 0..self.config.local_epochs {
            let (epoch_loss, epoch_acc) = self.train_epoch(global_weights, data, labels)?;
            total_loss += epoch_loss;
            correct_predictions += (epoch_acc * num_samples as f32) as usize;
        }

        // Calculate weight updates (difference from global weights)
        let weight_updates = self.calculate_weight_updates(global_weights)?;

        // Apply differential privacy if enabled
        let weight_updates = if self.config.enable_privacy {
            self.apply_differential_privacy(weight_updates)?
        } else {
            weight_updates
        };

        let avg_loss = total_loss / self.config.local_epochs as f32;
        let avg_accuracy =
            correct_predictions as f32 / (num_samples * self.config.local_epochs) as f32;

        // Record training round
        self.history.push(LocalTrainingRound {
            round: self.history.len(),
            loss: avg_loss,
            accuracy: avg_accuracy,
            samples_processed: num_samples,
        });

        Ok(ClientUpdate {
            client_id: self.config.client_id,
            weight_updates,
            num_samples,
            loss: avg_loss,
            accuracy: avg_accuracy,
        })
    }

    /// Train for one epoch
    fn train_epoch(
        &self,
        global_weights: &[Array2<f32>],
        data: &ArrayView2<f32>,
        labels: &ArrayView1<usize>,
    ) -> Result<(f32, f32)> {
        let num_samples = data.shape()[0];
        let num_batches = num_samples.div_ceil(self.config.batch_size);
        let mut total_loss = 0.0;
        let mut correct = 0;

        // Shuffle indices
        let mut indices: Vec<usize> = (0..num_samples).collect();
        use scirs2_core::random::rng;
        use scirs2_core::random::seq::SliceRandom;
        let mut rng_inst = rng();
        indices.shuffle(&mut rng_inst);

        for batch_idx in 0..num_batches {
            let start = batch_idx * self.config.batch_size;
            let end = ((batch_idx + 1) * self.config.batch_size).min(num_samples);
            let batch_indices = &indices[start..end];
            let batch_data = self.get_batch_data(data, batch_indices);
            let batch_labels = self.get_batch_labels(labels, batch_indices);

            // Simulate forward pass using global_weights shape info
            let _dummy_output: f32 = global_weights
                .iter()
                .map(|w| w.iter().sum::<f32>())
                .sum::<f32>()
                / global_weights.len().max(1) as f32;

            let batch_loss = self.calculate_loss_from_data(&batch_data, &batch_labels)?;
            total_loss += batch_loss * batch_indices.len() as f32;

            let batch_correct = self.calculate_correct_from_data(&batch_data, &batch_labels)?;
            correct += batch_correct;
        }

        let avg_loss = total_loss / num_samples as f32;
        let accuracy = correct as f32 / num_samples as f32;
        Ok((avg_loss, accuracy))
    }

    /// Get batch data
    pub fn get_batch_data(&self, data: &ArrayView2<f32>, indices: &[usize]) -> Array2<f32> {
        let batch_size = indices.len();
        let feature_dim = data.shape()[1];
        let mut batch = Array2::zeros((batch_size, feature_dim));
        for (i, &idx) in indices.iter().enumerate() {
            batch.row_mut(i).assign(&data.row(idx));
        }
        batch
    }

    /// Get batch labels
    fn get_batch_labels(&self, labels: &ArrayView1<usize>, indices: &[usize]) -> Array1<usize> {
        let batch_size = indices.len();
        let mut batch = Array1::zeros(batch_size);
        for (i, &idx) in indices.iter().enumerate() {
            batch[i] = labels[idx];
        }
        batch
    }

    /// Calculate loss from batch data (simplified)
    fn calculate_loss_from_data(&self, _data: &Array2<f32>, labels: &Array1<usize>) -> Result<f32> {
        // Simplified cross-entropy loss placeholder
        let batch_size = labels.len();
        Ok(0.5_f32 / batch_size.max(1) as f32)
    }

    /// Calculate correct predictions from batch data (simplified)
    fn calculate_correct_from_data(
        &self,
        data: &Array2<f32>,
        labels: &Array1<usize>,
    ) -> Result<usize> {
        // Simplified placeholder: count by majority class
        let batch_size = labels.len();
        let correct = data
            .rows()
            .into_iter()
            .zip(labels.iter())
            .filter(|(row, &label)| {
                let max_idx = row
                    .iter()
                    .enumerate()
                    .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
                    .map(|(idx, _)| idx % (label + 1))
                    .unwrap_or(0);
                max_idx == label
            })
            .count();
        // Ensure at least half are "correct" for placeholder
        Ok(correct.min(batch_size / 2 + 1))
    }

    /// Calculate weight updates
    fn calculate_weight_updates(&self, global_weights: &[Array2<f32>]) -> Result<Vec<Array2<f32>>> {
        // Return simulated difference between local and global weights
        let updates = global_weights
            .iter()
            .map(|global_w| {
                let local_w = global_w + 0.01_f32; // Simulated update
                local_w - global_w
            })
            .collect();
        Ok(updates)
    }

    /// Apply differential privacy to weight updates
    fn apply_differential_privacy(
        &mut self,
        mut updates: Vec<Array2<f32>>,
    ) -> Result<Vec<Array2<f32>>> {
        if let Some(ref mut accountant) = self.privacy_accountant {
            // Clip gradients
            let clip_threshold = 1.0_f32;
            for update in &mut updates {
                let norm = update.iter().map(|x| x * x).sum::<f32>().sqrt();
                if norm > clip_threshold {
                    *update *= clip_threshold / norm;
                }
            }
            // Add Gaussian noise
            use scirs2_core::random::{Distribution, Normal};
            let noise_scale = (clip_threshold as f64 * (2.0 * (1.0 / accountant.delta).ln()).sqrt()
                / accountant.max_epsilon) as f32;
            let noise_dist = Normal::new(0.0_f32, noise_scale)
                .map_err(|e| crate::error::NeuralError::InferenceError(format!("{e}")))?;
            let mut rng_inst = scirs2_core::random::rng();
            for update in updates.iter_mut() {
                for elem in update.iter_mut() {
                    *elem += noise_dist.sample(&mut rng_inst);
                }
            }
            // Update privacy budget
            let epsilon_per_step = accountant.max_epsilon / 100.0;
            accountant.epsilon_spent += epsilon_per_step;
        }
        Ok(updates)
    }

    /// Get client statistics
    pub fn get_statistics(&self) -> ClientStatistics {
        let total_samples: usize = self.history.iter().map(|r| r.samples_processed).sum();
        let avg_loss = if self.history.is_empty() {
            0.0
        } else {
            self.history.iter().map(|r| r.loss).sum::<f32>() / self.history.len() as f32
        };
        let avg_accuracy = if self.history.is_empty() {
            0.0
        } else {
            self.history.iter().map(|r| r.accuracy).sum::<f32>() / self.history.len() as f32
        };
        ClientStatistics {
            rounds_participated: self.history.len(),
            total_samples_processed: total_samples,
            average_loss: avg_loss,
            average_accuracy: avg_accuracy,
            privacy_spent: self.privacy_accountant.as_ref().map(|a| a.epsilon_spent),
        }
    }
}

/// Client statistics
pub struct ClientStatistics {
    pub rounds_participated: usize,
    pub total_samples_processed: usize,
    pub average_loss: f32,
    pub average_accuracy: f32,
    pub privacy_spent: Option<f64>,
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_client_creation() {
        let config = ClientConfig {
            client_id: 0,
            local_epochs: 5,
            batch_size: 32,
            learning_rate: 0.01,
            enable_privacy: false,
            privacy_budget: None,
        };
        let client = FederatedClient::new(config).expect("FederatedClient::new failed");
        assert_eq!(client.config.client_id, 0);
    }

    #[test]
    fn test_batch_extraction() {
        let config = ClientConfig {
            client_id: 0,
            local_epochs: 1,
            batch_size: 2,
            learning_rate: 0.01,
            enable_privacy: false,
            privacy_budget: None,
        };
        let client = FederatedClient::new(config).expect("FederatedClient::new failed");
        let data = Array2::from_shape_vec(
            (4, 3),
            vec![
                1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
            ],
        )
        .expect("from_shape_vec failed");
        let indices = vec![1, 3];
        let batch = client.get_batch_data(&data.view(), &indices);
        assert_eq!(batch.shape(), &[2, 3]);
        assert!((batch[[0, 0]] - 4.0).abs() < 1e-5);
        assert!((batch[[1, 0]] - 10.0).abs() < 1e-5);
    }
}