mlxrs 0.1.0

Safe Rust bindings for Apple's MLX array framework, with LM, VLM, audio, and embeddings support
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
//! Training-loop orchestration ported from mlx-lm
//! [`tuner/trainer.py`](https://github.com/ml-explore/mlx-lm/blob/main/mlx_lm/tuner/trainer.py).
//!
//! ## v1 status — mechanics-only [`train`]
//!
//! The public [`train`] loop wires up the optimizer step, callbacks, eval,
//! and save hooks end-to-end, but does NOT yet compute REAL per-parameter
//! gradients — it dispatches `optimizer.apply_gradients(zeros_like(params),
//! params)` because mlxrs has no `nn::Module` trait yet to bind
//! `params → loss` for [`crate::transforms::value_and_grad`]. The full
//! autograd path arrives once the [`crate::lm::model::Model`] trait grows
//! parameter binding (tracked separately on the M3 roadmap).
//!
//! Callers must explicitly opt in via
//! [`TrainingArgs::acknowledge_no_real_gradients`] = `true` before invoking
//! [`train`] in v1. The flag exists so a future production caller cannot
//! accidentally run multi-hour training jobs against the stub thinking
//! they're getting actual parameter updates — the `Err` returned otherwise
//! points the caller at this v1 limitation. Mechanics-only validation
//! (callbacks fire, optimizer state advances, save hook runs at the right
//! cadence) is the use case this v1 enables; real model fine-tuning is not.
//!
//! ## Periodic-event cadence — OPTIMIZER STEPS (deviation from Python)
//!
//! [`TrainingArgs::iters`] counts MICROBATCH iterations (matching Python
//! `mlx-lm/tuner/trainer.py`), but [`TrainingArgs::steps_per_report`],
//! [`TrainingArgs::steps_per_eval`], and [`TrainingArgs::steps_per_save`]
//! count OPTIMIZER STEPS — they fire only after a complete
//! [`TrainingArgs::grad_accumulation_steps`] window. Python counts these
//! in microbatches, which makes a caller's report frequency silently
//! depend on `grad_accumulation_steps`. mlxrs decouples them so a caller
//! bumping `grad_accumulation_steps` doesn't accidentally inflate their
//! report / eval / save frequency. Total optimizer steps the loop will
//! execute is `iters / grad_accumulation_steps` (floored — any final
//! partial window is dropped).
//!
//! ## Surface
//!
//! - [`TrainingArgs`] — config (Python `@dataclass class TrainingArgs`,
//!   `trainer.py:41..=83`).
//! - [`default_loss`] — token-level masked cross-entropy (Python
//!   `default_loss`, `trainer.py:86..=99`).
//! - [`grad_checkpoint`] — wrap a layer-forward in `mlxrs::transforms::checkpoint`
//!   (Python `grad_checkpoint`, `trainer.py:25..=38`). The Python version
//!   monkey-patches `type(layer).__call__`; the Rust version returns a
//!   wrapped closure (Rust idiom — composition over mutation).
//! - [`iterate_batches`] — length-sorted, padded, optionally-shuffled batch
//!   iterator (Python `iterate_batches`, `trainer.py:102..=173`).
//! - [`evaluate`] — eval-loss helper (Python `evaluate`, `trainer.py:176..=215`).
//! - [`train`] — the main training loop (Python `train`,
//!   `trainer.py:218..=387`). v1 mechanics-only; see the status note above.
//! - [`TrainingCallback`] — progress-reporting hook trait (Python
//!   `TrainingCallback`, `mlx_lm/tuner/callbacks.py`).
//!
//! ## Scope cuts (deviations from Python — see issue #163)
//!
//! - **Distributed training** (`mx.distributed.AllReduce`, `Group.barrier`,
//!   `average_gradients`, `mx.distributed.all_sum`) — out of scope for v1;
//!   single-process training only. Callers running multi-node have to add
//!   per-step `all_sum`/`average_gradients` themselves via the not-yet-
//!   wrapped `mlxrs_sys::mlx_distributed_*` FFI.
//! - **Adapter checkpoint save / final save** (`mx.save_safetensors`) —
//!   delegated to the caller via the [`TrainingCallback::on_save`] hook
//!   (NOT auto-saved by [`train`]). Rust idiom: don't write to disk inside
//!   library code unless the caller explicitly asks; the [`crate::io`]
//!   module exposes the safetensors / GGUF load+save primitives the
//!   caller composes.
//! - **`mx.metal.is_available()` / `mx.set_wired_limit(...)` /
//!   `mx.get_cache_memory()` / `mx.clear_cache()` / `mx.get_peak_memory()`**
//!   — call sites are no-ops in v1 (mlxrs's memory module covers the same
//!   surface but it is not auto-tuned inside `train` — caller does it).
//! - **`mx.compile` + `partial(state=..., inputs=..., outputs=...)`** —
//!   the Python `@mx.compile`'d `step(...)` closure is NOT replicated; the
//!   Rust loop computes value+grad → optimizer step per iteration straight
//!   from `crate::transforms::value_and_grad`. The compile-graph
//!   optimization is opt-in and out of scope for the v1 training surface.
//! - **`mx.random.state` thread-state** — out of scope; callers seed their
//!   own RNG and pass it through `iterate_batches`'s `seed`.

use std::{collections::HashMap, marker::PhantomData, time::Instant};

use smol_str::format_smolstr;

use crate::{
  Array, Dtype, Result,
  error::{
    EmptyInputPayload, Error, InvariantViolationPayload, LengthMismatchPayload, MissingKeyPayload,
    OutOfRangePayload, RankMismatchPayload, ShapePairMismatchPayload,
  },
  lm::{
    cache::KvCache,
    load::Weights,
    model::Model,
    perplexity,
    tuner::{
      datasets::{Dataset, Example},
      optimizers::Optimizer,
    },
  },
  ops::{arithmetic, comparison, logical, reduction},
  transforms,
};

// ─────────────────────────── TrainingArgs ───────────────────────────

