ghostflow_nn/
federated.rs

1//! Federated Learning
2//!
3//! Implements federated learning algorithms:
4//! - FedAvg (Federated Averaging)
5//! - FedProx (Federated Proximal)
6//! - FedAdam (Federated Adam)
7//! - Secure aggregation
8//! - Differential privacy
9
10use ghostflow_core::Tensor;
11use std::collections::HashMap;
12use rand::Rng;
13
14/// Client in federated learning
15#[derive(Debug, Clone)]
16pub struct FederatedClient {
17    /// Client ID
18    pub id: usize,
19    /// Local model parameters
20    pub parameters: HashMap<String, Tensor>,
21    /// Number of local samples
22    pub num_samples: usize,
23    /// Local learning rate
24    pub learning_rate: f32,
25}
26
27impl FederatedClient {
28    /// Create a new federated client
29    pub fn new(id: usize, parameters: HashMap<String, Tensor>, num_samples: usize) -> Self {
30        FederatedClient {
31            id,
32            parameters,
33            num_samples,
34            learning_rate: 0.01,
35        }
36    }
37    
38    /// Perform local training
39    pub fn local_train(&mut self, num_epochs: usize, _batch_size: usize) {
40        // Simplified local training
41        // In practice, this would train on local data
42        for _ in 0..num_epochs {
43            // Simulate gradient updates
44            for (_, param) in self.parameters.iter_mut() {
45                let grad = Tensor::randn(param.dims()).mul_scalar(0.01);
46                *param = param.sub(&grad).unwrap();
47            }
48        }
49    }
50    
51    /// Get model update (difference from initial parameters)
52    pub fn get_update(&self, initial_params: &HashMap<String, Tensor>) -> HashMap<String, Tensor> {
53        let mut updates = HashMap::new();
54        
55        for (name, param) in &self.parameters {
56            if let Some(initial) = initial_params.get(name) {
57                let update = param.sub(initial).unwrap();
58                updates.insert(name.clone(), update);
59            }
60        }
61        
62        updates
63    }
64    
65    /// Set parameters
66    pub fn set_parameters(&mut self, parameters: HashMap<String, Tensor>) {
67        self.parameters = parameters;
68    }
69}
70
71/// Federated server
72pub struct FederatedServer {
73    /// Global model parameters
74    pub global_parameters: HashMap<String, Tensor>,
75    /// Connected clients
76    pub clients: Vec<FederatedClient>,
77    /// Aggregation strategy
78    pub aggregation: AggregationStrategy,
79    /// Current round
80    pub round: usize,
81}
82
83#[derive(Debug, Clone, Copy)]
84pub enum AggregationStrategy {
85    /// Federated Averaging (FedAvg)
86    FedAvg,
87    /// Federated Proximal (FedProx)
88    FedProx { mu: f32 },
89    /// Federated Adam
90    FedAdam { beta1: f32, beta2: f32 },
91}
92
93impl FederatedServer {
94    /// Create a new federated server
95    pub fn new(global_parameters: HashMap<String, Tensor>, aggregation: AggregationStrategy) -> Self {
96        FederatedServer {
97            global_parameters,
98            clients: Vec::new(),
99            aggregation,
100            round: 0,
101        }
102    }
103    
104    /// Register a client
105    pub fn register_client(&mut self, client: FederatedClient) {
106        self.clients.push(client);
107    }
108    
109    /// Select clients for training (random sampling)
110    pub fn select_clients(&self, fraction: f32) -> Vec<usize> {
111        let num_clients = (self.clients.len() as f32 * fraction).ceil() as usize;
112        let mut rng = rand::thread_rng();
113        
114        let mut selected = Vec::new();
115        let mut available: Vec<usize> = (0..self.clients.len()).collect();
116        
117        for _ in 0..num_clients.min(available.len()) {
118            let idx = rng.gen_range(0..available.len());
119            selected.push(available.remove(idx));
120        }
121        
122        selected
123    }
124    
125    /// Distribute global model to clients
126    pub fn distribute_model(&mut self, client_ids: &[usize]) {
127        for &id in client_ids {
128            if let Some(client) = self.clients.get_mut(id) {
129                client.set_parameters(self.global_parameters.clone());
130            }
131        }
132    }
133    
134    /// Aggregate client updates using FedAvg
135    fn aggregate_fedavg(&self, client_ids: &[usize]) -> HashMap<String, Tensor> {
136        let mut aggregated = HashMap::new();
137        let mut total_samples = 0;
138        
139        // Collect initial parameters
140        let initial_params = self.global_parameters.clone();
141        
142        // Weighted average based on number of samples
143        for &id in client_ids {
144            if let Some(client) = self.clients.get(id) {
145                total_samples += client.num_samples;
146                let updates = client.get_update(&initial_params);
147                
148                for (name, update) in updates {
149                    let weighted_update = update.mul_scalar(client.num_samples as f32);
150                    
151                    aggregated.entry(name)
152                        .and_modify(|agg: &mut Tensor| *agg = agg.add(&weighted_update).unwrap())
153                        .or_insert(weighted_update);
154                }
155            }
156        }
157        
158        // Normalize by total samples
159        for (_, update) in aggregated.iter_mut() {
160            *update = update.div_scalar(total_samples as f32);
161        }
162        
163        // Apply updates to global parameters
164        let mut new_params = HashMap::new();
165        for (name, param) in &self.global_parameters {
166            if let Some(update) = aggregated.get(name) {
167                new_params.insert(name.clone(), param.add(update).unwrap());
168            } else {
169                new_params.insert(name.clone(), param.clone());
170            }
171        }
172        
173        new_params
174    }
175    
176    /// Aggregate client updates using FedProx
177    fn aggregate_fedprox(&self, client_ids: &[usize], mu: f32) -> HashMap<String, Tensor> {
178        // FedProx adds a proximal term to keep updates close to global model
179        // For simplicity, we use FedAvg with a damping factor
180        let mut updates = self.aggregate_fedavg(client_ids);
181        
182        // Apply proximal damping
183        for (name, update) in updates.iter_mut() {
184            if let Some(global_param) = self.global_parameters.get(name) {
185                let diff = update.sub(global_param).unwrap();
186                let damped = diff.mul_scalar(1.0 / (1.0 + mu));
187                *update = global_param.add(&damped).unwrap();
188            }
189        }
190        
191        updates
192    }
193    
194    /// Aggregate client updates
195    pub fn aggregate(&mut self, client_ids: &[usize]) {
196        let new_params = match self.aggregation {
197            AggregationStrategy::FedAvg => self.aggregate_fedavg(client_ids),
198            AggregationStrategy::FedProx { mu } => self.aggregate_fedprox(client_ids, mu),
199            AggregationStrategy::FedAdam { .. } => {
200                // Simplified - would need momentum buffers
201                self.aggregate_fedavg(client_ids)
202            }
203        };
204        
205        self.global_parameters = new_params;
206        self.round += 1;
207    }
208    
209    /// Run one round of federated learning
210    pub fn train_round(&mut self, client_fraction: f32, local_epochs: usize, batch_size: usize) {
211        // Select clients
212        let selected_clients = self.select_clients(client_fraction);
213        
214        // Distribute model
215        self.distribute_model(&selected_clients);
216        
217        // Local training
218        for &id in &selected_clients {
219            if let Some(client) = self.clients.get_mut(id) {
220                client.local_train(local_epochs, batch_size);
221            }
222        }
223        
224        // Aggregate updates
225        self.aggregate(&selected_clients);
226    }
227    
228    /// Get current round number
229    pub fn current_round(&self) -> usize {
230        self.round
231    }
232}
233
234/// Secure aggregation using secret sharing
235pub struct SecureAggregation {
236    /// Number of clients
237    num_clients: usize,
238    /// Threshold for reconstruction
239    threshold: usize,
240}
241
242impl SecureAggregation {
243    /// Create a new secure aggregation scheme
244    pub fn new(num_clients: usize, threshold: usize) -> Self {
245        SecureAggregation {
246            num_clients,
247            threshold,
248        }
249    }
250    
251    /// Split a value into shares (simplified Shamir's secret sharing)
252    pub fn share(&self, value: f32) -> Vec<f32> {
253        let mut rng = rand::thread_rng();
254        let mut shares = Vec::with_capacity(self.num_clients);
255        
256        // Generate random shares
257        let mut sum = 0.0;
258        for _ in 0..self.num_clients - 1 {
259            let share: f32 = rng.gen_range(-1.0..1.0);
260            shares.push(share);
261            sum += share;
262        }
263        
264        // Last share ensures sum equals value
265        shares.push(value - sum);
266        
267        shares
268    }
269    
270    /// Reconstruct value from shares
271    pub fn reconstruct(&self, shares: &[f32]) -> f32 {
272        if shares.len() < self.threshold {
273            return 0.0;
274        }
275        
276        shares.iter().sum()
277    }
278    
279    /// Securely aggregate client updates
280    pub fn aggregate_secure(&self, client_updates: &[HashMap<String, Tensor>]) -> HashMap<String, Tensor> {
281        let mut aggregated = HashMap::new();
282        
283        // For each parameter
284        if let Some(first_update) = client_updates.first() {
285            for (name, _) in first_update {
286                // Collect all client values for this parameter
287                let mut param_values = Vec::new();
288                for update in client_updates {
289                    if let Some(tensor) = update.get(name) {
290                        param_values.push(tensor.clone());
291                    }
292                }
293                
294                // Simple aggregation (in practice, would use secure protocols)
295                if !param_values.is_empty() {
296                    let mut sum = param_values[0].clone();
297                    for tensor in &param_values[1..] {
298                        sum = sum.add(tensor).unwrap();
299                    }
300                    let avg = sum.div_scalar(param_values.len() as f32);
301                    aggregated.insert(name.clone(), avg);
302                }
303            }
304        }
305        
306        aggregated
307    }
308}
309
310/// Differential privacy for federated learning
311pub struct DifferentialPrivacy {
312    /// Privacy budget (epsilon)
313    pub epsilon: f32,
314    /// Sensitivity
315    pub sensitivity: f32,
316    /// Noise scale
317    pub noise_scale: f32,
318}
319
320impl DifferentialPrivacy {
321    /// Create a new differential privacy mechanism
322    pub fn new(epsilon: f32, sensitivity: f32) -> Self {
323        let noise_scale = sensitivity / epsilon;
324        
325        DifferentialPrivacy {
326            epsilon,
327            sensitivity,
328            noise_scale,
329        }
330    }
331    
332    /// Add Gaussian noise to a tensor
333    pub fn add_noise(&self, tensor: &Tensor) -> Tensor {
334        let noise = Tensor::randn(tensor.dims()).mul_scalar(self.noise_scale);
335        tensor.add(&noise).unwrap()
336    }
337    
338    /// Clip gradients to bound sensitivity
339    pub fn clip_gradients(&self, gradients: &HashMap<String, Tensor>, max_norm: f32) -> HashMap<String, Tensor> {
340        let mut clipped = HashMap::new();
341        
342        // Compute global norm
343        let mut global_norm_sq = 0.0;
344        for (_, grad) in gradients {
345            let data = grad.data_f32();
346            global_norm_sq += data.iter().map(|x| x * x).sum::<f32>();
347        }
348        let global_norm = global_norm_sq.sqrt();
349        
350        // Clip if necessary
351        let clip_factor = if global_norm > max_norm {
352            max_norm / global_norm
353        } else {
354            1.0
355        };
356        
357        for (name, grad) in gradients {
358            let clipped_grad = grad.mul_scalar(clip_factor);
359            clipped.insert(name.clone(), clipped_grad);
360        }
361        
362        clipped
363    }
364    
365    /// Apply differential privacy to aggregated updates
366    pub fn privatize(&self, updates: &HashMap<String, Tensor>) -> HashMap<String, Tensor> {
367        let mut private_updates = HashMap::new();
368        
369        for (name, update) in updates {
370            let noisy_update = self.add_noise(update);
371            private_updates.insert(name.clone(), noisy_update);
372        }
373        
374        private_updates
375    }
376}
377
378#[cfg(test)]
379mod tests {
380    use super::*;
381    
382    #[test]
383    fn test_federated_client() {
384        let mut params = HashMap::new();
385        params.insert("weight".to_string(), Tensor::ones(&[2, 2]));
386        
387        let mut client = FederatedClient::new(0, params, 100);
388        assert_eq!(client.num_samples, 100);
389        
390        client.local_train(1, 32);
391        // Parameters should have changed
392    }
393    
394    #[test]
395    fn test_federated_server() {
396        let mut params = HashMap::new();
397        params.insert("weight".to_string(), Tensor::ones(&[2, 2]));
398        
399        let mut server = FederatedServer::new(params.clone(), AggregationStrategy::FedAvg);
400        
401        // Register clients
402        for i in 0..5 {
403            let client = FederatedClient::new(i, params.clone(), 100);
404            server.register_client(client);
405        }
406        
407        assert_eq!(server.clients.len(), 5);
408        
409        // Select clients
410        let selected = server.select_clients(0.5);
411        assert!(selected.len() >= 2 && selected.len() <= 3);
412    }
413    
414    #[test]
415    fn test_secure_aggregation() {
416        let secure_agg = SecureAggregation::new(5, 3);
417        
418        let value = 10.0;
419        let shares = secure_agg.share(value);
420        assert_eq!(shares.len(), 5);
421        
422        let reconstructed = secure_agg.reconstruct(&shares);
423        assert!((reconstructed - value).abs() < 0.001);
424    }
425    
426    #[test]
427    fn test_differential_privacy() {
428        let dp = DifferentialPrivacy::new(1.0, 1.0);
429        
430        let tensor = Tensor::ones(&[2, 2]);
431        let noisy = dp.add_noise(&tensor);
432        
433        // Noisy tensor should be different
434        assert_ne!(noisy.data_f32(), tensor.data_f32());
435    }
436}