1use scirs2_core::ndarray::{Array1, Array2, Array3};
7use scirs2_core::Complex64;
8use std::collections::HashMap;
9use std::f64::consts::PI;
10
11use crate::error::{MLError, Result};
12use crate::qnn::QuantumNeuralNetwork;
13use crate::utils::VariationalCircuit;
14use quantrs2_circuit::prelude::*;
15use quantrs2_core::gate::{multi::*, single::*, GateOp};
16
17#[derive(Debug)]
19pub struct QuantumFLClient {
20 client_id: String,
22 local_model: QuantumNeuralNetwork,
24 dataset_size: usize,
26 epsilon: f64,
28 noise_scale: f64,
30 local_params: HashMap<String, f64>,
32}
33
34impl QuantumFLClient {
35 pub fn new(
37 client_id: String,
38 model_config: &[(String, usize)], dataset_size: usize,
40 epsilon: f64,
41 ) -> Result<Self> {
42 let layers = model_config
44 .iter()
45 .map(|(layer_type, size)| match layer_type.as_str() {
46 "encoding" => crate::qnn::QNNLayerType::EncodingLayer {
47 num_features: *size,
48 },
49 "variational" => crate::qnn::QNNLayerType::VariationalLayer { num_params: *size },
50 "entanglement" => crate::qnn::QNNLayerType::EntanglementLayer {
51 connectivity: "full".to_string(),
52 },
53 _ => crate::qnn::QNNLayerType::MeasurementLayer {
54 measurement_basis: "computational".to_string(),
55 },
56 })
57 .collect();
58
59 let local_model = QuantumNeuralNetwork::new(layers, 4, 10, 2)?;
60 let noise_scale = (2.0 * (1.25 / epsilon).ln()).sqrt() / dataset_size as f64;
61
62 Ok(Self {
63 client_id,
64 local_model,
65 dataset_size,
66 epsilon,
67 noise_scale,
68 local_params: HashMap::new(),
69 })
70 }
71
72 pub fn train_local(
74 &mut self,
75 local_data: &Array2<f64>,
76 local_labels: &Array1<i32>,
77 epochs: usize,
78 ) -> Result<f64> {
79 let mut total_loss = 0.0;
80
81 for _ in 0..epochs {
82 for i in 0..local_data.nrows() {
84 let input = local_data.row(i).to_owned();
85 let label = local_labels[i];
86
87 let output = self.local_model.forward(&input)?;
89
90 let loss = self.compute_loss(&output, label)?;
92 total_loss += loss;
93
94 self.update_parameters(&input, label, 0.01)?;
96 }
97 }
98
99 self.add_dp_noise()?;
101
102 Ok(total_loss / (epochs * local_data.nrows()) as f64)
103 }
104
105 fn compute_loss(&self, output: &Array1<f64>, label: i32) -> Result<f64> {
107 let label_idx = label as usize;
109 if label_idx >= output.len() {
110 return Err(MLError::InvalidInput("Label out of bounds".to_string()));
111 }
112
113 Ok(-output[label_idx].ln())
114 }
115
116 fn update_parameters(
118 &mut self,
119 input: &Array1<f64>,
120 label: i32,
121 learning_rate: f64,
122 ) -> Result<()> {
123 for (key, value) in self.local_params.iter_mut() {
125 *value += learning_rate * fastrand::f64() * 0.1;
126 }
127 Ok(())
128 }
129
130 fn add_dp_noise(&mut self) -> Result<()> {
132 for (_, value) in self.local_params.iter_mut() {
133 let noise = self.noise_scale * Self::gaussian_noise();
135 *value += noise;
136 }
137 Ok(())
138 }
139
140 fn gaussian_noise() -> f64 {
142 let u1 = fastrand::f64();
144 let u2 = fastrand::f64();
145 (-2.0 * u1.ln()).sqrt() * (2.0 * PI * u2).cos()
146 }
147
148 pub fn get_parameters(&self) -> HashMap<String, f64> {
150 self.local_params.clone()
151 }
152
153 pub fn set_parameters(&mut self, params: HashMap<String, f64>) {
155 self.local_params = params;
156 }
157}
158
159#[derive(Debug)]
161pub struct QuantumFLServer {
162 model_config: Vec<(String, usize)>,
164 global_params: HashMap<String, f64>,
166 client_weights: HashMap<String, f64>,
168 aggregation_protocol: SecureAggregationProtocol,
170 byzantine_threshold: f64,
172}
173
174#[derive(Debug, Clone)]
175pub enum SecureAggregationProtocol {
176 FederatedAveraging,
178 SecureMultiparty,
180 HomomorphicEncryption,
182 QuantumSecretSharing,
184}
185
186impl QuantumFLServer {
187 pub fn new(
189 model_config: Vec<(String, usize)>,
190 aggregation_protocol: SecureAggregationProtocol,
191 byzantine_threshold: f64,
192 ) -> Self {
193 Self {
194 model_config,
195 global_params: HashMap::new(),
196 client_weights: HashMap::new(),
197 aggregation_protocol,
198 byzantine_threshold,
199 }
200 }
201
202 pub fn aggregate_updates(
204 &mut self,
205 client_updates: Vec<(String, HashMap<String, f64>, usize)>, ) -> Result<HashMap<String, f64>> {
207 match self.aggregation_protocol {
208 SecureAggregationProtocol::FederatedAveraging => {
209 self.federated_averaging(client_updates)
210 }
211 SecureAggregationProtocol::SecureMultiparty => {
212 self.secure_multiparty_aggregation(client_updates)
213 }
214 SecureAggregationProtocol::HomomorphicEncryption => {
215 self.homomorphic_aggregation(client_updates)
216 }
217 SecureAggregationProtocol::QuantumSecretSharing => {
218 self.quantum_secret_sharing_aggregation(client_updates)
219 }
220 }
221 }
222
223 fn federated_averaging(
225 &mut self,
226 client_updates: Vec<(String, HashMap<String, f64>, usize)>,
227 ) -> Result<HashMap<String, f64>> {
228 let total_samples: usize = client_updates.iter().map(|(_, _, size)| size).sum();
229 let mut aggregated = HashMap::new();
230
231 for (client_id, params, dataset_size) in client_updates {
233 let weight = dataset_size as f64 / total_samples as f64;
234 self.client_weights.insert(client_id.clone(), weight);
235
236 for (param_name, param_value) in params {
237 *aggregated.entry(param_name).or_insert(0.0) += weight * param_value;
238 }
239 }
240
241 self.global_params = aggregated.clone();
242 Ok(aggregated)
243 }
244
245 fn secure_multiparty_aggregation(
247 &mut self,
248 client_updates: Vec<(String, HashMap<String, f64>, usize)>,
249 ) -> Result<HashMap<String, f64>> {
250 let num_clients = client_updates.len();
252 let mut shares: HashMap<String, Vec<f64>> = HashMap::new();
253
254 for (_, params, _) in &client_updates {
256 for (param_name, param_value) in params {
257 shares
258 .entry(param_name.clone())
259 .or_insert(Vec::new())
260 .push(*param_value);
261 }
262 }
263
264 let mut aggregated = HashMap::new();
266 for (param_name, param_shares) in shares {
267 let aggregated_value = self.byzantine_robust_aggregation(¶m_shares)?;
268 aggregated.insert(param_name, aggregated_value);
269 }
270
271 self.global_params = aggregated.clone();
272 Ok(aggregated)
273 }
274
275 fn homomorphic_aggregation(
277 &mut self,
278 client_updates: Vec<(String, HashMap<String, f64>, usize)>,
279 ) -> Result<HashMap<String, f64>> {
280 let mut encrypted_sum = HashMap::new();
284
285 for (_, params, _) in &client_updates {
286 for (param_name, param_value) in params {
287 let encrypted = self.homomorphic_encrypt(*param_value)?;
289
290 *encrypted_sum.entry(param_name.clone()).or_insert(0.0) += encrypted;
292 }
293 }
294
295 let mut aggregated = HashMap::new();
297 for (param_name, encrypted_value) in encrypted_sum {
298 let decrypted = self.homomorphic_decrypt(encrypted_value)?;
299 aggregated.insert(param_name, decrypted / client_updates.len() as f64);
300 }
301
302 self.global_params = aggregated.clone();
303 Ok(aggregated)
304 }
305
306 fn quantum_secret_sharing_aggregation(
308 &mut self,
309 client_updates: Vec<(String, HashMap<String, f64>, usize)>,
310 ) -> Result<HashMap<String, f64>> {
311 let num_clients = client_updates.len();
312 let threshold = ((num_clients as f64) * self.byzantine_threshold).ceil() as usize;
313
314 let mut quantum_shares: HashMap<String, Vec<QuantumShare>> = HashMap::new();
316
317 for (client_id, params, _) in &client_updates {
318 for (param_name, param_value) in params {
319 let share = self.create_quantum_share(client_id, *param_value)?;
320 quantum_shares
321 .entry(param_name.clone())
322 .or_insert(Vec::new())
323 .push(share);
324 }
325 }
326
327 let mut aggregated = HashMap::new();
329 for (param_name, shares) in quantum_shares {
330 if shares.len() >= threshold {
331 let reconstructed = self.reconstruct_from_quantum_shares(&shares)?;
332 aggregated.insert(param_name, reconstructed);
333 }
334 }
335
336 self.global_params = aggregated.clone();
337 Ok(aggregated)
338 }
339
340 fn byzantine_robust_aggregation(&self, values: &[f64]) -> Result<f64> {
342 if values.is_empty() {
343 return Err(MLError::InvalidInput("No values to aggregate".to_string()));
344 }
345
346 let n = values.len();
348 let f = ((n as f64 * self.byzantine_threshold) as usize).min(n / 2);
349
350 let mut scores = vec![0.0; n];
352 for i in 0..n {
353 let mut distances: Vec<f64> = (0..n)
354 .filter(|&j| j != i)
355 .map(|j| (values[i] - values[j]).abs())
356 .collect();
357 distances.sort_by(|a, b| a.partial_cmp(b).unwrap());
358
359 scores[i] = distances.iter().take(n - f - 1).sum();
361 }
362
363 let best_idx = scores
365 .iter()
366 .enumerate()
367 .min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
368 .map(|(idx, _)| idx)
369 .unwrap();
370
371 Ok(values[best_idx])
372 }
373
374 fn homomorphic_encrypt(&self, value: f64) -> Result<f64> {
376 Ok(value * 1000.0 + fastrand::f64() * 10.0)
378 }
379
380 fn homomorphic_decrypt(&self, encrypted: f64) -> Result<f64> {
382 Ok((encrypted - 5.0) / 1000.0)
384 }
385
386 fn create_quantum_share(&self, client_id: &str, value: f64) -> Result<QuantumShare> {
388 let num_qubits = 3;
389 let mut circuit = VariationalCircuit::new(num_qubits);
390
391 circuit.add_gate("RY", vec![0], vec![(value * PI).to_string()]);
393
394 circuit.add_gate("H", vec![1], vec![]);
396 circuit.add_gate("CNOT", vec![1, 2], vec![]);
397 circuit.add_gate("CNOT", vec![0, 1], vec![]);
398
399 Ok(QuantumShare {
400 client_id: client_id.to_string(),
401 share_circuit: circuit,
402 share_value: value,
403 })
404 }
405
406 fn reconstruct_from_quantum_shares(&self, shares: &[QuantumShare]) -> Result<f64> {
408 let sum: f64 = shares.iter().map(|s| s.share_value).sum();
411 Ok(sum / shares.len() as f64)
412 }
413}
414
415#[derive(Debug)]
417struct QuantumShare {
418 client_id: String,
419 share_circuit: VariationalCircuit,
420 share_value: f64,
421}
422
423#[derive(Debug)]
425pub struct DistributedQuantumLearning {
426 server: QuantumFLServer,
428 clients: HashMap<String, QuantumFLClient>,
430 rounds: usize,
432 convergence_threshold: f64,
434}
435
436impl DistributedQuantumLearning {
437 pub fn new(
439 num_clients: usize,
440 model_config: Vec<(String, usize)>,
441 aggregation_protocol: SecureAggregationProtocol,
442 epsilon: f64,
443 ) -> Result<Self> {
444 let server = QuantumFLServer::new(
445 model_config.clone(),
446 aggregation_protocol,
447 0.2, );
449
450 let mut clients = HashMap::new();
451 for i in 0..num_clients {
452 let client_id = format!("client_{}", i);
453 let dataset_size = 100 + fastrand::usize(..900); let client =
455 QuantumFLClient::new(client_id.clone(), &model_config, dataset_size, epsilon)?;
456 clients.insert(client_id, client);
457 }
458
459 Ok(Self {
460 server,
461 clients,
462 rounds: 0,
463 convergence_threshold: 1e-4,
464 })
465 }
466
467 pub fn train(
469 &mut self,
470 data_distribution: &HashMap<String, (Array2<f64>, Array1<i32>)>,
471 num_rounds: usize,
472 clients_per_round: usize,
473 ) -> Result<FederatedTrainingResult> {
474 let mut round_losses = Vec::new();
475 let mut convergence_metric = f64::INFINITY;
476
477 for round in 0..num_rounds {
478 self.rounds = round + 1;
479
480 let selected_clients = self.select_clients(clients_per_round);
482
483 let mut client_updates = Vec::new();
485 let mut round_loss = 0.0;
486
487 for client_id in selected_clients {
488 if let Some(client) = self.clients.get_mut(&client_id) {
489 if let Some((data, labels)) = data_distribution.get(&client_id) {
490 let loss = client.train_local(data, labels, 5)?;
492 round_loss += loss;
493
494 let params = client.get_parameters();
496 let dataset_size = data.nrows();
497 client_updates.push((client_id.clone(), params, dataset_size));
498 }
499 }
500 }
501
502 let aggregated = self.server.aggregate_updates(client_updates)?;
504
505 for (_, client) in self.clients.iter_mut() {
507 client.set_parameters(aggregated.clone());
508 }
509
510 if round > 0 {
512 let prev_params = self.server.global_params.clone();
513 convergence_metric = self.compute_convergence(&prev_params, &aggregated)?;
514
515 if convergence_metric < self.convergence_threshold {
516 round_losses.push(round_loss / clients_per_round as f64);
517 break;
518 }
519 }
520
521 round_losses.push(round_loss / clients_per_round as f64);
522
523 self.server.global_params = aggregated.clone();
525 }
526
527 Ok(FederatedTrainingResult {
528 final_model_params: self.server.global_params.clone(),
529 round_losses,
530 num_rounds: self.rounds,
531 converged: convergence_metric < self.convergence_threshold,
532 convergence_metric,
533 })
534 }
535
536 fn select_clients(&self, num_clients: usize) -> Vec<String> {
538 let all_clients: Vec<String> = self.clients.keys().cloned().collect();
539 let mut selected = Vec::new();
540
541 while selected.len() < num_clients.min(all_clients.len()) {
542 let idx = fastrand::usize(..all_clients.len());
543 let client = all_clients[idx].clone();
544 if !selected.contains(&client) {
545 selected.push(client);
546 }
547 }
548
549 selected
550 }
551
552 fn compute_convergence(
554 &self,
555 old_params: &HashMap<String, f64>,
556 new_params: &HashMap<String, f64>,
557 ) -> Result<f64> {
558 let mut diff_sum = 0.0;
559 let mut count = 0;
560
561 for (key, new_val) in new_params {
562 if let Some(old_val) = old_params.get(key) {
563 diff_sum += (new_val - old_val).abs();
564 count += 1;
565 }
566 }
567
568 Ok(if count > 0 {
569 diff_sum / count as f64
570 } else {
571 0.0
572 })
573 }
574}
575
576#[derive(Debug)]
578pub struct FederatedTrainingResult {
579 pub final_model_params: HashMap<String, f64>,
581 pub round_losses: Vec<f64>,
583 pub num_rounds: usize,
585 pub converged: bool,
587 pub convergence_metric: f64,
589}
590
591pub mod privacy {
593 use super::*;
594
595 #[derive(Debug)]
597 pub struct QuantumDifferentialPrivacy {
598 epsilon: f64,
600 sensitivity: f64,
602 mechanism: NoiseType,
604 }
605
606 #[derive(Debug, Clone)]
607 pub enum NoiseType {
608 Laplace,
609 Gaussian,
610 Quantum,
611 }
612
613 impl QuantumDifferentialPrivacy {
614 pub fn new(epsilon: f64, sensitivity: f64, mechanism: NoiseType) -> Self {
616 Self {
617 epsilon,
618 sensitivity,
619 mechanism,
620 }
621 }
622
623 pub fn add_noise(&self, params: &mut HashMap<String, f64>) -> Result<()> {
625 for (_, value) in params.iter_mut() {
626 let noise = match self.mechanism {
627 NoiseType::Laplace => self.laplace_noise(),
628 NoiseType::Gaussian => self.gaussian_noise(),
629 NoiseType::Quantum => self.quantum_noise()?,
630 };
631 *value += noise;
632 }
633 Ok(())
634 }
635
636 fn laplace_noise(&self) -> f64 {
638 let scale = self.sensitivity / self.epsilon;
639 let u = fastrand::f64() - 0.5;
640 -scale * u.signum() * (1.0 - 2.0 * u.abs()).ln()
641 }
642
643 fn gaussian_noise(&self) -> f64 {
645 let scale = self.sensitivity * (2.0 * (1.25 / self.epsilon).ln()).sqrt();
646 QuantumFLClient::gaussian_noise() * scale
647 }
648
649 fn quantum_noise(&self) -> Result<f64> {
651 let p = (-self.epsilon).exp();
653 Ok(if fastrand::f64() < p {
654 fastrand::f64() * 2.0 - 1.0
655 } else {
656 0.0
657 })
658 }
659 }
660}
661
662#[cfg(test)]
663mod tests {
664 use super::*;
665 use scirs2_core::ndarray::array;
666
667 #[test]
668 fn test_quantum_fl_client() {
669 let config = vec![
670 ("encoding".to_string(), 4),
671 ("variational".to_string(), 8),
672 ("measurement".to_string(), 0),
673 ];
674
675 let mut client = QuantumFLClient::new("client_1".to_string(), &config, 100, 1.0).unwrap();
676
677 let data = array![[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]];
678 let labels = array![0, 1, 0];
679
680 let loss = client.train_local(&data, &labels, 1).unwrap();
681 assert!(loss >= 0.0);
682 }
683
684 #[test]
685 fn test_federated_averaging() {
686 let config = vec![("encoding".to_string(), 4)];
687 let mut server =
688 QuantumFLServer::new(config, SecureAggregationProtocol::FederatedAveraging, 0.2);
689
690 let mut params1 = HashMap::new();
691 params1.insert("w1".to_string(), 0.5);
692 params1.insert("w2".to_string(), 0.3);
693
694 let mut params2 = HashMap::new();
695 params2.insert("w1".to_string(), 0.7);
696 params2.insert("w2".to_string(), 0.4);
697
698 let updates = vec![
699 ("client1".to_string(), params1, 100),
700 ("client2".to_string(), params2, 200),
701 ];
702
703 let aggregated = server.aggregate_updates(updates).unwrap();
704
705 assert!((aggregated["w1"] - 0.633).abs() < 0.01);
707 }
708
709 #[test]
710 fn test_byzantine_robust_aggregation() {
711 let server = QuantumFLServer::new(vec![], SecureAggregationProtocol::SecureMultiparty, 0.3);
712
713 let values = vec![0.5, 0.52, 0.48, 0.51, 10.0]; let robust_value = server.byzantine_robust_aggregation(&values).unwrap();
716
717 assert!(robust_value < 1.0);
719 }
720
721 #[test]
722 fn test_differential_privacy() {
723 use privacy::*;
724
725 let dp = QuantumDifferentialPrivacy::new(1.0, 0.1, NoiseType::Gaussian);
726
727 let mut params = HashMap::new();
728 params.insert("param1".to_string(), 0.5);
729 params.insert("param2".to_string(), 0.3);
730
731 let original = params.clone();
732 dp.add_noise(&mut params).unwrap();
733
734 assert_ne!(params["param1"], original["param1"]);
736 assert_ne!(params["param2"], original["param2"]);
737 }
738
739 #[test]
740 fn test_distributed_learning() {
741 let config = vec![("encoding".to_string(), 4), ("variational".to_string(), 8)];
742
743 let mut system = DistributedQuantumLearning::new(
744 3, config,
746 SecureAggregationProtocol::FederatedAveraging,
747 1.0,
748 )
749 .unwrap();
750
751 let mut data_dist = HashMap::new();
753 for i in 0..3 {
754 let data = Array2::zeros((10, 4));
755 let labels = Array1::zeros(10);
756 data_dist.insert(format!("client_{}", i), (data, labels));
757 }
758
759 let result = system.train(&data_dist, 2, 2).unwrap();
760
761 assert_eq!(result.num_rounds, 2);
762 assert_eq!(result.round_losses.len(), 2);
763 }
764}