gam-sae 0.3.139

Sparse-autoencoder latent-manifold terms for the gam penalized-likelihood engine
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
use super::codes::solve_row_codes;
use super::scoring::{TileScorer, top_s_online};
use super::{SparseDictConfig, fit_sparse_dictionary};
use ndarray::{Array2, ArrayView2};

/// Build an exact rank-1 mixture: `K` orthonormal planted atoms (rows of an
/// orthonormal basis), each row a scaled single atom plus a tiny second atom.
fn planted(k: usize, p: usize, n: usize, second_share: f32) -> (Array2<f32>, Array2<f32>) {
    // Deterministic orthonormal directions from a fixed integer-symmetric matrix.
    let mut a = Array2::<f64>::zeros((p, p));
    for i in 0..p {
        for j in 0..p {
            a[[i, j]] = ((i * 7 + j * 3 + 1) % 11) as f64 - 5.0;
        }
    }
    let sym = &a + &a.t();
    use gam_linalg::faer_ndarray::FaerEigh;
    let (_ev, evecs) = sym.eigh(faer::Side::Lower).expect("orthonormal seed");
    let mut atoms = Array2::<f32>::zeros((k, p));
    for atom in 0..k {
        let col = evecs.column(atom % p);
        for c in 0..p {
            atoms[[atom, c]] = col[c] as f32;
        }
    }
    let mut x = Array2::<f32>::zeros((n, p));
    for row in 0..n {
        let primary = row % k;
        let secondary = (primary + 1) % k;
        let scale = 0.7 + 0.01 * (row / k) as f32;
        for c in 0..p {
            x[[row, c]] =
                scale * atoms[[primary, c]] + second_share * scale * atoms[[secondary, c]];
        }
    }
    (x, atoms)
}

/// PCA explained variance of the best rank-`r` subspace (linear baseline).
fn pca_ev(x: ArrayView2<'_, f32>, rank: usize) -> f64 {
    let n = x.nrows();
    let p = x.ncols();
    let mut means = vec![0.0f64; p];
    for i in 0..n {
        for c in 0..p {
            means[c] += x[[i, c]] as f64;
        }
    }
    for c in 0..p {
        means[c] /= n as f64;
    }
    let mut cov = Array2::<f64>::zeros((p, p));
    for i in 0..n {
        for a in 0..p {
            let xa = x[[i, a]] as f64 - means[a];
            for b in 0..p {
                cov[[a, b]] += xa * (x[[i, b]] as f64 - means[b]);
            }
        }
    }
    use gam_linalg::faer_ndarray::FaerEigh;
    let (evals, _) = cov.eigh(faer::Side::Lower).expect("pca eig");
    // eigh returns ascending; sum of top-`rank` over total.
    let total: f64 = evals.iter().sum();
    let mut sorted: Vec<f64> = evals.to_vec();
    sorted.sort_by(|a, b| b.partial_cmp(a).unwrap());
    let top: f64 = sorted.iter().take(rank).sum();
    if total <= 1.0e-24 { 1.0 } else { top / total }
}

/// Held-out reconstruction EV of a fitted dictionary `decoder` on a *fresh*
/// block `x_test` it never trained on. The decoder is FROZEN: each test row is
/// routed (top-`s`) against it and its codes are the active-set LS solve — the
/// exact production held-out path (`ManifoldSAE.reconstruct`), one decoder, new
/// coordinates. EV is `1 − RSS/TSS` with the TSS centred on `x_test`'s own mean,
/// so a dictionary that merely memorised the train block earns nothing here.
fn held_out_ev(
    decoder: ArrayView2<'_, f32>,
    x_test: ArrayView2<'_, f32>,
    s: usize,
    tile: usize,
    code_ridge: f32,
) -> f64 {
    let n = x_test.nrows();
    let p = x_test.ncols();
    let scorer = TileScorer::new(s, tile);
    let mut means = vec![0.0f64; p];
    for i in 0..n {
        for c in 0..p {
            means[c] += x_test[[i, c]] as f64;
        }
    }
    for c in 0..p {
        means[c] /= n as f64;
    }
    let mut rss = 0.0f64;
    let mut tss = 0.0f64;
    for i in 0..n {
        let row = x_test.row(i);
        let active = scorer.route_row(row, decoder);
        let code = solve_row_codes(row, decoder, &active, s, code_ridge);
        let mut recon = vec![0.0f64; p];
        for j in 0..code.indices.len() {
            let cj = code.codes[j] as f64;
            if cj == 0.0 {
                continue;
            }
            let drow = decoder.row(code.indices[j] as usize);
            for c in 0..p {
                recon[c] += cj * drow[c] as f64;
            }
        }
        for c in 0..p {
            let r = x_test[[i, c]] as f64 - recon[c];
            rss += r * r;
            let t = x_test[[i, c]] as f64 - means[c];
            tss += t * t;
        }
    }
    if tss <= 1.0e-24 {
        if rss <= 1.0e-24 { 1.0 } else { 0.0 }
    } else {
        1.0 - rss / tss
    }
}

