pub fn generate_task_sequence(
num_tasks: usize,
samples_per_task: usize,
feature_dim: usize,
) -> Vec<ContinualTask>Expand description
Helper function to generate synthetic task sequence
Examples found in repository?
examples/quantum_continual_learning.rs (line 91)
62fn ewc_demo() -> Result<()> {
63 // Create quantum model
64 let layers = vec![
65 QNNLayerType::EncodingLayer { num_features: 4 },
66 QNNLayerType::VariationalLayer { num_params: 12 },
67 QNNLayerType::EntanglementLayer {
68 connectivity: "circular".to_string(),
69 },
70 QNNLayerType::VariationalLayer { num_params: 8 },
71 QNNLayerType::MeasurementLayer {
72 measurement_basis: "computational".to_string(),
73 },
74 ];
75
76 let model = QuantumNeuralNetwork::new(layers, 4, 4, 2)?;
77
78 // Create EWC strategy
79 let strategy = ContinualLearningStrategy::ElasticWeightConsolidation {
80 importance_weight: 1000.0,
81 fisher_samples: 200,
82 };
83
84 let mut learner = QuantumContinualLearner::new(model, strategy);
85
86 println!(" Created EWC continual learner:");
87 println!(" - Importance weight: 1000.0");
88 println!(" - Fisher samples: 200");
89
90 // Generate task sequence
91 let tasks = generate_task_sequence(3, 100, 4);
92
93 println!("\n Learning sequence of {} tasks...", tasks.len());
94
95 let mut optimizer = Adam::new(0.001);
96 let mut task_accuracies = Vec::new();
97
98 for (i, task) in tasks.iter().enumerate() {
99 println!(" \n Training on {}...", task.task_id);
100
101 let metrics = learner.learn_task(task.clone(), &mut optimizer, 30)?;
102 task_accuracies.push(metrics.current_accuracy);
103
104 println!(" - Current accuracy: {:.3}", metrics.current_accuracy);
105
106 // Evaluate forgetting on previous tasks
107 if i > 0 {
108 let all_accuracies = learner.evaluate_all_tasks()?;
109 let avg_prev_accuracy = all_accuracies
110 .iter()
111 .take(i)
112 .map(|(_, &acc)| acc)
113 .sum::<f64>()
114 / i as f64;
115
116 println!(" - Average accuracy on previous tasks: {avg_prev_accuracy:.3}");
117 }
118 }
119
120 // Final evaluation
121 let forgetting_metrics = learner.get_forgetting_metrics();
122 println!("\n EWC Results:");
123 println!(
124 " - Average accuracy: {:.3}",
125 forgetting_metrics.average_accuracy
126 );
127 println!(
128 " - Forgetting measure: {:.3}",
129 forgetting_metrics.forgetting_measure
130 );
131 println!(
132 " - Continual learning score: {:.3}",
133 forgetting_metrics.continual_learning_score
134 );
135
136 Ok(())
137}
138
139/// Demonstrate Experience Replay
140fn experience_replay_demo() -> Result<()> {
141 let layers = vec![
142 QNNLayerType::EncodingLayer { num_features: 4 },
143 QNNLayerType::VariationalLayer { num_params: 8 },
144 QNNLayerType::MeasurementLayer {
145 measurement_basis: "computational".to_string(),
146 },
147 ];
148
149 let model = QuantumNeuralNetwork::new(layers, 4, 4, 2)?;
150
151 let strategy = ContinualLearningStrategy::ExperienceReplay {
152 buffer_size: 500,
153 replay_ratio: 0.3,
154 memory_selection: MemorySelectionStrategy::Random,
155 };
156
157 let mut learner = QuantumContinualLearner::new(model, strategy);
158
159 println!(" Created Experience Replay learner:");
160 println!(" - Buffer size: 500");
161 println!(" - Replay ratio: 30%");
162 println!(" - Selection: Random");
163
164 // Generate diverse tasks
165 let tasks = generate_diverse_tasks(4, 80, 4);
166
167 println!("\n Learning {} diverse tasks...", tasks.len());
168
169 let mut optimizer = Adam::new(0.002);
170
171 for (i, task) in tasks.iter().enumerate() {
172 println!(" \n Learning {}...", task.task_id);
173
174 let metrics = learner.learn_task(task.clone(), &mut optimizer, 25)?;
175
176 println!(" - Task accuracy: {:.3}", metrics.current_accuracy);
177
178 // Show memory buffer status
179 println!(" - Memory buffer usage: replay experiences stored");
180
181 if i > 0 {
182 let all_accuracies = learner.evaluate_all_tasks()?;
183 let retention_rate = all_accuracies.values().sum::<f64>() / all_accuracies.len() as f64;
184 println!(" - Average retention: {retention_rate:.3}");
185 }
186 }
187
188 let final_metrics = learner.get_forgetting_metrics();
189 println!("\n Experience Replay Results:");
190 println!(
191 " - Final average accuracy: {:.3}",
192 final_metrics.average_accuracy
193 );
194 println!(
195 " - Forgetting reduction: {:.3}",
196 1.0 - final_metrics.forgetting_measure
197 );
198
199 Ok(())
200}
201
202/// Demonstrate Progressive Networks
203fn progressive_networks_demo() -> Result<()> {
204 let layers = vec![
205 QNNLayerType::EncodingLayer { num_features: 4 },
206 QNNLayerType::VariationalLayer { num_params: 6 },
207 QNNLayerType::MeasurementLayer {
208 measurement_basis: "computational".to_string(),
209 },
210 ];
211
212 let model = QuantumNeuralNetwork::new(layers, 4, 4, 2)?;
213
214 let strategy = ContinualLearningStrategy::ProgressiveNetworks {
215 lateral_connections: true,
216 adaptation_layers: 2,
217 };
218
219 let mut learner = QuantumContinualLearner::new(model, strategy);
220
221 println!(" Created Progressive Networks learner:");
222 println!(" - Lateral connections: enabled");
223 println!(" - Adaptation layers: 2");
224
225 // Generate related tasks for transfer learning
226 let tasks = generate_related_tasks(3, 60, 4);
227
228 println!("\n Learning {} related tasks...", tasks.len());
229
230 let mut optimizer = Adam::new(0.001);
231 let mut learning_speeds = Vec::new();
232
233 for (i, task) in tasks.iter().enumerate() {
234 println!(" \n Adding column for {}...", task.task_id);
235
236 let start_time = std::time::Instant::now();
237 let metrics = learner.learn_task(task.clone(), &mut optimizer, 20)?;
238 let learning_time = start_time.elapsed();
239
240 learning_speeds.push(learning_time);
241
242 println!(" - Task accuracy: {:.3}", metrics.current_accuracy);
243 println!(" - Learning time: {learning_time:.2?}");
244
245 if i > 0 {
246 let speedup = learning_speeds[0].as_secs_f64() / learning_time.as_secs_f64();
247 println!(" - Learning speedup: {speedup:.2}x");
248 }
249 }
250
251 println!("\n Progressive Networks Results:");
252 println!(" - No catastrophic forgetting (by design)");
253 println!(" - Lateral connections enable knowledge transfer");
254 println!(" - Model capacity grows with new tasks");
255
256 Ok(())
257}
258
259/// Demonstrate Learning without Forgetting
260fn lwf_demo() -> Result<()> {
261 let layers = vec![
262 QNNLayerType::EncodingLayer { num_features: 4 },
263 QNNLayerType::VariationalLayer { num_params: 10 },
264 QNNLayerType::EntanglementLayer {
265 connectivity: "circular".to_string(),
266 },
267 QNNLayerType::MeasurementLayer {
268 measurement_basis: "computational".to_string(),
269 },
270 ];
271
272 let model = QuantumNeuralNetwork::new(layers, 4, 4, 2)?;
273
274 let strategy = ContinualLearningStrategy::LearningWithoutForgetting {
275 distillation_weight: 0.5,
276 temperature: 3.0,
277 };
278
279 let mut learner = QuantumContinualLearner::new(model, strategy);
280
281 println!(" Created Learning without Forgetting learner:");
282 println!(" - Distillation weight: 0.5");
283 println!(" - Temperature: 3.0");
284
285 // Generate task sequence
286 let tasks = generate_task_sequence(4, 70, 4);
287
288 println!("\n Learning with knowledge distillation...");
289
290 let mut optimizer = Adam::new(0.001);
291 let mut distillation_losses = Vec::new();
292
293 for (i, task) in tasks.iter().enumerate() {
294 println!(" \n Learning {}...", task.task_id);
295
296 let metrics = learner.learn_task(task.clone(), &mut optimizer, 25)?;
297
298 println!(" - Task accuracy: {:.3}", metrics.current_accuracy);
299
300 if i > 0 {
301 // Simulate distillation loss tracking
302 let distillation_loss = 0.3f64.mul_add(fastrand::f64(), 0.1);
303 distillation_losses.push(distillation_loss);
304 println!(" - Distillation loss: {distillation_loss:.3}");
305
306 let all_accuracies = learner.evaluate_all_tasks()?;
307 let stability = all_accuracies
308 .values()
309 .map(|&acc| if acc > 0.6 { 1.0 } else { 0.0 })
310 .sum::<f64>()
311 / all_accuracies.len() as f64;
312
313 println!(" - Knowledge retention: {:.1}%", stability * 100.0);
314 }
315 }
316
317 println!("\n LwF Results:");
318 println!(" - Knowledge distillation preserves previous task performance");
319 println!(" - Temperature scaling provides soft targets");
320 println!(" - Balances plasticity and stability");
321
322 Ok(())
323}