1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
//! Training step driver and AdamW optimizer (gated on `rocm-hip`).
//!
//! This module tree wires the per-parameter optimizer step
//! (`adamw.rs`), the per-tensor `Parameter` type
//! (`parameter.rs`), the HIP/AdamW bridge (`hip_adamw_bridge.rs`),
//! and the end-to-end `train_step` driver (`step.rs`).
//!
//! The training driver composes a user-supplied forward + loss +
//! backward with the `Parameter::adamw_step` (pure-CPU AdamW) or
//! `run_rocm_hip_adamw_step` (HIP AdamW, gated on `rocm-hip`).
//!
//! See `src/bin/train_quality_moe.rs` for the end-to-end 0.7B
//! MoE training runner and `src/training_runner.rs` for the
//! orchestrator that ties the training step to the dataset
//! bridge and the model architecture.
//!
// Phase 2.2 training step driver.
//
// Composes a user-supplied forward + backward pair with the Phase 1
// ROCm/HIP AdamW kernel into a single `train_step` call. The
// underlying kernel contract (fp16 weights, fp32 moments, fp16 grad,
// 1-based step index) is preserved end to end: gradients returned
// by the user's backward closure are fp32 on the host, rounded to
// fp16 by `AdamW::step` before the kernel is launched.
//
// The fp16 representation is stored as `u16` bit patterns inside
// `Tensor<u16>` to match the storage convention used by the kernels
// in `src/backend/hip_*.rs`.
pub use AdamW;
pub use ;
pub use ;