Skip to main content

quantrs2_core/qml/
quantum_meta_learning.rs

1//! Quantum Meta-Learning
2//!
3//! This module implements meta-learning algorithms for quantum machine learning,
4//! enabling rapid adaptation to new tasks with minimal training data.
5//!
6//! # Theoretical Background
7//!
8//! Quantum meta-learning extends classical meta-learning (MAML, Reptile) to
9//! quantum neural networks. The goal is to learn an initialization of quantum
10//! circuit parameters that can quickly adapt to new tasks through fine-tuning.
11//!
12//! # Key Algorithms
13//!
14//! - **Quantum MAML**: Model-Agnostic Meta-Learning for quantum circuits
15//! - **Quantum Reptile**: First-order approximation of MAML
16//! - **Quantum ProtoNets**: Prototype networks using quantum metric learning
17//! - **Quantum Matching Networks**: Attention-based few-shot learning
18//!
19//! # Applications
20//!
21//! - Few-shot quantum classification
22//! - Fast quantum state tomography
23//! - Adaptive quantum control
24//! - Quantum drug discovery with limited data
25//!
26//! # References
27//!
28//! - "Meta-Learning for Quantum Neural Networks"
29//! - "Few-Shot Learning with Quantum Classifiers"
30//! - "Quantum Model-Agnostic Meta-Learning"
31
32use crate::{
33    error::{QuantRS2Error, QuantRS2Result},
34    gate::GateOp,
35    qubit::QubitId,
36};
37use scirs2_core::ndarray::{Array1, Array2, Axis};
38use scirs2_core::random::prelude::*;
39use scirs2_core::Complex64;
40use std::f64::consts::PI;
41
42/// Configuration for quantum meta-learning
43#[derive(Debug, Clone)]
44pub struct QuantumMetaLearningConfig {
45    /// Number of qubits
46    pub num_qubits: usize,
47    /// Circuit depth
48    pub circuit_depth: usize,
49    /// Inner loop learning rate
50    pub inner_lr: f64,
51    /// Outer loop learning rate (meta-learning)
52    pub outer_lr: f64,
53    /// Number of inner loop steps
54    pub inner_steps: usize,
55    /// Number of support examples per class
56    pub n_support: usize,
57    /// Number of query examples per class
58    pub n_query: usize,
59    /// Number of classes per task
60    pub n_way: usize,
61    /// Meta-batch size (number of tasks)
62    pub meta_batch_size: usize,
63}
64
65impl Default for QuantumMetaLearningConfig {
66    fn default() -> Self {
67        Self {
68            num_qubits: 4,
69            circuit_depth: 4,
70            inner_lr: 0.01,
71            outer_lr: 0.001,
72            inner_steps: 5,
73            n_support: 5,
74            n_query: 15,
75            n_way: 2,
76            meta_batch_size: 4,
77        }
78    }
79}
80
81/// Quantum task for meta-learning
82#[derive(Debug, Clone)]
83pub struct QuantumTask {
84    /// Support set states
85    pub support_states: Vec<Array1<Complex64>>,
86    /// Support set labels
87    pub support_labels: Vec<usize>,
88    /// Query set states
89    pub query_states: Vec<Array1<Complex64>>,
90    /// Query set labels
91    pub query_labels: Vec<usize>,
92}
93
94impl QuantumTask {
95    /// Create new quantum task
96    pub const fn new(
97        support_states: Vec<Array1<Complex64>>,
98        support_labels: Vec<usize>,
99        query_states: Vec<Array1<Complex64>>,
100        query_labels: Vec<usize>,
101    ) -> Self {
102        Self {
103            support_states,
104            support_labels,
105            query_states,
106            query_labels,
107        }
108    }
109
110    /// Generate random task for testing
111    pub fn random(num_qubits: usize, n_way: usize, n_support: usize, n_query: usize) -> Self {
112        let mut rng = thread_rng();
113        let dim = 1 << num_qubits;
114
115        let mut support_states = Vec::new();
116        let mut support_labels = Vec::new();
117        let mut query_states = Vec::new();
118        let mut query_labels = Vec::new();
119
120        for class in 0..n_way {
121            // Generate class prototype
122            let mut prototype = Array1::from_shape_fn(dim, |_| {
123                Complex64::new(rng.random_range(-1.0..1.0), rng.random_range(-1.0..1.0))
124            });
125            let norm: f64 = prototype.iter().map(|x| x.norm_sqr()).sum::<f64>().sqrt();
126            for i in 0..dim {
127                prototype[i] = prototype[i] / norm;
128            }
129
130            // Generate support examples
131            for _ in 0..n_support {
132                let mut state = prototype.clone();
133                // Add small noise
134                for i in 0..dim {
135                    state[i] = state[i]
136                        + Complex64::new(rng.random_range(-0.1..0.1), rng.random_range(-0.1..0.1));
137                }
138                let norm: f64 = state.iter().map(|x| x.norm_sqr()).sum::<f64>().sqrt();
139                for i in 0..dim {
140                    state[i] = state[i] / norm;
141                }
142                support_states.push(state);
143                support_labels.push(class);
144            }
145
146            // Generate query examples
147            for _ in 0..n_query {
148                let mut state = prototype.clone();
149                // Add small noise
150                for i in 0..dim {
151                    state[i] = state[i]
152                        + Complex64::new(rng.random_range(-0.1..0.1), rng.random_range(-0.1..0.1));
153                }
154                let norm: f64 = state.iter().map(|x| x.norm_sqr()).sum::<f64>().sqrt();
155                for i in 0..dim {
156                    state[i] = state[i] / norm;
157                }
158                query_states.push(state);
159                query_labels.push(class);
160            }
161        }
162
163        Self {
164            support_states,
165            support_labels,
166            query_states,
167            query_labels,
168        }
169    }
170}
171
172/// Quantum circuit for meta-learning
173#[derive(Debug, Clone)]
174pub struct QuantumMetaCircuit {
175    /// Number of qubits
176    num_qubits: usize,
177    /// Circuit depth
178    depth: usize,
179    /// Number of output classes
180    num_classes: usize,
181    /// Circuit parameters
182    params: Array2<f64>,
183    /// Readout weights
184    readout_weights: Array2<f64>,
185}
186
187impl QuantumMetaCircuit {
188    /// Create new quantum meta circuit
189    pub fn new(num_qubits: usize, depth: usize, num_classes: usize) -> Self {
190        let mut rng = thread_rng();
191
192        let params = Array2::from_shape_fn((depth, num_qubits * 3), |_| rng.random_range(-PI..PI));
193
194        let scale = (2.0 / num_qubits as f64).sqrt();
195        let readout_weights = Array2::from_shape_fn((num_classes, num_qubits), |_| {
196            rng.random_range(-scale..scale)
197        });
198
199        Self {
200            num_qubits,
201            depth,
202            num_classes,
203            params,
204            readout_weights,
205        }
206    }
207
208    /// Forward pass
209    pub fn forward(&self, state: &Array1<Complex64>) -> QuantRS2Result<Array1<f64>> {
210        // Apply quantum circuit
211        let mut encoded = state.clone();
212
213        for layer in 0..self.depth {
214            // Rotation gates
215            for q in 0..self.num_qubits {
216                let rx = self.params[[layer, q * 3]];
217                let ry = self.params[[layer, q * 3 + 1]];
218                let rz = self.params[[layer, q * 3 + 2]];
219
220                encoded = self.apply_rotation(&encoded, q, rx, ry, rz)?;
221            }
222
223            // Entangling gates
224            for q in 0..self.num_qubits - 1 {
225                encoded = self.apply_cnot(&encoded, q, q + 1)?;
226            }
227        }
228
229        // Measure expectations and classify
230        let mut expectations = Array1::zeros(self.num_qubits);
231        for q in 0..self.num_qubits {
232            expectations[q] = self.pauli_z_expectation(&encoded, q)?;
233        }
234
235        // Linear readout
236        let mut logits = Array1::zeros(self.num_classes);
237        for i in 0..self.num_classes {
238            for j in 0..self.num_qubits {
239                logits[i] += self.readout_weights[[i, j]] * expectations[j];
240            }
241        }
242
243        // Softmax
244        let max_logit = logits.iter().copied().fold(f64::NEG_INFINITY, f64::max);
245        let mut probs = Array1::zeros(self.num_classes);
246        let mut sum_exp = 0.0;
247
248        for i in 0..self.num_classes {
249            probs[i] = (logits[i] - max_logit).exp();
250            sum_exp += probs[i];
251        }
252
253        for i in 0..self.num_classes {
254            probs[i] /= sum_exp;
255        }
256
257        Ok(probs)
258    }
259
260    /// Compute loss
261    pub fn compute_loss(
262        &self,
263        states: &[Array1<Complex64>],
264        labels: &[usize],
265    ) -> QuantRS2Result<f64> {
266        let mut total_loss = 0.0;
267
268        for (state, &label) in states.iter().zip(labels.iter()) {
269            let probs = self.forward(state)?;
270            // Cross-entropy loss
271            total_loss -= probs[label].ln();
272        }
273
274        Ok(total_loss / states.len() as f64)
275    }
276
277    /// Compute gradients (simplified with finite differences)
278    pub fn compute_gradients(
279        &self,
280        states: &[Array1<Complex64>],
281        labels: &[usize],
282    ) -> QuantRS2Result<(Array2<f64>, Array2<f64>)> {
283        let epsilon = 1e-4;
284
285        // Gradients for circuit parameters
286        let mut param_grads = Array2::zeros(self.params.dim());
287
288        for i in 0..self.params.shape()[0] {
289            for j in 0..self.params.shape()[1] {
290                let mut circuit_plus = self.clone();
291                circuit_plus.params[[i, j]] += epsilon;
292                let loss_plus = circuit_plus.compute_loss(states, labels)?;
293
294                let mut circuit_minus = self.clone();
295                circuit_minus.params[[i, j]] -= epsilon;
296                let loss_minus = circuit_minus.compute_loss(states, labels)?;
297
298                param_grads[[i, j]] = (loss_plus - loss_minus) / (2.0 * epsilon);
299            }
300        }
301
302        // Gradients for readout weights
303        let mut readout_grads = Array2::zeros(self.readout_weights.dim());
304
305        for i in 0..self.readout_weights.shape()[0] {
306            for j in 0..self.readout_weights.shape()[1] {
307                let mut circuit_plus = self.clone();
308                circuit_plus.readout_weights[[i, j]] += epsilon;
309                let loss_plus = circuit_plus.compute_loss(states, labels)?;
310
311                let mut circuit_minus = self.clone();
312                circuit_minus.readout_weights[[i, j]] -= epsilon;
313                let loss_minus = circuit_minus.compute_loss(states, labels)?;
314
315                readout_grads[[i, j]] = (loss_plus - loss_minus) / (2.0 * epsilon);
316            }
317        }
318
319        Ok((param_grads, readout_grads))
320    }
321
322    /// Update parameters
323    pub fn update_params(
324        &mut self,
325        param_grads: &Array2<f64>,
326        readout_grads: &Array2<f64>,
327        lr: f64,
328    ) {
329        self.params = &self.params - &(param_grads * lr);
330        self.readout_weights = &self.readout_weights - &(readout_grads * lr);
331    }
332
333    /// Helper methods
334    fn apply_rotation(
335        &self,
336        state: &Array1<Complex64>,
337        qubit: usize,
338        rx: f64,
339        ry: f64,
340        rz: f64,
341    ) -> QuantRS2Result<Array1<Complex64>> {
342        let mut result = state.clone();
343        result = self.apply_rz_gate(&result, qubit, rz)?;
344        result = self.apply_ry_gate(&result, qubit, ry)?;
345        result = self.apply_rx_gate(&result, qubit, rx)?;
346        Ok(result)
347    }
348
349    fn apply_rx_gate(
350        &self,
351        state: &Array1<Complex64>,
352        qubit: usize,
353        angle: f64,
354    ) -> QuantRS2Result<Array1<Complex64>> {
355        let dim = state.len();
356        let mut new_state = Array1::zeros(dim);
357        let cos_half = Complex64::new((angle / 2.0).cos(), 0.0);
358        let sin_half = Complex64::new(0.0, -(angle / 2.0).sin());
359
360        for i in 0..dim {
361            let j = i ^ (1 << qubit);
362            new_state[i] = state[i] * cos_half + state[j] * sin_half;
363        }
364
365        Ok(new_state)
366    }
367
368    fn apply_ry_gate(
369        &self,
370        state: &Array1<Complex64>,
371        qubit: usize,
372        angle: f64,
373    ) -> QuantRS2Result<Array1<Complex64>> {
374        let dim = state.len();
375        let mut new_state = Array1::zeros(dim);
376        let cos_half = (angle / 2.0).cos();
377        let sin_half = (angle / 2.0).sin();
378
379        for i in 0..dim {
380            let bit = (i >> qubit) & 1;
381            let j = i ^ (1 << qubit);
382            if bit == 0 {
383                new_state[i] = state[i] * cos_half - state[j] * sin_half;
384            } else {
385                new_state[i] = state[i] * cos_half + state[j] * sin_half;
386            }
387        }
388
389        Ok(new_state)
390    }
391
392    fn apply_rz_gate(
393        &self,
394        state: &Array1<Complex64>,
395        qubit: usize,
396        angle: f64,
397    ) -> QuantRS2Result<Array1<Complex64>> {
398        let dim = state.len();
399        let mut new_state = state.clone();
400        let phase = Complex64::new((angle / 2.0).cos(), -(angle / 2.0).sin());
401
402        for i in 0..dim {
403            let bit = (i >> qubit) & 1;
404            new_state[i] = if bit == 1 {
405                new_state[i] * phase
406            } else {
407                new_state[i] * phase.conj()
408            };
409        }
410
411        Ok(new_state)
412    }
413
414    fn apply_cnot(
415        &self,
416        state: &Array1<Complex64>,
417        control: usize,
418        target: usize,
419    ) -> QuantRS2Result<Array1<Complex64>> {
420        let dim = state.len();
421        let mut new_state = state.clone();
422
423        for i in 0..dim {
424            let control_bit = (i >> control) & 1;
425            if control_bit == 1 {
426                let j = i ^ (1 << target);
427                if i < j {
428                    let temp = new_state[i];
429                    new_state[i] = new_state[j];
430                    new_state[j] = temp;
431                }
432            }
433        }
434
435        Ok(new_state)
436    }
437
438    fn pauli_z_expectation(&self, state: &Array1<Complex64>, qubit: usize) -> QuantRS2Result<f64> {
439        let dim = state.len();
440        let mut expectation = 0.0;
441
442        for i in 0..dim {
443            let bit = (i >> qubit) & 1;
444            let sign = if bit == 0 { 1.0 } else { -1.0 };
445            expectation += sign * state[i].norm_sqr();
446        }
447
448        Ok(expectation)
449    }
450}
451
452/// Quantum MAML (Model-Agnostic Meta-Learning)
453#[derive(Debug, Clone)]
454pub struct QuantumMAML {
455    /// Configuration
456    config: QuantumMetaLearningConfig,
457    /// Meta-model (initialization point)
458    meta_model: QuantumMetaCircuit,
459}
460
461impl QuantumMAML {
462    /// Create new Quantum MAML
463    pub fn new(config: QuantumMetaLearningConfig) -> Self {
464        let meta_model =
465            QuantumMetaCircuit::new(config.num_qubits, config.circuit_depth, config.n_way);
466
467        Self { config, meta_model }
468    }
469
470    /// Meta-training step
471    pub fn meta_train_step(&mut self, tasks: &[QuantumTask]) -> QuantRS2Result<f64> {
472        let mut meta_param_grads = Array2::zeros(self.meta_model.params.dim());
473        let mut meta_readout_grads = Array2::zeros(self.meta_model.readout_weights.dim());
474        let mut total_loss = 0.0;
475
476        for task in tasks {
477            // Inner loop: adapt to task
478            let mut adapted_model = self.meta_model.clone();
479
480            for _ in 0..self.config.inner_steps {
481                let (param_grads, readout_grads) =
482                    adapted_model.compute_gradients(&task.support_states, &task.support_labels)?;
483
484                adapted_model.update_params(&param_grads, &readout_grads, self.config.inner_lr);
485            }
486
487            // Compute loss on query set
488            let query_loss = adapted_model.compute_loss(&task.query_states, &task.query_labels)?;
489            total_loss += query_loss;
490
491            // Compute meta-gradients
492            let (param_grads, readout_grads) =
493                adapted_model.compute_gradients(&task.query_states, &task.query_labels)?;
494
495            meta_param_grads = meta_param_grads + param_grads;
496            meta_readout_grads = meta_readout_grads + readout_grads;
497        }
498
499        // Average gradients
500        meta_param_grads = meta_param_grads / (tasks.len() as f64);
501        meta_readout_grads = meta_readout_grads / (tasks.len() as f64);
502
503        // Update meta-model
504        self.meta_model
505            .update_params(&meta_param_grads, &meta_readout_grads, self.config.outer_lr);
506
507        Ok(total_loss / tasks.len() as f64)
508    }
509
510    /// Adapt to new task
511    pub fn adapt(&self, task: &QuantumTask) -> QuantRS2Result<QuantumMetaCircuit> {
512        let mut adapted_model = self.meta_model.clone();
513
514        for _ in 0..self.config.inner_steps {
515            let (param_grads, readout_grads) =
516                adapted_model.compute_gradients(&task.support_states, &task.support_labels)?;
517
518            adapted_model.update_params(&param_grads, &readout_grads, self.config.inner_lr);
519        }
520
521        Ok(adapted_model)
522    }
523
524    /// Evaluate on new task
525    pub fn evaluate(&self, task: &QuantumTask) -> QuantRS2Result<f64> {
526        let adapted_model = self.adapt(task)?;
527
528        let mut correct = 0;
529        for (state, &label) in task.query_states.iter().zip(task.query_labels.iter()) {
530            let probs = adapted_model.forward(state)?;
531            let mut max_prob = f64::NEG_INFINITY;
532            let mut predicted = 0;
533
534            for (i, &prob) in probs.iter().enumerate() {
535                if prob > max_prob {
536                    max_prob = prob;
537                    predicted = i;
538                }
539            }
540
541            if predicted == label {
542                correct += 1;
543            }
544        }
545
546        Ok(correct as f64 / task.query_states.len() as f64)
547    }
548
549    /// Get meta-model
550    pub const fn meta_model(&self) -> &QuantumMetaCircuit {
551        &self.meta_model
552    }
553}
554
555/// Quantum Reptile (simpler first-order MAML)
556#[derive(Debug, Clone)]
557pub struct QuantumReptile {
558    /// Configuration
559    config: QuantumMetaLearningConfig,
560    /// Meta-model
561    meta_model: QuantumMetaCircuit,
562}
563
564impl QuantumReptile {
565    /// Create new Quantum Reptile
566    pub fn new(config: QuantumMetaLearningConfig) -> Self {
567        let meta_model =
568            QuantumMetaCircuit::new(config.num_qubits, config.circuit_depth, config.n_way);
569
570        Self { config, meta_model }
571    }
572
573    /// Meta-training step
574    pub fn meta_train_step(&mut self, task: &QuantumTask) -> QuantRS2Result<f64> {
575        // Adapt to task
576        let mut adapted_model = self.meta_model.clone();
577
578        for _ in 0..self.config.inner_steps {
579            let (param_grads, readout_grads) =
580                adapted_model.compute_gradients(&task.support_states, &task.support_labels)?;
581
582            adapted_model.update_params(&param_grads, &readout_grads, self.config.inner_lr);
583        }
584
585        // Compute loss
586        let loss = adapted_model.compute_loss(&task.query_states, &task.query_labels)?;
587
588        // Update meta-model towards adapted model
589        let param_diff = &adapted_model.params - &self.meta_model.params;
590        let readout_diff = &adapted_model.readout_weights - &self.meta_model.readout_weights;
591
592        self.meta_model.params = &self.meta_model.params + &(param_diff * self.config.outer_lr);
593        self.meta_model.readout_weights =
594            &self.meta_model.readout_weights + &(readout_diff * self.config.outer_lr);
595
596        Ok(loss)
597    }
598
599    /// Adapt to new task (same as MAML)
600    pub fn adapt(&self, task: &QuantumTask) -> QuantRS2Result<QuantumMetaCircuit> {
601        let mut adapted_model = self.meta_model.clone();
602
603        for _ in 0..self.config.inner_steps {
604            let (param_grads, readout_grads) =
605                adapted_model.compute_gradients(&task.support_states, &task.support_labels)?;
606
607            adapted_model.update_params(&param_grads, &readout_grads, self.config.inner_lr);
608        }
609
610        Ok(adapted_model)
611    }
612}
613
614#[cfg(test)]
615mod tests {
616    use super::*;
617
618    #[test]
619    fn test_quantum_meta_circuit() {
620        let circuit = QuantumMetaCircuit::new(3, 2, 2);
621
622        let state = Array1::from_vec(vec![
623            Complex64::new(1.0, 0.0),
624            Complex64::new(0.0, 0.0),
625            Complex64::new(0.0, 0.0),
626            Complex64::new(0.0, 0.0),
627            Complex64::new(0.0, 0.0),
628            Complex64::new(0.0, 0.0),
629            Complex64::new(0.0, 0.0),
630            Complex64::new(0.0, 0.0),
631        ]);
632
633        let probs = circuit
634            .forward(&state)
635            .expect("forward pass should succeed");
636        assert_eq!(probs.len(), 2);
637
638        let sum: f64 = probs.iter().sum();
639        assert!((sum - 1.0).abs() < 1e-6);
640    }
641
642    #[test]
643    fn test_quantum_maml() {
644        let config = QuantumMetaLearningConfig {
645            num_qubits: 2,
646            circuit_depth: 2,
647            inner_lr: 0.01,
648            outer_lr: 0.001,
649            inner_steps: 3,
650            n_support: 2,
651            n_query: 5,
652            n_way: 2,
653            meta_batch_size: 2,
654        };
655
656        let maml = QuantumMAML::new(config.clone());
657
658        let task = QuantumTask::random(
659            config.num_qubits,
660            config.n_way,
661            config.n_support,
662            config.n_query,
663        );
664
665        let adapted_model = maml.adapt(&task).expect("MAML adaptation should succeed");
666        let probs = adapted_model
667            .forward(&task.query_states[0])
668            .expect("adapted model forward pass should succeed");
669
670        assert_eq!(probs.len(), config.n_way);
671    }
672}