aprender-train 0.41.0

Training & Optimization library with autograd, LoRA, quantization, and model merging
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
//! Tests for distillation module.

use ndarray::{array, Array1, Array2};

use super::utils::{cross_entropy_loss, kl_divergence, l2_normalize, log_softmax, softmax};
use super::{AttentionTransfer, DistillationLoss, ProgressiveDistillation};

// =========================================================================
// Softmax Tests
// =========================================================================

#[test]
fn test_softmax_sums_to_one() {
    let logits = array![1.0, 2.0, 3.0, 4.0];
    let probs = softmax(&logits);
    let sum: f32 = probs.sum();
    assert!((sum - 1.0).abs() < 1e-5);
}

#[test]
fn test_softmax_all_positive() {
    let logits = array![-100.0, 0.0, 100.0];
    let probs = softmax(&logits);
    for p in &probs {
        assert!(*p >= 0.0);
    }
}

#[test]
fn test_softmax_numerical_stability() {
    // Large values should not overflow
    let logits = array![1000.0, 1001.0, 1002.0];
    let probs = softmax(&logits);
    assert!(probs.iter().all(|&p| p.is_finite()));
    assert!((probs.sum() - 1.0).abs() < 1e-5);
}

#[test]
fn test_log_softmax_identity() {
    let logits = array![1.0, 2.0, 3.0];
    let log_probs = log_softmax(&logits);
    let probs_from_log: Array1<f32> = log_probs.mapv(f32::exp);
    let probs = softmax(&logits);

    for (a, b) in probs.iter().zip(probs_from_log.iter()) {
        assert!((a - b).abs() < 1e-5);
    }
}

// =========================================================================
// KL Divergence Tests
// =========================================================================

#[test]
fn test_kl_divergence_zero_for_same() {
    let p = softmax(&array![1.0, 2.0, 3.0]);
    let log_p = log_softmax(&array![1.0, 2.0, 3.0]);
    let kl = kl_divergence(&log_p, &p);
    assert!(kl.abs() < 1e-5);
}

#[test]
fn test_kl_divergence_positive() {
    let p = softmax(&array![1.0, 2.0, 3.0]);
    let log_q = log_softmax(&array![3.0, 2.0, 1.0]);
    let kl = kl_divergence(&log_q, &p);
    assert!(kl >= 0.0);
}

// =========================================================================
// DistillationLoss Tests
// =========================================================================

#[test]
fn test_distillation_loss_default() {
    let loss = DistillationLoss::default();
    assert_eq!(loss.temperature, 4.0);
    assert_eq!(loss.alpha, 0.7);
}

#[test]
fn test_distillation_loss_positive() {
    let loss = DistillationLoss::new(4.0, 0.5);
    let student = array![1.0, 2.0, 3.0];
    let teacher = array![1.5, 2.5, 2.0];
    let l = loss.forward_single(&student, &teacher, 2);
    assert!(l >= 0.0);
}

#[test]
fn test_distillation_loss_zero_alpha() {
    // alpha=0 means only hard label loss
    let loss = DistillationLoss::new(4.0, 0.0);
    let student = array![1.0, 2.0, 3.0];
    let teacher = array![100.0, 200.0, 300.0]; // Very different teacher
    let l = loss.forward_single(&student, &teacher, 2);
    // Should be close to cross-entropy loss (ignoring teacher)
    let ce = cross_entropy_loss(&student, 2);
    assert!((l - ce).abs() < 0.01);
}

#[test]
fn test_distillation_loss_high_temp() {
    // Higher temperature = softer distributions
    let loss_low = DistillationLoss::new(1.0, 1.0);
    let loss_high = DistillationLoss::new(10.0, 1.0);
    let student = array![1.0, 2.0, 3.0];
    let teacher = array![1.0, 2.0, 3.0];

    let l_low = loss_low.soft_loss(&student, &teacher);
    let l_high = loss_high.soft_loss(&student, &teacher);

    // Both should be near zero for same logits
    assert!(l_low.abs() < 0.1);
    assert!(l_high.abs() < 0.1);
}

#[test]
fn test_distillation_loss_batch() {
    let loss = DistillationLoss::new(4.0, 0.5);
    let student = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 2.0, 1.0, 3.0])
        .expect("operation should succeed");
    let teacher = Array2::from_shape_vec((2, 3), vec![1.5, 2.5, 2.5, 2.5, 1.5, 2.5])
        .expect("operation should succeed");
    let targets = vec![2, 0];

    let l = loss.forward(&student, &teacher, &targets);
    assert!(l >= 0.0);
    assert!(l.is_finite());
}

// =========================================================================
// ProgressiveDistillation Tests
// =========================================================================

