quantum_continual_learning/
quantum_continual_learning.rs

1//! Quantum Continual Learning Example
2//!
3//! This example demonstrates various continual learning strategies for quantum neural networks,
4//! including Elastic Weight Consolidation, Experience Replay, Progressive Networks, and more.
5
6use scirs2_core::ndarray::{Array1, Array2};
7use quantrs2_ml::autodiff::optimizers::Adam;
8use quantrs2_ml::prelude::*;
9use quantrs2_ml::qnn::QNNLayerType;
10use scirs2_core::random::prelude::*;
11
12fn main() -> Result<()> {
13    println!("=== Quantum Continual Learning Demo ===\n");
14
15    // Step 1: Elastic Weight Consolidation (EWC)
16    println!("1. Elastic Weight Consolidation (EWC)...");
17    ewc_demo()?;
18
19    // Step 2: Experience Replay
20    println!("\n2. Experience Replay...");
21    experience_replay_demo()?;
22
23    // Step 3: Progressive Networks
24    println!("\n3. Progressive Networks...");
25    progressive_networks_demo()?;
26
27    // Step 4: Learning without Forgetting (LwF)
28    println!("\n4. Learning without Forgetting...");
29    lwf_demo()?;
30
31    // Step 5: Parameter Isolation
32    println!("\n5. Parameter Isolation...");
33    parameter_isolation_demo()?;
34
35    // Step 6: Task sequence evaluation
36    println!("\n6. Task Sequence Evaluation...");
37    task_sequence_demo()?;
38
39    // Step 7: Forgetting analysis
40    println!("\n7. Forgetting Analysis...");
41    forgetting_analysis_demo()?;
42
43    println!("\n=== Quantum Continual Learning Demo Complete ===");
44
45    Ok(())
46}
47
48/// Demonstrate Elastic Weight Consolidation
49fn ewc_demo() -> Result<()> {
50    // Create quantum model
51    let layers = vec![
52        QNNLayerType::EncodingLayer { num_features: 4 },
53        QNNLayerType::VariationalLayer { num_params: 12 },
54        QNNLayerType::EntanglementLayer {
55            connectivity: "circular".to_string(),
56        },
57        QNNLayerType::VariationalLayer { num_params: 8 },
58        QNNLayerType::MeasurementLayer {
59            measurement_basis: "computational".to_string(),
60        },
61    ];
62
63    let model = QuantumNeuralNetwork::new(layers, 4, 4, 2)?;
64
65    // Create EWC strategy
66    let strategy = ContinualLearningStrategy::ElasticWeightConsolidation {
67        importance_weight: 1000.0,
68        fisher_samples: 200,
69    };
70
71    let mut learner = QuantumContinualLearner::new(model, strategy);
72
73    println!("   Created EWC continual learner:");
74    println!("   - Importance weight: 1000.0");
75    println!("   - Fisher samples: 200");
76
77    // Generate task sequence
78    let tasks = generate_task_sequence(3, 100, 4);
79
80    println!("\n   Learning sequence of {} tasks...", tasks.len());
81
82    let mut optimizer = Adam::new(0.001);
83    let mut task_accuracies = Vec::new();
84
85    for (i, task) in tasks.iter().enumerate() {
86        println!("   \n   Training on {}...", task.task_id);
87
88        let metrics = learner.learn_task(task.clone(), &mut optimizer, 30)?;
89        task_accuracies.push(metrics.current_accuracy);
90
91        println!("   - Current accuracy: {:.3}", metrics.current_accuracy);
92
93        // Evaluate forgetting on previous tasks
94        if i > 0 {
95            let all_accuracies = learner.evaluate_all_tasks()?;
96            let avg_prev_accuracy = all_accuracies
97                .iter()
98                .take(i)
99                .map(|(_, &acc)| acc)
100                .sum::<f64>()
101                / i as f64;
102
103            println!(
104                "   - Average accuracy on previous tasks: {:.3}",
105                avg_prev_accuracy
106            );
107        }
108    }
109
110    // Final evaluation
111    let forgetting_metrics = learner.get_forgetting_metrics();
112    println!("\n   EWC Results:");
113    println!(
114        "   - Average accuracy: {:.3}",
115        forgetting_metrics.average_accuracy
116    );
117    println!(
118        "   - Forgetting measure: {:.3}",
119        forgetting_metrics.forgetting_measure
120    );
121    println!(
122        "   - Continual learning score: {:.3}",
123        forgetting_metrics.continual_learning_score
124    );
125
126    Ok(())
127}
128
129/// Demonstrate Experience Replay
130fn experience_replay_demo() -> Result<()> {
131    let layers = vec![
132        QNNLayerType::EncodingLayer { num_features: 4 },
133        QNNLayerType::VariationalLayer { num_params: 8 },
134        QNNLayerType::MeasurementLayer {
135            measurement_basis: "computational".to_string(),
136        },
137    ];
138
139    let model = QuantumNeuralNetwork::new(layers, 4, 4, 2)?;
140
141    let strategy = ContinualLearningStrategy::ExperienceReplay {
142        buffer_size: 500,
143        replay_ratio: 0.3,
144        memory_selection: MemorySelectionStrategy::Random,
145    };
146
147    let mut learner = QuantumContinualLearner::new(model, strategy);
148
149    println!("   Created Experience Replay learner:");
150    println!("   - Buffer size: 500");
151    println!("   - Replay ratio: 30%");
152    println!("   - Selection: Random");
153
154    // Generate diverse tasks
155    let tasks = generate_diverse_tasks(4, 80, 4);
156
157    println!("\n   Learning {} diverse tasks...", tasks.len());
158
159    let mut optimizer = Adam::new(0.002);
160
161    for (i, task) in tasks.iter().enumerate() {
162        println!("   \n   Learning {}...", task.task_id);
163
164        let metrics = learner.learn_task(task.clone(), &mut optimizer, 25)?;
165
166        println!("   - Task accuracy: {:.3}", metrics.current_accuracy);
167
168        // Show memory buffer status
169        println!("   - Memory buffer usage: replay experiences stored");
170
171        if i > 0 {
172            let all_accuracies = learner.evaluate_all_tasks()?;
173            let retention_rate = all_accuracies.values().sum::<f64>() / all_accuracies.len() as f64;
174            println!("   - Average retention: {:.3}", retention_rate);
175        }
176    }
177
178    let final_metrics = learner.get_forgetting_metrics();
179    println!("\n   Experience Replay Results:");
180    println!(
181        "   - Final average accuracy: {:.3}",
182        final_metrics.average_accuracy
183    );
184    println!(
185        "   - Forgetting reduction: {:.3}",
186        1.0 - final_metrics.forgetting_measure
187    );
188
189    Ok(())
190}
191
192/// Demonstrate Progressive Networks
193fn progressive_networks_demo() -> Result<()> {
194    let layers = vec![
195        QNNLayerType::EncodingLayer { num_features: 4 },
196        QNNLayerType::VariationalLayer { num_params: 6 },
197        QNNLayerType::MeasurementLayer {
198            measurement_basis: "computational".to_string(),
199        },
200    ];
201
202    let model = QuantumNeuralNetwork::new(layers, 4, 4, 2)?;
203
204    let strategy = ContinualLearningStrategy::ProgressiveNetworks {
205        lateral_connections: true,
206        adaptation_layers: 2,
207    };
208
209    let mut learner = QuantumContinualLearner::new(model, strategy);
210
211    println!("   Created Progressive Networks learner:");
212    println!("   - Lateral connections: enabled");
213    println!("   - Adaptation layers: 2");
214
215    // Generate related tasks for transfer learning
216    let tasks = generate_related_tasks(3, 60, 4);
217
218    println!("\n   Learning {} related tasks...", tasks.len());
219
220    let mut optimizer = Adam::new(0.001);
221    let mut learning_speeds = Vec::new();
222
223    for (i, task) in tasks.iter().enumerate() {
224        println!("   \n   Adding column for {}...", task.task_id);
225
226        let start_time = std::time::Instant::now();
227        let metrics = learner.learn_task(task.clone(), &mut optimizer, 20)?;
228        let learning_time = start_time.elapsed();
229
230        learning_speeds.push(learning_time);
231
232        println!("   - Task accuracy: {:.3}", metrics.current_accuracy);
233        println!("   - Learning time: {:.2?}", learning_time);
234
235        if i > 0 {
236            let speedup = learning_speeds[0].as_secs_f64() / learning_time.as_secs_f64();
237            println!("   - Learning speedup: {:.2}x", speedup);
238        }
239    }
240
241    println!("\n   Progressive Networks Results:");
242    println!("   - No catastrophic forgetting (by design)");
243    println!("   - Lateral connections enable knowledge transfer");
244    println!("   - Model capacity grows with new tasks");
245
246    Ok(())
247}
248
249/// Demonstrate Learning without Forgetting
250fn lwf_demo() -> Result<()> {
251    let layers = vec![
252        QNNLayerType::EncodingLayer { num_features: 4 },
253        QNNLayerType::VariationalLayer { num_params: 10 },
254        QNNLayerType::EntanglementLayer {
255            connectivity: "circular".to_string(),
256        },
257        QNNLayerType::MeasurementLayer {
258            measurement_basis: "computational".to_string(),
259        },
260    ];
261
262    let model = QuantumNeuralNetwork::new(layers, 4, 4, 2)?;
263
264    let strategy = ContinualLearningStrategy::LearningWithoutForgetting {
265        distillation_weight: 0.5,
266        temperature: 3.0,
267    };
268
269    let mut learner = QuantumContinualLearner::new(model, strategy);
270
271    println!("   Created Learning without Forgetting learner:");
272    println!("   - Distillation weight: 0.5");
273    println!("   - Temperature: 3.0");
274
275    // Generate task sequence
276    let tasks = generate_task_sequence(4, 70, 4);
277
278    println!("\n   Learning with knowledge distillation...");
279
280    let mut optimizer = Adam::new(0.001);
281    let mut distillation_losses = Vec::new();
282
283    for (i, task) in tasks.iter().enumerate() {
284        println!("   \n   Learning {}...", task.task_id);
285
286        let metrics = learner.learn_task(task.clone(), &mut optimizer, 25)?;
287
288        println!("   - Task accuracy: {:.3}", metrics.current_accuracy);
289
290        if i > 0 {
291            // Simulate distillation loss tracking
292            let distillation_loss = 0.1 + 0.3 * fastrand::f64();
293            distillation_losses.push(distillation_loss);
294            println!("   - Distillation loss: {:.3}", distillation_loss);
295
296            let all_accuracies = learner.evaluate_all_tasks()?;
297            let stability = all_accuracies
298                .values()
299                .map(|&acc| if acc > 0.6 { 1.0 } else { 0.0 })
300                .sum::<f64>()
301                / all_accuracies.len() as f64;
302
303            println!("   - Knowledge retention: {:.1}%", stability * 100.0);
304        }
305    }
306
307    println!("\n   LwF Results:");
308    println!("   - Knowledge distillation preserves previous task performance");
309    println!("   - Temperature scaling provides soft targets");
310    println!("   - Balances plasticity and stability");
311
312    Ok(())
313}
314
315/// Demonstrate Parameter Isolation
316fn parameter_isolation_demo() -> Result<()> {
317    let layers = vec![
318        QNNLayerType::EncodingLayer { num_features: 4 },
319        QNNLayerType::VariationalLayer { num_params: 16 },
320        QNNLayerType::EntanglementLayer {
321            connectivity: "full".to_string(),
322        },
323        QNNLayerType::MeasurementLayer {
324            measurement_basis: "computational".to_string(),
325        },
326    ];
327
328    let model = QuantumNeuralNetwork::new(layers, 4, 4, 2)?;
329
330    let strategy = ContinualLearningStrategy::ParameterIsolation {
331        allocation_strategy: ParameterAllocationStrategy::Masking,
332        growth_threshold: 0.8,
333    };
334
335    let mut learner = QuantumContinualLearner::new(model, strategy);
336
337    println!("   Created Parameter Isolation learner:");
338    println!("   - Allocation strategy: Masking");
339    println!("   - Growth threshold: 0.8");
340
341    // Generate tasks with different requirements
342    let tasks = generate_varying_complexity_tasks(3, 90, 4);
343
344    println!("\n   Learning with parameter isolation...");
345
346    let mut optimizer = Adam::new(0.001);
347    let mut parameter_usage = Vec::new();
348
349    for (i, task) in tasks.iter().enumerate() {
350        println!("   \n   Allocating parameters for {}...", task.task_id);
351
352        let metrics = learner.learn_task(task.clone(), &mut optimizer, 30)?;
353
354        // Simulate parameter usage tracking
355        let used_params = 16 * (i + 1) / tasks.len(); // Gradually use more parameters
356        parameter_usage.push(used_params);
357
358        println!("   - Task accuracy: {:.3}", metrics.current_accuracy);
359        println!("   - Parameters allocated: {}/{}", used_params, 16);
360        println!(
361            "   - Parameter efficiency: {:.1}%",
362            used_params as f64 / 16.0 * 100.0
363        );
364
365        if i > 0 {
366            let all_accuracies = learner.evaluate_all_tasks()?;
367            let interference = 1.0
368                - all_accuracies
369                    .values()
370                    .take(i)
371                    .map(|&acc| if acc > 0.7 { 1.0 } else { 0.0 })
372                    .sum::<f64>()
373                    / i as f64;
374
375            println!("   - Task interference: {:.1}%", interference * 100.0);
376        }
377    }
378
379    println!("\n   Parameter Isolation Results:");
380    println!("   - Dedicated parameters prevent interference");
381    println!("   - Scalable to many tasks");
382    println!("   - Maintains task-specific knowledge");
383
384    Ok(())
385}
386
387/// Demonstrate comprehensive task sequence evaluation
388fn task_sequence_demo() -> Result<()> {
389    println!("   Comprehensive continual learning evaluation...");
390
391    // Compare different strategies
392    let strategies = vec![
393        (
394            "EWC",
395            ContinualLearningStrategy::ElasticWeightConsolidation {
396                importance_weight: 500.0,
397                fisher_samples: 100,
398            },
399        ),
400        (
401            "Experience Replay",
402            ContinualLearningStrategy::ExperienceReplay {
403                buffer_size: 300,
404                replay_ratio: 0.2,
405                memory_selection: MemorySelectionStrategy::Random,
406            },
407        ),
408        (
409            "Quantum Regularization",
410            ContinualLearningStrategy::QuantumRegularization {
411                entanglement_preservation: 0.1,
412                parameter_drift_penalty: 0.5,
413            },
414        ),
415    ];
416
417    // Generate challenging task sequence
418    let tasks = generate_challenging_sequence(5, 60, 4);
419
420    println!(
421        "\n   Comparing strategies on {} challenging tasks:",
422        tasks.len()
423    );
424
425    for (strategy_name, strategy) in strategies {
426        println!("\n   --- {} ---", strategy_name);
427
428        let layers = vec![
429            QNNLayerType::EncodingLayer { num_features: 4 },
430            QNNLayerType::VariationalLayer { num_params: 8 },
431            QNNLayerType::MeasurementLayer {
432                measurement_basis: "computational".to_string(),
433            },
434        ];
435
436        let model = QuantumNeuralNetwork::new(layers, 4, 4, 2)?;
437        let mut learner = QuantumContinualLearner::new(model, strategy);
438        let mut optimizer = Adam::new(0.001);
439
440        for task in &tasks {
441            learner.learn_task(task.clone(), &mut optimizer, 20)?;
442        }
443
444        let final_metrics = learner.get_forgetting_metrics();
445        println!(
446            "   - Average accuracy: {:.3}",
447            final_metrics.average_accuracy
448        );
449        println!(
450            "   - Forgetting measure: {:.3}",
451            final_metrics.forgetting_measure
452        );
453        println!(
454            "   - CL score: {:.3}",
455            final_metrics.continual_learning_score
456        );
457    }
458
459    Ok(())
460}
461
462/// Demonstrate forgetting analysis
463fn forgetting_analysis_demo() -> Result<()> {
464    println!("   Detailed forgetting analysis...");
465
466    let layers = vec![
467        QNNLayerType::EncodingLayer { num_features: 4 },
468        QNNLayerType::VariationalLayer { num_params: 12 },
469        QNNLayerType::MeasurementLayer {
470            measurement_basis: "computational".to_string(),
471        },
472    ];
473
474    let model = QuantumNeuralNetwork::new(layers, 4, 4, 2)?;
475
476    let strategy = ContinualLearningStrategy::ElasticWeightConsolidation {
477        importance_weight: 1000.0,
478        fisher_samples: 150,
479    };
480
481    let mut learner = QuantumContinualLearner::new(model, strategy);
482
483    // Create tasks with increasing difficulty
484    let tasks = generate_increasing_difficulty_tasks(4, 80, 4);
485
486    println!("\n   Learning tasks with increasing difficulty...");
487
488    let mut optimizer = Adam::new(0.001);
489    let mut accuracy_matrix = Vec::new();
490
491    for (i, task) in tasks.iter().enumerate() {
492        println!(
493            "   \n   Learning {} (difficulty level {})...",
494            task.task_id,
495            i + 1
496        );
497
498        learner.learn_task(task.clone(), &mut optimizer, 25)?;
499
500        // Evaluate on all tasks learned so far
501        let all_accuracies = learner.evaluate_all_tasks()?;
502        let mut current_row = Vec::new();
503
504        for j in 0..=i {
505            let task_id = &tasks[j].task_id;
506            let accuracy = all_accuracies.get(task_id).unwrap_or(&0.0);
507            current_row.push(*accuracy);
508        }
509
510        accuracy_matrix.push(current_row.clone());
511
512        // Print current performance
513        for (j, &acc) in current_row.iter().enumerate() {
514            println!("   - Task {}: {:.3}", j + 1, acc);
515        }
516    }
517
518    println!("\n   Forgetting Analysis Results:");
519
520    // Compute backward transfer
521    for i in 1..accuracy_matrix.len() {
522        for j in 0..i {
523            let current_acc = accuracy_matrix[i][j];
524            let original_acc = accuracy_matrix[j][j];
525            let forgetting = (original_acc - current_acc).max(0.0);
526
527            if forgetting > 0.1 {
528                println!("   - Significant forgetting detected for Task {} after learning Task {}: {:.3}",
529                    j + 1, i + 1, forgetting);
530            }
531        }
532    }
533
534    // Compute average forgetting
535    let mut total_forgetting = 0.0;
536    let mut num_comparisons = 0;
537
538    for i in 1..accuracy_matrix.len() {
539        for j in 0..i {
540            let current_acc = accuracy_matrix[i][j];
541            let original_acc = accuracy_matrix[j][j];
542            total_forgetting += (original_acc - current_acc).max(0.0);
543            num_comparisons += 1;
544        }
545    }
546
547    let avg_forgetting = if num_comparisons > 0 {
548        total_forgetting / num_comparisons as f64
549    } else {
550        0.0
551    };
552
553    println!("   - Average forgetting: {:.3}", avg_forgetting);
554
555    // Compute final average accuracy
556    if let Some(final_row) = accuracy_matrix.last() {
557        let final_avg = final_row.iter().sum::<f64>() / final_row.len() as f64;
558        println!("   - Final average accuracy: {:.3}", final_avg);
559        println!(
560            "   - Continual learning effectiveness: {:.1}%",
561            (1.0 - avg_forgetting) * 100.0
562        );
563    }
564
565    Ok(())
566}
567
568/// Generate diverse tasks with different characteristics
569fn generate_diverse_tasks(
570    num_tasks: usize,
571    samples_per_task: usize,
572    feature_dim: usize,
573) -> Vec<ContinualTask> {
574    let mut tasks = Vec::new();
575
576    for i in 0..num_tasks {
577        let task_type = match i % 3 {
578            0 => "classification",
579            1 => "pattern_recognition",
580            _ => "feature_detection",
581        };
582
583        // Generate task-specific data with different distributions
584        let data = Array2::from_shape_fn((samples_per_task, feature_dim), |(row, col)| {
585            match i % 3 {
586                0 => {
587                    // Gaussian-like distribution
588                    let center = i as f64 * 0.2;
589                    center + 0.2 * (fastrand::f64() - 0.5)
590                }
591                1 => {
592                    // Sinusoidal pattern
593                    let freq = (i + 1) as f64;
594                    0.5 + 0.3 * (freq * row as f64 * 0.1 + col as f64 * 0.2).sin()
595                }
596                _ => {
597                    // Random with task-specific bias
598                    let bias = i as f64 * 0.1;
599                    bias + fastrand::f64() * 0.6
600                }
601            }
602        });
603
604        let labels = Array1::from_shape_fn(samples_per_task, |row| {
605            let features_sum = data.row(row).sum();
606            if features_sum > feature_dim as f64 * 0.5 {
607                1
608            } else {
609                0
610            }
611        });
612
613        let task = create_continual_task(
614            format!("{}_{}", task_type, i),
615            TaskType::Classification { num_classes: 2 },
616            data,
617            labels,
618            0.8,
619        );
620
621        tasks.push(task);
622    }
623
624    tasks
625}
626
627/// Generate related tasks for transfer learning
628fn generate_related_tasks(
629    num_tasks: usize,
630    samples_per_task: usize,
631    feature_dim: usize,
632) -> Vec<ContinualTask> {
633    let mut tasks = Vec::new();
634    let base_pattern = Array1::from_shape_fn(feature_dim, |i| (i as f64 * 0.3).sin());
635
636    for i in 0..num_tasks {
637        // Each task is a variation of the base pattern
638        let variation_strength = 0.1 + i as f64 * 0.1;
639
640        let data = Array2::from_shape_fn((samples_per_task, feature_dim), |(row, col)| {
641            let base_value = base_pattern[col];
642            let variation = variation_strength * (row as f64 * 0.05 + col as f64 * 0.1).cos();
643            let noise = 0.05 * (fastrand::f64() - 0.5);
644            (base_value + variation + noise).max(0.0).min(1.0)
645        });
646
647        let labels = Array1::from_shape_fn(samples_per_task, |row| {
648            let correlation = data
649                .row(row)
650                .iter()
651                .zip(base_pattern.iter())
652                .map(|(&x, &y)| x * y)
653                .sum::<f64>();
654            if correlation > 0.5 {
655                1
656            } else {
657                0
658            }
659        });
660
661        let task = create_continual_task(
662            format!("related_task_{}", i),
663            TaskType::Classification { num_classes: 2 },
664            data,
665            labels,
666            0.8,
667        );
668
669        tasks.push(task);
670    }
671
672    tasks
673}
674
675/// Generate tasks with varying complexity
676fn generate_varying_complexity_tasks(
677    num_tasks: usize,
678    samples_per_task: usize,
679    feature_dim: usize,
680) -> Vec<ContinualTask> {
681    let mut tasks = Vec::new();
682
683    for i in 0..num_tasks {
684        let complexity = (i + 1) as f64; // Increasing complexity
685
686        let data = Array2::from_shape_fn((samples_per_task, feature_dim), |(row, col)| {
687            // More complex decision boundaries for later tasks
688            let x = row as f64 / samples_per_task as f64;
689            let y = col as f64 / feature_dim as f64;
690
691            let value = match i {
692                0 => {
693                    if x > 0.5 {
694                        1.0
695                    } else {
696                        0.0
697                    }
698                } // Simple linear
699                1 => {
700                    if x * x + y * y > 0.25 {
701                        1.0
702                    } else {
703                        0.0
704                    }
705                } // Circular
706                2 => {
707                    if (x * 4.0).sin() * (y * 4.0).cos() > 0.0 {
708                        1.0
709                    } else {
710                        0.0
711                    }
712                } // Sinusoidal
713                _ => {
714                    // Very complex pattern
715                    let pattern = (x * 8.0).sin() * (y * 8.0).cos() + (x * y * 16.0).sin();
716                    if pattern > 0.0 {
717                        1.0
718                    } else {
719                        0.0
720                    }
721                }
722            };
723
724            value + 0.1 * (fastrand::f64() - 0.5) // Add noise
725        });
726
727        let labels = Array1::from_shape_fn(samples_per_task, |row| {
728            // Complex labeling based on multiple features
729            let features = data.row(row);
730            let decision_value = features
731                .iter()
732                .enumerate()
733                .map(|(j, &x)| x * (1.0 + j as f64 * complexity * 0.1))
734                .sum::<f64>();
735
736            if decision_value > feature_dim as f64 * 0.5 {
737                1
738            } else {
739                0
740            }
741        });
742
743        let task = create_continual_task(
744            format!("complex_task_{}", i),
745            TaskType::Classification { num_classes: 2 },
746            data,
747            labels,
748            0.8,
749        );
750
751        tasks.push(task);
752    }
753
754    tasks
755}
756
757/// Generate challenging task sequence
758fn generate_challenging_sequence(
759    num_tasks: usize,
760    samples_per_task: usize,
761    feature_dim: usize,
762) -> Vec<ContinualTask> {
763    let mut tasks = Vec::new();
764
765    for i in 0..num_tasks {
766        // Alternating between different types of challenges
767        let challenge_type = i % 4;
768
769        let data = Array2::from_shape_fn((samples_per_task, feature_dim), |(row, col)| {
770            match challenge_type {
771                0 => {
772                    // High-frequency patterns
773                    let freq = 10.0 + i as f64 * 2.0;
774                    0.5 + 0.4 * (freq * row as f64 * 0.01).sin()
775                }
776                1 => {
777                    // Overlapping distributions
778                    let center1 = 0.3 + i as f64 * 0.05;
779                    let center2 = 0.7 - i as f64 * 0.05;
780                    if row % 2 == 0 {
781                        center1 + 0.15 * (fastrand::f64() - 0.5)
782                    } else {
783                        center2 + 0.15 * (fastrand::f64() - 0.5)
784                    }
785                }
786                2 => {
787                    // Non-linear patterns
788                    let x = row as f64 / samples_per_task as f64;
789                    let y = col as f64 / feature_dim as f64;
790                    let pattern = (x * x - y * y + i as f64 * 0.1).tanh();
791                    0.5 + 0.3 * pattern
792                }
793                _ => {
794                    // Sparse patterns
795                    if fastrand::f64() < 0.2 {
796                        0.8 + 0.2 * fastrand::f64()
797                    } else {
798                        0.1 * fastrand::f64()
799                    }
800                }
801            }
802        });
803
804        let labels = Array1::from_shape_fn(samples_per_task, |row| {
805            let features = data.row(row);
806            match challenge_type {
807                0 => {
808                    if features.sum() > feature_dim as f64 * 0.5 {
809                        1
810                    } else {
811                        0
812                    }
813                }
814                1 => {
815                    if features[0] > 0.5 {
816                        1
817                    } else {
818                        0
819                    }
820                }
821                2 => {
822                    if features
823                        .iter()
824                        .enumerate()
825                        .map(|(j, &x)| x * (j as f64 + 1.0))
826                        .sum::<f64>()
827                        > 2.0
828                    {
829                        1
830                    } else {
831                        0
832                    }
833                }
834                _ => {
835                    if features.iter().filter(|&&x| x > 0.5).count() > feature_dim / 2 {
836                        1
837                    } else {
838                        0
839                    }
840                }
841            }
842        });
843
844        let task = create_continual_task(
845            format!("challenge_{}", i),
846            TaskType::Classification { num_classes: 2 },
847            data,
848            labels,
849            0.8,
850        );
851
852        tasks.push(task);
853    }
854
855    tasks
856}
857
858/// Generate tasks with increasing difficulty
859fn generate_increasing_difficulty_tasks(
860    num_tasks: usize,
861    samples_per_task: usize,
862    feature_dim: usize,
863) -> Vec<ContinualTask> {
864    let mut tasks = Vec::new();
865
866    for i in 0..num_tasks {
867        let difficulty = (i + 1) as f64;
868        let noise_level = 0.05 + difficulty * 0.02;
869        let pattern_complexity = 1.0 + difficulty * 0.5;
870
871        let data = Array2::from_shape_fn((samples_per_task, feature_dim), |(row, col)| {
872            let x = row as f64 / samples_per_task as f64;
873            let y = col as f64 / feature_dim as f64;
874
875            // Increasingly complex patterns
876            let base_pattern = (x * pattern_complexity * std::f64::consts::PI).sin()
877                * (y * pattern_complexity * std::f64::consts::PI).cos();
878
879            let pattern_value = 0.5 + 0.3 * base_pattern;
880            let noise = noise_level * (fastrand::f64() - 0.5);
881
882            (pattern_value + noise).max(0.0).min(1.0)
883        });
884
885        let labels = Array1::from_shape_fn(samples_per_task, |row| {
886            let features = data.row(row);
887
888            // Increasingly complex decision boundaries
889            let decision_value = features
890                .iter()
891                .enumerate()
892                .map(|(j, &x)| {
893                    let weight = 1.0 + (j as f64 * difficulty * 0.1).sin();
894                    x * weight
895                })
896                .sum::<f64>();
897
898            let threshold = feature_dim as f64 * 0.5 * (1.0 + difficulty * 0.1);
899            if decision_value > threshold {
900                1
901            } else {
902                0
903            }
904        });
905
906        let task = create_continual_task(
907            format!("difficulty_{}", i + 1),
908            TaskType::Classification { num_classes: 2 },
909            data,
910            labels,
911            0.8,
912        );
913
914        tasks.push(task);
915    }
916
917    tasks
918}