/// HELD-OUT rank-`r` PCA EV: principal subspace fitted on `x_train` ONLY, then
/// scored on `x_test`. This is the honest linear baseline the sparse trainer
/// must match-or-beat — the rank-`r` linear autoencoder's out-of-sample
/// reconstruction, with NO leakage of the test block into the basis.
fn pca_ev_held_out(x_train: ArrayView2<'_, f32>, x_test: ArrayView2<'_, f32>, rank: usize) -> f64 {
    let p = x_train.ncols();
    let ntr = x_train.nrows();
    let mut means = vec![0.0f64; p];
    for i in 0..ntr {
        for c in 0..p {
            means[c] += x_train[[i, c]] as f64;
        }
    }
    for c in 0..p {
        means[c] /= ntr as f64;
    }
    // Train covariance → top-`rank` eigenvectors (the PCA basis).
    let mut cov = Array2::<f64>::zeros((p, p));
    for i in 0..ntr {
        for a in 0..p {
            let xa = x_train[[i, a]] as f64 - means[a];
            for b in 0..p {
                cov[[a, b]] += xa * (x_train[[i, b]] as f64 - means[b]);
            }
        }
    }
    use gam_linalg::faer_ndarray::FaerEigh;
    let (evals, evecs) = cov.eigh(faer::Side::Lower).expect("pca eig");
    // eigh returns ascending eigenvalues; take the top-`rank` columns.
    let mut order: Vec<usize> = (0..p).collect();
    order.sort_by(|&a, &b| evals[b].partial_cmp(&evals[a]).unwrap());
    let keep: Vec<usize> = order.into_iter().take(rank.min(p)).collect();
    // Project test rows onto the train PCA subspace and reconstruct.
    let nte = x_test.nrows();
    let mut means_te = vec![0.0f64; p];
    for i in 0..nte {
        for c in 0..p {
            means_te[c] += x_test[[i, c]] as f64;
        }
    }
    for c in 0..p {
        means_te[c] /= nte as f64;
    }
    let mut rss = 0.0f64;
    let mut tss = 0.0f64;
    for i in 0..nte {
        // Centre on the TRAIN mean (the basis's origin) for reconstruction.
        let mut centred = vec![0.0f64; p];
        for c in 0..p {
            centred[c] = x_test[[i, c]] as f64 - means[c];
        }
        let mut recon = vec![0.0f64; p];
        for &k in &keep {
            let mut coord = 0.0f64;
            for c in 0..p {
                coord += centred[c] * evecs[[c, k]];
            }
            for c in 0..p {
                recon[c] += coord * evecs[[c, k]];
            }
        }
        for c in 0..p {
            let r = centred[c] - recon[c];
            rss += r * r;
            // TSS centred on the test mean: the variance an honest baseline must explain.
            let t = x_test[[i, c]] as f64 - means_te[c];
            tss += t * t;
        }
    }
    if tss <= 1.0e-24 {
        if rss <= 1.0e-24 { 1.0 } else { 0.0 }
    } else {
        1.0 - rss / tss
    }
}