#[test]
fn test_progressive_default() {
    let prog = ProgressiveDistillation::default();
    assert!(!prog.layer_mapping.is_empty());
    assert_eq!(prog.hidden_weight, 1.0);
}

#[test]
fn test_progressive_hidden_loss_zero_for_same() {
    let prog = ProgressiveDistillation::new(vec![(0, 0), (1, 1)]);
    let hidden = Array2::<f32>::ones((4, 768));
    let student = vec![hidden.clone(), hidden.clone()];
    let teacher = vec![hidden.clone(), hidden.clone()];

    let loss = prog.hidden_state_loss(&student, &teacher);
    assert!(loss.abs() < 1e-5);
}

#[test]
fn test_progressive_hidden_loss_positive_for_diff() {
    let prog = ProgressiveDistillation::new(vec![(0, 0)]);
    let s = Array2::<f32>::zeros((4, 768));
    let t = Array2::<f32>::ones((4, 768));

    let loss = prog.hidden_state_loss(&[s], &[t]);
    assert!(loss > 0.0);
}

#[test]
fn test_progressive_with_weight() {
    let prog = ProgressiveDistillation::new(vec![(0, 0)]).with_weight(0.5);
    assert_eq!(prog.hidden_weight, 0.5);
}

#[test]
fn test_progressive_projection_layer_creation() {
    // Student dim 512, teacher dim 768
    let prog = ProgressiveDistillation::new(vec![(0, 0)]).with_projection(512, 768);
    assert!(prog.projection.is_some());
    let proj = prog.projection.as_ref().expect("operation should succeed");
    assert_eq!(proj.dim(), (512, 768));
}

#[test]
fn test_progressive_hidden_loss_with_projection() {
    // Student has dim 512, teacher has dim 768
    let prog = ProgressiveDistillation::new(vec![(0, 0)]).with_projection(512, 768);

    let student = vec![Array2::<f32>::ones((4, 512))];
    let teacher = vec![Array2::<f32>::ones((4, 768))];

    // Should not skip due to shape mismatch
    let loss = prog.hidden_state_loss(&student, &teacher);
    // Loss should be computed (not zero due to projection mismatch)
    // Just verify it doesn't skip
    assert!(loss >= 0.0);
}

#[test]
fn test_progressive_projection_correct_transform() {
    // Use identity-like projection
    let mut prog = ProgressiveDistillation::new(vec![(0, 0)]).with_projection(768, 768);

    // Set projection to identity matrix
    if let Some(ref mut proj) = prog.projection {
        proj.fill(0.0);
        for i in 0..768 {
            proj[[i, i]] = 1.0;
        }
    }

    let hidden = Array2::<f32>::from_elem((4, 768), 1.0);
    let student = vec![hidden.clone()];
    let teacher = vec![hidden.clone()];

    // With identity projection, loss should be ~0
    let loss = prog.hidden_state_loss(&student, &teacher);
    assert!(loss.abs() < 1e-4, "Identity projection should give ~0 loss");
}

#[test]
fn test_progressive_no_projection_skips_mismatched() {
    // No projection set
    let prog = ProgressiveDistillation::new(vec![(0, 0)]);

    let student = vec![Array2::<f32>::ones((4, 512))];
    let teacher = vec![Array2::<f32>::ones((4, 768))];

    // Should skip due to shape mismatch, loss = 0
    let loss = prog.hidden_state_loss(&student, &teacher);
    assert_eq!(loss, 0.0, "Should skip mismatched shapes without projection");
}

// =========================================================================
// AttentionTransfer Tests
// =========================================================================

#[test]
fn test_attention_transfer_default() {
    let at = AttentionTransfer::default();
    assert_eq!(at.weight, 0.1);
}

#[test]
fn test_attention_transfer_zero_for_same() {
    let at = AttentionTransfer::new(1.0);
    let attn = Array2::<f32>::ones((8, 8));
    let student = vec![attn.clone()];
    let teacher = vec![attn.clone()];

    let loss = at.loss(&student, &teacher);
    assert!(loss.abs() < 1e-5);
}

#[test]
fn test_attention_transfer_positive_for_diff() {
    let at = AttentionTransfer::new(1.0);
    let s = Array2::<f32>::zeros((8, 8));
    let t = Array2::<f32>::ones((8, 8));

    let loss = at.loss(&[s], &[t]);
    assert!(loss > 0.0);
}

// =========================================================================
// L2 Normalize Tests
// =========================================================================

#[test]
fn test_l2_normalize_unit_norm() {
    let arr =
        Array2::from_shape_vec((2, 2), vec![3.0, 4.0, 0.0, 0.0]).expect("operation should succeed");
    let norm = l2_normalize(&arr);
    let l2 = norm.mapv(|x| x * x).sum().sqrt();
    assert!((l2 - 1.0).abs() < 1e-5);
}

