quantrs2_core/qml/
quantum_federated.rs

1//! Quantum Federated Learning
2//!
3//! This module implements federated learning for quantum machine learning,
4//! enabling privacy-preserving distributed training across multiple quantum
5//! devices without sharing raw quantum data.
6//!
7//! # Theoretical Background
8//!
9//! Quantum Federated Learning extends classical federated learning to quantum
10//! computing, allowing multiple parties to collaboratively train quantum models
11//! while keeping their quantum data private. This is crucial for applications
12//! in healthcare, finance, and defense where quantum data privacy is paramount.
13//!
14//! # Key Features
15//!
16//! - **Distributed Quantum Training**: Train across multiple quantum computers
17//! - **Privacy-Preserving Aggregation**: Secure parameter averaging
18//! - **Differential Privacy**: Noise injection for formal privacy guarantees
19//! - **Byzantine-Robust Aggregation**: Defense against malicious participants
20//! - **Adaptive Communication**: Minimize quantum circuit transmission
21//!
22//! # References
23//!
24//! - "Federated Learning with Quantum Computing" (2023)
25//! - "Privacy-Preserving Quantum Machine Learning" (2024)
26//! - "Distributed Quantum Neural Networks" (2024)
27
28use crate::{
29    error::{QuantRS2Error, QuantRS2Result},
30    gate::GateOp,
31    qubit::QubitId,
32};
33use scirs2_core::ndarray::{Array1, Array2, Axis};
34use scirs2_core::random::prelude::*;
35use scirs2_core::Complex64;
36use std::collections::HashMap;
37use std::f64::consts::PI;
38
39/// Configuration for quantum federated learning
40#[derive(Debug, Clone)]
41pub struct QuantumFederatedConfig {
42    /// Number of qubits in the quantum model
43    pub num_qubits: usize,
44    /// Circuit depth
45    pub circuit_depth: usize,
46    /// Number of clients
47    pub num_clients: usize,
48    /// Fraction of clients selected per round
49    pub client_fraction: f64,
50    /// Number of local training epochs
51    pub local_epochs: usize,
52    /// Local learning rate
53    pub local_lr: f64,
54    /// Aggregation strategy
55    pub aggregation: AggregationStrategy,
56    /// Differential privacy epsilon (0.0 = no DP)
57    pub dp_epsilon: f64,
58    /// Differential privacy delta
59    pub dp_delta: f64,
60}
61
62impl Default for QuantumFederatedConfig {
63    fn default() -> Self {
64        Self {
65            num_qubits: 4,
66            circuit_depth: 3,
67            num_clients: 10,
68            client_fraction: 0.3,
69            local_epochs: 5,
70            local_lr: 0.01,
71            aggregation: AggregationStrategy::FedAvg,
72            dp_epsilon: 1.0,
73            dp_delta: 1e-5,
74        }
75    }
76}
77
78/// Aggregation strategy for federated learning
79#[derive(Debug, Clone, Copy, PartialEq, Eq)]
80pub enum AggregationStrategy {
81    /// Federated averaging (FedAvg)
82    FedAvg,
83    /// Weighted averaging by dataset size
84    WeightedAvg,
85    /// Median aggregation (Byzantine-robust)
86    Median,
87    /// Trimmed mean (Byzantine-robust)
88    TrimmedMean,
89    /// Krum (Byzantine-robust)
90    Krum,
91}
92
93/// Quantum federated client
94#[derive(Debug, Clone)]
95pub struct QuantumFederatedClient {
96    /// Client ID
97    id: usize,
98    /// Local quantum circuit parameters
99    params: Array2<f64>,
100    /// Number of qubits
101    num_qubits: usize,
102    /// Circuit depth
103    depth: usize,
104    /// Local dataset size
105    dataset_size: usize,
106}
107
108impl QuantumFederatedClient {
109    /// Create new federated client
110    pub fn new(id: usize, num_qubits: usize, depth: usize, dataset_size: usize) -> Self {
111        let mut rng = thread_rng();
112        let params = Array2::from_shape_fn((depth, num_qubits * 3), |_| rng.gen_range(-PI..PI));
113
114        Self {
115            id,
116            params,
117            num_qubits,
118            depth,
119            dataset_size,
120        }
121    }
122
123    /// Local training on client's quantum data
124    pub fn train_local(
125        &mut self,
126        data: &[Array1<Complex64>],
127        labels: &[usize],
128        epochs: usize,
129        lr: f64,
130    ) -> QuantRS2Result<f64> {
131        let mut total_loss = 0.0;
132
133        for _ in 0..epochs {
134            let loss = self.compute_loss(data, labels)?;
135            total_loss += loss;
136
137            // Compute gradients using parameter-shift rule
138            let gradients = self.compute_gradients(data, labels)?;
139
140            // Update parameters
141            self.params = &self.params - &(gradients * lr);
142        }
143
144        Ok(total_loss / epochs as f64)
145    }
146
147    /// Compute loss on local data
148    fn compute_loss(&self, data: &[Array1<Complex64>], labels: &[usize]) -> QuantRS2Result<f64> {
149        let mut total_loss = 0.0;
150
151        for (state, &label) in data.iter().zip(labels.iter()) {
152            let output = self.forward(state)?;
153
154            // Cross-entropy loss
155            total_loss -= output[label].ln();
156        }
157
158        Ok(total_loss / data.len() as f64)
159    }
160
161    /// Forward pass through quantum circuit
162    fn forward(&self, state: &Array1<Complex64>) -> QuantRS2Result<Array1<f64>> {
163        let mut encoded = state.clone();
164
165        // Apply parameterized quantum circuit
166        for layer in 0..self.depth {
167            for q in 0..self.num_qubits {
168                let rx = self.params[[layer, q * 3]];
169                let ry = self.params[[layer, q * 3 + 1]];
170                let rz = self.params[[layer, q * 3 + 2]];
171
172                encoded = self.apply_rotation(&encoded, q, rx, ry, rz)?;
173            }
174
175            // Entangling layer
176            for q in 0..self.num_qubits - 1 {
177                encoded = self.apply_cnot(&encoded, q, q + 1)?;
178            }
179        }
180
181        // Measure Pauli-Z expectations
182        let mut expectations = Array1::zeros(2); // Binary classification
183        expectations[0] = self.pauli_z_expectation(&encoded, 0)?;
184        expectations[1] = 1.0 - expectations[0];
185
186        // Softmax
187        let max_exp = expectations
188            .iter()
189            .copied()
190            .fold(f64::NEG_INFINITY, f64::max);
191        let mut probs = Array1::zeros(2);
192        let mut sum = 0.0;
193
194        for i in 0..2 {
195            probs[i] = (expectations[i] - max_exp).exp();
196            sum += probs[i];
197        }
198
199        for i in 0..2 {
200            probs[i] /= sum;
201        }
202
203        Ok(probs)
204    }
205
206    /// Compute gradients using parameter-shift rule
207    fn compute_gradients(
208        &self,
209        data: &[Array1<Complex64>],
210        labels: &[usize],
211    ) -> QuantRS2Result<Array2<f64>> {
212        let epsilon = PI / 2.0; // Parameter-shift rule
213        let mut gradients = Array2::zeros(self.params.dim());
214
215        for i in 0..self.params.shape()[0] {
216            for j in 0..self.params.shape()[1] {
217                // Shift parameter forward
218                let mut client_plus = self.clone();
219                client_plus.params[[i, j]] += epsilon;
220                let loss_plus = client_plus.compute_loss(data, labels)?;
221
222                // Shift parameter backward
223                let mut client_minus = self.clone();
224                client_minus.params[[i, j]] -= epsilon;
225                let loss_minus = client_minus.compute_loss(data, labels)?;
226
227                // Parameter-shift gradient
228                gradients[[i, j]] = (loss_plus - loss_minus) / 2.0;
229            }
230        }
231
232        Ok(gradients)
233    }
234
235    /// Get model parameters
236    pub const fn get_params(&self) -> &Array2<f64> {
237        &self.params
238    }
239
240    /// Set model parameters
241    pub fn set_params(&mut self, params: Array2<f64>) {
242        self.params = params;
243    }
244
245    /// Get dataset size
246    pub const fn dataset_size(&self) -> usize {
247        self.dataset_size
248    }
249
250    // Helper methods
251    fn apply_rotation(
252        &self,
253        state: &Array1<Complex64>,
254        qubit: usize,
255        rx: f64,
256        ry: f64,
257        rz: f64,
258    ) -> QuantRS2Result<Array1<Complex64>> {
259        let mut result = state.clone();
260        result = self.apply_rz_gate(&result, qubit, rz)?;
261        result = self.apply_ry_gate(&result, qubit, ry)?;
262        result = self.apply_rx_gate(&result, qubit, rx)?;
263        Ok(result)
264    }
265
266    fn apply_rx_gate(
267        &self,
268        state: &Array1<Complex64>,
269        qubit: usize,
270        angle: f64,
271    ) -> QuantRS2Result<Array1<Complex64>> {
272        let dim = state.len();
273        let mut new_state = Array1::zeros(dim);
274        let cos_half = Complex64::new((angle / 2.0).cos(), 0.0);
275        let sin_half = Complex64::new(0.0, -(angle / 2.0).sin());
276
277        for i in 0..dim {
278            let j = i ^ (1 << qubit);
279            new_state[i] = state[i] * cos_half + state[j] * sin_half;
280        }
281
282        Ok(new_state)
283    }
284
285    fn apply_ry_gate(
286        &self,
287        state: &Array1<Complex64>,
288        qubit: usize,
289        angle: f64,
290    ) -> QuantRS2Result<Array1<Complex64>> {
291        let dim = state.len();
292        let mut new_state = Array1::zeros(dim);
293        let cos_half = (angle / 2.0).cos();
294        let sin_half = (angle / 2.0).sin();
295
296        for i in 0..dim {
297            let bit = (i >> qubit) & 1;
298            let j = i ^ (1 << qubit);
299            if bit == 0 {
300                new_state[i] = state[i] * cos_half - state[j] * sin_half;
301            } else {
302                new_state[i] = state[i] * cos_half + state[j] * sin_half;
303            }
304        }
305
306        Ok(new_state)
307    }
308
309    fn apply_rz_gate(
310        &self,
311        state: &Array1<Complex64>,
312        qubit: usize,
313        angle: f64,
314    ) -> QuantRS2Result<Array1<Complex64>> {
315        let dim = state.len();
316        let mut new_state = state.clone();
317        let phase = Complex64::new((angle / 2.0).cos(), -(angle / 2.0).sin());
318
319        for i in 0..dim {
320            let bit = (i >> qubit) & 1;
321            new_state[i] = if bit == 1 {
322                new_state[i] * phase
323            } else {
324                new_state[i] * phase.conj()
325            };
326        }
327
328        Ok(new_state)
329    }
330
331    fn apply_cnot(
332        &self,
333        state: &Array1<Complex64>,
334        control: usize,
335        target: usize,
336    ) -> QuantRS2Result<Array1<Complex64>> {
337        let dim = state.len();
338        let mut new_state = state.clone();
339
340        for i in 0..dim {
341            let control_bit = (i >> control) & 1;
342            if control_bit == 1 {
343                let j = i ^ (1 << target);
344                if i < j {
345                    let temp = new_state[i];
346                    new_state[i] = new_state[j];
347                    new_state[j] = temp;
348                }
349            }
350        }
351
352        Ok(new_state)
353    }
354
355    fn pauli_z_expectation(&self, state: &Array1<Complex64>, qubit: usize) -> QuantRS2Result<f64> {
356        let dim = state.len();
357        let mut expectation = 0.0;
358
359        for i in 0..dim {
360            let bit = (i >> qubit) & 1;
361            let sign = if bit == 0 { 1.0 } else { -1.0 };
362            expectation += sign * state[i].norm_sqr();
363        }
364
365        // Map from [-1, 1] to [0, 1]
366        Ok(f64::midpoint(expectation, 1.0))
367    }
368}
369
370/// Quantum federated learning server
371#[derive(Debug)]
372pub struct QuantumFederatedServer {
373    /// Configuration
374    config: QuantumFederatedConfig,
375    /// Global model parameters
376    global_params: Array2<f64>,
377    /// Clients
378    clients: Vec<QuantumFederatedClient>,
379    /// Training history
380    history: Vec<f64>,
381}
382
383impl QuantumFederatedServer {
384    /// Create new federated server
385    pub fn new(config: QuantumFederatedConfig) -> Self {
386        let mut rng = thread_rng();
387
388        // Initialize global model
389        let global_params =
390            Array2::from_shape_fn((config.circuit_depth, config.num_qubits * 3), |_| {
391                rng.gen_range(-PI..PI)
392            });
393
394        // Create clients
395        let mut clients = Vec::with_capacity(config.num_clients);
396        for i in 0..config.num_clients {
397            let dataset_size = rng.gen_range(50..200);
398            clients.push(QuantumFederatedClient::new(
399                i,
400                config.num_qubits,
401                config.circuit_depth,
402                dataset_size,
403            ));
404        }
405
406        Self {
407            config,
408            global_params,
409            clients,
410            history: Vec::new(),
411        }
412    }
413
414    /// Run one federated learning round
415    pub fn train_round(
416        &mut self,
417        client_data: &HashMap<usize, (Vec<Array1<Complex64>>, Vec<usize>)>,
418    ) -> QuantRS2Result<f64> {
419        // Select clients for this round
420        let num_selected =
421            (self.config.num_clients as f64 * self.config.client_fraction).ceil() as usize;
422        let selected_clients = self.select_clients(num_selected);
423
424        // Distribute global model to selected clients
425        for &client_id in &selected_clients {
426            self.clients[client_id].set_params(self.global_params.clone());
427        }
428
429        // Local training on each client
430        let mut client_updates = Vec::new();
431        let mut client_weights = Vec::new();
432        let mut avg_loss = 0.0;
433
434        for &client_id in &selected_clients {
435            if let Some((data, labels)) = client_data.get(&client_id) {
436                let loss = self.clients[client_id].train_local(
437                    data,
438                    labels,
439                    self.config.local_epochs,
440                    self.config.local_lr,
441                )?;
442
443                avg_loss += loss;
444
445                client_updates.push(self.clients[client_id].get_params().clone());
446                client_weights.push(self.clients[client_id].dataset_size() as f64);
447            }
448        }
449
450        avg_loss /= selected_clients.len() as f64;
451        self.history.push(avg_loss);
452
453        // Aggregate client updates
454        self.aggregate_updates(&client_updates, &client_weights)?;
455
456        Ok(avg_loss)
457    }
458
459    /// Select clients for training round
460    fn select_clients(&self, num_selected: usize) -> Vec<usize> {
461        let mut rng = thread_rng();
462        let mut clients: Vec<usize> = (0..self.config.num_clients).collect();
463
464        // Shuffle and select
465        for i in (1..clients.len()).rev() {
466            let j = rng.gen_range(0..=i);
467            clients.swap(i, j);
468        }
469
470        clients.truncate(num_selected);
471        clients
472    }
473
474    /// Aggregate client updates
475    fn aggregate_updates(
476        &mut self,
477        updates: &[Array2<f64>],
478        weights: &[f64],
479    ) -> QuantRS2Result<()> {
480        match self.config.aggregation {
481            AggregationStrategy::FedAvg => {
482                self.federated_averaging(updates)?;
483            }
484            AggregationStrategy::WeightedAvg => {
485                self.weighted_averaging(updates, weights)?;
486            }
487            AggregationStrategy::Median => {
488                self.median_aggregation(updates)?;
489            }
490            AggregationStrategy::TrimmedMean => {
491                self.trimmed_mean_aggregation(updates, 0.1)?;
492            }
493            AggregationStrategy::Krum => {
494                self.krum_aggregation(updates)?;
495            }
496        }
497
498        // Apply differential privacy if enabled
499        if self.config.dp_epsilon > 0.0 {
500            self.apply_differential_privacy()?;
501        }
502
503        Ok(())
504    }
505
506    /// Federated averaging (FedAvg)
507    fn federated_averaging(&mut self, updates: &[Array2<f64>]) -> QuantRS2Result<()> {
508        let mut avg_params = Array2::zeros(self.global_params.dim());
509
510        for update in updates {
511            avg_params = avg_params + update;
512        }
513
514        avg_params = avg_params / (updates.len() as f64);
515        self.global_params = avg_params;
516
517        Ok(())
518    }
519
520    /// Weighted averaging by dataset size
521    fn weighted_averaging(
522        &mut self,
523        updates: &[Array2<f64>],
524        weights: &[f64],
525    ) -> QuantRS2Result<()> {
526        let total_weight: f64 = weights.iter().sum();
527        let mut weighted_params = Array2::zeros(self.global_params.dim());
528
529        for (update, &weight) in updates.iter().zip(weights.iter()) {
530            weighted_params = weighted_params + update * (weight / total_weight);
531        }
532
533        self.global_params = weighted_params;
534        Ok(())
535    }
536
537    /// Median aggregation (coordinate-wise median)
538    fn median_aggregation(&mut self, updates: &[Array2<f64>]) -> QuantRS2Result<()> {
539        let shape = self.global_params.dim();
540        let mut median_params = Array2::zeros(shape);
541
542        for i in 0..shape.0 {
543            for j in 0..shape.1 {
544                let mut values: Vec<f64> = updates.iter().map(|u| u[[i, j]]).collect();
545                values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
546
547                median_params[[i, j]] = if values.len() % 2 == 0 {
548                    f64::midpoint(values[values.len() / 2 - 1], values[values.len() / 2])
549                } else {
550                    values[values.len() / 2]
551                };
552            }
553        }
554
555        self.global_params = median_params;
556        Ok(())
557    }
558
559    /// Trimmed mean aggregation
560    fn trimmed_mean_aggregation(
561        &mut self,
562        updates: &[Array2<f64>],
563        trim_ratio: f64,
564    ) -> QuantRS2Result<()> {
565        let shape = self.global_params.dim();
566        let mut trimmed_params = Array2::zeros(shape);
567        let trim_count = (updates.len() as f64 * trim_ratio).floor() as usize;
568
569        for i in 0..shape.0 {
570            for j in 0..shape.1 {
571                let mut values: Vec<f64> = updates.iter().map(|u| u[[i, j]]).collect();
572                values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
573
574                // Trim extremes
575                let trimmed: Vec<f64> = values[trim_count..values.len() - trim_count].to_vec();
576                trimmed_params[[i, j]] = trimmed.iter().sum::<f64>() / trimmed.len() as f64;
577            }
578        }
579
580        self.global_params = trimmed_params;
581        Ok(())
582    }
583
584    /// Krum aggregation (Byzantine-robust)
585    fn krum_aggregation(&mut self, updates: &[Array2<f64>]) -> QuantRS2Result<()> {
586        let n = updates.len();
587        let f = (n - 1) / 2; // Maximum Byzantine clients
588        let n_minus_f_minus_2 = n - f - 2;
589
590        // Compute pairwise distances
591        let mut scores = vec![0.0; n];
592
593        for i in 0..n {
594            let mut distances: Vec<(usize, f64)> = Vec::new();
595
596            for j in 0..n {
597                if i != j {
598                    let diff = &updates[i] - &updates[j];
599                    let dist: f64 = diff.iter().map(|x| x * x).sum::<f64>().sqrt();
600                    distances.push((j, dist));
601                }
602            }
603
604            // Sort by distance and sum closest n-f-2
605            distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
606            scores[i] = distances
607                .iter()
608                .take(n_minus_f_minus_2)
609                .map(|(_, d)| d)
610                .sum();
611        }
612
613        // Select client with minimum score
614        let best_client = scores
615            .iter()
616            .enumerate()
617            .min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
618            .map(|(idx, _)| idx)
619            .unwrap_or(0);
620
621        self.global_params.clone_from(&updates[best_client]);
622        Ok(())
623    }
624
625    /// Apply differential privacy to global model
626    fn apply_differential_privacy(&mut self) -> QuantRS2Result<()> {
627        let mut rng = thread_rng();
628
629        // Compute noise scale based on DP parameters
630        let sensitivity = 1.0; // L2 sensitivity
631        let noise_scale = sensitivity / self.config.dp_epsilon;
632
633        // Add Gaussian noise to parameters
634        for i in 0..self.global_params.shape()[0] {
635            for j in 0..self.global_params.shape()[1] {
636                let noise = rng.gen_range(-1.0..1.0) * noise_scale;
637                self.global_params[[i, j]] += noise;
638            }
639        }
640
641        Ok(())
642    }
643
644    /// Get global model parameters
645    pub const fn get_global_params(&self) -> &Array2<f64> {
646        &self.global_params
647    }
648
649    /// Get training history
650    pub fn history(&self) -> &[f64] {
651        &self.history
652    }
653}
654
655#[cfg(test)]
656mod tests {
657    use super::*;
658
659    #[test]
660    fn test_federated_client() {
661        let mut client = QuantumFederatedClient::new(0, 2, 2, 100);
662
663        let state = Array1::from_vec(vec![
664            Complex64::new(1.0, 0.0),
665            Complex64::new(0.0, 0.0),
666            Complex64::new(0.0, 0.0),
667            Complex64::new(0.0, 0.0),
668        ]);
669
670        let probs = client
671            .forward(&state)
672            .expect("Failed to forward through client");
673        assert_eq!(probs.len(), 2);
674
675        let sum: f64 = probs.iter().sum();
676        assert!((sum - 1.0).abs() < 1e-6);
677    }
678
679    #[test]
680    fn test_federated_server() {
681        let config = QuantumFederatedConfig {
682            num_qubits: 2,
683            circuit_depth: 2,
684            num_clients: 5,
685            client_fraction: 0.6,
686            local_epochs: 2,
687            local_lr: 0.01,
688            aggregation: AggregationStrategy::FedAvg,
689            dp_epsilon: 0.0,
690            dp_delta: 1e-5,
691        };
692
693        let server = QuantumFederatedServer::new(config);
694        assert_eq!(server.clients.len(), 5);
695    }
696}