#[test]
fn online_top_s_recovers_planted_largest_scores() {
    // A single row whose true top-s atoms are known: scores are
    // exactly the decoder · row dot products, so the planted maxima must win.
    // One distinct unit axis per atom so every atom's score is unique: the
    // planted top-3 atoms (by |xᵀd|) are unambiguous.
    let p = 50;
    let k = 50;
    let mut decoder = Array2::<f32>::zeros((k, p));
    for atom in 0..k {
        decoder[[atom, atom]] = 1.0;
    }
    let mut row = ndarray::Array1::<f32>::zeros(p);
    // Strongest atom is index 17 (score 9), then 4 (score 5), then 31 (score 3).
    row[17] = 9.0;
    row[4] = 5.0;
    row[31] = 3.0;
    let picked = top_s_online(row.view(), decoder.view(), 3, 8);
    let want_atoms = [17u32, 4u32, 31u32];
    assert_eq!(picked.len(), 3);
    for (rank, &(atom, score)) in picked.iter().enumerate() {
        assert_eq!(
            atom, want_atoms[rank],
            "rank {rank}: expected atom {}, got atom {atom} (score {score})",
            want_atoms[rank]
        );
    }
}

#[test]
fn tile_scorer_matches_untiled_brute_force() {
    let p = 5;
    let k = 37;
    let mut decoder = Array2::<f32>::zeros((k, p));
    for atom in 0..k {
        for c in 0..p {
            decoder[[atom, c]] = (((atom * 3 + c * 5 + 1) % 7) as f32 - 3.0) / 3.0;
        }
    }
    let row = ndarray::Array1::<f32>::from_vec((0..p).map(|c| (c as f32) - 2.0).collect());
    // Brute force: full score then argsort.
    let mut brute: Vec<(u32, f32)> = (0..k)
        .map(|a| {
            let mut acc = 0.0f32;
            for c in 0..p {
                acc += row[c] * decoder[[a, c]];
            }
            (a as u32, acc)
        })
        .collect();
    brute.sort_by(|x, y| {
        y.1.abs()
            .partial_cmp(&x.1.abs())
            .unwrap()
            .then(x.0.cmp(&y.0))
    });
    let scorer = TileScorer::new(4, 7);
    let tiled = scorer.route_row(row.view(), decoder.view());
    assert_eq!(tiled.len(), 4);
    for j in 0..4 {
        assert_eq!(
            tiled[j].0, brute[j].0,
            "tiled top-{j} disagrees with brute force"
        );
    }
}

#[test]
fn sparse_trainer_recovers_planted_dictionary_beats_pca_baseline() {
    // Planted K-atom rank-1 mixture; the sparse trainer with top_s=2 should
    // reconstruct it at high EV and match-or-beat a rank-K PCA baseline.
    let (k, p, n) = (8usize, 12usize, 480usize);
    let (x, _atoms) = planted(k, p, n, 0.2);
    let config = SparseDictConfig {
        n_atoms: k,
        active: 2,
        minibatch: 128,
        max_epochs: 40,
        score_tile: 16,
        code_ridge: 1.0e-6,
        decoder_ridge: 1.0e-6,
        tolerance: 1.0e-9,
    };
    let fit = fit_sparse_dictionary(x.view(), &config).expect("sparse dictionary fit");
    let baseline = pca_ev(x.view(), k);
    assert!(
        fit.explained_variance > 0.95,
        "expected EV > 0.95, got {}",
        fit.explained_variance
    );
    assert!(
        fit.explained_variance + 1.0e-6 >= baseline,
        "sparse trainer EV {} must match-or-beat rank-{k} PCA baseline {}",
        fit.explained_variance,
        baseline
    );
}

