transfer_learning/
transfer_learning.rs1use scirs2_core::ndarray::{Array1, Array2};
7use quantrs2_ml::autodiff::optimizers::Adam;
8use quantrs2_ml::prelude::*;
9use quantrs2_ml::qnn::QNNLayerType;
10
11fn main() -> Result<()> {
12 println!("=== Quantum Transfer Learning Demo ===\n");
13
14 println!("1. Loading pre-trained image classifier...");
16 let pretrained = QuantumModelZoo::get_image_classifier()?;
17
18 println!(" Pre-trained model info:");
19 println!(" - Task: {}", pretrained.task_description);
20 println!(
21 " - Original accuracy: {:.2}%",
22 pretrained
23 .performance_metrics
24 .get("accuracy")
25 .unwrap_or(&0.0)
26 * 100.0
27 );
28 println!(" - Number of qubits: {}", pretrained.qnn.num_qubits);
29
30 println!("\n2. Creating new layers for text classification task...");
32 let new_layers = vec![
33 QNNLayerType::VariationalLayer { num_params: 6 },
34 QNNLayerType::MeasurementLayer {
35 measurement_basis: "Pauli-Z".to_string(),
36 },
37 ];
38
39 println!("\n3. Testing different transfer learning strategies:");
41
42 println!("\n a) Fine-tuning strategy (train last 2 layers only)");
44 let mut transfer_finetune = QuantumTransferLearning::new(
45 pretrained.clone(),
46 new_layers.clone(),
47 TransferStrategy::FineTuning {
48 num_trainable_layers: 2,
49 },
50 )?;
51
52 println!(" b) Feature extraction strategy (freeze all pre-trained layers)");
54 let transfer_feature = QuantumTransferLearning::new(
55 pretrained.clone(),
56 new_layers.clone(),
57 TransferStrategy::FeatureExtraction,
58 )?;
59
60 println!(" c) Progressive unfreezing (unfreeze one layer every 5 epochs)");
62 let transfer_progressive = QuantumTransferLearning::new(
63 pretrained.clone(),
64 new_layers.clone(),
65 TransferStrategy::ProgressiveUnfreezing { unfreeze_rate: 5 },
66 )?;
67
68 println!("\n4. Generating synthetic training data...");
70 let num_samples = 50;
71 let num_features = 4;
72 let training_data = Array2::from_shape_fn((num_samples, num_features), |(i, j)| {
73 (i as f64 * 0.1 + j as f64 * 0.2).sin()
74 });
75 let labels = Array1::from_shape_fn(num_samples, |i| if i % 2 == 0 { 0.0 } else { 1.0 });
76
77 println!("\n5. Training with fine-tuning strategy...");
79 let mut optimizer = Adam::new(0.01);
80
81 let result = transfer_finetune.train(
82 &training_data,
83 &labels,
84 &mut optimizer,
85 20, 10, )?;
88
89 println!(" Training complete!");
90 println!(" - Final loss: {:.4}", result.final_loss);
91 println!(" - Accuracy: {:.2}%", result.accuracy * 100.0);
92
93 println!("\n6. Extracting features from pre-trained layers...");
95 let features = transfer_feature.extract_features(&training_data)?;
96 println!(" Extracted feature dimensions: {:?}", features.dim());
97
98 println!("\n7. Available pre-trained models in the zoo:");
100 println!(" - Image classifier (4 qubits, MNIST subset)");
101 println!(" - Chemistry model (6 qubits, molecular energy)");
102
103 let chemistry_model = QuantumModelZoo::get_chemistry_model()?;
105 println!("\n Chemistry model info:");
106 println!(" - Task: {}", chemistry_model.task_description);
107 println!(
108 " - MAE: {:.4}",
109 chemistry_model
110 .performance_metrics
111 .get("mae")
112 .unwrap_or(&0.0)
113 );
114 println!(
115 " - R² score: {:.4}",
116 chemistry_model
117 .performance_metrics
118 .get("r2_score")
119 .unwrap_or(&0.0)
120 );
121
122 println!("\n=== Transfer Learning Demo Complete ===");
123
124 Ok(())
125}
126
127fn print_layer_configs(configs: &[LayerConfig]) {
129 for (i, config) in configs.iter().enumerate() {
130 println!(
131 " Layer {}: frozen={}, lr_multiplier={:.2}",
132 i, config.frozen, config.learning_rate_multiplier
133 );
134 }
135}