use crate::loss::pcgrad;
use ndarray::{Array1, Array2};
use rand::rngs::StdRng;
use rand::{RngExt, SeedableRng};
use rand_distr::StandardNormal;
use serde::{Deserialize, Serialize};
use std::fs;
use std::path::Path;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub struct LayerSpec {
pub d_model: usize,
pub d_state: usize,
pub d_conv: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AdamWConfig {
pub lr: f32,
pub beta1: f32,
pub beta2: f32,
pub eps: f32,
pub weight_decay: f32,
}
impl Default for AdamWConfig {
fn default() -> Self {
Self {
lr: 1e-4,
beta1: 0.9,
beta2: 0.999,
eps: 1e-8,
weight_decay: 0.01,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum FreezeSelection {
FirstN(usize),
Indices(Vec<usize>),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ExpansionPlacement {
Append,
Prepend,
InsertAt { index: usize },
SpecificPositions(Vec<usize>),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExpansionConfig {
pub target_num_layers: usize,
pub placement: ExpansionPlacement,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExperimentalTrainerConfig {
pub vocab_size: usize,
pub layer_spec: LayerSpec,
pub expansion: ExpansionConfig,
pub freeze_selection: FreezeSelection,
pub freeze_embedding: bool,
pub ff_lr: f32,
pub ff_threshold: f32,
pub adamw: AdamWConfig,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MambaLayerParams {
pub a_log: Array2<f32>,
pub d_skip: Array1<f32>,
pub x_proj_w: Array2<f32>,
pub dt_proj_w: Array2<f32>,
pub dt_proj_b: Array1<f32>,
pub conv1d_w: Array2<f32>,
pub conv1d_b: Array1<f32>,
pub out_proj_w: Array2<f32>,
}
impl MambaLayerParams {
pub fn random(spec: LayerSpec, rng: &mut StdRng) -> Self {
Self {
a_log: random_matrix(spec.d_model, spec.d_state, rng, 0.02),
d_skip: random_vector(spec.d_model, rng, 0.02),
x_proj_w: random_matrix(spec.d_model, spec.d_state * 2 + 1, rng, 0.02),
dt_proj_w: random_matrix(1, spec.d_model, rng, 0.02),
dt_proj_b: random_vector(spec.d_model, rng, 0.02),
conv1d_w: random_matrix(spec.d_model, spec.d_conv, rng, 0.02),
conv1d_b: random_vector(spec.d_model, rng, 0.02),
out_proj_w: random_matrix(spec.d_model, spec.d_model, rng, 0.02),
}
}
pub fn zero_residual(spec: LayerSpec) -> Self {
Self {
a_log: Array2::zeros((spec.d_model, spec.d_state)),
d_skip: Array1::zeros(spec.d_model),
x_proj_w: Array2::zeros((spec.d_model, spec.d_state * 2 + 1)),
dt_proj_w: Array2::zeros((1, spec.d_model)),
dt_proj_b: Array1::zeros(spec.d_model),
conv1d_w: Array2::zeros((spec.d_model, spec.d_conv)),
conv1d_b: Array1::zeros(spec.d_model),
out_proj_w: Array2::zeros((spec.d_model, spec.d_model)),
}
}
pub fn l2_norm(&self) -> f32 {
let mut total = 0.0;
total += self.a_log.iter().map(|v| v * v).sum::<f32>();
total += self.d_skip.iter().map(|v| v * v).sum::<f32>();
total += self.x_proj_w.iter().map(|v| v * v).sum::<f32>();
total += self.dt_proj_w.iter().map(|v| v * v).sum::<f32>();
total += self.dt_proj_b.iter().map(|v| v * v).sum::<f32>();
total += self.conv1d_w.iter().map(|v| v * v).sum::<f32>();
total += self.conv1d_b.iter().map(|v| v * v).sum::<f32>();
total += self.out_proj_w.iter().map(|v| v * v).sum::<f32>();
total.sqrt()
}
pub fn ff_goodness_activation(h: &Array1<f32>, theta: &Array1<f32>) -> f32 {
let eps = 1e-5_f32;
let h_mean = h.iter().sum::<f32>() / h.len() as f32;
let h_var = h.iter().map(|v| (v - h_mean).powi(2)).sum::<f32>() / h.len() as f32;
let h_std = h_var.sqrt() + eps;
let theta_norm = theta.iter().map(|v| v * v).sum::<f32>().sqrt() + eps;
h.iter()
.zip(theta.iter())
.map(|(hi, ti)| (hi / h_std) * (ti / theta_norm))
.sum::<f32>()
}
pub fn ff_loss_pair(h_pos: &Array1<f32>, h_neg: &Array1<f32>, theta: &Array1<f32>) -> f32 {
let g_pos = Self::ff_goodness_activation(h_pos, theta);
let g_neg = Self::ff_goodness_activation(h_neg, theta);
(1.0_f32 - (g_pos - g_neg)).max(0.0)
}
pub fn ff_grad_theta(
h_pos: &Array1<f32>,
h_neg: &Array1<f32>,
theta: &Array1<f32>,
) -> Array1<f32> {
let loss = Self::ff_loss_pair(h_pos, h_neg, theta);
if loss <= 0.0 {
return Array1::zeros(theta.len());
}
let eps = 1e-5_f32;
let h_pos_mean = h_pos.iter().sum::<f32>() / h_pos.len() as f32;
let h_pos_std = (h_pos.iter().map(|v| (v - h_pos_mean).powi(2)).sum::<f32>()
/ h_pos.len() as f32)
.sqrt()
+ eps;
let h_neg_mean = h_neg.iter().sum::<f32>() / h_neg.len() as f32;
let h_neg_std = (h_neg.iter().map(|v| (v - h_neg_mean).powi(2)).sum::<f32>()
/ h_neg.len() as f32)
.sqrt()
+ eps;
let theta_norm = theta.iter().map(|v| v * v).sum::<f32>().sqrt() + eps;
Array1::from_iter(
h_pos
.iter()
.zip(h_neg.iter())
.map(|(hp, hn)| -((hp / h_pos_std) - (hn / h_neg_std)) / theta_norm),
)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrainerParams {
pub embedding: Array2<f32>,
pub layers: Vec<MambaLayerParams>,
}
impl TrainerParams {
pub fn random(vocab_size: usize, spec: LayerSpec, num_layers: usize, seed: u64) -> Self {
let mut rng = StdRng::seed_from_u64(seed);
let embedding = random_matrix(vocab_size, spec.d_model, &mut rng, 0.02);
let mut layers = Vec::with_capacity(num_layers);
for _ in 0..num_layers {
layers.push(MambaLayerParams::random(spec, &mut rng));
}
Self { embedding, layers }
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExperimentalTrainer {
pub cfg: ExperimentalTrainerConfig,
pub params: TrainerParams,
frozen_embedding: Option<Array2<f32>>,
frozen_layer_indices: Vec<usize>,
frozen_layers: Vec<MambaLayerParams>,
pub step: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CadencedStepStats {
pub step: usize,
pub ff_losses: Vec<f32>,
pub bp_applied: bool,
pub ff_updates_applied: usize,
pub bp_updates_applied: usize,
}
impl ExperimentalTrainer {
pub fn from_base(base: TrainerParams, cfg: ExperimentalTrainerConfig) -> Self {
let mut params = base;
expand_layers_in_place(
&mut params.layers,
cfg.layer_spec,
cfg.expansion.target_num_layers,
&cfg.expansion.placement,
);
let frozen_layer_indices =
resolve_freeze_indices(&cfg.freeze_selection, params.layers.len());
let frozen_layers = frozen_layer_indices
.iter()
.map(|&idx| params.layers[idx].clone())
.collect::<Vec<_>>();
let frozen_embedding = if cfg.freeze_embedding {
Some(params.embedding.clone())
} else {
None
};
Self {
cfg,
params,
frozen_embedding,
frozen_layer_indices,
frozen_layers,
step: 0,
}
}
pub fn save_checkpoint<P: AsRef<Path>>(&self, path: P) -> Result<(), String> {
let bytes = bincode::serde::encode_to_vec(self, bincode::config::standard())
.map_err(|err| format!("serialize failed: {err}"))?;
fs::write(path, bytes).map_err(|err| format!("checkpoint write failed: {err}"))
}
pub fn load_checkpoint<P: AsRef<Path>>(path: P) -> Result<Self, String> {
let bytes = fs::read(path).map_err(|err| format!("checkpoint read failed: {err}"))?;
let (decoded, _bytes_read) =
bincode::serde::decode_from_slice::<Self, _>(&bytes, bincode::config::standard())
.map_err(|err| format!("deserialize failed: {err}"))?;
Ok(decoded)
}
pub fn expanded_layer_count(&self) -> usize {
self.params.layers.len()
}
pub fn frozen_layer_indices(&self) -> &[usize] {
&self.frozen_layer_indices
}
pub fn enforce_freeze(&mut self) {
if let Some(frozen_embedding) = &self.frozen_embedding {
self.params.embedding.assign(frozen_embedding);
}
for (pos, &layer_idx) in self.frozen_layer_indices.iter().enumerate() {
self.params.layers[layer_idx] = self.frozen_layers[pos].clone();
}
}
pub fn train_ff_cycle(
&mut self,
h_pos_per_layer: &[Array1<f32>],
h_neg_per_layer: &[Array1<f32>],
) -> Vec<f32> {
assert_eq!(
h_pos_per_layer.len(),
self.params.layers.len(),
"h_pos layers mismatch"
);
assert_eq!(
h_neg_per_layer.len(),
self.params.layers.len(),
"h_neg layers mismatch"
);
let mut losses = Vec::with_capacity(self.params.layers.len());
for idx in 0..self.params.layers.len() {
if self.frozen_layer_indices.binary_search(&idx).is_ok() {
losses.push(0.0_f32);
continue;
}
let h_pos = &h_pos_per_layer[idx];
let h_neg = &h_neg_per_layer[idx];
let theta = self.params.layers[idx].d_skip.clone();
let loss = MambaLayerParams::ff_loss_pair(h_pos, h_neg, &theta);
losses.push(loss);
if loss > 0.0 {
let grad = MambaLayerParams::ff_grad_theta(h_pos, h_neg, &theta);
for (t, g) in self.params.layers[idx].d_skip.iter_mut().zip(grad.iter()) {
*t -= self.cfg.ff_lr * g;
}
}
}
self.enforce_freeze();
self.step += 1;
losses
}
pub fn train_step_cadenced(
&mut self,
h_pos_per_layer: &[Array1<f32>],
h_neg_per_layer: &[Array1<f32>],
bp_grads_per_layer: Option<&[Array1<f32>]>,
cadence_steps: usize,
pcgrad_epsilon: f32,
) -> CadencedStepStats {
assert_eq!(
h_pos_per_layer.len(),
self.params.layers.len(),
"h_pos layers mismatch"
);
assert_eq!(
h_neg_per_layer.len(),
self.params.layers.len(),
"h_neg layers mismatch"
);
let bp_due = cadence_steps > 0 && (self.step + 1).is_multiple_of(cadence_steps);
if bp_due {
let bp = bp_grads_per_layer.expect("bp gradients required on cadence step");
assert_eq!(
bp.len(),
self.params.layers.len(),
"bp grads layers mismatch"
);
}
let mut ff_losses = Vec::with_capacity(self.params.layers.len());
let mut ff_updates_applied = 0usize;
let mut bp_updates_applied = 0usize;
for idx in 0..self.params.layers.len() {
if self.frozen_layer_indices.binary_search(&idx).is_ok() {
ff_losses.push(0.0_f32);
continue;
}
let h_pos = &h_pos_per_layer[idx];
let h_neg = &h_neg_per_layer[idx];
let theta = self.params.layers[idx].d_skip.clone();
let ff_loss = MambaLayerParams::ff_loss_pair(h_pos, h_neg, &theta);
ff_losses.push(ff_loss);
let ff_grad = if ff_loss > 0.0 {
ff_updates_applied += 1;
MambaLayerParams::ff_grad_theta(h_pos, h_neg, &theta)
} else {
Array1::zeros(theta.len())
};
let bp_grad = if bp_due {
bp_updates_applied += 1;
bp_grads_per_layer.expect("bp gradients required on cadence step")[idx].clone()
} else {
Array1::zeros(theta.len())
};
let ff_after_surgery = if bp_due {
pcgrad(&ff_grad, &bp_grad, pcgrad_epsilon)
} else {
ff_grad
};
for i in 0..self.params.layers[idx].d_skip.len() {
self.params.layers[idx].d_skip[i] -= self.cfg.ff_lr * ff_after_surgery[i];
if bp_due {
self.params.layers[idx].d_skip[i] -= self.cfg.adamw.lr * bp_grad[i];
}
}
}
self.enforce_freeze();
self.step += 1;
CadencedStepStats {
step: self.step,
ff_losses,
bp_applied: bp_due,
ff_updates_applied,
bp_updates_applied,
}
}
pub fn layer_norms(&self) -> Vec<f32> {
self.params
.layers
.iter()
.map(MambaLayerParams::l2_norm)
.collect()
}
}
pub fn resolve_freeze_indices(selection: &FreezeSelection, total_layers: usize) -> Vec<usize> {
let mut indices = match selection {
FreezeSelection::FirstN(count) => (0..(*count).min(total_layers)).collect::<Vec<_>>(),
FreezeSelection::Indices(raw) => raw
.iter()
.copied()
.filter(|idx| *idx < total_layers)
.collect::<Vec<_>>(),
};
indices.sort_unstable();
indices.dedup();
indices
}
pub fn expand_layers_in_place(
base_layers: &mut Vec<MambaLayerParams>,
spec: LayerSpec,
target_num_layers: usize,
placement: &ExpansionPlacement,
) {
assert!(
target_num_layers >= base_layers.len(),
"target layers must be >= base layers"
);
let base_count = base_layers.len();
let new_count = target_num_layers - base_count;
if new_count == 0 {
return;
}
let new_layers = (0..new_count)
.map(|_| MambaLayerParams::zero_residual(spec))
.collect::<Vec<_>>();
match placement {
ExpansionPlacement::Append => {
base_layers.extend(new_layers);
}
ExpansionPlacement::Prepend => {
let mut merged = new_layers;
merged.extend(base_layers.clone());
*base_layers = merged;
}
ExpansionPlacement::InsertAt { index } => {
let insert_at = (*index).min(base_layers.len());
let tail = base_layers.split_off(insert_at);
base_layers.extend(new_layers);
base_layers.extend(tail);
}
ExpansionPlacement::SpecificPositions(positions) => {
assert!(
positions.len() == new_count,
"SpecificPositions length must equal number of new layers"
);
let mut final_positions = positions.clone();
final_positions.sort_unstable();
final_positions.dedup();
assert!(
final_positions.len() == new_count,
"SpecificPositions must contain unique indices"
);
assert!(
final_positions.iter().all(|idx| *idx < target_num_layers),
"SpecificPositions indices must be < target_num_layers"
);
let old_layers = base_layers.clone();
let mut old_cursor = 0usize;
let mut new_cursor = 0usize;
let mut merged = Vec::with_capacity(target_num_layers);
for final_idx in 0..target_num_layers {
if final_positions.binary_search(&final_idx).is_ok() {
merged.push(new_layers[new_cursor].clone());
new_cursor += 1;
} else {
merged.push(old_layers[old_cursor].clone());
old_cursor += 1;
}
}
*base_layers = merged;
}
}
}
fn random_matrix(rows: usize, cols: usize, rng: &mut StdRng, std: f32) -> Array2<f32> {
Array2::from_shape_fn((rows, cols), |_| rng.sample::<f32, _>(StandardNormal) * std)
}
fn random_vector(len: usize, rng: &mut StdRng, std: f32) -> Array1<f32> {
Array1::from_shape_fn(len, |_| rng.sample::<f32, _>(StandardNormal) * std)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::loss::pcgrad;
use ndarray::array;
fn spec() -> LayerSpec {
LayerSpec {
d_model: 16,
d_state: 8,
d_conv: 4,
}
}
fn cfg_with(
placement: ExpansionPlacement,
freeze: FreezeSelection,
) -> ExperimentalTrainerConfig {
ExperimentalTrainerConfig {
vocab_size: 64,
layer_spec: spec(),
expansion: ExpansionConfig {
target_num_layers: 6,
placement,
},
freeze_selection: freeze,
freeze_embedding: true,
ff_lr: 1e-2,
ff_threshold: 1e-3,
adamw: AdamWConfig::default(),
}
}
#[test]
fn append_expansion_adds_new_layers_at_end() {
let base = TrainerParams::random(64, spec(), 2, 7);
let trainer = ExperimentalTrainer::from_base(
base,
cfg_with(ExpansionPlacement::Append, FreezeSelection::FirstN(2)),
);
assert_eq!(trainer.expanded_layer_count(), 6);
for layer in &trainer.params.layers[2..] {
assert!(layer.l2_norm() <= 1e-9);
}
}
#[test]
fn specific_positions_support_interleave_style_insertions() {
let base = TrainerParams::random(64, spec(), 2, 9);
let trainer = ExperimentalTrainer::from_base(
base,
cfg_with(
ExpansionPlacement::SpecificPositions(vec![1, 3, 4, 5]),
FreezeSelection::Indices(vec![0, 2]),
),
);
assert_eq!(trainer.expanded_layer_count(), 6);
assert_eq!(trainer.frozen_layer_indices(), &[0, 2]);
assert!(trainer.params.layers[1].l2_norm() <= 1e-9);
assert!(trainer.params.layers[3].l2_norm() <= 1e-9);
assert!(trainer.params.layers[4].l2_norm() <= 1e-9);
assert!(trainer.params.layers[5].l2_norm() <= 1e-9);
}
#[test]
fn freeze_by_indices_is_enforced_after_training_cycle() {
let base = TrainerParams::random(64, spec(), 2, 11);
let mut trainer = ExperimentalTrainer::from_base(
base,
cfg_with(
ExpansionPlacement::Append,
FreezeSelection::Indices(vec![0, 1]),
),
);
let d = spec().d_model;
let n_layers = trainer.expanded_layer_count();
let h_pos: Vec<Array1<f32>> = (0..n_layers)
.map(|_| Array1::from_elem(d, 0.5_f32))
.collect();
let h_neg: Vec<Array1<f32>> = (0..n_layers)
.map(|_| Array1::from_elem(d, -0.5_f32))
.collect();
let before = trainer.layer_norms();
let _ = trainer.train_ff_cycle(&h_pos, &h_neg);
let after = trainer.layer_norms();
assert!((before[0] - after[0]).abs() <= 1e-6);
assert!((before[1] - after[1]).abs() <= 1e-6);
}
#[test]
fn checkpoint_roundtrip_preserves_state() {
let base = TrainerParams::random(64, spec(), 2, 5);
let mut trainer = ExperimentalTrainer::from_base(
base,
cfg_with(ExpansionPlacement::Append, FreezeSelection::FirstN(2)),
);
let d = spec().d_model;
let n = trainer.expanded_layer_count();
let h_pos: Vec<Array1<f32>> = (0..n).map(|_| Array1::from_elem(d, 0.3_f32)).collect();
let h_neg: Vec<Array1<f32>> = (0..n).map(|_| Array1::from_elem(d, -0.3_f32)).collect();
let _ = trainer.train_ff_cycle(&h_pos, &h_neg);
let ckpt = std::env::temp_dir().join("trainer_lab_roundtrip.bincode");
trainer.save_checkpoint(&ckpt).unwrap();
let loaded = ExperimentalTrainer::load_checkpoint(&ckpt).unwrap();
assert_eq!(trainer.step, loaded.step);
assert_eq!(
trainer.expanded_layer_count(),
loaded.expanded_layer_count()
);
assert_eq!(
trainer.frozen_layer_indices(),
loaded.frozen_layer_indices()
);
let _ = std::fs::remove_file(&ckpt);
}
#[test]
fn pcgrad_projection_removes_conflict_component() {
let ff = array![-1.0_f32, 0.0];
let bp = array![1.0_f32, 0.0];
let out = pcgrad(&ff, &bp, 1e-8);
assert!(out.dot(&bp) >= -1e-6);
}
#[test]
fn cadenced_step_applies_bp_only_when_due() {
let base = TrainerParams::random(64, spec(), 2, 13);
let mut trainer = ExperimentalTrainer::from_base(
base,
cfg_with(ExpansionPlacement::Append, FreezeSelection::Indices(vec![])),
);
let d = spec().d_model;
let n = trainer.expanded_layer_count();
let h_pos: Vec<Array1<f32>> = (0..n).map(|_| Array1::from_elem(d, 0.5_f32)).collect();
let h_neg: Vec<Array1<f32>> = (0..n).map(|_| Array1::from_elem(d, -0.5_f32)).collect();
let bp: Vec<Array1<f32>> = (0..n).map(|_| Array1::from_elem(d, 0.1_f32)).collect();
let s1 = trainer.train_step_cadenced(&h_pos, &h_neg, Some(&bp), 2, 1e-8);
assert!(!s1.bp_applied);
assert_eq!(s1.bp_updates_applied, 0);
let s2 = trainer.train_step_cadenced(&h_pos, &h_neg, Some(&bp), 2, 1e-8);
assert!(s2.bp_applied);
assert_eq!(s2.bp_updates_applied, n);
}
}