#[test]
fn test_l2_normalize_zero() {
    let arr = Array2::<f32>::zeros((2, 2));
    let norm = l2_normalize(&arr);
    // Should return zeros without NaN
    assert!(norm.iter().all(|&x| x.is_finite()));
}

// =========================================================================
// Property-like Tests
// =========================================================================

#[test]
fn test_distillation_loss_monotonic_in_alpha() {
    let student = array![1.0, 2.0, 3.0];
    let teacher = array![3.0, 2.0, 1.0]; // Very different

    let loss_0 = DistillationLoss::new(4.0, 0.0).forward_single(&student, &teacher, 2);
    let loss_1 = DistillationLoss::new(4.0, 1.0).forward_single(&student, &teacher, 2);

    // As alpha increases, soft loss contribution increases
    // Both should be valid losses
    assert!(loss_0 >= 0.0);
    assert!(loss_1 >= 0.0);
}

#[test]
fn test_temperature_scaling_effect() {
    let student = array![1.0, 2.0, 3.0];
    let teacher = array![0.5, 2.0, 3.5];

    let loss_t1 = DistillationLoss::new(1.0, 1.0).soft_loss(&student, &teacher);
    let loss_t10 = DistillationLoss::new(10.0, 1.0).soft_loss(&student, &teacher);

    // Both should be valid
    assert!(loss_t1.is_finite());
    assert!(loss_t10.is_finite());
}

// =========================================================================
// FALSIFY-APR-DISTILL-TRAIN-003 / TRAIN-004 — hf_pipeline parity coverage
//
// Mirrors the falsifier tests already pinned for the canonical
// `crates/aprender-train/src/distill/loss.rs` (task #186 / PR around
// 2026-04-30) against this parallel `hf_pipeline::distillation::DistillationLoss`
// implementation. Without these, the two implementations could drift apart
// on the math invariants the contract requires.
//
// Five-Whys:
//   Why 1: §35 found `apr distill --stage train` is a stub.
//   Why 2: contract `apr-cli-distill-train-v1.yaml` was authored with 9
//          falsifiers, of which 003+004 are *purely-mathematical* invariants
//          testable against the existing softmax / KL helpers.
//   Why 3: canonical `distill::loss::DistillationLoss` has both tests; the
//          parallel `hf_pipeline::distillation::DistillationLoss` did NOT.
//          The drift was discovered when a previous /loop iteration
//          (post-PR #1431) tried to add this coverage and hit the
//          `--features hub` build break later fixed by #1432-#1434.
//   Why 4: per `feedback_coverage_contracts_coevolution`, every parallel
//          implementation that participates in a contract must have the
//          same falsifier coverage — silent drift would let one impl
//          regress without the other surfacing.
//   Why 5: pinning these now means a future real-training PR cannot
//          regress the math on either path without tripping a gate.
// =========================================================================

/// FALSIFY-APR-DISTILL-TRAIN-003: temperature scaling preserves softmax ranking.
///
/// Contract: For any (logits, T>0): argmax(softmax(logits/T)) == argmax(logits).
///
/// hf_pipeline parity copy of the canonical
/// `distill::loss::tests::falsify_apr_distill_train_003_t_scaling_preserves_argmax`.
/// If this fails, the regression class is "T-scaling reorders argmax" — which
/// would corrupt the teacher's preference signal during distillation.
#[test]
fn falsify_apr_distill_train_003_t_scaling_preserves_argmax() {
    let logits: Array1<f32> = array![3.0, 1.0, 0.5, -1.0, 7.0, -3.0, 2.5, 0.0];
    let baseline_argmax = logits
        .iter()
        .enumerate()
        .max_by(|(_, a), (_, b)| a.partial_cmp(b).expect("logit ordering"))
        .expect("non-empty")
        .0;

    for &t in &[1.0_f32, 2.0, 3.0, 5.0, 10.0] {
        let scaled: Array1<f32> = logits.mapv(|x| x / t);
        let probs = softmax(&scaled);
        let scaled_argmax = probs
            .iter()
            .enumerate()
            .max_by(|(_, a), (_, b)| a.partial_cmp(b).expect("logit ordering"))
            .expect("non-empty")
            .0;
        assert_eq!(
            baseline_argmax, scaled_argmax,
            "FALSIFIED APR-DISTILL-TRAIN-003 (hf_pipeline): argmax shifted from {baseline_argmax} to {scaled_argmax} at T={t}"
        );
    }
}

