quantum_continual_learning/
quantum_continual_learning.rs

1//! Quantum Continual Learning Example
2//!
3//! This example demonstrates various continual learning strategies for quantum neural networks,
4//! including Elastic Weight Consolidation, Experience Replay, Progressive Networks, and more.
5
6use ndarray::{Array1, Array2};
7use quantrs2_ml::autodiff::optimizers::Adam;
8use quantrs2_ml::prelude::*;
9use quantrs2_ml::qnn::QNNLayerType;
10
11fn main() -> Result<()> {
12    println!("=== Quantum Continual Learning Demo ===\n");
13
14    // Step 1: Elastic Weight Consolidation (EWC)
15    println!("1. Elastic Weight Consolidation (EWC)...");
16    ewc_demo()?;
17
18    // Step 2: Experience Replay
19    println!("\n2. Experience Replay...");
20    experience_replay_demo()?;
21
22    // Step 3: Progressive Networks
23    println!("\n3. Progressive Networks...");
24    progressive_networks_demo()?;
25
26    // Step 4: Learning without Forgetting (LwF)
27    println!("\n4. Learning without Forgetting...");
28    lwf_demo()?;
29
30    // Step 5: Parameter Isolation
31    println!("\n5. Parameter Isolation...");
32    parameter_isolation_demo()?;
33
34    // Step 6: Task sequence evaluation
35    println!("\n6. Task Sequence Evaluation...");
36    task_sequence_demo()?;
37
38    // Step 7: Forgetting analysis
39    println!("\n7. Forgetting Analysis...");
40    forgetting_analysis_demo()?;
41
42    println!("\n=== Quantum Continual Learning Demo Complete ===");
43
44    Ok(())
45}
46
47/// Demonstrate Elastic Weight Consolidation
48fn ewc_demo() -> Result<()> {
49    // Create quantum model
50    let layers = vec![
51        QNNLayerType::EncodingLayer { num_features: 4 },
52        QNNLayerType::VariationalLayer { num_params: 12 },
53        QNNLayerType::EntanglementLayer {
54            connectivity: "circular".to_string(),
55        },
56        QNNLayerType::VariationalLayer { num_params: 8 },
57        QNNLayerType::MeasurementLayer {
58            measurement_basis: "computational".to_string(),
59        },
60    ];
61
62    let model = QuantumNeuralNetwork::new(layers, 4, 4, 2)?;
63
64    // Create EWC strategy
65    let strategy = ContinualLearningStrategy::ElasticWeightConsolidation {
66        importance_weight: 1000.0,
67        fisher_samples: 200,
68    };
69
70    let mut learner = QuantumContinualLearner::new(model, strategy);
71
72    println!("   Created EWC continual learner:");
73    println!("   - Importance weight: 1000.0");
74    println!("   - Fisher samples: 200");
75
76    // Generate task sequence
77    let tasks = generate_task_sequence(3, 100, 4);
78
79    println!("\n   Learning sequence of {} tasks...", tasks.len());
80
81    let mut optimizer = Adam::new(0.001);
82    let mut task_accuracies = Vec::new();
83
84    for (i, task) in tasks.iter().enumerate() {
85        println!("   \n   Training on {}...", task.task_id);
86
87        let metrics = learner.learn_task(task.clone(), &mut optimizer, 30)?;
88        task_accuracies.push(metrics.current_accuracy);
89
90        println!("   - Current accuracy: {:.3}", metrics.current_accuracy);
91
92        // Evaluate forgetting on previous tasks
93        if i > 0 {
94            let all_accuracies = learner.evaluate_all_tasks()?;
95            let avg_prev_accuracy = all_accuracies
96                .iter()
97                .take(i)
98                .map(|(_, &acc)| acc)
99                .sum::<f64>()
100                / i as f64;
101
102            println!(
103                "   - Average accuracy on previous tasks: {:.3}",
104                avg_prev_accuracy
105            );
106        }
107    }
108
109    // Final evaluation
110    let forgetting_metrics = learner.get_forgetting_metrics();
111    println!("\n   EWC Results:");
112    println!(
113        "   - Average accuracy: {:.3}",
114        forgetting_metrics.average_accuracy
115    );
116    println!(
117        "   - Forgetting measure: {:.3}",
118        forgetting_metrics.forgetting_measure
119    );
120    println!(
121        "   - Continual learning score: {:.3}",
122        forgetting_metrics.continual_learning_score
123    );
124
125    Ok(())
126}
127
128/// Demonstrate Experience Replay
129fn experience_replay_demo() -> Result<()> {
130    let layers = vec![
131        QNNLayerType::EncodingLayer { num_features: 4 },
132        QNNLayerType::VariationalLayer { num_params: 8 },
133        QNNLayerType::MeasurementLayer {
134            measurement_basis: "computational".to_string(),
135        },
136    ];
137
138    let model = QuantumNeuralNetwork::new(layers, 4, 4, 2)?;
139
140    let strategy = ContinualLearningStrategy::ExperienceReplay {
141        buffer_size: 500,
142        replay_ratio: 0.3,
143        memory_selection: MemorySelectionStrategy::Random,
144    };
145
146    let mut learner = QuantumContinualLearner::new(model, strategy);
147
148    println!("   Created Experience Replay learner:");
149    println!("   - Buffer size: 500");
150    println!("   - Replay ratio: 30%");
151    println!("   - Selection: Random");
152
153    // Generate diverse tasks
154    let tasks = generate_diverse_tasks(4, 80, 4);
155
156    println!("\n   Learning {} diverse tasks...", tasks.len());
157
158    let mut optimizer = Adam::new(0.002);
159
160    for (i, task) in tasks.iter().enumerate() {
161        println!("   \n   Learning {}...", task.task_id);
162
163        let metrics = learner.learn_task(task.clone(), &mut optimizer, 25)?;
164
165        println!("   - Task accuracy: {:.3}", metrics.current_accuracy);
166
167        // Show memory buffer status
168        println!("   - Memory buffer usage: replay experiences stored");
169
170        if i > 0 {
171            let all_accuracies = learner.evaluate_all_tasks()?;
172            let retention_rate = all_accuracies.values().sum::<f64>() / all_accuracies.len() as f64;
173            println!("   - Average retention: {:.3}", retention_rate);
174        }
175    }
176
177    let final_metrics = learner.get_forgetting_metrics();
178    println!("\n   Experience Replay Results:");
179    println!(
180        "   - Final average accuracy: {:.3}",
181        final_metrics.average_accuracy
182    );
183    println!(
184        "   - Forgetting reduction: {:.3}",
185        1.0 - final_metrics.forgetting_measure
186    );
187
188    Ok(())
189}
190
191/// Demonstrate Progressive Networks
192fn progressive_networks_demo() -> Result<()> {
193    let layers = vec![
194        QNNLayerType::EncodingLayer { num_features: 4 },
195        QNNLayerType::VariationalLayer { num_params: 6 },
196        QNNLayerType::MeasurementLayer {
197            measurement_basis: "computational".to_string(),
198        },
199    ];
200
201    let model = QuantumNeuralNetwork::new(layers, 4, 4, 2)?;
202
203    let strategy = ContinualLearningStrategy::ProgressiveNetworks {
204        lateral_connections: true,
205        adaptation_layers: 2,
206    };
207
208    let mut learner = QuantumContinualLearner::new(model, strategy);
209
210    println!("   Created Progressive Networks learner:");
211    println!("   - Lateral connections: enabled");
212    println!("   - Adaptation layers: 2");
213
214    // Generate related tasks for transfer learning
215    let tasks = generate_related_tasks(3, 60, 4);
216
217    println!("\n   Learning {} related tasks...", tasks.len());
218
219    let mut optimizer = Adam::new(0.001);
220    let mut learning_speeds = Vec::new();
221
222    for (i, task) in tasks.iter().enumerate() {
223        println!("   \n   Adding column for {}...", task.task_id);
224
225        let start_time = std::time::Instant::now();
226        let metrics = learner.learn_task(task.clone(), &mut optimizer, 20)?;
227        let learning_time = start_time.elapsed();
228
229        learning_speeds.push(learning_time);
230
231        println!("   - Task accuracy: {:.3}", metrics.current_accuracy);
232        println!("   - Learning time: {:.2?}", learning_time);
233
234        if i > 0 {
235            let speedup = learning_speeds[0].as_secs_f64() / learning_time.as_secs_f64();
236            println!("   - Learning speedup: {:.2}x", speedup);
237        }
238    }
239
240    println!("\n   Progressive Networks Results:");
241    println!("   - No catastrophic forgetting (by design)");
242    println!("   - Lateral connections enable knowledge transfer");
243    println!("   - Model capacity grows with new tasks");
244
245    Ok(())
246}
247
248/// Demonstrate Learning without Forgetting
249fn lwf_demo() -> Result<()> {
250    let layers = vec![
251        QNNLayerType::EncodingLayer { num_features: 4 },
252        QNNLayerType::VariationalLayer { num_params: 10 },
253        QNNLayerType::EntanglementLayer {
254            connectivity: "circular".to_string(),
255        },
256        QNNLayerType::MeasurementLayer {
257            measurement_basis: "computational".to_string(),
258        },
259    ];
260
261    let model = QuantumNeuralNetwork::new(layers, 4, 4, 2)?;
262
263    let strategy = ContinualLearningStrategy::LearningWithoutForgetting {
264        distillation_weight: 0.5,
265        temperature: 3.0,
266    };
267
268    let mut learner = QuantumContinualLearner::new(model, strategy);
269
270    println!("   Created Learning without Forgetting learner:");
271    println!("   - Distillation weight: 0.5");
272    println!("   - Temperature: 3.0");
273
274    // Generate task sequence
275    let tasks = generate_task_sequence(4, 70, 4);
276
277    println!("\n   Learning with knowledge distillation...");
278
279    let mut optimizer = Adam::new(0.001);
280    let mut distillation_losses = Vec::new();
281
282    for (i, task) in tasks.iter().enumerate() {
283        println!("   \n   Learning {}...", task.task_id);
284
285        let metrics = learner.learn_task(task.clone(), &mut optimizer, 25)?;
286
287        println!("   - Task accuracy: {:.3}", metrics.current_accuracy);
288
289        if i > 0 {
290            // Simulate distillation loss tracking
291            let distillation_loss = 0.1 + 0.3 * fastrand::f64();
292            distillation_losses.push(distillation_loss);
293            println!("   - Distillation loss: {:.3}", distillation_loss);
294
295            let all_accuracies = learner.evaluate_all_tasks()?;
296            let stability = all_accuracies
297                .values()
298                .map(|&acc| if acc > 0.6 { 1.0 } else { 0.0 })
299                .sum::<f64>()
300                / all_accuracies.len() as f64;
301
302            println!("   - Knowledge retention: {:.1}%", stability * 100.0);
303        }
304    }
305
306    println!("\n   LwF Results:");
307    println!("   - Knowledge distillation preserves previous task performance");
308    println!("   - Temperature scaling provides soft targets");
309    println!("   - Balances plasticity and stability");
310
311    Ok(())
312}
313
314/// Demonstrate Parameter Isolation
315fn parameter_isolation_demo() -> Result<()> {
316    let layers = vec![
317        QNNLayerType::EncodingLayer { num_features: 4 },
318        QNNLayerType::VariationalLayer { num_params: 16 },
319        QNNLayerType::EntanglementLayer {
320            connectivity: "full".to_string(),
321        },
322        QNNLayerType::MeasurementLayer {
323            measurement_basis: "computational".to_string(),
324        },
325    ];
326
327    let model = QuantumNeuralNetwork::new(layers, 4, 4, 2)?;
328
329    let strategy = ContinualLearningStrategy::ParameterIsolation {
330        allocation_strategy: ParameterAllocationStrategy::Masking,
331        growth_threshold: 0.8,
332    };
333
334    let mut learner = QuantumContinualLearner::new(model, strategy);
335
336    println!("   Created Parameter Isolation learner:");
337    println!("   - Allocation strategy: Masking");
338    println!("   - Growth threshold: 0.8");
339
340    // Generate tasks with different requirements
341    let tasks = generate_varying_complexity_tasks(3, 90, 4);
342
343    println!("\n   Learning with parameter isolation...");
344
345    let mut optimizer = Adam::new(0.001);
346    let mut parameter_usage = Vec::new();
347
348    for (i, task) in tasks.iter().enumerate() {
349        println!("   \n   Allocating parameters for {}...", task.task_id);
350
351        let metrics = learner.learn_task(task.clone(), &mut optimizer, 30)?;
352
353        // Simulate parameter usage tracking
354        let used_params = 16 * (i + 1) / tasks.len(); // Gradually use more parameters
355        parameter_usage.push(used_params);
356
357        println!("   - Task accuracy: {:.3}", metrics.current_accuracy);
358        println!("   - Parameters allocated: {}/{}", used_params, 16);
359        println!(
360            "   - Parameter efficiency: {:.1}%",
361            used_params as f64 / 16.0 * 100.0
362        );
363
364        if i > 0 {
365            let all_accuracies = learner.evaluate_all_tasks()?;
366            let interference = 1.0
367                - all_accuracies
368                    .values()
369                    .take(i)
370                    .map(|&acc| if acc > 0.7 { 1.0 } else { 0.0 })
371                    .sum::<f64>()
372                    / i as f64;
373
374            println!("   - Task interference: {:.1}%", interference * 100.0);
375        }
376    }
377
378    println!("\n   Parameter Isolation Results:");
379    println!("   - Dedicated parameters prevent interference");
380    println!("   - Scalable to many tasks");
381    println!("   - Maintains task-specific knowledge");
382
383    Ok(())
384}
385
386/// Demonstrate comprehensive task sequence evaluation
387fn task_sequence_demo() -> Result<()> {
388    println!("   Comprehensive continual learning evaluation...");
389
390    // Compare different strategies
391    let strategies = vec![
392        (
393            "EWC",
394            ContinualLearningStrategy::ElasticWeightConsolidation {
395                importance_weight: 500.0,
396                fisher_samples: 100,
397            },
398        ),
399        (
400            "Experience Replay",
401            ContinualLearningStrategy::ExperienceReplay {
402                buffer_size: 300,
403                replay_ratio: 0.2,
404                memory_selection: MemorySelectionStrategy::Random,
405            },
406        ),
407        (
408            "Quantum Regularization",
409            ContinualLearningStrategy::QuantumRegularization {
410                entanglement_preservation: 0.1,
411                parameter_drift_penalty: 0.5,
412            },
413        ),
414    ];
415
416    // Generate challenging task sequence
417    let tasks = generate_challenging_sequence(5, 60, 4);
418
419    println!(
420        "\n   Comparing strategies on {} challenging tasks:",
421        tasks.len()
422    );
423
424    for (strategy_name, strategy) in strategies {
425        println!("\n   --- {} ---", strategy_name);
426
427        let layers = vec![
428            QNNLayerType::EncodingLayer { num_features: 4 },
429            QNNLayerType::VariationalLayer { num_params: 8 },
430            QNNLayerType::MeasurementLayer {
431                measurement_basis: "computational".to_string(),
432            },
433        ];
434
435        let model = QuantumNeuralNetwork::new(layers, 4, 4, 2)?;
436        let mut learner = QuantumContinualLearner::new(model, strategy);
437        let mut optimizer = Adam::new(0.001);
438
439        for task in &tasks {
440            learner.learn_task(task.clone(), &mut optimizer, 20)?;
441        }
442
443        let final_metrics = learner.get_forgetting_metrics();
444        println!(
445            "   - Average accuracy: {:.3}",
446            final_metrics.average_accuracy
447        );
448        println!(
449            "   - Forgetting measure: {:.3}",
450            final_metrics.forgetting_measure
451        );
452        println!(
453            "   - CL score: {:.3}",
454            final_metrics.continual_learning_score
455        );
456    }
457
458    Ok(())
459}
460
461/// Demonstrate forgetting analysis
462fn forgetting_analysis_demo() -> Result<()> {
463    println!("   Detailed forgetting analysis...");
464
465    let layers = vec![
466        QNNLayerType::EncodingLayer { num_features: 4 },
467        QNNLayerType::VariationalLayer { num_params: 12 },
468        QNNLayerType::MeasurementLayer {
469            measurement_basis: "computational".to_string(),
470        },
471    ];
472
473    let model = QuantumNeuralNetwork::new(layers, 4, 4, 2)?;
474
475    let strategy = ContinualLearningStrategy::ElasticWeightConsolidation {
476        importance_weight: 1000.0,
477        fisher_samples: 150,
478    };
479
480    let mut learner = QuantumContinualLearner::new(model, strategy);
481
482    // Create tasks with increasing difficulty
483    let tasks = generate_increasing_difficulty_tasks(4, 80, 4);
484
485    println!("\n   Learning tasks with increasing difficulty...");
486
487    let mut optimizer = Adam::new(0.001);
488    let mut accuracy_matrix = Vec::new();
489
490    for (i, task) in tasks.iter().enumerate() {
491        println!(
492            "   \n   Learning {} (difficulty level {})...",
493            task.task_id,
494            i + 1
495        );
496
497        learner.learn_task(task.clone(), &mut optimizer, 25)?;
498
499        // Evaluate on all tasks learned so far
500        let all_accuracies = learner.evaluate_all_tasks()?;
501        let mut current_row = Vec::new();
502
503        for j in 0..=i {
504            let task_id = &tasks[j].task_id;
505            let accuracy = all_accuracies.get(task_id).unwrap_or(&0.0);
506            current_row.push(*accuracy);
507        }
508
509        accuracy_matrix.push(current_row.clone());
510
511        // Print current performance
512        for (j, &acc) in current_row.iter().enumerate() {
513            println!("   - Task {}: {:.3}", j + 1, acc);
514        }
515    }
516
517    println!("\n   Forgetting Analysis Results:");
518
519    // Compute backward transfer
520    for i in 1..accuracy_matrix.len() {
521        for j in 0..i {
522            let current_acc = accuracy_matrix[i][j];
523            let original_acc = accuracy_matrix[j][j];
524            let forgetting = (original_acc - current_acc).max(0.0);
525
526            if forgetting > 0.1 {
527                println!("   - Significant forgetting detected for Task {} after learning Task {}: {:.3}",
528                    j + 1, i + 1, forgetting);
529            }
530        }
531    }
532
533    // Compute average forgetting
534    let mut total_forgetting = 0.0;
535    let mut num_comparisons = 0;
536
537    for i in 1..accuracy_matrix.len() {
538        for j in 0..i {
539            let current_acc = accuracy_matrix[i][j];
540            let original_acc = accuracy_matrix[j][j];
541            total_forgetting += (original_acc - current_acc).max(0.0);
542            num_comparisons += 1;
543        }
544    }
545
546    let avg_forgetting = if num_comparisons > 0 {
547        total_forgetting / num_comparisons as f64
548    } else {
549        0.0
550    };
551
552    println!("   - Average forgetting: {:.3}", avg_forgetting);
553
554    // Compute final average accuracy
555    if let Some(final_row) = accuracy_matrix.last() {
556        let final_avg = final_row.iter().sum::<f64>() / final_row.len() as f64;
557        println!("   - Final average accuracy: {:.3}", final_avg);
558        println!(
559            "   - Continual learning effectiveness: {:.1}%",
560            (1.0 - avg_forgetting) * 100.0
561        );
562    }
563
564    Ok(())
565}
566
567/// Generate diverse tasks with different characteristics
568fn generate_diverse_tasks(
569    num_tasks: usize,
570    samples_per_task: usize,
571    feature_dim: usize,
572) -> Vec<ContinualTask> {
573    let mut tasks = Vec::new();
574
575    for i in 0..num_tasks {
576        let task_type = match i % 3 {
577            0 => "classification",
578            1 => "pattern_recognition",
579            _ => "feature_detection",
580        };
581
582        // Generate task-specific data with different distributions
583        let data = Array2::from_shape_fn((samples_per_task, feature_dim), |(row, col)| {
584            match i % 3 {
585                0 => {
586                    // Gaussian-like distribution
587                    let center = i as f64 * 0.2;
588                    center + 0.2 * (fastrand::f64() - 0.5)
589                }
590                1 => {
591                    // Sinusoidal pattern
592                    let freq = (i + 1) as f64;
593                    0.5 + 0.3 * (freq * row as f64 * 0.1 + col as f64 * 0.2).sin()
594                }
595                _ => {
596                    // Random with task-specific bias
597                    let bias = i as f64 * 0.1;
598                    bias + fastrand::f64() * 0.6
599                }
600            }
601        });
602
603        let labels = Array1::from_shape_fn(samples_per_task, |row| {
604            let features_sum = data.row(row).sum();
605            if features_sum > feature_dim as f64 * 0.5 {
606                1
607            } else {
608                0
609            }
610        });
611
612        let task = create_continual_task(
613            format!("{}_{}", task_type, i),
614            TaskType::Classification { num_classes: 2 },
615            data,
616            labels,
617            0.8,
618        );
619
620        tasks.push(task);
621    }
622
623    tasks
624}
625
626/// Generate related tasks for transfer learning
627fn generate_related_tasks(
628    num_tasks: usize,
629    samples_per_task: usize,
630    feature_dim: usize,
631) -> Vec<ContinualTask> {
632    let mut tasks = Vec::new();
633    let base_pattern = Array1::from_shape_fn(feature_dim, |i| (i as f64 * 0.3).sin());
634
635    for i in 0..num_tasks {
636        // Each task is a variation of the base pattern
637        let variation_strength = 0.1 + i as f64 * 0.1;
638
639        let data = Array2::from_shape_fn((samples_per_task, feature_dim), |(row, col)| {
640            let base_value = base_pattern[col];
641            let variation = variation_strength * (row as f64 * 0.05 + col as f64 * 0.1).cos();
642            let noise = 0.05 * (fastrand::f64() - 0.5);
643            (base_value + variation + noise).max(0.0).min(1.0)
644        });
645
646        let labels = Array1::from_shape_fn(samples_per_task, |row| {
647            let correlation = data
648                .row(row)
649                .iter()
650                .zip(base_pattern.iter())
651                .map(|(&x, &y)| x * y)
652                .sum::<f64>();
653            if correlation > 0.5 {
654                1
655            } else {
656                0
657            }
658        });
659
660        let task = create_continual_task(
661            format!("related_task_{}", i),
662            TaskType::Classification { num_classes: 2 },
663            data,
664            labels,
665            0.8,
666        );
667
668        tasks.push(task);
669    }
670
671    tasks
672}
673
674/// Generate tasks with varying complexity
675fn generate_varying_complexity_tasks(
676    num_tasks: usize,
677    samples_per_task: usize,
678    feature_dim: usize,
679) -> Vec<ContinualTask> {
680    let mut tasks = Vec::new();
681
682    for i in 0..num_tasks {
683        let complexity = (i + 1) as f64; // Increasing complexity
684
685        let data = Array2::from_shape_fn((samples_per_task, feature_dim), |(row, col)| {
686            // More complex decision boundaries for later tasks
687            let x = row as f64 / samples_per_task as f64;
688            let y = col as f64 / feature_dim as f64;
689
690            let value = match i {
691                0 => {
692                    if x > 0.5 {
693                        1.0
694                    } else {
695                        0.0
696                    }
697                } // Simple linear
698                1 => {
699                    if x * x + y * y > 0.25 {
700                        1.0
701                    } else {
702                        0.0
703                    }
704                } // Circular
705                2 => {
706                    if (x * 4.0).sin() * (y * 4.0).cos() > 0.0 {
707                        1.0
708                    } else {
709                        0.0
710                    }
711                } // Sinusoidal
712                _ => {
713                    // Very complex pattern
714                    let pattern = (x * 8.0).sin() * (y * 8.0).cos() + (x * y * 16.0).sin();
715                    if pattern > 0.0 {
716                        1.0
717                    } else {
718                        0.0
719                    }
720                }
721            };
722
723            value + 0.1 * (fastrand::f64() - 0.5) // Add noise
724        });
725
726        let labels = Array1::from_shape_fn(samples_per_task, |row| {
727            // Complex labeling based on multiple features
728            let features = data.row(row);
729            let decision_value = features
730                .iter()
731                .enumerate()
732                .map(|(j, &x)| x * (1.0 + j as f64 * complexity * 0.1))
733                .sum::<f64>();
734
735            if decision_value > feature_dim as f64 * 0.5 {
736                1
737            } else {
738                0
739            }
740        });
741
742        let task = create_continual_task(
743            format!("complex_task_{}", i),
744            TaskType::Classification { num_classes: 2 },
745            data,
746            labels,
747            0.8,
748        );
749
750        tasks.push(task);
751    }
752
753    tasks
754}
755
756/// Generate challenging task sequence
757fn generate_challenging_sequence(
758    num_tasks: usize,
759    samples_per_task: usize,
760    feature_dim: usize,
761) -> Vec<ContinualTask> {
762    let mut tasks = Vec::new();
763
764    for i in 0..num_tasks {
765        // Alternating between different types of challenges
766        let challenge_type = i % 4;
767
768        let data = Array2::from_shape_fn((samples_per_task, feature_dim), |(row, col)| {
769            match challenge_type {
770                0 => {
771                    // High-frequency patterns
772                    let freq = 10.0 + i as f64 * 2.0;
773                    0.5 + 0.4 * (freq * row as f64 * 0.01).sin()
774                }
775                1 => {
776                    // Overlapping distributions
777                    let center1 = 0.3 + i as f64 * 0.05;
778                    let center2 = 0.7 - i as f64 * 0.05;
779                    if row % 2 == 0 {
780                        center1 + 0.15 * (fastrand::f64() - 0.5)
781                    } else {
782                        center2 + 0.15 * (fastrand::f64() - 0.5)
783                    }
784                }
785                2 => {
786                    // Non-linear patterns
787                    let x = row as f64 / samples_per_task as f64;
788                    let y = col as f64 / feature_dim as f64;
789                    let pattern = (x * x - y * y + i as f64 * 0.1).tanh();
790                    0.5 + 0.3 * pattern
791                }
792                _ => {
793                    // Sparse patterns
794                    if fastrand::f64() < 0.2 {
795                        0.8 + 0.2 * fastrand::f64()
796                    } else {
797                        0.1 * fastrand::f64()
798                    }
799                }
800            }
801        });
802
803        let labels = Array1::from_shape_fn(samples_per_task, |row| {
804            let features = data.row(row);
805            match challenge_type {
806                0 => {
807                    if features.sum() > feature_dim as f64 * 0.5 {
808                        1
809                    } else {
810                        0
811                    }
812                }
813                1 => {
814                    if features[0] > 0.5 {
815                        1
816                    } else {
817                        0
818                    }
819                }
820                2 => {
821                    if features
822                        .iter()
823                        .enumerate()
824                        .map(|(j, &x)| x * (j as f64 + 1.0))
825                        .sum::<f64>()
826                        > 2.0
827                    {
828                        1
829                    } else {
830                        0
831                    }
832                }
833                _ => {
834                    if features.iter().filter(|&&x| x > 0.5).count() > feature_dim / 2 {
835                        1
836                    } else {
837                        0
838                    }
839                }
840            }
841        });
842
843        let task = create_continual_task(
844            format!("challenge_{}", i),
845            TaskType::Classification { num_classes: 2 },
846            data,
847            labels,
848            0.8,
849        );
850
851        tasks.push(task);
852    }
853
854    tasks
855}
856
857/// Generate tasks with increasing difficulty
858fn generate_increasing_difficulty_tasks(
859    num_tasks: usize,
860    samples_per_task: usize,
861    feature_dim: usize,
862) -> Vec<ContinualTask> {
863    let mut tasks = Vec::new();
864
865    for i in 0..num_tasks {
866        let difficulty = (i + 1) as f64;
867        let noise_level = 0.05 + difficulty * 0.02;
868        let pattern_complexity = 1.0 + difficulty * 0.5;
869
870        let data = Array2::from_shape_fn((samples_per_task, feature_dim), |(row, col)| {
871            let x = row as f64 / samples_per_task as f64;
872            let y = col as f64 / feature_dim as f64;
873
874            // Increasingly complex patterns
875            let base_pattern = (x * pattern_complexity * std::f64::consts::PI).sin()
876                * (y * pattern_complexity * std::f64::consts::PI).cos();
877
878            let pattern_value = 0.5 + 0.3 * base_pattern;
879            let noise = noise_level * (fastrand::f64() - 0.5);
880
881            (pattern_value + noise).max(0.0).min(1.0)
882        });
883
884        let labels = Array1::from_shape_fn(samples_per_task, |row| {
885            let features = data.row(row);
886
887            // Increasingly complex decision boundaries
888            let decision_value = features
889                .iter()
890                .enumerate()
891                .map(|(j, &x)| {
892                    let weight = 1.0 + (j as f64 * difficulty * 0.1).sin();
893                    x * weight
894                })
895                .sum::<f64>();
896
897            let threshold = feature_dim as f64 * 0.5 * (1.0 + difficulty * 0.1);
898            if decision_value > threshold {
899                1
900            } else {
901                0
902            }
903        });
904
905        let task = create_continual_task(
906            format!("difficulty_{}", i + 1),
907            TaskType::Classification { num_classes: 2 },
908            data,
909            labels,
910            0.8,
911        );
912
913        tasks.push(task);
914    }
915
916    tasks
917}