1use crate::{
29 error::{QuantRS2Error, QuantRS2Result},
30 gate::GateOp,
31 qubit::QubitId,
32};
33use scirs2_core::ndarray::{Array1, Array2, Axis};
34use scirs2_core::random::prelude::*;
35use scirs2_core::Complex64;
36use std::collections::HashMap;
37use std::f64::consts::PI;
38
39#[derive(Debug, Clone)]
41pub struct QuantumFederatedConfig {
42 pub num_qubits: usize,
44 pub circuit_depth: usize,
46 pub num_clients: usize,
48 pub client_fraction: f64,
50 pub local_epochs: usize,
52 pub local_lr: f64,
54 pub aggregation: AggregationStrategy,
56 pub dp_epsilon: f64,
58 pub dp_delta: f64,
60}
61
62impl Default for QuantumFederatedConfig {
63 fn default() -> Self {
64 Self {
65 num_qubits: 4,
66 circuit_depth: 3,
67 num_clients: 10,
68 client_fraction: 0.3,
69 local_epochs: 5,
70 local_lr: 0.01,
71 aggregation: AggregationStrategy::FedAvg,
72 dp_epsilon: 1.0,
73 dp_delta: 1e-5,
74 }
75 }
76}
77
78#[derive(Debug, Clone, Copy, PartialEq, Eq)]
80pub enum AggregationStrategy {
81 FedAvg,
83 WeightedAvg,
85 Median,
87 TrimmedMean,
89 Krum,
91}
92
93#[derive(Debug, Clone)]
95pub struct QuantumFederatedClient {
96 id: usize,
98 params: Array2<f64>,
100 num_qubits: usize,
102 depth: usize,
104 dataset_size: usize,
106}
107
108impl QuantumFederatedClient {
109 pub fn new(id: usize, num_qubits: usize, depth: usize, dataset_size: usize) -> Self {
111 let mut rng = thread_rng();
112 let params = Array2::from_shape_fn((depth, num_qubits * 3), |_| rng.gen_range(-PI..PI));
113
114 Self {
115 id,
116 params,
117 num_qubits,
118 depth,
119 dataset_size,
120 }
121 }
122
123 pub fn train_local(
125 &mut self,
126 data: &[Array1<Complex64>],
127 labels: &[usize],
128 epochs: usize,
129 lr: f64,
130 ) -> QuantRS2Result<f64> {
131 let mut total_loss = 0.0;
132
133 for _ in 0..epochs {
134 let loss = self.compute_loss(data, labels)?;
135 total_loss += loss;
136
137 let gradients = self.compute_gradients(data, labels)?;
139
140 self.params = &self.params - &(gradients * lr);
142 }
143
144 Ok(total_loss / epochs as f64)
145 }
146
147 fn compute_loss(&self, data: &[Array1<Complex64>], labels: &[usize]) -> QuantRS2Result<f64> {
149 let mut total_loss = 0.0;
150
151 for (state, &label) in data.iter().zip(labels.iter()) {
152 let output = self.forward(state)?;
153
154 total_loss -= output[label].ln();
156 }
157
158 Ok(total_loss / data.len() as f64)
159 }
160
161 fn forward(&self, state: &Array1<Complex64>) -> QuantRS2Result<Array1<f64>> {
163 let mut encoded = state.clone();
164
165 for layer in 0..self.depth {
167 for q in 0..self.num_qubits {
168 let rx = self.params[[layer, q * 3]];
169 let ry = self.params[[layer, q * 3 + 1]];
170 let rz = self.params[[layer, q * 3 + 2]];
171
172 encoded = self.apply_rotation(&encoded, q, rx, ry, rz)?;
173 }
174
175 for q in 0..self.num_qubits - 1 {
177 encoded = self.apply_cnot(&encoded, q, q + 1)?;
178 }
179 }
180
181 let mut expectations = Array1::zeros(2); expectations[0] = self.pauli_z_expectation(&encoded, 0)?;
184 expectations[1] = 1.0 - expectations[0];
185
186 let max_exp = expectations
188 .iter()
189 .copied()
190 .fold(f64::NEG_INFINITY, f64::max);
191 let mut probs = Array1::zeros(2);
192 let mut sum = 0.0;
193
194 for i in 0..2 {
195 probs[i] = (expectations[i] - max_exp).exp();
196 sum += probs[i];
197 }
198
199 for i in 0..2 {
200 probs[i] /= sum;
201 }
202
203 Ok(probs)
204 }
205
206 fn compute_gradients(
208 &self,
209 data: &[Array1<Complex64>],
210 labels: &[usize],
211 ) -> QuantRS2Result<Array2<f64>> {
212 let epsilon = PI / 2.0; let mut gradients = Array2::zeros(self.params.dim());
214
215 for i in 0..self.params.shape()[0] {
216 for j in 0..self.params.shape()[1] {
217 let mut client_plus = self.clone();
219 client_plus.params[[i, j]] += epsilon;
220 let loss_plus = client_plus.compute_loss(data, labels)?;
221
222 let mut client_minus = self.clone();
224 client_minus.params[[i, j]] -= epsilon;
225 let loss_minus = client_minus.compute_loss(data, labels)?;
226
227 gradients[[i, j]] = (loss_plus - loss_minus) / 2.0;
229 }
230 }
231
232 Ok(gradients)
233 }
234
235 pub const fn get_params(&self) -> &Array2<f64> {
237 &self.params
238 }
239
240 pub fn set_params(&mut self, params: Array2<f64>) {
242 self.params = params;
243 }
244
245 pub const fn dataset_size(&self) -> usize {
247 self.dataset_size
248 }
249
250 fn apply_rotation(
252 &self,
253 state: &Array1<Complex64>,
254 qubit: usize,
255 rx: f64,
256 ry: f64,
257 rz: f64,
258 ) -> QuantRS2Result<Array1<Complex64>> {
259 let mut result = state.clone();
260 result = self.apply_rz_gate(&result, qubit, rz)?;
261 result = self.apply_ry_gate(&result, qubit, ry)?;
262 result = self.apply_rx_gate(&result, qubit, rx)?;
263 Ok(result)
264 }
265
266 fn apply_rx_gate(
267 &self,
268 state: &Array1<Complex64>,
269 qubit: usize,
270 angle: f64,
271 ) -> QuantRS2Result<Array1<Complex64>> {
272 let dim = state.len();
273 let mut new_state = Array1::zeros(dim);
274 let cos_half = Complex64::new((angle / 2.0).cos(), 0.0);
275 let sin_half = Complex64::new(0.0, -(angle / 2.0).sin());
276
277 for i in 0..dim {
278 let j = i ^ (1 << qubit);
279 new_state[i] = state[i] * cos_half + state[j] * sin_half;
280 }
281
282 Ok(new_state)
283 }
284
285 fn apply_ry_gate(
286 &self,
287 state: &Array1<Complex64>,
288 qubit: usize,
289 angle: f64,
290 ) -> QuantRS2Result<Array1<Complex64>> {
291 let dim = state.len();
292 let mut new_state = Array1::zeros(dim);
293 let cos_half = (angle / 2.0).cos();
294 let sin_half = (angle / 2.0).sin();
295
296 for i in 0..dim {
297 let bit = (i >> qubit) & 1;
298 let j = i ^ (1 << qubit);
299 if bit == 0 {
300 new_state[i] = state[i] * cos_half - state[j] * sin_half;
301 } else {
302 new_state[i] = state[i] * cos_half + state[j] * sin_half;
303 }
304 }
305
306 Ok(new_state)
307 }
308
309 fn apply_rz_gate(
310 &self,
311 state: &Array1<Complex64>,
312 qubit: usize,
313 angle: f64,
314 ) -> QuantRS2Result<Array1<Complex64>> {
315 let dim = state.len();
316 let mut new_state = state.clone();
317 let phase = Complex64::new((angle / 2.0).cos(), -(angle / 2.0).sin());
318
319 for i in 0..dim {
320 let bit = (i >> qubit) & 1;
321 new_state[i] = if bit == 1 {
322 new_state[i] * phase
323 } else {
324 new_state[i] * phase.conj()
325 };
326 }
327
328 Ok(new_state)
329 }
330
331 fn apply_cnot(
332 &self,
333 state: &Array1<Complex64>,
334 control: usize,
335 target: usize,
336 ) -> QuantRS2Result<Array1<Complex64>> {
337 let dim = state.len();
338 let mut new_state = state.clone();
339
340 for i in 0..dim {
341 let control_bit = (i >> control) & 1;
342 if control_bit == 1 {
343 let j = i ^ (1 << target);
344 if i < j {
345 let temp = new_state[i];
346 new_state[i] = new_state[j];
347 new_state[j] = temp;
348 }
349 }
350 }
351
352 Ok(new_state)
353 }
354
355 fn pauli_z_expectation(&self, state: &Array1<Complex64>, qubit: usize) -> QuantRS2Result<f64> {
356 let dim = state.len();
357 let mut expectation = 0.0;
358
359 for i in 0..dim {
360 let bit = (i >> qubit) & 1;
361 let sign = if bit == 0 { 1.0 } else { -1.0 };
362 expectation += sign * state[i].norm_sqr();
363 }
364
365 Ok(f64::midpoint(expectation, 1.0))
367 }
368}
369
370#[derive(Debug)]
372pub struct QuantumFederatedServer {
373 config: QuantumFederatedConfig,
375 global_params: Array2<f64>,
377 clients: Vec<QuantumFederatedClient>,
379 history: Vec<f64>,
381}
382
383impl QuantumFederatedServer {
384 pub fn new(config: QuantumFederatedConfig) -> Self {
386 let mut rng = thread_rng();
387
388 let global_params =
390 Array2::from_shape_fn((config.circuit_depth, config.num_qubits * 3), |_| {
391 rng.gen_range(-PI..PI)
392 });
393
394 let mut clients = Vec::with_capacity(config.num_clients);
396 for i in 0..config.num_clients {
397 let dataset_size = rng.gen_range(50..200);
398 clients.push(QuantumFederatedClient::new(
399 i,
400 config.num_qubits,
401 config.circuit_depth,
402 dataset_size,
403 ));
404 }
405
406 Self {
407 config,
408 global_params,
409 clients,
410 history: Vec::new(),
411 }
412 }
413
414 pub fn train_round(
416 &mut self,
417 client_data: &HashMap<usize, (Vec<Array1<Complex64>>, Vec<usize>)>,
418 ) -> QuantRS2Result<f64> {
419 let num_selected =
421 (self.config.num_clients as f64 * self.config.client_fraction).ceil() as usize;
422 let selected_clients = self.select_clients(num_selected);
423
424 for &client_id in &selected_clients {
426 self.clients[client_id].set_params(self.global_params.clone());
427 }
428
429 let mut client_updates = Vec::new();
431 let mut client_weights = Vec::new();
432 let mut avg_loss = 0.0;
433
434 for &client_id in &selected_clients {
435 if let Some((data, labels)) = client_data.get(&client_id) {
436 let loss = self.clients[client_id].train_local(
437 data,
438 labels,
439 self.config.local_epochs,
440 self.config.local_lr,
441 )?;
442
443 avg_loss += loss;
444
445 client_updates.push(self.clients[client_id].get_params().clone());
446 client_weights.push(self.clients[client_id].dataset_size() as f64);
447 }
448 }
449
450 avg_loss /= selected_clients.len() as f64;
451 self.history.push(avg_loss);
452
453 self.aggregate_updates(&client_updates, &client_weights)?;
455
456 Ok(avg_loss)
457 }
458
459 fn select_clients(&self, num_selected: usize) -> Vec<usize> {
461 let mut rng = thread_rng();
462 let mut clients: Vec<usize> = (0..self.config.num_clients).collect();
463
464 for i in (1..clients.len()).rev() {
466 let j = rng.gen_range(0..=i);
467 clients.swap(i, j);
468 }
469
470 clients.truncate(num_selected);
471 clients
472 }
473
474 fn aggregate_updates(
476 &mut self,
477 updates: &[Array2<f64>],
478 weights: &[f64],
479 ) -> QuantRS2Result<()> {
480 match self.config.aggregation {
481 AggregationStrategy::FedAvg => {
482 self.federated_averaging(updates)?;
483 }
484 AggregationStrategy::WeightedAvg => {
485 self.weighted_averaging(updates, weights)?;
486 }
487 AggregationStrategy::Median => {
488 self.median_aggregation(updates)?;
489 }
490 AggregationStrategy::TrimmedMean => {
491 self.trimmed_mean_aggregation(updates, 0.1)?;
492 }
493 AggregationStrategy::Krum => {
494 self.krum_aggregation(updates)?;
495 }
496 }
497
498 if self.config.dp_epsilon > 0.0 {
500 self.apply_differential_privacy()?;
501 }
502
503 Ok(())
504 }
505
506 fn federated_averaging(&mut self, updates: &[Array2<f64>]) -> QuantRS2Result<()> {
508 let mut avg_params = Array2::zeros(self.global_params.dim());
509
510 for update in updates {
511 avg_params = avg_params + update;
512 }
513
514 avg_params = avg_params / (updates.len() as f64);
515 self.global_params = avg_params;
516
517 Ok(())
518 }
519
520 fn weighted_averaging(
522 &mut self,
523 updates: &[Array2<f64>],
524 weights: &[f64],
525 ) -> QuantRS2Result<()> {
526 let total_weight: f64 = weights.iter().sum();
527 let mut weighted_params = Array2::zeros(self.global_params.dim());
528
529 for (update, &weight) in updates.iter().zip(weights.iter()) {
530 weighted_params = weighted_params + update * (weight / total_weight);
531 }
532
533 self.global_params = weighted_params;
534 Ok(())
535 }
536
537 fn median_aggregation(&mut self, updates: &[Array2<f64>]) -> QuantRS2Result<()> {
539 let shape = self.global_params.dim();
540 let mut median_params = Array2::zeros(shape);
541
542 for i in 0..shape.0 {
543 for j in 0..shape.1 {
544 let mut values: Vec<f64> = updates.iter().map(|u| u[[i, j]]).collect();
545 values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
546
547 median_params[[i, j]] = if values.len() % 2 == 0 {
548 f64::midpoint(values[values.len() / 2 - 1], values[values.len() / 2])
549 } else {
550 values[values.len() / 2]
551 };
552 }
553 }
554
555 self.global_params = median_params;
556 Ok(())
557 }
558
559 fn trimmed_mean_aggregation(
561 &mut self,
562 updates: &[Array2<f64>],
563 trim_ratio: f64,
564 ) -> QuantRS2Result<()> {
565 let shape = self.global_params.dim();
566 let mut trimmed_params = Array2::zeros(shape);
567 let trim_count = (updates.len() as f64 * trim_ratio).floor() as usize;
568
569 for i in 0..shape.0 {
570 for j in 0..shape.1 {
571 let mut values: Vec<f64> = updates.iter().map(|u| u[[i, j]]).collect();
572 values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
573
574 let trimmed: Vec<f64> = values[trim_count..values.len() - trim_count].to_vec();
576 trimmed_params[[i, j]] = trimmed.iter().sum::<f64>() / trimmed.len() as f64;
577 }
578 }
579
580 self.global_params = trimmed_params;
581 Ok(())
582 }
583
584 fn krum_aggregation(&mut self, updates: &[Array2<f64>]) -> QuantRS2Result<()> {
586 let n = updates.len();
587 let f = (n - 1) / 2; let n_minus_f_minus_2 = n - f - 2;
589
590 let mut scores = vec![0.0; n];
592
593 for i in 0..n {
594 let mut distances: Vec<(usize, f64)> = Vec::new();
595
596 for j in 0..n {
597 if i != j {
598 let diff = &updates[i] - &updates[j];
599 let dist: f64 = diff.iter().map(|x| x * x).sum::<f64>().sqrt();
600 distances.push((j, dist));
601 }
602 }
603
604 distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
606 scores[i] = distances
607 .iter()
608 .take(n_minus_f_minus_2)
609 .map(|(_, d)| d)
610 .sum();
611 }
612
613 let best_client = scores
615 .iter()
616 .enumerate()
617 .min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
618 .map(|(idx, _)| idx)
619 .unwrap_or(0);
620
621 self.global_params.clone_from(&updates[best_client]);
622 Ok(())
623 }
624
625 fn apply_differential_privacy(&mut self) -> QuantRS2Result<()> {
627 let mut rng = thread_rng();
628
629 let sensitivity = 1.0; let noise_scale = sensitivity / self.config.dp_epsilon;
632
633 for i in 0..self.global_params.shape()[0] {
635 for j in 0..self.global_params.shape()[1] {
636 let noise = rng.gen_range(-1.0..1.0) * noise_scale;
637 self.global_params[[i, j]] += noise;
638 }
639 }
640
641 Ok(())
642 }
643
644 pub const fn get_global_params(&self) -> &Array2<f64> {
646 &self.global_params
647 }
648
649 pub fn history(&self) -> &[f64] {
651 &self.history
652 }
653}
654
655#[cfg(test)]
656mod tests {
657 use super::*;
658
659 #[test]
660 fn test_federated_client() {
661 let mut client = QuantumFederatedClient::new(0, 2, 2, 100);
662
663 let state = Array1::from_vec(vec![
664 Complex64::new(1.0, 0.0),
665 Complex64::new(0.0, 0.0),
666 Complex64::new(0.0, 0.0),
667 Complex64::new(0.0, 0.0),
668 ]);
669
670 let probs = client
671 .forward(&state)
672 .expect("Failed to forward through client");
673 assert_eq!(probs.len(), 2);
674
675 let sum: f64 = probs.iter().sum();
676 assert!((sum - 1.0).abs() < 1e-6);
677 }
678
679 #[test]
680 fn test_federated_server() {
681 let config = QuantumFederatedConfig {
682 num_qubits: 2,
683 circuit_depth: 2,
684 num_clients: 5,
685 client_fraction: 0.6,
686 local_epochs: 2,
687 local_lr: 0.01,
688 aggregation: AggregationStrategy::FedAvg,
689 dp_epsilon: 0.0,
690 dp_delta: 1e-5,
691 };
692
693 let server = QuantumFederatedServer::new(config);
694 assert_eq!(server.clients.len(), 5);
695 }
696}