Skip to main content

optirs_core/distributed/
fedprox.rs

1// FedProx (Federated Proximal) optimizer implementation
2//
3// FedProx extends FedAvg by adding a proximal term to the local objective function,
4// which helps handle systems heterogeneity (stragglers) and statistical heterogeneity
5// (non-IID data) in federated learning.
6//
7// Reference: Li et al., "Federated Optimization in Heterogeneous Networks" (MLSys 2020)
8
9use crate::error::{OptimError, Result};
10use scirs2_core::ndarray::{Array, Array1, Dimension, ScalarOperand, Zip};
11use scirs2_core::numeric::Float;
12use std::fmt::Debug;
13
14/// Configuration for the FedProx optimizer
15#[derive(Debug, Clone)]
16pub struct FedProxConfig<A: Float> {
17    /// Proximal term coefficient (mu).
18    /// Controls the strength of the proximal regularization.
19    /// When mu=0, FedProx degenerates to FedAvg.
20    pub mu: A,
21    /// Number of local training epochs per communication round
22    pub local_epochs: usize,
23    /// Fraction of clients participating in each round (0.0, 1.0]
24    pub participation_rate: A,
25    /// Total number of clients in the federation
26    pub num_clients: usize,
27}
28
29impl<A: Float + ScalarOperand + Debug + Send + Sync + 'static> FedProxConfig<A> {
30    /// Create a new FedProxConfig with default values
31    pub fn new(num_clients: usize) -> Self {
32        Self {
33            mu: A::from(0.01).unwrap_or_else(|| A::zero()),
34            local_epochs: 5,
35            participation_rate: A::one(),
36            num_clients,
37        }
38    }
39
40    /// Validate the configuration
41    pub fn validate(&self) -> Result<()> {
42        if self.mu < A::zero() {
43            return Err(OptimError::InvalidConfig(
44                "Proximal term coefficient mu must be non-negative".to_string(),
45            ));
46        }
47        if self.local_epochs == 0 {
48            return Err(OptimError::InvalidConfig(
49                "local_epochs must be at least 1".to_string(),
50            ));
51        }
52        if self.participation_rate <= A::zero() || self.participation_rate > A::one() {
53            return Err(OptimError::InvalidConfig(
54                "participation_rate must be in (0.0, 1.0]".to_string(),
55            ));
56        }
57        if self.num_clients == 0 {
58            return Err(OptimError::InvalidConfig(
59                "num_clients must be at least 1".to_string(),
60            ));
61        }
62        Ok(())
63    }
64}
65
66/// Builder for FedProxConfig
67#[derive(Debug)]
68pub struct FedProxConfigBuilder<A: Float> {
69    mu: Option<A>,
70    local_epochs: Option<usize>,
71    participation_rate: Option<A>,
72    num_clients: usize,
73}
74
75impl<A: Float + ScalarOperand + Debug + Send + Sync + 'static> FedProxConfigBuilder<A> {
76    /// Create a new builder with the required number of clients
77    pub fn new(num_clients: usize) -> Self {
78        Self {
79            mu: None,
80            local_epochs: None,
81            participation_rate: None,
82            num_clients,
83        }
84    }
85
86    /// Set the proximal term coefficient
87    pub fn mu(mut self, mu: A) -> Self {
88        self.mu = Some(mu);
89        self
90    }
91
92    /// Set the number of local training epochs
93    pub fn local_epochs(mut self, epochs: usize) -> Self {
94        self.local_epochs = Some(epochs);
95        self
96    }
97
98    /// Set the client participation rate
99    pub fn participation_rate(mut self, rate: A) -> Self {
100        self.participation_rate = Some(rate);
101        self
102    }
103
104    /// Build the configuration, validating all parameters
105    pub fn build(self) -> Result<FedProxConfig<A>> {
106        let config = FedProxConfig {
107            mu: self
108                .mu
109                .unwrap_or_else(|| A::from(0.01).unwrap_or_else(|| A::zero())),
110            local_epochs: self.local_epochs.unwrap_or(5),
111            participation_rate: self.participation_rate.unwrap_or_else(|| A::one()),
112            num_clients: self.num_clients,
113        };
114        config.validate()?;
115        Ok(config)
116    }
117}
118
119/// A client's update submission containing updated parameters and metadata
120#[derive(Debug, Clone)]
121pub struct ClientUpdate<A: Float, D: Dimension> {
122    /// Unique identifier for the client
123    pub client_id: usize,
124    /// Updated model parameters after local training
125    pub parameters: Vec<Array<A, D>>,
126    /// Size of the client's local dataset (used for weighted aggregation)
127    pub data_size: usize,
128}
129
130/// FedProx optimizer for federated learning
131///
132/// FedProx adds a proximal term mu/2 * ||w - w_global||^2 to each client's
133/// local objective, producing a gradient correction of mu * (w - w_global).
134/// This encourages local models to stay close to the global model, improving
135/// convergence under heterogeneous conditions.
136#[derive(Debug)]
137pub struct FedProxOptimizer<
138    A: Float + ScalarOperand + Debug + Send + Sync + 'static,
139    D: Dimension + Send + Sync + Clone,
140> {
141    /// Configuration
142    config: FedProxConfig<A>,
143    /// Stored global model parameters (the "anchor" for proximal term)
144    global_parameters: Option<Vec<Array<A, D>>>,
145    /// Collected client updates for the current round
146    client_updates: Vec<ClientUpdate<A, D>>,
147    /// Number of completed communication rounds
148    round_count: usize,
149}
150
151impl<
152        A: Float + ScalarOperand + Debug + Send + Sync + 'static,
153        D: Dimension + Send + Sync + Clone,
154    > FedProxOptimizer<A, D>
155{
156    /// Create a new FedProxOptimizer with the given configuration
157    pub fn new(config: FedProxConfig<A>) -> Self {
158        Self {
159            config,
160            global_parameters: None,
161            client_updates: Vec::new(),
162            round_count: 0,
163        }
164    }
165
166    /// Convenience method to create a builder
167    pub fn builder(num_clients: usize) -> FedProxConfigBuilder<A> {
168        FedProxConfigBuilder::new(num_clients)
169    }
170
171    /// Store the current global model parameters
172    ///
173    /// This must be called before performing local updates, as the proximal
174    /// term references the global parameters.
175    pub fn set_global_parameters(&mut self, params: &[Array<A, D>]) -> Result<()> {
176        if params.is_empty() {
177            return Err(OptimError::InvalidParameter(
178                "Global parameters cannot be empty".to_string(),
179            ));
180        }
181        self.global_parameters = Some(params.to_vec());
182        // Clear any previous client updates when starting a new round
183        self.client_updates.clear();
184        Ok(())
185    }
186
187    /// Perform a local update step with the FedProx proximal gradient
188    ///
189    /// Computes: w_new = w - lr * (gradient + mu * (w - w_global))
190    ///
191    /// When mu=0, this reduces to standard SGD: w_new = w - lr * gradient
192    pub fn local_update(
193        &self,
194        params: &[Array<A, D>],
195        gradients: &[Array<A, D>],
196        lr: A,
197    ) -> Result<Vec<Array<A, D>>> {
198        if params.len() != gradients.len() {
199            return Err(OptimError::DimensionMismatch(format!(
200                "Parameters length ({}) does not match gradients length ({})",
201                params.len(),
202                gradients.len()
203            )));
204        }
205
206        let proximal_grads = self.compute_proximal_gradient(params)?;
207
208        let mut updated = Vec::with_capacity(params.len());
209        for (i, (param, grad)) in params.iter().zip(gradients.iter()).enumerate() {
210            if param.shape() != grad.shape() {
211                return Err(OptimError::DimensionMismatch(format!(
212                    "Parameter shape {:?} does not match gradient shape {:?} at index {}",
213                    param.shape(),
214                    grad.shape(),
215                    i
216                )));
217            }
218
219            let prox = &proximal_grads[i];
220            // w_new = w - lr * (gradient + proximal_gradient)
221            let mut new_param = param.clone();
222            Zip::from(&mut new_param)
223                .and(grad)
224                .and(prox)
225                .for_each(|w, &g, &p| {
226                    *w = *w - lr * (g + p);
227                });
228            updated.push(new_param);
229        }
230
231        Ok(updated)
232    }
233
234    /// Compute the proximal gradient: mu * (params - global_params)
235    ///
236    /// This gradient term pulls local parameters towards the global model,
237    /// preventing excessive divergence during local training.
238    pub fn compute_proximal_gradient(&self, params: &[Array<A, D>]) -> Result<Vec<Array<A, D>>> {
239        let global = self.global_parameters.as_ref().ok_or_else(|| {
240            OptimError::InvalidState(
241                "Global parameters not set. Call set_global_parameters first.".to_string(),
242            )
243        })?;
244
245        if params.len() != global.len() {
246            return Err(OptimError::DimensionMismatch(format!(
247                "Local parameters length ({}) does not match global parameters length ({})",
248                params.len(),
249                global.len()
250            )));
251        }
252
253        let mu = self.config.mu;
254        let mut prox_grads = Vec::with_capacity(params.len());
255
256        for (i, (local, global_p)) in params.iter().zip(global.iter()).enumerate() {
257            if local.shape() != global_p.shape() {
258                return Err(OptimError::DimensionMismatch(format!(
259                    "Local param shape {:?} != global param shape {:?} at index {}",
260                    local.shape(),
261                    global_p.shape(),
262                    i
263                )));
264            }
265
266            let mut prox = local.clone();
267            Zip::from(&mut prox).and(global_p).for_each(|l, &g| {
268                *l = mu * (*l - g);
269            });
270            prox_grads.push(prox);
271        }
272
273        Ok(prox_grads)
274    }
275
276    /// Submit a client's updated parameters after local training
277    pub fn submit_client_update(
278        &mut self,
279        client_id: usize,
280        params: &[Array<A, D>],
281        data_size: usize,
282    ) -> Result<()> {
283        if params.is_empty() {
284            return Err(OptimError::InvalidParameter(
285                "Client parameters cannot be empty".to_string(),
286            ));
287        }
288        if data_size == 0 {
289            return Err(OptimError::InvalidParameter(
290                "Client data_size must be positive".to_string(),
291            ));
292        }
293
294        // Validate against global parameters shape if available
295        if let Some(ref global) = self.global_parameters {
296            if params.len() != global.len() {
297                return Err(OptimError::DimensionMismatch(format!(
298                    "Client {} parameter count ({}) does not match global ({})",
299                    client_id,
300                    params.len(),
301                    global.len()
302                )));
303            }
304            for (i, (cp, gp)) in params.iter().zip(global.iter()).enumerate() {
305                if cp.shape() != gp.shape() {
306                    return Err(OptimError::DimensionMismatch(format!(
307                        "Client {} param shape {:?} != global shape {:?} at index {}",
308                        client_id,
309                        cp.shape(),
310                        gp.shape(),
311                        i
312                    )));
313                }
314            }
315        }
316
317        self.client_updates.push(ClientUpdate {
318            client_id,
319            parameters: params.to_vec(),
320            data_size,
321        });
322
323        Ok(())
324    }
325
326    /// Aggregate client updates using weighted averaging (FedAvg-style)
327    ///
328    /// Each client's contribution is weighted by its data_size relative to
329    /// the total data across all participating clients. This produces a new
330    /// global model for the next communication round.
331    pub fn aggregate_updates(&mut self) -> Result<Vec<Array<A, D>>> {
332        if self.client_updates.is_empty() {
333            return Err(OptimError::InvalidState(
334                "No client updates to aggregate".to_string(),
335            ));
336        }
337
338        // Compute total data size across all clients
339        let total_data: usize = self.client_updates.iter().map(|u| u.data_size).sum();
340        if total_data == 0 {
341            return Err(OptimError::InvalidState(
342                "Total data size across clients is zero".to_string(),
343            ));
344        }
345        let total_data_a = A::from(total_data).ok_or_else(|| {
346            OptimError::ComputationError("Cannot convert total data size to float".to_string())
347        })?;
348
349        // Determine number of parameter tensors from first client
350        let num_params = self.client_updates[0].parameters.len();
351
352        // Initialize aggregated parameters with zeros (same shape as first client)
353        let mut aggregated: Vec<Array<A, D>> = self.client_updates[0]
354            .parameters
355            .iter()
356            .map(|p| Array::zeros(p.raw_dim()))
357            .collect();
358
359        // Weighted sum
360        for update in &self.client_updates {
361            if update.parameters.len() != num_params {
362                return Err(OptimError::DimensionMismatch(format!(
363                    "Client {} has {} parameters, expected {}",
364                    update.client_id,
365                    update.parameters.len(),
366                    num_params
367                )));
368            }
369
370            let weight = A::from(update.data_size).ok_or_else(|| {
371                OptimError::ComputationError("Cannot convert client data size to float".to_string())
372            })? / total_data_a;
373
374            for (agg, client_param) in aggregated.iter_mut().zip(update.parameters.iter()) {
375                Zip::from(agg).and(client_param).for_each(|a, &c| {
376                    *a = *a + weight * c;
377                });
378            }
379        }
380
381        // Update global parameters and increment round
382        self.global_parameters = Some(aggregated.clone());
383        self.client_updates.clear();
384        self.round_count += 1;
385
386        Ok(aggregated)
387    }
388
389    /// Get the number of completed communication rounds
390    pub fn get_round_count(&self) -> usize {
391        self.round_count
392    }
393
394    /// Get a reference to the current configuration
395    pub fn get_config(&self) -> &FedProxConfig<A> {
396        &self.config
397    }
398
399    /// Get a reference to the current global parameters, if set
400    pub fn get_global_parameters(&self) -> Option<&Vec<Array<A, D>>> {
401        self.global_parameters.as_ref()
402    }
403
404    /// Get the number of client updates collected so far in this round
405    pub fn get_pending_updates_count(&self) -> usize {
406        self.client_updates.len()
407    }
408}
409
410#[cfg(test)]
411mod tests {
412    use super::*;
413    use scirs2_core::ndarray::{Array1, Ix1};
414
415    #[test]
416    fn test_fedprox_config_builder() {
417        // Test default configuration
418        let config: FedProxConfig<f64> =
419            FedProxConfigBuilder::new(10).build().expect("build failed");
420        assert_eq!(config.num_clients, 10);
421        assert!((config.mu - 0.01).abs() < 1e-10);
422        assert_eq!(config.local_epochs, 5);
423        assert!((config.participation_rate - 1.0).abs() < 1e-10);
424
425        // Test custom configuration
426        let config: FedProxConfig<f64> = FedProxConfigBuilder::new(20)
427            .mu(0.1)
428            .local_epochs(10)
429            .participation_rate(0.5)
430            .build()
431            .expect("build failed");
432        assert_eq!(config.num_clients, 20);
433        assert!((config.mu - 0.1).abs() < 1e-10);
434        assert_eq!(config.local_epochs, 10);
435        assert!((config.participation_rate - 0.5).abs() < 1e-10);
436
437        // Test invalid mu
438        let result: std::result::Result<FedProxConfig<f64>, _> =
439            FedProxConfigBuilder::new(5).mu(-0.1).build();
440        assert!(result.is_err());
441
442        // Test invalid participation_rate
443        let result: std::result::Result<FedProxConfig<f64>, _> =
444            FedProxConfigBuilder::new(5).participation_rate(0.0).build();
445        assert!(result.is_err());
446
447        let result: std::result::Result<FedProxConfig<f64>, _> =
448            FedProxConfigBuilder::new(5).participation_rate(1.5).build();
449        assert!(result.is_err());
450
451        // Test invalid local_epochs
452        let result: std::result::Result<FedProxConfig<f64>, _> =
453            FedProxConfigBuilder::new(5).local_epochs(0).build();
454        assert!(result.is_err());
455    }
456
457    #[test]
458    fn test_fedprox_set_global_parameters() {
459        let config: FedProxConfig<f64> = FedProxConfig::new(3);
460        let mut optimizer: FedProxOptimizer<f64, Ix1> = FedProxOptimizer::new(config);
461
462        // Setting valid parameters
463        let params = vec![
464            Array1::from_vec(vec![1.0, 2.0, 3.0]),
465            Array1::from_vec(vec![4.0, 5.0]),
466        ];
467        assert!(optimizer.set_global_parameters(&params).is_ok());
468        assert!(optimizer.get_global_parameters().is_some());
469
470        // Verify stored parameters match
471        let stored = optimizer
472            .get_global_parameters()
473            .expect("should have params");
474        assert_eq!(stored.len(), 2);
475        assert_eq!(stored[0].len(), 3);
476        assert_eq!(stored[1].len(), 2);
477
478        // Setting empty parameters should fail
479        let empty: Vec<Array1<f64>> = vec![];
480        assert!(optimizer.set_global_parameters(&empty).is_err());
481    }
482
483    #[test]
484    fn test_local_update_with_proximal_term() {
485        let config: FedProxConfig<f64> = FedProxConfigBuilder::new(2)
486            .mu(0.1)
487            .build()
488            .expect("build failed");
489        let mut optimizer: FedProxOptimizer<f64, Ix1> = FedProxOptimizer::new(config);
490
491        // Set global parameters
492        let global = vec![Array1::from_vec(vec![1.0, 2.0, 3.0])];
493        optimizer
494            .set_global_parameters(&global)
495            .expect("set params failed");
496
497        // Local parameters diverge from global
498        let local = vec![Array1::from_vec(vec![1.5, 2.5, 3.5])];
499        let grads = vec![Array1::from_vec(vec![0.1, 0.2, 0.3])];
500        let lr: f64 = 0.01;
501
502        let updated = optimizer
503            .local_update(&local, &grads, lr)
504            .expect("local_update failed");
505
506        // Expected: w_new = w - lr * (grad + mu * (w - w_global))
507        // For element 0: 1.5 - 0.01 * (0.1 + 0.1 * (1.5 - 1.0))
508        //              = 1.5 - 0.01 * (0.1 + 0.05)
509        //              = 1.5 - 0.0015 = 1.4985
510        assert!((updated[0][0] - 1.4985).abs() < 1e-10);
511
512        // For element 1: 2.5 - 0.01 * (0.2 + 0.1 * (2.5 - 2.0))
513        //              = 2.5 - 0.01 * (0.2 + 0.05)
514        //              = 2.5 - 0.0025 = 2.4975
515        assert!((updated[0][1] - 2.4975).abs() < 1e-10);
516
517        // For element 2: 3.5 - 0.01 * (0.3 + 0.1 * (3.5 - 3.0))
518        //              = 3.5 - 0.01 * (0.3 + 0.05)
519        //              = 3.5 - 0.0035 = 3.4965
520        assert!((updated[0][2] - 3.4965).abs() < 1e-10);
521    }
522
523    #[test]
524    fn test_proximal_gradient_computation() {
525        let config: FedProxConfig<f64> = FedProxConfigBuilder::new(2)
526            .mu(0.5)
527            .build()
528            .expect("build failed");
529        let mut optimizer: FedProxOptimizer<f64, Ix1> = FedProxOptimizer::new(config);
530
531        let global = vec![Array1::from_vec(vec![1.0, 2.0, 3.0])];
532        optimizer
533            .set_global_parameters(&global)
534            .expect("set params failed");
535
536        let local = vec![Array1::from_vec(vec![2.0, 4.0, 6.0])];
537        let prox = optimizer
538            .compute_proximal_gradient(&local)
539            .expect("proximal gradient failed");
540
541        // Expected: mu * (local - global) = 0.5 * (local - global)
542        // [0.5*(2-1), 0.5*(4-2), 0.5*(6-3)] = [0.5, 1.0, 1.5]
543        assert!((prox[0][0] - 0.5).abs() < 1e-10);
544        assert!((prox[0][1] - 1.0).abs() < 1e-10);
545        assert!((prox[0][2] - 1.5).abs() < 1e-10);
546
547        // Test error when global parameters not set
548        let config2: FedProxConfig<f64> = FedProxConfig::new(2);
549        let optimizer2: FedProxOptimizer<f64, Ix1> = FedProxOptimizer::new(config2);
550        assert!(optimizer2.compute_proximal_gradient(&local).is_err());
551
552        // Test dimension mismatch
553        let mismatched = vec![
554            Array1::from_vec(vec![1.0, 2.0]),
555            Array1::from_vec(vec![3.0]),
556        ];
557        assert!(optimizer.compute_proximal_gradient(&mismatched).is_err());
558    }
559
560    #[test]
561    fn test_aggregate_updates_weighted() {
562        let config: FedProxConfig<f64> = FedProxConfig::new(3);
563        let mut optimizer: FedProxOptimizer<f64, Ix1> = FedProxOptimizer::new(config);
564
565        let global = vec![Array1::from_vec(vec![0.0, 0.0])];
566        optimizer
567            .set_global_parameters(&global)
568            .expect("set params failed");
569
570        // Client 0: params=[2.0, 4.0], data_size=100
571        optimizer
572            .submit_client_update(0, &[Array1::from_vec(vec![2.0, 4.0])], 100)
573            .expect("submit failed");
574
575        // Client 1: params=[4.0, 6.0], data_size=300
576        optimizer
577            .submit_client_update(1, &[Array1::from_vec(vec![4.0, 6.0])], 300)
578            .expect("submit failed");
579
580        assert_eq!(optimizer.get_pending_updates_count(), 2);
581
582        let aggregated = optimizer.aggregate_updates().expect("aggregate failed");
583
584        // Weighted average: (100/400)*[2,4] + (300/400)*[4,6]
585        //                 = 0.25*[2,4] + 0.75*[4,6]
586        //                 = [0.5+3.0, 1.0+4.5]
587        //                 = [3.5, 5.5]
588        assert!((aggregated[0][0] - 3.5).abs() < 1e-10);
589        assert!((aggregated[0][1] - 5.5).abs() < 1e-10);
590
591        // Round count should have incremented
592        assert_eq!(optimizer.get_round_count(), 1);
593        // Client updates should be cleared
594        assert_eq!(optimizer.get_pending_updates_count(), 0);
595
596        // Aggregating with no updates should fail
597        assert!(optimizer.aggregate_updates().is_err());
598    }
599
600    #[test]
601    fn test_fedprox_mu_zero_is_fedavg() {
602        // When mu=0, FedProx should behave identically to FedAvg
603        // (the proximal gradient becomes zero)
604        let config_prox: FedProxConfig<f64> = FedProxConfigBuilder::new(2)
605            .mu(0.0)
606            .build()
607            .expect("build failed");
608        let mut optimizer_prox: FedProxOptimizer<f64, Ix1> = FedProxOptimizer::new(config_prox);
609
610        let global = vec![Array1::from_vec(vec![1.0, 2.0, 3.0])];
611        optimizer_prox
612            .set_global_parameters(&global)
613            .expect("set params failed");
614
615        // Verify proximal gradient is zero when mu=0
616        let local = vec![Array1::from_vec(vec![5.0, 10.0, 15.0])];
617        let prox_grad = optimizer_prox
618            .compute_proximal_gradient(&local)
619            .expect("proximal gradient failed");
620        for val in prox_grad[0].iter() {
621            assert!(
622                val.abs() < 1e-15,
623                "Proximal gradient should be zero when mu=0"
624            );
625        }
626
627        // Verify local update with mu=0 is plain SGD
628        let grads = vec![Array1::from_vec(vec![0.1, 0.2, 0.3])];
629        let lr: f64 = 0.01;
630        let updated = optimizer_prox
631            .local_update(&local, &grads, lr)
632            .expect("local_update failed");
633
634        // Plain SGD: w_new = w - lr * grad
635        // [5.0 - 0.01*0.1, 10.0 - 0.01*0.2, 15.0 - 0.01*0.3]
636        // = [4.999, 9.998, 14.997]
637        assert!((updated[0][0] - 4.999).abs() < 1e-10);
638        assert!((updated[0][1] - 9.998).abs() < 1e-10);
639        assert!((updated[0][2] - 14.997).abs() < 1e-10);
640
641        // Verify aggregation with mu=0 gives same weighted average as FedAvg
642        optimizer_prox
643            .submit_client_update(0, &[Array1::from_vec(vec![2.0, 3.0, 4.0])], 200)
644            .expect("submit failed");
645        optimizer_prox
646            .submit_client_update(1, &[Array1::from_vec(vec![4.0, 5.0, 6.0])], 200)
647            .expect("submit failed");
648
649        let agg = optimizer_prox
650            .aggregate_updates()
651            .expect("aggregate failed");
652
653        // Equal weights (200 each) => arithmetic mean
654        // [(2+4)/2, (3+5)/2, (4+6)/2] = [3.0, 4.0, 5.0]
655        assert!((agg[0][0] - 3.0).abs() < 1e-10);
656        assert!((agg[0][1] - 4.0).abs() < 1e-10);
657        assert!((agg[0][2] - 5.0).abs() < 1e-10);
658    }
659}