Skip to main content

kizzasi_model/
curriculum.rs

1//! Curriculum Learning for kizzasi-model
2//!
3//! Implements curriculum learning strategies that control the order and
4//! difficulty of training samples presented to the model. This follows the
5//! insight from Bengio et al. (2009) that training on easier examples first
6//! and gradually introducing harder ones can improve convergence and
7//! generalization.
8//!
9//! # Strategies
10//!
11//! - **Competence-based pacing**: linearly increases a competence threshold
12//!   each epoch; only samples with difficulty <= competence are included.
13//! - **Self-Paced Learning (SPL)**: uses a soft weight function
14//!   `w = max(0, 1 - difficulty/lambda)` where lambda grows over time,
15//!   gradually admitting harder samples.
16//! - **Annealing**: linearly interpolates the difficulty gate from an easy
17//!   start threshold to a harder end threshold across epochs.
18//!
19//! # Example
20//!
21//! ```rust,ignore
22//! use kizzasi_model::curriculum::{CurriculumScheduler, CurriculumStrategy, CurriculumDataProvider};
23//! use kizzasi_model::training_loop::ArrayDataProvider;
24//! use scirs2_core::ndarray::{Array1, Array2};
25//!
26//! let features = Array2::<f32>::zeros((100, 4));
27//! let targets = Array1::<f32>::zeros(100);
28//! let data = ArrayDataProvider::new(features, targets);
29//!
30//! // Per-sample difficulty scores in [0, 1]
31//! let difficulties = Array1::from_vec(vec![0.1; 100]);
32//! let strategy = CurriculumStrategy::Competence { initial: 0.2, increment: 0.1 };
33//!
34//! let mut provider = CurriculumDataProvider::new(data, difficulties, strategy);
35//! provider.advance_epoch();
36//! let active = provider.active_indices();
37//! ```
38
39use crate::error::{ModelError, ModelResult};
40use crate::training_loop::{ArrayDataProvider, DataProvider};
41use scirs2_core::ndarray::{Array1, Array2};
42
43// ---------------------------------------------------------------------------
44// CurriculumStrategy
45// ---------------------------------------------------------------------------
46
47/// Strategy for controlling which samples are presented during training.
48#[derive(Debug, Clone)]
49pub enum CurriculumStrategy {
50    /// Competence-based pacing.
51    ///
52    /// The competence threshold starts at `initial` and increases by
53    /// `increment` each time `step()` is called, clamped to [0, 1].
54    /// A sample is included if its difficulty <= current competence.
55    Competence {
56        /// Starting competence level in [0, 1].
57        initial: f32,
58        /// Amount competence grows per epoch.
59        increment: f32,
60    },
61
62    /// Self-Paced Learning (SPL).
63    ///
64    /// Uses a soft weight function: `w(d) = max(0, 1 - d / lambda)`.
65    /// A sample is included when `w > 0.5`, i.e. `d < lambda / 2`.
66    /// `lambda` grows by a multiplicative factor each epoch, starting at
67    /// the provided initial value.
68    SelfPaced {
69        /// Initial pace parameter controlling the difficulty boundary.
70        /// Higher lambda admits more samples.
71        lambda: f32,
72    },
73
74    /// Annealing strategy.
75    ///
76    /// Linearly interpolates the difficulty gate from `start_difficulty`
77    /// to `end` over successive epochs. The gate value at epoch `e` is:
78    ///
79    /// `gate = start + (end - start) * min(1, e / ramp_epochs)`
80    ///
81    /// where `ramp_epochs` is derived from the number of `step()` calls.
82    Annealing {
83        /// Starting difficulty threshold (easy end).
84        start_difficulty: f32,
85        /// Ending difficulty threshold (hard end, typically 1.0).
86        end: f32,
87    },
88}
89
90// ---------------------------------------------------------------------------
91// CurriculumScheduler
92// ---------------------------------------------------------------------------
93
94/// Drives the curriculum pacing across training epochs.
95///
96/// Each call to [`step()`](CurriculumScheduler::step) advances the internal
97/// epoch counter and recomputes the current competence / difficulty gate.
98#[derive(Debug, Clone)]
99pub struct CurriculumScheduler {
100    strategy: CurriculumStrategy,
101    /// Current competence level in [0, 1] — determines which samples pass.
102    current_competence: f32,
103    /// Number of times `step()` has been called.
104    epoch: usize,
105    /// For SPL: current lambda (grows each epoch).
106    spl_lambda: f32,
107    /// For SPL: multiplicative growth factor for lambda each epoch.
108    spl_growth: f32,
109}
110
111impl CurriculumScheduler {
112    /// Create a new scheduler for the given strategy.
113    ///
114    /// The initial competence is derived from the strategy parameters:
115    /// - Competence: `initial`
116    /// - SelfPaced: derived from `lambda` → `lambda / 2`
117    /// - Annealing: `start_difficulty`
118    pub fn new(strategy: CurriculumStrategy) -> Self {
119        let (initial_competence, spl_lambda) = match &strategy {
120            CurriculumStrategy::Competence { initial, .. } => (*initial, 0.0),
121            CurriculumStrategy::SelfPaced { lambda } => {
122                // SPL: include sample if difficulty < lambda/2
123                // So effective competence = lambda/2 clamped to [0,1]
124                ((*lambda * 0.5).clamp(0.0, 1.0), *lambda)
125            }
126            CurriculumStrategy::Annealing {
127                start_difficulty, ..
128            } => (*start_difficulty, 0.0),
129        };
130
131        Self {
132            strategy,
133            current_competence: initial_competence.clamp(0.0, 1.0),
134            epoch: 0,
135            spl_lambda,
136            // Lambda grows by 20% each epoch by default — a reasonable pace
137            // that ensures convergence to including all samples.
138            spl_growth: 1.2,
139        }
140    }
141
142    /// Override the SPL growth factor (default 1.2).
143    ///
144    /// Only meaningful for [`CurriculumStrategy::SelfPaced`].
145    pub fn with_spl_growth(mut self, growth: f32) -> Self {
146        self.spl_growth = growth.max(1.0);
147        self
148    }
149
150    /// Advance one epoch and return the updated competence in [0, 1].
151    pub fn step(&mut self) -> f32 {
152        self.epoch += 1;
153
154        match &self.strategy {
155            CurriculumStrategy::Competence {
156                initial, increment, ..
157            } => {
158                // Linear ramp: competence = initial + epoch * increment
159                self.current_competence =
160                    (*initial + self.epoch as f32 * *increment).clamp(0.0, 1.0);
161            }
162            CurriculumStrategy::SelfPaced { .. } => {
163                // Grow lambda each epoch
164                self.spl_lambda *= self.spl_growth;
165                // Effective competence boundary = lambda / 2, clamped
166                self.current_competence = (self.spl_lambda * 0.5).clamp(0.0, 1.0);
167            }
168            CurriculumStrategy::Annealing {
169                start_difficulty,
170                end,
171            } => {
172                // We use a fixed ramp of 100 epochs for the interpolation.
173                // After 100 epochs the gate is fully at `end`.
174                let ramp_epochs = 100.0_f32;
175                let t = (self.epoch as f32 / ramp_epochs).clamp(0.0, 1.0);
176                self.current_competence =
177                    (*start_difficulty + (*end - *start_difficulty) * t).clamp(0.0, 1.0);
178            }
179        }
180
181        self.current_competence
182    }
183
184    /// Return the current competence level without advancing.
185    pub fn current_competence(&self) -> f32 {
186        self.current_competence
187    }
188
189    /// Return how many epochs have been stepped.
190    pub fn epoch(&self) -> usize {
191        self.epoch
192    }
193
194    /// Whether a sample with the given difficulty should be included at the
195    /// current competence level.
196    pub fn should_include(&self, difficulty: f32) -> bool {
197        match &self.strategy {
198            CurriculumStrategy::Competence { .. } | CurriculumStrategy::Annealing { .. } => {
199                difficulty <= self.current_competence
200            }
201            CurriculumStrategy::SelfPaced { .. } => {
202                // SPL weight: w = max(0, 1 - difficulty / lambda)
203                // Include if w > 0.5, i.e. difficulty < lambda / 2
204                if self.spl_lambda <= 0.0 {
205                    return false;
206                }
207                let w = (1.0 - difficulty / self.spl_lambda).max(0.0);
208                w > 0.5
209            }
210        }
211    }
212
213    /// Return the indices of samples whose difficulty passes the current gate.
214    pub fn filter_indices(&self, difficulties: &[f32]) -> Vec<usize> {
215        difficulties
216            .iter()
217            .enumerate()
218            .filter(|(_, &d)| self.should_include(d))
219            .map(|(i, _)| i)
220            .collect()
221    }
222
223    /// Compute SPL weight for a given difficulty.
224    ///
225    /// Returns a value in [0, 1] where 1 = easy/fully included and 0 = too hard.
226    /// For non-SPL strategies this returns 1.0 if included, 0.0 otherwise.
227    pub fn spl_weight(&self, difficulty: f32) -> f32 {
228        match &self.strategy {
229            CurriculumStrategy::SelfPaced { .. } => {
230                if self.spl_lambda <= 0.0 {
231                    return 0.0;
232                }
233                (1.0 - difficulty / self.spl_lambda).clamp(0.0, 1.0)
234            }
235            _ => {
236                if self.should_include(difficulty) {
237                    1.0
238                } else {
239                    0.0
240                }
241            }
242        }
243    }
244}
245
246// ---------------------------------------------------------------------------
247// CurriculumDataProvider
248// ---------------------------------------------------------------------------
249
250/// A data provider that wraps [`ArrayDataProvider`] and filters samples
251/// according to a [`CurriculumScheduler`].
252///
253/// Each sample has an associated difficulty score in [0, 1]. The scheduler
254/// controls which samples are visible at each epoch, starting with easy
255/// samples and progressively including harder ones.
256pub struct CurriculumDataProvider {
257    /// The underlying data.
258    inner: ArrayDataProvider,
259    /// Per-sample difficulty scores, length == `inner.num_samples()`.
260    difficulties: Array1<f32>,
261    /// The curriculum scheduler driving the pacing.
262    scheduler: CurriculumScheduler,
263    /// Cached active indices (recomputed on `advance_epoch`).
264    cached_active: Vec<usize>,
265}
266
267impl CurriculumDataProvider {
268    /// Create a new curriculum data provider.
269    ///
270    /// # Errors
271    ///
272    /// Returns an error if `difficulties.len() != data.num_samples()`.
273    pub fn new(
274        data: ArrayDataProvider,
275        difficulties: Array1<f32>,
276        strategy: CurriculumStrategy,
277    ) -> ModelResult<Self> {
278        if difficulties.len() != data.num_samples() {
279            return Err(ModelError::dimension_mismatch(
280                "CurriculumDataProvider::new",
281                data.num_samples(),
282                difficulties.len(),
283            ));
284        }
285
286        let scheduler = CurriculumScheduler::new(strategy);
287
288        // Compute initial active indices.
289        let cached_active = scheduler.filter_indices(difficulties.as_slice().unwrap_or(&[]));
290
291        Ok(Self {
292            inner: data,
293            difficulties,
294            scheduler,
295            cached_active,
296        })
297    }
298
299    /// Advance one epoch: steps the scheduler and recomputes active indices.
300    pub fn advance_epoch(&mut self) {
301        self.scheduler.step();
302        self.recompute_active();
303    }
304
305    /// Recompute the cached active indices from the current scheduler state.
306    fn recompute_active(&mut self) {
307        let diff_slice = self.difficulties.as_slice().unwrap_or(&[]);
308        self.cached_active = self.scheduler.filter_indices(diff_slice);
309    }
310
311    /// Return the indices of currently active (included) samples.
312    pub fn active_indices(&self) -> Vec<usize> {
313        self.cached_active.clone()
314    }
315
316    /// Number of currently active samples.
317    pub fn active_count(&self) -> usize {
318        self.cached_active.len()
319    }
320
321    /// Total samples in the underlying dataset (not just active ones).
322    pub fn total_samples(&self) -> usize {
323        self.inner.num_samples()
324    }
325
326    /// Fraction of the dataset currently active.
327    pub fn active_fraction(&self) -> f32 {
328        if self.inner.num_samples() == 0 {
329            return 0.0;
330        }
331        self.cached_active.len() as f32 / self.inner.num_samples() as f32
332    }
333
334    /// Access the scheduler (read-only).
335    pub fn scheduler(&self) -> &CurriculumScheduler {
336        &self.scheduler
337    }
338
339    /// Access the difficulty scores.
340    pub fn difficulties(&self) -> &Array1<f32> {
341        &self.difficulties
342    }
343
344    /// Get SPL weights for all active samples.
345    ///
346    /// Returns a vector of `(sample_index, weight)` pairs for all currently
347    /// active samples.
348    pub fn active_weights(&self) -> Vec<(usize, f32)> {
349        self.cached_active
350            .iter()
351            .map(|&idx| {
352                let d = self.difficulties[idx];
353                (idx, self.scheduler.spl_weight(d))
354            })
355            .collect()
356    }
357}
358
359impl DataProvider for CurriculumDataProvider {
360    fn num_samples(&self) -> usize {
361        // Report the number of *active* samples so training loops
362        // naturally iterate over the curriculum-filtered subset.
363        self.cached_active.len()
364    }
365
366    fn num_features(&self) -> usize {
367        self.inner.num_features()
368    }
369
370    fn get_batch(&self, indices: &[usize]) -> (Array2<f32>, Array1<f32>) {
371        // `indices` are positions within the *active* set.
372        // Map them to the original dataset indices.
373        let mapped: Vec<usize> = indices
374            .iter()
375            .map(|&i| {
376                if i < self.cached_active.len() {
377                    self.cached_active[i]
378                } else {
379                    // Clamp to last active index to avoid panic.
380                    self.cached_active.last().copied().unwrap_or(0)
381                }
382            })
383            .collect();
384
385        self.inner.get_batch(&mapped)
386    }
387
388    fn shuffle_indices(&self, rng_seed: u64) -> Vec<usize> {
389        // Shuffle over the active set only.
390        let n = self.cached_active.len();
391        let mut indices: Vec<usize> = (0..n).collect();
392        let mut state = rng_seed.wrapping_add(1);
393        for i in (1..n).rev() {
394            state = state
395                .wrapping_mul(6_364_136_223_846_793_005)
396                .wrapping_add(1_442_695_040_888_963_407);
397            let j = (state >> 33) as usize % (i + 1);
398            indices.swap(i, j);
399        }
400        indices
401    }
402}
403
404// ---------------------------------------------------------------------------
405// Difficulty estimators
406// ---------------------------------------------------------------------------
407
408/// Estimate per-sample difficulty from loss values.
409///
410/// Given a vector of per-sample losses, normalises them to [0, 1] by
411/// mapping the minimum loss to 0 and the maximum to 1. Samples with
412/// higher loss are considered harder.
413///
414/// Returns an `Array1<f32>` of difficulty scores.
415pub fn estimate_difficulty_from_loss(losses: &[f32]) -> ModelResult<Array1<f32>> {
416    if losses.is_empty() {
417        return Err(ModelError::invalid_config(
418            "Cannot estimate difficulty from empty loss vector",
419        ));
420    }
421
422    let min_loss = losses.iter().copied().fold(f32::INFINITY, f32::min);
423    let max_loss = losses.iter().copied().fold(f32::NEG_INFINITY, f32::max);
424
425    let range = max_loss - min_loss;
426
427    let difficulties = if range.abs() < f32::EPSILON {
428        // All losses are the same — assign uniform difficulty 0.5.
429        Array1::from_elem(losses.len(), 0.5)
430    } else {
431        Array1::from_vec(
432            losses
433                .iter()
434                .map(|&l| ((l - min_loss) / range).clamp(0.0, 1.0))
435                .collect(),
436        )
437    };
438
439    Ok(difficulties)
440}
441
442/// Estimate difficulty based on feature variance.
443///
444/// Samples whose features have higher variance (further from the dataset
445/// mean) are considered harder. Normalised to [0, 1].
446pub fn estimate_difficulty_from_variance(features: &Array2<f32>) -> ModelResult<Array1<f32>> {
447    let n = features.nrows();
448    if n == 0 {
449        return Err(ModelError::invalid_config(
450            "Cannot estimate difficulty from empty feature matrix",
451        ));
452    }
453
454    let ncols = features.ncols();
455
456    // Compute per-sample variance of features.
457    let mut variances = Vec::with_capacity(n);
458    for row in features.rows() {
459        let mean: f32 = row.iter().sum::<f32>() / ncols.max(1) as f32;
460        let var: f32 =
461            row.iter().map(|&x| (x - mean) * (x - mean)).sum::<f32>() / ncols.max(1) as f32;
462        variances.push(var);
463    }
464
465    estimate_difficulty_from_loss(&variances)
466}
467
468// ---------------------------------------------------------------------------
469// Tests
470// ---------------------------------------------------------------------------
471
472#[cfg(test)]
473mod tests {
474    use super::*;
475    use scirs2_core::ndarray::{Array1, Array2};
476
477    /// Helper: create an ArrayDataProvider with n samples, nf features.
478    fn make_provider(n: usize, nf: usize) -> ArrayDataProvider {
479        let features = Array2::<f32>::zeros((n, nf));
480        let targets = Array1::<f32>::zeros(n);
481        ArrayDataProvider::new(features, targets)
482    }
483
484    // 1. Competence-based pacing
485    #[test]
486    fn test_curriculum_competence_pacing() {
487        let strategy = CurriculumStrategy::Competence {
488            initial: 0.2,
489            increment: 0.1,
490        };
491        let mut sched = CurriculumScheduler::new(strategy);
492
493        // Before any step: competence = 0.2
494        assert!((sched.current_competence() - 0.2).abs() < 1e-6);
495
496        // After step 1: competence = 0.2 + 0.1 = 0.3
497        let c1 = sched.step();
498        assert!((c1 - 0.3).abs() < 1e-6);
499
500        // After step 2: 0.4
501        let c2 = sched.step();
502        assert!((c2 - 0.4).abs() < 1e-6);
503
504        // Filter test: difficulties [0.1, 0.35, 0.5, 0.9]
505        let difficulties = [0.1, 0.35, 0.5, 0.9];
506        let indices = sched.filter_indices(&difficulties);
507        // Competence = 0.4 → should include 0.1 and 0.35
508        assert_eq!(indices, vec![0, 1]);
509
510        // Step a lot — should clamp at 1.0
511        for _ in 0..20 {
512            sched.step();
513        }
514        assert!((sched.current_competence() - 1.0).abs() < 1e-6);
515        // Now all should be included
516        let all = sched.filter_indices(&difficulties);
517        assert_eq!(all.len(), 4);
518    }
519
520    // 2. Self-Paced Learning
521    #[test]
522    fn test_curriculum_self_paced() {
523        let strategy = CurriculumStrategy::SelfPaced { lambda: 0.4 };
524        let mut sched = CurriculumScheduler::new(strategy);
525
526        // Initial: lambda = 0.4, boundary = 0.2
527        // Samples below 0.2 are included
528        let difficulties = [0.1, 0.19, 0.3, 0.5, 0.8];
529
530        let easy_first = sched.filter_indices(&difficulties);
531        // 0.1 and 0.19 are < 0.2
532        assert_eq!(easy_first, vec![0, 1]);
533
534        // After several steps, lambda grows and more are included
535        for _ in 0..5 {
536            sched.step();
537        }
538
539        let more = sched.filter_indices(&difficulties);
540        // Lambda has grown: 0.4 * 1.2^5 = ~0.995
541        // Boundary = ~0.497, so 0.1, 0.19, 0.3 should now be included
542        assert!(
543            more.len() >= 3,
544            "expected >= 3 included, got {}",
545            more.len()
546        );
547
548        // After many steps, all should be included (lambda large enough)
549        for _ in 0..20 {
550            sched.step();
551        }
552        let all = sched.filter_indices(&difficulties);
553        assert_eq!(all.len(), difficulties.len());
554    }
555
556    // 3. Annealing
557    #[test]
558    fn test_curriculum_annealing() {
559        let strategy = CurriculumStrategy::Annealing {
560            start_difficulty: 0.1,
561            end: 1.0,
562        };
563        let mut sched = CurriculumScheduler::new(strategy);
564
565        // Initial gate = start_difficulty = 0.1
566        assert!((sched.current_competence() - 0.1).abs() < 1e-6);
567
568        // After 50 steps (halfway through 100 ramp): gate = 0.1 + 0.9*0.5 = 0.55
569        for _ in 0..50 {
570            sched.step();
571        }
572        assert!(
573            (sched.current_competence() - 0.55).abs() < 0.02,
574            "expected ~0.55, got {}",
575            sched.current_competence()
576        );
577
578        // After 100 steps: gate = 1.0
579        for _ in 0..50 {
580            sched.step();
581        }
582        assert!(
583            (sched.current_competence() - 1.0).abs() < 1e-6,
584            "expected 1.0, got {}",
585            sched.current_competence()
586        );
587
588        // All samples should be included at gate = 1.0
589        let diffs = [0.0, 0.3, 0.7, 1.0];
590        let included = sched.filter_indices(&diffs);
591        assert_eq!(included.len(), 4);
592    }
593
594    // 4. CurriculumDataProvider active indices grow
595    #[test]
596    fn test_curriculum_provider_active_indices() {
597        let provider = make_provider(10, 2);
598        // Difficulties: 0.0, 0.1, 0.2, ..., 0.9
599        let difficulties = Array1::from_vec((0..10).map(|i| i as f32 * 0.1).collect());
600
601        let strategy = CurriculumStrategy::Competence {
602            initial: 0.0,
603            increment: 0.2,
604        };
605
606        let mut cp = CurriculumDataProvider::new(provider, difficulties, strategy)
607            .expect("construction should succeed");
608
609        // Initial competence = 0.0 → only difficulty == 0.0 passes (index 0)
610        let a0 = cp.active_indices();
611        assert_eq!(a0, vec![0]);
612
613        // After 1 epoch: competence = 0.2 → indices 0,1,2 (difficulties 0.0, 0.1, 0.2)
614        cp.advance_epoch();
615        let a1 = cp.active_indices();
616        assert_eq!(a1, vec![0, 1, 2]);
617
618        // After 2 epochs: competence = 0.4 → indices 0..=4
619        cp.advance_epoch();
620        let a2 = cp.active_indices();
621        assert_eq!(a2, vec![0, 1, 2, 3, 4]);
622
623        // Set grows monotonically
624        assert!(a2.len() > a1.len());
625        assert!(a1.len() > a0.len());
626    }
627
628    // 5. CurriculumDataProvider implements DataProvider correctly
629    #[test]
630    fn test_curriculum_provider_implements_data_provider() {
631        let features = Array2::from_shape_vec(
632            (5, 2),
633            vec![
634                1.0, 2.0, // easy
635                3.0, 4.0, // easy
636                5.0, 6.0, // medium
637                7.0, 8.0, // hard
638                9.0, 10.0, // hard
639            ],
640        )
641        .expect("shape ok");
642        let targets = Array1::from_vec(vec![10.0, 20.0, 30.0, 40.0, 50.0]);
643        let data = ArrayDataProvider::new(features, targets);
644
645        let difficulties = Array1::from_vec(vec![0.1, 0.15, 0.5, 0.8, 0.95]);
646        let strategy = CurriculumStrategy::Competence {
647            initial: 0.2,
648            increment: 0.1,
649        };
650
651        let cp = CurriculumDataProvider::new(data, difficulties, strategy)
652            .expect("construction should succeed");
653
654        // At initial competence 0.2: indices 0,1 are active (difficulties 0.1, 0.15)
655        assert_eq!(cp.num_samples(), 2);
656        assert_eq!(cp.num_features(), 2);
657
658        // get_batch with active-set indices [0, 1] → original indices [0, 1]
659        let (feat, tgt) = cp.get_batch(&[0, 1]);
660        assert_eq!(feat.shape(), &[2, 2]);
661        assert_eq!(tgt.len(), 2);
662
663        // Verify actual values match original samples 0 and 1
664        assert!((feat[[0, 0]] - 1.0).abs() < 1e-6);
665        assert!((feat[[1, 0]] - 3.0).abs() < 1e-6);
666        assert!((tgt[0] - 10.0).abs() < 1e-6);
667        assert!((tgt[1] - 20.0).abs() < 1e-6);
668    }
669}