1use crate::error::{OptimError, Result};
10use scirs2_core::ndarray::{Array, Array1, Dimension, ScalarOperand, Zip};
11use scirs2_core::numeric::Float;
12use std::fmt::Debug;
13
14#[derive(Debug, Clone)]
16pub struct FedProxConfig<A: Float> {
17 pub mu: A,
21 pub local_epochs: usize,
23 pub participation_rate: A,
25 pub num_clients: usize,
27}
28
29impl<A: Float + ScalarOperand + Debug + Send + Sync + 'static> FedProxConfig<A> {
30 pub fn new(num_clients: usize) -> Self {
32 Self {
33 mu: A::from(0.01).unwrap_or_else(|| A::zero()),
34 local_epochs: 5,
35 participation_rate: A::one(),
36 num_clients,
37 }
38 }
39
40 pub fn validate(&self) -> Result<()> {
42 if self.mu < A::zero() {
43 return Err(OptimError::InvalidConfig(
44 "Proximal term coefficient mu must be non-negative".to_string(),
45 ));
46 }
47 if self.local_epochs == 0 {
48 return Err(OptimError::InvalidConfig(
49 "local_epochs must be at least 1".to_string(),
50 ));
51 }
52 if self.participation_rate <= A::zero() || self.participation_rate > A::one() {
53 return Err(OptimError::InvalidConfig(
54 "participation_rate must be in (0.0, 1.0]".to_string(),
55 ));
56 }
57 if self.num_clients == 0 {
58 return Err(OptimError::InvalidConfig(
59 "num_clients must be at least 1".to_string(),
60 ));
61 }
62 Ok(())
63 }
64}
65
66#[derive(Debug)]
68pub struct FedProxConfigBuilder<A: Float> {
69 mu: Option<A>,
70 local_epochs: Option<usize>,
71 participation_rate: Option<A>,
72 num_clients: usize,
73}
74
75impl<A: Float + ScalarOperand + Debug + Send + Sync + 'static> FedProxConfigBuilder<A> {
76 pub fn new(num_clients: usize) -> Self {
78 Self {
79 mu: None,
80 local_epochs: None,
81 participation_rate: None,
82 num_clients,
83 }
84 }
85
86 pub fn mu(mut self, mu: A) -> Self {
88 self.mu = Some(mu);
89 self
90 }
91
92 pub fn local_epochs(mut self, epochs: usize) -> Self {
94 self.local_epochs = Some(epochs);
95 self
96 }
97
98 pub fn participation_rate(mut self, rate: A) -> Self {
100 self.participation_rate = Some(rate);
101 self
102 }
103
104 pub fn build(self) -> Result<FedProxConfig<A>> {
106 let config = FedProxConfig {
107 mu: self
108 .mu
109 .unwrap_or_else(|| A::from(0.01).unwrap_or_else(|| A::zero())),
110 local_epochs: self.local_epochs.unwrap_or(5),
111 participation_rate: self.participation_rate.unwrap_or_else(|| A::one()),
112 num_clients: self.num_clients,
113 };
114 config.validate()?;
115 Ok(config)
116 }
117}
118
119#[derive(Debug, Clone)]
121pub struct ClientUpdate<A: Float, D: Dimension> {
122 pub client_id: usize,
124 pub parameters: Vec<Array<A, D>>,
126 pub data_size: usize,
128}
129
130#[derive(Debug)]
137pub struct FedProxOptimizer<
138 A: Float + ScalarOperand + Debug + Send + Sync + 'static,
139 D: Dimension + Send + Sync + Clone,
140> {
141 config: FedProxConfig<A>,
143 global_parameters: Option<Vec<Array<A, D>>>,
145 client_updates: Vec<ClientUpdate<A, D>>,
147 round_count: usize,
149}
150
151impl<
152 A: Float + ScalarOperand + Debug + Send + Sync + 'static,
153 D: Dimension + Send + Sync + Clone,
154 > FedProxOptimizer<A, D>
155{
156 pub fn new(config: FedProxConfig<A>) -> Self {
158 Self {
159 config,
160 global_parameters: None,
161 client_updates: Vec::new(),
162 round_count: 0,
163 }
164 }
165
166 pub fn builder(num_clients: usize) -> FedProxConfigBuilder<A> {
168 FedProxConfigBuilder::new(num_clients)
169 }
170
171 pub fn set_global_parameters(&mut self, params: &[Array<A, D>]) -> Result<()> {
176 if params.is_empty() {
177 return Err(OptimError::InvalidParameter(
178 "Global parameters cannot be empty".to_string(),
179 ));
180 }
181 self.global_parameters = Some(params.to_vec());
182 self.client_updates.clear();
184 Ok(())
185 }
186
187 pub fn local_update(
193 &self,
194 params: &[Array<A, D>],
195 gradients: &[Array<A, D>],
196 lr: A,
197 ) -> Result<Vec<Array<A, D>>> {
198 if params.len() != gradients.len() {
199 return Err(OptimError::DimensionMismatch(format!(
200 "Parameters length ({}) does not match gradients length ({})",
201 params.len(),
202 gradients.len()
203 )));
204 }
205
206 let proximal_grads = self.compute_proximal_gradient(params)?;
207
208 let mut updated = Vec::with_capacity(params.len());
209 for (i, (param, grad)) in params.iter().zip(gradients.iter()).enumerate() {
210 if param.shape() != grad.shape() {
211 return Err(OptimError::DimensionMismatch(format!(
212 "Parameter shape {:?} does not match gradient shape {:?} at index {}",
213 param.shape(),
214 grad.shape(),
215 i
216 )));
217 }
218
219 let prox = &proximal_grads[i];
220 let mut new_param = param.clone();
222 Zip::from(&mut new_param)
223 .and(grad)
224 .and(prox)
225 .for_each(|w, &g, &p| {
226 *w = *w - lr * (g + p);
227 });
228 updated.push(new_param);
229 }
230
231 Ok(updated)
232 }
233
234 pub fn compute_proximal_gradient(&self, params: &[Array<A, D>]) -> Result<Vec<Array<A, D>>> {
239 let global = self.global_parameters.as_ref().ok_or_else(|| {
240 OptimError::InvalidState(
241 "Global parameters not set. Call set_global_parameters first.".to_string(),
242 )
243 })?;
244
245 if params.len() != global.len() {
246 return Err(OptimError::DimensionMismatch(format!(
247 "Local parameters length ({}) does not match global parameters length ({})",
248 params.len(),
249 global.len()
250 )));
251 }
252
253 let mu = self.config.mu;
254 let mut prox_grads = Vec::with_capacity(params.len());
255
256 for (i, (local, global_p)) in params.iter().zip(global.iter()).enumerate() {
257 if local.shape() != global_p.shape() {
258 return Err(OptimError::DimensionMismatch(format!(
259 "Local param shape {:?} != global param shape {:?} at index {}",
260 local.shape(),
261 global_p.shape(),
262 i
263 )));
264 }
265
266 let mut prox = local.clone();
267 Zip::from(&mut prox).and(global_p).for_each(|l, &g| {
268 *l = mu * (*l - g);
269 });
270 prox_grads.push(prox);
271 }
272
273 Ok(prox_grads)
274 }
275
276 pub fn submit_client_update(
278 &mut self,
279 client_id: usize,
280 params: &[Array<A, D>],
281 data_size: usize,
282 ) -> Result<()> {
283 if params.is_empty() {
284 return Err(OptimError::InvalidParameter(
285 "Client parameters cannot be empty".to_string(),
286 ));
287 }
288 if data_size == 0 {
289 return Err(OptimError::InvalidParameter(
290 "Client data_size must be positive".to_string(),
291 ));
292 }
293
294 if let Some(ref global) = self.global_parameters {
296 if params.len() != global.len() {
297 return Err(OptimError::DimensionMismatch(format!(
298 "Client {} parameter count ({}) does not match global ({})",
299 client_id,
300 params.len(),
301 global.len()
302 )));
303 }
304 for (i, (cp, gp)) in params.iter().zip(global.iter()).enumerate() {
305 if cp.shape() != gp.shape() {
306 return Err(OptimError::DimensionMismatch(format!(
307 "Client {} param shape {:?} != global shape {:?} at index {}",
308 client_id,
309 cp.shape(),
310 gp.shape(),
311 i
312 )));
313 }
314 }
315 }
316
317 self.client_updates.push(ClientUpdate {
318 client_id,
319 parameters: params.to_vec(),
320 data_size,
321 });
322
323 Ok(())
324 }
325
326 pub fn aggregate_updates(&mut self) -> Result<Vec<Array<A, D>>> {
332 if self.client_updates.is_empty() {
333 return Err(OptimError::InvalidState(
334 "No client updates to aggregate".to_string(),
335 ));
336 }
337
338 let total_data: usize = self.client_updates.iter().map(|u| u.data_size).sum();
340 if total_data == 0 {
341 return Err(OptimError::InvalidState(
342 "Total data size across clients is zero".to_string(),
343 ));
344 }
345 let total_data_a = A::from(total_data).ok_or_else(|| {
346 OptimError::ComputationError("Cannot convert total data size to float".to_string())
347 })?;
348
349 let num_params = self.client_updates[0].parameters.len();
351
352 let mut aggregated: Vec<Array<A, D>> = self.client_updates[0]
354 .parameters
355 .iter()
356 .map(|p| Array::zeros(p.raw_dim()))
357 .collect();
358
359 for update in &self.client_updates {
361 if update.parameters.len() != num_params {
362 return Err(OptimError::DimensionMismatch(format!(
363 "Client {} has {} parameters, expected {}",
364 update.client_id,
365 update.parameters.len(),
366 num_params
367 )));
368 }
369
370 let weight = A::from(update.data_size).ok_or_else(|| {
371 OptimError::ComputationError("Cannot convert client data size to float".to_string())
372 })? / total_data_a;
373
374 for (agg, client_param) in aggregated.iter_mut().zip(update.parameters.iter()) {
375 Zip::from(agg).and(client_param).for_each(|a, &c| {
376 *a = *a + weight * c;
377 });
378 }
379 }
380
381 self.global_parameters = Some(aggregated.clone());
383 self.client_updates.clear();
384 self.round_count += 1;
385
386 Ok(aggregated)
387 }
388
389 pub fn get_round_count(&self) -> usize {
391 self.round_count
392 }
393
394 pub fn get_config(&self) -> &FedProxConfig<A> {
396 &self.config
397 }
398
399 pub fn get_global_parameters(&self) -> Option<&Vec<Array<A, D>>> {
401 self.global_parameters.as_ref()
402 }
403
404 pub fn get_pending_updates_count(&self) -> usize {
406 self.client_updates.len()
407 }
408}
409
410#[cfg(test)]
411mod tests {
412 use super::*;
413 use scirs2_core::ndarray::{Array1, Ix1};
414
415 #[test]
416 fn test_fedprox_config_builder() {
417 let config: FedProxConfig<f64> =
419 FedProxConfigBuilder::new(10).build().expect("build failed");
420 assert_eq!(config.num_clients, 10);
421 assert!((config.mu - 0.01).abs() < 1e-10);
422 assert_eq!(config.local_epochs, 5);
423 assert!((config.participation_rate - 1.0).abs() < 1e-10);
424
425 let config: FedProxConfig<f64> = FedProxConfigBuilder::new(20)
427 .mu(0.1)
428 .local_epochs(10)
429 .participation_rate(0.5)
430 .build()
431 .expect("build failed");
432 assert_eq!(config.num_clients, 20);
433 assert!((config.mu - 0.1).abs() < 1e-10);
434 assert_eq!(config.local_epochs, 10);
435 assert!((config.participation_rate - 0.5).abs() < 1e-10);
436
437 let result: std::result::Result<FedProxConfig<f64>, _> =
439 FedProxConfigBuilder::new(5).mu(-0.1).build();
440 assert!(result.is_err());
441
442 let result: std::result::Result<FedProxConfig<f64>, _> =
444 FedProxConfigBuilder::new(5).participation_rate(0.0).build();
445 assert!(result.is_err());
446
447 let result: std::result::Result<FedProxConfig<f64>, _> =
448 FedProxConfigBuilder::new(5).participation_rate(1.5).build();
449 assert!(result.is_err());
450
451 let result: std::result::Result<FedProxConfig<f64>, _> =
453 FedProxConfigBuilder::new(5).local_epochs(0).build();
454 assert!(result.is_err());
455 }
456
457 #[test]
458 fn test_fedprox_set_global_parameters() {
459 let config: FedProxConfig<f64> = FedProxConfig::new(3);
460 let mut optimizer: FedProxOptimizer<f64, Ix1> = FedProxOptimizer::new(config);
461
462 let params = vec![
464 Array1::from_vec(vec![1.0, 2.0, 3.0]),
465 Array1::from_vec(vec![4.0, 5.0]),
466 ];
467 assert!(optimizer.set_global_parameters(¶ms).is_ok());
468 assert!(optimizer.get_global_parameters().is_some());
469
470 let stored = optimizer
472 .get_global_parameters()
473 .expect("should have params");
474 assert_eq!(stored.len(), 2);
475 assert_eq!(stored[0].len(), 3);
476 assert_eq!(stored[1].len(), 2);
477
478 let empty: Vec<Array1<f64>> = vec![];
480 assert!(optimizer.set_global_parameters(&empty).is_err());
481 }
482
483 #[test]
484 fn test_local_update_with_proximal_term() {
485 let config: FedProxConfig<f64> = FedProxConfigBuilder::new(2)
486 .mu(0.1)
487 .build()
488 .expect("build failed");
489 let mut optimizer: FedProxOptimizer<f64, Ix1> = FedProxOptimizer::new(config);
490
491 let global = vec![Array1::from_vec(vec![1.0, 2.0, 3.0])];
493 optimizer
494 .set_global_parameters(&global)
495 .expect("set params failed");
496
497 let local = vec![Array1::from_vec(vec![1.5, 2.5, 3.5])];
499 let grads = vec![Array1::from_vec(vec![0.1, 0.2, 0.3])];
500 let lr: f64 = 0.01;
501
502 let updated = optimizer
503 .local_update(&local, &grads, lr)
504 .expect("local_update failed");
505
506 assert!((updated[0][0] - 1.4985).abs() < 1e-10);
511
512 assert!((updated[0][1] - 2.4975).abs() < 1e-10);
516
517 assert!((updated[0][2] - 3.4965).abs() < 1e-10);
521 }
522
523 #[test]
524 fn test_proximal_gradient_computation() {
525 let config: FedProxConfig<f64> = FedProxConfigBuilder::new(2)
526 .mu(0.5)
527 .build()
528 .expect("build failed");
529 let mut optimizer: FedProxOptimizer<f64, Ix1> = FedProxOptimizer::new(config);
530
531 let global = vec![Array1::from_vec(vec![1.0, 2.0, 3.0])];
532 optimizer
533 .set_global_parameters(&global)
534 .expect("set params failed");
535
536 let local = vec![Array1::from_vec(vec![2.0, 4.0, 6.0])];
537 let prox = optimizer
538 .compute_proximal_gradient(&local)
539 .expect("proximal gradient failed");
540
541 assert!((prox[0][0] - 0.5).abs() < 1e-10);
544 assert!((prox[0][1] - 1.0).abs() < 1e-10);
545 assert!((prox[0][2] - 1.5).abs() < 1e-10);
546
547 let config2: FedProxConfig<f64> = FedProxConfig::new(2);
549 let optimizer2: FedProxOptimizer<f64, Ix1> = FedProxOptimizer::new(config2);
550 assert!(optimizer2.compute_proximal_gradient(&local).is_err());
551
552 let mismatched = vec![
554 Array1::from_vec(vec![1.0, 2.0]),
555 Array1::from_vec(vec![3.0]),
556 ];
557 assert!(optimizer.compute_proximal_gradient(&mismatched).is_err());
558 }
559
560 #[test]
561 fn test_aggregate_updates_weighted() {
562 let config: FedProxConfig<f64> = FedProxConfig::new(3);
563 let mut optimizer: FedProxOptimizer<f64, Ix1> = FedProxOptimizer::new(config);
564
565 let global = vec![Array1::from_vec(vec![0.0, 0.0])];
566 optimizer
567 .set_global_parameters(&global)
568 .expect("set params failed");
569
570 optimizer
572 .submit_client_update(0, &[Array1::from_vec(vec![2.0, 4.0])], 100)
573 .expect("submit failed");
574
575 optimizer
577 .submit_client_update(1, &[Array1::from_vec(vec![4.0, 6.0])], 300)
578 .expect("submit failed");
579
580 assert_eq!(optimizer.get_pending_updates_count(), 2);
581
582 let aggregated = optimizer.aggregate_updates().expect("aggregate failed");
583
584 assert!((aggregated[0][0] - 3.5).abs() < 1e-10);
589 assert!((aggregated[0][1] - 5.5).abs() < 1e-10);
590
591 assert_eq!(optimizer.get_round_count(), 1);
593 assert_eq!(optimizer.get_pending_updates_count(), 0);
595
596 assert!(optimizer.aggregate_updates().is_err());
598 }
599
600 #[test]
601 fn test_fedprox_mu_zero_is_fedavg() {
602 let config_prox: FedProxConfig<f64> = FedProxConfigBuilder::new(2)
605 .mu(0.0)
606 .build()
607 .expect("build failed");
608 let mut optimizer_prox: FedProxOptimizer<f64, Ix1> = FedProxOptimizer::new(config_prox);
609
610 let global = vec![Array1::from_vec(vec![1.0, 2.0, 3.0])];
611 optimizer_prox
612 .set_global_parameters(&global)
613 .expect("set params failed");
614
615 let local = vec![Array1::from_vec(vec![5.0, 10.0, 15.0])];
617 let prox_grad = optimizer_prox
618 .compute_proximal_gradient(&local)
619 .expect("proximal gradient failed");
620 for val in prox_grad[0].iter() {
621 assert!(
622 val.abs() < 1e-15,
623 "Proximal gradient should be zero when mu=0"
624 );
625 }
626
627 let grads = vec![Array1::from_vec(vec![0.1, 0.2, 0.3])];
629 let lr: f64 = 0.01;
630 let updated = optimizer_prox
631 .local_update(&local, &grads, lr)
632 .expect("local_update failed");
633
634 assert!((updated[0][0] - 4.999).abs() < 1e-10);
638 assert!((updated[0][1] - 9.998).abs() < 1e-10);
639 assert!((updated[0][2] - 14.997).abs() < 1e-10);
640
641 optimizer_prox
643 .submit_client_update(0, &[Array1::from_vec(vec![2.0, 3.0, 4.0])], 200)
644 .expect("submit failed");
645 optimizer_prox
646 .submit_client_update(1, &[Array1::from_vec(vec![4.0, 5.0, 6.0])], 200)
647 .expect("submit failed");
648
649 let agg = optimizer_prox
650 .aggregate_updates()
651 .expect("aggregate failed");
652
653 assert!((agg[0][0] - 3.0).abs() < 1e-10);
656 assert!((agg[0][1] - 4.0).abs() < 1e-10);
657 assert!((agg[0][2] - 5.0).abs() < 1e-10);
658 }
659}