/// Training-loop configuration. Mirrors Python `tuner.TrainingArgs`
/// (`trainer.py:41..=83`).
#[derive(Debug, Clone)]
pub struct TrainingArgs {
  /// Minibatch size (Python `batch_size`, default `4`).
  batch_size: usize,
  /// Total training iterations (Python `iters`, default `100`).
  iters: usize,
  /// Number of validation batches per eval (Python `val_batches`, default
  /// `25`). `None` uses the entire validation set (Python `-1`).
  val_batches: Option<usize>,
  /// OPTIMIZER steps between training-loss reports (Python
  /// `steps_per_report`, default `10`). NOTE: counts OPTIMIZER steps (not
  /// microbatches like the Python ref) — see the v1 status note in the
  /// module-level doc-comment for the rationale.
  steps_per_report: usize,
  /// OPTIMIZER steps between validations (Python `steps_per_eval`,
  /// default `200`). Counts OPTIMIZER steps (see [`Self::steps_per_report`]
  /// for the deviation from the Python ref).
  steps_per_eval: usize,
  /// OPTIMIZER steps between checkpoint saves (Python `steps_per_save`,
  /// default `100`). Counts OPTIMIZER steps (see [`Self::steps_per_report`]
  /// for the deviation from the Python ref).
  steps_per_save: usize,
  /// Maximum per-example sequence length after padding/truncation (Python
  /// `max_seq_length`, default `2048`).
  max_seq_length: usize,
  /// Save/load path for the trained adapter weights (Python `adapter_file`,
  /// default `adapters.safetensors`).
  adapter_file: String,
  /// Enable gradient checkpointing on the first decoder layer (Python
  /// `grad_checkpoint`, default `false`). Caller wraps the layer via
  /// [`grad_checkpoint`] before training; this flag is informational
  /// (training loop does not auto-wrap).
  grad_checkpoint: bool,
  /// Number of micro-batches accumulated before an optimizer step (Python
  /// `grad_accumulation_steps`, default `1`). The training loop
  /// accumulates the SUM of per-microbatch gradients across one window,
  /// divides by this count (the MEAN), then dispatches to
  /// [`crate::lm::tuner::optimizers::Optimizer::apply_gradients`]. The
  /// final partial window at the end of [`Self::iters`] is DROPPED — no
  /// optimizer call fires for it — so the total optimizer step count is
  /// `iters / grad_accumulation_steps` (floored).
  grad_accumulation_steps: usize,
  /// Cache-clear threshold in bytes (Python `clear_cache_threshold`,
  /// default `0` = disabled). v1 is a no-op (memory management out of
  /// scope), kept for API parity.
  clear_cache_threshold: usize,
  /// Caller-side acknowledgment that [`train`]'s v1 path runs the
  /// optimizer / callback / save mechanics but does NOT compute real
  /// per-parameter gradients (the `nn::Module` trait that binds
  /// `params → loss` for [`crate::transforms::value_and_grad`] is not yet
  /// ported). Default is `false` so a future production caller cannot
  /// accidentally run a long training job thinking the model is being
  /// updated. When `false`, [`train`] returns
  /// [`Error::InvariantViolation`] pointing at this
  /// field; set to `true` to opt into the mechanics-only training path.
  ///
  /// **No Python parity:** this field is mlxrs-specific (Python's
  /// `mx.value_and_grad` works against any callable that closes over
  /// `mx.array` parameters, so the Python trainer has nothing analogous
  /// to fence off).
  acknowledge_no_real_gradients: bool,
}

impl TrainingArgs {
  /// Construct a [`TrainingArgs`] with the Python-default values.
  ///
  /// Equivalent to [`TrainingArgs::default()`]. Use `.with_*` builder
  /// methods to override individual fields.
  pub fn new() -> Self {
    Self {
      batch_size: 4,
      iters: 100,
      val_batches: Some(25),
      steps_per_report: 10,
      steps_per_eval: 200,
      steps_per_save: 100,
      max_seq_length: 2048,
      adapter_file: "adapters.safetensors".into(),
      grad_checkpoint: false,
      grad_accumulation_steps: 1,
      clear_cache_threshold: 0,
      // Caller MUST flip this to `true` to opt into the v1 mechanics-only
      // `train()` (see the field doc-comment + module-level v1 status note).
      acknowledge_no_real_gradients: false,
    }
  }

  /// Minibatch size.
  #[inline(always)]
  pub fn batch_size(&self) -> usize {
    self.batch_size
  }

  /// Total training iterations.
  #[inline(always)]
  pub fn iters(&self) -> usize {
    self.iters
  }

  /// Number of validation batches per eval (`None` = entire val set).
  #[inline(always)]
  pub fn val_batches(&self) -> Option<usize> {
    self.val_batches
  }

  /// OPTIMIZER steps between training-loss reports.
  #[inline(always)]
  pub fn steps_per_report(&self) -> usize {
    self.steps_per_report
  }

  /// OPTIMIZER steps between validations.
  #[inline(always)]
  pub fn steps_per_eval(&self) -> usize {
    self.steps_per_eval
  }

  /// OPTIMIZER steps between checkpoint saves.
  #[inline(always)]
  pub fn steps_per_save(&self) -> usize {
    self.steps_per_save
  }

  /// Maximum per-example sequence length after padding/truncation.
  #[inline(always)]
  pub fn max_seq_length(&self) -> usize {
    self.max_seq_length
  }

  /// Save/load path for the trained adapter weights.
  #[inline(always)]
  pub fn adapter_file(&self) -> &str {
    &self.adapter_file
  }

  /// Whether gradient checkpointing is enabled (informational flag).
  #[inline(always)]
  pub fn grad_checkpoint(&self) -> bool {
    self.grad_checkpoint
  }

  /// Number of micro-batches accumulated before an optimizer step.
  #[inline(always)]
  pub fn grad_accumulation_steps(&self) -> usize {
    self.grad_accumulation_steps
  }

  /// Cache-clear threshold in bytes (`0` = disabled).
  #[inline(always)]
  pub fn clear_cache_threshold(&self) -> usize {
    self.clear_cache_threshold
  }

  /// Whether the caller has acknowledged the v1 no-real-gradients limitation.
  #[inline(always)]
  pub fn acknowledge_no_real_gradients(&self) -> bool {
    self.acknowledge_no_real_gradients
  }

  /// Set `batch_size`. Returns `self` for chaining.
  #[must_use]
  pub fn with_batch_size(mut self, batch_size: usize) -> Self {
    self.batch_size = batch_size;
    self
  }

  /// Set `iters`. Returns `self` for chaining.
  #[must_use]
  pub fn with_iters(mut self, iters: usize) -> Self {
    self.iters = iters;
    self
  }

  /// Set `val_batches`. Returns `self` for chaining.
  #[must_use]
  pub fn with_val_batches(mut self, val_batches: Option<usize>) -> Self {
    self.val_batches = val_batches;
    self
  }