#[test]
fn sparse_trainer_beats_rank_k_pca_on_held_out_reconstruction() {
    // #1026 MVP ACCEPTANCE (the real one): on a planted dictionary at modest K,
    // the trainer's route→sparse-codes→decoder-update must recover HELD-OUT
    // reconstruction EV that match-or-beats a rank-K linear/PCA baseline fitted
    // on the SAME train block — out of sample, no leakage. A planted sparse
    // mixture (each row a handful of atoms drawn from a K-atom over-complete
    // dictionary, K > p) is exactly the regime where a sparse top-s code beats a
    // rank-K dense subspace: the linear PCA of a p-dim block saturates at rank p,
    // but the sparse dictionary keeps resolving distinct atoms past p.
    let (k, p, n) = (64usize, 16usize, 1600usize);
    // Over-complete planted dictionary: K=64 atoms in p=16 dims, each row a
    // 2-sparse combination. Linear PCA caps at rank 16; the sparse code does not.
    let (x, _atoms) = planted(k, p, n, 0.35);
    // Deterministic 80/20 split (stride the rows so both blocks see every atom).
    let n_test = n / 5;
    let mut train_rows: Vec<usize> = Vec::new();
    let mut test_rows: Vec<usize> = Vec::new();
    for i in 0..n {
        if i % 5 == 0 {
            test_rows.push(i);
        } else {
            train_rows.push(i);
        }
    }
    let mut x_train = Array2::<f32>::zeros((train_rows.len(), p));
    for (r, &i) in train_rows.iter().enumerate() {
        x_train.row_mut(r).assign(&x.row(i));
    }
    let mut x_test = Array2::<f32>::zeros((test_rows.len(), p));
    for (r, &i) in test_rows.iter().enumerate() {
        x_test.row_mut(r).assign(&x.row(i));
    }
    assert_eq!(x_test.nrows(), n_test);

    let s = 2usize;
    let tile = 16usize;
    let code_ridge = 1.0e-6f32;
    let config = SparseDictConfig {
        n_atoms: k,
        active: s,
        minibatch: 256,
        max_epochs: 60,
        score_tile: tile,
        code_ridge,
        decoder_ridge: 1.0e-6,
        tolerance: 1.0e-9,
    };
    // Fit the dictionary on TRAIN ONLY.
    let fit = fit_sparse_dictionary(x_train.view(), &config).expect("held-out trainer fit");

    // Held-out EV: frozen decoder, fresh test-row codes (production path).
    let sparse_out = held_out_ev(fit.decoder.view(), x_test.view(), s, tile, code_ridge);
    // Linear baseline: rank-K PCA fitted on train, scored on test. With K > p the
    // rank is clamped to p, so this is the best possible LINEAR autoencoder here.
    let pca_out = pca_ev_held_out(x_train.view(), x_test.view(), k);

    assert!(
        sparse_out > 0.9,
        "held-out sparse-dictionary EV {sparse_out} should explain the planted held-out block"
    );
    assert!(
        sparse_out + 1.0e-4 >= pca_out,
        "held-out sparse EV {sparse_out} must match-or-beat held-out rank-{k} PCA baseline {pca_out}"
    );
}

#[test]
fn fixed_width_sparse_storage_never_dense_and_reconstructs() {
    let (k, p, n) = (6usize, 8usize, 240usize);
    let (x, _atoms) = planted(k, p, n, 0.0);
    let config = SparseDictConfig {
        n_atoms: k,
        active: 1,
        max_epochs: 30,
        score_tile: 4,
        ..SparseDictConfig::new(k)
    };
    let fit = fit_sparse_dictionary(x.view(), &config).expect("fit");
    // Storage is fixed-width N×s, NOT N×K.
    assert_eq!(fit.indices.dim(), (n, 1));
    assert_eq!(fit.codes.dim(), (n, 1));
    assert_eq!(fit.decoder.dim(), (k, p));
    // Reconstruction EV from the packed sparse codes matches the reported EV.
    let recon = fit.reconstruct();
    let mut rss = 0.0f64;
    let mut tss = 0.0f64;
    let mut means = vec![0.0f64; p];
    for i in 0..n {
        for c in 0..p {
            means[c] += x[[i, c]] as f64;
        }
    }
    for c in 0..p {
        means[c] /= n as f64;
    }
    for i in 0..n {
        for c in 0..p {
            let r = x[[i, c]] as f64 - recon[[i, c]] as f64;
            rss += r * r;
            let t = x[[i, c]] as f64 - means[c];
            tss += t * t;
        }
    }
    let recon_ev = 1.0 - rss / tss;
    assert!(
        (recon_ev - fit.explained_variance).abs() < 1.0e-4,
        "packed-code reconstruction EV {recon_ev} disagrees with reported {}",
        fit.explained_variance
    );
}

