use crate::{Dataset, Result};
use scirs2_core::random::rand_prelude::*;
use scirs2_core::random::rngs::StdRng;
use scirs2_core::random::SeedableRng;
use scirs2_core::random::{rng, Random};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use tenflowers_core::{Tensor, TensorError};
pub type ClientId = String;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClientConfig {
pub client_id: ClientId,
pub distribution_type: DataDistribution,
pub privacy_config: PrivacyConfig,
pub metadata: HashMap<String, String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum DataDistribution {
Iid,
NonIidClassImbalance { class_weights: Vec<f64> },
NonIidFeatureShift { shift_factor: f64 },
NonIidMixed {
class_weights: Vec<f64>,
shift_factor: f64,
},
Custom {
strategy_name: String,
parameters: HashMap<String, f64>,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PrivacyConfig {
pub enable_dp: bool,
pub epsilon: f64,
pub delta: f64,
pub noise_mechanism: NoiseMechanism,
pub privacy_budget: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum NoiseMechanism {
Laplace { sensitivity: f64 },
Gaussian { sensitivity: f64 },
Exponential { sensitivity: f64 },
}
impl Default for PrivacyConfig {
fn default() -> Self {
Self {
enable_dp: false,
epsilon: 1.0,
delta: 1e-5,
noise_mechanism: NoiseMechanism::Gaussian { sensitivity: 1.0 },
privacy_budget: 10.0,
}
}
}
#[derive(Debug)]
pub struct FederatedClientDataset<T, D> {
config: ClientConfig,
dataset: D,
privacy_manager: Arc<Mutex<PrivacyManager>>,
stats: ClientStats,
_phantom: std::marker::PhantomData<T>,
}
#[derive(Debug)]
pub struct PrivacyManager {
remaining_budget: f64,
rng: StdRng,
noise_scale_cache: HashMap<String, f64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClientStats {
pub sample_count: usize,
pub class_distribution: HashMap<String, usize>,
pub feature_stats: FederatedFeatureStats,
pub quality_metrics: QualityMetrics,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FederatedFeatureStats {
pub means: Vec<f64>,
pub stds: Vec<f64>,
pub ranges: Vec<(f64, f64)>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QualityMetrics {
pub missing_percentage: f64,
pub outlier_percentage: f64,
pub consistency_score: f64,
}
impl PrivacyManager {
pub fn new(config: &PrivacyConfig, seed: u64) -> Self {
Self {
remaining_budget: config.privacy_budget,
rng: StdRng::seed_from_u64(seed),
noise_scale_cache: HashMap::new(),
}
}
pub fn add_noise(
&mut self,
value: f64,
config: &PrivacyConfig,
query_sensitivity: f64,
) -> Result<f64> {
if !config.enable_dp {
return Ok(value);
}
if self.remaining_budget <= 0.0 {
return Err(TensorError::invalid_argument(
"Privacy budget exhausted".to_string(),
));
}
let noise_scale = self.calculate_noise_scale(config, query_sensitivity);
let noise = match &config.noise_mechanism {
NoiseMechanism::Laplace { .. } => {
let scale = noise_scale;
self.sample_laplace(scale)
}
NoiseMechanism::Gaussian { .. } => {
let sigma = noise_scale;
self.sample_gaussian(sigma)
}
NoiseMechanism::Exponential { .. } => {
let scale = noise_scale;
self.sample_laplace(scale)
}
};
self.remaining_budget -= config.epsilon;
Ok(value + noise)
}
pub fn add_noise_tensor<T>(
&mut self,
tensor: &Tensor<T>,
config: &PrivacyConfig,
sensitivity: f64,
) -> Result<Tensor<T>>
where
T: Clone + Default + Send + Sync + 'static,
T: From<f64> + Into<f64>,
{
if !config.enable_dp {
return Ok(tensor.clone());
}
let shape = tensor.shape().dims().to_vec();
let mut noisy_data = Vec::new();
if let Some(slice) = tensor.as_slice() {
for value in slice {
let original_value: f64 = value.clone().into();
let noisy_value = self.add_noise(original_value, config, sensitivity)?;
noisy_data.push(T::from(noisy_value));
}
} else {
let value: f64 = tensor.get(&[]).unwrap_or_default().into();
let noisy_value = self.add_noise(value, config, sensitivity)?;
noisy_data.push(T::from(noisy_value));
}
Tensor::from_vec(noisy_data, &shape)
}
fn calculate_noise_scale(&mut self, config: &PrivacyConfig, sensitivity: f64) -> f64 {
let cache_key = format!("{}_{}", config.epsilon, sensitivity);
if let Some(&cached_scale) = self.noise_scale_cache.get(&cache_key) {
return cached_scale;
}
let scale = match &config.noise_mechanism {
NoiseMechanism::Laplace { .. } => sensitivity / config.epsilon,
NoiseMechanism::Gaussian { .. } => {
let factor = (2.0 * (1.25 / config.delta).ln()).sqrt();
factor * sensitivity / config.epsilon
}
NoiseMechanism::Exponential { .. } => sensitivity / config.epsilon,
};
self.noise_scale_cache.insert(cache_key, scale);
scale
}
fn sample_laplace(&mut self, scale: f64) -> f64 {
let u1: f64 = self.rng.random();
let u2: f64 = self.rng.random();
let sign = if u1 < 0.5 { -1.0 } else { 1.0 };
sign * scale * (1.0_f64 - 2.0_f64 * u2.abs()).max(1e-10_f64).ln()
}
fn sample_gaussian(&mut self, sigma: f64) -> f64 {
let u1: f64 = self.rng.random();
let u2: f64 = self.rng.random();
let z0 = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
z0 * sigma
}
pub fn can_spend_budget(&self, epsilon: f64) -> bool {
self.remaining_budget >= epsilon
}
pub fn remaining_budget(&self) -> f64 {
self.remaining_budget
}
}
impl<T, D> FederatedClientDataset<T, D>
where
D: Dataset<T>,
T: Clone + Default + Send + Sync + 'static,
{
pub fn new(dataset: D, config: ClientConfig) -> Self {
let stats = Self::compute_basic_stats(&dataset);
let privacy_manager = Arc::new(Mutex::new(PrivacyManager::new(&config.privacy_config, 42)));
Self {
config,
dataset,
privacy_manager,
stats,
_phantom: std::marker::PhantomData,
}
}
pub fn config(&self) -> &ClientConfig {
&self.config
}
pub fn stats(&self) -> &ClientStats {
&self.stats
}
pub fn get_private(&self, index: usize) -> Result<(Tensor<T>, Tensor<T>)>
where
T: From<f64> + Into<f64>,
{
let (features, labels) = self.dataset.get(index)?;
if !self.config.privacy_config.enable_dp {
return Ok((features, labels));
}
let mut privacy_manager = self
.privacy_manager
.lock()
.expect("lock should not be poisoned");
let noisy_features =
privacy_manager.add_noise_tensor(&features, &self.config.privacy_config, 1.0)?;
Ok((noisy_features, labels))
}
pub fn compute_private_statistics(&self) -> Result<PrivateStats>
where
T: From<f64> + Into<f64>,
{
let mut feature_sums = Vec::new();
let mut feature_counts = Vec::new();
let sample_count = self.dataset.len();
if sample_count == 0 {
return Ok(PrivateStats {
sample_count: 0,
feature_means: Vec::new(),
class_counts: HashMap::new(),
});
}
let (first_features, _) = self.dataset.get(0)?;
let feature_dim = if let Some(slice) = first_features.as_slice() {
slice.len()
} else {
1
};
feature_sums.resize(feature_dim, 0.0);
feature_counts.resize(feature_dim, 0);
for i in 0..sample_count {
let (features, _) = self.dataset.get(i)?;
if let Some(slice) = features.as_slice() {
for (j, value) in slice.iter().enumerate() {
feature_sums[j] += value.clone().into();
feature_counts[j] += 1;
}
} else {
let value: f64 = features.get(&[]).unwrap_or(T::default()).into();
feature_sums[0] += value;
feature_counts[0] += 1;
}
}
let mut private_means = Vec::new();
let mut privacy_manager = self
.privacy_manager
.lock()
.expect("lock should not be poisoned");
for i in 0..feature_dim {
let mean = if feature_counts[i] > 0 {
feature_sums[i] / feature_counts[i] as f64
} else {
0.0
};
let private_mean = privacy_manager.add_noise(mean, &self.config.privacy_config, 1.0)?;
private_means.push(private_mean);
}
let private_sample_count =
privacy_manager.add_noise(sample_count as f64, &self.config.privacy_config, 1.0)?
as usize;
Ok(PrivateStats {
sample_count: private_sample_count,
feature_means: private_means,
class_counts: HashMap::new(), })
}
fn compute_basic_stats(dataset: &D) -> ClientStats {
let sample_count = dataset.len();
ClientStats {
sample_count,
class_distribution: HashMap::new(), feature_stats: FederatedFeatureStats {
means: Vec::new(),
stds: Vec::new(),
ranges: Vec::new(),
},
quality_metrics: QualityMetrics {
missing_percentage: 0.0,
outlier_percentage: 0.0,
consistency_score: 1.0,
},
}
}
}
impl<T, D> Dataset<T> for FederatedClientDataset<T, D>
where
D: Dataset<T>,
T: Clone + Default + Send + Sync + 'static,
{
fn len(&self) -> usize {
self.dataset.len()
}
fn get(&self, index: usize) -> Result<(Tensor<T>, Tensor<T>)> {
self.dataset.get(index)
}
}
#[derive(Debug, Clone)]
pub struct PrivateStats {
pub sample_count: usize,
pub feature_means: Vec<f64>,
pub class_counts: HashMap<String, usize>,
}
#[derive(Debug)]
pub struct FederatedPartitioner {
num_clients: usize,
strategy: PartitioningStrategy,
rng: StdRng,
}
#[derive(Debug, Clone)]
pub enum PartitioningStrategy {
Uniform,
Dirichlet { alpha: f64 },
ClassBased { classes_per_client: usize },
QuantityBased { size_variance: f64 },
}
impl FederatedPartitioner {
pub fn new(num_clients: usize, strategy: PartitioningStrategy, seed: u64) -> Self {
Self {
num_clients,
strategy,
rng: StdRng::seed_from_u64(seed),
}
}
pub fn partition<T, D>(
&mut self,
dataset: D,
) -> Result<Vec<FederatedClientDataset<T, ClientIndexedDataset<T, D>>>>
where
D: Dataset<T> + Clone,
T: Clone + Default + Send + Sync + 'static,
{
let total_samples = dataset.len();
let client_assignments = self.generate_client_assignments(total_samples)?;
let mut client_datasets = Vec::new();
for (client_idx, indices) in client_assignments.into_iter().enumerate() {
let client_id = format!("client_{client_idx}");
let client_dataset = ClientIndexedDataset::new(dataset.clone(), indices);
let config = ClientConfig {
client_id: client_id.clone(),
distribution_type: self.get_distribution_type_for_client(client_idx),
privacy_config: PrivacyConfig::default(),
metadata: HashMap::new(),
};
let federated_client = FederatedClientDataset::new(client_dataset, config);
client_datasets.push(federated_client);
}
Ok(client_datasets)
}
fn generate_client_assignments(&mut self, total_samples: usize) -> Result<Vec<Vec<usize>>> {
match &self.strategy {
PartitioningStrategy::Uniform => self.uniform_partition(total_samples),
PartitioningStrategy::Dirichlet { alpha } => {
self.dirichlet_partition(total_samples, *alpha)
}
PartitioningStrategy::ClassBased {
classes_per_client: _,
} => {
self.uniform_partition(total_samples)
}
PartitioningStrategy::QuantityBased { size_variance } => {
self.quantity_based_partition(total_samples, *size_variance)
}
}
}
fn uniform_partition(&mut self, total_samples: usize) -> Result<Vec<Vec<usize>>> {
let mut indices: Vec<usize> = (0..total_samples).collect();
for i in (1..indices.len()).rev() {
let j = self.rng.random_range(0..i);
indices.swap(i, j);
}
let base_size = total_samples / self.num_clients;
let remainder = total_samples % self.num_clients;
let mut client_assignments = Vec::new();
let mut start_idx = 0;
for i in 0..self.num_clients {
let client_size = base_size + if i < remainder { 1 } else { 0 };
let end_idx = start_idx + client_size;
client_assignments.push(indices[start_idx..end_idx].to_vec());
start_idx = end_idx;
}
Ok(client_assignments)
}
fn dirichlet_partition(&mut self, total_samples: usize, alpha: f64) -> Result<Vec<Vec<usize>>> {
let mut proportions = Vec::new();
let mut sum = 0.0;
for _ in 0..self.num_clients {
let prop = self.rng.random::<f64>() * alpha + 0.1; proportions.push(prop);
sum += prop;
}
for prop in &mut proportions {
*prop /= sum;
}
let mut client_assignments = Vec::new();
let mut assigned_samples = 0;
for (i, &proportion) in proportions.iter().enumerate() {
let client_samples = if i == self.num_clients - 1 {
total_samples - assigned_samples
} else {
(total_samples as f64 * proportion) as usize
};
let indices: Vec<usize> =
(assigned_samples..assigned_samples + client_samples).collect();
client_assignments.push(indices);
assigned_samples += client_samples;
}
Ok(client_assignments)
}
fn quantity_based_partition(
&mut self,
total_samples: usize,
size_variance: f64,
) -> Result<Vec<Vec<usize>>> {
let base_size = total_samples as f64 / self.num_clients as f64;
let mut client_sizes = Vec::new();
let mut total_assigned = 0;
for i in 0..self.num_clients {
let variance_factor = 1.0 + (self.rng.random::<f64>() - 0.5) * 2.0 * size_variance;
let client_size = if i == self.num_clients - 1 {
total_samples - total_assigned
} else {
((base_size * variance_factor) as usize).min(total_samples - total_assigned)
};
client_sizes.push(client_size);
total_assigned += client_size;
if total_assigned >= total_samples {
break;
}
}
let mut client_assignments = Vec::new();
let mut start_idx = 0;
for &size in &client_sizes {
let end_idx = (start_idx + size).min(total_samples);
let indices: Vec<usize> = (start_idx..end_idx).collect();
client_assignments.push(indices);
start_idx = end_idx;
}
Ok(client_assignments)
}
fn get_distribution_type_for_client(&self, _client_idx: usize) -> DataDistribution {
match &self.strategy {
PartitioningStrategy::Uniform => DataDistribution::Iid,
PartitioningStrategy::Dirichlet { alpha } => DataDistribution::NonIidClassImbalance {
class_weights: vec![*alpha, 1.0 - alpha],
},
PartitioningStrategy::ClassBased { .. } => DataDistribution::NonIidClassImbalance {
class_weights: vec![0.8, 0.2],
},
PartitioningStrategy::QuantityBased { .. } => DataDistribution::Iid,
}
}
}
#[derive(Debug, Clone)]
pub struct ClientIndexedDataset<T, D> {
dataset: D,
indices: Vec<usize>,
_phantom: std::marker::PhantomData<T>,
}
impl<T, D> ClientIndexedDataset<T, D>
where
D: Dataset<T>,
T: Clone + Default + Send + Sync + 'static,
{
pub fn new(dataset: D, indices: Vec<usize>) -> Self {
Self {
dataset,
indices,
_phantom: std::marker::PhantomData,
}
}
pub fn inner(&self) -> &D {
&self.dataset
}
pub fn indices(&self) -> &[usize] {
&self.indices
}
}
impl<T, D> Dataset<T> for ClientIndexedDataset<T, D>
where
D: Dataset<T>,
T: Clone + Default + Send + Sync + 'static,
{
fn len(&self) -> usize {
self.indices.len()
}
fn get(&self, index: usize) -> Result<(Tensor<T>, Tensor<T>)> {
if index >= self.indices.len() {
return Err(TensorError::invalid_argument(format!(
"Index {} out of bounds for client dataset of length {}",
index,
self.indices.len()
)));
}
let actual_index = self.indices[index];
self.dataset.get(actual_index)
}
}
#[derive(Debug)]
pub struct FederatedAggregator {
strategy: AggregationStrategy,
client_weights: HashMap<ClientId, f64>,
}
#[derive(Debug, Clone)]
pub enum AggregationStrategy {
Average,
WeightedBySize,
WeightedByQuality,
Median,
TrimmedMean { trim_fraction: f64 },
}
impl FederatedAggregator {
pub fn new(strategy: AggregationStrategy) -> Self {
Self {
strategy,
client_weights: HashMap::new(),
}
}
pub fn set_client_weight(&mut self, client_id: ClientId, weight: f64) {
self.client_weights.insert(client_id, weight);
}
pub fn aggregate_statistics(
&self,
client_stats: Vec<(ClientId, PrivateStats)>,
) -> Result<PrivateStats> {
if client_stats.is_empty() {
return Err(TensorError::invalid_argument(
"No client statistics provided".to_string(),
));
}
match &self.strategy {
AggregationStrategy::Average => self.average_statistics(client_stats),
AggregationStrategy::WeightedBySize => {
self.weighted_statistics(client_stats, |stats| stats.sample_count as f64)
}
AggregationStrategy::WeightedByQuality => {
self.weighted_statistics(client_stats, |_| 1.0)
} AggregationStrategy::Median => self.median_statistics(client_stats),
AggregationStrategy::TrimmedMean { trim_fraction } => {
self.trimmed_mean_statistics(client_stats, *trim_fraction)
}
}
}
fn average_statistics(
&self,
client_stats: Vec<(ClientId, PrivateStats)>,
) -> Result<PrivateStats> {
let num_clients = client_stats.len() as f64;
let mut total_samples = 0;
let mut aggregated_means = Vec::new();
let mut aggregated_class_counts = HashMap::new();
if let Some((_, first_stats)) = client_stats.first() {
aggregated_means.resize(first_stats.feature_means.len(), 0.0);
}
for (_, stats) in &client_stats {
total_samples += stats.sample_count;
for (i, &mean) in stats.feature_means.iter().enumerate() {
if i < aggregated_means.len() {
aggregated_means[i] += mean / num_clients;
}
}
for (class, &count) in &stats.class_counts {
*aggregated_class_counts.entry(class.clone()).or_insert(0) += count;
}
}
Ok(PrivateStats {
sample_count: total_samples,
feature_means: aggregated_means,
class_counts: aggregated_class_counts,
})
}
fn weighted_statistics<F>(
&self,
client_stats: Vec<(ClientId, PrivateStats)>,
weight_fn: F,
) -> Result<PrivateStats>
where
F: Fn(&PrivateStats) -> f64,
{
#[allow(unused_assignments)]
let mut total_weight = 0.0;
let mut total_samples = 0;
let mut aggregated_means = Vec::new();
let mut aggregated_class_counts = HashMap::new();
let weights: Vec<f64> = client_stats
.iter()
.map(|(_, stats)| weight_fn(stats))
.collect();
total_weight = weights.iter().sum();
if let Some((_, first_stats)) = client_stats.first() {
aggregated_means.resize(first_stats.feature_means.len(), 0.0);
}
for ((_, stats), weight) in client_stats.iter().zip(weights.iter()) {
let normalized_weight = weight / total_weight;
total_samples += stats.sample_count;
for (i, &mean) in stats.feature_means.iter().enumerate() {
if i < aggregated_means.len() {
aggregated_means[i] += mean * normalized_weight;
}
}
for (class, &count) in &stats.class_counts {
let weighted_count = (count as f64 * normalized_weight) as usize;
*aggregated_class_counts.entry(class.clone()).or_insert(0) += weighted_count;
}
}
Ok(PrivateStats {
sample_count: total_samples,
feature_means: aggregated_means,
class_counts: aggregated_class_counts,
})
}
fn median_statistics(
&self,
client_stats: Vec<(ClientId, PrivateStats)>,
) -> Result<PrivateStats> {
self.average_statistics(client_stats)
}
fn trimmed_mean_statistics(
&self,
client_stats: Vec<(ClientId, PrivateStats)>,
_trim_fraction: f64,
) -> Result<PrivateStats> {
self.average_statistics(client_stats)
}
}
pub trait FederatedDatasetExt<T>: Dataset<T> + Sized
where
T: Clone + Default + Send + Sync + 'static,
{
fn federated_client(self, config: ClientConfig) -> FederatedClientDataset<T, Self> {
FederatedClientDataset::new(self, config)
}
fn partition_federated(
self,
num_clients: usize,
strategy: PartitioningStrategy,
seed: u64,
) -> Result<Vec<FederatedClientDataset<T, ClientIndexedDataset<T, Self>>>>
where
Self: Clone,
{
let mut partitioner = FederatedPartitioner::new(num_clients, strategy, seed);
partitioner.partition(self)
}
}
impl<T, D: Dataset<T>> FederatedDatasetExt<T> for D where T: Clone + Default + Send + Sync + 'static {}
#[cfg(test)]
mod tests {
use super::*;
use crate::TensorDataset;
#[test]
fn test_privacy_manager() {
let config = PrivacyConfig {
enable_dp: true,
epsilon: 1.0,
delta: 1e-5,
noise_mechanism: NoiseMechanism::Gaussian { sensitivity: 1.0 },
privacy_budget: 10.0,
};
let mut privacy_manager = PrivacyManager::new(&config, 42);
assert_eq!(privacy_manager.remaining_budget(), 10.0);
assert!(privacy_manager.can_spend_budget(1.0));
let noisy_value = privacy_manager
.add_noise(5.0, &config, 1.0)
.expect("test: operation should succeed");
assert!(privacy_manager.remaining_budget() < 10.0);
assert_ne!(noisy_value, 5.0); }
#[test]
fn test_federated_client_dataset() {
let features_data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let labels_data = vec![0.0, 1.0, 0.0];
let features =
Tensor::from_vec(features_data, &[3, 2]).expect("test: tensor creation should succeed");
let labels =
Tensor::from_vec(labels_data, &[3]).expect("test: tensor creation should succeed");
let dataset = TensorDataset::new(features, labels);
let config = ClientConfig {
client_id: "test_client".to_string(),
distribution_type: DataDistribution::Iid,
privacy_config: PrivacyConfig::default(),
metadata: HashMap::new(),
};
let federated_dataset = FederatedClientDataset::new(dataset, config);
assert_eq!(federated_dataset.len(), 3);
assert_eq!(federated_dataset.config().client_id, "test_client");
assert_eq!(federated_dataset.stats().sample_count, 3);
}
#[test]
fn test_federated_partitioner() {
let features_data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let labels_data = vec![0.0, 1.0, 0.0, 1.0];
let features =
Tensor::from_vec(features_data, &[4, 2]).expect("test: tensor creation should succeed");
let labels =
Tensor::from_vec(labels_data, &[4]).expect("test: tensor creation should succeed");
let dataset = TensorDataset::new(features, labels);
let mut partitioner = FederatedPartitioner::new(2, PartitioningStrategy::Uniform, 42);
let client_datasets = partitioner
.partition(dataset)
.expect("test: operation should succeed");
assert_eq!(client_datasets.len(), 2);
let total_samples: usize = client_datasets.iter().map(|d| d.len()).sum();
assert_eq!(total_samples, 4);
}
#[test]
fn test_client_indexed_dataset() {
let features_data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let labels_data = vec![0.0, 1.0, 0.0];
let features =
Tensor::from_vec(features_data, &[3, 2]).expect("test: tensor creation should succeed");
let labels =
Tensor::from_vec(labels_data, &[3]).expect("test: tensor creation should succeed");
let dataset = TensorDataset::new(features, labels);
let indices = vec![0, 2]; let client_dataset = ClientIndexedDataset::new(dataset, indices);
assert_eq!(client_dataset.len(), 2);
assert_eq!(client_dataset.indices(), &[0, 2]);
let (features, labels) = client_dataset.get(0).expect("index should be in bounds");
let features_slice = features.as_slice().expect("tensor should be contiguous");
assert_eq!(features_slice, &[1.0, 2.0]); assert_eq!(labels.get(&[]).expect("test: get should succeed"), 0.0);
let (features, labels) = client_dataset.get(1).expect("index should be in bounds");
let features_slice = features.as_slice().expect("tensor should be contiguous");
assert_eq!(features_slice, &[5.0, 6.0]); assert_eq!(labels.get(&[]).expect("test: get should succeed"), 0.0);
}
#[test]
fn test_federated_aggregator() {
let aggregator = FederatedAggregator::new(AggregationStrategy::Average);
let client_stats = vec![
(
"client1".to_string(),
PrivateStats {
sample_count: 100,
feature_means: vec![1.0, 2.0],
class_counts: HashMap::new(),
},
),
(
"client2".to_string(),
PrivateStats {
sample_count: 200,
feature_means: vec![3.0, 4.0],
class_counts: HashMap::new(),
},
),
];
let aggregated = aggregator
.aggregate_statistics(client_stats)
.expect("test: operation should succeed");
assert_eq!(aggregated.sample_count, 300);
assert_eq!(aggregated.feature_means, vec![2.0, 3.0]); }
#[test]
fn test_privacy_config_serialization() {
let config = PrivacyConfig {
enable_dp: true,
epsilon: 1.0,
delta: 1e-5,
noise_mechanism: NoiseMechanism::Gaussian { sensitivity: 1.0 },
privacy_budget: 10.0,
};
let json = serde_json::to_string(&config).expect("test: serialization should succeed");
let deserialized: PrivacyConfig =
serde_json::from_str(&json).expect("test: JSON parsing should succeed");
assert!(deserialized.enable_dp);
assert_eq!(deserialized.epsilon, 1.0);
}
#[test]
fn test_extension_trait() {
let features_data = vec![1.0, 2.0, 3.0, 4.0];
let labels_data = vec![0.0, 1.0];
let features =
Tensor::from_vec(features_data, &[2, 2]).expect("test: tensor creation should succeed");
let labels =
Tensor::from_vec(labels_data, &[2]).expect("test: tensor creation should succeed");
let dataset = TensorDataset::new(features, labels);
let config = ClientConfig {
client_id: "test_client".to_string(),
distribution_type: DataDistribution::Iid,
privacy_config: PrivacyConfig::default(),
metadata: HashMap::new(),
};
let federated_dataset = dataset.federated_client(config);
assert_eq!(federated_dataset.len(), 2);
let features_data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let labels_data = vec![0.0, 1.0, 0.0, 1.0];
let features =
Tensor::from_vec(features_data, &[4, 2]).expect("test: tensor creation should succeed");
let labels =
Tensor::from_vec(labels_data, &[4]).expect("test: tensor creation should succeed");
let dataset = TensorDataset::new(features, labels);
let client_datasets = dataset
.partition_federated(2, PartitioningStrategy::Uniform, 42)
.expect("test: operation should succeed");
assert_eq!(client_datasets.len(), 2);
}
}