  /// Set `steps_per_report`. Returns `self` for chaining.
  #[must_use]
  pub fn with_steps_per_report(mut self, steps_per_report: usize) -> Self {
    self.steps_per_report = steps_per_report;
    self
  }

  /// Set `steps_per_eval`. Returns `self` for chaining.
  #[must_use]
  pub fn with_steps_per_eval(mut self, steps_per_eval: usize) -> Self {
    self.steps_per_eval = steps_per_eval;
    self
  }

  /// Set `steps_per_save`. Returns `self` for chaining.
  #[must_use]
  pub fn with_steps_per_save(mut self, steps_per_save: usize) -> Self {
    self.steps_per_save = steps_per_save;
    self
  }

  /// Set `max_seq_length`. Returns `self` for chaining.
  #[must_use]
  pub fn with_max_seq_length(mut self, max_seq_length: usize) -> Self {
    self.max_seq_length = max_seq_length;
    self
  }

  /// Set `adapter_file`. Returns `self` for chaining.
  #[must_use]
  pub fn with_adapter_file(mut self, adapter_file: impl Into<String>) -> Self {
    self.adapter_file = adapter_file.into();
    self
  }

  /// Set `grad_checkpoint`. Returns `self` for chaining.
  #[must_use]
  pub fn with_grad_checkpoint(mut self, grad_checkpoint: bool) -> Self {
    self.grad_checkpoint = grad_checkpoint;
    self
  }

  /// Set `grad_accumulation_steps`. Returns `self` for chaining.
  #[must_use]
  pub fn with_grad_accumulation_steps(mut self, grad_accumulation_steps: usize) -> Self {
    self.grad_accumulation_steps = grad_accumulation_steps;
    self
  }

  /// Set `clear_cache_threshold`. Returns `self` for chaining.
  #[must_use]
  pub fn with_clear_cache_threshold(mut self, clear_cache_threshold: usize) -> Self {
    self.clear_cache_threshold = clear_cache_threshold;
    self
  }

  /// Set `acknowledge_no_real_gradients`. Returns `self` for chaining.
  #[must_use]
  pub fn with_acknowledge_no_real_gradients(mut self, acknowledge_no_real_gradients: bool) -> Self {
    self.acknowledge_no_real_gradients = acknowledge_no_real_gradients;
    self
  }
}

impl Default for TrainingArgs {
  fn default() -> Self {
    Self::new()
  }
}

// ─────────────────────────── default_loss ───────────────────────────

/// Token-level masked cross-entropy loss for next-token prediction.
///
/// Mirrors Python `default_loss` (`trainer.py:86..=99`), with an exclusive
/// upper bound on the mask (`steps < length` instead of Python's
/// `steps <= length`) to drop the first padded token from the supervised
/// targets. Matches the masking pattern used by mlx-lm's own DWQ trainer
/// (`mlx_lm/quant/dwq.py:115` — `mx.arange(1, 1 + targets.shape[1]) <
/// lengths[:, 1:]`).
///
/// ```text
/// inputs  = batch[:, :-1]
/// targets = batch[:, 1:]
/// logits  = model(inputs)
/// steps   = arange(1, T+1)
/// mask    = (steps >= lengths[:, 0:1]) & (steps < lengths[:, 1:])
/// ce      = cross_entropy(logits, targets) * mask
/// ntoks   = mask.sum()
/// loss    = ce.astype(float32).sum() / ntoks
/// ```
///
/// `batch` is an `[B, S]` integer-token tensor; `lengths` is `[B, 2]`
/// where each row is `(offset, length)`:
/// - tokens at positions `[0, offset)` are the prompt prefix (excluded from
///   the loss);
/// - tokens at positions `[offset, length)` are the completion (included).
///
/// The shifted target at position `length - 1` corresponds to the FIRST
/// padded slot in the unshifted batch (`batch[:, length]` after pad), so
/// the exclusive upper bound excludes it from the supervised loss — the
/// training signal never asks the model to predict the pad token 0 from
/// the last real completion token.
///
/// Returns `(loss_scalar, ntoks_scalar)` — both 0D `Array`s in f32.
///
/// `model.forward` is called WITHOUT a KV cache (training does a fresh
/// forward per step, unlike inference). A future grad-accumulation
/// micro-batching pass through this fn would re-evaluate the same logits
/// — caller controls invocation count.
pub fn default_loss<M>(model: &M, batch: &Array, lengths: &Array) -> Result<(Array, Array)>
where
  M: Model,
{
  let shape = batch.shape();
  let (_b, s) = match shape.as_slice() {
    [b, s] => (*b, *s),
    other => {
      return Err(Error::RankMismatch(RankMismatchPayload::new(
        "default_loss: batch must be rank-2 [B, S]",
        other.len() as u32,
        other.to_vec(),
      )));
    }
  };
  if s < 2 {
    return Err(Error::OutOfRange(OutOfRangePayload::new(
      "default_loss: batch S",
      "must be >= 2 for next-token prediction",
      format_smolstr!("{s}"),
    )));
  }
  let lengths_shape = lengths.shape();
  let expected_lengths_shape = [shape[0], 2_usize];
  if lengths_shape.as_slice() != expected_lengths_shape {
    return Err(Error::ShapePairMismatch(ShapePairMismatchPayload::new(
      "default_loss: lengths must be [B, 2] = (offset, length)",
      expected_lengths_shape.to_vec(),
      lengths_shape.to_vec(),
    )));
  }
  // inputs = batch[:, :-1], targets = batch[:, 1:]
  let b_dim = shape[0] as i32;
  let s_dim = s as i32;
  let inputs = crate::ops::indexing::slice(batch, &[0, 0], &[b_dim, s_dim - 1], &[1, 1])?;
  let targets = crate::ops::indexing::slice(batch, &[0, 1], &[b_dim, s_dim], &[1, 1])?;
  // Forward — empty cache slice (training does a fresh forward per step).
  let mut cache: Vec<Box<dyn KvCache>> = Vec::new();
  let logits = model.forward(&inputs, &mut cache)?;
  // steps = arange(1, targets.shape[1] + 1) → [1..T]
  let t_dim = targets.shape()[1] as f32;
  let steps = Array::arange::<f32>(1.0, t_dim + 1.0, 1.0)?;
  // mask = (steps >= lengths[:, 0:1]) & (steps < lengths[:, 1:])
  // lengths[:, 0:1] is [B, 1]; lengths[:, 1:] is [B, 1].
  // Exclusive upper bound (`<`) drops the supervised target at
  // `step == length`, which corresponds to the FIRST padded slot in the
  // un-shifted batch. See the function's doc-comment for the off-by-one
  // analysis + mlx-lm DWQ reference.
  let offset = crate::ops::indexing::slice(lengths, &[0, 0], &[b_dim, 1], &[1, 1])?;
  let length = crate::ops::indexing::slice(lengths, &[0, 1], &[b_dim, 2], &[1, 1])?;
  // arange returns f32; cast steps to the same dtype as offset (int)
  // before comparison. Python does the comparison implicitly across
  // f32-int via mlx broadcasting → both promoted to f32.
  let offset_f = offset.astype(Dtype::F32)?;
  let length_f = length.astype(Dtype::F32)?;
  let ge = comparison::greater_equal(&steps, &offset_f)?;
  let lt = comparison::less(&steps, &length_f)?;
  let mask = logical::logical_and(&ge, &lt)?;
  // Cross-entropy (reduction="none") → [B, T]
  let ce = perplexity::cross_entropy_none(&logits, &targets)?;
  // ce * mask
  let mask_f = mask.astype(Dtype::F32)?;
  let ce_masked = arithmetic::multiply(&ce, &mask_f)?;
  // ntoks = mask.sum() (int)
  let mut ntoks = reduction::sum(&mask_f, false)?;
  // Reject zero-supervised-token batches BEFORE the divide rather than
  // producing NaN/Inf downstream (train accumulates `loss * ntoks` and
  // evaluate divides by `total_tokens`; both would silently poison
  // metrics if any batch contained only prompt-only / fully-truncated
  // rows under the exclusive `<` upper bound). The check forces an eval
  // on `ntoks` one division earlier than the caller's `.item::<f32>()`
  // would; the trade-off is an explicit, actionable error vs a silent
  // numerical fault.
  let ntoks_count = ntoks.item::<f32>()?;
  if ntoks_count == 0.0 {
    // The supervised-token set is empty after the length mask: every
    // example in the batch is too short (prompt-only or fully truncated)
    // or has length <= 1. Reject before the divide rather than emitting
    // NaN/Inf downstream; the caller should filter such examples upstream.
    return Err(Error::EmptyInput(EmptyInputPayload::new(
      "default_loss: supervised tokens after the length mask (batch produced 0 supervised tokens)",
    )));
  }
  // loss = ce.astype(f32).sum() / ntoks
  let ce_sum = reduction::sum(&ce_masked.astype(Dtype::F32)?, false)?;
  let loss = arithmetic::divide(&ce_sum, &ntoks)?;
  Ok((loss, ntoks))
}

