tokitai-operator 0.1.0

Verified DL kernel compiler: formally-checked GEMM, p-adic, sheaf, contract-carrying ops. Paper-artifact grade.
Documentation
//! Training step driver and AdamW optimizer (gated on `rocm-hip`).
//!
//! This module tree wires the per-parameter optimizer step
//! (`adamw.rs`), the per-tensor `Parameter` type
//! (`parameter.rs`), the HIP/AdamW bridge (`hip_adamw_bridge.rs`),
//! and the end-to-end `train_step` driver (`step.rs`).
//!
//! The training driver composes a user-supplied forward + loss +
//! backward with the `Parameter::adamw_step` (pure-CPU AdamW) or
//! `run_rocm_hip_adamw_step` (HIP AdamW, gated on `rocm-hip`).
//!
//! See `src/bin/train_quality_moe.rs` for the end-to-end 0.7B
//! MoE training runner and `src/training_runner.rs` for the
//! orchestrator that ties the training step to the dataset
//! bridge and the model architecture.
//!
// Phase 2.2 training step driver.
//
// Composes a user-supplied forward + backward pair with the Phase 1
// ROCm/HIP AdamW kernel into a single `train_step` call. The
// underlying kernel contract (fp16 weights, fp32 moments, fp16 grad,
// 1-based step index) is preserved end to end: gradients returned
// by the user's backward closure are fp32 on the host, rounded to
// fp16 by `AdamW::step` before the kernel is launched.
//
// The fp16 representation is stored as `u16` bit patterns inside
// `Tensor<u16>` to match the storage convention used by the kernels
// in `src/backend/hip_*.rs`.

pub mod adamw;
pub mod hip_adamw_bridge;
pub mod parameter;
pub mod step;

pub use adamw::AdamW;
pub use parameter::{Parameter, f32_to_fp16_bits, fp16_bits_to_f32};
pub use step::{BackwardOutput, mse_loss_backward, train_step};