use crate::error::{OptimError, Result};
use scirs2_core::ndarray::{Array, Array1, Dimension, ScalarOperand, Zip};
use scirs2_core::numeric::Float;
use std::fmt::Debug;
#[derive(Debug, Clone)]
pub struct FedProxConfig<A: Float> {
pub mu: A,
pub local_epochs: usize,
pub participation_rate: A,
pub num_clients: usize,
}
impl<A: Float + ScalarOperand + Debug + Send + Sync + 'static> FedProxConfig<A> {
pub fn new(num_clients: usize) -> Self {
Self {
mu: A::from(0.01).unwrap_or_else(|| A::zero()),
local_epochs: 5,
participation_rate: A::one(),
num_clients,
}
}
pub fn validate(&self) -> Result<()> {
if self.mu < A::zero() {
return Err(OptimError::InvalidConfig(
"Proximal term coefficient mu must be non-negative".to_string(),
));
}
if self.local_epochs == 0 {
return Err(OptimError::InvalidConfig(
"local_epochs must be at least 1".to_string(),
));
}
if self.participation_rate <= A::zero() || self.participation_rate > A::one() {
return Err(OptimError::InvalidConfig(
"participation_rate must be in (0.0, 1.0]".to_string(),
));
}
if self.num_clients == 0 {
return Err(OptimError::InvalidConfig(
"num_clients must be at least 1".to_string(),
));
}
Ok(())
}
}
#[derive(Debug)]
pub struct FedProxConfigBuilder<A: Float> {
mu: Option<A>,
local_epochs: Option<usize>,
participation_rate: Option<A>,
num_clients: usize,
}
impl<A: Float + ScalarOperand + Debug + Send + Sync + 'static> FedProxConfigBuilder<A> {
pub fn new(num_clients: usize) -> Self {
Self {
mu: None,
local_epochs: None,
participation_rate: None,
num_clients,
}
}
pub fn mu(mut self, mu: A) -> Self {
self.mu = Some(mu);
self
}
pub fn local_epochs(mut self, epochs: usize) -> Self {
self.local_epochs = Some(epochs);
self
}
pub fn participation_rate(mut self, rate: A) -> Self {
self.participation_rate = Some(rate);
self
}
pub fn build(self) -> Result<FedProxConfig<A>> {
let config = FedProxConfig {
mu: self
.mu
.unwrap_or_else(|| A::from(0.01).unwrap_or_else(|| A::zero())),
local_epochs: self.local_epochs.unwrap_or(5),
participation_rate: self.participation_rate.unwrap_or_else(|| A::one()),
num_clients: self.num_clients,
};
config.validate()?;
Ok(config)
}
}
#[derive(Debug, Clone)]
pub struct ClientUpdate<A: Float, D: Dimension> {
pub client_id: usize,
pub parameters: Vec<Array<A, D>>,
pub data_size: usize,
}
#[derive(Debug)]
pub struct FedProxOptimizer<
A: Float + ScalarOperand + Debug + Send + Sync + 'static,
D: Dimension + Send + Sync + Clone,
> {
config: FedProxConfig<A>,
global_parameters: Option<Vec<Array<A, D>>>,
client_updates: Vec<ClientUpdate<A, D>>,
round_count: usize,
}
impl<
A: Float + ScalarOperand + Debug + Send + Sync + 'static,
D: Dimension + Send + Sync + Clone,
> FedProxOptimizer<A, D>
{
pub fn new(config: FedProxConfig<A>) -> Self {
Self {
config,
global_parameters: None,
client_updates: Vec::new(),
round_count: 0,
}
}
pub fn builder(num_clients: usize) -> FedProxConfigBuilder<A> {
FedProxConfigBuilder::new(num_clients)
}
pub fn set_global_parameters(&mut self, params: &[Array<A, D>]) -> Result<()> {
if params.is_empty() {
return Err(OptimError::InvalidParameter(
"Global parameters cannot be empty".to_string(),
));
}
self.global_parameters = Some(params.to_vec());
self.client_updates.clear();
Ok(())
}
pub fn local_update(
&self,
params: &[Array<A, D>],
gradients: &[Array<A, D>],
lr: A,
) -> Result<Vec<Array<A, D>>> {
if params.len() != gradients.len() {
return Err(OptimError::DimensionMismatch(format!(
"Parameters length ({}) does not match gradients length ({})",
params.len(),
gradients.len()
)));
}
let proximal_grads = self.compute_proximal_gradient(params)?;
let mut updated = Vec::with_capacity(params.len());
for (i, (param, grad)) in params.iter().zip(gradients.iter()).enumerate() {
if param.shape() != grad.shape() {
return Err(OptimError::DimensionMismatch(format!(
"Parameter shape {:?} does not match gradient shape {:?} at index {}",
param.shape(),
grad.shape(),
i
)));
}
let prox = &proximal_grads[i];
let mut new_param = param.clone();
Zip::from(&mut new_param)
.and(grad)
.and(prox)
.for_each(|w, &g, &p| {
*w = *w - lr * (g + p);
});
updated.push(new_param);
}
Ok(updated)
}
pub fn compute_proximal_gradient(&self, params: &[Array<A, D>]) -> Result<Vec<Array<A, D>>> {
let global = self.global_parameters.as_ref().ok_or_else(|| {
OptimError::InvalidState(
"Global parameters not set. Call set_global_parameters first.".to_string(),
)
})?;
if params.len() != global.len() {
return Err(OptimError::DimensionMismatch(format!(
"Local parameters length ({}) does not match global parameters length ({})",
params.len(),
global.len()
)));
}
let mu = self.config.mu;
let mut prox_grads = Vec::with_capacity(params.len());
for (i, (local, global_p)) in params.iter().zip(global.iter()).enumerate() {
if local.shape() != global_p.shape() {
return Err(OptimError::DimensionMismatch(format!(
"Local param shape {:?} != global param shape {:?} at index {}",
local.shape(),
global_p.shape(),
i
)));
}
let mut prox = local.clone();
Zip::from(&mut prox).and(global_p).for_each(|l, &g| {
*l = mu * (*l - g);
});
prox_grads.push(prox);
}
Ok(prox_grads)
}
pub fn submit_client_update(
&mut self,
client_id: usize,
params: &[Array<A, D>],
data_size: usize,
) -> Result<()> {
if params.is_empty() {
return Err(OptimError::InvalidParameter(
"Client parameters cannot be empty".to_string(),
));
}
if data_size == 0 {
return Err(OptimError::InvalidParameter(
"Client data_size must be positive".to_string(),
));
}
if let Some(ref global) = self.global_parameters {
if params.len() != global.len() {
return Err(OptimError::DimensionMismatch(format!(
"Client {} parameter count ({}) does not match global ({})",
client_id,
params.len(),
global.len()
)));
}
for (i, (cp, gp)) in params.iter().zip(global.iter()).enumerate() {
if cp.shape() != gp.shape() {
return Err(OptimError::DimensionMismatch(format!(
"Client {} param shape {:?} != global shape {:?} at index {}",
client_id,
cp.shape(),
gp.shape(),
i
)));
}
}
}
self.client_updates.push(ClientUpdate {
client_id,
parameters: params.to_vec(),
data_size,
});
Ok(())
}
pub fn aggregate_updates(&mut self) -> Result<Vec<Array<A, D>>> {
if self.client_updates.is_empty() {
return Err(OptimError::InvalidState(
"No client updates to aggregate".to_string(),
));
}
let total_data: usize = self.client_updates.iter().map(|u| u.data_size).sum();
if total_data == 0 {
return Err(OptimError::InvalidState(
"Total data size across clients is zero".to_string(),
));
}
let total_data_a = A::from(total_data).ok_or_else(|| {
OptimError::ComputationError("Cannot convert total data size to float".to_string())
})?;
let num_params = self.client_updates[0].parameters.len();
let mut aggregated: Vec<Array<A, D>> = self.client_updates[0]
.parameters
.iter()
.map(|p| Array::zeros(p.raw_dim()))
.collect();
for update in &self.client_updates {
if update.parameters.len() != num_params {
return Err(OptimError::DimensionMismatch(format!(
"Client {} has {} parameters, expected {}",
update.client_id,
update.parameters.len(),
num_params
)));
}
let weight = A::from(update.data_size).ok_or_else(|| {
OptimError::ComputationError("Cannot convert client data size to float".to_string())
})? / total_data_a;
for (agg, client_param) in aggregated.iter_mut().zip(update.parameters.iter()) {
Zip::from(agg).and(client_param).for_each(|a, &c| {
*a = *a + weight * c;
});
}
}
self.global_parameters = Some(aggregated.clone());
self.client_updates.clear();
self.round_count += 1;
Ok(aggregated)
}
pub fn get_round_count(&self) -> usize {
self.round_count
}
pub fn get_config(&self) -> &FedProxConfig<A> {
&self.config
}
pub fn get_global_parameters(&self) -> Option<&Vec<Array<A, D>>> {
self.global_parameters.as_ref()
}
pub fn get_pending_updates_count(&self) -> usize {
self.client_updates.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::{Array1, Ix1};
#[test]
fn test_fedprox_config_builder() {
let config: FedProxConfig<f64> =
FedProxConfigBuilder::new(10).build().expect("build failed");
assert_eq!(config.num_clients, 10);
assert!((config.mu - 0.01).abs() < 1e-10);
assert_eq!(config.local_epochs, 5);
assert!((config.participation_rate - 1.0).abs() < 1e-10);
let config: FedProxConfig<f64> = FedProxConfigBuilder::new(20)
.mu(0.1)
.local_epochs(10)
.participation_rate(0.5)
.build()
.expect("build failed");
assert_eq!(config.num_clients, 20);
assert!((config.mu - 0.1).abs() < 1e-10);
assert_eq!(config.local_epochs, 10);
assert!((config.participation_rate - 0.5).abs() < 1e-10);
let result: std::result::Result<FedProxConfig<f64>, _> =
FedProxConfigBuilder::new(5).mu(-0.1).build();
assert!(result.is_err());
let result: std::result::Result<FedProxConfig<f64>, _> =
FedProxConfigBuilder::new(5).participation_rate(0.0).build();
assert!(result.is_err());
let result: std::result::Result<FedProxConfig<f64>, _> =
FedProxConfigBuilder::new(5).participation_rate(1.5).build();
assert!(result.is_err());
let result: std::result::Result<FedProxConfig<f64>, _> =
FedProxConfigBuilder::new(5).local_epochs(0).build();
assert!(result.is_err());
}
#[test]
fn test_fedprox_set_global_parameters() {
let config: FedProxConfig<f64> = FedProxConfig::new(3);
let mut optimizer: FedProxOptimizer<f64, Ix1> = FedProxOptimizer::new(config);
let params = vec![
Array1::from_vec(vec![1.0, 2.0, 3.0]),
Array1::from_vec(vec![4.0, 5.0]),
];
assert!(optimizer.set_global_parameters(¶ms).is_ok());
assert!(optimizer.get_global_parameters().is_some());
let stored = optimizer
.get_global_parameters()
.expect("should have params");
assert_eq!(stored.len(), 2);
assert_eq!(stored[0].len(), 3);
assert_eq!(stored[1].len(), 2);
let empty: Vec<Array1<f64>> = vec![];
assert!(optimizer.set_global_parameters(&empty).is_err());
}
#[test]
fn test_local_update_with_proximal_term() {
let config: FedProxConfig<f64> = FedProxConfigBuilder::new(2)
.mu(0.1)
.build()
.expect("build failed");
let mut optimizer: FedProxOptimizer<f64, Ix1> = FedProxOptimizer::new(config);
let global = vec![Array1::from_vec(vec![1.0, 2.0, 3.0])];
optimizer
.set_global_parameters(&global)
.expect("set params failed");
let local = vec![Array1::from_vec(vec![1.5, 2.5, 3.5])];
let grads = vec![Array1::from_vec(vec![0.1, 0.2, 0.3])];
let lr: f64 = 0.01;
let updated = optimizer
.local_update(&local, &grads, lr)
.expect("local_update failed");
assert!((updated[0][0] - 1.4985).abs() < 1e-10);
assert!((updated[0][1] - 2.4975).abs() < 1e-10);
assert!((updated[0][2] - 3.4965).abs() < 1e-10);
}
#[test]
fn test_proximal_gradient_computation() {
let config: FedProxConfig<f64> = FedProxConfigBuilder::new(2)
.mu(0.5)
.build()
.expect("build failed");
let mut optimizer: FedProxOptimizer<f64, Ix1> = FedProxOptimizer::new(config);
let global = vec![Array1::from_vec(vec![1.0, 2.0, 3.0])];
optimizer
.set_global_parameters(&global)
.expect("set params failed");
let local = vec![Array1::from_vec(vec![2.0, 4.0, 6.0])];
let prox = optimizer
.compute_proximal_gradient(&local)
.expect("proximal gradient failed");
assert!((prox[0][0] - 0.5).abs() < 1e-10);
assert!((prox[0][1] - 1.0).abs() < 1e-10);
assert!((prox[0][2] - 1.5).abs() < 1e-10);
let config2: FedProxConfig<f64> = FedProxConfig::new(2);
let optimizer2: FedProxOptimizer<f64, Ix1> = FedProxOptimizer::new(config2);
assert!(optimizer2.compute_proximal_gradient(&local).is_err());
let mismatched = vec![
Array1::from_vec(vec![1.0, 2.0]),
Array1::from_vec(vec![3.0]),
];
assert!(optimizer.compute_proximal_gradient(&mismatched).is_err());
}
#[test]
fn test_aggregate_updates_weighted() {
let config: FedProxConfig<f64> = FedProxConfig::new(3);
let mut optimizer: FedProxOptimizer<f64, Ix1> = FedProxOptimizer::new(config);
let global = vec![Array1::from_vec(vec![0.0, 0.0])];
optimizer
.set_global_parameters(&global)
.expect("set params failed");
optimizer
.submit_client_update(0, &[Array1::from_vec(vec![2.0, 4.0])], 100)
.expect("submit failed");
optimizer
.submit_client_update(1, &[Array1::from_vec(vec![4.0, 6.0])], 300)
.expect("submit failed");
assert_eq!(optimizer.get_pending_updates_count(), 2);
let aggregated = optimizer.aggregate_updates().expect("aggregate failed");
assert!((aggregated[0][0] - 3.5).abs() < 1e-10);
assert!((aggregated[0][1] - 5.5).abs() < 1e-10);
assert_eq!(optimizer.get_round_count(), 1);
assert_eq!(optimizer.get_pending_updates_count(), 0);
assert!(optimizer.aggregate_updates().is_err());
}
#[test]
fn test_fedprox_mu_zero_is_fedavg() {
let config_prox: FedProxConfig<f64> = FedProxConfigBuilder::new(2)
.mu(0.0)
.build()
.expect("build failed");
let mut optimizer_prox: FedProxOptimizer<f64, Ix1> = FedProxOptimizer::new(config_prox);
let global = vec![Array1::from_vec(vec![1.0, 2.0, 3.0])];
optimizer_prox
.set_global_parameters(&global)
.expect("set params failed");
let local = vec![Array1::from_vec(vec![5.0, 10.0, 15.0])];
let prox_grad = optimizer_prox
.compute_proximal_gradient(&local)
.expect("proximal gradient failed");
for val in prox_grad[0].iter() {
assert!(
val.abs() < 1e-15,
"Proximal gradient should be zero when mu=0"
);
}
let grads = vec![Array1::from_vec(vec![0.1, 0.2, 0.3])];
let lr: f64 = 0.01;
let updated = optimizer_prox
.local_update(&local, &grads, lr)
.expect("local_update failed");
assert!((updated[0][0] - 4.999).abs() < 1e-10);
assert!((updated[0][1] - 9.998).abs() < 1e-10);
assert!((updated[0][2] - 14.997).abs() < 1e-10);
optimizer_prox
.submit_client_update(0, &[Array1::from_vec(vec![2.0, 3.0, 4.0])], 200)
.expect("submit failed");
optimizer_prox
.submit_client_update(1, &[Array1::from_vec(vec![4.0, 5.0, 6.0])], 200)
.expect("submit failed");
let agg = optimizer_prox
.aggregate_updates()
.expect("aggregate failed");
assert!((agg[0][0] - 3.0).abs() < 1e-10);
assert!((agg[0][1] - 4.0).abs() < 1e-10);
assert!((agg[0][2] - 5.0).abs() < 1e-10);
}
}