// ─────────────────────────── grad_checkpoint ───────────────────────────

/// Wrap a forward function `f` so its activations are recomputed on the
/// backward pass instead of being stored.
///
/// Mirrors Python `grad_checkpoint` (`trainer.py:25..=38`), with the
/// key difference that Python monkey-patches `type(layer).__call__` (so the
/// wrap is global per layer type) while Rust returns a wrapped closure
/// (composition over mutation — caller substitutes the wrapped fn into
/// the model's forward chain).
///
/// Thin re-export of [`crate::transforms::checkpoint::checkpoint`].
pub fn grad_checkpoint<F>(f: F) -> Result<impl Fn(&[Array]) -> Result<Vec<Array>>>
where
  F: Fn(&[Array]) -> Result<Vec<Array>> + 'static,
{
  transforms::checkpoint::checkpoint(f)
}

// ─────────────────────────── TrainingCallback ───────────────────────────

/// Hook trait for training-loop progress reporting.
///
/// Mirrors Python `TrainingCallback` (`mlx_lm/tuner/callbacks.py`); each
/// method has a default no-op impl so callers override only what they
/// need.
pub trait TrainingCallback {
  /// Invoked at the end of every [`TrainingArgs::steps_per_report`]
  /// iteration with a summary of the most recent training window.
  fn on_train_loss_report(&mut self, _info: &TrainInfo) {}

  /// Invoked at the end of every [`TrainingArgs::steps_per_eval`]
  /// iteration (and before iteration 1) with a summary of the most recent
  /// validation pass.
  fn on_val_loss_report(&mut self, _info: &ValInfo) {}

  /// Invoked at the end of every [`TrainingArgs::steps_per_save`]
  /// iteration with the current iteration count + the configured
  /// [`TrainingArgs::adapter_file`] path. Default no-op so callers opt
  /// into saving.
  fn on_save(&mut self, _it: usize, _adapter_file: &str) -> Result<()> {
    Ok(())
  }
}

/// Per-window training summary handed to [`TrainingCallback::on_train_loss_report`].
#[derive(Debug, Clone)]
pub struct TrainInfo {
  /// 1-based iteration index at which this report fired.
  iteration: usize,
  /// Mean training loss over the most recent report window.
  train_loss: f32,
  /// Optimizer's resolved learning rate at this iteration.
  learning_rate: f32,
  /// Iterations / second over the most recent report window.
  iterations_per_second: f32,
  /// Tokens / second over the most recent report window.
  tokens_per_second: f32,
  /// Cumulative trained tokens so far.
  trained_tokens: usize,
}

impl TrainInfo {
  /// Construct a [`TrainInfo`].
  pub fn new(
    iteration: usize,
    train_loss: f32,
    learning_rate: f32,
    iterations_per_second: f32,
    tokens_per_second: f32,
    trained_tokens: usize,
  ) -> Self {
    Self {
      iteration,
      train_loss,
      learning_rate,
      iterations_per_second,
      tokens_per_second,
      trained_tokens,
    }
  }

  /// 1-based iteration index at which this report fired.
  #[inline(always)]
  pub fn iteration(&self) -> usize {
    self.iteration
  }

  /// Mean training loss over the most recent report window.
  #[inline(always)]
  pub fn train_loss(&self) -> f32 {
    self.train_loss
  }

  /// Optimizer's resolved learning rate at this iteration.
  #[inline(always)]
  pub fn learning_rate(&self) -> f32 {
    self.learning_rate
  }

  /// Iterations / second over the most recent report window.
  #[inline(always)]
  pub fn iterations_per_second(&self) -> f32 {
    self.iterations_per_second
  }

  /// Tokens / second over the most recent report window.
  #[inline(always)]
  pub fn tokens_per_second(&self) -> f32 {
    self.tokens_per_second
  }

