kizzasi_model/curriculum.rs
1//! Curriculum Learning for kizzasi-model
2//!
3//! Implements curriculum learning strategies that control the order and
4//! difficulty of training samples presented to the model. This follows the
5//! insight from Bengio et al. (2009) that training on easier examples first
6//! and gradually introducing harder ones can improve convergence and
7//! generalization.
8//!
9//! # Strategies
10//!
11//! - **Competence-based pacing**: linearly increases a competence threshold
12//! each epoch; only samples with difficulty <= competence are included.
13//! - **Self-Paced Learning (SPL)**: uses a soft weight function
14//! `w = max(0, 1 - difficulty/lambda)` where lambda grows over time,
15//! gradually admitting harder samples.
16//! - **Annealing**: linearly interpolates the difficulty gate from an easy
17//! start threshold to a harder end threshold across epochs.
18//!
19//! # Example
20//!
21//! ```rust,ignore
22//! use kizzasi_model::curriculum::{CurriculumScheduler, CurriculumStrategy, CurriculumDataProvider};
23//! use kizzasi_model::training_loop::ArrayDataProvider;
24//! use scirs2_core::ndarray::{Array1, Array2};
25//!
26//! let features = Array2::<f32>::zeros((100, 4));
27//! let targets = Array1::<f32>::zeros(100);
28//! let data = ArrayDataProvider::new(features, targets);
29//!
30//! // Per-sample difficulty scores in [0, 1]
31//! let difficulties = Array1::from_vec(vec![0.1; 100]);
32//! let strategy = CurriculumStrategy::Competence { initial: 0.2, increment: 0.1 };
33//!
34//! let mut provider = CurriculumDataProvider::new(data, difficulties, strategy);
35//! provider.advance_epoch();
36//! let active = provider.active_indices();
37//! ```
38
39use crate::error::{ModelError, ModelResult};
40use crate::training_loop::{ArrayDataProvider, DataProvider};
41use scirs2_core::ndarray::{Array1, Array2};
42
43// ---------------------------------------------------------------------------
44// CurriculumStrategy
45// ---------------------------------------------------------------------------
46
47/// Strategy for controlling which samples are presented during training.
48#[derive(Debug, Clone)]
49pub enum CurriculumStrategy {
50 /// Competence-based pacing.
51 ///
52 /// The competence threshold starts at `initial` and increases by
53 /// `increment` each time `step()` is called, clamped to [0, 1].
54 /// A sample is included if its difficulty <= current competence.
55 Competence {
56 /// Starting competence level in [0, 1].
57 initial: f32,
58 /// Amount competence grows per epoch.
59 increment: f32,
60 },
61
62 /// Self-Paced Learning (SPL).
63 ///
64 /// Uses a soft weight function: `w(d) = max(0, 1 - d / lambda)`.
65 /// A sample is included when `w > 0.5`, i.e. `d < lambda / 2`.
66 /// `lambda` grows by a multiplicative factor each epoch, starting at
67 /// the provided initial value.
68 SelfPaced {
69 /// Initial pace parameter controlling the difficulty boundary.
70 /// Higher lambda admits more samples.
71 lambda: f32,
72 },
73
74 /// Annealing strategy.
75 ///
76 /// Linearly interpolates the difficulty gate from `start_difficulty`
77 /// to `end` over successive epochs. The gate value at epoch `e` is:
78 ///
79 /// `gate = start + (end - start) * min(1, e / ramp_epochs)`
80 ///
81 /// where `ramp_epochs` is derived from the number of `step()` calls.
82 Annealing {
83 /// Starting difficulty threshold (easy end).
84 start_difficulty: f32,
85 /// Ending difficulty threshold (hard end, typically 1.0).
86 end: f32,
87 },
88}
89
90// ---------------------------------------------------------------------------
91// CurriculumScheduler
92// ---------------------------------------------------------------------------
93
94/// Drives the curriculum pacing across training epochs.
95///
96/// Each call to [`step()`](CurriculumScheduler::step) advances the internal
97/// epoch counter and recomputes the current competence / difficulty gate.
98#[derive(Debug, Clone)]
99pub struct CurriculumScheduler {
100 strategy: CurriculumStrategy,
101 /// Current competence level in [0, 1] — determines which samples pass.
102 current_competence: f32,
103 /// Number of times `step()` has been called.
104 epoch: usize,
105 /// For SPL: current lambda (grows each epoch).
106 spl_lambda: f32,
107 /// For SPL: multiplicative growth factor for lambda each epoch.
108 spl_growth: f32,
109}
110
111impl CurriculumScheduler {
112 /// Create a new scheduler for the given strategy.
113 ///
114 /// The initial competence is derived from the strategy parameters:
115 /// - Competence: `initial`
116 /// - SelfPaced: derived from `lambda` → `lambda / 2`
117 /// - Annealing: `start_difficulty`
118 pub fn new(strategy: CurriculumStrategy) -> Self {
119 let (initial_competence, spl_lambda) = match &strategy {
120 CurriculumStrategy::Competence { initial, .. } => (*initial, 0.0),
121 CurriculumStrategy::SelfPaced { lambda } => {
122 // SPL: include sample if difficulty < lambda/2
123 // So effective competence = lambda/2 clamped to [0,1]
124 ((*lambda * 0.5).clamp(0.0, 1.0), *lambda)
125 }
126 CurriculumStrategy::Annealing {
127 start_difficulty, ..
128 } => (*start_difficulty, 0.0),
129 };
130
131 Self {
132 strategy,
133 current_competence: initial_competence.clamp(0.0, 1.0),
134 epoch: 0,
135 spl_lambda,
136 // Lambda grows by 20% each epoch by default — a reasonable pace
137 // that ensures convergence to including all samples.
138 spl_growth: 1.2,
139 }
140 }
141
142 /// Override the SPL growth factor (default 1.2).
143 ///
144 /// Only meaningful for [`CurriculumStrategy::SelfPaced`].
145 pub fn with_spl_growth(mut self, growth: f32) -> Self {
146 self.spl_growth = growth.max(1.0);
147 self
148 }
149
150 /// Advance one epoch and return the updated competence in [0, 1].
151 pub fn step(&mut self) -> f32 {
152 self.epoch += 1;
153
154 match &self.strategy {
155 CurriculumStrategy::Competence {
156 initial, increment, ..
157 } => {
158 // Linear ramp: competence = initial + epoch * increment
159 self.current_competence =
160 (*initial + self.epoch as f32 * *increment).clamp(0.0, 1.0);
161 }
162 CurriculumStrategy::SelfPaced { .. } => {
163 // Grow lambda each epoch
164 self.spl_lambda *= self.spl_growth;
165 // Effective competence boundary = lambda / 2, clamped
166 self.current_competence = (self.spl_lambda * 0.5).clamp(0.0, 1.0);
167 }
168 CurriculumStrategy::Annealing {
169 start_difficulty,
170 end,
171 } => {
172 // We use a fixed ramp of 100 epochs for the interpolation.
173 // After 100 epochs the gate is fully at `end`.
174 let ramp_epochs = 100.0_f32;
175 let t = (self.epoch as f32 / ramp_epochs).clamp(0.0, 1.0);
176 self.current_competence =
177 (*start_difficulty + (*end - *start_difficulty) * t).clamp(0.0, 1.0);
178 }
179 }
180
181 self.current_competence
182 }
183
184 /// Return the current competence level without advancing.
185 pub fn current_competence(&self) -> f32 {
186 self.current_competence
187 }
188
189 /// Return how many epochs have been stepped.
190 pub fn epoch(&self) -> usize {
191 self.epoch
192 }
193
194 /// Whether a sample with the given difficulty should be included at the
195 /// current competence level.
196 pub fn should_include(&self, difficulty: f32) -> bool {
197 match &self.strategy {
198 CurriculumStrategy::Competence { .. } | CurriculumStrategy::Annealing { .. } => {
199 difficulty <= self.current_competence
200 }
201 CurriculumStrategy::SelfPaced { .. } => {
202 // SPL weight: w = max(0, 1 - difficulty / lambda)
203 // Include if w > 0.5, i.e. difficulty < lambda / 2
204 if self.spl_lambda <= 0.0 {
205 return false;
206 }
207 let w = (1.0 - difficulty / self.spl_lambda).max(0.0);
208 w > 0.5
209 }
210 }
211 }
212
213 /// Return the indices of samples whose difficulty passes the current gate.
214 pub fn filter_indices(&self, difficulties: &[f32]) -> Vec<usize> {
215 difficulties
216 .iter()
217 .enumerate()
218 .filter(|(_, &d)| self.should_include(d))
219 .map(|(i, _)| i)
220 .collect()
221 }
222
223 /// Compute SPL weight for a given difficulty.
224 ///
225 /// Returns a value in [0, 1] where 1 = easy/fully included and 0 = too hard.
226 /// For non-SPL strategies this returns 1.0 if included, 0.0 otherwise.
227 pub fn spl_weight(&self, difficulty: f32) -> f32 {
228 match &self.strategy {
229 CurriculumStrategy::SelfPaced { .. } => {
230 if self.spl_lambda <= 0.0 {
231 return 0.0;
232 }
233 (1.0 - difficulty / self.spl_lambda).clamp(0.0, 1.0)
234 }
235 _ => {
236 if self.should_include(difficulty) {
237 1.0
238 } else {
239 0.0
240 }
241 }
242 }
243 }
244}
245
246// ---------------------------------------------------------------------------
247// CurriculumDataProvider
248// ---------------------------------------------------------------------------
249
250/// A data provider that wraps [`ArrayDataProvider`] and filters samples
251/// according to a [`CurriculumScheduler`].
252///
253/// Each sample has an associated difficulty score in [0, 1]. The scheduler
254/// controls which samples are visible at each epoch, starting with easy
255/// samples and progressively including harder ones.
256pub struct CurriculumDataProvider {
257 /// The underlying data.
258 inner: ArrayDataProvider,
259 /// Per-sample difficulty scores, length == `inner.num_samples()`.
260 difficulties: Array1<f32>,
261 /// The curriculum scheduler driving the pacing.
262 scheduler: CurriculumScheduler,
263 /// Cached active indices (recomputed on `advance_epoch`).
264 cached_active: Vec<usize>,
265}
266
267impl CurriculumDataProvider {
268 /// Create a new curriculum data provider.
269 ///
270 /// # Errors
271 ///
272 /// Returns an error if `difficulties.len() != data.num_samples()`.
273 pub fn new(
274 data: ArrayDataProvider,
275 difficulties: Array1<f32>,
276 strategy: CurriculumStrategy,
277 ) -> ModelResult<Self> {
278 if difficulties.len() != data.num_samples() {
279 return Err(ModelError::dimension_mismatch(
280 "CurriculumDataProvider::new",
281 data.num_samples(),
282 difficulties.len(),
283 ));
284 }
285
286 let scheduler = CurriculumScheduler::new(strategy);
287
288 // Compute initial active indices.
289 let cached_active = scheduler.filter_indices(difficulties.as_slice().unwrap_or(&[]));
290
291 Ok(Self {
292 inner: data,
293 difficulties,
294 scheduler,
295 cached_active,
296 })
297 }
298
299 /// Advance one epoch: steps the scheduler and recomputes active indices.
300 pub fn advance_epoch(&mut self) {
301 self.scheduler.step();
302 self.recompute_active();
303 }
304
305 /// Recompute the cached active indices from the current scheduler state.
306 fn recompute_active(&mut self) {
307 let diff_slice = self.difficulties.as_slice().unwrap_or(&[]);
308 self.cached_active = self.scheduler.filter_indices(diff_slice);
309 }
310
311 /// Return the indices of currently active (included) samples.
312 pub fn active_indices(&self) -> Vec<usize> {
313 self.cached_active.clone()
314 }
315
316 /// Number of currently active samples.
317 pub fn active_count(&self) -> usize {
318 self.cached_active.len()
319 }
320
321 /// Total samples in the underlying dataset (not just active ones).
322 pub fn total_samples(&self) -> usize {
323 self.inner.num_samples()
324 }
325
326 /// Fraction of the dataset currently active.
327 pub fn active_fraction(&self) -> f32 {
328 if self.inner.num_samples() == 0 {
329 return 0.0;
330 }
331 self.cached_active.len() as f32 / self.inner.num_samples() as f32
332 }
333
334 /// Access the scheduler (read-only).
335 pub fn scheduler(&self) -> &CurriculumScheduler {
336 &self.scheduler
337 }
338
339 /// Access the difficulty scores.
340 pub fn difficulties(&self) -> &Array1<f32> {
341 &self.difficulties
342 }
343
344 /// Get SPL weights for all active samples.
345 ///
346 /// Returns a vector of `(sample_index, weight)` pairs for all currently
347 /// active samples.
348 pub fn active_weights(&self) -> Vec<(usize, f32)> {
349 self.cached_active
350 .iter()
351 .map(|&idx| {
352 let d = self.difficulties[idx];
353 (idx, self.scheduler.spl_weight(d))
354 })
355 .collect()
356 }
357}
358
359impl DataProvider for CurriculumDataProvider {
360 fn num_samples(&self) -> usize {
361 // Report the number of *active* samples so training loops
362 // naturally iterate over the curriculum-filtered subset.
363 self.cached_active.len()
364 }
365
366 fn num_features(&self) -> usize {
367 self.inner.num_features()
368 }
369
370 fn get_batch(&self, indices: &[usize]) -> (Array2<f32>, Array1<f32>) {
371 // `indices` are positions within the *active* set.
372 // Map them to the original dataset indices.
373 let mapped: Vec<usize> = indices
374 .iter()
375 .map(|&i| {
376 if i < self.cached_active.len() {
377 self.cached_active[i]
378 } else {
379 // Clamp to last active index to avoid panic.
380 self.cached_active.last().copied().unwrap_or(0)
381 }
382 })
383 .collect();
384
385 self.inner.get_batch(&mapped)
386 }
387
388 fn shuffle_indices(&self, rng_seed: u64) -> Vec<usize> {
389 // Shuffle over the active set only.
390 let n = self.cached_active.len();
391 let mut indices: Vec<usize> = (0..n).collect();
392 let mut state = rng_seed.wrapping_add(1);
393 for i in (1..n).rev() {
394 state = state
395 .wrapping_mul(6_364_136_223_846_793_005)
396 .wrapping_add(1_442_695_040_888_963_407);
397 let j = (state >> 33) as usize % (i + 1);
398 indices.swap(i, j);
399 }
400 indices
401 }
402}
403
404// ---------------------------------------------------------------------------
405// Difficulty estimators
406// ---------------------------------------------------------------------------
407
408/// Estimate per-sample difficulty from loss values.
409///
410/// Given a vector of per-sample losses, normalises them to [0, 1] by
411/// mapping the minimum loss to 0 and the maximum to 1. Samples with
412/// higher loss are considered harder.
413///
414/// Returns an `Array1<f32>` of difficulty scores.
415pub fn estimate_difficulty_from_loss(losses: &[f32]) -> ModelResult<Array1<f32>> {
416 if losses.is_empty() {
417 return Err(ModelError::invalid_config(
418 "Cannot estimate difficulty from empty loss vector",
419 ));
420 }
421
422 let min_loss = losses.iter().copied().fold(f32::INFINITY, f32::min);
423 let max_loss = losses.iter().copied().fold(f32::NEG_INFINITY, f32::max);
424
425 let range = max_loss - min_loss;
426
427 let difficulties = if range.abs() < f32::EPSILON {
428 // All losses are the same — assign uniform difficulty 0.5.
429 Array1::from_elem(losses.len(), 0.5)
430 } else {
431 Array1::from_vec(
432 losses
433 .iter()
434 .map(|&l| ((l - min_loss) / range).clamp(0.0, 1.0))
435 .collect(),
436 )
437 };
438
439 Ok(difficulties)
440}
441
442/// Estimate difficulty based on feature variance.
443///
444/// Samples whose features have higher variance (further from the dataset
445/// mean) are considered harder. Normalised to [0, 1].
446pub fn estimate_difficulty_from_variance(features: &Array2<f32>) -> ModelResult<Array1<f32>> {
447 let n = features.nrows();
448 if n == 0 {
449 return Err(ModelError::invalid_config(
450 "Cannot estimate difficulty from empty feature matrix",
451 ));
452 }
453
454 let ncols = features.ncols();
455
456 // Compute per-sample variance of features.
457 let mut variances = Vec::with_capacity(n);
458 for row in features.rows() {
459 let mean: f32 = row.iter().sum::<f32>() / ncols.max(1) as f32;
460 let var: f32 =
461 row.iter().map(|&x| (x - mean) * (x - mean)).sum::<f32>() / ncols.max(1) as f32;
462 variances.push(var);
463 }
464
465 estimate_difficulty_from_loss(&variances)
466}
467
468// ---------------------------------------------------------------------------
469// Tests
470// ---------------------------------------------------------------------------
471
472#[cfg(test)]
473mod tests {
474 use super::*;
475 use scirs2_core::ndarray::{Array1, Array2};
476
477 /// Helper: create an ArrayDataProvider with n samples, nf features.
478 fn make_provider(n: usize, nf: usize) -> ArrayDataProvider {
479 let features = Array2::<f32>::zeros((n, nf));
480 let targets = Array1::<f32>::zeros(n);
481 ArrayDataProvider::new(features, targets)
482 }
483
484 // 1. Competence-based pacing
485 #[test]
486 fn test_curriculum_competence_pacing() {
487 let strategy = CurriculumStrategy::Competence {
488 initial: 0.2,
489 increment: 0.1,
490 };
491 let mut sched = CurriculumScheduler::new(strategy);
492
493 // Before any step: competence = 0.2
494 assert!((sched.current_competence() - 0.2).abs() < 1e-6);
495
496 // After step 1: competence = 0.2 + 0.1 = 0.3
497 let c1 = sched.step();
498 assert!((c1 - 0.3).abs() < 1e-6);
499
500 // After step 2: 0.4
501 let c2 = sched.step();
502 assert!((c2 - 0.4).abs() < 1e-6);
503
504 // Filter test: difficulties [0.1, 0.35, 0.5, 0.9]
505 let difficulties = [0.1, 0.35, 0.5, 0.9];
506 let indices = sched.filter_indices(&difficulties);
507 // Competence = 0.4 → should include 0.1 and 0.35
508 assert_eq!(indices, vec![0, 1]);
509
510 // Step a lot — should clamp at 1.0
511 for _ in 0..20 {
512 sched.step();
513 }
514 assert!((sched.current_competence() - 1.0).abs() < 1e-6);
515 // Now all should be included
516 let all = sched.filter_indices(&difficulties);
517 assert_eq!(all.len(), 4);
518 }
519
520 // 2. Self-Paced Learning
521 #[test]
522 fn test_curriculum_self_paced() {
523 let strategy = CurriculumStrategy::SelfPaced { lambda: 0.4 };
524 let mut sched = CurriculumScheduler::new(strategy);
525
526 // Initial: lambda = 0.4, boundary = 0.2
527 // Samples below 0.2 are included
528 let difficulties = [0.1, 0.19, 0.3, 0.5, 0.8];
529
530 let easy_first = sched.filter_indices(&difficulties);
531 // 0.1 and 0.19 are < 0.2
532 assert_eq!(easy_first, vec![0, 1]);
533
534 // After several steps, lambda grows and more are included
535 for _ in 0..5 {
536 sched.step();
537 }
538
539 let more = sched.filter_indices(&difficulties);
540 // Lambda has grown: 0.4 * 1.2^5 = ~0.995
541 // Boundary = ~0.497, so 0.1, 0.19, 0.3 should now be included
542 assert!(
543 more.len() >= 3,
544 "expected >= 3 included, got {}",
545 more.len()
546 );
547
548 // After many steps, all should be included (lambda large enough)
549 for _ in 0..20 {
550 sched.step();
551 }
552 let all = sched.filter_indices(&difficulties);
553 assert_eq!(all.len(), difficulties.len());
554 }
555
556 // 3. Annealing
557 #[test]
558 fn test_curriculum_annealing() {
559 let strategy = CurriculumStrategy::Annealing {
560 start_difficulty: 0.1,
561 end: 1.0,
562 };
563 let mut sched = CurriculumScheduler::new(strategy);
564
565 // Initial gate = start_difficulty = 0.1
566 assert!((sched.current_competence() - 0.1).abs() < 1e-6);
567
568 // After 50 steps (halfway through 100 ramp): gate = 0.1 + 0.9*0.5 = 0.55
569 for _ in 0..50 {
570 sched.step();
571 }
572 assert!(
573 (sched.current_competence() - 0.55).abs() < 0.02,
574 "expected ~0.55, got {}",
575 sched.current_competence()
576 );
577
578 // After 100 steps: gate = 1.0
579 for _ in 0..50 {
580 sched.step();
581 }
582 assert!(
583 (sched.current_competence() - 1.0).abs() < 1e-6,
584 "expected 1.0, got {}",
585 sched.current_competence()
586 );
587
588 // All samples should be included at gate = 1.0
589 let diffs = [0.0, 0.3, 0.7, 1.0];
590 let included = sched.filter_indices(&diffs);
591 assert_eq!(included.len(), 4);
592 }
593
594 // 4. CurriculumDataProvider active indices grow
595 #[test]
596 fn test_curriculum_provider_active_indices() {
597 let provider = make_provider(10, 2);
598 // Difficulties: 0.0, 0.1, 0.2, ..., 0.9
599 let difficulties = Array1::from_vec((0..10).map(|i| i as f32 * 0.1).collect());
600
601 let strategy = CurriculumStrategy::Competence {
602 initial: 0.0,
603 increment: 0.2,
604 };
605
606 let mut cp = CurriculumDataProvider::new(provider, difficulties, strategy)
607 .expect("construction should succeed");
608
609 // Initial competence = 0.0 → only difficulty == 0.0 passes (index 0)
610 let a0 = cp.active_indices();
611 assert_eq!(a0, vec![0]);
612
613 // After 1 epoch: competence = 0.2 → indices 0,1,2 (difficulties 0.0, 0.1, 0.2)
614 cp.advance_epoch();
615 let a1 = cp.active_indices();
616 assert_eq!(a1, vec![0, 1, 2]);
617
618 // After 2 epochs: competence = 0.4 → indices 0..=4
619 cp.advance_epoch();
620 let a2 = cp.active_indices();
621 assert_eq!(a2, vec![0, 1, 2, 3, 4]);
622
623 // Set grows monotonically
624 assert!(a2.len() > a1.len());
625 assert!(a1.len() > a0.len());
626 }
627
628 // 5. CurriculumDataProvider implements DataProvider correctly
629 #[test]
630 fn test_curriculum_provider_implements_data_provider() {
631 let features = Array2::from_shape_vec(
632 (5, 2),
633 vec![
634 1.0, 2.0, // easy
635 3.0, 4.0, // easy
636 5.0, 6.0, // medium
637 7.0, 8.0, // hard
638 9.0, 10.0, // hard
639 ],
640 )
641 .expect("shape ok");
642 let targets = Array1::from_vec(vec![10.0, 20.0, 30.0, 40.0, 50.0]);
643 let data = ArrayDataProvider::new(features, targets);
644
645 let difficulties = Array1::from_vec(vec![0.1, 0.15, 0.5, 0.8, 0.95]);
646 let strategy = CurriculumStrategy::Competence {
647 initial: 0.2,
648 increment: 0.1,
649 };
650
651 let cp = CurriculumDataProvider::new(data, difficulties, strategy)
652 .expect("construction should succeed");
653
654 // At initial competence 0.2: indices 0,1 are active (difficulties 0.1, 0.15)
655 assert_eq!(cp.num_samples(), 2);
656 assert_eq!(cp.num_features(), 2);
657
658 // get_batch with active-set indices [0, 1] → original indices [0, 1]
659 let (feat, tgt) = cp.get_batch(&[0, 1]);
660 assert_eq!(feat.shape(), &[2, 2]);
661 assert_eq!(tgt.len(), 2);
662
663 // Verify actual values match original samples 0 and 1
664 assert!((feat[[0, 0]] - 1.0).abs() < 1e-6);
665 assert!((feat[[1, 0]] - 3.0).abs() < 1e-6);
666 assert!((tgt[0] - 10.0).abs() < 1e-6);
667 assert!((tgt[1] - 20.0).abs() < 1e-6);
668 }
669}