optirs-core 0.3.1

OptiRS core optimization algorithms and utilities
Documentation
use std::fmt::Debug;
// Secure Aggregation Module
//
// This module implements secure aggregation protocols for federated learning,
// including secure multi-party computation techniques to aggregate client updates
// while preserving individual client privacy.

use super::super::moment_accountant::MomentsAccountant;
use super::super::{AccountingMethod, DifferentialPrivacyConfig, NoiseMechanism, PrivacyBudget};
use crate::error::{OptimError, Result};
use scirs2_core::ndarray::{Array1, Array2};
use scirs2_core::numeric::Float;
use scirs2_core::random::Rng;
use scirs2_core::random::{Random, Rng as SCRRng};
use std::collections::{HashMap, VecDeque};
use std::sync::Arc;

/// Secure aggregation configuration
#[derive(Debug, Clone)]
pub struct SecureAggregationConfig {
    /// Enable secure aggregation
    pub enabled: bool,

    /// Minimum number of clients for aggregation
    pub min_clients: usize,

    /// Maximum number of dropouts tolerated
    pub max_dropouts: usize,

    /// Masking vector dimension
    pub masking_dimension: usize,

    /// Random seed sharing method
    pub seed_sharing: SeedSharingMethod,

    /// Quantization bits for compressed aggregation
    pub quantization_bits: Option<u8>,

    /// Enable differential privacy on aggregated result
    pub aggregate_dp: bool,
}

/// Seed sharing methods for secure aggregation
#[derive(Debug, Clone, Copy)]
pub enum SeedSharingMethod {
    /// Shamir secret sharing
    ShamirSecretSharing,

    /// Threshold encryption
    ThresholdEncryption,

    /// Distributed key generation
    DistributedKeyGeneration,
}

/// Secure aggregation protocol implementation
pub struct SecureAggregator<T: Float + Debug + Send + Sync + 'static> {
    config: SecureAggregationConfig,
    client_masks: HashMap<String, Array1<T>>,
    shared_randomness: Arc<std::sync::Mutex<u64>>,
    aggregation_threshold: usize,
    round_keys: Vec<u64>,
}

/// Secure aggregation plan
#[derive(Debug, Clone)]
pub struct SecureAggregationPlan {
    pub round_seed: u64,
    pub participating_clients: Vec<String>,
    pub min_threshold: usize,
    pub masking_enabled: bool,
}

impl<T: Float + Debug + Send + Sync + 'static + scirs2_core::ndarray::ScalarOperand>
    SecureAggregator<T>
{
    pub fn new(config: SecureAggregationConfig) -> Result<Self> {
        let min_clients = config.min_clients;
        Ok(Self {
            config,
            client_masks: HashMap::new(),
            shared_randomness: Arc::new(std::sync::Mutex::new(0u64)),
            aggregation_threshold: min_clients,
            round_keys: Vec::new(),
        })
    }

    pub fn prepare_round(&mut self, selectedclients: &[String]) -> Result<SecureAggregationPlan> {
        // Generate round-specific keys
        let mut seed = self.shared_randomness.lock().expect("lock poisoned");
        *seed = seed.wrapping_add(1);
        let round_seed = *seed;
        self.round_keys.push(round_seed);

        // Generate client masks (simplified)
        self.client_masks.clear();
        for clientid in selectedclients.iter() {
            let mut client_rng = Random::default();
            let mask_size = self.config.masking_dimension;

            let mask = Array1::from_iter(
                (0..mask_size)
                    .map(|_| T::from(client_rng.gen_range(-1.0..1.0)).expect("unwrap failed")),
            );

            self.client_masks.insert(clientid.clone(), mask);
        }

        Ok(SecureAggregationPlan {
            round_seed,
            participating_clients: selectedclients.to_vec(),
            min_threshold: self.config.min_clients,
            masking_enabled: true,
        })
    }

    pub fn aggregate_with_masks(
        &self,
        clientupdates: &HashMap<String, Array1<T>>,
        _aggregation_plan: &SecureAggregationPlan,
    ) -> Result<Array1<T>> {
        if clientupdates.len() < self.aggregation_threshold {
            return Err(OptimError::InvalidConfig(
                "Insufficient clients for secure aggregation".to_string(),
            ));
        }

        // Simplified secure aggregation (in practice, would use more sophisticated protocols)
        let first_update = clientupdates.values().next().expect("unwrap failed");
        let mut aggregated = Array1::zeros(first_update.len());

        for (clientid, update) in clientupdates {
            if let Some(mask) = self.client_masks.get(clientid) {
                // Apply mask (simplified - real implementation would be more complex)
                let masked_update = if update.len() == mask.len() {
                    update + mask
                } else {
                    update.clone() // Fallback if dimensions don't match
                };
                aggregated = aggregated + masked_update;
            } else {
                aggregated = aggregated + update;
            }
        }

        // Remove aggregated masks (simplified)
        let num_clients = T::from(clientupdates.len()).expect("unwrap failed");
        aggregated = aggregated / num_clients;

        Ok(aggregated)
    }

    /// Get current configuration
    pub fn config(&self) -> &SecureAggregationConfig {
        &self.config
    }

    /// Get aggregation threshold
    pub fn aggregation_threshold(&self) -> usize {
        self.aggregation_threshold
    }

    /// Check if secure aggregation is enabled
    pub fn is_enabled(&self) -> bool {
        self.config.enabled
    }
}

impl Default for SecureAggregationConfig {
    fn default() -> Self {
        Self {
            enabled: true,
            min_clients: 10,
            max_dropouts: 5,
            masking_dimension: 1000,
            seed_sharing: SeedSharingMethod::ShamirSecretSharing,
            quantization_bits: None,
            aggregate_dp: true,
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use scirs2_core::ndarray::Array1;

    #[test]
    fn test_secure_aggregation_config() {
        let config = SecureAggregationConfig {
            enabled: true,
            min_clients: 5,
            max_dropouts: 2,
            masking_dimension: 100,
            seed_sharing: SeedSharingMethod::ShamirSecretSharing,
            quantization_bits: Some(8),
            aggregate_dp: true,
        };

        assert!(config.enabled);
        assert_eq!(config.min_clients, 5);
        assert_eq!(config.max_dropouts, 2);
    }

    #[test]
    fn test_secure_aggregator_creation() {
        let config = SecureAggregationConfig::default();
        let aggregator = SecureAggregator::<f64>::new(config.clone());

        assert!(aggregator.is_ok());
        let agg = aggregator.expect("unwrap failed");
        assert_eq!(agg.aggregation_threshold(), config.min_clients);
        assert!(agg.is_enabled());
    }

    #[test]
    fn test_secure_aggregation_plan() {
        let config = SecureAggregationConfig::default();
        let mut aggregator = SecureAggregator::<f64>::new(config).expect("unwrap failed");

        let clients = vec!["client1".to_string(), "client2".to_string()];
        let plan = aggregator.prepare_round(&clients);

        assert!(plan.is_ok());
        let plan = plan.expect("unwrap failed");
        assert_eq!(plan.participating_clients.len(), 2);
        assert!(plan.masking_enabled);
    }
}