use crate::backends::llm_policy::OptimizerKind;
use anyhow::{Context, Result, bail};
use serde_json::json;
use std::fs::File;
use std::io::Write;
use std::path::Path;
use std::time::Instant;
use wide::f32x8;
use super::kernel;
use super::profiling::{NullProfiler, ProfilerSink};
use super::tensor::Tensor1D;
use super::weights::Weights;
#[derive(Debug, Clone)]
pub struct Config {
pub vocab_size: usize,
pub hidden_size: usize,
pub num_layers: usize,
pub num_heads: usize,
pub head_dim: usize,
pub intermediate_size: usize,
pub layer_norm_eps: f32,
pub group_norm_eps: f32,
pub decay_low_rank: usize, pub a_low_rank: usize,
pub v_low_rank: usize,
pub g_low_rank: usize,
}
impl Default for Config {
fn default() -> Self {
Self {
vocab_size: 256,
hidden_size: 256,
num_layers: 12,
num_heads: 4, head_dim: 64,
intermediate_size: 1024,
layer_norm_eps: 1e-5,
group_norm_eps: 64e-5,
decay_low_rank: 32,
a_low_rank: 32,
v_low_rank: 32,
g_low_rank: 64,
}
}
}
impl Config {
pub fn validate(&self) -> Result<()> {
if self.vocab_size == 0 {
bail!("rwkv7 vocab_size must be > 0");
}
if self.head_dim != 64 {
bail!("rwkv7 head_dim must be 64 for current kernels");
}
if self.hidden_size != self.num_heads * self.head_dim {
bail!(
"rwkv7 hidden_size must equal num_heads * head_dim ({} != {} * {})",
self.hidden_size,
self.num_heads,
self.head_dim
);
}
if self.num_layers == 0 {
bail!("rwkv7 num_layers must be > 0");
}
if self.intermediate_size == 0 {
bail!("rwkv7 intermediate_size must be > 0");
}
Ok(())
}
}
#[derive(Clone)]
pub struct LayerState {
pub att_x_prev: Tensor1D,
pub att_state: Tensor1D, pub ffn_x_prev: Tensor1D,
}
impl LayerState {
fn new(cfg: &Config) -> Self {
let state_size = cfg.num_heads * cfg.head_dim * cfg.head_dim;
Self {
att_x_prev: Tensor1D::zeros(cfg.hidden_size),
att_state: Tensor1D::zeros(state_size),
ffn_x_prev: Tensor1D::zeros(cfg.hidden_size),
}
}
}
#[derive(Clone)]
pub struct State {
pub layers: Vec<LayerState>,
pub v_first: Tensor1D,
pub v_first_set: bool,
}
impl State {
pub fn new(cfg: &Config) -> Self {
Self {
layers: (0..cfg.num_layers).map(|_| LayerState::new(cfg)).collect(),
v_first: Tensor1D::zeros(cfg.hidden_size),
v_first_set: false,
}
}
pub fn reset(&mut self) {
self.v_first_set = false;
self.v_first.zero();
for layer in &mut self.layers {
layer.att_x_prev.zero();
layer.att_state.zero();
layer.ffn_x_prev.zero();
}
}
}
#[derive(Clone)]
struct AttentionWeights {
x_r: Tensor1D,
x_w: Tensor1D,
x_k: Tensor1D,
x_v: Tensor1D,
x_a: Tensor1D,
x_g: Tensor1D,
rkv_proj: Tensor1D,
o_proj: Tensor1D,
w1: Tensor1D, w2: Tensor1D, w0: Tensor1D,
a1: Tensor1D, a2: Tensor1D, a0: Tensor1D,
v1: Option<Tensor1D>, v2: Option<Tensor1D>, v0: Option<Tensor1D>,
g1: Tensor1D, g2: Tensor1D,
k_k: Tensor1D, k_a: Tensor1D, r_k: Tensor1D,
g_norm_w: Tensor1D, g_norm_b: Tensor1D, }
#[derive(Clone)]
struct FfnWeights {
x_k: Tensor1D, key_w: Tensor1D, value_w: Tensor1D, }
#[derive(Clone)]
struct BlockWeights {
pre_norm_w: Option<Tensor1D>,
pre_norm_b: Option<Tensor1D>,
attn_norm_w: Tensor1D,
attn_norm_b: Tensor1D,
ffn_norm_w: Tensor1D,
ffn_norm_b: Tensor1D,
attn: AttentionWeights,
ffn: FfnWeights,
}
#[derive(Clone)]
pub struct Model {
cfg: Config,
embeddings: Tensor1D,
ln_out_w: Tensor1D,
ln_out_b: Tensor1D,
lm_head: Tensor1D,
blocks: Vec<BlockWeights>,
}
#[derive(Clone)]
struct AdamTensorState {
m: Tensor1D,
v: Tensor1D,
}
impl AdamTensorState {
#[inline]
fn new(len: usize) -> Self {
Self {
m: Tensor1D::zeros(len),
v: Tensor1D::zeros(len),
}
}
}
#[derive(Clone)]
struct AttentionAdamState {
x_r: AdamTensorState,
x_w: AdamTensorState,
x_k: AdamTensorState,
x_v: AdamTensorState,
x_a: AdamTensorState,
x_g: AdamTensorState,
rkv_proj: AdamTensorState,
o_proj: AdamTensorState,
w1: AdamTensorState,
w2: AdamTensorState,
w0: AdamTensorState,
a1: AdamTensorState,
a2: AdamTensorState,
a0: AdamTensorState,
v1: Option<AdamTensorState>,
v2: Option<AdamTensorState>,
v0: Option<AdamTensorState>,
g1: AdamTensorState,
g2: AdamTensorState,
k_k: AdamTensorState,
k_a: AdamTensorState,
r_k: AdamTensorState,
g_norm_w: AdamTensorState,
g_norm_b: AdamTensorState,
}
#[derive(Clone)]
struct FfnAdamState {
x_k: AdamTensorState,
key_w: AdamTensorState,
value_w: AdamTensorState,
}
#[derive(Clone)]
struct BlockAdamState {
pre_norm_w: Option<AdamTensorState>,
pre_norm_b: Option<AdamTensorState>,
attn_norm_w: AdamTensorState,
attn_norm_b: AdamTensorState,
ffn_norm_w: AdamTensorState,
ffn_norm_b: AdamTensorState,
attn: AttentionAdamState,
ffn: FfnAdamState,
}
#[derive(Clone)]
pub struct FullAdamState {
embeddings: AdamTensorState,
ln_out_w: AdamTensorState,
ln_out_b: AdamTensorState,
lm_head: AdamTensorState,
blocks: Vec<BlockAdamState>,
}
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
pub struct TrainScopeMask {
pub embed: bool,
pub pre_norm: bool,
pub attn_norm: bool,
pub ffn_norm: bool,
pub attn: bool,
pub ffn: bool,
pub head: bool,
pub bias: bool,
}
impl TrainScopeMask {
#[inline]
pub fn all() -> Self {
Self {
embed: true,
pre_norm: true,
attn_norm: true,
ffn_norm: true,
attn: true,
ffn: true,
head: true,
bias: true,
}
}
#[inline]
pub fn trains_non_head_params(&self) -> bool {
self.embed || self.pre_norm || self.attn_norm || self.ffn_norm || self.attn || self.ffn
}
#[inline]
pub fn trains_any_params(&self) -> bool {
self.trains_non_head_params() || self.head || self.bias
}
}
#[derive(Clone)]
struct AttentionGradState {
x_r: Tensor1D,
x_w: Tensor1D,
x_k: Tensor1D,
x_v: Tensor1D,
x_a: Tensor1D,
x_g: Tensor1D,
rkv_proj: Tensor1D,
o_proj: Tensor1D,
w1: Tensor1D,
w2: Tensor1D,
w0: Tensor1D,
a1: Tensor1D,
a2: Tensor1D,
a0: Tensor1D,
v1: Option<Tensor1D>,
v2: Option<Tensor1D>,
v0: Option<Tensor1D>,
g1: Tensor1D,
g2: Tensor1D,
k_k: Tensor1D,
k_a: Tensor1D,
r_k: Tensor1D,
g_norm_w: Tensor1D,
g_norm_b: Tensor1D,
}
#[derive(Clone)]
struct FfnGradState {
x_k: Tensor1D,
key_w: Tensor1D,
value_w: Tensor1D,
}
#[derive(Clone)]
struct BlockGradState {
pre_norm_w: Option<Tensor1D>,
pre_norm_b: Option<Tensor1D>,
attn_norm_w: Tensor1D,
attn_norm_b: Tensor1D,
ffn_norm_w: Tensor1D,
ffn_norm_b: Tensor1D,
attn: AttentionGradState,
ffn: FfnGradState,
}
#[derive(Clone)]
struct FullGradState {
embeddings: Tensor1D,
ln_out_w: Tensor1D,
ln_out_b: Tensor1D,
lm_head: Tensor1D,
blocks: Vec<BlockGradState>,
}
struct AdamStep {
lr: f32,
clip: f32,
b1: f32,
b2: f32,
eps: f32,
bias_corr1: f32,
bias_corr2: f32,
}
#[derive(Clone)]
struct LayerTrainTrace {
x_in: Tensor1D,
x_after_pre: Tensor1D,
attn_norm: Tensor1D,
att_x_prev_old: Tensor1D,
ffn_x_prev_old: Tensor1D,
att_state_old: Tensor1D,
xr: Tensor1D,
xw: Tensor1D,
xk: Tensor1D,
xv: Tensor1D,
xa: Tensor1D,
xg: Tensor1D,
r: Tensor1D,
k_pre: Tensor1D,
k: Tensor1D,
v_pre: Tensor1D,
v: Tensor1D,
nu: Tensor1D,
w_hidden: Tensor1D,
w_pre: Tensor1D,
w_sigmoid: Tensor1D,
w_decay: Tensor1D,
a_hidden: Tensor1D,
a: Tensor1D,
g_hidden: Tensor1D,
g: Tensor1D,
kk_pre: Tensor1D,
kk: Tensor1D,
y_wkv: Tensor1D,
y_gn: Tensor1D,
alpha: Tensor1D,
y_head: Tensor1D,
y_gate: Tensor1D,
att_out: Tensor1D,
x_after_attn: Tensor1D,
ffn_norm: Tensor1D,
ffn_xk: Tensor1D,
ffn_pre: Tensor1D,
ffn_k: Tensor1D,
ffn_out: Tensor1D,
x_out: Tensor1D,
v_hidden: Tensor1D,
uses_v_residual: bool,
}
impl LayerTrainTrace {
fn new(cfg: &Config) -> Self {
let c = cfg.hidden_size;
let i = cfg.intermediate_size;
let state = cfg.num_heads * cfg.head_dim * cfg.head_dim;
Self {
x_in: Tensor1D::zeros(c),
x_after_pre: Tensor1D::zeros(c),
attn_norm: Tensor1D::zeros(c),
att_x_prev_old: Tensor1D::zeros(c),
ffn_x_prev_old: Tensor1D::zeros(c),
att_state_old: Tensor1D::zeros(state),
xr: Tensor1D::zeros(c),
xw: Tensor1D::zeros(c),
xk: Tensor1D::zeros(c),
xv: Tensor1D::zeros(c),
xa: Tensor1D::zeros(c),
xg: Tensor1D::zeros(c),
r: Tensor1D::zeros(c),
k_pre: Tensor1D::zeros(c),
k: Tensor1D::zeros(c),
v_pre: Tensor1D::zeros(c),
v: Tensor1D::zeros(c),
nu: Tensor1D::zeros(c),
w_hidden: Tensor1D::zeros(cfg.decay_low_rank),
w_pre: Tensor1D::zeros(c),
w_sigmoid: Tensor1D::zeros(c),
w_decay: Tensor1D::zeros(c),
a_hidden: Tensor1D::zeros(cfg.a_low_rank),
a: Tensor1D::zeros(c),
g_hidden: Tensor1D::zeros(cfg.g_low_rank),
g: Tensor1D::zeros(c),
kk_pre: Tensor1D::zeros(c),
kk: Tensor1D::zeros(c),
y_wkv: Tensor1D::zeros(c),
y_gn: Tensor1D::zeros(c),
alpha: Tensor1D::zeros(cfg.num_heads),
y_head: Tensor1D::zeros(c),
y_gate: Tensor1D::zeros(c),
att_out: Tensor1D::zeros(c),
x_after_attn: Tensor1D::zeros(c),
ffn_norm: Tensor1D::zeros(c),
ffn_xk: Tensor1D::zeros(c),
ffn_pre: Tensor1D::zeros(i),
ffn_k: Tensor1D::zeros(i),
ffn_out: Tensor1D::zeros(c),
x_out: Tensor1D::zeros(c),
v_hidden: Tensor1D::zeros(cfg.v_low_rank.max(1)),
uses_v_residual: false,
}
}
}
#[derive(Clone)]
struct TokenTrainTrace {
token: usize,
x: Tensor1D,
x_normed: Tensor1D,
v_first: Tensor1D,
layers: Vec<LayerTrainTrace>,
}
impl TokenTrainTrace {
fn from_scratch(scratch: &ScratchBuffers) -> Self {
Self {
token: scratch.train_token,
x: scratch.x.clone(),
x_normed: scratch.x_normed.clone(),
v_first: scratch.train_v_first.clone(),
layers: scratch.train_trace_layers.clone(),
}
}
}
#[derive(Clone)]
struct LayerRecurrentGradState {
att_x_prev: Tensor1D,
att_state: Tensor1D,
ffn_x_prev: Tensor1D,
}
impl LayerRecurrentGradState {
fn new(cfg: &Config) -> Self {
let state_size = cfg.num_heads * cfg.head_dim * cfg.head_dim;
Self {
att_x_prev: Tensor1D::zeros(cfg.hidden_size),
att_state: Tensor1D::zeros(state_size),
ffn_x_prev: Tensor1D::zeros(cfg.hidden_size),
}
}
}
#[derive(Clone)]
struct RecurrentGradState {
layers: Vec<LayerRecurrentGradState>,
}
impl RecurrentGradState {
fn new(cfg: &Config) -> Self {
Self {
layers: (0..cfg.num_layers)
.map(|_| LayerRecurrentGradState::new(cfg))
.collect(),
}
}
fn zero(&mut self) {
for layer in &mut self.layers {
layer.att_x_prev.zero();
layer.att_state.zero();
layer.ffn_x_prev.zero();
}
}
}
#[derive(Clone)]
pub struct ScratchBuffers {
x: Tensor1D, x_normed: Tensor1D, xr: Tensor1D, xw: Tensor1D, xk: Tensor1D, xv: Tensor1D, xa: Tensor1D, xg: Tensor1D, r: Tensor1D, k: Tensor1D, v: Tensor1D, w_lora_tmp: Tensor1D, w_decay: Tensor1D, a: Tensor1D, g: Tensor1D, kk: Tensor1D, y: Tensor1D, att_out: Tensor1D, ffn_k: Tensor1D, ffn_out: Tensor1D, logits: Tensor1D, grad_x: Tensor1D,
grad_x2: Tensor1D,
grad_x3: Tensor1D,
grad_x4: Tensor1D,
grad_x5: Tensor1D,
grad_x6: Tensor1D,
grad_v_first: Tensor1D,
grad_param: Tensor1D,
grad_param2: Tensor1D,
grad_saved: Tensor1D,
grad_ffn: Tensor1D,
grad_ffn2: Tensor1D,
grad_low_rank: Tensor1D,
grad_low_rank2: Tensor1D,
grad_att_state: Tensor1D,
grad_logits: Tensor1D,
train_trace_layers: Vec<LayerTrainTrace>,
train_token: usize,
train_v_first: Tensor1D,
train_trace_valid: bool,
capture_train_trace: bool,
}
impl ScratchBuffers {
pub fn new(cfg: &Config) -> Self {
let c = cfg.hidden_size;
let i = cfg.intermediate_size;
let v = cfg.vocab_size;
let state_size = cfg.num_heads * cfg.head_dim * cfg.head_dim;
let d_rank = cfg
.decay_low_rank
.max(cfg.a_low_rank)
.max(cfg.v_low_rank)
.max(cfg.g_low_rank)
.max(64);
let mut train_trace_layers = Vec::with_capacity(cfg.num_layers);
for _ in 0..cfg.num_layers {
train_trace_layers.push(LayerTrainTrace::new(cfg));
}
Self {
x: Tensor1D::zeros(c),
x_normed: Tensor1D::zeros(c),
xr: Tensor1D::zeros(c),
xw: Tensor1D::zeros(c),
xk: Tensor1D::zeros(c),
xv: Tensor1D::zeros(c),
xa: Tensor1D::zeros(c),
xg: Tensor1D::zeros(c),
r: Tensor1D::zeros(c),
k: Tensor1D::zeros(c),
v: Tensor1D::zeros(c),
w_lora_tmp: Tensor1D::zeros(d_rank),
w_decay: Tensor1D::zeros(c),
a: Tensor1D::zeros(c),
g: Tensor1D::zeros(c),
kk: Tensor1D::zeros(c),
y: Tensor1D::zeros(c),
att_out: Tensor1D::zeros(c),
ffn_k: Tensor1D::zeros(i),
ffn_out: Tensor1D::zeros(c),
logits: Tensor1D::zeros(v),
grad_x: Tensor1D::zeros(c),
grad_x2: Tensor1D::zeros(c),
grad_x3: Tensor1D::zeros(c),
grad_x4: Tensor1D::zeros(c),
grad_x5: Tensor1D::zeros(c),
grad_x6: Tensor1D::zeros(c),
grad_v_first: Tensor1D::zeros(c),
grad_param: Tensor1D::zeros(c),
grad_param2: Tensor1D::zeros(c),
grad_saved: Tensor1D::zeros(c),
grad_ffn: Tensor1D::zeros(i),
grad_ffn2: Tensor1D::zeros(i),
grad_low_rank: Tensor1D::zeros(d_rank),
grad_low_rank2: Tensor1D::zeros(d_rank),
grad_att_state: Tensor1D::zeros(state_size),
grad_logits: Tensor1D::zeros(v),
train_trace_layers,
train_token: 0,
train_v_first: Tensor1D::zeros(c),
train_trace_valid: false,
capture_train_trace: false,
}
}
#[inline]
pub fn lm_head_input(&self) -> &[f32] {
self.x_normed.as_slice()
}
#[inline]
pub fn logits(&self) -> &[f32] {
self.logits.as_slice()
}
#[inline]
pub fn set_lm_head_input(&mut self, value: &[f32]) {
self.x_normed.as_mut_slice().copy_from_slice(value);
}
#[inline]
pub fn set_capture_train_trace(&mut self, enabled: bool) {
self.capture_train_trace = enabled;
if !enabled {
self.train_trace_valid = false;
}
}
#[inline]
pub fn has_train_trace(&self) -> bool {
self.train_trace_valid
}
}
impl Model {
fn tensor_from(weights: &Weights, name: &str) -> Result<Tensor1D> {
Ok(Tensor1D::from_vec(weights.require(name)?.data().to_vec()))
}
fn optional_tensor_from(weights: &Weights, name: &str) -> Option<Tensor1D> {
weights
.get(name)
.map(|tensor| Tensor1D::from_vec(tensor.data().to_vec()))
}
pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> {
let weights = Weights::load(path.as_ref()).context("Failed to load model weights")?;
let emb = weights.require("model.embeddings.weight")?;
let vocab_size = emb.shape()[0];
let hidden_size = emb.shape()[1];
let num_heads = hidden_size / 64; let head_dim = 64;
let mut num_layers = 0;
while weights
.get(&format!("model.layers.{}.attn.r_proj.weight", num_layers))
.is_some()
{
num_layers += 1;
}
let ffn_key = weights.require("model.layers.0.ffn.key.weight")?;
let intermediate_size = ffn_key.shape()[0];
let w1 = weights.require("model.layers.0.attn.w_lora.lora.0.weight")?;
let decay_low_rank = w1.shape()[0];
let a1 = weights.require("model.layers.0.attn.a_lora.lora.0.weight")?;
let a_low_rank = a1.shape()[0];
let g1 = weights.require("model.layers.0.attn.g_lora.lora.0.weight")?;
let g_low_rank = g1.shape()[0];
let v_low_rank = if num_layers > 1 {
if let Some(v1) = weights.get("model.layers.1.attn.v_lora.lora.0.weight") {
v1.shape()[0]
} else {
32
}
} else {
32
};
let cfg = Config {
vocab_size,
hidden_size,
num_layers,
num_heads,
head_dim,
intermediate_size,
layer_norm_eps: 1e-5,
group_norm_eps: 64e-5,
decay_low_rank,
a_low_rank,
v_low_rank,
g_low_rank,
};
let embeddings = Self::tensor_from(&weights, "model.embeddings.weight")?;
let ln_out_w = Self::tensor_from(&weights, "model.norm.weight")?;
let ln_out_b = Self::tensor_from(&weights, "model.norm.bias")?;
let lm_head = Self::tensor_from(&weights, "lm_head.weight")?;
let mut blocks = Vec::with_capacity(num_layers);
for i in 0..num_layers {
let prefix = format!("model.layers.{}", i);
let (pre_norm_w, pre_norm_b) = if i == 0 {
(
Some(Self::tensor_from(
&weights,
&format!("{}.pre_norm.weight", prefix),
)?),
Some(Self::tensor_from(
&weights,
&format!("{}.pre_norm.bias", prefix),
)?),
)
} else {
(None, None)
};
let attn_norm_w = Self::tensor_from(&weights, &format!("{}.attn_norm.weight", prefix))?;
let attn_norm_b = Self::tensor_from(&weights, &format!("{}.attn_norm.bias", prefix))?;
let ffn_norm_w = Self::tensor_from(&weights, &format!("{}.ffn_norm.weight", prefix))?;
let ffn_norm_b = Self::tensor_from(&weights, &format!("{}.ffn_norm.bias", prefix))?;
let r_proj_data = weights
.require(&format!("{}.attn.r_proj.weight", prefix))?
.data();
let k_proj_data = weights
.require(&format!("{}.attn.k_proj.weight", prefix))?
.data();
let v_proj_data = weights
.require(&format!("{}.attn.v_proj.weight", prefix))?
.data();
let proj_size = hidden_size * hidden_size;
let mut rkv_proj = Tensor1D::zeros(3 * proj_size);
rkv_proj.as_mut_slice()[0..proj_size].copy_from_slice(r_proj_data);
rkv_proj.as_mut_slice()[proj_size..2 * proj_size].copy_from_slice(k_proj_data);
rkv_proj.as_mut_slice()[2 * proj_size..3 * proj_size].copy_from_slice(v_proj_data);
let attn = AttentionWeights {
x_r: Self::tensor_from(&weights, &format!("{}.attn.x_r", prefix))?,
x_w: Self::tensor_from(&weights, &format!("{}.attn.x_w", prefix))?,
x_k: Self::tensor_from(&weights, &format!("{}.attn.x_k", prefix))?,
x_v: Self::tensor_from(&weights, &format!("{}.attn.x_v", prefix))?,
x_a: Self::tensor_from(&weights, &format!("{}.attn.x_a", prefix))?,
x_g: Self::tensor_from(&weights, &format!("{}.attn.x_g", prefix))?,
rkv_proj,
o_proj: Self::tensor_from(&weights, &format!("{}.attn.o_proj.weight", prefix))?,
w1: Self::tensor_from(&weights, &format!("{}.attn.w_lora.lora.0.weight", prefix))?,
w2: Self::tensor_from(&weights, &format!("{}.attn.w_lora.lora.2.weight", prefix))?,
w0: Self::tensor_from(&weights, &format!("{}.attn.w_lora.lora.2.bias", prefix))?,
a1: Self::tensor_from(&weights, &format!("{}.attn.a_lora.lora.0.weight", prefix))?,
a2: Self::tensor_from(&weights, &format!("{}.attn.a_lora.lora.2.weight", prefix))?,
a0: Self::tensor_from(&weights, &format!("{}.attn.a_lora.lora.2.bias", prefix))?,
v1: Self::optional_tensor_from(
&weights,
&format!("{}.attn.v_lora.lora.0.weight", prefix),
),
v2: Self::optional_tensor_from(
&weights,
&format!("{}.attn.v_lora.lora.2.weight", prefix),
),
v0: Self::optional_tensor_from(
&weights,
&format!("{}.attn.v_lora.lora.2.bias", prefix),
),
g1: Self::tensor_from(&weights, &format!("{}.attn.g_lora.lora.0.weight", prefix))?,
g2: Self::tensor_from(&weights, &format!("{}.attn.g_lora.lora.2.weight", prefix))?,
k_k: Self::tensor_from(&weights, &format!("{}.attn.k_k", prefix))?,
k_a: Self::tensor_from(&weights, &format!("{}.attn.k_a", prefix))?,
r_k: Self::tensor_from(&weights, &format!("{}.attn.r_k", prefix))?,
g_norm_w: Self::tensor_from(&weights, &format!("{}.attn.g_norm.weight", prefix))?,
g_norm_b: Self::tensor_from(&weights, &format!("{}.attn.g_norm.bias", prefix))?,
};
let ffn = FfnWeights {
x_k: Self::tensor_from(&weights, &format!("{}.ffn.x_k", prefix))?,
key_w: Self::tensor_from(&weights, &format!("{}.ffn.key.weight", prefix))?,
value_w: Self::tensor_from(&weights, &format!("{}.ffn.value.weight", prefix))?,
};
blocks.push(BlockWeights {
pre_norm_w,
pre_norm_b,
attn_norm_w,
attn_norm_b,
ffn_norm_w,
ffn_norm_b,
attn,
ffn,
});
}
Ok(Self {
cfg,
embeddings,
ln_out_w,
ln_out_b,
lm_head,
blocks,
})
}
pub fn new_random(cfg: Config, seed: u64) -> Result<Self> {
cfg.validate()?;
let mut rng = RwkvRng::new(seed);
let c = cfg.hidden_size;
let v = cfg.vocab_size;
let i = cfg.intermediate_size;
let d_w = cfg.decay_low_rank;
let d_a = cfg.a_low_rank;
let d_v = cfg.v_low_rank;
let d_g = cfg.g_low_rank;
let mut embeddings = Tensor1D::zeros(v * c);
init_uniform(&mut embeddings, &mut rng, 0.02);
let mut ln_out_w = Tensor1D::zeros(c);
let mut ln_out_b = Tensor1D::zeros(c);
init_const(&mut ln_out_w, 1.0);
init_const(&mut ln_out_b, 0.0);
let mut lm_head = Tensor1D::zeros(v * c);
init_uniform(&mut lm_head, &mut rng, 0.02);
let mut blocks = Vec::with_capacity(cfg.num_layers);
for layer_idx in 0..cfg.num_layers {
let (pre_norm_w, pre_norm_b) = if layer_idx == 0 {
let mut w = Tensor1D::zeros(c);
let mut b = Tensor1D::zeros(c);
init_const(&mut w, 1.0);
init_const(&mut b, 0.0);
(Some(w), Some(b))
} else {
(None, None)
};
let mut attn_norm_w = Tensor1D::zeros(c);
let mut attn_norm_b = Tensor1D::zeros(c);
init_const(&mut attn_norm_w, 1.0);
init_const(&mut attn_norm_b, 0.0);
let mut ffn_norm_w = Tensor1D::zeros(c);
let mut ffn_norm_b = Tensor1D::zeros(c);
init_const(&mut ffn_norm_w, 1.0);
init_const(&mut ffn_norm_b, 0.0);
let mut rkv_proj = Tensor1D::zeros(3 * c * c);
init_uniform(&mut rkv_proj, &mut rng, 0.02);
let mut o_proj = Tensor1D::zeros(c * c);
init_uniform(&mut o_proj, &mut rng, 0.02);
let mut w1 = Tensor1D::zeros(d_w * c);
let mut w2 = Tensor1D::zeros(c * d_w);
let mut w0 = Tensor1D::zeros(c);
init_uniform(&mut w1, &mut rng, 0.02);
init_uniform(&mut w2, &mut rng, 0.02);
init_const(&mut w0, 0.0);
let mut a1 = Tensor1D::zeros(d_a * c);
let mut a2 = Tensor1D::zeros(c * d_a);
let mut a0 = Tensor1D::zeros(c);
init_uniform(&mut a1, &mut rng, 0.02);
init_uniform(&mut a2, &mut rng, 0.02);
init_const(&mut a0, 0.0);
let (v1, v2, v0) = if layer_idx == 0 {
(None, None, None)
} else {
let mut v1 = Tensor1D::zeros(d_v * c);
let mut v2 = Tensor1D::zeros(c * d_v);
let mut v0 = Tensor1D::zeros(c);
init_uniform(&mut v1, &mut rng, 0.02);
init_uniform(&mut v2, &mut rng, 0.02);
init_const(&mut v0, 0.0);
(Some(v1), Some(v2), Some(v0))
};
let mut g1 = Tensor1D::zeros(d_g * c);
let mut g2 = Tensor1D::zeros(c * d_g);
init_uniform(&mut g1, &mut rng, 0.02);
init_uniform(&mut g2, &mut rng, 0.02);
let mut x_r = Tensor1D::zeros(c);
let mut x_w = Tensor1D::zeros(c);
let mut x_k = Tensor1D::zeros(c);
let mut x_v = Tensor1D::zeros(c);
let mut x_a = Tensor1D::zeros(c);
let mut x_g = Tensor1D::zeros(c);
init_centered(&mut x_r, &mut rng, 0.5, 0.02);
init_centered(&mut x_w, &mut rng, 0.5, 0.02);
init_centered(&mut x_k, &mut rng, 0.5, 0.02);
init_centered(&mut x_v, &mut rng, 0.5, 0.02);
init_centered(&mut x_a, &mut rng, 0.5, 0.02);
init_centered(&mut x_g, &mut rng, 0.5, 0.02);
let mut k_k = Tensor1D::zeros(c);
let mut k_a = Tensor1D::zeros(c);
let mut r_k = Tensor1D::zeros(c);
init_const(&mut k_k, 1.0);
init_const(&mut k_a, 1.0);
init_const(&mut r_k, 1.0);
let mut g_norm_w = Tensor1D::zeros(c);
let mut g_norm_b = Tensor1D::zeros(c);
init_const(&mut g_norm_w, 1.0);
init_const(&mut g_norm_b, 0.0);
let attn = AttentionWeights {
x_r,
x_w,
x_k,
x_v,
x_a,
x_g,
rkv_proj,
o_proj,
w1,
w2,
w0,
a1,
a2,
a0,
v1,
v2,
v0,
g1,
g2,
k_k,
k_a,
r_k,
g_norm_w,
g_norm_b,
};
let mut ffn_x_k = Tensor1D::zeros(c);
init_centered(&mut ffn_x_k, &mut rng, 0.5, 0.02);
let mut key_w = Tensor1D::zeros(i * c);
let mut value_w = Tensor1D::zeros(c * i);
init_uniform(&mut key_w, &mut rng, 0.02);
init_uniform(&mut value_w, &mut rng, 0.02);
let ffn = FfnWeights {
x_k: ffn_x_k,
key_w,
value_w,
};
blocks.push(BlockWeights {
pre_norm_w,
pre_norm_b,
attn_norm_w,
attn_norm_b,
ffn_norm_w,
ffn_norm_b,
attn,
ffn,
});
}
Ok(Self {
cfg,
embeddings,
ln_out_w,
ln_out_b,
lm_head,
blocks,
})
}
pub fn save_safetensors<P: AsRef<Path>>(&self, path: P) -> Result<()> {
#[derive(Clone)]
struct TensorRec {
name: String,
shape: Vec<usize>,
data: Vec<f32>,
}
let c = self.cfg.hidden_size;
let v = self.cfg.vocab_size;
let i = self.cfg.intermediate_size;
let d_w = self.cfg.decay_low_rank;
let d_a = self.cfg.a_low_rank;
let d_v = self.cfg.v_low_rank;
let d_g = self.cfg.g_low_rank;
let mut recs = Vec::<TensorRec>::new();
let push = |recs: &mut Vec<TensorRec>, name: String, shape: Vec<usize>, src: &Tensor1D| {
recs.push(TensorRec {
name,
shape,
data: src.as_slice().to_vec(),
});
};
push(
&mut recs,
"model.embeddings.weight".to_string(),
vec![v, c],
&self.embeddings,
);
push(
&mut recs,
"model.norm.weight".to_string(),
vec![c],
&self.ln_out_w,
);
push(
&mut recs,
"model.norm.bias".to_string(),
vec![c],
&self.ln_out_b,
);
push(
&mut recs,
"lm_head.weight".to_string(),
vec![v, c],
&self.lm_head,
);
for (idx, b) in self.blocks.iter().enumerate() {
let pfx = format!("model.layers.{idx}");
if let (Some(w), Some(bias)) = (&b.pre_norm_w, &b.pre_norm_b) {
push(&mut recs, format!("{pfx}.pre_norm.weight"), vec![c], w);
push(&mut recs, format!("{pfx}.pre_norm.bias"), vec![c], bias);
}
push(
&mut recs,
format!("{pfx}.attn_norm.weight"),
vec![c],
&b.attn_norm_w,
);
push(
&mut recs,
format!("{pfx}.attn_norm.bias"),
vec![c],
&b.attn_norm_b,
);
push(
&mut recs,
format!("{pfx}.ffn_norm.weight"),
vec![c],
&b.ffn_norm_w,
);
push(
&mut recs,
format!("{pfx}.ffn_norm.bias"),
vec![c],
&b.ffn_norm_b,
);
let proj = b.attn.rkv_proj.as_slice();
let proj_size = c * c;
recs.push(TensorRec {
name: format!("{pfx}.attn.r_proj.weight"),
shape: vec![c, c],
data: proj[0..proj_size].to_vec(),
});
recs.push(TensorRec {
name: format!("{pfx}.attn.k_proj.weight"),
shape: vec![c, c],
data: proj[proj_size..2 * proj_size].to_vec(),
});
recs.push(TensorRec {
name: format!("{pfx}.attn.v_proj.weight"),
shape: vec![c, c],
data: proj[2 * proj_size..3 * proj_size].to_vec(),
});
push(
&mut recs,
format!("{pfx}.attn.o_proj.weight"),
vec![c, c],
&b.attn.o_proj,
);
push(&mut recs, format!("{pfx}.attn.x_r"), vec![c], &b.attn.x_r);
push(&mut recs, format!("{pfx}.attn.x_w"), vec![c], &b.attn.x_w);
push(&mut recs, format!("{pfx}.attn.x_k"), vec![c], &b.attn.x_k);
push(&mut recs, format!("{pfx}.attn.x_v"), vec![c], &b.attn.x_v);
push(&mut recs, format!("{pfx}.attn.x_a"), vec![c], &b.attn.x_a);
push(&mut recs, format!("{pfx}.attn.x_g"), vec![c], &b.attn.x_g);
push(
&mut recs,
format!("{pfx}.attn.w_lora.lora.0.weight"),
vec![d_w, c],
&b.attn.w1,
);
push(
&mut recs,
format!("{pfx}.attn.w_lora.lora.2.weight"),
vec![c, d_w],
&b.attn.w2,
);
push(
&mut recs,
format!("{pfx}.attn.w_lora.lora.2.bias"),
vec![c],
&b.attn.w0,
);
push(
&mut recs,
format!("{pfx}.attn.a_lora.lora.0.weight"),
vec![d_a, c],
&b.attn.a1,
);
push(
&mut recs,
format!("{pfx}.attn.a_lora.lora.2.weight"),
vec![c, d_a],
&b.attn.a2,
);
push(
&mut recs,
format!("{pfx}.attn.a_lora.lora.2.bias"),
vec![c],
&b.attn.a0,
);
if let Some(v1) = &b.attn.v1 {
push(
&mut recs,
format!("{pfx}.attn.v_lora.lora.0.weight"),
vec![d_v, c],
v1,
);
}
if let Some(v2) = &b.attn.v2 {
push(
&mut recs,
format!("{pfx}.attn.v_lora.lora.2.weight"),
vec![c, d_v],
v2,
);
}
if let Some(v0) = &b.attn.v0 {
push(
&mut recs,
format!("{pfx}.attn.v_lora.lora.2.bias"),
vec![c],
v0,
);
}
push(
&mut recs,
format!("{pfx}.attn.g_lora.lora.0.weight"),
vec![d_g, c],
&b.attn.g1,
);
push(
&mut recs,
format!("{pfx}.attn.g_lora.lora.2.weight"),
vec![c, d_g],
&b.attn.g2,
);
push(&mut recs, format!("{pfx}.attn.k_k"), vec![c], &b.attn.k_k);
push(&mut recs, format!("{pfx}.attn.k_a"), vec![c], &b.attn.k_a);
push(&mut recs, format!("{pfx}.attn.r_k"), vec![c], &b.attn.r_k);
push(
&mut recs,
format!("{pfx}.attn.g_norm.weight"),
vec![c],
&b.attn.g_norm_w,
);
push(
&mut recs,
format!("{pfx}.attn.g_norm.bias"),
vec![c],
&b.attn.g_norm_b,
);
push(&mut recs, format!("{pfx}.ffn.x_k"), vec![c], &b.ffn.x_k);
push(
&mut recs,
format!("{pfx}.ffn.key.weight"),
vec![i, c],
&b.ffn.key_w,
);
push(
&mut recs,
format!("{pfx}.ffn.value.weight"),
vec![c, i],
&b.ffn.value_w,
);
}
recs.sort_by(|a, b| a.name.cmp(&b.name));
let mut offset = 0usize;
let mut header = serde_json::Map::new();
header.insert("__metadata__".to_string(), json!({}));
for rec in &recs {
let bytes = rec.data.len() * 4;
header.insert(
rec.name.clone(),
json!({
"dtype": "F32",
"shape": rec.shape,
"data_offsets": [offset, offset + bytes]
}),
);
offset += bytes;
}
let header_bytes = serde_json::to_vec(&header)?;
let mut f = File::create(path.as_ref())?;
f.write_all(&(header_bytes.len() as u64).to_le_bytes())?;
f.write_all(&header_bytes)?;
for rec in &recs {
for v in &rec.data {
f.write_all(&v.to_le_bytes())?;
}
}
Ok(())
}
pub fn new_full_adam_state(&self) -> FullAdamState {
let mut blocks = Vec::with_capacity(self.blocks.len());
for b in &self.blocks {
blocks.push(BlockAdamState {
pre_norm_w: b.pre_norm_w.as_ref().map(|t| AdamTensorState::new(t.len())),
pre_norm_b: b.pre_norm_b.as_ref().map(|t| AdamTensorState::new(t.len())),
attn_norm_w: AdamTensorState::new(b.attn_norm_w.len()),
attn_norm_b: AdamTensorState::new(b.attn_norm_b.len()),
ffn_norm_w: AdamTensorState::new(b.ffn_norm_w.len()),
ffn_norm_b: AdamTensorState::new(b.ffn_norm_b.len()),
attn: AttentionAdamState {
x_r: AdamTensorState::new(b.attn.x_r.len()),
x_w: AdamTensorState::new(b.attn.x_w.len()),
x_k: AdamTensorState::new(b.attn.x_k.len()),
x_v: AdamTensorState::new(b.attn.x_v.len()),
x_a: AdamTensorState::new(b.attn.x_a.len()),
x_g: AdamTensorState::new(b.attn.x_g.len()),
rkv_proj: AdamTensorState::new(b.attn.rkv_proj.len()),
o_proj: AdamTensorState::new(b.attn.o_proj.len()),
w1: AdamTensorState::new(b.attn.w1.len()),
w2: AdamTensorState::new(b.attn.w2.len()),
w0: AdamTensorState::new(b.attn.w0.len()),
a1: AdamTensorState::new(b.attn.a1.len()),
a2: AdamTensorState::new(b.attn.a2.len()),
a0: AdamTensorState::new(b.attn.a0.len()),
v1: b.attn.v1.as_ref().map(|t| AdamTensorState::new(t.len())),
v2: b.attn.v2.as_ref().map(|t| AdamTensorState::new(t.len())),
v0: b.attn.v0.as_ref().map(|t| AdamTensorState::new(t.len())),
g1: AdamTensorState::new(b.attn.g1.len()),
g2: AdamTensorState::new(b.attn.g2.len()),
k_k: AdamTensorState::new(b.attn.k_k.len()),
k_a: AdamTensorState::new(b.attn.k_a.len()),
r_k: AdamTensorState::new(b.attn.r_k.len()),
g_norm_w: AdamTensorState::new(b.attn.g_norm_w.len()),
g_norm_b: AdamTensorState::new(b.attn.g_norm_b.len()),
},
ffn: FfnAdamState {
x_k: AdamTensorState::new(b.ffn.x_k.len()),
key_w: AdamTensorState::new(b.ffn.key_w.len()),
value_w: AdamTensorState::new(b.ffn.value_w.len()),
},
});
}
FullAdamState {
embeddings: AdamTensorState::new(self.embeddings.len()),
ln_out_w: AdamTensorState::new(self.ln_out_w.len()),
ln_out_b: AdamTensorState::new(self.ln_out_b.len()),
lm_head: AdamTensorState::new(self.lm_head.len()),
blocks,
}
}
fn new_full_grad_state(&self) -> FullGradState {
let mut blocks = Vec::with_capacity(self.blocks.len());
for b in &self.blocks {
blocks.push(BlockGradState {
pre_norm_w: b.pre_norm_w.as_ref().map(|t| Tensor1D::zeros(t.len())),
pre_norm_b: b.pre_norm_b.as_ref().map(|t| Tensor1D::zeros(t.len())),
attn_norm_w: Tensor1D::zeros(b.attn_norm_w.len()),
attn_norm_b: Tensor1D::zeros(b.attn_norm_b.len()),
ffn_norm_w: Tensor1D::zeros(b.ffn_norm_w.len()),
ffn_norm_b: Tensor1D::zeros(b.ffn_norm_b.len()),
attn: AttentionGradState {
x_r: Tensor1D::zeros(b.attn.x_r.len()),
x_w: Tensor1D::zeros(b.attn.x_w.len()),
x_k: Tensor1D::zeros(b.attn.x_k.len()),
x_v: Tensor1D::zeros(b.attn.x_v.len()),
x_a: Tensor1D::zeros(b.attn.x_a.len()),
x_g: Tensor1D::zeros(b.attn.x_g.len()),
rkv_proj: Tensor1D::zeros(b.attn.rkv_proj.len()),
o_proj: Tensor1D::zeros(b.attn.o_proj.len()),
w1: Tensor1D::zeros(b.attn.w1.len()),
w2: Tensor1D::zeros(b.attn.w2.len()),
w0: Tensor1D::zeros(b.attn.w0.len()),
a1: Tensor1D::zeros(b.attn.a1.len()),
a2: Tensor1D::zeros(b.attn.a2.len()),
a0: Tensor1D::zeros(b.attn.a0.len()),
v1: b.attn.v1.as_ref().map(|t| Tensor1D::zeros(t.len())),
v2: b.attn.v2.as_ref().map(|t| Tensor1D::zeros(t.len())),
v0: b.attn.v0.as_ref().map(|t| Tensor1D::zeros(t.len())),
g1: Tensor1D::zeros(b.attn.g1.len()),
g2: Tensor1D::zeros(b.attn.g2.len()),
k_k: Tensor1D::zeros(b.attn.k_k.len()),
k_a: Tensor1D::zeros(b.attn.k_a.len()),
r_k: Tensor1D::zeros(b.attn.r_k.len()),
g_norm_w: Tensor1D::zeros(b.attn.g_norm_w.len()),
g_norm_b: Tensor1D::zeros(b.attn.g_norm_b.len()),
},
ffn: FfnGradState {
x_k: Tensor1D::zeros(b.ffn.x_k.len()),
key_w: Tensor1D::zeros(b.ffn.key_w.len()),
value_w: Tensor1D::zeros(b.ffn.value_w.len()),
},
});
}
FullGradState {
embeddings: Tensor1D::zeros(self.embeddings.len()),
ln_out_w: Tensor1D::zeros(self.ln_out_w.len()),
ln_out_b: Tensor1D::zeros(self.ln_out_b.len()),
lm_head: Tensor1D::zeros(self.lm_head.len()),
blocks,
}
}
fn new_recurrent_grad_state(&self) -> RecurrentGradState {
RecurrentGradState::new(&self.cfg)
}
pub fn save_full_adam_safetensors<P: AsRef<Path>>(
&self,
adam: &FullAdamState,
path: P,
) -> Result<()> {
#[derive(Clone)]
struct TensorRec {
name: String,
shape: Vec<usize>,
data: Vec<f32>,
}
let c = self.cfg.hidden_size;
let i = self.cfg.intermediate_size;
let v = self.cfg.vocab_size;
let h = self.cfg.num_heads;
let n = self.cfg.head_dim;
let d_w = self.cfg.decay_low_rank;
let d_a = self.cfg.a_low_rank;
let d_v = self.cfg.v_low_rank;
let d_g = self.cfg.g_low_rank;
let mut recs = Vec::<TensorRec>::new();
let mut push_state = |name: &str, shape: Vec<usize>, st: &AdamTensorState| {
recs.push(TensorRec {
name: format!("{name}.m"),
shape: shape.clone(),
data: st.m.as_slice().to_vec(),
});
recs.push(TensorRec {
name: format!("{name}.v"),
shape,
data: st.v.as_slice().to_vec(),
});
};
push_state("opt.model.embeddings.weight", vec![v, c], &adam.embeddings);
push_state("opt.model.norm.weight", vec![c], &adam.ln_out_w);
push_state("opt.model.norm.bias", vec![c], &adam.ln_out_b);
push_state("opt.lm_head.weight", vec![v, c], &adam.lm_head);
for (idx, b) in adam.blocks.iter().enumerate() {
let p = format!("opt.model.layers.{idx}");
if let Some(st) = &b.pre_norm_w {
push_state(&format!("{p}.pre_norm.weight"), vec![c], st);
}
if let Some(st) = &b.pre_norm_b {
push_state(&format!("{p}.pre_norm.bias"), vec![c], st);
}
push_state(&format!("{p}.attn_norm.weight"), vec![c], &b.attn_norm_w);
push_state(&format!("{p}.attn_norm.bias"), vec![c], &b.attn_norm_b);
push_state(&format!("{p}.ffn_norm.weight"), vec![c], &b.ffn_norm_w);
push_state(&format!("{p}.ffn_norm.bias"), vec![c], &b.ffn_norm_b);
push_state(&format!("{p}.attn.x_r"), vec![c], &b.attn.x_r);
push_state(&format!("{p}.attn.x_w"), vec![c], &b.attn.x_w);
push_state(&format!("{p}.attn.x_k"), vec![c], &b.attn.x_k);
push_state(&format!("{p}.attn.x_v"), vec![c], &b.attn.x_v);
push_state(&format!("{p}.attn.x_a"), vec![c], &b.attn.x_a);
push_state(&format!("{p}.attn.x_g"), vec![c], &b.attn.x_g);
push_state(
&format!("{p}.attn.rkv_proj"),
vec![3, c, c],
&b.attn.rkv_proj,
);
push_state(
&format!("{p}.attn.o_proj.weight"),
vec![c, c],
&b.attn.o_proj,
);
push_state(
&format!("{p}.attn.w_lora.lora.0.weight"),
vec![d_w, c],
&b.attn.w1,
);
push_state(
&format!("{p}.attn.w_lora.lora.2.weight"),
vec![c, d_w],
&b.attn.w2,
);
push_state(&format!("{p}.attn.w_lora.lora.2.bias"), vec![c], &b.attn.w0);
push_state(
&format!("{p}.attn.a_lora.lora.0.weight"),
vec![d_a, c],
&b.attn.a1,
);
push_state(
&format!("{p}.attn.a_lora.lora.2.weight"),
vec![c, d_a],
&b.attn.a2,
);
push_state(&format!("{p}.attn.a_lora.lora.2.bias"), vec![c], &b.attn.a0);
if let Some(st) = &b.attn.v1 {
push_state(&format!("{p}.attn.v_lora.lora.0.weight"), vec![d_v, c], st);
}
if let Some(st) = &b.attn.v2 {
push_state(&format!("{p}.attn.v_lora.lora.2.weight"), vec![c, d_v], st);
}
if let Some(st) = &b.attn.v0 {
push_state(&format!("{p}.attn.v_lora.lora.2.bias"), vec![c], st);
}
push_state(
&format!("{p}.attn.g_lora.lora.0.weight"),
vec![d_g, c],
&b.attn.g1,
);
push_state(
&format!("{p}.attn.g_lora.lora.2.weight"),
vec![c, d_g],
&b.attn.g2,
);
push_state(&format!("{p}.attn.k_k"), vec![c], &b.attn.k_k);
push_state(&format!("{p}.attn.k_a"), vec![c], &b.attn.k_a);
push_state(&format!("{p}.attn.r_k"), vec![h, n], &b.attn.r_k);
push_state(
&format!("{p}.attn.g_norm.weight"),
vec![c],
&b.attn.g_norm_w,
);
push_state(&format!("{p}.attn.g_norm.bias"), vec![c], &b.attn.g_norm_b);
push_state(&format!("{p}.ffn.x_k"), vec![c], &b.ffn.x_k);
push_state(&format!("{p}.ffn.key.weight"), vec![i, c], &b.ffn.key_w);
push_state(&format!("{p}.ffn.value.weight"), vec![c, i], &b.ffn.value_w);
}
recs.sort_by(|a, b| a.name.cmp(&b.name));
let mut offset = 0usize;
let mut header = serde_json::Map::new();
header.insert("__metadata__".to_string(), json!({}));
for rec in &recs {
let bytes = rec.data.len() * 4;
header.insert(
rec.name.clone(),
json!({
"dtype": "F32",
"shape": rec.shape,
"data_offsets": [offset, offset + bytes],
}),
);
offset += bytes;
}
let header_bytes = serde_json::to_vec(&header)?;
let mut f = File::create(path)?;
f.write_all(&(header_bytes.len() as u64).to_le_bytes())?;
f.write_all(&header_bytes)?;
for rec in &recs {
for v in &rec.data {
f.write_all(&v.to_le_bytes())?;
}
}
Ok(())
}
pub fn load_full_adam_safetensors<P: AsRef<Path>>(&self, path: P) -> Result<FullAdamState> {
let weights = Weights::load(path.as_ref()).with_context(|| {
format!(
"failed to load optimizer moments from {}",
path.as_ref().display()
)
})?;
let mut adam = self.new_full_adam_state();
let load_state = |name: &str, st: &mut AdamTensorState| -> Result<()> {
let m_name = format!("{name}.m");
let v_name = format!("{name}.v");
let m_t = weights
.require(&m_name)
.with_context(|| format!("missing optimizer tensor '{m_name}'"))?;
let v_t = weights
.require(&v_name)
.with_context(|| format!("missing optimizer tensor '{v_name}'"))?;
if m_t.data().len() != st.m.len() {
bail!(
"optimizer tensor '{}' len {} != expected {}",
m_name,
m_t.data().len(),
st.m.len()
);
}
if v_t.data().len() != st.v.len() {
bail!(
"optimizer tensor '{}' len {} != expected {}",
v_name,
v_t.data().len(),
st.v.len()
);
}
st.m.as_mut_slice().copy_from_slice(m_t.data());
st.v.as_mut_slice().copy_from_slice(v_t.data());
Ok(())
};
let c = self.cfg.hidden_size;
let i = self.cfg.intermediate_size;
let v = self.cfg.vocab_size;
let h = self.cfg.num_heads;
let n = self.cfg.head_dim;
let _ = (c, i, v, h, n);
load_state("opt.model.embeddings.weight", &mut adam.embeddings)?;
load_state("opt.model.norm.weight", &mut adam.ln_out_w)?;
load_state("opt.model.norm.bias", &mut adam.ln_out_b)?;
load_state("opt.lm_head.weight", &mut adam.lm_head)?;
for (idx, b) in adam.blocks.iter_mut().enumerate() {
let p = format!("opt.model.layers.{idx}");
if let Some(st) = b.pre_norm_w.as_mut() {
load_state(&format!("{p}.pre_norm.weight"), st)?;
}
if let Some(st) = b.pre_norm_b.as_mut() {
load_state(&format!("{p}.pre_norm.bias"), st)?;
}
load_state(&format!("{p}.attn_norm.weight"), &mut b.attn_norm_w)?;
load_state(&format!("{p}.attn_norm.bias"), &mut b.attn_norm_b)?;
load_state(&format!("{p}.ffn_norm.weight"), &mut b.ffn_norm_w)?;
load_state(&format!("{p}.ffn_norm.bias"), &mut b.ffn_norm_b)?;
load_state(&format!("{p}.attn.x_r"), &mut b.attn.x_r)?;
load_state(&format!("{p}.attn.x_w"), &mut b.attn.x_w)?;
load_state(&format!("{p}.attn.x_k"), &mut b.attn.x_k)?;
load_state(&format!("{p}.attn.x_v"), &mut b.attn.x_v)?;
load_state(&format!("{p}.attn.x_a"), &mut b.attn.x_a)?;
load_state(&format!("{p}.attn.x_g"), &mut b.attn.x_g)?;
load_state(&format!("{p}.attn.rkv_proj"), &mut b.attn.rkv_proj)?;
load_state(&format!("{p}.attn.o_proj.weight"), &mut b.attn.o_proj)?;
load_state(&format!("{p}.attn.w_lora.lora.0.weight"), &mut b.attn.w1)?;
load_state(&format!("{p}.attn.w_lora.lora.2.weight"), &mut b.attn.w2)?;
load_state(&format!("{p}.attn.w_lora.lora.2.bias"), &mut b.attn.w0)?;
load_state(&format!("{p}.attn.a_lora.lora.0.weight"), &mut b.attn.a1)?;
load_state(&format!("{p}.attn.a_lora.lora.2.weight"), &mut b.attn.a2)?;
load_state(&format!("{p}.attn.a_lora.lora.2.bias"), &mut b.attn.a0)?;
if let Some(st) = b.attn.v1.as_mut() {
load_state(&format!("{p}.attn.v_lora.lora.0.weight"), st)?;
}
if let Some(st) = b.attn.v2.as_mut() {
load_state(&format!("{p}.attn.v_lora.lora.2.weight"), st)?;
}
if let Some(st) = b.attn.v0.as_mut() {
load_state(&format!("{p}.attn.v_lora.lora.2.bias"), st)?;
}
load_state(&format!("{p}.attn.g_lora.lora.0.weight"), &mut b.attn.g1)?;
load_state(&format!("{p}.attn.g_lora.lora.2.weight"), &mut b.attn.g2)?;
load_state(&format!("{p}.attn.k_k"), &mut b.attn.k_k)?;
load_state(&format!("{p}.attn.k_a"), &mut b.attn.k_a)?;
load_state(&format!("{p}.attn.r_k"), &mut b.attn.r_k)?;
load_state(&format!("{p}.attn.g_norm.weight"), &mut b.attn.g_norm_w)?;
load_state(&format!("{p}.attn.g_norm.bias"), &mut b.attn.g_norm_b)?;
load_state(&format!("{p}.ffn.x_k"), &mut b.ffn.x_k)?;
load_state(&format!("{p}.ffn.key.weight"), &mut b.ffn.key_w)?;
load_state(&format!("{p}.ffn.value.weight"), &mut b.ffn.value_w)?;
}
Ok(adam)
}
pub fn config(&self) -> &Config {
&self.cfg
}
pub fn new_state(&self) -> State {
State::new(&self.cfg)
}
#[inline]
pub fn lm_head_weights(&self) -> &[f32] {
self.lm_head.as_slice()
}
#[inline]
pub fn lm_head_weights_mut(&mut self) -> &mut [f32] {
self.lm_head.as_mut_slice()
}
#[allow(clippy::too_many_arguments)]
fn apply_full_gradients(
&mut self,
grads: &FullGradState,
scope: TrainScopeMask,
optimizer: OptimizerKind,
lr: f32,
clip: f32,
adam_t: &mut usize,
model_adam: Option<&mut FullAdamState>,
out_bias: Option<&mut [f32]>,
out_bias_grad: Option<&[f32]>,
out_bias_adam_m: Option<&mut [f32]>,
out_bias_adam_v: Option<&mut [f32]>,
) -> Result<()> {
let mut adam_step = None::<AdamStep>;
let mut model_adam = model_adam;
if matches!(optimizer, OptimizerKind::Adam) {
*adam_t = adam_t.saturating_add(1);
let t = (*adam_t).max(1) as i32;
let b1 = 0.9f32;
let b2 = 0.999f32;
adam_step = Some(AdamStep {
lr,
clip: clip.max(0.0),
b1,
b2,
eps: 1e-8,
bias_corr1: 1.0 - b1.powi(t),
bias_corr2: 1.0 - b2.powi(t),
});
if scope.trains_non_head_params() && model_adam.is_none() {
bail!("rwkv Adam full-training state is missing");
}
}
if scope.bias
&& let (Some(bias), Some(grad)) = (out_bias, out_bias_grad)
{
match optimizer {
OptimizerKind::Sgd => sgd_vec_update(bias, grad, lr, clip),
OptimizerKind::Adam => {
let cfg = adam_step.as_ref().expect("adam cfg initialized");
let Some(m) = out_bias_adam_m else {
bail!("rwkv Adam output-bias state is missing (m)");
};
let Some(v) = out_bias_adam_v else {
bail!("rwkv Adam output-bias state is missing (v)");
};
apply_adam_vec_update_raw(bias, grad, m, v, cfg);
}
}
}
if scope.head {
match optimizer {
OptimizerKind::Sgd => {
sgd_vec_update(
self.lm_head.as_mut_slice(),
grads.lm_head.as_slice(),
lr,
clip,
);
sgd_vec_update(
self.ln_out_w.as_mut_slice(),
grads.ln_out_w.as_slice(),
lr,
clip,
);
sgd_vec_update(
self.ln_out_b.as_mut_slice(),
grads.ln_out_b.as_slice(),
lr,
clip,
);
}
OptimizerKind::Adam => {
let cfg = adam_step.as_ref().expect("adam cfg initialized");
let adam = model_adam.as_mut().expect("adam state exists");
apply_adam_vec_update(
self.lm_head.as_mut_slice(),
grads.lm_head.as_slice(),
&mut adam.lm_head,
cfg,
);
apply_adam_vec_update(
self.ln_out_w.as_mut_slice(),
grads.ln_out_w.as_slice(),
&mut adam.ln_out_w,
cfg,
);
apply_adam_vec_update(
self.ln_out_b.as_mut_slice(),
grads.ln_out_b.as_slice(),
&mut adam.ln_out_b,
cfg,
);
}
}
}
for layer_idx in 0..self.cfg.num_layers {
let block = &mut self.blocks[layer_idx];
let grad = &grads.blocks[layer_idx];
match optimizer {
OptimizerKind::Sgd => {
if scope.ffn {
sgd_vec_update(
block.ffn.x_k.as_mut_slice(),
grad.ffn.x_k.as_slice(),
lr,
clip,
);
sgd_vec_update(
block.ffn.key_w.as_mut_slice(),
grad.ffn.key_w.as_slice(),
lr,
clip,
);
sgd_vec_update(
block.ffn.value_w.as_mut_slice(),
grad.ffn.value_w.as_slice(),
lr,
clip,
);
}
if scope.ffn_norm {
sgd_vec_update(
block.ffn_norm_w.as_mut_slice(),
grad.ffn_norm_w.as_slice(),
lr,
clip,
);
sgd_vec_update(
block.ffn_norm_b.as_mut_slice(),
grad.ffn_norm_b.as_slice(),
lr,
clip,
);
}
if scope.attn {
sgd_vec_update(
block.attn.o_proj.as_mut_slice(),
grad.attn.o_proj.as_slice(),
lr,
clip,
);
sgd_vec_update(
block.attn.r_k.as_mut_slice(),
grad.attn.r_k.as_slice(),
lr,
clip,
);
sgd_vec_update(
block.attn.g_norm_w.as_mut_slice(),
grad.attn.g_norm_w.as_slice(),
lr,
clip,
);
sgd_vec_update(
block.attn.g_norm_b.as_mut_slice(),
grad.attn.g_norm_b.as_slice(),
lr,
clip,
);
sgd_vec_update(
block.attn.k_a.as_mut_slice(),
grad.attn.k_a.as_slice(),
lr,
clip,
);
sgd_vec_update(
block.attn.k_k.as_mut_slice(),
grad.attn.k_k.as_slice(),
lr,
clip,
);
sgd_vec_update(
block.attn.rkv_proj.as_mut_slice(),
grad.attn.rkv_proj.as_slice(),
lr,
clip,
);
sgd_vec_update(
block.attn.w0.as_mut_slice(),
grad.attn.w0.as_slice(),
lr,
clip,
);
sgd_vec_update(
block.attn.w2.as_mut_slice(),
grad.attn.w2.as_slice(),
lr,
clip,
);
sgd_vec_update(
block.attn.w1.as_mut_slice(),
grad.attn.w1.as_slice(),
lr,
clip,
);
sgd_vec_update(
block.attn.a0.as_mut_slice(),
grad.attn.a0.as_slice(),
lr,
clip,
);
sgd_vec_update(
block.attn.a2.as_mut_slice(),
grad.attn.a2.as_slice(),
lr,
clip,
);
sgd_vec_update(
block.attn.a1.as_mut_slice(),
grad.attn.a1.as_slice(),
lr,
clip,
);
sgd_vec_update(
block.attn.g2.as_mut_slice(),
grad.attn.g2.as_slice(),
lr,
clip,
);
sgd_vec_update(
block.attn.g1.as_mut_slice(),
grad.attn.g1.as_slice(),
lr,
clip,
);
sgd_vec_update(
block.attn.x_r.as_mut_slice(),
grad.attn.x_r.as_slice(),
lr,
clip,
);
sgd_vec_update(
block.attn.x_w.as_mut_slice(),
grad.attn.x_w.as_slice(),
lr,
clip,
);
sgd_vec_update(
block.attn.x_k.as_mut_slice(),
grad.attn.x_k.as_slice(),
lr,
clip,
);
sgd_vec_update(
block.attn.x_v.as_mut_slice(),
grad.attn.x_v.as_slice(),
lr,
clip,
);
sgd_vec_update(
block.attn.x_a.as_mut_slice(),
grad.attn.x_a.as_slice(),
lr,
clip,
);
sgd_vec_update(
block.attn.x_g.as_mut_slice(),
grad.attn.x_g.as_slice(),
lr,
clip,
);
if let (Some(v1), Some(gv1)) =
(block.attn.v1.as_mut(), grad.attn.v1.as_ref())
{
sgd_vec_update(v1.as_mut_slice(), gv1.as_slice(), lr, clip);
}
if let (Some(v2), Some(gv2)) =
(block.attn.v2.as_mut(), grad.attn.v2.as_ref())
{
sgd_vec_update(v2.as_mut_slice(), gv2.as_slice(), lr, clip);
}
if let (Some(v0), Some(gv0)) =
(block.attn.v0.as_mut(), grad.attn.v0.as_ref())
{
sgd_vec_update(v0.as_mut_slice(), gv0.as_slice(), lr, clip);
}
}
if scope.attn_norm {
sgd_vec_update(
block.attn_norm_w.as_mut_slice(),
grad.attn_norm_w.as_slice(),
lr,
clip,
);
sgd_vec_update(
block.attn_norm_b.as_mut_slice(),
grad.attn_norm_b.as_slice(),
lr,
clip,
);
}
if scope.pre_norm
&& let (Some(w), Some(gw)) =
(block.pre_norm_w.as_mut(), grad.pre_norm_w.as_ref())
{
sgd_vec_update(w.as_mut_slice(), gw.as_slice(), lr, clip);
}
if scope.pre_norm
&& let (Some(b), Some(gb)) =
(block.pre_norm_b.as_mut(), grad.pre_norm_b.as_ref())
{
sgd_vec_update(b.as_mut_slice(), gb.as_slice(), lr, clip);
}
}
OptimizerKind::Adam => {
let cfg = adam_step.as_ref().expect("adam cfg initialized");
let adam =
&mut model_adam.as_mut().expect("adam state exists").blocks[layer_idx];
if scope.ffn {
apply_adam_vec_update(
block.ffn.x_k.as_mut_slice(),
grad.ffn.x_k.as_slice(),
&mut adam.ffn.x_k,
cfg,
);
apply_adam_vec_update(
block.ffn.key_w.as_mut_slice(),
grad.ffn.key_w.as_slice(),
&mut adam.ffn.key_w,
cfg,
);
apply_adam_vec_update(
block.ffn.value_w.as_mut_slice(),
grad.ffn.value_w.as_slice(),
&mut adam.ffn.value_w,
cfg,
);
}
if scope.ffn_norm {
apply_adam_vec_update(
block.ffn_norm_w.as_mut_slice(),
grad.ffn_norm_w.as_slice(),
&mut adam.ffn_norm_w,
cfg,
);
apply_adam_vec_update(
block.ffn_norm_b.as_mut_slice(),
grad.ffn_norm_b.as_slice(),
&mut adam.ffn_norm_b,
cfg,
);
}
if scope.attn {
apply_adam_vec_update(
block.attn.o_proj.as_mut_slice(),
grad.attn.o_proj.as_slice(),
&mut adam.attn.o_proj,
cfg,
);
apply_adam_vec_update(
block.attn.r_k.as_mut_slice(),
grad.attn.r_k.as_slice(),
&mut adam.attn.r_k,
cfg,
);
apply_adam_vec_update(
block.attn.g_norm_w.as_mut_slice(),
grad.attn.g_norm_w.as_slice(),
&mut adam.attn.g_norm_w,
cfg,
);
apply_adam_vec_update(
block.attn.g_norm_b.as_mut_slice(),
grad.attn.g_norm_b.as_slice(),
&mut adam.attn.g_norm_b,
cfg,
);
apply_adam_vec_update(
block.attn.k_a.as_mut_slice(),
grad.attn.k_a.as_slice(),
&mut adam.attn.k_a,
cfg,
);
apply_adam_vec_update(
block.attn.k_k.as_mut_slice(),
grad.attn.k_k.as_slice(),
&mut adam.attn.k_k,
cfg,
);
apply_adam_vec_update(
block.attn.rkv_proj.as_mut_slice(),
grad.attn.rkv_proj.as_slice(),
&mut adam.attn.rkv_proj,
cfg,
);
apply_adam_vec_update(
block.attn.w0.as_mut_slice(),
grad.attn.w0.as_slice(),
&mut adam.attn.w0,
cfg,
);
apply_adam_vec_update(
block.attn.w2.as_mut_slice(),
grad.attn.w2.as_slice(),
&mut adam.attn.w2,
cfg,
);
apply_adam_vec_update(
block.attn.w1.as_mut_slice(),
grad.attn.w1.as_slice(),
&mut adam.attn.w1,
cfg,
);
apply_adam_vec_update(
block.attn.a0.as_mut_slice(),
grad.attn.a0.as_slice(),
&mut adam.attn.a0,
cfg,
);
apply_adam_vec_update(
block.attn.a2.as_mut_slice(),
grad.attn.a2.as_slice(),
&mut adam.attn.a2,
cfg,
);
apply_adam_vec_update(
block.attn.a1.as_mut_slice(),
grad.attn.a1.as_slice(),
&mut adam.attn.a1,
cfg,
);
apply_adam_vec_update(
block.attn.g2.as_mut_slice(),
grad.attn.g2.as_slice(),
&mut adam.attn.g2,
cfg,
);
apply_adam_vec_update(
block.attn.g1.as_mut_slice(),
grad.attn.g1.as_slice(),
&mut adam.attn.g1,
cfg,
);
apply_adam_vec_update(
block.attn.x_r.as_mut_slice(),
grad.attn.x_r.as_slice(),
&mut adam.attn.x_r,
cfg,
);
apply_adam_vec_update(
block.attn.x_w.as_mut_slice(),
grad.attn.x_w.as_slice(),
&mut adam.attn.x_w,
cfg,
);
apply_adam_vec_update(
block.attn.x_k.as_mut_slice(),
grad.attn.x_k.as_slice(),
&mut adam.attn.x_k,
cfg,
);
apply_adam_vec_update(
block.attn.x_v.as_mut_slice(),
grad.attn.x_v.as_slice(),
&mut adam.attn.x_v,
cfg,
);
apply_adam_vec_update(
block.attn.x_a.as_mut_slice(),
grad.attn.x_a.as_slice(),
&mut adam.attn.x_a,
cfg,
);
apply_adam_vec_update(
block.attn.x_g.as_mut_slice(),
grad.attn.x_g.as_slice(),
&mut adam.attn.x_g,
cfg,
);
if let (Some(v1), Some(gv1), Some(av1)) = (
block.attn.v1.as_mut(),
grad.attn.v1.as_ref(),
adam.attn.v1.as_mut(),
) {
apply_adam_vec_update(v1.as_mut_slice(), gv1.as_slice(), av1, cfg);
}
if let (Some(v2), Some(gv2), Some(av2)) = (
block.attn.v2.as_mut(),
grad.attn.v2.as_ref(),
adam.attn.v2.as_mut(),
) {
apply_adam_vec_update(v2.as_mut_slice(), gv2.as_slice(), av2, cfg);
}
if let (Some(v0), Some(gv0), Some(av0)) = (
block.attn.v0.as_mut(),
grad.attn.v0.as_ref(),
adam.attn.v0.as_mut(),
) {
apply_adam_vec_update(v0.as_mut_slice(), gv0.as_slice(), av0, cfg);
}
}
if scope.attn_norm {
apply_adam_vec_update(
block.attn_norm_w.as_mut_slice(),
grad.attn_norm_w.as_slice(),
&mut adam.attn_norm_w,
cfg,
);
apply_adam_vec_update(
block.attn_norm_b.as_mut_slice(),
grad.attn_norm_b.as_slice(),
&mut adam.attn_norm_b,
cfg,
);
}
if scope.pre_norm
&& let (Some(w), Some(gw), Some(aw)) = (
block.pre_norm_w.as_mut(),
grad.pre_norm_w.as_ref(),
adam.pre_norm_w.as_mut(),
)
{
apply_adam_vec_update(w.as_mut_slice(), gw.as_slice(), aw, cfg);
}
if scope.pre_norm
&& let (Some(b), Some(gb), Some(ab)) = (
block.pre_norm_b.as_mut(),
grad.pre_norm_b.as_ref(),
adam.pre_norm_b.as_mut(),
)
{
apply_adam_vec_update(b.as_mut_slice(), gb.as_slice(), ab, cfg);
}
}
}
}
if scope.embed {
match optimizer {
OptimizerKind::Sgd => {
sgd_vec_update(
self.embeddings.as_mut_slice(),
grads.embeddings.as_slice(),
lr,
clip,
);
}
OptimizerKind::Adam => {
let cfg = adam_step.as_ref().expect("adam cfg initialized");
let adam = model_adam.as_mut().expect("adam state exists");
apply_adam_vec_update(
self.embeddings.as_mut_slice(),
grads.embeddings.as_slice(),
&mut adam.embeddings,
cfg,
);
}
}
}
Ok(())
}
#[allow(clippy::needless_range_loop)]
fn accumulate_token_step_gradients(
&self,
scratch: &mut ScratchBuffers,
trace: &TokenTrainTrace,
state_new: &State,
symbol: u8,
pdf: &[f64],
grad_scale: f32,
scope: TrainScopeMask,
grads: &mut FullGradState,
out_bias_grad: Option<&mut [f32]>,
future: &mut RecurrentGradState,
) -> Result<()> {
let c = self.cfg.hidden_size;
let h = self.cfg.num_heads;
let n = self.cfg.head_dim;
let i = self.cfg.intermediate_size;
let d_w = self.cfg.decay_low_rank;
let d_a = self.cfg.a_low_rank;
let d_v = self.cfg.v_low_rank;
let d_g = self.cfg.g_low_rank;
let vocab = self.cfg.vocab_size.min(pdf.len());
if vocab == 0 {
return Ok(());
}
scratch.grad_logits.zero();
for idx in 0..vocab {
let p = pdf[idx].clamp(1e-12, 1.0) as f32;
let target = if idx == symbol as usize { 1.0 } else { 0.0 };
scratch.grad_logits[idx] = (target - p) * grad_scale;
}
if scope.bias
&& let Some(bias_grad) = out_bias_grad
{
add_vec_grad(
&mut bias_grad[0..vocab],
&scratch.grad_logits.as_slice()[0..vocab],
);
}
scratch.grad_x.zero();
if scope.head {
add_outer_grad(
grads.lm_head.as_mut_slice(),
vocab,
c,
&scratch.grad_logits.as_slice()[0..vocab],
trace.x_normed.as_slice(),
);
}
for row in 0..vocab {
let g = scratch.grad_logits[row];
if g == 0.0 {
continue;
}
let row_off = row * c;
for col in 0..c {
scratch.grad_x[col] += self.lm_head[row_off + col] * g;
}
}
let needs_backprop = scope.trains_non_head_params() || scope.head;
if !needs_backprop {
return Ok(());
}
layer_norm_backward(
trace.x.as_slice(),
self.ln_out_w.as_slice(),
scratch.grad_x.as_slice(),
self.cfg.layer_norm_eps,
scratch.grad_x2.as_mut_slice(),
scratch.grad_x3.as_mut_slice(),
scratch.grad_x4.as_mut_slice(),
);
if scope.head {
add_vec_grad(grads.ln_out_w.as_mut_slice(), scratch.grad_x3.as_slice());
add_vec_grad(grads.ln_out_b.as_mut_slice(), scratch.grad_x4.as_slice());
}
scratch.grad_x.copy_from_slice(scratch.grad_x2.as_slice());
scratch.grad_v_first.zero();
for layer_idx in (0..self.cfg.num_layers).rev() {
let tr = &trace.layers[layer_idx];
let block = &self.blocks[layer_idx];
let block_grads = &mut grads.blocks[layer_idx];
let future_layer = &mut future.layers[layer_idx];
scratch.grad_x2.copy_from_slice(scratch.grad_x.as_slice());
scratch.grad_x3.copy_from_slice(scratch.grad_x.as_slice());
unsafe {
kernel::gemv_t_avx(
block.ffn.value_w.as_ptr(),
scratch.grad_x3.as_ptr(),
scratch.grad_ffn.as_mut_ptr(),
c,
i,
);
}
if scope.ffn {
add_outer_grad(
block_grads.ffn.value_w.as_mut_slice(),
c,
i,
scratch.grad_x3.as_slice(),
tr.ffn_k.as_slice(),
);
}
for col in 0..i {
let pre = tr.ffn_pre[col];
scratch.grad_ffn2[col] = if pre > 0.0 {
scratch.grad_ffn[col] * (2.0 * pre)
} else {
0.0
};
}
unsafe {
kernel::gemv_t_avx(
block.ffn.key_w.as_ptr(),
scratch.grad_ffn2.as_ptr(),
scratch.grad_x4.as_mut_ptr(),
i,
c,
);
}
if scope.ffn {
add_outer_grad(
block_grads.ffn.key_w.as_mut_slice(),
i,
c,
scratch.grad_ffn2.as_slice(),
tr.ffn_xk.as_slice(),
);
}
scratch
.grad_x5
.copy_from_slice(future_layer.ffn_x_prev.as_slice());
future_layer.ffn_x_prev.zero();
for col in 0..c {
let g = scratch.grad_x4[col];
let mix = block.ffn.x_k[col];
let base = tr.ffn_norm[col];
let prev = tr.ffn_x_prev_old[col];
scratch.grad_x5[col] += g * (1.0 - mix);
future_layer.ffn_x_prev[col] = g * mix;
scratch.grad_param[col] = g * (prev - base);
}
if scope.ffn {
add_vec_grad(
block_grads.ffn.x_k.as_mut_slice(),
scratch.grad_param.as_slice(),
);
}
layer_norm_backward(
tr.x_after_attn.as_slice(),
block.ffn_norm_w.as_slice(),
scratch.grad_x5.as_slice(),
self.cfg.layer_norm_eps,
scratch.grad_x4.as_mut_slice(),
scratch.grad_x3.as_mut_slice(),
scratch.grad_x6.as_mut_slice(),
);
if scope.ffn_norm {
add_vec_grad(
block_grads.ffn_norm_w.as_mut_slice(),
scratch.grad_x3.as_slice(),
);
add_vec_grad(
block_grads.ffn_norm_b.as_mut_slice(),
scratch.grad_x6.as_slice(),
);
}
for col in 0..c {
scratch.grad_x2[col] += scratch.grad_x4[col];
}
scratch.grad_x.copy_from_slice(scratch.grad_x2.as_slice());
scratch.grad_x3.copy_from_slice(scratch.grad_x2.as_slice());
unsafe {
kernel::gemv_t_avx(
block.attn.o_proj.as_ptr(),
scratch.grad_x3.as_ptr(),
scratch.grad_x4.as_mut_ptr(),
c,
c,
);
}
if scope.attn {
add_outer_grad(
block_grads.attn.o_proj.as_mut_slice(),
c,
c,
scratch.grad_x3.as_slice(),
tr.y_gate.as_slice(),
);
}
for col in 0..c {
let gy = scratch.grad_x4[col];
scratch.grad_saved[col] = gy * tr.y_head[col];
scratch.grad_x4[col] = gy * tr.g[col];
}
scratch.grad_x2.zero();
scratch.grad_x3.zero();
scratch.grad_x6.zero();
scratch.grad_param.zero();
for head_idx in 0..h {
let off = head_idx * n;
let mut g_alpha = 0.0f32;
for j in 0..n {
let g = scratch.grad_x4[off + j];
g_alpha += g * tr.v[off + j];
scratch.grad_x6[off + j] += g * tr.alpha[head_idx];
}
for j in 0..n {
let idx = off + j;
let rk = block.attn.r_k[idx];
let rv = tr.r[idx];
let kv = tr.k[idx];
let g = g_alpha * rk;
scratch.grad_x2[idx] += g * kv;
scratch.grad_x3[idx] += g * rv;
scratch.grad_param[idx] += g_alpha * rv * kv;
}
}
if scope.attn {
add_vec_grad(
block_grads.attn.r_k.as_mut_slice(),
scratch.grad_param.as_slice(),
);
}
scratch.grad_x5.as_mut_slice()[0..c].copy_from_slice(&scratch.grad_x4.as_slice()[0..c]);
group_norm_backward(
tr.y_wkv.as_slice(),
block.attn.g_norm_w.as_slice(),
scratch.grad_x5.as_slice(),
h,
n,
self.cfg.group_norm_eps,
scratch.grad_x4.as_mut_slice(),
scratch.grad_param.as_mut_slice(),
scratch.grad_param2.as_mut_slice(),
);
if scope.attn {
add_vec_grad(
block_grads.attn.g_norm_w.as_mut_slice(),
scratch.grad_param.as_slice(),
);
add_vec_grad(
block_grads.attn.g_norm_b.as_mut_slice(),
scratch.grad_param2.as_slice(),
);
}
scratch.grad_param.zero();
scratch.grad_x5.zero();
scratch.grad_param2.zero();
scratch
.grad_att_state
.copy_from_slice(future_layer.att_state.as_slice());
future_layer.att_state.zero();
let s_old = tr.att_state_old.as_slice();
let s_new = state_new.layers[layer_idx].att_state.as_slice();
for head_idx in 0..h {
let off = head_idx * n;
let s_off = head_idx * n * n;
let grad_y = &scratch.grad_x4.as_slice()[off..off + n];
let r_head = &tr.r.as_slice()[off..off + n];
let k_head = &tr.k.as_slice()[off..off + n];
let kk_head = &tr.kk.as_slice()[off..off + n];
let a_head = &tr.a.as_slice()[off..off + n];
let v_head = &tr.v.as_slice()[off..off + n];
let w_head = &tr.w_decay.as_slice()[off..off + n];
unsafe {
kernel::gemv_t_avx(
s_new.as_ptr().add(s_off),
grad_y.as_ptr(),
scratch.grad_low_rank.as_mut_ptr(),
n,
n,
);
}
for j in 0..n {
scratch.grad_x2[off + j] += scratch.grad_low_rank[j];
}
let g_state = &mut scratch.grad_att_state.as_mut_slice()[s_off..s_off + n * n];
for irow in 0..n {
let gy = grad_y[irow];
let row_off = irow * n;
for j in 0..n {
g_state[row_off + j] += gy * r_head[j];
}
}
unsafe {
kernel::gemv_avx(
s_old.as_ptr().add(s_off),
kk_head.as_ptr(),
scratch.grad_low_rank.as_mut_ptr(),
n,
n,
);
}
let u = &scratch.grad_low_rank.as_slice()[0..n];
for j in 0..n {
let mut grad_w = 0.0f32;
let mut grad_k = 0.0f32;
let mut grad_b = 0.0f32;
for irow in 0..n {
let g = g_state[irow * n + j];
grad_w += g * s_old[s_off + irow * n + j];
grad_k += g * v_head[irow];
grad_b -= g * u[irow];
future_layer.att_state[s_off + irow * n + j] = g * w_head[j];
}
scratch.grad_param[off + j] += grad_w;
scratch.grad_x3[off + j] += grad_k;
scratch.grad_param2[off + j] += grad_b * a_head[j];
scratch.grad_x5[off + j] += grad_b * kk_head[j];
}
for irow in 0..n {
let mut grad_u = 0.0f32;
for j in 0..n {
grad_u -= g_state[irow * n + j] * kk_head[j] * a_head[j];
}
scratch.grad_low_rank2[irow] = grad_u;
let row_off = irow * n;
for j in 0..n {
future_layer.att_state[s_off + row_off + j] += grad_u * kk_head[j];
}
}
unsafe {
kernel::gemv_t_avx(
s_old.as_ptr().add(s_off),
scratch.grad_low_rank2.as_ptr(),
scratch.grad_low_rank.as_mut_ptr(),
n,
n,
);
}
for j in 0..n {
scratch.grad_param2[off + j] += scratch.grad_low_rank[j];
}
for irow in 0..n {
let mut grad_v = 0.0f32;
for j in 0..n {
grad_v += g_state[irow * n + j] * k_head[j];
}
scratch.grad_x6[off + irow] += grad_v;
}
}
for col in 0..c {
let gk = scratch.grad_x3[col];
let scale = 1.0 + (tr.a[col] - 1.0) * block.attn.k_a[col];
let d_scale = gk * tr.k_pre[col];
scratch.grad_x3[col] = gk * scale;
scratch.grad_x5[col] += d_scale * block.attn.k_a[col];
scratch.grad_param[col] = d_scale * (tr.a[col] - 1.0);
}
for head_idx in 0..h {
let off = head_idx * n;
l2_normalize_backward(
&tr.kk_pre.as_slice()[off..off + n],
&tr.kk.as_slice()[off..off + n],
&scratch.grad_param2.as_slice()[off..off + n],
1e-12,
&mut scratch.grad_x4.as_mut_slice()[off..off + n],
);
}
for col in 0..c {
let g = scratch.grad_x4[col];
scratch.grad_x3[col] += g * block.attn.k_k[col];
scratch.grad_param2[col] = g * tr.k_pre[col];
}
if scope.attn {
add_vec_grad(
block_grads.attn.k_a.as_mut_slice(),
scratch.grad_param.as_slice(),
);
add_vec_grad(
block_grads.attn.k_k.as_mut_slice(),
scratch.grad_param2.as_slice(),
);
}
scratch
.grad_param2
.copy_from_slice(scratch.grad_x6.as_slice());
if layer_idx == 0 {
for col in 0..c {
scratch.grad_x6[col] += scratch.grad_v_first[col];
}
} else if tr.uses_v_residual
&& block.attn.v1.is_some()
&& block.attn.v2.is_some()
&& block.attn.v0.is_some()
{
let v1 = block.attn.v1.as_ref().expect("v1");
let v2 = block.attn.v2.as_ref().expect("v2");
for col in 0..c {
let gv = scratch.grad_param2[col];
let nu = tr.nu[col];
scratch.grad_x6[col] = gv * (1.0 - nu);
scratch.grad_x3[col] = gv * (trace.v_first[col] - tr.v_pre[col]);
scratch.grad_v_first[col] += gv * nu;
}
for col in 0..c {
let nu = tr.nu[col];
scratch.grad_x3[col] *= nu * (1.0 - nu);
}
if scope.attn {
add_vec_grad(
block_grads
.attn
.v0
.as_mut()
.expect("grad v0")
.as_mut_slice(),
scratch.grad_x3.as_slice(),
);
add_outer_grad(
block_grads
.attn
.v2
.as_mut()
.expect("grad v2")
.as_mut_slice(),
c,
d_v,
scratch.grad_x3.as_slice(),
&tr.v_hidden.as_slice()[0..d_v],
);
}
unsafe {
kernel::gemv_t_avx(
v2.as_ptr(),
scratch.grad_x3.as_ptr(),
scratch.grad_low_rank.as_mut_ptr(),
c,
d_v,
);
}
if scope.attn {
add_outer_grad(
block_grads
.attn
.v1
.as_mut()
.expect("grad v1")
.as_mut_slice(),
d_v,
c,
&scratch.grad_low_rank.as_slice()[0..d_v],
tr.xv.as_slice(),
);
}
for col in 0..c {
let mut acc = 0.0f32;
for row in 0..d_v {
acc += v1[row * c + col] * scratch.grad_low_rank[row];
}
scratch.grad_x4[col] += acc;
}
}
let proj_size = c * c;
if scope.attn {
add_outer_grad(
&mut block_grads.attn.rkv_proj.as_mut_slice()[0..proj_size],
c,
c,
scratch.grad_x2.as_slice(),
tr.xr.as_slice(),
);
add_outer_grad(
&mut block_grads.attn.rkv_proj.as_mut_slice()[proj_size..2 * proj_size],
c,
c,
scratch.grad_x3.as_slice(),
tr.xk.as_slice(),
);
add_outer_grad(
&mut block_grads.attn.rkv_proj.as_mut_slice()[2 * proj_size..3 * proj_size],
c,
c,
scratch.grad_x6.as_slice(),
tr.xv.as_slice(),
);
}
let proj = block.attn.rkv_proj.as_slice();
unsafe {
kernel::gemv_t_avx(
proj.as_ptr(),
scratch.grad_x2.as_ptr(),
scratch.grad_param.as_mut_ptr(),
c,
c,
);
kernel::gemv_t_avx(
proj.as_ptr().add(proj_size),
scratch.grad_x3.as_ptr(),
scratch.grad_param2.as_mut_ptr(),
c,
c,
);
kernel::gemv_t_avx(
proj.as_ptr().add(2 * proj_size),
scratch.grad_x6.as_ptr(),
scratch.grad_x4.as_mut_ptr(),
c,
c,
);
}
let inv_sqrt_e = 1.0 / std::f32::consts::E.sqrt();
for col in 0..c {
let sig = tr.w_sigmoid[col];
let d_sig = scratch.grad_param[col] * (-inv_sqrt_e) * tr.w_decay[col];
scratch.grad_param[col] = d_sig * sig * (1.0 - sig);
}
if scope.attn {
add_vec_grad(
block_grads.attn.w0.as_mut_slice(),
scratch.grad_param.as_slice(),
);
add_outer_grad(
block_grads.attn.w2.as_mut_slice(),
c,
d_w,
scratch.grad_param.as_slice(),
&tr.w_hidden.as_slice()[0..d_w],
);
}
unsafe {
kernel::gemv_t_avx(
block.attn.w2.as_ptr(),
scratch.grad_param.as_ptr(),
scratch.grad_low_rank.as_mut_ptr(),
c,
d_w,
);
}
for col in 0..d_w {
let t = tr.w_hidden[col];
scratch.grad_low_rank[col] *= 1.0 - t * t;
}
if scope.attn {
add_outer_grad(
block_grads.attn.w1.as_mut_slice(),
d_w,
c,
&scratch.grad_low_rank.as_slice()[0..d_w],
tr.xw.as_slice(),
);
}
unsafe {
kernel::gemv_t_avx(
block.attn.w1.as_ptr(),
scratch.grad_low_rank.as_ptr(),
scratch.grad_x6.as_mut_ptr(),
d_w,
c,
);
}
for col in 0..c {
let a = tr.a[col];
scratch.grad_x5[col] *= a * (1.0 - a);
}
if scope.attn {
add_vec_grad(
block_grads.attn.a0.as_mut_slice(),
scratch.grad_x5.as_slice(),
);
add_outer_grad(
block_grads.attn.a2.as_mut_slice(),
c,
d_a,
scratch.grad_x5.as_slice(),
&tr.a_hidden.as_slice()[0..d_a],
);
}
unsafe {
kernel::gemv_t_avx(
block.attn.a2.as_ptr(),
scratch.grad_x5.as_ptr(),
scratch.grad_low_rank.as_mut_ptr(),
c,
d_a,
);
}
if scope.attn {
add_outer_grad(
block_grads.attn.a1.as_mut_slice(),
d_a,
c,
&scratch.grad_low_rank.as_slice()[0..d_a],
tr.xa.as_slice(),
);
}
unsafe {
kernel::gemv_t_avx(
block.attn.a1.as_ptr(),
scratch.grad_low_rank.as_ptr(),
scratch.grad_x5.as_mut_ptr(),
d_a,
c,
);
}
if scope.attn {
add_outer_grad(
block_grads.attn.g2.as_mut_slice(),
c,
d_g,
scratch.grad_saved.as_slice(),
&tr.g_hidden.as_slice()[0..d_g],
);
}
unsafe {
kernel::gemv_t_avx(
block.attn.g2.as_ptr(),
scratch.grad_saved.as_ptr(),
scratch.grad_low_rank.as_mut_ptr(),
c,
d_g,
);
}
for col in 0..d_g {
let sig = tr.g_hidden[col];
scratch.grad_low_rank2[col] = scratch.grad_low_rank[col] * sig * (1.0 - sig);
}
if scope.attn {
add_outer_grad(
block_grads.attn.g1.as_mut_slice(),
d_g,
c,
&scratch.grad_low_rank2.as_slice()[0..d_g],
tr.xg.as_slice(),
);
}
unsafe {
kernel::gemv_t_avx(
block.attn.g1.as_ptr(),
scratch.grad_low_rank2.as_ptr(),
scratch.grad_saved.as_mut_ptr(),
d_g,
c,
);
}
scratch
.grad_x3
.copy_from_slice(future_layer.att_x_prev.as_slice());
future_layer.att_x_prev.zero();
for col in 0..c {
let g = scratch.grad_param[col];
let mix = block.attn.x_r[col];
let base = tr.attn_norm[col];
let prev = tr.att_x_prev_old[col];
scratch.grad_x3[col] += g * (1.0 - mix);
future_layer.att_x_prev[col] += g * mix;
scratch.grad_x2[col] = g * (prev - base);
}
if scope.attn {
add_vec_grad(
block_grads.attn.x_r.as_mut_slice(),
scratch.grad_x2.as_slice(),
);
}
for col in 0..c {
let g = scratch.grad_x6[col];
let mix = block.attn.x_w[col];
let base = tr.attn_norm[col];
let prev = tr.att_x_prev_old[col];
scratch.grad_x3[col] += g * (1.0 - mix);
future_layer.att_x_prev[col] += g * mix;
scratch.grad_x2[col] = g * (prev - base);
}
if scope.attn {
add_vec_grad(
block_grads.attn.x_w.as_mut_slice(),
scratch.grad_x2.as_slice(),
);
}
for col in 0..c {
let g = scratch.grad_param2[col];
let mix = block.attn.x_k[col];
let base = tr.attn_norm[col];
let prev = tr.att_x_prev_old[col];
scratch.grad_x3[col] += g * (1.0 - mix);
future_layer.att_x_prev[col] += g * mix;
scratch.grad_x2[col] = g * (prev - base);
}
if scope.attn {
add_vec_grad(
block_grads.attn.x_k.as_mut_slice(),
scratch.grad_x2.as_slice(),
);
}
for col in 0..c {
let g = scratch.grad_x4[col];
let mix = block.attn.x_v[col];
let base = tr.attn_norm[col];
let prev = tr.att_x_prev_old[col];
scratch.grad_x3[col] += g * (1.0 - mix);
future_layer.att_x_prev[col] += g * mix;
scratch.grad_x2[col] = g * (prev - base);
}
if scope.attn {
add_vec_grad(
block_grads.attn.x_v.as_mut_slice(),
scratch.grad_x2.as_slice(),
);
}
for col in 0..c {
let g = scratch.grad_x5[col];
let mix = block.attn.x_a[col];
let base = tr.attn_norm[col];
let prev = tr.att_x_prev_old[col];
scratch.grad_x3[col] += g * (1.0 - mix);
future_layer.att_x_prev[col] += g * mix;
scratch.grad_x2[col] = g * (prev - base);
}
if scope.attn {
add_vec_grad(
block_grads.attn.x_a.as_mut_slice(),
scratch.grad_x2.as_slice(),
);
}
for col in 0..c {
let g = scratch.grad_saved[col];
let mix = block.attn.x_g[col];
let base = tr.attn_norm[col];
let prev = tr.att_x_prev_old[col];
scratch.grad_x3[col] += g * (1.0 - mix);
future_layer.att_x_prev[col] += g * mix;
scratch.grad_x2[col] = g * (prev - base);
}
if scope.attn {
add_vec_grad(
block_grads.attn.x_g.as_mut_slice(),
scratch.grad_x2.as_slice(),
);
}
layer_norm_backward(
tr.x_after_pre.as_slice(),
block.attn_norm_w.as_slice(),
scratch.grad_x3.as_slice(),
self.cfg.layer_norm_eps,
scratch.grad_x2.as_mut_slice(),
scratch.grad_x4.as_mut_slice(),
scratch.grad_x5.as_mut_slice(),
);
if scope.attn_norm {
add_vec_grad(
block_grads.attn_norm_w.as_mut_slice(),
scratch.grad_x4.as_slice(),
);
add_vec_grad(
block_grads.attn_norm_b.as_mut_slice(),
scratch.grad_x5.as_slice(),
);
}
for col in 0..c {
scratch.grad_x[col] += scratch.grad_x2[col];
}
if layer_idx == 0
&& let (Some(w), Some(_b)) = (&block.pre_norm_w, &block.pre_norm_b)
{
layer_norm_backward(
tr.x_in.as_slice(),
w.as_slice(),
scratch.grad_x.as_slice(),
self.cfg.layer_norm_eps,
scratch.grad_x2.as_mut_slice(),
scratch.grad_x3.as_mut_slice(),
scratch.grad_x4.as_mut_slice(),
);
if scope.pre_norm {
add_vec_grad(
block_grads
.pre_norm_w
.as_mut()
.expect("grad pre_norm_w")
.as_mut_slice(),
scratch.grad_x3.as_slice(),
);
add_vec_grad(
block_grads
.pre_norm_b
.as_mut()
.expect("grad pre_norm_b")
.as_mut_slice(),
scratch.grad_x4.as_slice(),
);
}
scratch.grad_x.copy_from_slice(scratch.grad_x2.as_slice());
}
}
if scope.embed {
let token_idx = trace.token.min(self.cfg.vocab_size.saturating_sub(1));
let off = token_idx * c;
add_vec_grad(
&mut grads.embeddings.as_mut_slice()[off..off + c],
scratch.grad_x.as_slice(),
);
}
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn online_train_segment_tbptt(
&mut self,
scratch: &mut ScratchBuffers,
start_state: &State,
steps: &[(u32, u8)],
scope: TrainScopeMask,
optimizer: OptimizerKind,
lr: f32,
clip: f32,
replay_chunk: usize,
adam_t: &mut usize,
model_adam: Option<&mut FullAdamState>,
out_bias: Option<&mut [f32]>,
out_bias_adam_m: Option<&mut [f32]>,
out_bias_adam_v: Option<&mut [f32]>,
live_state_out: &mut State,
) -> Result<()> {
if steps.is_empty() {
*live_state_out = start_state.clone();
return Ok(());
}
let grad_scale = 1.0f32 / (steps.len() as f32);
let chunk = replay_chunk.max(1).min(steps.len().max(1));
let mut grads = self.new_full_grad_state();
let mut recurrent = self.new_recurrent_grad_state();
recurrent.zero();
let mut bias_grad = out_bias.as_deref().map(|b| vec![0.0f32; b.len()]);
{
let mut checkpoints = Vec::<State>::new();
let mut checkpoint_state = start_state.clone();
scratch.set_capture_train_trace(false);
for chunk_start in (0..steps.len()).step_by(chunk) {
checkpoints.push(checkpoint_state.clone());
let chunk_end = (chunk_start + chunk).min(steps.len());
for &(input_token, _) in &steps[chunk_start..chunk_end] {
self.forward(scratch, input_token, &mut checkpoint_state);
}
}
for chunk_idx in (0..checkpoints.len()).rev() {
let chunk_start = chunk_idx * chunk;
let chunk_end = (chunk_start + chunk).min(steps.len());
let mut state = checkpoints[chunk_idx].clone();
let mut step_states = Vec::<State>::with_capacity(chunk_end - chunk_start + 1);
let mut step_traces =
Vec::<TokenTrainTrace>::with_capacity(chunk_end - chunk_start);
let mut step_pdfs =
Vec::<Vec<f64>>::with_capacity(chunk_end.saturating_sub(chunk_start));
step_states.push(state.clone());
for &(input_token, _) in &steps[chunk_start..chunk_end] {
scratch.set_capture_train_trace(true);
let logits = self.forward(scratch, input_token, &mut state);
let mut pdf = vec![0.0f64; self.cfg.vocab_size];
super::super::softmax_pdf_floor_with_bias(
logits,
out_bias.as_deref(),
&mut pdf,
);
step_pdfs.push(pdf);
step_traces.push(TokenTrainTrace::from_scratch(scratch));
step_states.push(state.clone());
}
for local_idx in (0..step_traces.len()).rev() {
let (_, target_symbol) = steps[chunk_start + local_idx];
self.accumulate_token_step_gradients(
scratch,
&step_traces[local_idx],
&step_states[local_idx + 1],
target_symbol,
&step_pdfs[local_idx],
grad_scale,
scope,
&mut grads,
bias_grad.as_deref_mut(),
&mut recurrent,
)?;
}
}
}
self.apply_full_gradients(
&grads,
scope,
optimizer,
lr,
clip,
adam_t,
model_adam,
out_bias,
bias_grad.as_deref(),
out_bias_adam_m,
out_bias_adam_v,
)?;
scratch.set_capture_train_trace(false);
*live_state_out = start_state.clone();
for &(input_token, _) in steps {
self.forward(scratch, input_token, live_state_out);
}
Ok(())
}
#[allow(clippy::too_many_arguments)]
#[allow(clippy::needless_range_loop)]
pub fn online_train_step_bptt1(
&mut self,
scratch: &mut ScratchBuffers,
state: &State,
symbol: u8,
pdf: &[f64],
scope: TrainScopeMask,
optimizer: OptimizerKind,
lr: f32,
clip: f32,
adam_t: &mut usize,
model_adam: Option<&mut FullAdamState>,
out_bias: Option<&mut [f32]>,
out_bias_adam_m: Option<&mut [f32]>,
out_bias_adam_v: Option<&mut [f32]>,
) -> Result<()> {
if !scope.trains_any_params() {
return Ok(());
}
if scope.trains_non_head_params() && !scratch.train_trace_valid {
bail!("rwkv full training trace is missing; run one forward step first");
}
let c = self.cfg.hidden_size;
let h = self.cfg.num_heads;
let n = self.cfg.head_dim;
let i = self.cfg.intermediate_size;
let d_w = self.cfg.decay_low_rank;
let d_a = self.cfg.a_low_rank;
let d_v = self.cfg.v_low_rank;
let d_g = self.cfg.g_low_rank;
let vocab = self.cfg.vocab_size.min(pdf.len());
if vocab == 0 {
return Ok(());
}
let mut adam_step = None::<AdamStep>;
let mut model_adam = model_adam;
if matches!(optimizer, OptimizerKind::Adam) {
*adam_t = adam_t.saturating_add(1);
let t = (*adam_t).max(1) as i32;
let b1 = 0.9f32;
let b2 = 0.999f32;
adam_step = Some(AdamStep {
lr,
clip: clip.max(0.0),
b1,
b2,
eps: 1e-8,
bias_corr1: 1.0 - b1.powi(t),
bias_corr2: 1.0 - b2.powi(t),
});
if scope.trains_non_head_params() && model_adam.is_none() {
bail!("rwkv Adam full-training state is missing");
}
}
scratch.grad_logits.zero();
for idx in 0..vocab {
let p = pdf[idx].clamp(1e-12, 1.0) as f32;
let target = if idx == symbol as usize { 1.0 } else { 0.0 };
let mut g = target - p;
if clip > 0.0 {
g = g.clamp(-clip, clip);
}
scratch.grad_logits[idx] = g;
}
if scope.bias
&& let Some(bias) = out_bias
{
match optimizer {
OptimizerKind::Sgd => {
for idx in 0..bias.len().min(vocab) {
bias[idx] += lr * scratch.grad_logits[idx];
}
}
OptimizerKind::Adam => {
let cfg = adam_step.as_ref().expect("adam cfg initialized");
let Some(m) = out_bias_adam_m else {
bail!("rwkv Adam output-bias state is missing (m)");
};
let Some(vv) = out_bias_adam_v else {
bail!("rwkv Adam output-bias state is missing (v)");
};
let n = bias.len().min(vocab);
apply_adam_vec_update_raw(
&mut bias[0..n],
&scratch.grad_logits.as_slice()[0..n],
&mut m[0..n],
&mut vv[0..n],
cfg,
);
}
}
}
scratch.grad_x.zero();
if scope.head {
match optimizer {
OptimizerKind::Sgd => {
fused_sgd_head_backward_update(
self.lm_head.as_mut_slice(),
vocab,
c,
&scratch.grad_logits.as_slice()[0..vocab],
scratch.x_normed.as_slice(),
scratch.grad_x.as_mut_slice(),
lr,
clip,
);
}
OptimizerKind::Adam => {
let cfg = adam_step.as_ref().expect("adam cfg initialized");
let adam = model_adam.as_mut().expect("adam state exists");
fused_adam_head_backward_update(
self.lm_head.as_mut_slice(),
vocab,
c,
&scratch.grad_logits.as_slice()[0..vocab],
scratch.x_normed.as_slice(),
scratch.grad_x.as_mut_slice(),
adam.lm_head.m.as_mut_slice(),
adam.lm_head.v.as_mut_slice(),
cfg,
);
}
}
} else {
for row in 0..vocab {
let g = scratch.grad_logits[row];
if g == 0.0 {
continue;
}
let row_off = row * c;
for col in 0..c {
scratch.grad_x[col] += self.lm_head[row_off + col] * g;
}
}
}
let needs_backprop = scope.trains_non_head_params() || scope.head;
if !needs_backprop {
return Ok(());
}
layer_norm_backward(
scratch.x.as_slice(),
self.ln_out_w.as_slice(),
scratch.grad_x.as_slice(),
self.cfg.layer_norm_eps,
scratch.grad_x2.as_mut_slice(),
scratch.grad_x3.as_mut_slice(),
scratch.grad_x4.as_mut_slice(),
);
if scope.head {
match optimizer {
OptimizerKind::Sgd => {
sgd_vec_update(
self.ln_out_w.as_mut_slice(),
scratch.grad_x3.as_slice(),
lr,
clip,
);
sgd_vec_update(
self.ln_out_b.as_mut_slice(),
scratch.grad_x4.as_slice(),
lr,
clip,
);
}
OptimizerKind::Adam => {
let cfg = adam_step.as_ref().expect("adam cfg initialized");
let adam = model_adam.as_mut().expect("adam state exists");
apply_adam_vec_update(
self.ln_out_w.as_mut_slice(),
scratch.grad_x3.as_slice(),
&mut adam.ln_out_w,
cfg,
);
apply_adam_vec_update(
self.ln_out_b.as_mut_slice(),
scratch.grad_x4.as_slice(),
&mut adam.ln_out_b,
cfg,
);
}
}
}
scratch.grad_x.copy_from_slice(scratch.grad_x2.as_slice());
scratch.grad_v_first.zero();
for layer_idx in (0..self.cfg.num_layers).rev() {
let tr = &scratch.train_trace_layers[layer_idx];
let block = &mut self.blocks[layer_idx];
scratch.grad_x2.copy_from_slice(scratch.grad_x.as_slice()); scratch.grad_x3.copy_from_slice(scratch.grad_x.as_slice());
unsafe {
kernel::gemv_t_avx(
block.ffn.value_w.as_ptr(),
scratch.grad_x3.as_ptr(),
scratch.grad_ffn.as_mut_ptr(),
c,
i,
);
}
if scope.ffn {
match optimizer {
OptimizerKind::Sgd => sgd_outer_update(
block.ffn.value_w.as_mut_slice(),
c,
i,
scratch.grad_x3.as_slice(),
tr.ffn_k.as_slice(),
lr,
clip,
),
OptimizerKind::Adam => {
let cfg = adam_step.as_ref().expect("adam cfg initialized");
let adam =
&mut model_adam.as_mut().expect("adam state exists").blocks[layer_idx];
apply_adam_outer_update(
block.ffn.value_w.as_mut_slice(),
c,
i,
scratch.grad_x3.as_slice(),
tr.ffn_k.as_slice(),
&mut adam.ffn.value_w,
cfg,
);
}
}
}
for col in 0..i {
let pre = tr.ffn_pre[col];
scratch.grad_ffn2[col] = if pre > 0.0 {
scratch.grad_ffn[col] * (2.0 * pre)
} else {
0.0
};
}
unsafe {
kernel::gemv_t_avx(
block.ffn.key_w.as_ptr(),
scratch.grad_ffn2.as_ptr(),
scratch.grad_x4.as_mut_ptr(),
i,
c,
);
}
if scope.ffn {
match optimizer {
OptimizerKind::Sgd => sgd_outer_update(
block.ffn.key_w.as_mut_slice(),
i,
c,
scratch.grad_ffn2.as_slice(),
tr.ffn_xk.as_slice(),
lr,
clip,
),
OptimizerKind::Adam => {
let cfg = adam_step.as_ref().expect("adam cfg initialized");
let adam =
&mut model_adam.as_mut().expect("adam state exists").blocks[layer_idx];
apply_adam_outer_update(
block.ffn.key_w.as_mut_slice(),
i,
c,
scratch.grad_ffn2.as_slice(),
tr.ffn_xk.as_slice(),
&mut adam.ffn.key_w,
cfg,
);
}
}
}
for col in 0..c {
let g = scratch.grad_x4[col];
let mix = block.ffn.x_k[col];
let base = tr.ffn_norm[col];
let prev = tr.ffn_x_prev_old[col];
scratch.grad_x5[col] = g * (1.0 - mix); scratch.grad_param[col] = g * (prev - base); }
if scope.ffn {
match optimizer {
OptimizerKind::Sgd => sgd_vec_update(
block.ffn.x_k.as_mut_slice(),
scratch.grad_param.as_slice(),
lr,
clip,
),
OptimizerKind::Adam => {
let cfg = adam_step.as_ref().expect("adam cfg initialized");
let adam =
&mut model_adam.as_mut().expect("adam state exists").blocks[layer_idx];
apply_adam_vec_update(
block.ffn.x_k.as_mut_slice(),
scratch.grad_param.as_slice(),
&mut adam.ffn.x_k,
cfg,
);
}
}
}
layer_norm_backward(
tr.x_after_attn.as_slice(),
block.ffn_norm_w.as_slice(),
scratch.grad_x5.as_slice(),
self.cfg.layer_norm_eps,
scratch.grad_x4.as_mut_slice(),
scratch.grad_x3.as_mut_slice(),
scratch.grad_x6.as_mut_slice(),
);
if scope.ffn_norm {
match optimizer {
OptimizerKind::Sgd => {
sgd_vec_update(
block.ffn_norm_w.as_mut_slice(),
scratch.grad_x3.as_slice(),
lr,
clip,
);
sgd_vec_update(
block.ffn_norm_b.as_mut_slice(),
scratch.grad_x6.as_slice(),
lr,
clip,
);
}
OptimizerKind::Adam => {
let cfg = adam_step.as_ref().expect("adam cfg initialized");
let adam =
&mut model_adam.as_mut().expect("adam state exists").blocks[layer_idx];
apply_adam_vec_update(
block.ffn_norm_w.as_mut_slice(),
scratch.grad_x3.as_slice(),
&mut adam.ffn_norm_w,
cfg,
);
apply_adam_vec_update(
block.ffn_norm_b.as_mut_slice(),
scratch.grad_x6.as_slice(),
&mut adam.ffn_norm_b,
cfg,
);
}
}
}
for col in 0..c {
scratch.grad_x2[col] += scratch.grad_x4[col];
}
scratch.grad_x.copy_from_slice(scratch.grad_x2.as_slice()); scratch.grad_x3.copy_from_slice(scratch.grad_x2.as_slice());
unsafe {
kernel::gemv_t_avx(
block.attn.o_proj.as_ptr(),
scratch.grad_x3.as_ptr(),
scratch.grad_x4.as_mut_ptr(),
c,
c,
);
}
if scope.attn {
match optimizer {
OptimizerKind::Sgd => sgd_outer_update(
block.attn.o_proj.as_mut_slice(),
c,
c,
scratch.grad_x3.as_slice(),
tr.y_gate.as_slice(),
lr,
clip,
),
OptimizerKind::Adam => {
let cfg = adam_step.as_ref().expect("adam cfg initialized");
let adam =
&mut model_adam.as_mut().expect("adam state exists").blocks[layer_idx];
apply_adam_outer_update(
block.attn.o_proj.as_mut_slice(),
c,
c,
scratch.grad_x3.as_slice(),
tr.y_gate.as_slice(),
&mut adam.attn.o_proj,
cfg,
);
}
}
}
for col in 0..c {
let gy = scratch.grad_x4[col];
scratch.grad_saved[col] = gy * tr.y_head[col]; scratch.grad_x4[col] = gy * tr.g[col]; }
scratch.grad_x2.zero(); scratch.grad_x3.zero(); scratch.grad_x6.zero(); scratch.grad_param.zero(); for head_idx in 0..h {
let off = head_idx * n;
let mut g_alpha = 0.0f32;
for j in 0..n {
let g = scratch.grad_x4[off + j];
g_alpha += g * tr.v[off + j];
scratch.grad_x6[off + j] += g * tr.alpha[head_idx];
}
for j in 0..n {
let idx = off + j;
let rk = block.attn.r_k[idx];
let rv = tr.r[idx];
let kv = tr.k[idx];
let g = g_alpha * rk;
scratch.grad_x2[idx] += g * kv;
scratch.grad_x3[idx] += g * rv;
scratch.grad_param[idx] += g_alpha * rv * kv;
}
}
if scope.attn {
match optimizer {
OptimizerKind::Sgd => sgd_vec_update(
block.attn.r_k.as_mut_slice(),
scratch.grad_param.as_slice(),
lr,
clip,
),
OptimizerKind::Adam => {
let cfg = adam_step.as_ref().expect("adam cfg initialized");
let adam =
&mut model_adam.as_mut().expect("adam state exists").blocks[layer_idx];
apply_adam_vec_update(
block.attn.r_k.as_mut_slice(),
scratch.grad_param.as_slice(),
&mut adam.attn.r_k,
cfg,
);
}
}
}
scratch.grad_x5.as_mut_slice()[0..c].copy_from_slice(&scratch.grad_x4.as_slice()[0..c]);
group_norm_backward(
tr.y_wkv.as_slice(),
block.attn.g_norm_w.as_slice(),
scratch.grad_x5.as_slice(),
h,
n,
self.cfg.group_norm_eps,
scratch.grad_x4.as_mut_slice(), scratch.grad_param.as_mut_slice(), scratch.grad_param2.as_mut_slice(), );
if scope.attn {
match optimizer {
OptimizerKind::Sgd => {
sgd_vec_update(
block.attn.g_norm_w.as_mut_slice(),
scratch.grad_param.as_slice(),
lr,
clip,
);
sgd_vec_update(
block.attn.g_norm_b.as_mut_slice(),
scratch.grad_param2.as_slice(),
lr,
clip,
);
}
OptimizerKind::Adam => {
let cfg = adam_step.as_ref().expect("adam cfg initialized");
let adam =
&mut model_adam.as_mut().expect("adam state exists").blocks[layer_idx];
apply_adam_vec_update(
block.attn.g_norm_w.as_mut_slice(),
scratch.grad_param.as_slice(),
&mut adam.attn.g_norm_w,
cfg,
);
apply_adam_vec_update(
block.attn.g_norm_b.as_mut_slice(),
scratch.grad_param2.as_slice(),
&mut adam.attn.g_norm_b,
cfg,
);
}
}
}
scratch.grad_param.zero(); scratch.grad_x5.zero(); scratch.grad_param2.zero(); let s_old = tr.att_state_old.as_slice();
let s_new = state.layers[layer_idx].att_state.as_slice();
for head_idx in 0..h {
let off = head_idx * n;
let s_head_old_off = head_idx * n * n;
let s_head_new_off = head_idx * n * n;
let grad_y = &scratch.grad_x4.as_slice()[off..off + n];
let r_head = &tr.r.as_slice()[off..off + n];
let k_head = &tr.k.as_slice()[off..off + n];
let kk_head = &tr.kk.as_slice()[off..off + n];
let a_head = &tr.a.as_slice()[off..off + n];
let v_head = &tr.v.as_slice()[off..off + n];
unsafe {
kernel::gemv_t_avx(
s_new.as_ptr().add(s_head_new_off),
grad_y.as_ptr(),
scratch.grad_low_rank.as_mut_ptr(),
n,
n,
);
kernel::gemv_t_avx(
s_old.as_ptr().add(s_head_old_off),
grad_y.as_ptr(),
scratch.grad_low_rank2.as_mut_ptr(),
n,
n,
);
}
for j in 0..n {
let idx = off + j;
scratch.grad_x2[idx] += scratch.grad_low_rank[j];
scratch.grad_param[idx] += r_head[j] * scratch.grad_low_rank2[j];
}
unsafe {
kernel::gemv_avx(
s_old.as_ptr().add(s_head_old_off),
kk_head.as_ptr(),
scratch.grad_low_rank.as_mut_ptr(),
n,
n,
);
}
let mut dot_gv = 0.0f32;
let mut dot_rk = 0.0f32;
let mut dot_r_kka = 0.0f32;
let mut sum_gy_u = 0.0f32;
for j in 0..n {
dot_gv += grad_y[j] * v_head[j];
dot_rk += r_head[j] * k_head[j];
dot_r_kka += r_head[j] * kk_head[j] * a_head[j];
sum_gy_u += grad_y[j] * scratch.grad_low_rank[j];
}
for j in 0..n {
let idx = off + j;
scratch.grad_x3[idx] += r_head[j] * dot_gv;
scratch.grad_x6[idx] += grad_y[j] * dot_rk;
scratch.grad_x5[idx] -= sum_gy_u * r_head[j] * kk_head[j];
scratch.grad_low_rank[j] = -grad_y[j] * dot_r_kka;
}
unsafe {
kernel::gemv_t_avx(
s_old.as_ptr().add(s_head_old_off),
scratch.grad_low_rank.as_ptr(),
scratch.grad_low_rank2.as_mut_ptr(),
n,
n,
);
}
for j in 0..n {
let idx = off + j;
scratch.grad_param2[idx] +=
scratch.grad_low_rank2[j] - sum_gy_u * r_head[j] * a_head[j];
}
}
for col in 0..c {
let gk = scratch.grad_x3[col];
let scale = 1.0 + (tr.a[col] - 1.0) * block.attn.k_a[col];
let d_scale = gk * tr.k_pre[col];
scratch.grad_x3[col] = gk * scale; scratch.grad_x5[col] += d_scale * block.attn.k_a[col]; scratch.grad_param[col] = d_scale * (tr.a[col] - 1.0); }
for head_idx in 0..h {
let off = head_idx * n;
l2_normalize_backward(
&tr.kk_pre.as_slice()[off..off + n],
&tr.kk.as_slice()[off..off + n],
&scratch.grad_param2.as_slice()[off..off + n],
1e-12,
&mut scratch.grad_x4.as_mut_slice()[off..off + n],
);
}
for col in 0..c {
let g = scratch.grad_x4[col];
scratch.grad_x3[col] += g * block.attn.k_k[col]; scratch.grad_param2[col] = g * tr.k_pre[col]; }
if scope.attn {
match optimizer {
OptimizerKind::Sgd => {
sgd_vec_update(
block.attn.k_a.as_mut_slice(),
scratch.grad_param.as_slice(),
lr,
clip,
);
sgd_vec_update(
block.attn.k_k.as_mut_slice(),
scratch.grad_param2.as_slice(),
lr,
clip,
);
}
OptimizerKind::Adam => {
let cfg = adam_step.as_ref().expect("adam cfg initialized");
let adam =
&mut model_adam.as_mut().expect("adam state exists").blocks[layer_idx];
apply_adam_vec_update(
block.attn.k_a.as_mut_slice(),
scratch.grad_param.as_slice(),
&mut adam.attn.k_a,
cfg,
);
apply_adam_vec_update(
block.attn.k_k.as_mut_slice(),
scratch.grad_param2.as_slice(),
&mut adam.attn.k_k,
cfg,
);
}
}
}
scratch
.grad_param2
.copy_from_slice(scratch.grad_x6.as_slice()); if layer_idx == 0 {
for col in 0..c {
scratch.grad_x6[col] += scratch.grad_v_first[col];
}
} else if tr.uses_v_residual
&& let (Some(v1), Some(v2), Some(v0)) =
(&mut block.attn.v1, &mut block.attn.v2, &mut block.attn.v0)
{
for col in 0..c {
let gv = scratch.grad_param2[col];
let nu = tr.nu[col];
scratch.grad_x6[col] = gv * (1.0 - nu); scratch.grad_x3[col] = gv * (scratch.train_v_first[col] - tr.v_pre[col]); scratch.grad_v_first[col] += gv * nu; }
for col in 0..c {
let nu = tr.nu[col];
scratch.grad_x3[col] *= nu * (1.0 - nu); }
if scope.attn {
match optimizer {
OptimizerKind::Sgd => {
sgd_vec_update(v0.as_mut_slice(), scratch.grad_x3.as_slice(), lr, clip)
}
OptimizerKind::Adam => {
let cfg = adam_step.as_ref().expect("adam cfg initialized");
let adam = &mut model_adam.as_mut().expect("adam state exists").blocks
[layer_idx];
apply_adam_vec_update(
v0.as_mut_slice(),
scratch.grad_x3.as_slice(),
adam.attn.v0.as_mut().expect("adam v0 state"),
cfg,
);
}
}
}
if scope.attn {
match optimizer {
OptimizerKind::Sgd => sgd_outer_update(
v2.as_mut_slice(),
c,
d_v,
scratch.grad_x3.as_slice(),
&tr.v_hidden.as_slice()[0..d_v],
lr,
clip,
),
OptimizerKind::Adam => {
let cfg = adam_step.as_ref().expect("adam cfg initialized");
let adam = &mut model_adam.as_mut().expect("adam state exists").blocks
[layer_idx];
apply_adam_outer_update(
v2.as_mut_slice(),
c,
d_v,
scratch.grad_x3.as_slice(),
&tr.v_hidden.as_slice()[0..d_v],
adam.attn.v2.as_mut().expect("adam v2 state"),
cfg,
);
}
}
}
unsafe {
kernel::gemv_t_avx(
v2.as_ptr(),
scratch.grad_x3.as_ptr(),
scratch.grad_low_rank.as_mut_ptr(),
c,
d_v,
);
}
if scope.attn {
match optimizer {
OptimizerKind::Sgd => sgd_outer_update(
v1.as_mut_slice(),
d_v,
c,
&scratch.grad_low_rank.as_slice()[0..d_v],
tr.xv.as_slice(),
lr,
clip,
),
OptimizerKind::Adam => {
let cfg = adam_step.as_ref().expect("adam cfg initialized");
let adam = &mut model_adam.as_mut().expect("adam state exists").blocks
[layer_idx];
apply_adam_outer_update(
v1.as_mut_slice(),
d_v,
c,
&scratch.grad_low_rank.as_slice()[0..d_v],
tr.xv.as_slice(),
adam.attn.v1.as_mut().expect("adam v1 state"),
cfg,
);
}
}
}
for col in 0..c {
let mut acc = 0.0f32;
for row in 0..d_v {
acc += v1[row * c + col] * scratch.grad_low_rank[row];
}
scratch.grad_x4[col] += acc; }
}
let proj_size = c * c;
if scope.attn {
match optimizer {
OptimizerKind::Sgd => {
sgd_outer_update(
&mut block.attn.rkv_proj.as_mut_slice()[0..proj_size],
c,
c,
scratch.grad_x2.as_slice(),
tr.xr.as_slice(),
lr,
clip,
);
sgd_outer_update(
&mut block.attn.rkv_proj.as_mut_slice()[proj_size..2 * proj_size],
c,
c,
scratch.grad_x3.as_slice(),
tr.xk.as_slice(),
lr,
clip,
);
sgd_outer_update(
&mut block.attn.rkv_proj.as_mut_slice()[2 * proj_size..3 * proj_size],
c,
c,
scratch.grad_x6.as_slice(),
tr.xv.as_slice(),
lr,
clip,
);
}
OptimizerKind::Adam => {
let cfg = adam_step.as_ref().expect("adam cfg initialized");
let adam =
&mut model_adam.as_mut().expect("adam state exists").blocks[layer_idx];
apply_adam_outer_update_raw(
&mut block.attn.rkv_proj.as_mut_slice()[0..proj_size],
c,
c,
scratch.grad_x2.as_slice(),
tr.xr.as_slice(),
&mut adam.attn.rkv_proj.m.as_mut_slice()[0..proj_size],
&mut adam.attn.rkv_proj.v.as_mut_slice()[0..proj_size],
cfg,
);
apply_adam_outer_update_raw(
&mut block.attn.rkv_proj.as_mut_slice()[proj_size..2 * proj_size],
c,
c,
scratch.grad_x3.as_slice(),
tr.xk.as_slice(),
&mut adam.attn.rkv_proj.m.as_mut_slice()[proj_size..2 * proj_size],
&mut adam.attn.rkv_proj.v.as_mut_slice()[proj_size..2 * proj_size],
cfg,
);
apply_adam_outer_update_raw(
&mut block.attn.rkv_proj.as_mut_slice()[2 * proj_size..3 * proj_size],
c,
c,
scratch.grad_x6.as_slice(),
tr.xv.as_slice(),
&mut adam.attn.rkv_proj.m.as_mut_slice()[2 * proj_size..3 * proj_size],
&mut adam.attn.rkv_proj.v.as_mut_slice()[2 * proj_size..3 * proj_size],
cfg,
);
}
}
}
let proj = block.attn.rkv_proj.as_slice();
unsafe {
kernel::gemv_t_avx(
proj.as_ptr(),
scratch.grad_x2.as_ptr(),
scratch.grad_param.as_mut_ptr(),
c,
c,
);
kernel::gemv_t_avx(
proj.as_ptr().add(proj_size),
scratch.grad_x3.as_ptr(),
scratch.grad_param2.as_mut_ptr(),
c,
c,
);
kernel::gemv_t_avx(
proj.as_ptr().add(2 * proj_size),
scratch.grad_x6.as_ptr(),
scratch.grad_x4.as_mut_ptr(),
c,
c,
);
}
let inv_sqrt_e = 1.0 / std::f32::consts::E.sqrt();
for col in 0..c {
let sig = tr.w_sigmoid[col];
let d_sig = scratch.grad_param[col] * (-inv_sqrt_e) * tr.w_decay[col];
scratch.grad_param[col] = d_sig * sig * (1.0 - sig); }
if scope.attn {
match optimizer {
OptimizerKind::Sgd => sgd_vec_update(
block.attn.w0.as_mut_slice(),
scratch.grad_param.as_slice(),
lr,
clip,
),
OptimizerKind::Adam => {
let cfg = adam_step.as_ref().expect("adam cfg initialized");
let adam =
&mut model_adam.as_mut().expect("adam state exists").blocks[layer_idx];
apply_adam_vec_update(
block.attn.w0.as_mut_slice(),
scratch.grad_param.as_slice(),
&mut adam.attn.w0,
cfg,
);
}
}
match optimizer {
OptimizerKind::Sgd => sgd_outer_update(
block.attn.w2.as_mut_slice(),
c,
d_w,
scratch.grad_param.as_slice(),
&tr.w_hidden.as_slice()[0..d_w],
lr,
clip,
),
OptimizerKind::Adam => {
let cfg = adam_step.as_ref().expect("adam cfg initialized");
let adam =
&mut model_adam.as_mut().expect("adam state exists").blocks[layer_idx];
apply_adam_outer_update(
block.attn.w2.as_mut_slice(),
c,
d_w,
scratch.grad_param.as_slice(),
&tr.w_hidden.as_slice()[0..d_w],
&mut adam.attn.w2,
cfg,
);
}
}
}
unsafe {
kernel::gemv_t_avx(
block.attn.w2.as_ptr(),
scratch.grad_param.as_ptr(),
scratch.grad_low_rank.as_mut_ptr(),
c,
d_w,
);
}
for col in 0..d_w {
let t = tr.w_hidden[col];
scratch.grad_low_rank[col] *= 1.0 - t * t;
}
if scope.attn {
match optimizer {
OptimizerKind::Sgd => sgd_outer_update(
block.attn.w1.as_mut_slice(),
d_w,
c,
&scratch.grad_low_rank.as_slice()[0..d_w],
tr.xw.as_slice(),
lr,
clip,
),
OptimizerKind::Adam => {
let cfg = adam_step.as_ref().expect("adam cfg initialized");
let adam =
&mut model_adam.as_mut().expect("adam state exists").blocks[layer_idx];
apply_adam_outer_update(
block.attn.w1.as_mut_slice(),
d_w,
c,
&scratch.grad_low_rank.as_slice()[0..d_w],
tr.xw.as_slice(),
&mut adam.attn.w1,
cfg,
);
}
}
}
unsafe {
kernel::gemv_t_avx(
block.attn.w1.as_ptr(),
scratch.grad_low_rank.as_ptr(),
scratch.grad_x6.as_mut_ptr(),
d_w,
c,
);
}
for col in 0..c {
let a = tr.a[col];
scratch.grad_x5[col] *= a * (1.0 - a); }
if scope.attn {
match optimizer {
OptimizerKind::Sgd => sgd_vec_update(
block.attn.a0.as_mut_slice(),
scratch.grad_x5.as_slice(),
lr,
clip,
),
OptimizerKind::Adam => {
let cfg = adam_step.as_ref().expect("adam cfg initialized");
let adam =
&mut model_adam.as_mut().expect("adam state exists").blocks[layer_idx];
apply_adam_vec_update(
block.attn.a0.as_mut_slice(),
scratch.grad_x5.as_slice(),
&mut adam.attn.a0,
cfg,
);
}
}
match optimizer {
OptimizerKind::Sgd => sgd_outer_update(
block.attn.a2.as_mut_slice(),
c,
d_a,
scratch.grad_x5.as_slice(),
&tr.a_hidden.as_slice()[0..d_a],
lr,
clip,
),
OptimizerKind::Adam => {
let cfg = adam_step.as_ref().expect("adam cfg initialized");
let adam =
&mut model_adam.as_mut().expect("adam state exists").blocks[layer_idx];
apply_adam_outer_update(
block.attn.a2.as_mut_slice(),
c,
d_a,
scratch.grad_x5.as_slice(),
&tr.a_hidden.as_slice()[0..d_a],
&mut adam.attn.a2,
cfg,
);
}
}
}
unsafe {
kernel::gemv_t_avx(
block.attn.a2.as_ptr(),
scratch.grad_x5.as_ptr(),
scratch.grad_low_rank.as_mut_ptr(),
c,
d_a,
);
}
if scope.attn {
match optimizer {
OptimizerKind::Sgd => sgd_outer_update(
block.attn.a1.as_mut_slice(),
d_a,
c,
&scratch.grad_low_rank.as_slice()[0..d_a],
tr.xa.as_slice(),
lr,
clip,
),
OptimizerKind::Adam => {
let cfg = adam_step.as_ref().expect("adam cfg initialized");
let adam =
&mut model_adam.as_mut().expect("adam state exists").blocks[layer_idx];
apply_adam_outer_update(
block.attn.a1.as_mut_slice(),
d_a,
c,
&scratch.grad_low_rank.as_slice()[0..d_a],
tr.xa.as_slice(),
&mut adam.attn.a1,
cfg,
);
}
}
}
unsafe {
kernel::gemv_t_avx(
block.attn.a1.as_ptr(),
scratch.grad_low_rank.as_ptr(),
scratch.grad_x5.as_mut_ptr(),
d_a,
c,
);
}
if scope.attn {
match optimizer {
OptimizerKind::Sgd => sgd_outer_update(
block.attn.g2.as_mut_slice(),
c,
d_g,
scratch.grad_saved.as_slice(),
&tr.g_hidden.as_slice()[0..d_g],
lr,
clip,
),
OptimizerKind::Adam => {
let cfg = adam_step.as_ref().expect("adam cfg initialized");
let adam =
&mut model_adam.as_mut().expect("adam state exists").blocks[layer_idx];
apply_adam_outer_update(
block.attn.g2.as_mut_slice(),
c,
d_g,
scratch.grad_saved.as_slice(),
&tr.g_hidden.as_slice()[0..d_g],
&mut adam.attn.g2,
cfg,
);
}
}
}
unsafe {
kernel::gemv_t_avx(
block.attn.g2.as_ptr(),
scratch.grad_saved.as_ptr(),
scratch.grad_low_rank.as_mut_ptr(),
c,
d_g,
);
}
for col in 0..d_g {
let sig = tr.g_hidden[col];
scratch.grad_low_rank2[col] = scratch.grad_low_rank[col] * sig * (1.0 - sig);
}
if scope.attn {
match optimizer {
OptimizerKind::Sgd => sgd_outer_update(
block.attn.g1.as_mut_slice(),
d_g,
c,
&scratch.grad_low_rank2.as_slice()[0..d_g],
tr.xg.as_slice(),
lr,
clip,
),
OptimizerKind::Adam => {
let cfg = adam_step.as_ref().expect("adam cfg initialized");
let adam =
&mut model_adam.as_mut().expect("adam state exists").blocks[layer_idx];
apply_adam_outer_update(
block.attn.g1.as_mut_slice(),
d_g,
c,
&scratch.grad_low_rank2.as_slice()[0..d_g],
tr.xg.as_slice(),
&mut adam.attn.g1,
cfg,
);
}
}
}
unsafe {
kernel::gemv_t_avx(
block.attn.g1.as_ptr(),
scratch.grad_low_rank2.as_ptr(),
scratch.grad_saved.as_mut_ptr(),
d_g,
c,
);
}
scratch.grad_x3.zero();
for col in 0..c {
let g = scratch.grad_param[col];
let mix = block.attn.x_r[col];
let base = tr.attn_norm[col];
let prev = tr.att_x_prev_old[col];
scratch.grad_x3[col] += g * (1.0 - mix);
scratch.grad_x2[col] = g * (prev - base);
}
if scope.attn {
match optimizer {
OptimizerKind::Sgd => sgd_vec_update(
block.attn.x_r.as_mut_slice(),
scratch.grad_x2.as_slice(),
lr,
clip,
),
OptimizerKind::Adam => {
let cfg = adam_step.as_ref().expect("adam cfg initialized");
let adam =
&mut model_adam.as_mut().expect("adam state exists").blocks[layer_idx];
apply_adam_vec_update(
block.attn.x_r.as_mut_slice(),
scratch.grad_x2.as_slice(),
&mut adam.attn.x_r,
cfg,
);
}
}
}
for col in 0..c {
let g = scratch.grad_x6[col];
let mix = block.attn.x_w[col];
let base = tr.attn_norm[col];
let prev = tr.att_x_prev_old[col];
scratch.grad_x3[col] += g * (1.0 - mix);
scratch.grad_x2[col] = g * (prev - base);
}
if scope.attn {
match optimizer {
OptimizerKind::Sgd => sgd_vec_update(
block.attn.x_w.as_mut_slice(),
scratch.grad_x2.as_slice(),
lr,
clip,
),
OptimizerKind::Adam => {
let cfg = adam_step.as_ref().expect("adam cfg initialized");
let adam =
&mut model_adam.as_mut().expect("adam state exists").blocks[layer_idx];
apply_adam_vec_update(
block.attn.x_w.as_mut_slice(),
scratch.grad_x2.as_slice(),
&mut adam.attn.x_w,
cfg,
);
}
}
}
for col in 0..c {
let g = scratch.grad_param2[col];
let mix = block.attn.x_k[col];
let base = tr.attn_norm[col];
let prev = tr.att_x_prev_old[col];
scratch.grad_x3[col] += g * (1.0 - mix);
scratch.grad_x2[col] = g * (prev - base);
}
if scope.attn {
match optimizer {
OptimizerKind::Sgd => sgd_vec_update(
block.attn.x_k.as_mut_slice(),
scratch.grad_x2.as_slice(),
lr,
clip,
),
OptimizerKind::Adam => {
let cfg = adam_step.as_ref().expect("adam cfg initialized");
let adam =
&mut model_adam.as_mut().expect("adam state exists").blocks[layer_idx];
apply_adam_vec_update(
block.attn.x_k.as_mut_slice(),
scratch.grad_x2.as_slice(),
&mut adam.attn.x_k,
cfg,
);
}
}
}
for col in 0..c {
let g = scratch.grad_x4[col];
let mix = block.attn.x_v[col];
let base = tr.attn_norm[col];
let prev = tr.att_x_prev_old[col];
scratch.grad_x3[col] += g * (1.0 - mix);
scratch.grad_x2[col] = g * (prev - base);
}
if scope.attn {
match optimizer {
OptimizerKind::Sgd => sgd_vec_update(
block.attn.x_v.as_mut_slice(),
scratch.grad_x2.as_slice(),
lr,
clip,
),
OptimizerKind::Adam => {
let cfg = adam_step.as_ref().expect("adam cfg initialized");
let adam =
&mut model_adam.as_mut().expect("adam state exists").blocks[layer_idx];
apply_adam_vec_update(
block.attn.x_v.as_mut_slice(),
scratch.grad_x2.as_slice(),
&mut adam.attn.x_v,
cfg,
);
}
}
}
for col in 0..c {
let g = scratch.grad_x5[col];
let mix = block.attn.x_a[col];
let base = tr.attn_norm[col];
let prev = tr.att_x_prev_old[col];
scratch.grad_x3[col] += g * (1.0 - mix);
scratch.grad_x2[col] = g * (prev - base);
}
if scope.attn {
match optimizer {
OptimizerKind::Sgd => sgd_vec_update(
block.attn.x_a.as_mut_slice(),
scratch.grad_x2.as_slice(),
lr,
clip,
),
OptimizerKind::Adam => {
let cfg = adam_step.as_ref().expect("adam cfg initialized");
let adam =
&mut model_adam.as_mut().expect("adam state exists").blocks[layer_idx];
apply_adam_vec_update(
block.attn.x_a.as_mut_slice(),
scratch.grad_x2.as_slice(),
&mut adam.attn.x_a,
cfg,
);
}
}
}
for col in 0..c {
let g = scratch.grad_saved[col];
let mix = block.attn.x_g[col];
let base = tr.attn_norm[col];
let prev = tr.att_x_prev_old[col];
scratch.grad_x3[col] += g * (1.0 - mix);
scratch.grad_x2[col] = g * (prev - base);
}
if scope.attn {
match optimizer {
OptimizerKind::Sgd => sgd_vec_update(
block.attn.x_g.as_mut_slice(),
scratch.grad_x2.as_slice(),
lr,
clip,
),
OptimizerKind::Adam => {
let cfg = adam_step.as_ref().expect("adam cfg initialized");
let adam =
&mut model_adam.as_mut().expect("adam state exists").blocks[layer_idx];
apply_adam_vec_update(
block.attn.x_g.as_mut_slice(),
scratch.grad_x2.as_slice(),
&mut adam.attn.x_g,
cfg,
);
}
}
}
layer_norm_backward(
tr.x_after_pre.as_slice(),
block.attn_norm_w.as_slice(),
scratch.grad_x3.as_slice(),
self.cfg.layer_norm_eps,
scratch.grad_x2.as_mut_slice(),
scratch.grad_x4.as_mut_slice(),
scratch.grad_x5.as_mut_slice(),
);
if scope.attn_norm {
match optimizer {
OptimizerKind::Sgd => {
sgd_vec_update(
block.attn_norm_w.as_mut_slice(),
scratch.grad_x4.as_slice(),
lr,
clip,
);
sgd_vec_update(
block.attn_norm_b.as_mut_slice(),
scratch.grad_x5.as_slice(),
lr,
clip,
);
}
OptimizerKind::Adam => {
let cfg = adam_step.as_ref().expect("adam cfg initialized");
let adam =
&mut model_adam.as_mut().expect("adam state exists").blocks[layer_idx];
apply_adam_vec_update(
block.attn_norm_w.as_mut_slice(),
scratch.grad_x4.as_slice(),
&mut adam.attn_norm_w,
cfg,
);
apply_adam_vec_update(
block.attn_norm_b.as_mut_slice(),
scratch.grad_x5.as_slice(),
&mut adam.attn_norm_b,
cfg,
);
}
}
}
for col in 0..c {
scratch.grad_x[col] += scratch.grad_x2[col];
}
if layer_idx == 0
&& let (Some(w), Some(b)) = (&mut block.pre_norm_w, &mut block.pre_norm_b)
{
layer_norm_backward(
tr.x_in.as_slice(),
w.as_slice(),
scratch.grad_x.as_slice(),
self.cfg.layer_norm_eps,
scratch.grad_x2.as_mut_slice(),
scratch.grad_x3.as_mut_slice(),
scratch.grad_x4.as_mut_slice(),
);
if scope.pre_norm {
match optimizer {
OptimizerKind::Sgd => {
sgd_vec_update(w.as_mut_slice(), scratch.grad_x3.as_slice(), lr, clip);
sgd_vec_update(b.as_mut_slice(), scratch.grad_x4.as_slice(), lr, clip);
}
OptimizerKind::Adam => {
let cfg = adam_step.as_ref().expect("adam cfg initialized");
let adam = &mut model_adam.as_mut().expect("adam state exists").blocks
[layer_idx];
apply_adam_vec_update(
w.as_mut_slice(),
scratch.grad_x3.as_slice(),
adam.pre_norm_w.as_mut().expect("adam pre_norm_w"),
cfg,
);
apply_adam_vec_update(
b.as_mut_slice(),
scratch.grad_x4.as_slice(),
adam.pre_norm_b.as_mut().expect("adam pre_norm_b"),
cfg,
);
}
}
}
scratch.grad_x.copy_from_slice(scratch.grad_x2.as_slice());
}
}
if scope.embed {
let token_idx = scratch
.train_token
.min(self.cfg.vocab_size.saturating_sub(1));
let off = token_idx * c;
let row = &mut self.embeddings.as_mut_slice()[off..off + c];
match optimizer {
OptimizerKind::Sgd => {
sgd_vec_update(row, scratch.grad_x.as_slice(), lr, clip);
}
OptimizerKind::Adam => {
let cfg = adam_step.as_ref().expect("adam cfg initialized");
let adam = model_adam.as_mut().expect("adam state exists");
let m = &mut adam.embeddings.m.as_mut_slice()[off..off + c];
let v = &mut adam.embeddings.v.as_mut_slice()[off..off + c];
apply_adam_vec_update_raw(row, scratch.grad_x.as_slice(), m, v, cfg);
}
}
}
Ok(())
}
#[inline(never)]
pub fn forward<'a>(
&'a self,
scratch: &'a mut ScratchBuffers,
token: u32,
state: &mut State,
) -> &'a [f32] {
let mut sink = NullProfiler;
self.forward_with_sink(scratch, token, state, &mut sink)
}
#[inline(never)]
pub fn forward_with_profiler<'a, S: ProfilerSink>(
&'a self,
scratch: &'a mut ScratchBuffers,
token: u32,
state: &mut State,
profiler: &mut S,
) -> &'a [f32] {
self.forward_with_sink(scratch, token, state, profiler)
}
#[inline(never)]
fn forward_with_sink<'a, S: ProfilerSink>(
&'a self,
scratch: &'a mut ScratchBuffers,
token: u32,
state: &mut State,
profiler: &mut S,
) -> &'a [f32] {
if scratch.capture_train_trace {
self.forward_with_sink_impl::<true, S>(scratch, token, state, profiler)
} else {
self.forward_with_sink_impl::<false, S>(scratch, token, state, profiler)
}
}
fn forward_with_sink_impl<'a, const CAPTURE: bool, S: ProfilerSink>(
&'a self,
scratch: &'a mut ScratchBuffers,
token: u32,
state: &mut State,
profiler: &mut S,
) -> &'a [f32] {
let c = self.cfg.hidden_size;
let _h = self.cfg.num_heads;
let _n = self.cfg.head_dim;
let num_layers = self.cfg.num_layers;
let token_idx = (token as usize).min(self.cfg.vocab_size.saturating_sub(1));
let emb_offset = token_idx * c;
let emb_slice = &self.embeddings.as_slice()[emb_offset..emb_offset + c];
scratch.x.as_mut_slice().copy_from_slice(emb_slice);
if CAPTURE {
scratch.train_token = token_idx;
scratch.train_trace_valid = true;
} else {
scratch.train_trace_valid = false;
}
profiler.begin_token();
unsafe {
for layer_idx in 0..num_layers {
if CAPTURE {
scratch.train_trace_layers[layer_idx]
.x_in
.copy_from(&scratch.x);
}
if let (Some(w), Some(b)) = (
&self.blocks[layer_idx].pre_norm_w,
&self.blocks[layer_idx].pre_norm_b,
) {
kernel::layer_norm_avx(
scratch.x.as_ptr(),
w.as_ptr(),
b.as_ptr(),
scratch.x.as_mut_ptr(),
c,
self.cfg.layer_norm_eps,
);
}
if CAPTURE {
scratch.train_trace_layers[layer_idx]
.x_after_pre
.copy_from(&scratch.x);
}
kernel::layer_norm_avx(
scratch.x.as_ptr(),
self.blocks[layer_idx].attn_norm_w.as_ptr(),
self.blocks[layer_idx].attn_norm_b.as_ptr(),
scratch.x_normed.as_mut_ptr(),
c,
self.cfg.layer_norm_eps,
);
if CAPTURE {
scratch.train_trace_layers[layer_idx]
.attn_norm
.copy_from(&scratch.x_normed);
}
let trace_ptr = if CAPTURE {
&mut scratch.train_trace_layers[layer_idx] as *mut LayerTrainTrace
} else {
std::ptr::null_mut()
};
if S::ENABLED {
let attn_start = Instant::now();
self.attention_forward_impl::<CAPTURE>(scratch, layer_idx, state, trace_ptr);
profiler.record_attention(layer_idx, attn_start.elapsed());
} else {
self.attention_forward_impl::<CAPTURE>(scratch, layer_idx, state, trace_ptr);
}
kernel::add_avx(
scratch.x.as_ptr(),
scratch.att_out.as_ptr(),
scratch.x.as_mut_ptr(),
c,
);
if CAPTURE {
scratch.train_trace_layers[layer_idx]
.x_after_attn
.copy_from(&scratch.x);
}
kernel::layer_norm_avx(
scratch.x.as_ptr(),
self.blocks[layer_idx].ffn_norm_w.as_ptr(),
self.blocks[layer_idx].ffn_norm_b.as_ptr(),
scratch.x_normed.as_mut_ptr(),
c,
self.cfg.layer_norm_eps,
);
if CAPTURE {
scratch.train_trace_layers[layer_idx]
.ffn_norm
.copy_from(&scratch.x_normed);
}
if S::ENABLED {
let ffn_start = Instant::now();
self.ffn_forward_impl::<CAPTURE>(
scratch,
layer_idx,
&mut state.layers[layer_idx],
trace_ptr,
);
profiler.record_ffn(layer_idx, ffn_start.elapsed());
} else {
self.ffn_forward_impl::<CAPTURE>(
scratch,
layer_idx,
&mut state.layers[layer_idx],
trace_ptr,
);
}
kernel::add_avx(
scratch.x.as_ptr(),
scratch.ffn_out.as_ptr(),
scratch.x.as_mut_ptr(),
c,
);
if CAPTURE {
scratch.train_trace_layers[layer_idx]
.x_out
.copy_from(&scratch.x);
}
}
kernel::layer_norm_avx(
scratch.x.as_ptr(),
self.ln_out_w.as_ptr(),
self.ln_out_b.as_ptr(),
scratch.x_normed.as_mut_ptr(),
c,
self.cfg.layer_norm_eps,
);
kernel::gemv_avx(
self.lm_head.as_ptr(),
scratch.x_normed.as_ptr(),
scratch.logits.as_mut_ptr(),
self.cfg.vocab_size,
c,
);
}
if CAPTURE {
scratch.train_v_first.copy_from(&state.v_first);
}
scratch.logits.as_slice()
}
#[inline(always)]
unsafe fn attention_forward_impl<const CAPTURE: bool>(
&self,
scratch: &mut ScratchBuffers,
layer_idx: usize,
state: &mut State,
trace: *mut LayerTrainTrace,
) {
let attn = &self.blocks[layer_idx].attn;
let layer_state = &mut state.layers[layer_idx];
let c = self.cfg.hidden_size;
let h = self.cfg.num_heads;
let n = self.cfg.head_dim;
let d_w = self.cfg.decay_low_rank;
let d_a = self.cfg.a_low_rank;
let d_g = self.cfg.g_low_rank;
if CAPTURE {
let tr = &mut *trace;
tr.att_x_prev_old.copy_from(&layer_state.att_x_prev);
tr.att_state_old.copy_from(&layer_state.att_state);
}
kernel::token_shift_multi6_avx(
scratch.x_normed.as_ptr(),
layer_state.att_x_prev.as_ptr(),
attn.x_r.as_ptr(),
attn.x_w.as_ptr(),
attn.x_k.as_ptr(),
attn.x_v.as_ptr(),
attn.x_a.as_ptr(),
attn.x_g.as_ptr(),
scratch.xr.as_mut_ptr(),
scratch.xw.as_mut_ptr(),
scratch.xk.as_mut_ptr(),
scratch.xv.as_mut_ptr(),
scratch.xa.as_mut_ptr(),
scratch.xg.as_mut_ptr(),
c,
);
if CAPTURE {
let tr = &mut *trace;
tr.xr.copy_from(&scratch.xr);
tr.xw.copy_from(&scratch.xw);
tr.xk.copy_from(&scratch.xk);
tr.xv.copy_from(&scratch.xv);
tr.xa.copy_from(&scratch.xa);
tr.xg.copy_from(&scratch.xg);
}
kernel::copy(
scratch.x_normed.as_ptr(),
layer_state.att_x_prev.as_mut_ptr(),
c,
);
let proj_size = c * c;
kernel::gemv_avx(
attn.rkv_proj.as_ptr(),
scratch.xr.as_ptr(),
scratch.r.as_mut_ptr(),
c,
c,
);
kernel::gemv_avx(
attn.rkv_proj.as_ptr().add(proj_size),
scratch.xk.as_ptr(),
scratch.k.as_mut_ptr(),
c,
c,
);
kernel::gemv_avx(
attn.rkv_proj.as_ptr().add(2 * proj_size),
scratch.xv.as_ptr(),
scratch.v.as_mut_ptr(),
c,
c,
);
if CAPTURE {
let tr = &mut *trace;
tr.r.copy_from(&scratch.r);
tr.k_pre.copy_from(&scratch.k);
tr.v_pre.copy_from(&scratch.v);
}
kernel::gemv_avx(
attn.w1.as_ptr(),
scratch.xw.as_ptr(),
scratch.w_lora_tmp.as_mut_ptr(),
d_w,
c,
);
kernel::tanh_avx(
scratch.w_lora_tmp.as_ptr(),
scratch.w_lora_tmp.as_mut_ptr(),
d_w,
);
if CAPTURE {
let tr = &mut *trace;
tr.w_hidden.as_mut_slice()[0..d_w]
.copy_from_slice(&scratch.w_lora_tmp.as_slice()[0..d_w]);
}
kernel::gemv_avx(
attn.w2.as_ptr(),
scratch.w_lora_tmp.as_ptr(),
scratch.w_decay.as_mut_ptr(),
c,
d_w,
);
kernel::add_avx(
scratch.w_decay.as_ptr(),
attn.w0.as_ptr(),
scratch.w_decay.as_mut_ptr(),
c,
);
if CAPTURE {
let tr = &mut *trace;
tr.w_pre.copy_from(&scratch.w_decay);
}
let inv_sqrt_e = 1.0 / std::f32::consts::E.sqrt();
kernel::sigmoid_exp_neg_scaled_avx(
scratch.w_decay.as_ptr(),
scratch.w_decay.as_mut_ptr(),
if CAPTURE {
(*trace).w_sigmoid.as_mut_ptr()
} else {
std::ptr::null_mut()
},
inv_sqrt_e,
c,
);
if CAPTURE {
let tr = &mut *trace;
tr.w_decay.copy_from(&scratch.w_decay);
}
kernel::gemv_avx(
attn.a1.as_ptr(),
scratch.xa.as_ptr(),
scratch.w_lora_tmp.as_mut_ptr(),
d_a,
c,
);
if CAPTURE {
let tr = &mut *trace;
tr.a_hidden.as_mut_slice()[0..d_a]
.copy_from_slice(&scratch.w_lora_tmp.as_slice()[0..d_a]);
}
kernel::gemv_avx(
attn.a2.as_ptr(),
scratch.w_lora_tmp.as_ptr(),
scratch.a.as_mut_ptr(),
c,
d_a,
);
kernel::add_avx(
scratch.a.as_ptr(),
attn.a0.as_ptr(),
scratch.a.as_mut_ptr(),
c,
);
kernel::sigmoid_avx(scratch.a.as_ptr(), scratch.a.as_mut_ptr(), c);
if CAPTURE {
let tr = &mut *trace;
tr.a.copy_from(&scratch.a);
}
kernel::gemv_avx(
attn.g1.as_ptr(),
scratch.xg.as_ptr(),
scratch.w_lora_tmp.as_mut_ptr(),
d_g,
c,
);
kernel::sigmoid_avx(
scratch.w_lora_tmp.as_ptr(),
scratch.w_lora_tmp.as_mut_ptr(),
d_g,
);
if CAPTURE {
let tr = &mut *trace;
tr.g_hidden.as_mut_slice()[0..d_g]
.copy_from_slice(&scratch.w_lora_tmp.as_slice()[0..d_g]);
}
kernel::gemv_avx(
attn.g2.as_ptr(),
scratch.w_lora_tmp.as_ptr(),
scratch.g.as_mut_ptr(),
c,
d_g,
);
if CAPTURE {
let tr = &mut *trace;
tr.g.copy_from(&scratch.g);
}
if layer_idx == 0 {
state.v_first.copy_from(&scratch.v);
state.v_first_set = true;
if CAPTURE {
let tr = &mut *trace;
tr.uses_v_residual = false;
tr.v.copy_from(&scratch.v);
}
} else if state.v_first_set
&& let (Some(v1), Some(v2), Some(v0)) = (&attn.v1, &attn.v2, &attn.v0)
{
let d_v = self.cfg.v_low_rank;
kernel::gemv_avx(
v1.as_ptr(),
scratch.xv.as_ptr(),
scratch.w_lora_tmp.as_mut_ptr(),
d_v,
c,
);
if CAPTURE {
let tr = &mut *trace;
tr.v_hidden.as_mut_slice()[0..d_v]
.copy_from_slice(&scratch.w_lora_tmp.as_slice()[0..d_v]);
}
kernel::gemv_avx(
v2.as_ptr(),
scratch.w_lora_tmp.as_ptr(),
scratch.att_out.as_mut_ptr(), c,
d_v,
);
kernel::add_avx(
scratch.att_out.as_ptr(),
v0.as_ptr(),
scratch.att_out.as_mut_ptr(),
c,
);
kernel::sigmoid_avx(scratch.att_out.as_ptr(), scratch.att_out.as_mut_ptr(), c);
if CAPTURE {
let tr = &mut *trace;
tr.uses_v_residual = true;
tr.nu.copy_from(&scratch.att_out);
}
for i in 0..c {
let nu = scratch.att_out[i];
scratch.v[i] += (state.v_first[i] - scratch.v[i]) * nu;
}
if CAPTURE {
let tr = &mut *trace;
tr.v.copy_from(&scratch.v);
}
} else if CAPTURE {
let tr = &mut *trace;
tr.uses_v_residual = false;
tr.v.copy_from(&scratch.v);
}
kernel::mul_avx(
scratch.k.as_ptr(),
attn.k_k.as_ptr(),
scratch.kk.as_mut_ptr(),
c,
);
if CAPTURE {
let tr = &mut *trace;
tr.kk_pre.copy_from(&scratch.kk);
}
for head in 0..h {
let offset = head * n;
kernel::l2_normalize_avx(
scratch.kk.as_ptr().add(offset),
scratch.kk.as_mut_ptr().add(offset),
n,
1e-12,
);
}
if CAPTURE {
let tr = &mut *trace;
tr.kk.copy_from(&scratch.kk);
}
for i in 0..c {
let scale = 1.0 + (scratch.a[i] - 1.0) * attn.k_a[i];
scratch.k[i] *= scale;
}
if CAPTURE {
let tr = &mut *trace;
tr.k.copy_from(&scratch.k);
}
kernel::rwkv7_wkv_update_avx(
layer_state.att_state.as_mut_ptr(),
scratch.w_decay.as_ptr(),
scratch.k.as_ptr(),
scratch.v.as_ptr(),
scratch.kk.as_ptr(),
scratch.a.as_ptr(),
scratch.r.as_ptr(),
scratch.y.as_mut_ptr(),
h,
n,
);
if CAPTURE {
let tr = &mut *trace;
tr.y_wkv.copy_from(&scratch.y);
}
kernel::group_norm_avx(
scratch.y.as_ptr(),
attn.g_norm_w.as_ptr(),
attn.g_norm_b.as_ptr(),
scratch.y.as_mut_ptr(),
h,
n,
self.cfg.group_norm_eps,
);
if CAPTURE {
let tr = &mut *trace;
tr.y_gn.copy_from(&scratch.y);
}
for head in 0..h {
let offset = head * n;
let mut alpha = 0.0f32;
for j in 0..n {
alpha += scratch.r[offset + j] * scratch.k[offset + j] * attn.r_k[head * n + j];
}
if CAPTURE {
let tr = &mut *trace;
tr.alpha[head] = alpha;
}
for j in 0..n {
scratch.y[offset + j] += alpha * scratch.v[offset + j];
}
}
if CAPTURE {
let tr = &mut *trace;
tr.y_head.copy_from(&scratch.y);
}
kernel::mul_avx(
scratch.y.as_ptr(),
scratch.g.as_ptr(),
scratch.y.as_mut_ptr(),
c,
);
if CAPTURE {
let tr = &mut *trace;
tr.y_gate.copy_from(&scratch.y);
}
kernel::gemv_avx(
attn.o_proj.as_ptr(),
scratch.y.as_ptr(),
scratch.att_out.as_mut_ptr(),
c,
c,
);
if CAPTURE {
let tr = &mut *trace;
tr.att_out.copy_from(&scratch.att_out);
}
}
#[inline(always)]
unsafe fn ffn_forward_impl<const CAPTURE: bool>(
&self,
scratch: &mut ScratchBuffers,
layer_idx: usize,
layer_state: &mut LayerState,
trace: *mut LayerTrainTrace,
) {
let ffn = &self.blocks[layer_idx].ffn;
let c = self.cfg.hidden_size;
let i = self.cfg.intermediate_size;
if CAPTURE {
let tr = &mut *trace;
tr.ffn_x_prev_old.copy_from(&layer_state.ffn_x_prev);
}
kernel::token_shift_avx(
scratch.x_normed.as_ptr(),
layer_state.ffn_x_prev.as_ptr(),
ffn.x_k.as_ptr(),
scratch.xk.as_mut_ptr(),
c,
);
if CAPTURE {
let tr = &mut *trace;
tr.ffn_xk.copy_from(&scratch.xk);
}
kernel::copy(
scratch.x_normed.as_ptr(),
layer_state.ffn_x_prev.as_mut_ptr(),
c,
);
kernel::gemv_avx(
ffn.key_w.as_ptr(),
scratch.xk.as_ptr(),
scratch.ffn_k.as_mut_ptr(),
i,
c,
);
if CAPTURE {
let tr = &mut *trace;
tr.ffn_pre.copy_from(&scratch.ffn_k);
}
kernel::relu_squared_avx(scratch.ffn_k.as_ptr(), scratch.ffn_k.as_mut_ptr(), i);
if CAPTURE {
let tr = &mut *trace;
tr.ffn_k.copy_from(&scratch.ffn_k);
}
kernel::gemv_avx(
ffn.value_w.as_ptr(),
scratch.ffn_k.as_ptr(),
scratch.ffn_out.as_mut_ptr(),
c,
i,
);
if CAPTURE {
let tr = &mut *trace;
tr.ffn_out.copy_from(&scratch.ffn_out);
}
}
}
#[allow(clippy::needless_range_loop)]
fn layer_norm_backward(
input: &[f32],
weight: &[f32],
grad_out: &[f32],
eps: f32,
grad_input: &mut [f32],
grad_weight: &mut [f32],
grad_bias: &mut [f32],
) {
let n = input
.len()
.min(weight.len())
.min(grad_out.len())
.min(grad_input.len())
.min(grad_weight.len())
.min(grad_bias.len());
if n == 0 {
return;
}
let nf = n as f32;
let mut mean = 0.0f32;
for &x in &input[0..n] {
mean += x;
}
mean /= nf;
let mut var = 0.0f32;
for &x in &input[0..n] {
let d = x - mean;
var += d * d;
}
var /= nf;
let inv_std = (var + eps).sqrt().recip();
let mut sum_gw = 0.0f32;
let mut sum_gw_xhat = 0.0f32;
for i in 0..n {
let xhat = (input[i] - mean) * inv_std;
let gw = grad_out[i] * weight[i];
grad_weight[i] = grad_out[i] * xhat;
grad_bias[i] = grad_out[i];
sum_gw += gw;
sum_gw_xhat += gw * xhat;
}
for i in 0..n {
let xhat = (input[i] - mean) * inv_std;
let gw = grad_out[i] * weight[i];
grad_input[i] = (gw * nf - sum_gw - xhat * sum_gw_xhat) * inv_std / nf;
}
}
#[allow(clippy::needless_range_loop, clippy::too_many_arguments)]
fn group_norm_backward(
input: &[f32],
weight: &[f32],
grad_out: &[f32],
num_groups: usize,
group_size: usize,
eps: f32,
grad_input: &mut [f32],
grad_weight: &mut [f32],
grad_bias: &mut [f32],
) {
let c = input
.len()
.min(weight.len())
.min(grad_out.len())
.min(grad_input.len())
.min(grad_weight.len())
.min(grad_bias.len());
if c == 0 || num_groups == 0 || group_size == 0 {
return;
}
grad_input[0..c].fill(0.0);
grad_weight[0..c].fill(0.0);
grad_bias[0..c].fill(0.0);
let g = num_groups.min(c / group_size);
let n = group_size as f32;
for group in 0..g {
let off = group * group_size;
let end = (off + group_size).min(c);
let len = end - off;
if len == 0 {
continue;
}
let mut mean = 0.0f32;
for idx in off..end {
mean += input[idx];
}
mean /= len as f32;
let mut var = 0.0f32;
for idx in off..end {
let d = input[idx] - mean;
var += d * d;
}
var /= len as f32;
let inv_std = (var + eps).sqrt().recip();
let mut sum_gw = 0.0f32;
let mut sum_gw_xhat = 0.0f32;
for idx in off..end {
let xhat = (input[idx] - mean) * inv_std;
let gw = grad_out[idx] * weight[idx];
grad_weight[idx] += grad_out[idx] * xhat;
grad_bias[idx] += grad_out[idx];
sum_gw += gw;
sum_gw_xhat += gw * xhat;
}
for idx in off..end {
let xhat = (input[idx] - mean) * inv_std;
let gw = grad_out[idx] * weight[idx];
grad_input[idx] = (gw * n - sum_gw - xhat * sum_gw_xhat) * inv_std / n;
}
}
}
fn l2_normalize_backward(
x: &[f32],
y: &[f32],
grad_out: &[f32],
min_norm: f32,
grad_input: &mut [f32],
) {
let n = x
.len()
.min(y.len())
.min(grad_out.len())
.min(grad_input.len());
if n == 0 {
return;
}
let mut norm_sq = 0.0f32;
for &v in &x[0..n] {
norm_sq += v * v;
}
let norm_raw = norm_sq.sqrt();
if norm_raw <= min_norm {
let inv = min_norm.recip();
for i in 0..n {
grad_input[i] = grad_out[i] * inv;
}
return;
}
let norm = norm_raw;
let mut dot = 0.0f32;
for i in 0..n {
dot += grad_out[i] * y[i];
}
let inv = norm.recip();
for i in 0..n {
grad_input[i] = (grad_out[i] - y[i] * dot) * inv;
}
}
#[inline(always)]
fn add_vec_grad(dst: &mut [f32], src: &[f32]) {
let n = dst.len().min(src.len());
for i in 0..n {
dst[i] += src[i];
}
}
#[inline(always)]
#[allow(clippy::needless_range_loop)]
fn add_outer_grad(dst: &mut [f32], rows: usize, cols: usize, left: &[f32], right: &[f32]) {
let rows = rows.min(left.len());
let cols = cols.min(right.len());
let n = dst.len();
if rows == 0 || cols == 0 || n == 0 {
return;
}
for r in 0..rows {
let g = left[r];
if g == 0.0 {
continue;
}
let off = r * cols;
if off >= n {
break;
}
let row_cols = cols.min(n - off);
for c in 0..row_cols {
dst[off + c] += g * right[c];
}
}
}
#[inline(always)]
fn sgd_vec_update(param: &mut [f32], grad: &[f32], lr: f32, clip: f32) {
let n = param.len().min(grad.len());
if n == 0 {
return;
}
if clip > 0.0 {
for i in 0..n {
param[i] += lr * grad[i].clamp(-clip, clip);
}
} else {
for i in 0..n {
param[i] += lr * grad[i];
}
}
}
#[inline(always)]
#[allow(clippy::needless_range_loop)]
fn sgd_outer_update(
param: &mut [f32],
rows: usize,
cols: usize,
left: &[f32],
right: &[f32],
lr: f32,
clip: f32,
) {
let rows = rows.min(left.len());
let cols = cols.min(right.len());
let n = param.len();
if rows == 0 || cols == 0 || n == 0 {
return;
}
for r in 0..rows {
let g = left[r];
let off = r * cols;
if off >= n {
break;
}
let row_cols = cols.min(n - off);
if clip > 0.0 {
for c in 0..row_cols {
param[off + c] += lr * (g * right[c]).clamp(-clip, clip);
}
} else {
for c in 0..row_cols {
param[off + c] += lr * g * right[c];
}
}
}
}
#[inline(always)]
#[allow(clippy::needless_range_loop, clippy::too_many_arguments)]
fn fused_sgd_head_backward_update(
param: &mut [f32],
rows: usize,
cols: usize,
left: &[f32],
right: &[f32],
grad_input: &mut [f32],
lr: f32,
clip: f32,
) {
let rows = rows.min(left.len());
let cols = cols.min(right.len()).min(grad_input.len());
let n = param.len();
if rows == 0 || cols == 0 || n == 0 {
return;
}
let do_clip = clip > 0.0;
let lr8 = f32x8::splat(lr);
for row in 0..rows {
let g = left[row];
if g == 0.0 {
continue;
}
let off = row * cols;
if off >= n {
break;
}
let row_cols = cols.min(n - off);
if do_clip {
for col in 0..row_cols {
let idx = off + col;
let w_old = param[idx];
grad_input[col] += w_old * g;
param[idx] = w_old + lr * (g * right[col]).clamp(-clip, clip);
}
continue;
}
let mut col = 0usize;
unsafe {
let g8 = f32x8::splat(g);
while col + 8 <= row_cols {
let idx = off + col;
let wv = param.as_ptr().add(idx).cast::<f32x8>().read_unaligned();
let rv = right.as_ptr().add(col).cast::<f32x8>().read_unaligned();
let giv = grad_input
.as_ptr()
.add(col)
.cast::<f32x8>()
.read_unaligned();
grad_input
.as_mut_ptr()
.add(col)
.cast::<f32x8>()
.write_unaligned(giv + wv * g8);
param
.as_mut_ptr()
.add(idx)
.cast::<f32x8>()
.write_unaligned(wv + (g8 * rv) * lr8);
col += 8;
}
}
while col < row_cols {
let idx = off + col;
let w_old = param[idx];
grad_input[col] += w_old * g;
param[idx] = w_old + lr * g * right[col];
col += 1;
}
}
}
#[inline(always)]
fn apply_adam_vec_update(
param: &mut [f32],
grad: &[f32],
adam: &mut AdamTensorState,
step: &AdamStep,
) {
let n = param
.len()
.min(grad.len())
.min(adam.m.len())
.min(adam.v.len());
if n == 0 {
return;
}
apply_adam_vec_update_raw(
&mut param[0..n],
&grad[0..n],
&mut adam.m.as_mut_slice()[0..n],
&mut adam.v.as_mut_slice()[0..n],
step,
);
}
#[inline(always)]
fn apply_adam_vec_update_raw(
param: &mut [f32],
grad: &[f32],
m: &mut [f32],
v: &mut [f32],
step: &AdamStep,
) {
let n = param.len().min(grad.len()).min(m.len()).min(v.len());
if n == 0 {
return;
}
let b1 = step.b1;
let b2 = step.b2;
let one_b1 = 1.0 - b1;
let one_b2 = 1.0 - b2;
let inv_bc1 = 1.0 / step.bias_corr1;
let inv_bc2 = 1.0 / step.bias_corr2;
let do_clip = step.clip > 0.0;
let clip = step.clip;
if do_clip {
for idx in 0..n {
let g = grad[idx].clamp(-clip, clip);
let mm = b1 * m[idx] + one_b1 * g;
let vv = b2 * v[idx] + one_b2 * g * g;
m[idx] = mm;
v[idx] = vv;
let m_hat = mm * inv_bc1;
let v_hat = vv * inv_bc2;
param[idx] += step.lr * m_hat / (v_hat.sqrt() + step.eps);
}
return;
}
let mut idx = 0usize;
unsafe {
let b1v = f32x8::splat(b1);
let b2v = f32x8::splat(b2);
let one_b1v = f32x8::splat(one_b1);
let one_b2v = f32x8::splat(one_b2);
let inv_bc1v = f32x8::splat(inv_bc1);
let inv_bc2v = f32x8::splat(inv_bc2);
let lrv = f32x8::splat(step.lr);
let epsv = f32x8::splat(step.eps);
while idx + 8 <= n {
let gv = grad.as_ptr().add(idx).cast::<f32x8>().read_unaligned();
let mv = m.as_ptr().add(idx).cast::<f32x8>().read_unaligned();
let vv = v.as_ptr().add(idx).cast::<f32x8>().read_unaligned();
let mm = mv * b1v + gv * one_b1v;
let vv2 = vv * b2v + (gv * gv) * one_b2v;
m.as_mut_ptr().add(idx).cast::<f32x8>().write_unaligned(mm);
v.as_mut_ptr().add(idx).cast::<f32x8>().write_unaligned(vv2);
let pv = param.as_ptr().add(idx).cast::<f32x8>().read_unaligned();
let upd = ((mm * inv_bc1v) / ((vv2 * inv_bc2v).sqrt() + epsv)) * lrv;
param
.as_mut_ptr()
.add(idx)
.cast::<f32x8>()
.write_unaligned(pv + upd);
idx += 8;
}
}
while idx < n {
let g = grad[idx];
let mm = b1 * m[idx] + one_b1 * g;
let vv = b2 * v[idx] + one_b2 * g * g;
m[idx] = mm;
v[idx] = vv;
let m_hat = mm * inv_bc1;
let v_hat = vv * inv_bc2;
param[idx] += step.lr * m_hat / (v_hat.sqrt() + step.eps);
idx += 1;
}
}
#[inline(always)]
#[allow(clippy::needless_range_loop, clippy::too_many_arguments)]
fn fused_adam_head_backward_update(
param: &mut [f32],
rows: usize,
cols: usize,
left: &[f32],
right: &[f32],
grad_input: &mut [f32],
m: &mut [f32],
v: &mut [f32],
step: &AdamStep,
) {
let rows = rows.min(left.len());
let cols = cols.min(right.len()).min(grad_input.len());
let n = param.len().min(m.len()).min(v.len());
if rows == 0 || cols == 0 || n == 0 {
return;
}
let b1 = step.b1;
let b2 = step.b2;
let one_b1 = 1.0 - b1;
let one_b2 = 1.0 - b2;
let inv_bc1 = 1.0 / step.bias_corr1;
let inv_bc2 = 1.0 / step.bias_corr2;
let do_clip = step.clip > 0.0;
let clip = step.clip;
let b1v = f32x8::splat(b1);
let b2v = f32x8::splat(b2);
let one_b1v = f32x8::splat(one_b1);
let one_b2v = f32x8::splat(one_b2);
let inv_bc1v = f32x8::splat(inv_bc1);
let inv_bc2v = f32x8::splat(inv_bc2);
let epsv = f32x8::splat(step.eps);
let lrv = f32x8::splat(step.lr);
for row in 0..rows {
let g = left[row];
if g == 0.0 {
continue;
}
let off = row * cols;
if off >= n {
break;
}
let row_cols = cols.min(n - off);
if do_clip {
for col in 0..row_cols {
let idx = off + col;
let w_old = param[idx];
grad_input[col] += w_old * g;
let gg = (g * right[col]).clamp(-clip, clip);
let mm = b1 * m[idx] + one_b1 * gg;
let vv = b2 * v[idx] + one_b2 * gg * gg;
m[idx] = mm;
v[idx] = vv;
let m_hat = mm * inv_bc1;
let v_hat = vv * inv_bc2;
param[idx] = w_old + step.lr * m_hat / (v_hat.sqrt() + step.eps);
}
continue;
}
let mut col = 0usize;
unsafe {
let g8 = f32x8::splat(g);
while col + 8 <= row_cols {
let idx = off + col;
let wv = param.as_ptr().add(idx).cast::<f32x8>().read_unaligned();
let rv = right.as_ptr().add(col).cast::<f32x8>().read_unaligned();
let giv = grad_input
.as_ptr()
.add(col)
.cast::<f32x8>()
.read_unaligned();
grad_input
.as_mut_ptr()
.add(col)
.cast::<f32x8>()
.write_unaligned(giv + wv * g8);
let gv = g8 * rv;
let mv = m.as_ptr().add(idx).cast::<f32x8>().read_unaligned();
let vv = v.as_ptr().add(idx).cast::<f32x8>().read_unaligned();
let mm = mv * b1v + gv * one_b1v;
let vv2 = vv * b2v + (gv * gv) * one_b2v;
m.as_mut_ptr().add(idx).cast::<f32x8>().write_unaligned(mm);
v.as_mut_ptr().add(idx).cast::<f32x8>().write_unaligned(vv2);
let upd = ((mm * inv_bc1v) / ((vv2 * inv_bc2v).sqrt() + epsv)) * lrv;
param
.as_mut_ptr()
.add(idx)
.cast::<f32x8>()
.write_unaligned(wv + upd);
col += 8;
}
}
while col < row_cols {
let idx = off + col;
let w_old = param[idx];
grad_input[col] += w_old * g;
let gg = g * right[col];
let mm = b1 * m[idx] + one_b1 * gg;
let vv = b2 * v[idx] + one_b2 * gg * gg;
m[idx] = mm;
v[idx] = vv;
let m_hat = mm * inv_bc1;
let v_hat = vv * inv_bc2;
param[idx] = w_old + step.lr * m_hat / (v_hat.sqrt() + step.eps);
col += 1;
}
}
}
#[inline(always)]
fn apply_adam_outer_update(
param: &mut [f32],
rows: usize,
cols: usize,
left: &[f32],
right: &[f32],
adam: &mut AdamTensorState,
step: &AdamStep,
) {
let n = param.len().min(adam.m.len()).min(adam.v.len());
if n == 0 {
return;
}
apply_adam_outer_update_raw(
&mut param[0..n],
rows,
cols,
left,
right,
&mut adam.m.as_mut_slice()[0..n],
&mut adam.v.as_mut_slice()[0..n],
step,
);
}
#[allow(clippy::too_many_arguments)]
#[inline(always)]
#[allow(clippy::needless_range_loop)]
fn apply_adam_outer_update_raw(
param: &mut [f32],
rows: usize,
cols: usize,
left: &[f32],
right: &[f32],
m: &mut [f32],
v: &mut [f32],
step: &AdamStep,
) {
let rows = rows.min(left.len());
let cols = cols.min(right.len());
let n = param.len().min(m.len()).min(v.len());
if rows == 0 || cols == 0 || n == 0 {
return;
}
let b1 = step.b1;
let b2 = step.b2;
let one_b1 = 1.0 - b1;
let one_b2 = 1.0 - b2;
let inv_bc1 = 1.0 / step.bias_corr1;
let inv_bc2 = 1.0 / step.bias_corr2;
let do_clip = step.clip > 0.0;
let clip = step.clip;
let b1v = f32x8::splat(b1);
let b2v = f32x8::splat(b2);
let one_b1v = f32x8::splat(one_b1);
let one_b2v = f32x8::splat(one_b2);
let inv_bc1v = f32x8::splat(inv_bc1);
let inv_bc2v = f32x8::splat(inv_bc2);
let epsv = f32x8::splat(step.eps);
let lrv = f32x8::splat(step.lr);
for row in 0..rows {
let g_row = left[row];
let off = row * cols;
if off >= n {
break;
}
let row_cols = (n - off).min(cols);
if do_clip {
for col in 0..row_cols {
let idx = off + col;
let g = (g_row * right[col]).clamp(-clip, clip);
let mm = b1 * m[idx] + one_b1 * g;
let vv = b2 * v[idx] + one_b2 * g * g;
m[idx] = mm;
v[idx] = vv;
let m_hat = mm * inv_bc1;
let v_hat = vv * inv_bc2;
param[idx] += step.lr * m_hat / (v_hat.sqrt() + step.eps);
}
continue;
}
let mut col = 0usize;
unsafe {
let g8 = f32x8::splat(g_row);
while col + 8 <= row_cols {
let idx = off + col;
let rv = right.as_ptr().add(col).cast::<f32x8>().read_unaligned();
let gv = g8 * rv;
let mv = m.as_ptr().add(idx).cast::<f32x8>().read_unaligned();
let vv = v.as_ptr().add(idx).cast::<f32x8>().read_unaligned();
let mm = mv * b1v + gv * one_b1v;
let vv2 = vv * b2v + (gv * gv) * one_b2v;
m.as_mut_ptr().add(idx).cast::<f32x8>().write_unaligned(mm);
v.as_mut_ptr().add(idx).cast::<f32x8>().write_unaligned(vv2);
let pv = param.as_ptr().add(idx).cast::<f32x8>().read_unaligned();
let upd = ((mm * inv_bc1v) / ((vv2 * inv_bc2v).sqrt() + epsv)) * lrv;
param
.as_mut_ptr()
.add(idx)
.cast::<f32x8>()
.write_unaligned(pv + upd);
col += 8;
}
}
while col < row_cols {
let idx = off + col;
let g = g_row * right[col];
let mm = b1 * m[idx] + one_b1 * g;
let vv = b2 * v[idx] + one_b2 * g * g;
m[idx] = mm;
v[idx] = vv;
let m_hat = mm * inv_bc1;
let v_hat = vv * inv_bc2;
param[idx] += step.lr * m_hat / (v_hat.sqrt() + step.eps);
col += 1;
}
}
}
struct RwkvRng {
state: u64,
}
impl RwkvRng {
fn new(seed: u64) -> Self {
Self {
state: seed ^ 0x9E37_79B9_7F4A_7C15,
}
}
#[inline]
fn next_u32(&mut self) -> u32 {
self.state = self
.state
.wrapping_mul(6_364_136_223_846_793_005)
.wrapping_add(1);
(self.state >> 32) as u32
}
#[inline]
fn next_f32(&mut self) -> f32 {
let v = self.next_u32() as f32;
v * (1.0 / (u32::MAX as f32))
}
}
#[inline]
fn init_uniform(t: &mut Tensor1D, rng: &mut RwkvRng, scale: f32) {
let s = t.as_mut_slice();
for v in s {
let r = rng.next_f32() - 0.5;
*v = r * 2.0 * scale;
}
}
#[inline]
fn init_centered(t: &mut Tensor1D, rng: &mut RwkvRng, center: f32, scale: f32) {
let s = t.as_mut_slice();
for v in s {
let r = rng.next_f32() - 0.5;
*v = center + r * 2.0 * scale;
}
}
#[inline]
fn init_const(t: &mut Tensor1D, value: f32) {
t.as_mut_slice().fill(value);
}
#[cfg(test)]
mod tests {
use super::*;
fn test_cfg() -> Config {
Config {
vocab_size: 256,
hidden_size: 64,
num_layers: 1,
num_heads: 1,
head_dim: 64,
intermediate_size: 64,
layer_norm_eps: 1e-5,
group_norm_eps: 64e-5,
decay_low_rank: 8,
a_low_rank: 8,
v_low_rank: 8,
g_low_rank: 8,
}
}
fn softmax_loss(logits: &[f32], target: u8) -> f64 {
let max_logit = logits
.iter()
.copied()
.fold(f32::NEG_INFINITY, |a, b| a.max(b));
let mut sum = 0.0f64;
for &z in logits {
sum += ((z - max_logit) as f64).exp();
}
let p = ((logits[target as usize] - max_logit) as f64).exp() / sum.max(1e-300);
-p.max(1e-300).ln()
}
fn segment_loss(model: &Model, cfg: &Config, steps: &[(u32, u8)]) -> f64 {
if steps.is_empty() {
return 0.0;
}
let mut scratch = ScratchBuffers::new(cfg);
let mut state = model.new_state();
let mut loss = 0.0f64;
for &(input, target) in steps {
let logits = model.forward(&mut scratch, input, &mut state);
loss += softmax_loss(logits, target);
}
loss / (steps.len() as f64)
}
fn segment_grads(model: &Model, cfg: &Config, steps: &[(u32, u8)]) -> FullGradState {
let mut scratch = ScratchBuffers::new(cfg);
let mut state = model.new_state();
let mut states = Vec::with_capacity(steps.len() + 1);
let mut traces = Vec::with_capacity(steps.len());
let mut pdfs = Vec::with_capacity(steps.len());
states.push(state.clone());
for &(input, _) in steps {
scratch.set_capture_train_trace(true);
let logits = model.forward(&mut scratch, input, &mut state);
let mut pdf = vec![0.0f64; cfg.vocab_size];
super::super::super::softmax_pdf_floor_with_bias(logits, None, &mut pdf);
pdfs.push(pdf);
traces.push(TokenTrainTrace::from_scratch(&scratch));
states.push(state.clone());
}
let mut grads = model.new_full_grad_state();
let mut recurrent = model.new_recurrent_grad_state();
let scope = TrainScopeMask {
embed: true,
pre_norm: true,
attn_norm: true,
ffn_norm: true,
attn: true,
ffn: true,
head: true,
bias: false,
};
let grad_scale = 1.0f32 / (steps.len() as f32);
for idx in (0..steps.len()).rev() {
model
.accumulate_token_step_gradients(
&mut scratch,
&traces[idx],
&states[idx + 1],
steps[idx].1,
&pdfs[idx],
grad_scale,
scope,
&mut grads,
None,
&mut recurrent,
)
.expect("segment gradient accumulation");
}
grads
}
#[derive(Clone, Copy, Debug)]
enum Probe {
Embed,
LnOutW,
AttnNormW,
OProj,
KProj,
VProj,
FfnKey,
}
fn probe_value(model: &Model, probe: Probe) -> f32 {
match probe {
Probe::Embed => model.embeddings[7],
Probe::LnOutW => model.ln_out_w[5],
Probe::AttnNormW => model.blocks[0].attn_norm_w[9],
Probe::OProj => model.blocks[0].attn.o_proj[23],
Probe::KProj => model.blocks[0].attn.rkv_proj[64 * 64 + 17],
Probe::VProj => model.blocks[0].attn.rkv_proj[2 * 64 * 64 + 29],
Probe::FfnKey => model.blocks[0].ffn.key_w[11],
}
}
fn set_probe(model: &mut Model, probe: Probe, value: f32) {
match probe {
Probe::Embed => model.embeddings[7] = value,
Probe::LnOutW => model.ln_out_w[5] = value,
Probe::AttnNormW => model.blocks[0].attn_norm_w[9] = value,
Probe::OProj => model.blocks[0].attn.o_proj[23] = value,
Probe::KProj => model.blocks[0].attn.rkv_proj[64 * 64 + 17] = value,
Probe::VProj => model.blocks[0].attn.rkv_proj[2 * 64 * 64 + 29] = value,
Probe::FfnKey => model.blocks[0].ffn.key_w[11] = value,
}
}
fn probe_grad(grads: &FullGradState, probe: Probe) -> f32 {
match probe {
Probe::Embed => grads.embeddings[7],
Probe::LnOutW => grads.ln_out_w[5],
Probe::AttnNormW => grads.blocks[0].attn_norm_w[9],
Probe::OProj => grads.blocks[0].attn.o_proj[23],
Probe::KProj => grads.blocks[0].attn.rkv_proj[64 * 64 + 17],
Probe::VProj => grads.blocks[0].attn.rkv_proj[2 * 64 * 64 + 29],
Probe::FfnKey => grads.blocks[0].ffn.key_w[11],
}
}
fn weighted_checksum(data: &[f32]) -> f64 {
data.iter()
.enumerate()
.map(|(i, &v)| (i as f64 + 1.0) * (v as f64))
.sum()
}
#[test]
fn test_config_default() {
let cfg = Config::default();
assert_eq!(cfg.vocab_size, 256);
assert_eq!(cfg.hidden_size, 256);
assert_eq!(cfg.num_layers, 12);
assert_eq!(cfg.num_heads, 4);
assert_eq!(cfg.head_dim, 64);
}
#[test]
fn test_forward_deterministic_snapshot() {
let cfg = Config {
vocab_size: 256,
hidden_size: 64,
num_layers: 2,
num_heads: 1,
head_dim: 64,
intermediate_size: 128,
layer_norm_eps: 1e-5,
group_norm_eps: 64e-5,
decay_low_rank: 16,
a_low_rank: 16,
v_low_rank: 16,
g_low_rank: 32,
};
cfg.validate().expect("valid test config");
let model = Model::new_random(cfg.clone(), 0x1234_5678_9ABC_DEF0).expect("random model");
let mut state = model.new_state();
let mut scratch = ScratchBuffers::new(&cfg);
let tokens = [0u32, 1, 7, 42, 255, 3, 128, 64, 17, 99];
let mut probes = Vec::new();
let mut last_logits = vec![0.0; 8];
for &token in &tokens {
let logits = model.forward(&mut scratch, token, &mut state);
probes.push(logits[0]);
probes.push(logits[1]);
probes.push(logits[2]);
probes.push(logits[42]);
probes.push(logits[127]);
probes.push(logits[255]);
last_logits.copy_from_slice(&logits[0..8]);
}
let probe_checksum = weighted_checksum(&probes);
let last_logits_checksum = weighted_checksum(&last_logits);
let state_att_checksum = weighted_checksum(state.layers[0].att_state.as_slice());
let state_prev_checksum = weighted_checksum(state.layers[1].att_x_prev.as_slice());
let v_first_checksum = weighted_checksum(state.v_first.as_slice());
let expected_probe_checksum = 25.674_967_924_598_604_f64;
let expected_last_logits_checksum = 0.679_873_816_668_987_3_f64;
let expected_state_att_checksum = 129.962_464_237_222_32_f64;
let expected_state_prev_checksum = -231.326_208_570_972_08_f64;
let expected_v_first_checksum = -1.921_361_377_462_744_7_f64;
let tol = 2e-4_f64;
assert!(
(probe_checksum - expected_probe_checksum).abs() <= tol,
"probe_checksum={probe_checksum}"
);
assert!(
(last_logits_checksum - expected_last_logits_checksum).abs() <= tol,
"last_logits_checksum={last_logits_checksum}"
);
assert!(
(state_att_checksum - expected_state_att_checksum).abs() <= tol,
"state_att_checksum={state_att_checksum}"
);
assert!(
(state_prev_checksum - expected_state_prev_checksum).abs() <= tol,
"state_prev_checksum={state_prev_checksum}"
);
assert!(
(v_first_checksum - expected_v_first_checksum).abs() <= tol,
"v_first_checksum={v_first_checksum}"
);
}
#[test]
fn traced_and_untraced_forward_match_exactly() {
let cfg = Config {
vocab_size: 256,
hidden_size: 64,
num_layers: 2,
num_heads: 1,
head_dim: 64,
intermediate_size: 128,
layer_norm_eps: 1e-5,
group_norm_eps: 64e-5,
decay_low_rank: 16,
a_low_rank: 16,
v_low_rank: 16,
g_low_rank: 32,
};
cfg.validate().expect("valid test config");
let model = Model::new_random(cfg.clone(), 0xCAFEBABE).expect("random model");
let mut traced_state = model.new_state();
let mut plain_state = model.new_state();
let mut traced_scratch = ScratchBuffers::new(&cfg);
let mut plain_scratch = ScratchBuffers::new(&cfg);
traced_scratch.set_capture_train_trace(true);
plain_scratch.set_capture_train_trace(false);
let tokens = [3u32, 19, 77, 120, 255, 5, 88, 13, 144, 1, 200];
for &token in &tokens {
let traced_logits = model
.forward(&mut traced_scratch, token, &mut traced_state)
.to_vec();
let plain_logits = model
.forward(&mut plain_scratch, token, &mut plain_state)
.to_vec();
for (a, b) in traced_logits.iter().zip(plain_logits.iter()) {
assert_eq!(a.to_bits(), b.to_bits());
}
assert_eq!(traced_state.v_first_set, plain_state.v_first_set);
for (&a, &b) in traced_state
.v_first
.as_slice()
.iter()
.zip(plain_state.v_first.as_slice())
{
assert_eq!(a.to_bits(), b.to_bits());
}
for (tr_layer, plain_layer) in traced_state.layers.iter().zip(plain_state.layers.iter())
{
for (&a, &b) in tr_layer
.att_x_prev
.as_slice()
.iter()
.zip(plain_layer.att_x_prev.as_slice())
{
assert_eq!(a.to_bits(), b.to_bits());
}
for (&a, &b) in tr_layer
.att_state
.as_slice()
.iter()
.zip(plain_layer.att_state.as_slice())
{
assert_eq!(a.to_bits(), b.to_bits());
}
for (&a, &b) in tr_layer
.ffn_x_prev
.as_slice()
.iter()
.zip(plain_layer.ffn_x_prev.as_slice())
{
assert_eq!(a.to_bits(), b.to_bits());
}
}
}
}
#[test]
fn tbptt_segment_gradients_match_finite_difference() {
let cfg = test_cfg();
cfg.validate().expect("valid test config");
let model = Model::new_random(cfg.clone(), 0xD00D_F00D).expect("random model");
let steps = [(0u32, 1u8), (1, 2), (2, 3)];
let grads = segment_grads(&model, &cfg, &steps);
let eps = 1e-3f32;
for probe in [
Probe::Embed,
Probe::LnOutW,
Probe::AttnNormW,
Probe::OProj,
Probe::KProj,
Probe::VProj,
Probe::FfnKey,
] {
let analytic = probe_grad(&grads, probe);
let mut plus = model.clone();
let base = probe_value(&plus, probe);
set_probe(&mut plus, probe, base + eps);
let loss_plus = segment_loss(&plus, &cfg, &steps);
let mut minus = model.clone();
set_probe(&mut minus, probe, base - eps);
let loss_minus = segment_loss(&minus, &cfg, &steps);
let numeric = -((loss_plus - loss_minus) / (2.0 * eps as f64)) as f32;
let tol = 5e-2f32.max(analytic.abs().max(numeric.abs()) * 8e-2);
assert!(
(analytic - numeric).abs() <= tol,
"probe={probe:?} analytic={analytic} numeric={numeric} tol={tol}"
);
}
}
#[test]
fn tbptt_sgd_step_reduces_mean_segment_loss() {
let cfg = test_cfg();
cfg.validate().expect("valid test config");
let mut model = Model::new_random(cfg.clone(), 0x1234_5678).expect("random model");
let steps = [(0u32, 1u8), (1, 2), (2, 3), (3, 4)];
let before = segment_loss(&model, &cfg, &steps);
let mut scratch = ScratchBuffers::new(&cfg);
let start_state = model.new_state();
let mut live_state = model.new_state();
let mut adam_t = 0usize;
let scope = TrainScopeMask {
embed: true,
pre_norm: true,
attn_norm: true,
ffn_norm: true,
attn: true,
ffn: true,
head: true,
bias: false,
};
model
.online_train_segment_tbptt(
&mut scratch,
&start_state,
&steps,
scope,
OptimizerKind::Sgd,
1e-3,
0.0,
2,
&mut adam_t,
None,
None,
None,
None,
&mut live_state,
)
.expect("tbptt sgd step");
let after = segment_loss(&model, &cfg, &steps);
assert!(
after < before,
"expected SGD TBPTT step to reduce mean loss: before={before} after={after}"
);
}
}