#[test]
fn route_minibatch_returns_a_valid_top_s() {
    // The batched-GEMM minibatch router must return a genuine top-`s` per row:
    // every selected atom's score (recomputed exactly in f64) is within f32-GEMM
    // rounding of the true `s`-th-largest |score| cutoff, and the reported score
    // matches the exact dot product. Where two atoms tie within rounding the
    // batched and row-at-a-time paths may pick different members of the tie —
    // that is correct (they are interchangeable) and is exactly why the fit is
    // minibatch-invariant rather than bit-identical. Non-orthogonal unit atoms
    // so the scores are generic.
    let (k, p, n) = (40usize, 11usize, 137usize);
    let mut decoder = Array2::<f32>::zeros((k, p));
    for atom in 0..k {
        for c in 0..p {
            decoder[[atom, c]] = (((atom * 5 + c * 3 + 1) % 13) as f32 - 6.0) / 6.0;
        }
    }
    // Unit-norm the decoder rows (the trainer always routes against unit atoms).
    for mut row in decoder.outer_iter_mut() {
        let nrm: f32 = row.iter().map(|v| v * v).sum::<f32>().sqrt();
        if nrm > 1.0e-12 {
            row.mapv_inplace(|v| v / nrm);
        }
    }
    let mut x = Array2::<f32>::zeros((n, p));
    for row in 0..n {
        for c in 0..p {
            x[[row, c]] = (((row * 7 + c * 2 + 3) % 17) as f32 - 8.0) / 4.0;
        }
    }
    let s = 4usize;
    let scorer = TileScorer::new(s, 7);
    let batched = scorer.route_minibatch(x.view(), decoder.view());
    assert_eq!(batched.len(), n);

    // Exact f64 |score| of one row against one atom.
    let exact_mag = |row: usize, atom: usize| -> f64 {
        let mut acc = 0.0f64;
        for c in 0..p {
            acc += x[[row, c]] as f64 * decoder[[atom, c]] as f64;
        }
        acc.abs()
    };
    const TOL: f64 = 1.0e-5;
    for (i, shortlist) in batched.iter().enumerate() {
        assert_eq!(shortlist.len(), s, "row {i}: shortlist must have width s");
        // The shortlist's atoms are distinct.
        let mut seen = std::collections::HashSet::new();
        for &(atom, _) in shortlist {
            assert!(seen.insert(atom), "row {i}: atom {atom} selected twice");
        }
        // Reported scores match the exact dot product.
        for &(atom, score) in shortlist {
            assert!(
                (score.abs() as f64 - exact_mag(i, atom as usize)).abs() <= TOL,
                "row {i}: reported |score| {} for atom {atom} != exact {}",
                score.abs(),
                exact_mag(i, atom as usize)
            );
        }
        // The true s-th-largest |score| cutoff, computed exactly.
        let mut all: Vec<f64> = (0..k).map(|a| exact_mag(i, a)).collect();
        all.sort_by(|a, b| b.partial_cmp(a).unwrap());
        let cutoff = all[s - 1];
        // Every selected atom must clear the cutoff up to rounding (a valid top-s).
        for &(atom, _) in shortlist {
            assert!(
                exact_mag(i, atom as usize) + TOL >= cutoff,
                "row {i}: selected atom {atom} (|score| {}) is below the top-{s} cutoff {cutoff}",
                exact_mag(i, atom as usize)
            );
        }
        // The shortlist is sorted by descending |score|.
        for w in shortlist.windows(2) {
            assert!(
                w[0].1.abs() + (TOL as f32) >= w[1].1.abs(),
                "row {i}: shortlist not sorted by descending |score|"
            );
        }
    }
}

#[test]
fn fit_is_minibatch_size_invariant() {
    // The minibatch knob bounds peak working set, NOT the solution. Fitting the
    // same data with a tiny minibatch (1 row at a time) and with a minibatch that
    // covers the whole block must produce the same dictionary quality: the
    // route→code→refresh math is identical, only the score-block tiling changes.
    let (k, p, n) = (8usize, 12usize, 480usize);
    let (x, _atoms) = planted(k, p, n, 0.2);
    let base = SparseDictConfig {
        n_atoms: k,
        active: 2,
        minibatch: 1,
        max_epochs: 40,
        score_tile: 16,
        code_ridge: 1.0e-6,
        decoder_ridge: 1.0e-6,
        tolerance: 1.0e-9,
    };
    let fit_mb1 = fit_sparse_dictionary(x.view(), &base).expect("minibatch=1 fit");
    let fit_mbn = fit_sparse_dictionary(
        x.view(),
        &SparseDictConfig {
            minibatch: n,
            ..base
        },
    )
    .expect("minibatch=N fit");
    let fit_mb_mid = fit_sparse_dictionary(
        x.view(),
        &SparseDictConfig {
            minibatch: 64,
            ..base
        },
    )
    .expect("minibatch=64 fit");
    // Same EV to f32-rounding tolerance regardless of how the rows were batched.
    assert!(
        (fit_mb1.explained_variance - fit_mbn.explained_variance).abs() < 1.0e-4,
        "minibatch=1 EV {} vs minibatch=N EV {} must agree",
        fit_mb1.explained_variance,
        fit_mbn.explained_variance
    );
    assert!(
        (fit_mb1.explained_variance - fit_mb_mid.explained_variance).abs() < 1.0e-4,
        "minibatch=1 EV {} vs minibatch=64 EV {} must agree",
        fit_mb1.explained_variance,
        fit_mb_mid.explained_variance
    );
}

