1use ghostflow_core::Tensor;
11use std::collections::HashMap;
12use rand::Rng;
13
14#[derive(Debug, Clone)]
16pub struct FederatedClient {
17 pub id: usize,
19 pub parameters: HashMap<String, Tensor>,
21 pub num_samples: usize,
23 pub learning_rate: f32,
25}
26
27impl FederatedClient {
28 pub fn new(id: usize, parameters: HashMap<String, Tensor>, num_samples: usize) -> Self {
30 FederatedClient {
31 id,
32 parameters,
33 num_samples,
34 learning_rate: 0.01,
35 }
36 }
37
38 pub fn local_train(&mut self, num_epochs: usize, batch_size: usize) {
40 for _ in 0..num_epochs {
43 for (_, param) in self.parameters.iter_mut() {
45 let grad = Tensor::randn(param.dims()).mul_scalar(0.01);
46 *param = param.sub(&grad).unwrap();
47 }
48 }
49 }
50
51 pub fn get_update(&self, initial_params: &HashMap<String, Tensor>) -> HashMap<String, Tensor> {
53 let mut updates = HashMap::new();
54
55 for (name, param) in &self.parameters {
56 if let Some(initial) = initial_params.get(name) {
57 let update = param.sub(initial).unwrap();
58 updates.insert(name.clone(), update);
59 }
60 }
61
62 updates
63 }
64
65 pub fn set_parameters(&mut self, parameters: HashMap<String, Tensor>) {
67 self.parameters = parameters;
68 }
69}
70
71pub struct FederatedServer {
73 pub global_parameters: HashMap<String, Tensor>,
75 pub clients: Vec<FederatedClient>,
77 pub aggregation: AggregationStrategy,
79 pub round: usize,
81}
82
83#[derive(Debug, Clone, Copy)]
84pub enum AggregationStrategy {
85 FedAvg,
87 FedProx { mu: f32 },
89 FedAdam { beta1: f32, beta2: f32 },
91}
92
93impl FederatedServer {
94 pub fn new(global_parameters: HashMap<String, Tensor>, aggregation: AggregationStrategy) -> Self {
96 FederatedServer {
97 global_parameters,
98 clients: Vec::new(),
99 aggregation,
100 round: 0,
101 }
102 }
103
104 pub fn register_client(&mut self, client: FederatedClient) {
106 self.clients.push(client);
107 }
108
109 pub fn select_clients(&self, fraction: f32) -> Vec<usize> {
111 let num_clients = (self.clients.len() as f32 * fraction).ceil() as usize;
112 let mut rng = rand::thread_rng();
113
114 let mut selected = Vec::new();
115 let mut available: Vec<usize> = (0..self.clients.len()).collect();
116
117 for _ in 0..num_clients.min(available.len()) {
118 let idx = rng.gen_range(0..available.len());
119 selected.push(available.remove(idx));
120 }
121
122 selected
123 }
124
125 pub fn distribute_model(&mut self, client_ids: &[usize]) {
127 for &id in client_ids {
128 if let Some(client) = self.clients.get_mut(id) {
129 client.set_parameters(self.global_parameters.clone());
130 }
131 }
132 }
133
134 fn aggregate_fedavg(&self, client_ids: &[usize]) -> HashMap<String, Tensor> {
136 let mut aggregated = HashMap::new();
137 let mut total_samples = 0;
138
139 let initial_params = self.global_parameters.clone();
141
142 for &id in client_ids {
144 if let Some(client) = self.clients.get(id) {
145 total_samples += client.num_samples;
146 let updates = client.get_update(&initial_params);
147
148 for (name, update) in updates {
149 let weighted_update = update.mul_scalar(client.num_samples as f32);
150
151 aggregated.entry(name)
152 .and_modify(|agg: &mut Tensor| *agg = agg.add(&weighted_update).unwrap())
153 .or_insert(weighted_update);
154 }
155 }
156 }
157
158 for (_, update) in aggregated.iter_mut() {
160 *update = update.div_scalar(total_samples as f32);
161 }
162
163 let mut new_params = HashMap::new();
165 for (name, param) in &self.global_parameters {
166 if let Some(update) = aggregated.get(name) {
167 new_params.insert(name.clone(), param.add(update).unwrap());
168 } else {
169 new_params.insert(name.clone(), param.clone());
170 }
171 }
172
173 new_params
174 }
175
176 fn aggregate_fedprox(&self, client_ids: &[usize], mu: f32) -> HashMap<String, Tensor> {
178 let mut updates = self.aggregate_fedavg(client_ids);
181
182 for (name, update) in updates.iter_mut() {
184 if let Some(global_param) = self.global_parameters.get(name) {
185 let diff = update.sub(global_param).unwrap();
186 let damped = diff.mul_scalar(1.0 / (1.0 + mu));
187 *update = global_param.add(&damped).unwrap();
188 }
189 }
190
191 updates
192 }
193
194 pub fn aggregate(&mut self, client_ids: &[usize]) {
196 let new_params = match self.aggregation {
197 AggregationStrategy::FedAvg => self.aggregate_fedavg(client_ids),
198 AggregationStrategy::FedProx { mu } => self.aggregate_fedprox(client_ids, mu),
199 AggregationStrategy::FedAdam { .. } => {
200 self.aggregate_fedavg(client_ids)
202 }
203 };
204
205 self.global_parameters = new_params;
206 self.round += 1;
207 }
208
209 pub fn train_round(&mut self, client_fraction: f32, local_epochs: usize, batch_size: usize) {
211 let selected_clients = self.select_clients(client_fraction);
213
214 self.distribute_model(&selected_clients);
216
217 for &id in &selected_clients {
219 if let Some(client) = self.clients.get_mut(id) {
220 client.local_train(local_epochs, batch_size);
221 }
222 }
223
224 self.aggregate(&selected_clients);
226 }
227
228 pub fn current_round(&self) -> usize {
230 self.round
231 }
232}
233
234pub struct SecureAggregation {
236 num_clients: usize,
238 threshold: usize,
240}
241
242impl SecureAggregation {
243 pub fn new(num_clients: usize, threshold: usize) -> Self {
245 SecureAggregation {
246 num_clients,
247 threshold,
248 }
249 }
250
251 pub fn share(&self, value: f32) -> Vec<f32> {
253 let mut rng = rand::thread_rng();
254 let mut shares = Vec::with_capacity(self.num_clients);
255
256 let mut sum = 0.0;
258 for _ in 0..self.num_clients - 1 {
259 let share: f32 = rng.gen_range(-1.0..1.0);
260 shares.push(share);
261 sum += share;
262 }
263
264 shares.push(value - sum);
266
267 shares
268 }
269
270 pub fn reconstruct(&self, shares: &[f32]) -> f32 {
272 if shares.len() < self.threshold {
273 return 0.0;
274 }
275
276 shares.iter().sum()
277 }
278
279 pub fn aggregate_secure(&self, client_updates: &[HashMap<String, Tensor>]) -> HashMap<String, Tensor> {
281 let mut aggregated = HashMap::new();
282
283 if let Some(first_update) = client_updates.first() {
285 for (name, _) in first_update {
286 let mut param_values = Vec::new();
288 for update in client_updates {
289 if let Some(tensor) = update.get(name) {
290 param_values.push(tensor.clone());
291 }
292 }
293
294 if !param_values.is_empty() {
296 let mut sum = param_values[0].clone();
297 for tensor in ¶m_values[1..] {
298 sum = sum.add(tensor).unwrap();
299 }
300 let avg = sum.div_scalar(param_values.len() as f32);
301 aggregated.insert(name.clone(), avg);
302 }
303 }
304 }
305
306 aggregated
307 }
308}
309
310pub struct DifferentialPrivacy {
312 pub epsilon: f32,
314 pub sensitivity: f32,
316 pub noise_scale: f32,
318}
319
320impl DifferentialPrivacy {
321 pub fn new(epsilon: f32, sensitivity: f32) -> Self {
323 let noise_scale = sensitivity / epsilon;
324
325 DifferentialPrivacy {
326 epsilon,
327 sensitivity,
328 noise_scale,
329 }
330 }
331
332 pub fn add_noise(&self, tensor: &Tensor) -> Tensor {
334 let noise = Tensor::randn(tensor.dims()).mul_scalar(self.noise_scale);
335 tensor.add(&noise).unwrap()
336 }
337
338 pub fn clip_gradients(&self, gradients: &HashMap<String, Tensor>, max_norm: f32) -> HashMap<String, Tensor> {
340 let mut clipped = HashMap::new();
341
342 let mut global_norm_sq = 0.0;
344 for (_, grad) in gradients {
345 let data = grad.data_f32();
346 global_norm_sq += data.iter().map(|x| x * x).sum::<f32>();
347 }
348 let global_norm = global_norm_sq.sqrt();
349
350 let clip_factor = if global_norm > max_norm {
352 max_norm / global_norm
353 } else {
354 1.0
355 };
356
357 for (name, grad) in gradients {
358 let clipped_grad = grad.mul_scalar(clip_factor);
359 clipped.insert(name.clone(), clipped_grad);
360 }
361
362 clipped
363 }
364
365 pub fn privatize(&self, updates: &HashMap<String, Tensor>) -> HashMap<String, Tensor> {
367 let mut private_updates = HashMap::new();
368
369 for (name, update) in updates {
370 let noisy_update = self.add_noise(update);
371 private_updates.insert(name.clone(), noisy_update);
372 }
373
374 private_updates
375 }
376}
377
378#[cfg(test)]
379mod tests {
380 use super::*;
381
382 #[test]
383 fn test_federated_client() {
384 let mut params = HashMap::new();
385 params.insert("weight".to_string(), Tensor::ones(&[2, 2]));
386
387 let mut client = FederatedClient::new(0, params, 100);
388 assert_eq!(client.num_samples, 100);
389
390 client.local_train(1, 32);
391 }
393
394 #[test]
395 fn test_federated_server() {
396 let mut params = HashMap::new();
397 params.insert("weight".to_string(), Tensor::ones(&[2, 2]));
398
399 let mut server = FederatedServer::new(params.clone(), AggregationStrategy::FedAvg);
400
401 for i in 0..5 {
403 let client = FederatedClient::new(i, params.clone(), 100);
404 server.register_client(client);
405 }
406
407 assert_eq!(server.clients.len(), 5);
408
409 let selected = server.select_clients(0.5);
411 assert!(selected.len() >= 2 && selected.len() <= 3);
412 }
413
414 #[test]
415 fn test_secure_aggregation() {
416 let secure_agg = SecureAggregation::new(5, 3);
417
418 let value = 10.0;
419 let shares = secure_agg.share(value);
420 assert_eq!(shares.len(), 5);
421
422 let reconstructed = secure_agg.reconstruct(&shares);
423 assert!((reconstructed - value).abs() < 0.001);
424 }
425
426 #[test]
427 fn test_differential_privacy() {
428 let dp = DifferentialPrivacy::new(1.0, 1.0);
429
430 let tensor = Tensor::ones(&[2, 2]);
431 let noisy = dp.add_noise(&tensor);
432
433 assert_ne!(noisy.data_f32(), tensor.data_f32());
435 }
436}