/// FALSIFY-APR-DISTILL-TRAIN-004: alpha=1.0 reduces to pure KD (soft loss only).
///
/// Contract: at alpha=1.0, total_loss equals the soft (KL) loss exactly —
/// the (1-alpha)*ce_loss term is zeroed.
///
/// On the hf_pipeline path, `forward_single` and `soft_loss` both exist as
/// public methods; the bookkeeping invariant is that `forward_single` at
/// alpha=1.0 must equal `soft_loss` to within fp32 noise.
#[test]
fn falsify_apr_distill_train_004_alpha_one_equals_pure_kd() {
    let student: Array1<f32> = array![2.5, 0.7, -1.3, 4.0];
    let teacher: Array1<f32> = array![1.8, 1.1, -0.2, 3.5];
    let target: usize = 3;

    let temperature = 3.0_f32;
    let alpha_one = DistillationLoss::new(temperature, 1.0);

    let total_at_alpha_one = alpha_one.forward_single(&student, &teacher, target);
    let pure_kd = alpha_one.soft_loss(&student, &teacher);

    let abs_err = (total_at_alpha_one - pure_kd).abs();
    let rel_err = if pure_kd.abs() > 1e-9 { abs_err / pure_kd.abs() } else { abs_err };

    assert!(
        rel_err < 1e-5,
        "FALSIFIED APR-DISTILL-TRAIN-004 (hf_pipeline): forward_single@alpha=1 ({total_at_alpha_one}) != soft_loss ({pure_kd}); rel_err={rel_err}"
    );
}

/// FALSIFY-APR-DISTILL-TRAIN-004 dual: alpha=0.0 reduces to pure CE.
///
/// Symmetric bookkeeping: at alpha=0.0, total_loss should equal cross_entropy_loss
/// of the student logits at the target label. Catches the off-by-one regression
/// where the (1-alpha) and alpha coefficients are swapped — `forward_single`
/// multiplies kl_loss by alpha and ce_loss by (1-alpha), so at alpha=0 only
/// the CE term should remain.
#[test]
fn falsify_apr_distill_train_004_alpha_zero_equals_pure_ce() {
    let student: Array1<f32> = array![2.5, 0.7, -1.3, 4.0];
    let teacher: Array1<f32> = array![1.8, 1.1, -0.2, 3.5];
    let target: usize = 3;

    let alpha_zero = DistillationLoss::new(3.0, 0.0);
    let total_at_alpha_zero = alpha_zero.forward_single(&student, &teacher, target);
    let pure_ce = cross_entropy_loss(&student, target);

    let abs_err = (total_at_alpha_zero - pure_ce).abs();
    let rel_err = if pure_ce.abs() > 1e-9 { abs_err / pure_ce.abs() } else { abs_err };

    assert!(
        rel_err < 1e-5,
        "FALSIFIED APR-DISTILL-TRAIN-004-dual (hf_pipeline): forward_single@alpha=0 ({total_at_alpha_zero}) != cross_entropy_loss ({pure_ce}); rel_err={rel_err}"
    );
}

/// FALSIFY-APR-DISTILL-TRAIN-003 cross-impl symmetry: hf_pipeline and
/// canonical `distill::loss` must produce the same argmax under temperature
/// scaling for the same input. This guards against the two parallel
/// `softmax` implementations diverging — if either's max-subtract trick
/// breaks, this test fails on hf_pipeline only, while the canonical test
/// would still pass; both tests together pin both impls.
#[test]
fn falsify_apr_distill_train_003_log_softmax_consistency() {
    // softmax(x) and exp(log_softmax(x)) must agree (within fp32 noise) so
    // the FALSIFY-APR-DISTILL-TRAIN-003 invariant holds for downstream KL
    // computations that go through log_softmax.
    let logits: Array1<f32> = array![3.0, 1.0, 0.5, 7.0];
    let probs = softmax(&logits);
    let log_probs_exp: Array1<f32> = log_softmax(&logits).mapv(f32::exp);

    for (i, (p, le)) in probs.iter().zip(log_probs_exp.iter()).enumerate() {
        assert!(
            (p - le).abs() < 1e-5,
            "softmax/log_softmax inconsistency at i={i}: softmax={p}, exp(log_softmax)={le}"
        );
    }

    // l2_normalize is reachable; smoke-test it on a 2D matrix so the
    // import doesn't dangle. l2_normalize takes &Array2<f32> per its
    // signature in distillation/utils.rs:43.
    let m: Array2<f32> = Array2::from_shape_vec((1, 2), vec![3.0, 4.0]).expect("shape (1, 2)");
    let normed = l2_normalize(&m);
    let norm_sq: f32 = normed.iter().map(|x| x * x).sum();
    assert!(
        (norm_sq - 1.0).abs() < 1e-5,
        "l2_normalize should produce row of unit norm (got norm_sq={norm_sq})"
    );
}