use core::fmt;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct ClientId(String);
impl ClientId {
pub fn new(id: impl Into<String>) -> Self {
Self(id.into())
}
pub fn as_str(&self) -> &str {
&self.0
}
}
impl fmt::Display for ClientId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0)
}
}
#[derive(Debug, Clone)]
pub struct FederatedClient {
id: ClientId,
num_samples: usize,
availability: f64,
compute_capacity: f64,
bandwidth_mbps: f64,
data_distribution: DataDistribution,
privacy_budget: Option<f64>,
}
impl FederatedClient {
pub fn new(id: impl Into<String>, num_samples: usize, availability: f64) -> Self {
Self {
id: ClientId::new(id),
num_samples,
availability: availability.max(0.0).min(1.0),
compute_capacity: 1.0,
bandwidth_mbps: 10.0,
data_distribution: DataDistribution::Unknown,
privacy_budget: None,
}
}
pub fn with_compute_capacity(mut self, capacity: f64) -> Self {
self.compute_capacity = capacity;
self
}
pub fn with_bandwidth(mut self, bandwidth_mbps: f64) -> Self {
self.bandwidth_mbps = bandwidth_mbps;
self
}
pub fn with_data_distribution(mut self, distribution: DataDistribution) -> Self {
self.data_distribution = distribution;
self
}
pub fn with_privacy_budget(mut self, epsilon: f64) -> Self {
self.privacy_budget = Some(epsilon);
self
}
pub fn id(&self) -> &ClientId {
&self.id
}
pub fn num_samples(&self) -> usize {
self.num_samples
}
pub fn availability(&self) -> f64 {
self.availability
}
pub fn compute_capacity(&self) -> f64 {
self.compute_capacity
}
pub fn bandwidth_mbps(&self) -> f64 {
self.bandwidth_mbps
}
pub fn data_distribution(&self) -> &DataDistribution {
&self.data_distribution
}
pub fn privacy_budget(&self) -> Option<f64> {
self.privacy_budget
}
pub fn weight(&self) -> f64 {
self.num_samples as f64
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum DataDistribution {
IID,
LabelSkew { skew_factor: f64 },
FeatureSkew { skew_factor: f64 },
QuantitySkew,
Unknown,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AggregationStrategy {
FedAvg,
FedProx,
FedAdaptive,
SecureAggregation,
WeightedBySize,
WeightedByPerformance,
}
impl fmt::Display for AggregationStrategy {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
AggregationStrategy::FedAvg => write!(f, "FedAvg"),
AggregationStrategy::FedProx => write!(f, "FedProx"),
AggregationStrategy::FedAdaptive => write!(f, "FedAdaptive"),
AggregationStrategy::SecureAggregation => write!(f, "SecureAggregation"),
AggregationStrategy::WeightedBySize => write!(f, "WeightedBySize"),
AggregationStrategy::WeightedByPerformance => write!(f, "WeightedByPerformance"),
}
}
}
#[derive(Debug, Clone)]
pub struct ClientUpdate {
client_id: ClientId,
round: u64,
num_steps: usize,
loss: f64,
accuracy: Option<f64>,
metadata: Vec<(String, String)>,
}
impl ClientUpdate {
pub fn new(client_id: ClientId, round: u64, num_steps: usize, loss: f64) -> Self {
Self {
client_id,
round,
num_steps,
loss,
accuracy: None,
metadata: Vec::new(),
}
}
pub fn with_accuracy(mut self, accuracy: f64) -> Self {
self.accuracy = Some(accuracy);
self
}
pub fn add_metadata(&mut self, key: impl Into<String>, value: impl Into<String>) {
self.metadata.push((key.into(), value.into()));
}
pub fn client_id(&self) -> &ClientId {
&self.client_id
}
pub fn round(&self) -> u64 {
self.round
}
pub fn num_steps(&self) -> usize {
self.num_steps
}
pub fn loss(&self) -> f64 {
self.loss
}
pub fn accuracy(&self) -> Option<f64> {
self.accuracy
}
pub fn metadata(&self) -> &[(String, String)] {
&self.metadata
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ClientSelectionStrategy {
Random,
ByAvailability,
ByDataSize,
ByComputeCapacity,
PowerOfChoice { choices: usize },
All,
}
#[derive(Debug, Clone)]
pub struct ClientSelector {
strategy: ClientSelectionStrategy,
}
impl ClientSelector {
pub fn new(strategy: ClientSelectionStrategy) -> Self {
Self { strategy }
}
pub fn select(&self, clients: &[FederatedClient], num_select: usize) -> Vec<ClientId> {
match self.strategy {
ClientSelectionStrategy::Random => {
clients
.iter()
.take(num_select.min(clients.len()))
.map(|c| c.id().clone())
.collect()
}
ClientSelectionStrategy::ByAvailability => {
let mut sorted: Vec<_> = clients.iter().collect();
sorted.sort_by(|a, b| {
b.availability()
.partial_cmp(&a.availability())
.unwrap_or(core::cmp::Ordering::Equal)
});
sorted
.iter()
.take(num_select.min(clients.len()))
.map(|c| c.id().clone())
.collect()
}
ClientSelectionStrategy::ByDataSize => {
let mut sorted: Vec<_> = clients.iter().collect();
sorted.sort_by_key(|c| core::cmp::Reverse(c.num_samples()));
sorted
.iter()
.take(num_select.min(clients.len()))
.map(|c| c.id().clone())
.collect()
}
ClientSelectionStrategy::ByComputeCapacity => {
let mut sorted: Vec<_> = clients.iter().collect();
sorted.sort_by(|a, b| {
b.compute_capacity()
.partial_cmp(&a.compute_capacity())
.unwrap_or(core::cmp::Ordering::Equal)
});
sorted
.iter()
.take(num_select.min(clients.len()))
.map(|c| c.id().clone())
.collect()
}
ClientSelectionStrategy::PowerOfChoice { choices: _ } => {
clients
.iter()
.take(num_select.min(clients.len()))
.map(|c| c.id().clone())
.collect()
}
ClientSelectionStrategy::All => clients.iter().map(|c| c.id().clone()).collect(),
}
}
pub fn strategy(&self) -> ClientSelectionStrategy {
self.strategy
}
}
#[derive(Debug, Clone, Copy)]
pub struct PrivacyParameters {
epsilon: f64,
delta: f64,
clip_norm: f64,
noise_multiplier: f64,
}
impl PrivacyParameters {
pub fn new(epsilon: f64, delta: f64) -> Self {
Self {
epsilon,
delta,
clip_norm: 1.0,
noise_multiplier: 1.0,
}
}
pub fn with_clip_norm(mut self, clip_norm: f64) -> Self {
self.clip_norm = clip_norm;
self
}
pub fn with_noise_multiplier(mut self, noise_multiplier: f64) -> Self {
self.noise_multiplier = noise_multiplier;
self
}
pub fn epsilon(&self) -> f64 {
self.epsilon
}
pub fn delta(&self) -> f64 {
self.delta
}
pub fn clip_norm(&self) -> f64 {
self.clip_norm
}
pub fn noise_multiplier(&self) -> f64 {
self.noise_multiplier
}
pub fn is_exhausted(&self) -> bool {
self.epsilon <= 0.0
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CompressionTechnique {
None,
Quantization { bits: u8 },
Sparsification { k: usize },
Sketching,
LowRank { rank: usize },
}
#[derive(Debug, Clone)]
pub struct TrainingRound {
round: u64,
num_clients: usize,
num_completed: usize,
avg_loss: f64,
avg_accuracy: Option<f64>,
communication_cost: usize,
duration_secs: f64,
}
impl TrainingRound {
pub fn new(round: u64, num_clients: usize) -> Self {
Self {
round,
num_clients,
num_completed: 0,
avg_loss: 0.0,
avg_accuracy: None,
communication_cost: 0,
duration_secs: 0.0,
}
}
pub fn set_completed(&mut self, num_completed: usize) {
self.num_completed = num_completed;
}
pub fn set_avg_loss(&mut self, avg_loss: f64) {
self.avg_loss = avg_loss;
}
pub fn set_avg_accuracy(&mut self, avg_accuracy: f64) {
self.avg_accuracy = Some(avg_accuracy);
}
pub fn set_communication_cost(&mut self, cost: usize) {
self.communication_cost = cost;
}
pub fn set_duration(&mut self, duration_secs: f64) {
self.duration_secs = duration_secs;
}
pub fn round(&self) -> u64 {
self.round
}
pub fn num_clients(&self) -> usize {
self.num_clients
}
pub fn num_completed(&self) -> usize {
self.num_completed
}
pub fn avg_loss(&self) -> f64 {
self.avg_loss
}
pub fn avg_accuracy(&self) -> Option<f64> {
self.avg_accuracy
}
pub fn communication_cost(&self) -> usize {
self.communication_cost
}
pub fn duration_secs(&self) -> f64 {
self.duration_secs
}
pub fn completion_rate(&self) -> f64 {
if self.num_clients == 0 {
0.0
} else {
self.num_completed as f64 / self.num_clients as f64
}
}
}
#[derive(Debug, Clone)]
pub struct FairnessMetrics {
accuracy_variance: f64,
min_accuracy: f64,
max_accuracy: f64,
jains_index: f64,
}
impl FairnessMetrics {
pub fn new(
accuracy_variance: f64,
min_accuracy: f64,
max_accuracy: f64,
jains_index: f64,
) -> Self {
Self {
accuracy_variance,
min_accuracy,
max_accuracy,
jains_index,
}
}
pub fn accuracy_variance(&self) -> f64 {
self.accuracy_variance
}
pub fn min_accuracy(&self) -> f64 {
self.min_accuracy
}
pub fn max_accuracy(&self) -> f64 {
self.max_accuracy
}
pub fn jains_index(&self) -> f64 {
self.jains_index
}
pub fn is_fair(&self) -> bool {
self.jains_index > 0.8
}
}
#[derive(Debug, Clone)]
pub struct FederatedCoordinator {
current_round: u64,
strategy: AggregationStrategy,
privacy: Option<PrivacyParameters>,
compression: CompressionTechnique,
rounds: Vec<TrainingRound>,
}
impl FederatedCoordinator {
pub fn new(strategy: AggregationStrategy) -> Self {
Self {
current_round: 0,
strategy,
privacy: None,
compression: CompressionTechnique::None,
rounds: Vec::new(),
}
}
pub fn with_privacy(mut self, privacy: PrivacyParameters) -> Self {
self.privacy = Some(privacy);
self
}
pub fn with_compression(mut self, compression: CompressionTechnique) -> Self {
self.compression = compression;
self
}
pub fn start_round(&mut self, num_clients: usize) -> u64 {
self.current_round += 1;
self.rounds
.push(TrainingRound::new(self.current_round, num_clients));
self.current_round
}
pub fn complete_round(&mut self, avg_loss: f64, num_completed: usize) {
if let Some(round) = self.rounds.last_mut() {
round.set_avg_loss(avg_loss);
round.set_completed(num_completed);
}
}
pub fn current_round(&self) -> u64 {
self.current_round
}
pub fn strategy(&self) -> AggregationStrategy {
self.strategy
}
pub fn privacy(&self) -> Option<&PrivacyParameters> {
self.privacy.as_ref()
}
pub fn compression(&self) -> CompressionTechnique {
self.compression
}
pub fn rounds(&self) -> &[TrainingRound] {
&self.rounds
}
pub fn statistics(&self) -> CoordinatorStatistics {
let total_rounds = self.rounds.len();
let avg_completion_rate = if total_rounds > 0 {
self.rounds.iter().map(|r| r.completion_rate()).sum::<f64>() / total_rounds as f64
} else {
0.0
};
let total_communication = self.rounds.iter().map(|r| r.communication_cost()).sum();
CoordinatorStatistics {
total_rounds,
avg_completion_rate,
total_communication,
}
}
}
#[derive(Debug, Clone)]
pub struct CoordinatorStatistics {
pub total_rounds: usize,
pub avg_completion_rate: f64,
pub total_communication: usize,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_client_id() {
let id = ClientId::new("client_1");
assert_eq!(id.as_str(), "client_1");
assert_eq!(format!("{}", id), "client_1");
}
#[test]
fn test_federated_client_creation() {
let client = FederatedClient::new("client_1", 1000, 0.8);
assert_eq!(client.id().as_str(), "client_1");
assert_eq!(client.num_samples(), 1000);
assert_eq!(client.availability(), 0.8);
assert_eq!(client.weight(), 1000.0);
}
#[test]
fn test_client_with_builder() {
let client = FederatedClient::new("client_1", 1000, 0.8)
.with_compute_capacity(2.0)
.with_bandwidth(50.0)
.with_privacy_budget(1.0);
assert_eq!(client.compute_capacity(), 2.0);
assert_eq!(client.bandwidth_mbps(), 50.0);
assert_eq!(client.privacy_budget(), Some(1.0));
}
#[test]
fn test_data_distribution() {
let iid = DataDistribution::IID;
let label_skew = DataDistribution::LabelSkew { skew_factor: 0.5 };
let _feature_skew = DataDistribution::FeatureSkew { skew_factor: 0.3 };
assert_eq!(iid, DataDistribution::IID);
assert_ne!(iid, label_skew);
}
#[test]
fn test_aggregation_strategy_display() {
assert_eq!(format!("{}", AggregationStrategy::FedAvg), "FedAvg");
assert_eq!(format!("{}", AggregationStrategy::FedProx), "FedProx");
}
#[test]
fn test_client_update() {
let id = ClientId::new("client_1");
let mut update = ClientUpdate::new(id.clone(), 5, 100, 0.5).with_accuracy(0.85);
update.add_metadata("dataset", "mnist");
assert_eq!(update.client_id(), &id);
assert_eq!(update.round(), 5);
assert_eq!(update.num_steps(), 100);
assert_eq!(update.loss(), 0.5);
assert_eq!(update.accuracy(), Some(0.85));
assert_eq!(update.metadata().len(), 1);
}
#[test]
fn test_client_selector_random() {
let clients = vec![
FederatedClient::new("client_1", 1000, 0.8),
FederatedClient::new("client_2", 500, 0.6),
FederatedClient::new("client_3", 800, 0.9),
];
let selector = ClientSelector::new(ClientSelectionStrategy::Random);
let selected = selector.select(&clients, 2);
assert_eq!(selected.len(), 2);
}
#[test]
fn test_client_selector_by_data_size() {
let clients = vec![
FederatedClient::new("client_1", 1000, 0.8),
FederatedClient::new("client_2", 500, 0.6),
FederatedClient::new("client_3", 800, 0.9),
];
let selector = ClientSelector::new(ClientSelectionStrategy::ByDataSize);
let selected = selector.select(&clients, 2);
assert_eq!(selected.len(), 2);
assert_eq!(selected[0].as_str(), "client_1"); }
#[test]
fn test_client_selector_all() {
let clients = vec![
FederatedClient::new("client_1", 1000, 0.8),
FederatedClient::new("client_2", 500, 0.6),
];
let selector = ClientSelector::new(ClientSelectionStrategy::All);
let selected = selector.select(&clients, 10);
assert_eq!(selected.len(), 2); }
#[test]
fn test_privacy_parameters() {
let privacy = PrivacyParameters::new(1.0, 1e-5)
.with_clip_norm(2.0)
.with_noise_multiplier(0.5);
assert_eq!(privacy.epsilon(), 1.0);
assert_eq!(privacy.delta(), 1e-5);
assert_eq!(privacy.clip_norm(), 2.0);
assert_eq!(privacy.noise_multiplier(), 0.5);
assert!(!privacy.is_exhausted());
}
#[test]
fn test_compression_techniques() {
let _none = CompressionTechnique::None;
let _quant = CompressionTechnique::Quantization { bits: 8 };
let _sparse = CompressionTechnique::Sparsification { k: 100 };
let _sketch = CompressionTechnique::Sketching;
let _low_rank = CompressionTechnique::LowRank { rank: 10 };
}
#[test]
fn test_training_round() {
let mut round = TrainingRound::new(1, 10);
round.set_completed(8);
round.set_avg_loss(0.5);
round.set_avg_accuracy(0.85);
round.set_communication_cost(1024);
round.set_duration(60.0);
assert_eq!(round.round(), 1);
assert_eq!(round.num_clients(), 10);
assert_eq!(round.num_completed(), 8);
assert_eq!(round.avg_loss(), 0.5);
assert_eq!(round.avg_accuracy(), Some(0.85));
assert_eq!(round.communication_cost(), 1024);
assert_eq!(round.duration_secs(), 60.0);
assert_eq!(round.completion_rate(), 0.8);
}
#[test]
fn test_fairness_metrics() {
let metrics = FairnessMetrics::new(0.01, 0.80, 0.90, 0.85);
assert_eq!(metrics.accuracy_variance(), 0.01);
assert_eq!(metrics.min_accuracy(), 0.80);
assert_eq!(metrics.max_accuracy(), 0.90);
assert_eq!(metrics.jains_index(), 0.85);
assert!(metrics.is_fair());
}
#[test]
fn test_federated_coordinator() {
let mut coordinator = FederatedCoordinator::new(AggregationStrategy::FedAvg)
.with_privacy(PrivacyParameters::new(1.0, 1e-5))
.with_compression(CompressionTechnique::Quantization { bits: 8 });
assert_eq!(coordinator.current_round(), 0);
let round1 = coordinator.start_round(10);
assert_eq!(round1, 1);
coordinator.complete_round(0.5, 8);
let stats = coordinator.statistics();
assert_eq!(stats.total_rounds, 1);
}
#[test]
fn test_coordinator_multiple_rounds() {
let mut coordinator = FederatedCoordinator::new(AggregationStrategy::FedAvg);
coordinator.start_round(10);
coordinator.complete_round(0.6, 9);
coordinator.start_round(10);
coordinator.complete_round(0.4, 8);
coordinator.start_round(10);
coordinator.complete_round(0.3, 10);
assert_eq!(coordinator.current_round(), 3);
assert_eq!(coordinator.rounds().len(), 3);
let stats = coordinator.statistics();
assert_eq!(stats.total_rounds, 3);
assert!(stats.avg_completion_rate > 0.8);
}
#[test]
fn test_client_selection_strategies() {
let _random = ClientSelectionStrategy::Random;
let _avail = ClientSelectionStrategy::ByAvailability;
let _size = ClientSelectionStrategy::ByDataSize;
let _compute = ClientSelectionStrategy::ByComputeCapacity;
let _power = ClientSelectionStrategy::PowerOfChoice { choices: 3 };
let _all = ClientSelectionStrategy::All;
}
#[test]
fn test_availability_clamping() {
let client1 = FederatedClient::new("c1", 100, 1.5); let client2 = FederatedClient::new("c2", 100, -0.1);
assert_eq!(client1.availability(), 1.0);
assert_eq!(client2.availability(), 0.0);
}
}