rustkernel_ml/
federated.rs

1//! Federated Learning kernels.
2//!
3//! This module provides privacy-preserving distributed learning algorithms:
4//! - SecureAggregation - Privacy-preserving model aggregation
5
6use rand::rngs::StdRng;
7use rand::{Rng, SeedableRng, rng};
8use rustkernel_core::{domain::Domain, kernel::KernelMetadata, traits::GpuKernel};
9use serde::{Deserialize, Serialize};
10
11// ============================================================================
12// Secure Aggregation Kernel
13// ============================================================================
14
15/// Configuration for secure aggregation.
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct SecureAggConfig {
18    /// Minimum number of participants required.
19    pub min_participants: usize,
20    /// Maximum number of participants.
21    pub max_participants: usize,
22    /// Privacy budget (differential privacy epsilon).
23    pub epsilon: f64,
24    /// Clipping threshold for gradients.
25    pub clip_threshold: f64,
26    /// Whether to use differential privacy noise.
27    pub add_noise: bool,
28    /// Seed for reproducibility.
29    pub seed: Option<u64>,
30}
31
32impl Default for SecureAggConfig {
33    fn default() -> Self {
34        Self {
35            min_participants: 3,
36            max_participants: 100,
37            epsilon: 1.0,
38            clip_threshold: 1.0,
39            add_noise: true,
40            seed: None,
41        }
42    }
43}
44
45/// A participant's model update (gradient or weights).
46#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct ParticipantUpdate {
48    /// Participant identifier.
49    pub participant_id: String,
50    /// Model parameters/gradients.
51    pub parameters: Vec<f64>,
52    /// Number of local samples used.
53    pub sample_count: usize,
54    /// Local loss value (optional).
55    pub local_loss: Option<f64>,
56}
57
58/// Result of secure aggregation.
59#[derive(Debug, Clone, Serialize, Deserialize)]
60pub struct AggregationResult {
61    /// Aggregated parameters.
62    pub aggregated_params: Vec<f64>,
63    /// Number of participants included.
64    pub participant_count: usize,
65    /// Total samples across participants.
66    pub total_samples: usize,
67    /// Average loss if reported.
68    pub average_loss: Option<f64>,
69    /// Privacy guarantee achieved.
70    pub privacy_guarantee: PrivacyGuarantee,
71    /// Participants that were included.
72    pub included_participants: Vec<String>,
73    /// Participants that were excluded (if any).
74    pub excluded_participants: Vec<String>,
75}
76
77/// Privacy guarantee provided.
78#[derive(Debug, Clone, Serialize, Deserialize)]
79pub struct PrivacyGuarantee {
80    /// Differential privacy epsilon.
81    pub epsilon: f64,
82    /// Differential privacy delta.
83    pub delta: f64,
84    /// Whether secure aggregation was used.
85    pub secure_aggregation: bool,
86    /// Noise scale applied.
87    pub noise_scale: f64,
88}
89
90/// Mask for secure aggregation protocol.
91#[derive(Debug, Clone)]
92#[allow(dead_code)]
93struct SecureMask {
94    participant_id: String,
95    mask: Vec<f64>,
96    seed: u64,
97}
98
99#[allow(dead_code)]
100impl SecureMask {
101    fn generate(participant_id: &str, size: usize, seed: u64) -> Self {
102        let mut rng = StdRng::seed_from_u64(seed);
103        let mask: Vec<f64> = (0..size).map(|_| rng.random_range(-1.0..1.0)).collect();
104        Self {
105            participant_id: participant_id.to_string(),
106            mask,
107            seed,
108        }
109    }
110}
111
112/// Secure Aggregation kernel.
113///
114/// Implements privacy-preserving aggregation of model updates from
115/// multiple participants. Uses masking and differential privacy
116/// to ensure no individual update can be reconstructed.
117#[derive(Debug, Clone)]
118pub struct SecureAggregation {
119    metadata: KernelMetadata,
120}
121
122impl Default for SecureAggregation {
123    fn default() -> Self {
124        Self::new()
125    }
126}
127
128impl SecureAggregation {
129    /// Create a new Secure Aggregation kernel.
130    #[must_use]
131    pub fn new() -> Self {
132        Self {
133            metadata: KernelMetadata::batch("ml/secure-aggregation", Domain::StatisticalML)
134                .with_description("Privacy-preserving federated model aggregation")
135                .with_throughput(1_000)
136                .with_latency_us(500.0),
137        }
138    }
139
140    /// Aggregate updates from multiple participants.
141    pub fn aggregate(updates: &[ParticipantUpdate], config: &SecureAggConfig) -> AggregationResult {
142        if updates.is_empty() {
143            return AggregationResult {
144                aggregated_params: Vec::new(),
145                participant_count: 0,
146                total_samples: 0,
147                average_loss: None,
148                privacy_guarantee: PrivacyGuarantee {
149                    epsilon: config.epsilon,
150                    delta: 1e-5,
151                    secure_aggregation: false,
152                    noise_scale: 0.0,
153                },
154                included_participants: Vec::new(),
155                excluded_participants: Vec::new(),
156            };
157        }
158
159        // Check minimum participants
160        if updates.len() < config.min_participants {
161            return AggregationResult {
162                aggregated_params: Vec::new(),
163                participant_count: 0,
164                total_samples: 0,
165                average_loss: None,
166                privacy_guarantee: PrivacyGuarantee {
167                    epsilon: f64::INFINITY,
168                    delta: 1.0,
169                    secure_aggregation: false,
170                    noise_scale: 0.0,
171                },
172                included_participants: Vec::new(),
173                excluded_participants: updates.iter().map(|u| u.participant_id.clone()).collect(),
174            };
175        }
176
177        let param_size = updates[0].parameters.len();
178        let mut included = Vec::new();
179        let mut excluded = Vec::new();
180
181        // Clip and validate updates
182        let clipped_updates: Vec<(String, Vec<f64>, usize)> = updates
183            .iter()
184            .filter_map(|u| {
185                if u.parameters.len() != param_size {
186                    excluded.push(u.participant_id.clone());
187                    return None;
188                }
189                included.push(u.participant_id.clone());
190                let clipped = Self::clip_update(&u.parameters, config.clip_threshold);
191                Some((u.participant_id.clone(), clipped, u.sample_count))
192            })
193            .collect();
194
195        if clipped_updates.len() < config.min_participants {
196            return AggregationResult {
197                aggregated_params: Vec::new(),
198                participant_count: 0,
199                total_samples: 0,
200                average_loss: None,
201                privacy_guarantee: PrivacyGuarantee {
202                    epsilon: f64::INFINITY,
203                    delta: 1.0,
204                    secure_aggregation: false,
205                    noise_scale: 0.0,
206                },
207                included_participants: Vec::new(),
208                excluded_participants: updates.iter().map(|u| u.participant_id.clone()).collect(),
209            };
210        }
211
212        // Compute weighted average (FedAvg style)
213        let total_samples: usize = clipped_updates.iter().map(|(_, _, s)| s).sum();
214        let mut aggregated = vec![0.0; param_size];
215
216        for (_, params, sample_count) in &clipped_updates {
217            let weight = *sample_count as f64 / total_samples as f64;
218            for (i, &p) in params.iter().enumerate() {
219                aggregated[i] += p * weight;
220            }
221        }
222
223        // Add differential privacy noise
224        let noise_scale = if config.add_noise {
225            Self::add_dp_noise(&mut aggregated, config)
226        } else {
227            0.0
228        };
229
230        // Compute average loss
231        let average_loss = {
232            let losses: Vec<f64> = updates.iter().filter_map(|u| u.local_loss).collect();
233            if losses.is_empty() {
234                None
235            } else {
236                Some(losses.iter().sum::<f64>() / losses.len() as f64)
237            }
238        };
239
240        AggregationResult {
241            aggregated_params: aggregated,
242            participant_count: clipped_updates.len(),
243            total_samples,
244            average_loss,
245            privacy_guarantee: PrivacyGuarantee {
246                epsilon: config.epsilon,
247                delta: 1e-5,
248                secure_aggregation: true,
249                noise_scale,
250            },
251            included_participants: included,
252            excluded_participants: excluded,
253        }
254    }
255
256    /// Clip update to bound sensitivity.
257    fn clip_update(params: &[f64], threshold: f64) -> Vec<f64> {
258        let norm: f64 = params.iter().map(|x| x * x).sum::<f64>().sqrt();
259        if norm <= threshold {
260            params.to_vec()
261        } else {
262            let scale = threshold / norm;
263            params.iter().map(|&x| x * scale).collect()
264        }
265    }
266
267    /// Add Gaussian noise for differential privacy.
268    fn add_dp_noise(params: &mut [f64], config: &SecureAggConfig) -> f64 {
269        // Gaussian mechanism: sigma = sensitivity * sqrt(2 * ln(1.25/delta)) / epsilon
270        let delta = 1e-5;
271        let sensitivity = config.clip_threshold;
272        let sigma = sensitivity * (2.0 * (1.25_f64 / delta).ln()).sqrt() / config.epsilon;
273
274        let mut rng = match config.seed {
275            Some(seed) => StdRng::seed_from_u64(seed),
276            None => StdRng::from_rng(&mut rng()),
277        };
278
279        for p in params.iter_mut() {
280            // Box-Muller transform for Gaussian noise
281            let u1: f64 = rng.random_range(0.0001..1.0);
282            let u2: f64 = rng.random_range(0.0..1.0);
283            let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
284            *p += sigma * z;
285        }
286
287        sigma
288    }
289
290    /// Verify aggregation result (for testing).
291    pub fn verify_aggregation(_updates: &[ParticipantUpdate], result: &AggregationResult) -> bool {
292        // Basic sanity checks
293        if result.participant_count == 0 {
294            return result.aggregated_params.is_empty();
295        }
296
297        if result.aggregated_params.is_empty() {
298            return false;
299        }
300
301        // Check participant counts match
302        result.included_participants.len() == result.participant_count
303    }
304
305    /// Simulate a federated learning round.
306    pub fn simulate_round(
307        _global_model: &[f64],
308        local_updates: &[Vec<f64>],
309        sample_counts: &[usize],
310        config: &SecureAggConfig,
311    ) -> AggregationResult {
312        let updates: Vec<ParticipantUpdate> = local_updates
313            .iter()
314            .zip(sample_counts.iter())
315            .enumerate()
316            .map(|(i, (params, &count))| ParticipantUpdate {
317                participant_id: format!("participant_{}", i),
318                parameters: params.clone(),
319                sample_count: count,
320                local_loss: Some(0.5), // Dummy loss
321            })
322            .collect();
323
324        Self::aggregate(&updates, config)
325    }
326}
327
328impl GpuKernel for SecureAggregation {
329    fn metadata(&self) -> &KernelMetadata {
330        &self.metadata
331    }
332}
333
334#[cfg(test)]
335mod tests {
336    use super::*;
337
338    #[test]
339    fn test_secure_aggregation_metadata() {
340        let kernel = SecureAggregation::new();
341        assert_eq!(kernel.metadata().id, "ml/secure-aggregation");
342    }
343
344    #[test]
345    fn test_basic_aggregation() {
346        let updates = vec![
347            ParticipantUpdate {
348                participant_id: "p1".to_string(),
349                parameters: vec![1.0, 2.0, 3.0],
350                sample_count: 100,
351                local_loss: Some(0.5),
352            },
353            ParticipantUpdate {
354                participant_id: "p2".to_string(),
355                parameters: vec![2.0, 3.0, 4.0],
356                sample_count: 100,
357                local_loss: Some(0.6),
358            },
359            ParticipantUpdate {
360                participant_id: "p3".to_string(),
361                parameters: vec![3.0, 4.0, 5.0],
362                sample_count: 100,
363                local_loss: Some(0.7),
364            },
365        ];
366
367        let config = SecureAggConfig {
368            min_participants: 3,
369            add_noise: false,      // Disable for deterministic test
370            clip_threshold: 100.0, // High threshold to avoid clipping
371            ..Default::default()
372        };
373
374        let result = SecureAggregation::aggregate(&updates, &config);
375
376        assert_eq!(result.participant_count, 3);
377        assert_eq!(result.total_samples, 300);
378        assert_eq!(result.aggregated_params.len(), 3);
379
380        // Average should be (1+2+3)/3=2, (2+3+4)/3=3, (3+4+5)/3=4
381        assert!((result.aggregated_params[0] - 2.0).abs() < 0.01);
382        assert!((result.aggregated_params[1] - 3.0).abs() < 0.01);
383        assert!((result.aggregated_params[2] - 4.0).abs() < 0.01);
384    }
385
386    #[test]
387    fn test_weighted_aggregation() {
388        let updates = vec![
389            ParticipantUpdate {
390                participant_id: "p1".to_string(),
391                parameters: vec![1.0],
392                sample_count: 100, // 1/3 weight
393                local_loss: None,
394            },
395            ParticipantUpdate {
396                participant_id: "p2".to_string(),
397                parameters: vec![4.0],
398                sample_count: 200, // 2/3 weight
399                local_loss: None,
400            },
401            ParticipantUpdate {
402                participant_id: "p3".to_string(),
403                parameters: vec![1.0],
404                sample_count: 0, // 0 weight
405                local_loss: None,
406            },
407        ];
408
409        let config = SecureAggConfig {
410            min_participants: 2,
411            add_noise: false,
412            clip_threshold: 100.0, // High threshold to avoid clipping
413            ..Default::default()
414        };
415
416        let result = SecureAggregation::aggregate(&updates, &config);
417
418        // Weighted average: (1*100 + 4*200 + 1*0) / 300 = 900/300 = 3.0
419        assert!((result.aggregated_params[0] - 3.0).abs() < 0.01);
420    }
421
422    #[test]
423    fn test_insufficient_participants() {
424        let updates = vec![ParticipantUpdate {
425            participant_id: "p1".to_string(),
426            parameters: vec![1.0],
427            sample_count: 100,
428            local_loss: None,
429        }];
430
431        let config = SecureAggConfig {
432            min_participants: 3,
433            ..Default::default()
434        };
435
436        let result = SecureAggregation::aggregate(&updates, &config);
437
438        assert_eq!(result.participant_count, 0);
439        assert!(result.aggregated_params.is_empty());
440        assert_eq!(result.privacy_guarantee.epsilon, f64::INFINITY);
441    }
442
443    #[test]
444    fn test_clipping() {
445        let params = vec![3.0, 4.0]; // Norm = 5
446        let clipped = SecureAggregation::clip_update(&params, 1.0);
447
448        let norm: f64 = clipped.iter().map(|x| x * x).sum::<f64>().sqrt();
449        assert!((norm - 1.0).abs() < 0.001);
450    }
451
452    #[test]
453    fn test_dp_noise_added() {
454        let updates = vec![
455            ParticipantUpdate {
456                participant_id: "p1".to_string(),
457                parameters: vec![1.0, 1.0],
458                sample_count: 100,
459                local_loss: None,
460            },
461            ParticipantUpdate {
462                participant_id: "p2".to_string(),
463                parameters: vec![1.0, 1.0],
464                sample_count: 100,
465                local_loss: None,
466            },
467            ParticipantUpdate {
468                participant_id: "p3".to_string(),
469                parameters: vec![1.0, 1.0],
470                sample_count: 100,
471                local_loss: None,
472            },
473        ];
474
475        let config = SecureAggConfig {
476            min_participants: 3,
477            add_noise: true,
478            epsilon: 1.0,
479            seed: Some(42),
480            ..Default::default()
481        };
482
483        let result = SecureAggregation::aggregate(&updates, &config);
484
485        // With noise, result should not be exactly 1.0
486        assert!(result.privacy_guarantee.noise_scale > 0.0);
487        // But should be close (noise is bounded by epsilon)
488    }
489
490    #[test]
491    fn test_empty_updates() {
492        let config = SecureAggConfig::default();
493        let result = SecureAggregation::aggregate(&[], &config);
494
495        assert!(result.aggregated_params.is_empty());
496        assert_eq!(result.participant_count, 0);
497    }
498
499    #[test]
500    fn test_simulate_round() {
501        let global = vec![0.0, 0.0, 0.0];
502        let local_updates = vec![
503            vec![0.1, 0.2, 0.3],
504            vec![0.2, 0.3, 0.4],
505            vec![0.3, 0.4, 0.5],
506        ];
507        let sample_counts = vec![100, 100, 100];
508
509        let config = SecureAggConfig {
510            min_participants: 3,
511            add_noise: false,
512            ..Default::default()
513        };
514
515        let result =
516            SecureAggregation::simulate_round(&global, &local_updates, &sample_counts, &config);
517
518        assert_eq!(result.participant_count, 3);
519        assert!(result.average_loss.is_some());
520    }
521
522    #[test]
523    fn test_verify_aggregation() {
524        let updates = vec![
525            ParticipantUpdate {
526                participant_id: "p1".to_string(),
527                parameters: vec![1.0],
528                sample_count: 100,
529                local_loss: None,
530            },
531            ParticipantUpdate {
532                participant_id: "p2".to_string(),
533                parameters: vec![2.0],
534                sample_count: 100,
535                local_loss: None,
536            },
537            ParticipantUpdate {
538                participant_id: "p3".to_string(),
539                parameters: vec![3.0],
540                sample_count: 100,
541                local_loss: None,
542            },
543        ];
544
545        let config = SecureAggConfig {
546            min_participants: 3,
547            add_noise: false,
548            ..Default::default()
549        };
550
551        let result = SecureAggregation::aggregate(&updates, &config);
552        assert!(SecureAggregation::verify_aggregation(&updates, &result));
553    }
554}