  /// Cumulative trained tokens so far.
  #[inline(always)]
  pub fn trained_tokens(&self) -> usize {
    self.trained_tokens
  }
}

/// Per-eval validation summary handed to [`TrainingCallback::on_val_loss_report`].
#[derive(Debug, Clone)]
pub struct ValInfo {
  /// 1-based iteration index at which this eval fired (note Python uses
  /// `it - 1` for pre-first-step eval; this port mirrors that).
  iteration: usize,
  /// Mean validation loss across `num_batches` eval batches.
  val_loss: f32,
  /// Wall-clock seconds the eval took.
  val_time: f32,
}

impl ValInfo {
  /// Construct a [`ValInfo`].
  pub fn new(iteration: usize, val_loss: f32, val_time: f32) -> Self {
    Self {
      iteration,
      val_loss,
      val_time,
    }
  }

  /// 1-based iteration index at which this eval fired.
  #[inline(always)]
  pub fn iteration(&self) -> usize {
    self.iteration
  }

  /// Mean validation loss across eval batches.
  #[inline(always)]
  pub fn val_loss(&self) -> f32 {
    self.val_loss
  }

  /// Wall-clock seconds the eval took.
  #[inline(always)]
  pub fn val_time(&self) -> f32 {
    self.val_time
  }
}

/// No-op [`TrainingCallback`] used as the default when the caller doesn't
/// provide one.
pub struct NoopCallback;

impl TrainingCallback for NoopCallback {}

// ─────────────────────────── iterate_batches ───────────────────────────

/// One yielded batch from [`iterate_batches`]:
///
/// - `tokens` — the `[B, max_len_in_batch]` int32 token tensor (padded
///   with `0` past each row's true length, truncated to
///   [`TrainingArgs::max_seq_length`] before padding).
/// - `lengths` — the `[B, 2]` `(offset, length)` per-row metadata used by
///   [`default_loss`] to build the per-token loss mask.
pub struct Batch {
  /// `[B, S]` int32 token tensor.
  tokens: Array,
  /// `[B, 2]` `(offset, length)` per-row metadata.
  lengths: Array,
  // PhantomData<'_>-equivalent: keep `Batch` consistent with future fields
  // (e.g. an associated key for distributed sharding) without breaking the
  // ABI.
  _marker: PhantomData<()>,
}

impl Batch {
  /// Construct a [`Batch`] from a token tensor and a lengths tensor.
  pub fn new(tokens: Array, lengths: Array) -> Self {
    Self {
      tokens,
      lengths,
      _marker: PhantomData,
    }
  }

  /// The `[B, S]` int32 token tensor.
  #[inline(always)]
  pub fn tokens_ref(&self) -> &Array {
    &self.tokens
  }

  /// The `[B, 2]` `(offset, length)` per-row metadata tensor.
  #[inline(always)]
  pub fn lengths_ref(&self) -> &Array {
    &self.lengths
  }
}

/// Iterate over `dataset` in length-sorted, padded batches matching the
/// Python `iterate_batches` (`trainer.py:102..=173`) — sans distributed
/// sharding.
///
/// - Sorts examples by length (Python `sorted(range(len(dataset)), key=...)`).
/// - Forms `batch_size`-sized groups in length-sorted order (Python
///   `batch_idx = [idx[i:i+batch_size] for i in range(0, len-bs+1, bs)]`).
/// - Each yielded batch pads every example to `1 + 32·ceil((max_len_in_batch
///   + 31) / 32)` (Python `pad_to = 32` heuristic at `trainer.py:157..=159`),
///   clamped to `max_seq_length`.
/// - If `shuffle_seed` is `Some(seed)`, batch groups are shuffled with a
///   deterministic RNG seeded by `seed` (Python `np.random.seed` +
///   `np.random.permutation`).
///
/// Returns an iterator that yields `Result<Batch>` (errors mid-iter
/// short-circuit to the caller).
///
/// `loop_forever` flag mirrors Python `loop: bool` — when true, the
/// iterator restarts after exhausting all batch groups (used by [`train`]
/// for the main loop; eval passes `false` to terminate after one pass).
pub fn iterate_batches<'a, D: Dataset + 'a>(
  dataset: &'a D,
  batch_size: usize,
  max_seq_length: usize,
  loop_forever: bool,
  shuffle_seed: Option<u64>,
) -> Result<impl Iterator<Item = Result<Batch>> + 'a> {
  if dataset.len() < batch_size {
    return Err(Error::OutOfRange(OutOfRangePayload::new(
      "iterate_batches: dataset size",
      "must be >= batch_size",
      format_smolstr!("{} (batch_size={batch_size})", dataset.len()),
    )));
  }
  // Length-sort indices.
  let mut idx: Vec<usize> = (0..dataset.len()).collect();
  let lens: Vec<usize> = (0..dataset.len())
    .map(|i| dataset.process(i).map(|(toks, _)| toks.len()))
    .collect::<Result<_>>()?;
  idx.sort_by_key(|&i| lens[i]);
  // Group into batch_size chunks (drop the ragged tail, Python: range
  // `0 .. len-bs+1` step bs).
  let num_batches = dataset.len() / batch_size;
  let mut batch_idx: Vec<Vec<usize>> = Vec::with_capacity(num_batches);
  for i in 0..num_batches {
    batch_idx.push(idx[i * batch_size..(i + 1) * batch_size].to_vec());
  }
  Ok(BatchIter {
    dataset,
    batch_idx,
    max_seq_length,
    cursor: 0,
    order: Vec::new(),
    loop_forever,
    shuffle_seed,
    rng_state: shuffle_seed,
    first_pass: true,
  })
}

struct BatchIter<'a, D: Dataset> {
  dataset: &'a D,
  batch_idx: Vec<Vec<usize>>,
  max_seq_length: usize,
  cursor: usize,
  order: Vec<usize>,
  loop_forever: bool,
  shuffle_seed: Option<u64>,
  rng_state: Option<u64>,
  first_pass: bool,
}