#[test]
fn scales_to_large_k_without_dense_n_by_k() {
    // K far larger than the planted rank: trainer must stay correct and never
    // allocate N×K (it would here be 240*2000 floats; the test just checks it
    // runs and stays fixed-width).
    let (planted_k, p, n) = (8usize, 10usize, 240usize);
    let (x, _atoms) = planted(planted_k, p, n, 0.1);
    let k = 2000usize;
    let config = SparseDictConfig {
        n_atoms: k,
        active: 1,
        max_epochs: 6,
        score_tile: 256,
        ..SparseDictConfig::new(k)
    };
    let fit = fit_sparse_dictionary(x.view(), &config).expect("large-K fit");
    assert_eq!(fit.indices.dim(), (n, 1));
    assert!(
        fit.explained_variance > 0.9,
        "large-K trainer should still explain the low-rank signal; got {}",
        fit.explained_variance
    );
}

/// #1026 — a real large-K `fit_sparse_dictionary` whose minibatch × K route
/// block clears the device break-even runs the route step on the GPU under the
/// ambient (default `Auto`) residency mode, and is bit-for-bit reproducible.
///
/// We deliberately do NOT touch the process-wide `set_gpu_mode` (it is
/// first-writer-wins, so a test that pinned it would poison every other test in
/// the binary). The full-fit GPU-route == CPU-route equivalence is locked
/// directly at the routing primitive by
/// `scoring_gpu::tests::device_route_minibatch_matches_cpu_top_s_online`, which
/// drives `GpuMode::Required`, asserts `ScoreBlockPath::Device`, and proves the
/// routed top-`s` support is bit-identical to the CPU `top_s_online` oracle.
/// Since the only mode-dependent step in the fit is where those bit-identical
/// scores are computed, the whole alternating-minimisation trajectory is
/// mode-invariant — which this test confirms by re-running the fit and asserting
/// the dictionary, indices, and codes are identical to the bit.
#[test]
fn large_k_fit_routes_on_gpu_above_breakeven_and_is_reproducible() {
    // minibatch=512 × K=4096 = 2,097,152-element score block per minibatch,
    // above DEVICE_SCORE_BLOCK_MIN_ELEMS (1<<20), so the GPU route engages on a
    // CUDA host. p=48 is a representative residual-stream width.
    let (planted_k, p, n) = (8usize, 48usize, 1536usize);
    let (x, _atoms) = planted(planted_k, p, n, 0.1);
    let k = 4096usize;
    let config = SparseDictConfig {
        n_atoms: k,
        active: 2,
        minibatch: 512,
        max_epochs: 4,
        score_tile: 1024,
        ..SparseDictConfig::new(k)
    };

    let fit = fit_sparse_dictionary(x.view(), &config).expect("large-K fit");
    let fit2 = fit_sparse_dictionary(x.view(), &config).expect("large-K fit (rerun)");

    assert_eq!(
        fit.decoder, fit2.decoder,
        "[#1026] sparse-dict fit is non-deterministic across runs (GPU route must \
         be bit-reproducible)"
    );
    assert_eq!(fit.indices, fit2.indices);
    assert_eq!(fit.codes, fit2.codes);
    assert!(
        fit.explained_variance > 0.9,
        "[#1026] large-K fit should explain the low-rank signal; got {}",
        fit.explained_variance
    );
}