quantrs2_ml/
federated.rs

1//! Quantum federated learning protocols for distributed quantum machine learning.
2//!
3//! This module implements privacy-preserving distributed training of quantum models
4//! with secure aggregation and differential privacy guarantees.
5
6use scirs2_core::ndarray::{Array1, Array2, Array3};
7use scirs2_core::Complex64;
8use std::collections::HashMap;
9use std::f64::consts::PI;
10
11use crate::error::{MLError, Result};
12use crate::qnn::QuantumNeuralNetwork;
13use crate::utils::VariationalCircuit;
14use quantrs2_circuit::prelude::*;
15use quantrs2_core::gate::{multi::*, single::*, GateOp};
16
17/// Federated learning client for quantum models
18#[derive(Debug)]
19pub struct QuantumFLClient {
20    /// Client ID
21    client_id: String,
22    /// Local quantum model
23    local_model: QuantumNeuralNetwork,
24    /// Local dataset size
25    dataset_size: usize,
26    /// Privacy budget
27    epsilon: f64,
28    /// Noise scale for differential privacy
29    noise_scale: f64,
30    /// Client-specific parameters
31    local_params: HashMap<String, f64>,
32}
33
34impl QuantumFLClient {
35    /// Create a new federated learning client
36    pub fn new(
37        client_id: String,
38        model_config: &[(String, usize)], // Layer configs
39        dataset_size: usize,
40        epsilon: f64,
41    ) -> Result<Self> {
42        // Create local model based on config
43        let layers = model_config
44            .iter()
45            .map(|(layer_type, size)| match layer_type.as_str() {
46                "encoding" => crate::qnn::QNNLayerType::EncodingLayer {
47                    num_features: *size,
48                },
49                "variational" => crate::qnn::QNNLayerType::VariationalLayer { num_params: *size },
50                "entanglement" => crate::qnn::QNNLayerType::EntanglementLayer {
51                    connectivity: "full".to_string(),
52                },
53                _ => crate::qnn::QNNLayerType::MeasurementLayer {
54                    measurement_basis: "computational".to_string(),
55                },
56            })
57            .collect();
58
59        let local_model = QuantumNeuralNetwork::new(layers, 4, 10, 2)?;
60        let noise_scale = (2.0 * (1.25 / epsilon).ln()).sqrt() / dataset_size as f64;
61
62        Ok(Self {
63            client_id,
64            local_model,
65            dataset_size,
66            epsilon,
67            noise_scale,
68            local_params: HashMap::new(),
69        })
70    }
71
72    /// Train on local data
73    pub fn train_local(
74        &mut self,
75        local_data: &Array2<f64>,
76        local_labels: &Array1<i32>,
77        epochs: usize,
78    ) -> Result<f64> {
79        let mut total_loss = 0.0;
80
81        for _ in 0..epochs {
82            // Simplified training loop
83            for i in 0..local_data.nrows() {
84                let input = local_data.row(i).to_owned();
85                let label = local_labels[i];
86
87                // Forward pass
88                let output = self.local_model.forward(&input)?;
89
90                // Compute loss
91                let loss = self.compute_loss(&output, label)?;
92                total_loss += loss;
93
94                // Backward pass (simplified)
95                self.update_parameters(&input, label, 0.01)?;
96            }
97        }
98
99        // Add differential privacy noise
100        self.add_dp_noise()?;
101
102        Ok(total_loss / (epochs * local_data.nrows()) as f64)
103    }
104
105    /// Compute loss function
106    fn compute_loss(&self, output: &Array1<f64>, label: i32) -> Result<f64> {
107        // Cross-entropy loss for classification
108        let label_idx = label as usize;
109        if label_idx >= output.len() {
110            return Err(MLError::InvalidInput("Label out of bounds".to_string()));
111        }
112
113        Ok(-output[label_idx].ln())
114    }
115
116    /// Update parameters (simplified)
117    fn update_parameters(
118        &mut self,
119        input: &Array1<f64>,
120        label: i32,
121        learning_rate: f64,
122    ) -> Result<()> {
123        // Placeholder parameter update
124        for (key, value) in self.local_params.iter_mut() {
125            *value += learning_rate * fastrand::f64() * 0.1;
126        }
127        Ok(())
128    }
129
130    /// Add differential privacy noise
131    fn add_dp_noise(&mut self) -> Result<()> {
132        for (_, value) in self.local_params.iter_mut() {
133            // Add Gaussian noise scaled by sensitivity and epsilon
134            let noise = self.noise_scale * Self::gaussian_noise();
135            *value += noise;
136        }
137        Ok(())
138    }
139
140    /// Generate Gaussian noise
141    fn gaussian_noise() -> f64 {
142        // Box-Muller transform
143        let u1 = fastrand::f64();
144        let u2 = fastrand::f64();
145        (-2.0 * u1.ln()).sqrt() * (2.0 * PI * u2).cos()
146    }
147
148    /// Get model parameters for aggregation
149    pub fn get_parameters(&self) -> HashMap<String, f64> {
150        self.local_params.clone()
151    }
152
153    /// Update model with aggregated parameters
154    pub fn set_parameters(&mut self, params: HashMap<String, f64>) {
155        self.local_params = params;
156    }
157}
158
159/// Quantum secure aggregation server
160#[derive(Debug)]
161pub struct QuantumFLServer {
162    /// Global model configuration
163    model_config: Vec<(String, usize)>,
164    /// Aggregated parameters
165    global_params: HashMap<String, f64>,
166    /// Client weights for aggregation
167    client_weights: HashMap<String, f64>,
168    /// Secure aggregation protocol
169    aggregation_protocol: SecureAggregationProtocol,
170    /// Byzantine fault tolerance threshold
171    byzantine_threshold: f64,
172}
173
174#[derive(Debug, Clone)]
175pub enum SecureAggregationProtocol {
176    /// Simple averaging
177    FederatedAveraging,
178    /// Secure multi-party computation
179    SecureMultiparty,
180    /// Homomorphic encryption
181    HomomorphicEncryption,
182    /// Quantum secret sharing
183    QuantumSecretSharing,
184}
185
186impl QuantumFLServer {
187    /// Create a new federated learning server
188    pub fn new(
189        model_config: Vec<(String, usize)>,
190        aggregation_protocol: SecureAggregationProtocol,
191        byzantine_threshold: f64,
192    ) -> Self {
193        Self {
194            model_config,
195            global_params: HashMap::new(),
196            client_weights: HashMap::new(),
197            aggregation_protocol,
198            byzantine_threshold,
199        }
200    }
201
202    /// Aggregate client updates
203    pub fn aggregate_updates(
204        &mut self,
205        client_updates: Vec<(String, HashMap<String, f64>, usize)>, // (client_id, params, dataset_size)
206    ) -> Result<HashMap<String, f64>> {
207        match self.aggregation_protocol {
208            SecureAggregationProtocol::FederatedAveraging => {
209                self.federated_averaging(client_updates)
210            }
211            SecureAggregationProtocol::SecureMultiparty => {
212                self.secure_multiparty_aggregation(client_updates)
213            }
214            SecureAggregationProtocol::HomomorphicEncryption => {
215                self.homomorphic_aggregation(client_updates)
216            }
217            SecureAggregationProtocol::QuantumSecretSharing => {
218                self.quantum_secret_sharing_aggregation(client_updates)
219            }
220        }
221    }
222
223    /// Federated averaging aggregation
224    fn federated_averaging(
225        &mut self,
226        client_updates: Vec<(String, HashMap<String, f64>, usize)>,
227    ) -> Result<HashMap<String, f64>> {
228        let total_samples: usize = client_updates.iter().map(|(_, _, size)| size).sum();
229        let mut aggregated = HashMap::new();
230
231        // Weight by dataset size
232        for (client_id, params, dataset_size) in client_updates {
233            let weight = dataset_size as f64 / total_samples as f64;
234            self.client_weights.insert(client_id.clone(), weight);
235
236            for (param_name, param_value) in params {
237                *aggregated.entry(param_name).or_insert(0.0) += weight * param_value;
238            }
239        }
240
241        self.global_params = aggregated.clone();
242        Ok(aggregated)
243    }
244
245    /// Secure multi-party computation aggregation
246    fn secure_multiparty_aggregation(
247        &mut self,
248        client_updates: Vec<(String, HashMap<String, f64>, usize)>,
249    ) -> Result<HashMap<String, f64>> {
250        // Implement secure aggregation using secret sharing
251        let num_clients = client_updates.len();
252        let mut shares: HashMap<String, Vec<f64>> = HashMap::new();
253
254        // Collect shares for each parameter
255        for (_, params, _) in &client_updates {
256            for (param_name, param_value) in params {
257                shares
258                    .entry(param_name.clone())
259                    .or_insert(Vec::new())
260                    .push(*param_value);
261            }
262        }
263
264        // Aggregate shares with Byzantine fault tolerance
265        let mut aggregated = HashMap::new();
266        for (param_name, param_shares) in shares {
267            let aggregated_value = self.byzantine_robust_aggregation(&param_shares)?;
268            aggregated.insert(param_name, aggregated_value);
269        }
270
271        self.global_params = aggregated.clone();
272        Ok(aggregated)
273    }
274
275    /// Homomorphic encryption aggregation
276    fn homomorphic_aggregation(
277        &mut self,
278        client_updates: Vec<(String, HashMap<String, f64>, usize)>,
279    ) -> Result<HashMap<String, f64>> {
280        // Simplified homomorphic aggregation
281        // In practice, would use actual homomorphic encryption
282
283        let mut encrypted_sum = HashMap::new();
284
285        for (_, params, _) in &client_updates {
286            for (param_name, param_value) in params {
287                // "Encrypt" (simplified)
288                let encrypted = self.homomorphic_encrypt(*param_value)?;
289
290                // Add encrypted values
291                *encrypted_sum.entry(param_name.clone()).or_insert(0.0) += encrypted;
292            }
293        }
294
295        // "Decrypt" aggregated values
296        let mut aggregated = HashMap::new();
297        for (param_name, encrypted_value) in encrypted_sum {
298            let decrypted = self.homomorphic_decrypt(encrypted_value)?;
299            aggregated.insert(param_name, decrypted / client_updates.len() as f64);
300        }
301
302        self.global_params = aggregated.clone();
303        Ok(aggregated)
304    }
305
306    /// Quantum secret sharing aggregation
307    fn quantum_secret_sharing_aggregation(
308        &mut self,
309        client_updates: Vec<(String, HashMap<String, f64>, usize)>,
310    ) -> Result<HashMap<String, f64>> {
311        let num_clients = client_updates.len();
312        let threshold = ((num_clients as f64) * self.byzantine_threshold).ceil() as usize;
313
314        // Create quantum shares
315        let mut quantum_shares: HashMap<String, Vec<QuantumShare>> = HashMap::new();
316
317        for (client_id, params, _) in &client_updates {
318            for (param_name, param_value) in params {
319                let share = self.create_quantum_share(client_id, *param_value)?;
320                quantum_shares
321                    .entry(param_name.clone())
322                    .or_insert(Vec::new())
323                    .push(share);
324            }
325        }
326
327        // Reconstruct from shares
328        let mut aggregated = HashMap::new();
329        for (param_name, shares) in quantum_shares {
330            if shares.len() >= threshold {
331                let reconstructed = self.reconstruct_from_quantum_shares(&shares)?;
332                aggregated.insert(param_name, reconstructed);
333            }
334        }
335
336        self.global_params = aggregated.clone();
337        Ok(aggregated)
338    }
339
340    /// Byzantine-robust aggregation
341    fn byzantine_robust_aggregation(&self, values: &[f64]) -> Result<f64> {
342        if values.is_empty() {
343            return Err(MLError::InvalidInput("No values to aggregate".to_string()));
344        }
345
346        // Krum algorithm for Byzantine robustness
347        let n = values.len();
348        let f = ((n as f64 * self.byzantine_threshold) as usize).min(n / 2);
349
350        // Compute pairwise distances
351        let mut scores = vec![0.0; n];
352        for i in 0..n {
353            let mut distances: Vec<f64> = (0..n)
354                .filter(|&j| j != i)
355                .map(|j| (values[i] - values[j]).abs())
356                .collect();
357            distances.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
358
359            // Sum of n-f-1 closest values
360            scores[i] = distances.iter().take(n - f - 1).sum();
361        }
362
363        // Select value with minimum score
364        let best_idx = scores
365            .iter()
366            .enumerate()
367            .min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
368            .map(|(idx, _)| idx)
369            .unwrap_or(0);
370
371        Ok(values[best_idx])
372    }
373
374    /// Simple homomorphic encryption (placeholder)
375    fn homomorphic_encrypt(&self, value: f64) -> Result<f64> {
376        // In practice, use proper homomorphic encryption
377        Ok(value * 1000.0 + fastrand::f64() * 10.0)
378    }
379
380    /// Simple homomorphic decryption (placeholder)
381    fn homomorphic_decrypt(&self, encrypted: f64) -> Result<f64> {
382        // In practice, use proper homomorphic decryption
383        Ok((encrypted - 5.0) / 1000.0)
384    }
385
386    /// Create quantum share
387    fn create_quantum_share(&self, client_id: &str, value: f64) -> Result<QuantumShare> {
388        let num_qubits = 3;
389        let mut circuit = VariationalCircuit::new(num_qubits);
390
391        // Encode value in quantum state
392        circuit.add_gate("RY", vec![0], vec![(value * PI).to_string()]);
393
394        // Create entangled shares
395        circuit.add_gate("H", vec![1], vec![]);
396        circuit.add_gate("CNOT", vec![1, 2], vec![]);
397        circuit.add_gate("CNOT", vec![0, 1], vec![]);
398
399        Ok(QuantumShare {
400            client_id: client_id.to_string(),
401            share_circuit: circuit,
402            share_value: value,
403        })
404    }
405
406    /// Reconstruct from quantum shares
407    fn reconstruct_from_quantum_shares(&self, shares: &[QuantumShare]) -> Result<f64> {
408        // Simplified reconstruction
409        // In practice, would perform quantum state tomography
410        let sum: f64 = shares.iter().map(|s| s.share_value).sum();
411        Ok(sum / shares.len() as f64)
412    }
413}
414
415/// Quantum share for secret sharing
416#[derive(Debug)]
417struct QuantumShare {
418    client_id: String,
419    share_circuit: VariationalCircuit,
420    share_value: f64,
421}
422
423/// Distributed quantum learning coordinator
424#[derive(Debug)]
425pub struct DistributedQuantumLearning {
426    /// Server instance
427    server: QuantumFLServer,
428    /// Client instances
429    clients: HashMap<String, QuantumFLClient>,
430    /// Communication rounds
431    rounds: usize,
432    /// Convergence threshold
433    convergence_threshold: f64,
434}
435
436impl DistributedQuantumLearning {
437    /// Create a new distributed learning system
438    pub fn new(
439        num_clients: usize,
440        model_config: Vec<(String, usize)>,
441        aggregation_protocol: SecureAggregationProtocol,
442        epsilon: f64,
443    ) -> Result<Self> {
444        let server = QuantumFLServer::new(
445            model_config.clone(),
446            aggregation_protocol,
447            0.2, // Byzantine threshold
448        );
449
450        let mut clients = HashMap::new();
451        for i in 0..num_clients {
452            let client_id = format!("client_{}", i);
453            let dataset_size = 100 + fastrand::usize(..900); // Random dataset size
454            let client =
455                QuantumFLClient::new(client_id.clone(), &model_config, dataset_size, epsilon)?;
456            clients.insert(client_id, client);
457        }
458
459        Ok(Self {
460            server,
461            clients,
462            rounds: 0,
463            convergence_threshold: 1e-4,
464        })
465    }
466
467    /// Run federated training
468    pub fn train(
469        &mut self,
470        data_distribution: &HashMap<String, (Array2<f64>, Array1<i32>)>,
471        num_rounds: usize,
472        clients_per_round: usize,
473    ) -> Result<FederatedTrainingResult> {
474        let mut round_losses = Vec::new();
475        let mut convergence_metric = f64::INFINITY;
476
477        for round in 0..num_rounds {
478            self.rounds = round + 1;
479
480            // Select random subset of clients
481            let selected_clients = self.select_clients(clients_per_round);
482
483            // Local training
484            let mut client_updates = Vec::new();
485            let mut round_loss = 0.0;
486
487            for client_id in selected_clients {
488                if let Some(client) = self.clients.get_mut(&client_id) {
489                    if let Some((data, labels)) = data_distribution.get(&client_id) {
490                        // Train locally
491                        let loss = client.train_local(data, labels, 5)?;
492                        round_loss += loss;
493
494                        // Get parameters
495                        let params = client.get_parameters();
496                        let dataset_size = data.nrows();
497                        client_updates.push((client_id.clone(), params, dataset_size));
498                    }
499                }
500            }
501
502            // Aggregate updates
503            let aggregated = self.server.aggregate_updates(client_updates)?;
504
505            // Update all clients with aggregated model
506            for (_, client) in self.clients.iter_mut() {
507                client.set_parameters(aggregated.clone());
508            }
509
510            // Check convergence (skip on first round)
511            if round > 0 {
512                let prev_params = self.server.global_params.clone();
513                convergence_metric = self.compute_convergence(&prev_params, &aggregated)?;
514
515                if convergence_metric < self.convergence_threshold {
516                    round_losses.push(round_loss / clients_per_round as f64);
517                    break;
518                }
519            }
520
521            round_losses.push(round_loss / clients_per_round as f64);
522
523            // Update server's global params
524            self.server.global_params = aggregated.clone();
525        }
526
527        Ok(FederatedTrainingResult {
528            final_model_params: self.server.global_params.clone(),
529            round_losses,
530            num_rounds: self.rounds,
531            converged: convergence_metric < self.convergence_threshold,
532            convergence_metric,
533        })
534    }
535
536    /// Select random clients for training round
537    fn select_clients(&self, num_clients: usize) -> Vec<String> {
538        let all_clients: Vec<String> = self.clients.keys().cloned().collect();
539        let mut selected = Vec::new();
540
541        while selected.len() < num_clients.min(all_clients.len()) {
542            let idx = fastrand::usize(..all_clients.len());
543            let client = all_clients[idx].clone();
544            if !selected.contains(&client) {
545                selected.push(client);
546            }
547        }
548
549        selected
550    }
551
552    /// Compute convergence metric
553    fn compute_convergence(
554        &self,
555        old_params: &HashMap<String, f64>,
556        new_params: &HashMap<String, f64>,
557    ) -> Result<f64> {
558        let mut diff_sum = 0.0;
559        let mut count = 0;
560
561        for (key, new_val) in new_params {
562            if let Some(old_val) = old_params.get(key) {
563                diff_sum += (new_val - old_val).abs();
564                count += 1;
565            }
566        }
567
568        Ok(if count > 0 {
569            diff_sum / count as f64
570        } else {
571            0.0
572        })
573    }
574}
575
576/// Result of federated training
577#[derive(Debug)]
578pub struct FederatedTrainingResult {
579    /// Final aggregated model parameters
580    pub final_model_params: HashMap<String, f64>,
581    /// Loss history per round
582    pub round_losses: Vec<f64>,
583    /// Number of rounds completed
584    pub num_rounds: usize,
585    /// Whether training converged
586    pub converged: bool,
587    /// Final convergence metric
588    pub convergence_metric: f64,
589}
590
591/// Privacy-preserving quantum computation
592pub mod privacy {
593    use super::*;
594
595    /// Differential privacy mechanism for quantum circuits
596    #[derive(Debug)]
597    pub struct QuantumDifferentialPrivacy {
598        /// Privacy budget
599        epsilon: f64,
600        /// Sensitivity bound
601        sensitivity: f64,
602        /// Noise mechanism
603        mechanism: NoiseType,
604    }
605
606    #[derive(Debug, Clone)]
607    pub enum NoiseType {
608        Laplace,
609        Gaussian,
610        Quantum,
611    }
612
613    impl QuantumDifferentialPrivacy {
614        /// Create new DP mechanism
615        pub fn new(epsilon: f64, sensitivity: f64, mechanism: NoiseType) -> Self {
616            Self {
617                epsilon,
618                sensitivity,
619                mechanism,
620            }
621        }
622
623        /// Add noise to quantum circuit parameters
624        pub fn add_noise(&self, params: &mut HashMap<String, f64>) -> Result<()> {
625            for (_, value) in params.iter_mut() {
626                let noise = match self.mechanism {
627                    NoiseType::Laplace => self.laplace_noise(),
628                    NoiseType::Gaussian => self.gaussian_noise(),
629                    NoiseType::Quantum => self.quantum_noise()?,
630                };
631                *value += noise;
632            }
633            Ok(())
634        }
635
636        /// Laplace noise
637        fn laplace_noise(&self) -> f64 {
638            let scale = self.sensitivity / self.epsilon;
639            let u = fastrand::f64() - 0.5;
640            -scale * u.signum() * (1.0 - 2.0 * u.abs()).ln()
641        }
642
643        /// Gaussian noise
644        fn gaussian_noise(&self) -> f64 {
645            let scale = self.sensitivity * (2.0 * (1.25 / self.epsilon).ln()).sqrt();
646            QuantumFLClient::gaussian_noise() * scale
647        }
648
649        /// Quantum noise
650        fn quantum_noise(&self) -> Result<f64> {
651            // Implement quantum noise using depolarizing channel
652            let p = (-self.epsilon).exp();
653            Ok(if fastrand::f64() < p {
654                fastrand::f64() * 2.0 - 1.0
655            } else {
656                0.0
657            })
658        }
659    }
660}
661
662#[cfg(test)]
663mod tests {
664    use super::*;
665    use scirs2_core::ndarray::array;
666
667    #[test]
668    fn test_quantum_fl_client() {
669        let config = vec![
670            ("encoding".to_string(), 4),
671            ("variational".to_string(), 8),
672            ("measurement".to_string(), 0),
673        ];
674
675        let mut client = QuantumFLClient::new("client_1".to_string(), &config, 100, 1.0)
676            .expect("Failed to create client");
677
678        let data = array![[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]];
679        let labels = array![0, 1, 0];
680
681        let loss = client
682            .train_local(&data, &labels, 1)
683            .expect("Training failed");
684        assert!(loss >= 0.0);
685    }
686
687    #[test]
688    fn test_federated_averaging() {
689        let config = vec![("encoding".to_string(), 4)];
690        let mut server =
691            QuantumFLServer::new(config, SecureAggregationProtocol::FederatedAveraging, 0.2);
692
693        let mut params1 = HashMap::new();
694        params1.insert("w1".to_string(), 0.5);
695        params1.insert("w2".to_string(), 0.3);
696
697        let mut params2 = HashMap::new();
698        params2.insert("w1".to_string(), 0.7);
699        params2.insert("w2".to_string(), 0.4);
700
701        let updates = vec![
702            ("client1".to_string(), params1, 100),
703            ("client2".to_string(), params2, 200),
704        ];
705
706        let aggregated = server
707            .aggregate_updates(updates)
708            .expect("Aggregation failed");
709
710        // Weighted average: w1 = (0.5*100 + 0.7*200)/300 = 0.633...
711        assert!((aggregated["w1"] - 0.633).abs() < 0.01);
712    }
713
714    #[test]
715    fn test_byzantine_robust_aggregation() {
716        let server = QuantumFLServer::new(vec![], SecureAggregationProtocol::SecureMultiparty, 0.3);
717
718        // Normal values with one outlier
719        let values = vec![0.5, 0.52, 0.48, 0.51, 10.0]; // 10.0 is Byzantine
720        let robust_value = server
721            .byzantine_robust_aggregation(&values)
722            .expect("Byzantine aggregation failed");
723
724        // Should select one of the normal values
725        assert!(robust_value < 1.0);
726    }
727
728    #[test]
729    fn test_differential_privacy() {
730        use privacy::*;
731
732        let dp = QuantumDifferentialPrivacy::new(1.0, 0.1, NoiseType::Gaussian);
733
734        let mut params = HashMap::new();
735        params.insert("param1".to_string(), 0.5);
736        params.insert("param2".to_string(), 0.3);
737
738        let original = params.clone();
739        dp.add_noise(&mut params).expect("Failed to add noise");
740
741        // Check that noise was added
742        assert_ne!(params["param1"], original["param1"]);
743        assert_ne!(params["param2"], original["param2"]);
744    }
745
746    #[test]
747    fn test_distributed_learning() {
748        let config = vec![("encoding".to_string(), 4), ("variational".to_string(), 8)];
749
750        let mut system = DistributedQuantumLearning::new(
751            3, // 3 clients
752            config,
753            SecureAggregationProtocol::FederatedAveraging,
754            1.0,
755        )
756        .expect("Failed to create distributed learning system");
757
758        // Create dummy data for each client
759        let mut data_dist = HashMap::new();
760        for i in 0..3 {
761            let data = Array2::zeros((10, 4));
762            let labels = Array1::zeros(10);
763            data_dist.insert(format!("client_{}", i), (data, labels));
764        }
765
766        let result = system.train(&data_dist, 2, 2).expect("Training failed");
767
768        assert_eq!(result.num_rounds, 2);
769        assert_eq!(result.round_losses.len(), 2);
770    }
771}