impl<D: Dataset> Iterator for BatchIter<'_, D> {
  type Item = Result<Batch>;

  fn next(&mut self) -> Option<Self::Item> {
    if self.cursor >= self.order.len() {
      // End of one pass.
      if !self.first_pass && !self.loop_forever {
        return None;
      }
      self.first_pass = false;
      // Refresh the iteration order. With shuffle: deterministic Fisher-
      // Yates seeded by `rng_state` (advanced per restart so each pass
      // shuffles differently). Without: in-order.
      self.order = (0..self.batch_idx.len()).collect();
      if self.shuffle_seed.is_some()
        && let Some(seed) = self.rng_state
      {
        fisher_yates_shuffle(&mut self.order, seed);
        // Advance the seed for the next loop pass so successive
        // re-shuffles are distinct (and not the same permutation).
        self.rng_state = Some(seed.wrapping_add(1));
      }
      self.cursor = 0;
      if self.order.is_empty() {
        return None;
      }
    }
    let batch_slot = self.order[self.cursor];
    self.cursor += 1;
    Some(build_batch(
      self.dataset,
      &self.batch_idx[batch_slot],
      self.max_seq_length,
    ))
  }
}

fn build_batch<D: Dataset>(dataset: &D, indices: &[usize], max_seq_length: usize) -> Result<Batch> {
  let mut examples: Vec<Example> = Vec::with_capacity(indices.len());
  for &i in indices {
    examples.push(dataset.process(i)?);
  }
  let lengths: Vec<usize> = examples.iter().map(|(toks, _)| toks.len()).collect();
  // Pad to one plus nearest multiple of pad_to (32) or max_seq_length.
  let pad_to = 32usize;
  let max_in_batch = *lengths.iter().max().unwrap_or(&0);
  let mut max_len_in_batch = 1 + pad_to * max_in_batch.div_ceil(pad_to);
  if max_len_in_batch > max_seq_length {
    max_len_in_batch = max_seq_length;
  }
  let batch_size = examples.len();
  let mut buf = vec![0i32; batch_size * max_len_in_batch];
  let mut len_buf = vec![0i32; batch_size * 2];
  for (j, (toks, offset)) in examples.iter().enumerate() {
    let truncated = toks.len().min(max_seq_length).min(max_len_in_batch);
    for (k, &t) in toks[..truncated].iter().enumerate() {
      buf[j * max_len_in_batch + k] = t as i32;
    }
    len_buf[j * 2] = (*offset).min(truncated) as i32;
    len_buf[j * 2 + 1] = truncated as i32;
  }
  let tokens = Array::from_slice::<i32>(&buf, &(batch_size, max_len_in_batch))?;
  let lengths_arr = Array::from_slice::<i32>(&len_buf, &(batch_size, 2usize))?;
  Ok(Batch::new(tokens, lengths_arr))
}

