quantrs2_core/qml/
quantum_meta_learning.rs

1//! Quantum Meta-Learning
2//!
3//! This module implements meta-learning algorithms for quantum machine learning,
4//! enabling rapid adaptation to new tasks with minimal training data.
5//!
6//! # Theoretical Background
7//!
8//! Quantum meta-learning extends classical meta-learning (MAML, Reptile) to
9//! quantum neural networks. The goal is to learn an initialization of quantum
10//! circuit parameters that can quickly adapt to new tasks through fine-tuning.
11//!
12//! # Key Algorithms
13//!
14//! - **Quantum MAML**: Model-Agnostic Meta-Learning for quantum circuits
15//! - **Quantum Reptile**: First-order approximation of MAML
16//! - **Quantum ProtoNets**: Prototype networks using quantum metric learning
17//! - **Quantum Matching Networks**: Attention-based few-shot learning
18//!
19//! # Applications
20//!
21//! - Few-shot quantum classification
22//! - Fast quantum state tomography
23//! - Adaptive quantum control
24//! - Quantum drug discovery with limited data
25//!
26//! # References
27//!
28//! - "Meta-Learning for Quantum Neural Networks" (2023)
29//! - "Few-Shot Learning with Quantum Classifiers" (2024)
30//! - "Quantum Model-Agnostic Meta-Learning" (2024)
31
32use crate::{
33    error::{QuantRS2Error, QuantRS2Result},
34    gate::GateOp,
35    qubit::QubitId,
36};
37use scirs2_core::ndarray::{Array1, Array2, Axis};
38use scirs2_core::random::prelude::*;
39use scirs2_core::Complex64;
40use std::f64::consts::PI;
41
42/// Configuration for quantum meta-learning
43#[derive(Debug, Clone)]
44pub struct QuantumMetaLearningConfig {
45    /// Number of qubits
46    pub num_qubits: usize,
47    /// Circuit depth
48    pub circuit_depth: usize,
49    /// Inner loop learning rate
50    pub inner_lr: f64,
51    /// Outer loop learning rate (meta-learning)
52    pub outer_lr: f64,
53    /// Number of inner loop steps
54    pub inner_steps: usize,
55    /// Number of support examples per class
56    pub n_support: usize,
57    /// Number of query examples per class
58    pub n_query: usize,
59    /// Number of classes per task
60    pub n_way: usize,
61    /// Meta-batch size (number of tasks)
62    pub meta_batch_size: usize,
63}
64
65impl Default for QuantumMetaLearningConfig {
66    fn default() -> Self {
67        Self {
68            num_qubits: 4,
69            circuit_depth: 4,
70            inner_lr: 0.01,
71            outer_lr: 0.001,
72            inner_steps: 5,
73            n_support: 5,
74            n_query: 15,
75            n_way: 2,
76            meta_batch_size: 4,
77        }
78    }
79}
80
81/// Quantum task for meta-learning
82#[derive(Debug, Clone)]
83pub struct QuantumTask {
84    /// Support set states
85    pub support_states: Vec<Array1<Complex64>>,
86    /// Support set labels
87    pub support_labels: Vec<usize>,
88    /// Query set states
89    pub query_states: Vec<Array1<Complex64>>,
90    /// Query set labels
91    pub query_labels: Vec<usize>,
92}
93
94impl QuantumTask {
95    /// Create new quantum task
96    pub fn new(
97        support_states: Vec<Array1<Complex64>>,
98        support_labels: Vec<usize>,
99        query_states: Vec<Array1<Complex64>>,
100        query_labels: Vec<usize>,
101    ) -> Self {
102        Self {
103            support_states,
104            support_labels,
105            query_states,
106            query_labels,
107        }
108    }
109
110    /// Generate random task for testing
111    pub fn random(num_qubits: usize, n_way: usize, n_support: usize, n_query: usize) -> Self {
112        let mut rng = thread_rng();
113        let dim = 1 << num_qubits;
114
115        let mut support_states = Vec::new();
116        let mut support_labels = Vec::new();
117        let mut query_states = Vec::new();
118        let mut query_labels = Vec::new();
119
120        for class in 0..n_way {
121            // Generate class prototype
122            let mut prototype = Array1::from_shape_fn(dim, |_| {
123                Complex64::new(rng.gen_range(-1.0..1.0), rng.gen_range(-1.0..1.0))
124            });
125            let norm: f64 = prototype.iter().map(|x| x.norm_sqr()).sum::<f64>().sqrt();
126            for i in 0..dim {
127                prototype[i] = prototype[i] / norm;
128            }
129
130            // Generate support examples
131            for _ in 0..n_support {
132                let mut state = prototype.clone();
133                // Add small noise
134                for i in 0..dim {
135                    state[i] = state[i]
136                        + Complex64::new(rng.gen_range(-0.1..0.1), rng.gen_range(-0.1..0.1));
137                }
138                let norm: f64 = state.iter().map(|x| x.norm_sqr()).sum::<f64>().sqrt();
139                for i in 0..dim {
140                    state[i] = state[i] / norm;
141                }
142                support_states.push(state);
143                support_labels.push(class);
144            }
145
146            // Generate query examples
147            for _ in 0..n_query {
148                let mut state = prototype.clone();
149                // Add small noise
150                for i in 0..dim {
151                    state[i] = state[i]
152                        + Complex64::new(rng.gen_range(-0.1..0.1), rng.gen_range(-0.1..0.1));
153                }
154                let norm: f64 = state.iter().map(|x| x.norm_sqr()).sum::<f64>().sqrt();
155                for i in 0..dim {
156                    state[i] = state[i] / norm;
157                }
158                query_states.push(state);
159                query_labels.push(class);
160            }
161        }
162
163        Self {
164            support_states,
165            support_labels,
166            query_states,
167            query_labels,
168        }
169    }
170}
171
172/// Quantum circuit for meta-learning
173#[derive(Debug, Clone)]
174pub struct QuantumMetaCircuit {
175    /// Number of qubits
176    num_qubits: usize,
177    /// Circuit depth
178    depth: usize,
179    /// Number of output classes
180    num_classes: usize,
181    /// Circuit parameters
182    params: Array2<f64>,
183    /// Readout weights
184    readout_weights: Array2<f64>,
185}
186
187impl QuantumMetaCircuit {
188    /// Create new quantum meta circuit
189    pub fn new(num_qubits: usize, depth: usize, num_classes: usize) -> Self {
190        let mut rng = thread_rng();
191
192        let params = Array2::from_shape_fn((depth, num_qubits * 3), |_| rng.gen_range(-PI..PI));
193
194        let scale = (2.0 / num_qubits as f64).sqrt();
195        let readout_weights =
196            Array2::from_shape_fn((num_classes, num_qubits), |_| rng.gen_range(-scale..scale));
197
198        Self {
199            num_qubits,
200            depth,
201            num_classes,
202            params,
203            readout_weights,
204        }
205    }
206
207    /// Forward pass
208    pub fn forward(&self, state: &Array1<Complex64>) -> QuantRS2Result<Array1<f64>> {
209        // Apply quantum circuit
210        let mut encoded = state.clone();
211
212        for layer in 0..self.depth {
213            // Rotation gates
214            for q in 0..self.num_qubits {
215                let rx = self.params[[layer, q * 3]];
216                let ry = self.params[[layer, q * 3 + 1]];
217                let rz = self.params[[layer, q * 3 + 2]];
218
219                encoded = self.apply_rotation(&encoded, q, rx, ry, rz)?;
220            }
221
222            // Entangling gates
223            for q in 0..self.num_qubits - 1 {
224                encoded = self.apply_cnot(&encoded, q, q + 1)?;
225            }
226        }
227
228        // Measure expectations and classify
229        let mut expectations = Array1::zeros(self.num_qubits);
230        for q in 0..self.num_qubits {
231            expectations[q] = self.pauli_z_expectation(&encoded, q)?;
232        }
233
234        // Linear readout
235        let mut logits = Array1::zeros(self.num_classes);
236        for i in 0..self.num_classes {
237            for j in 0..self.num_qubits {
238                logits[i] += self.readout_weights[[i, j]] * expectations[j];
239            }
240        }
241
242        // Softmax
243        let max_logit = logits.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
244        let mut probs = Array1::zeros(self.num_classes);
245        let mut sum_exp = 0.0;
246
247        for i in 0..self.num_classes {
248            probs[i] = (logits[i] - max_logit).exp();
249            sum_exp += probs[i];
250        }
251
252        for i in 0..self.num_classes {
253            probs[i] /= sum_exp;
254        }
255
256        Ok(probs)
257    }
258
259    /// Compute loss
260    pub fn compute_loss(
261        &self,
262        states: &[Array1<Complex64>],
263        labels: &[usize],
264    ) -> QuantRS2Result<f64> {
265        let mut total_loss = 0.0;
266
267        for (state, &label) in states.iter().zip(labels.iter()) {
268            let probs = self.forward(state)?;
269            // Cross-entropy loss
270            total_loss -= probs[label].ln();
271        }
272
273        Ok(total_loss / states.len() as f64)
274    }
275
276    /// Compute gradients (simplified with finite differences)
277    pub fn compute_gradients(
278        &self,
279        states: &[Array1<Complex64>],
280        labels: &[usize],
281    ) -> QuantRS2Result<(Array2<f64>, Array2<f64>)> {
282        let epsilon = 1e-4;
283
284        // Gradients for circuit parameters
285        let mut param_grads = Array2::zeros(self.params.dim());
286
287        for i in 0..self.params.shape()[0] {
288            for j in 0..self.params.shape()[1] {
289                let mut circuit_plus = self.clone();
290                circuit_plus.params[[i, j]] += epsilon;
291                let loss_plus = circuit_plus.compute_loss(states, labels)?;
292
293                let mut circuit_minus = self.clone();
294                circuit_minus.params[[i, j]] -= epsilon;
295                let loss_minus = circuit_minus.compute_loss(states, labels)?;
296
297                param_grads[[i, j]] = (loss_plus - loss_minus) / (2.0 * epsilon);
298            }
299        }
300
301        // Gradients for readout weights
302        let mut readout_grads = Array2::zeros(self.readout_weights.dim());
303
304        for i in 0..self.readout_weights.shape()[0] {
305            for j in 0..self.readout_weights.shape()[1] {
306                let mut circuit_plus = self.clone();
307                circuit_plus.readout_weights[[i, j]] += epsilon;
308                let loss_plus = circuit_plus.compute_loss(states, labels)?;
309
310                let mut circuit_minus = self.clone();
311                circuit_minus.readout_weights[[i, j]] -= epsilon;
312                let loss_minus = circuit_minus.compute_loss(states, labels)?;
313
314                readout_grads[[i, j]] = (loss_plus - loss_minus) / (2.0 * epsilon);
315            }
316        }
317
318        Ok((param_grads, readout_grads))
319    }
320
321    /// Update parameters
322    pub fn update_params(
323        &mut self,
324        param_grads: &Array2<f64>,
325        readout_grads: &Array2<f64>,
326        lr: f64,
327    ) {
328        self.params = &self.params - &(param_grads * lr);
329        self.readout_weights = &self.readout_weights - &(readout_grads * lr);
330    }
331
332    /// Helper methods
333    fn apply_rotation(
334        &self,
335        state: &Array1<Complex64>,
336        qubit: usize,
337        rx: f64,
338        ry: f64,
339        rz: f64,
340    ) -> QuantRS2Result<Array1<Complex64>> {
341        let mut result = state.clone();
342        result = self.apply_rz_gate(&result, qubit, rz)?;
343        result = self.apply_ry_gate(&result, qubit, ry)?;
344        result = self.apply_rx_gate(&result, qubit, rx)?;
345        Ok(result)
346    }
347
348    fn apply_rx_gate(
349        &self,
350        state: &Array1<Complex64>,
351        qubit: usize,
352        angle: f64,
353    ) -> QuantRS2Result<Array1<Complex64>> {
354        let dim = state.len();
355        let mut new_state = Array1::zeros(dim);
356        let cos_half = Complex64::new((angle / 2.0).cos(), 0.0);
357        let sin_half = Complex64::new(0.0, -(angle / 2.0).sin());
358
359        for i in 0..dim {
360            let j = i ^ (1 << qubit);
361            new_state[i] = state[i] * cos_half + state[j] * sin_half;
362        }
363
364        Ok(new_state)
365    }
366
367    fn apply_ry_gate(
368        &self,
369        state: &Array1<Complex64>,
370        qubit: usize,
371        angle: f64,
372    ) -> QuantRS2Result<Array1<Complex64>> {
373        let dim = state.len();
374        let mut new_state = Array1::zeros(dim);
375        let cos_half = (angle / 2.0).cos();
376        let sin_half = (angle / 2.0).sin();
377
378        for i in 0..dim {
379            let bit = (i >> qubit) & 1;
380            let j = i ^ (1 << qubit);
381            if bit == 0 {
382                new_state[i] = state[i] * cos_half - state[j] * sin_half;
383            } else {
384                new_state[i] = state[i] * cos_half + state[j] * sin_half;
385            }
386        }
387
388        Ok(new_state)
389    }
390
391    fn apply_rz_gate(
392        &self,
393        state: &Array1<Complex64>,
394        qubit: usize,
395        angle: f64,
396    ) -> QuantRS2Result<Array1<Complex64>> {
397        let dim = state.len();
398        let mut new_state = state.clone();
399        let phase = Complex64::new((angle / 2.0).cos(), -(angle / 2.0).sin());
400
401        for i in 0..dim {
402            let bit = (i >> qubit) & 1;
403            new_state[i] = if bit == 1 {
404                new_state[i] * phase
405            } else {
406                new_state[i] * phase.conj()
407            };
408        }
409
410        Ok(new_state)
411    }
412
413    fn apply_cnot(
414        &self,
415        state: &Array1<Complex64>,
416        control: usize,
417        target: usize,
418    ) -> QuantRS2Result<Array1<Complex64>> {
419        let dim = state.len();
420        let mut new_state = state.clone();
421
422        for i in 0..dim {
423            let control_bit = (i >> control) & 1;
424            if control_bit == 1 {
425                let j = i ^ (1 << target);
426                if i < j {
427                    let temp = new_state[i];
428                    new_state[i] = new_state[j];
429                    new_state[j] = temp;
430                }
431            }
432        }
433
434        Ok(new_state)
435    }
436
437    fn pauli_z_expectation(&self, state: &Array1<Complex64>, qubit: usize) -> QuantRS2Result<f64> {
438        let dim = state.len();
439        let mut expectation = 0.0;
440
441        for i in 0..dim {
442            let bit = (i >> qubit) & 1;
443            let sign = if bit == 0 { 1.0 } else { -1.0 };
444            expectation += sign * state[i].norm_sqr();
445        }
446
447        Ok(expectation)
448    }
449}
450
451/// Quantum MAML (Model-Agnostic Meta-Learning)
452#[derive(Debug, Clone)]
453pub struct QuantumMAML {
454    /// Configuration
455    config: QuantumMetaLearningConfig,
456    /// Meta-model (initialization point)
457    meta_model: QuantumMetaCircuit,
458}
459
460impl QuantumMAML {
461    /// Create new Quantum MAML
462    pub fn new(config: QuantumMetaLearningConfig) -> Self {
463        let meta_model =
464            QuantumMetaCircuit::new(config.num_qubits, config.circuit_depth, config.n_way);
465
466        Self { config, meta_model }
467    }
468
469    /// Meta-training step
470    pub fn meta_train_step(&mut self, tasks: &[QuantumTask]) -> QuantRS2Result<f64> {
471        let mut meta_param_grads = Array2::zeros(self.meta_model.params.dim());
472        let mut meta_readout_grads = Array2::zeros(self.meta_model.readout_weights.dim());
473        let mut total_loss = 0.0;
474
475        for task in tasks {
476            // Inner loop: adapt to task
477            let mut adapted_model = self.meta_model.clone();
478
479            for _ in 0..self.config.inner_steps {
480                let (param_grads, readout_grads) =
481                    adapted_model.compute_gradients(&task.support_states, &task.support_labels)?;
482
483                adapted_model.update_params(&param_grads, &readout_grads, self.config.inner_lr);
484            }
485
486            // Compute loss on query set
487            let query_loss = adapted_model.compute_loss(&task.query_states, &task.query_labels)?;
488            total_loss += query_loss;
489
490            // Compute meta-gradients
491            let (param_grads, readout_grads) =
492                adapted_model.compute_gradients(&task.query_states, &task.query_labels)?;
493
494            meta_param_grads = meta_param_grads + param_grads;
495            meta_readout_grads = meta_readout_grads + readout_grads;
496        }
497
498        // Average gradients
499        meta_param_grads = meta_param_grads / (tasks.len() as f64);
500        meta_readout_grads = meta_readout_grads / (tasks.len() as f64);
501
502        // Update meta-model
503        self.meta_model
504            .update_params(&meta_param_grads, &meta_readout_grads, self.config.outer_lr);
505
506        Ok(total_loss / tasks.len() as f64)
507    }
508
509    /// Adapt to new task
510    pub fn adapt(&self, task: &QuantumTask) -> QuantRS2Result<QuantumMetaCircuit> {
511        let mut adapted_model = self.meta_model.clone();
512
513        for _ in 0..self.config.inner_steps {
514            let (param_grads, readout_grads) =
515                adapted_model.compute_gradients(&task.support_states, &task.support_labels)?;
516
517            adapted_model.update_params(&param_grads, &readout_grads, self.config.inner_lr);
518        }
519
520        Ok(adapted_model)
521    }
522
523    /// Evaluate on new task
524    pub fn evaluate(&self, task: &QuantumTask) -> QuantRS2Result<f64> {
525        let adapted_model = self.adapt(task)?;
526
527        let mut correct = 0;
528        for (state, &label) in task.query_states.iter().zip(task.query_labels.iter()) {
529            let probs = adapted_model.forward(state)?;
530            let mut max_prob = f64::NEG_INFINITY;
531            let mut predicted = 0;
532
533            for (i, &prob) in probs.iter().enumerate() {
534                if prob > max_prob {
535                    max_prob = prob;
536                    predicted = i;
537                }
538            }
539
540            if predicted == label {
541                correct += 1;
542            }
543        }
544
545        Ok(correct as f64 / task.query_states.len() as f64)
546    }
547
548    /// Get meta-model
549    pub fn meta_model(&self) -> &QuantumMetaCircuit {
550        &self.meta_model
551    }
552}
553
554/// Quantum Reptile (simpler first-order MAML)
555#[derive(Debug, Clone)]
556pub struct QuantumReptile {
557    /// Configuration
558    config: QuantumMetaLearningConfig,
559    /// Meta-model
560    meta_model: QuantumMetaCircuit,
561}
562
563impl QuantumReptile {
564    /// Create new Quantum Reptile
565    pub fn new(config: QuantumMetaLearningConfig) -> Self {
566        let meta_model =
567            QuantumMetaCircuit::new(config.num_qubits, config.circuit_depth, config.n_way);
568
569        Self { config, meta_model }
570    }
571
572    /// Meta-training step
573    pub fn meta_train_step(&mut self, task: &QuantumTask) -> QuantRS2Result<f64> {
574        // Adapt to task
575        let mut adapted_model = self.meta_model.clone();
576
577        for _ in 0..self.config.inner_steps {
578            let (param_grads, readout_grads) =
579                adapted_model.compute_gradients(&task.support_states, &task.support_labels)?;
580
581            adapted_model.update_params(&param_grads, &readout_grads, self.config.inner_lr);
582        }
583
584        // Compute loss
585        let loss = adapted_model.compute_loss(&task.query_states, &task.query_labels)?;
586
587        // Update meta-model towards adapted model
588        let param_diff = &adapted_model.params - &self.meta_model.params;
589        let readout_diff = &adapted_model.readout_weights - &self.meta_model.readout_weights;
590
591        self.meta_model.params = &self.meta_model.params + &(param_diff * self.config.outer_lr);
592        self.meta_model.readout_weights =
593            &self.meta_model.readout_weights + &(readout_diff * self.config.outer_lr);
594
595        Ok(loss)
596    }
597
598    /// Adapt to new task (same as MAML)
599    pub fn adapt(&self, task: &QuantumTask) -> QuantRS2Result<QuantumMetaCircuit> {
600        let mut adapted_model = self.meta_model.clone();
601
602        for _ in 0..self.config.inner_steps {
603            let (param_grads, readout_grads) =
604                adapted_model.compute_gradients(&task.support_states, &task.support_labels)?;
605
606            adapted_model.update_params(&param_grads, &readout_grads, self.config.inner_lr);
607        }
608
609        Ok(adapted_model)
610    }
611}
612
613#[cfg(test)]
614mod tests {
615    use super::*;
616
617    #[test]
618    fn test_quantum_meta_circuit() {
619        let circuit = QuantumMetaCircuit::new(3, 2, 2);
620
621        let state = Array1::from_vec(vec![
622            Complex64::new(1.0, 0.0),
623            Complex64::new(0.0, 0.0),
624            Complex64::new(0.0, 0.0),
625            Complex64::new(0.0, 0.0),
626            Complex64::new(0.0, 0.0),
627            Complex64::new(0.0, 0.0),
628            Complex64::new(0.0, 0.0),
629            Complex64::new(0.0, 0.0),
630        ]);
631
632        let probs = circuit.forward(&state).unwrap();
633        assert_eq!(probs.len(), 2);
634
635        let sum: f64 = probs.iter().sum();
636        assert!((sum - 1.0).abs() < 1e-6);
637    }
638
639    #[test]
640    fn test_quantum_maml() {
641        let config = QuantumMetaLearningConfig {
642            num_qubits: 2,
643            circuit_depth: 2,
644            inner_lr: 0.01,
645            outer_lr: 0.001,
646            inner_steps: 3,
647            n_support: 2,
648            n_query: 5,
649            n_way: 2,
650            meta_batch_size: 2,
651        };
652
653        let maml = QuantumMAML::new(config.clone());
654
655        let task = QuantumTask::random(
656            config.num_qubits,
657            config.n_way,
658            config.n_support,
659            config.n_query,
660        );
661
662        let adapted_model = maml.adapt(&task).unwrap();
663        let probs = adapted_model.forward(&task.query_states[0]).unwrap();
664
665        assert_eq!(probs.len(), config.n_way);
666    }
667}