/// Deterministic Fisher-Yates shuffle. Uses a SplitMix64 RNG so the same
/// `seed` produces the same permutation across runs/platforms (mirrors
/// Python's `np.random.seed(seed); np.random.permutation(...)` determinism
/// without pulling in `rand`).
fn fisher_yates_shuffle<T>(slice: &mut [T], seed: u64) {
  let mut state = seed;
  for i in (1..slice.len()).rev() {
    // SplitMix64 step.
    state = state.wrapping_add(0x9E37_79B9_7F4A_7C15);
    let mut z = state;
    z = (z ^ (z >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9);
    z = (z ^ (z >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB);
    z ^= z >> 31;
    let j = (z as usize) % (i + 1);
    slice.swap(i, j);
  }
}

// ─────────────────────────── evaluate ───────────────────────────

/// Evaluate `model` on `dataset` for at most `num_batches` batches,
/// returning the token-weighted mean cross-entropy loss.
///
/// Mirrors Python `evaluate` (`trainer.py:176..=215`), without
/// distributed `all_sum`. Each batch's loss + token count is accumulated;
/// the final loss is `total_loss / total_tokens`. `num_batches` of `None`
/// uses the whole eval set (matching Python's `num_batches == -1` sentinel).
pub fn evaluate<M: Model, D: Dataset, F>(
  model: &M,
  dataset: &D,
  batch_size: usize,
  num_batches: Option<usize>,
  max_seq_length: usize,
  mut loss_fn: F,
) -> Result<f32>
where
  F: FnMut(&M, &Array, &Array) -> Result<(Array, Array)>,
{
  let mut total_loss = 0.0_f32;
  let mut total_tokens = 0.0_f32;
  // Eval iterator: NO shuffle, NO loop. One pass over the (length-sorted)
  // batches.
  let iter = iterate_batches(dataset, batch_size, max_seq_length, false, None)?;
  let cap = num_batches.unwrap_or(usize::MAX);
  for (i, batch) in iter.enumerate() {
    if i >= cap {
      break;
    }
    let batch = batch?;
    let (mut loss, mut ntoks) = loss_fn(model, batch.tokens_ref(), batch.lengths_ref())?;
    let loss_f = loss.item::<f32>()?;
    let ntoks_f = ntoks.item::<f32>()?;
    // Token-weighted accumulation: total += per_token_loss · ntoks
    total_loss += loss_f * ntoks_f;
    total_tokens += ntoks_f;
  }
  if total_tokens == 0.0 {
    return Err(Error::EmptyInput(EmptyInputPayload::new(
      "evaluate: eval set (produced no batches with tokens)",
    )));
  }
  Ok(total_loss / total_tokens)
}

// ─────────────────────────── train ───────────────────────────

/// Run the training loop on `model` + `optimizer` over `train_dataset`,
/// optionally evaluating on `val_dataset` every
/// [`TrainingArgs::steps_per_eval`] OPTIMIZER STEPS.
///
/// Mirrors Python `train` (`trainer.py:218..=387`), with the scope cuts
/// documented in the
/// [scope-cuts module-level note](self#scope-cuts-deviations-from-python),
/// the v1 mechanics-only / no-real-gradients gate documented in the
/// [v1 status module-level note](self#v1-status--mechanics-only-train),
/// and the optimizer-step periodic cadence documented in the
/// [periodic-event cadence note](self#periodic-event-cadence--optimizer-steps-deviation-from-python).
///
/// Per microbatch the loop computes `(loss, grads)` and accumulates
/// `grads` into a running sum across [`TrainingArgs::grad_accumulation_steps`]
/// microbatches; once the window is complete it divides the accumulator
/// by `grad_accumulation_steps` and dispatches the MEAN to
/// [`Optimizer::apply_gradients`]. Any final partial window at the end
/// of [`TrainingArgs::iters`] is dropped. (`grads` is currently
/// `zeros_like(params)` per the v1 mechanics-only note above.)
///
/// ## Parameter handoff
///
/// `params` is a mutable [`Weights`] map (the same flat-key shape mlxrs
/// uses everywhere). The caller owns the parameter map and the optimizer
/// mutates it in place each step. The model is read-only — it consumes the
/// parameters indirectly (e.g. baked into its captured state at load time).
/// This deviates from Python's `model.update(params)` per-step pattern
/// because mlxrs has no `nn.Module` runtime parameter system yet (a future
/// follow-up will introduce a `Module` trait + `update()` hook).
///
/// ## Loss closure
///
/// `loss_fn` takes `(model, tokens, lengths)` and returns
/// `(loss_scalar, ntoks_scalar)`. The defaults are [`default_loss`]; pass
/// a custom closure for specialized losses (label smoothing, KD, etc.).
#[allow(clippy::too_many_arguments)]
pub fn train<M, D, O, L, C>(
  model: &M,
  optimizer: &mut O,
  params: &mut Weights,
  train_dataset: &D,
  val_dataset: Option<&D>,
  args: &TrainingArgs,
  loss_fn: L,
  callback: &mut C,
) -> Result<()>
where
  M: Model,
  D: Dataset,
  O: Optimizer + ?Sized,
  L: Fn(&M, &Array, &Array) -> Result<(Array, Array)>,
  C: TrainingCallback,
{
  if !args.acknowledge_no_real_gradients() {
    return Err(Error::InvariantViolation(InvariantViolationPayload::new(
      "train: TrainingArgs::acknowledge_no_real_gradients",
      "must be set to `true` to run the v1 mechanics-only training path",
    )));
  }
  if args.iters() == 0 {
    return Ok(());
  }
  // Validate every interval field used as a modulo divisor. A `0`
  // interval would underflow `it % 0` (panic) the first time the loop
  // tested the periodic-report / eval / save predicate, so reject up
  // front with a clear error instead of letting it crash at iteration 1.
  if args.grad_accumulation_steps() == 0 {
    return Err(Error::InvariantViolation(InvariantViolationPayload::new(
      "train: grad_accumulation_steps",
      "must be >= 1",
    )));
  }
  if args.steps_per_report() == 0 {
    return Err(Error::InvariantViolation(InvariantViolationPayload::new(
      "train: steps_per_report",
      "must be >= 1",
    )));
  }
  if args.steps_per_eval() == 0 {
    return Err(Error::InvariantViolation(InvariantViolationPayload::new(
      "train: steps_per_eval",
      "must be >= 1",
    )));
  }
  if args.steps_per_save() == 0 {
    return Err(Error::InvariantViolation(InvariantViolationPayload::new(
      "train: steps_per_save",
      "must be >= 1",
    )));
  }
  // Total OPTIMIZER steps the loop will execute. Microbatch count is
  // `args.iters()`; one optimizer step per `args.grad_accumulation_steps()`
  // microbatches; any final partial window is DROPPED (no optimizer step
  // for it). The floored division is therefore the right count.
  let total_optim_steps = args.iters() / args.grad_accumulation_steps();
  // Periodic-report window accumulators. `window_steps` is OPTIMIZER
  // STEPS in the current report window, `window_secs` is the cumulative
  // wall-clock time across all microbatches that fed those steps,
  // `window_microbatches` is the per-microbatch count used to denominate
  // the mean train loss (mirrors mlx-lm's per-microbatch loss semantic —
  // dividing by `window_steps` instead would inflate the reported loss
  // by `grad_accumulation_steps×` for every callback / log line / early-
  // stop monitor).
  let mut window_loss = 0.0_f32;
  let mut window_tokens = 0.0_f32;
  let mut window_steps = 0usize;
  let mut window_microbatches = 0usize;
  let mut window_secs = 0.0_f32;
  let mut trained_tokens = 0usize;
  // Gradient-accumulation state. `accumulated_grads` collects the SUM of
  // per-microbatch gradients across one optimizer window, then is divided
  // by `args.grad_accumulation_steps()` (the MEAN) before being dispatched
  // to the optimizer.
  let mut accumulated_grads: Option<Weights> = None;
  let mut accum_count: usize = 0;
  // OPTIMIZER step counter (NOT microbatch counter). Periodic events —
  // train-loss reports, val-loss evals, save hooks — fire on this
  // counter, so the per-event cadence is independent of
  // `grad_accumulation_steps`. Deviation from `mlx-lm/tuner/trainer.py`
  // which counts microbatches (see the v1 status note in the module-level
  // doc-comment); chosen so a caller bumping `grad_accumulation_steps`
  // doesn't accidentally inflate their report / eval / save frequency.
  let mut optim_step: usize = 0;
  // Per-microbatch timing accumulator for the current optimizer window.
  // Folded into `window_secs` and reset every time the optimizer fires.
  let mut window_micro_secs = 0.0_f32;
  let mut iter = iterate_batches(
    train_dataset,
    args.batch_size(),
    args.max_seq_length(),
    true,
    None,
  )?;
  // Pre-loop val — emit BEFORE the first optimizer step (Python
  // trainer.py:286..=317 does this implicitly by checking `it == 1`
  // before its first step body). `iteration: 0` matches the
  // microbatch-based semantics, which fire at `iteration: it - 1 = 0`.
  if let Some(val) = val_dataset
    && total_optim_steps >= 1
  {
    run_val(model, val, args, 0, callback, &loss_fn)?;
  }
  for _microbatch_it in 1..=args.iters() {
    let micro_start = Instant::now();
    let batch = iter.next().ok_or_else(|| {
      Error::InvariantViolation(InvariantViolationPayload::new(
        "train: batch iterator",
        "must never be exhausted (loop=true should never end)",
      ))
    })??;
    // Compute loss + (placeholder) gradients. NOTE: this is the v1
    // mechanics-only path — production code threads `value_and_grad`
    // over a future `nn::Module` trait that binds `params -> loss`. v1
    // ships a no-grad pass-through (`build_zero_grads`) gated by
    // [`TrainingArgs::acknowledge_no_real_gradients`] so the optimizer /
    // callback / save mechanics can be tested end-to-end.
    let (loss_scalar, ntoks_scalar) = (loss_fn)(model, batch.tokens_ref(), batch.lengths_ref())?;
    let mut loss_val = loss_scalar.try_clone()?;
    let mut ntoks_val = ntoks_scalar.try_clone()?;
    let loss_f = loss_val.item::<f32>()?;
    let ntoks_f = ntoks_val.item::<f32>()?;
    let grads: Weights = build_zero_grads(params)?;
    // Accumulate (sum) into the current window.
    accumulated_grads = Some(match accumulated_grads {
      None => grads,
      Some(acc) => add_weights(&acc, &grads)?,
    });
    accum_count += 1;
    window_loss += loss_f;
    window_tokens += ntoks_f;
    window_microbatches += 1;
    trained_tokens += ntoks_f as usize;
    window_micro_secs += micro_start.elapsed().as_secs_f32();
    // Optimizer step fires only when the accumulation window is full.
    // Partial windows at the end of `iters` are DROPPED (no
    // apply_gradients call for them); see the contract documented on
    // [`TrainingArgs::grad_accumulation_steps`] + the v1 status note.
    if accum_count < args.grad_accumulation_steps() {
      continue;
    }
    let avg = divide_weights(
      accumulated_grads
        .as_ref()
        .expect("accumulated_grads must be Some after at least one accum"),
      args.grad_accumulation_steps() as f32,
    )?;
    optimizer.apply_gradients(&avg, params)?;
    optim_step += 1;
    accumulated_grads = None;
    accum_count = 0;
    window_steps += 1;
    window_secs += window_micro_secs;
    window_micro_secs = 0.0;
    let is_last_optim_step = optim_step == total_optim_steps;
    // Periodic train-loss report (cadence in OPTIMIZER STEPS).
    if optim_step.is_multiple_of(args.steps_per_report()) || is_last_optim_step {
      // Mean train loss is denominated by COMPLETED MICROBATCHES, not by
      // optimizer-step count: `window_loss` aggregates one summand per
      // microbatch (line ~767), so dividing by `window_steps` (=
      // window_microbatches / grad_accumulation_steps) inflates the
      // reported loss by `grad_accumulation_steps×`. See trainer module
      // doc note + the regression test
      // `grad_accumulation_steps_4_reports_constant_loss_at_2_not_8`.
      let mean_loss = if window_microbatches > 0 {
        window_loss / (window_microbatches as f32)
      } else {
        0.0
      };
      let it_sec = if window_secs > 0.0 {
        (window_steps as f32) / window_secs
      } else {
        0.0
      };
      let tok_sec = if window_secs > 0.0 {
        window_tokens / window_secs
      } else {
        0.0
      };
      callback.on_train_loss_report(&TrainInfo::new(
        optim_step,
        mean_loss,
        optimizer.learning_rate(),
        it_sec,
        tok_sec,
        trained_tokens,
      ));
      window_loss = 0.0;
      window_tokens = 0.0;
      window_steps = 0;
      window_microbatches = 0;
      window_secs = 0.0;
    }
    // Periodic mid-training eval (cadence in OPTIMIZER STEPS). Fires
    // both on the regular cadence and at the final optimizer step (so
    // the caller always sees an end-of-training validation).
    if let Some(val) = val_dataset
      && (optim_step.is_multiple_of(args.steps_per_eval()) || is_last_optim_step)
    {
      run_val(model, val, args, optim_step, callback, &loss_fn)?;
    }
    // Periodic save hook (cadence in OPTIMIZER STEPS).
    if optim_step.is_multiple_of(args.steps_per_save()) {
      callback.on_save(optim_step, args.adapter_file())?;
    }
  }
  // Final save hook (Python: writes adapters.safetensors at the end).
  // Iteration label is the LAST optimizer step (0 if there were no
  // optimizer steps, e.g. iters < grad_accumulation_steps).
  callback.on_save(optim_step, args.adapter_file())?;
  Ok(())
}

/// Run one validation pass and dispatch `on_val_loss_report` with the
/// matching [`ValInfo`]. Centralized so the train loop's pre-loop
/// (iteration 0) and per-step (iteration = `optim_step`) val call sites
/// share one body.
fn run_val<M, D, L, C>(
  model: &M,
  val: &D,
  args: &TrainingArgs,
  iteration: usize,
  callback: &mut C,
  loss_fn: &L,
) -> Result<()>
where
  M: Model,
  D: Dataset,
  L: Fn(&M, &Array, &Array) -> Result<(Array, Array)>,
  C: TrainingCallback,
{
  let val_start = Instant::now();
  let val_loss = evaluate(
    model,
    val,
    args.batch_size(),
    args.val_batches(),
    args.max_seq_length(),
    |m, b, l| (loss_fn)(m, b, l),
  )?;
  let val_time = val_start.elapsed().as_secs_f32();
  callback.on_val_loss_report(&ValInfo::new(iteration, val_loss, val_time));
  Ok(())
}

/// Build a `Weights` with `zeros_like` of each entry's `Array`. Used by
/// the v1 [`train`] loop's pass-through gradient path — production
/// integration replaces this with `value_and_grad(loss_closure)` over a
/// future `Module` trait that maps `params → loss`.
fn build_zero_grads(params: &Weights) -> Result<Weights> {
  let mut grads: Weights = HashMap::with_capacity(params.len());
  for (key, value) in params {
    grads.insert(key.clone(), crate::ops::misc::zeros_like(value)?);
  }
  Ok(grads)
}

/// Element-wise sum of two parameter-keyed gradient maps. `a` and `b`
/// must share the same key set (the trainer always builds them from the
/// same `params`, so this is an internal invariant — any missing key is
/// reported via [`Error::MissingKey`] rather than silently dropped).
fn add_weights(a: &Weights, b: &Weights) -> Result<Weights> {
  if a.len() != b.len() {
    return Err(Error::LengthMismatch(LengthMismatchPayload::new(
      "trainer::add_weights: lhs vs rhs key counts",
      a.len(),
      b.len(),
    )));
  }
  let mut out: Weights = HashMap::with_capacity(a.len());
  for (key, lhs) in a {
    let Some(rhs) = b.get(key) else {
      return Err(Error::MissingKey(MissingKeyPayload::new(
        "trainer::add_weights: key missing from rhs",
        key.as_str(),
      )));
    };
    out.insert(key.clone(), arithmetic::add(lhs, rhs)?);
  }
  Ok(out)
}

/// Scalar-divide every entry in a parameter-keyed gradient map by
/// `divisor`. Used by the [`train`] gradient-accumulation path to
/// average the per-microbatch summed gradients before dispatching to
/// the optimizer.
fn divide_weights(w: &Weights, divisor: f32) -> Result<Weights> {
  let divisor_scalar = Array::full::<f32>(&[0i32; 0], divisor)?;
  let mut out: Weights = HashMap::with_capacity(w.len());
  for (key, value) in w {
    out.insert(key.clone(), arithmetic::divide(value, &divisor_scalar)?);
  }
  Ok(out)
}

#[cfg(test)]
mod tests;