use crate::metrics::grad_norm;
use crate::model::Layer;
use crate::moe_model::MoEModel;
use crate::moe_model::topology::{IN_DIM, N_EXPERTS, OUT_DIM, TOP_K};
#[derive(Debug, Clone)]
pub struct Fp32Linear {
pub weight: Vec<f32>, pub bias: Vec<f32>, }
#[derive(Debug, Clone)]
pub struct Fp32LayerNorm {
pub gamma: Vec<f32>, pub beta: Vec<f32>, pub eps: f32,
}
#[derive(Debug, Clone)]
pub struct Fp32Expert {
pub linears: Vec<Fp32Linear>,
pub layernorms: Vec<Fp32LayerNorm>,
}
#[derive(Debug, Clone)]
pub struct Fp32Router {
pub linear: Fp32Linear,
pub top_k: usize,
}
#[derive(Debug, Clone)]
pub struct Fp32MoE {
pub router: Fp32Router,
pub experts: Vec<Fp32Expert>,
}
pub fn build_from_model(model: &MoEModel) -> Fp32MoE {
let params = model.router.linear.parameters();
let router = Fp32Router {
linear: Fp32Linear {
weight: params[0].data.data.clone(),
bias: params[1].data.data.clone(),
},
top_k: model.router.top_k,
};
let mut experts: Vec<Fp32Expert> = Vec::with_capacity(N_EXPERTS);
for seq in &model.experts {
let mut linears: Vec<Fp32Linear> = Vec::new();
let mut layernorms: Vec<Fp32LayerNorm> = Vec::new();
for layer in seq.layers.iter() {
let params = layer.parameters();
match layer.name() {
"Linear" => {
linears.push(Fp32Linear {
weight: params[0].data.data.clone(),
bias: params[1].data.data.clone(),
});
}
"LayerNorm" => {
let eps = 1e-5f32;
layernorms.push(Fp32LayerNorm {
gamma: params[0].data.data.clone(),
beta: params[1].data.data.clone(),
eps,
});
}
_ => {
}
}
}
experts.push(Fp32Expert {
linears,
layernorms,
});
}
Fp32MoE { router, experts }
}
fn linear_fwd(x: &[f32], w: &Fp32Linear) -> Vec<f32> {
let n = w.bias.len();
let k = w.weight.len() / n.max(1);
let batch = x.len() / k.max(1);
let mut out = vec![0.0f32; batch * n];
for r in 0..batch {
for ci in 0..n {
let mut s = w.bias[ci];
for ki in 0..k {
s += x[r * k + ki] * w.weight[ci * k + ki];
}
out[r * n + ci] = s;
}
}
out
}
fn layernorm_fwd(x: &[f32], ln: &Fp32LayerNorm, n_rows: usize) -> (Vec<f32>, Vec<f32>, Vec<f32>) {
let d = ln.gamma.len();
let inv_d = 1.0f32 / d as f32;
let mut mean = vec![0.0f32; n_rows];
let mut rstd = vec![0.0f32; n_rows];
let mut out = vec![0.0f32; n_rows * d];
for r in 0..n_rows {
let row = &x[r * d..(r + 1) * d];
let m: f32 = row.iter().copied().sum::<f32>() * inv_d;
let v: f32 = row.iter().map(|v| (v - m) * (v - m)).sum::<f32>() * inv_d;
let rs = 1.0f32 / (v + ln.eps).sqrt();
mean[r] = m;
rstd[r] = rs;
for c in 0..d {
let xhat = (row[c] - m) * rs;
out[r * d + c] = xhat * ln.gamma[c] + ln.beta[c];
}
}
(out, mean, rstd)
}
fn gelu_fwd(x: &[f32]) -> Vec<f32> {
let c = (2.0f32 / std::f32::consts::PI).sqrt();
x.iter()
.map(|&v| {
let inner = c * (v + 0.044715f32 * v * v * v);
0.5 * v * (1.0 + inner.tanh())
})
.collect()
}
fn softmax_fwd(z: &[f32], n_rows: usize, n_cols: usize) -> Vec<f32> {
let mut out = vec![0.0f32; z.len()];
for r in 0..n_rows {
let row = &z[r * n_cols..(r + 1) * n_cols];
let m = row.iter().fold(f32::NEG_INFINITY, |a, b| a.max(*b));
let mut sum = 0.0f32;
let mut exps = vec![0.0f32; n_cols];
for c in 0..n_cols {
let e = (row[c] - m).exp();
exps[c] = e;
sum += e;
}
let inv = 1.0f32 / sum.max(1e-30f32);
for c in 0..n_cols {
out[r * n_cols + c] = exps[c] * inv;
}
}
out
}
fn topk_mask(p: &mut [f32], n_rows: usize, n_cols: usize, k: usize) -> Vec<usize> {
let mut idx = Vec::with_capacity(n_rows * k);
for r in 0..n_rows {
let row = &mut p[r * n_cols..(r + 1) * n_cols];
let mut order: Vec<usize> = (0..n_cols).collect();
order.sort_by(|&a, &b| {
row[b]
.partial_cmp(&row[a])
.unwrap_or(std::cmp::Ordering::Equal)
});
for c in 0..n_cols {
if c < k {
idx.push(order[c]);
} else {
row[order[c]] = 0.0;
}
}
}
idx
}
fn router_fwd(x: &[f32], router: &Fp32Router, n_rows: usize) -> (Vec<f32>, Vec<usize>) {
let logits = linear_fwd(x, &router.linear);
let mut probs = softmax_fwd(&logits, n_rows, N_EXPERTS);
let top_k_indices = topk_mask(&mut probs, n_rows, N_EXPERTS, router.top_k);
(probs, top_k_indices)
}
#[derive(Debug, Clone)]
struct ExpertBlockCache {
ln_in: Vec<f32>,
ln_mean: Vec<f32>,
ln_rstd: Vec<f32>,
gelu_in: Vec<f32>,
gelu_out: Vec<f32>,
}
#[derive(Debug, Clone)]
struct ExpertFwdCache {
blocks: Vec<ExpertBlockCache>,
pre_out_lin_in: Vec<f32>,
out: Vec<f32>,
}
fn expert_fwd_v2(x: &[f32], exp: &Fp32Expert, n_rows: usize) -> ExpertFwdCache {
let mut current: Vec<f32> = x.to_vec();
let mut blocks: Vec<ExpertBlockCache> = Vec::new();
let last_lin_idx = exp.linears.len() - 1;
for li in 0..last_lin_idx {
let lin = linear_fwd(¤t, &exp.linears[li]);
let (ln_out, ln_mean, ln_rstd) = layernorm_fwd(&lin, &exp.layernorms[li], n_rows);
let gelu_in = ln_out.clone();
let gelu_out = gelu_fwd(&gelu_in);
blocks.push(ExpertBlockCache {
ln_in: lin,
ln_mean,
ln_rstd,
gelu_in,
gelu_out: gelu_out.clone(),
});
current = gelu_out;
}
let last_lin = &exp.linears[last_lin_idx];
let out = linear_fwd(¤t, last_lin);
ExpertFwdCache {
blocks,
pre_out_lin_in: current,
out,
}
}
fn linear_bwd_inplace(
dy: &[f32],
x: &[f32],
w: &Fp32Linear,
dw: &mut [f32],
db: &mut [f32],
) -> Vec<f32> {
let n = w.bias.len();
let k = w.weight.len() / n.max(1);
let batch = dy.len() / n.max(1);
for bi in 0..batch {
for ni in 0..n {
let dyv = dy[bi * n + ni];
for ki in 0..k {
dw[ni * k + ki] += dyv * x[bi * k + ki];
}
db[ni] += dyv;
}
}
let mut dx = vec![0.0f32; batch * k];
for bi in 0..batch {
for ni in 0..n {
let dyv = dy[bi * n + ni];
for ki in 0..k {
dx[bi * k + ki] += dyv * w.weight[ni * k + ki];
}
}
}
dx
}
fn layernorm_bwd(
dy: &[f32],
x: &[f32],
ln: &Fp32LayerNorm,
mean: &[f32],
rstd: &[f32],
n_rows: usize,
dgamma: &mut [f32],
dbeta: &mut [f32],
) -> Vec<f32> {
let d = ln.gamma.len();
let inv_d = 1.0f32 / d as f32;
let mut dx = vec![0.0f32; n_rows * d];
for r in 0..n_rows {
let m = mean[r];
let rs = rstd[r];
let mut sum_dxhat = 0.0f32;
let mut sum_dxhat_xhat = 0.0f32;
for c in 0..d {
let xhat = (x[r * d + c] - m) * rs;
let dxhat = dy[r * d + c] * ln.gamma[c];
sum_dxhat += dxhat;
sum_dxhat_xhat += dxhat * xhat;
}
for c in 0..d {
let xhat = (x[r * d + c] - m) * rs;
let dxhat = dy[r * d + c] * ln.gamma[c];
dx[r * d + c] = inv_d * rs * (d as f32 * dxhat - sum_dxhat - xhat * sum_dxhat_xhat);
}
}
for c in 0..d {
let mut sg = 0.0f32;
let mut sb = 0.0f32;
for r in 0..n_rows {
let m = mean[r];
let rs = rstd[r];
let xhat = (x[r * d + c] - m) * rs;
sg += xhat * dy[r * d + c];
sb += dy[r * d + c];
}
dgamma[c] += sg;
dbeta[c] += sb;
}
dx
}
fn gelu_bwd(go: &[f32], x: &[f32]) -> Vec<f32> {
let c = (2.0f32 / std::f32::consts::PI).sqrt();
let a = 0.044715f32;
go.iter()
.zip(x.iter())
.map(|(g, xv)| {
let x2 = xv * xv;
let x3 = x2 * xv;
let inner = c * (xv + a * x3);
let t = inner.tanh();
let dt = 1.0f32 - t * t;
let dydx = 0.5 * (1.0 + t) + 0.5 * xv * dt * c * (1.0 + 3.0 * a * x2);
g * dydx
})
.collect()
}
fn softmax_bwd(g: &[f32], p: &[f32], n_rows: usize, n_cols: usize) -> Vec<f32> {
let mut out = vec![0.0f32; g.len()];
for r in 0..n_rows {
let row_g = &g[r * n_cols..(r + 1) * n_cols];
let row_p = &p[r * n_cols..(r + 1) * n_cols];
let dot: f32 = row_g.iter().zip(row_p.iter()).map(|(a, b)| a * b).sum();
for c in 0..n_cols {
out[r * n_cols + c] = row_p[c] * (row_g[c] - dot);
}
}
out
}
pub struct Fp16SimResult {
pub param_grads: Vec<Vec<f32>>,
pub logits: Vec<f32>,
pub batch: usize,
pub top_k_indices: Vec<usize>,
}
pub fn simulated_fp16_forward_backward(
ref_model: &Fp32MoE,
x: &[f32],
grad_output: &[f32],
) -> Fp16SimResult {
let k = IN_DIM;
let batch = x.len() / k;
assert_eq!(x.len(), batch * k);
assert_eq!(grad_output.len(), batch * OUT_DIM);
let x = q_vec(x);
let router_w_logits = linear_fwd(&x, &ref_model.router.linear);
let mut probs = softmax_fwd(&q_vec(&router_w_logits), batch, N_EXPERTS);
let top_k_indices = topk_mask(&mut probs, batch, N_EXPERTS, ref_model.router.top_k);
let router_w = q_vec(&probs);
let mut exp_caches: Vec<ExpertFwdCache> = Vec::with_capacity(N_EXPERTS);
let mut exp_outs: Vec<Vec<f32>> = Vec::with_capacity(N_EXPERTS);
for ei in 0..N_EXPERTS {
let c = expert_fwd_sim_v2(&x, &ref_model.experts[ei], batch);
exp_outs.push(q_vec(&c.out));
exp_caches.push(c);
}
let mut logits = vec![0.0f32; batch * OUT_DIM];
for ei in 0..N_EXPERTS {
for bi in 0..batch {
let w = router_w[bi * N_EXPERTS + ei];
if w == 0.0 {
continue;
}
for d in 0..OUT_DIM {
logits[bi * OUT_DIM + d] += w * exp_outs[ei][bi * OUT_DIM + d];
}
}
}
let logits = q_vec(&logits);
let mut router_dw = vec![0.0f32; ref_model.router.linear.weight.len()];
let mut router_db = vec![0.0f32; ref_model.router.linear.bias.len()];
let mut exp_dw: Vec<Vec<Vec<f32>>> = Vec::with_capacity(N_EXPERTS);
let mut exp_db: Vec<Vec<Vec<f32>>> = Vec::with_capacity(N_EXPERTS);
let mut exp_dgamma: Vec<Vec<Vec<f32>>> = Vec::with_capacity(N_EXPERTS);
let mut exp_dbeta: Vec<Vec<Vec<f32>>> = Vec::with_capacity(N_EXPERTS);
for ei in 0..N_EXPERTS {
let e = &ref_model.experts[ei];
let n_lin = e.linears.len();
let mut dw = Vec::with_capacity(n_lin);
let mut db_lin = Vec::with_capacity(n_lin);
for li in 0..n_lin {
dw.push(vec![0.0f32; e.linears[li].weight.len()]);
db_lin.push(vec![0.0f32; e.linears[li].bias.len()]);
}
exp_dw.push(dw);
exp_db.push(db_lin);
let mut dg = Vec::with_capacity(e.layernorms.len());
let mut db_ln = Vec::with_capacity(e.layernorms.len());
for ln in &e.layernorms {
dg.push(vec![0.0f32; ln.gamma.len()]);
db_ln.push(vec![0.0f32; ln.beta.len()]);
}
exp_dgamma.push(dg);
exp_dbeta.push(db_ln);
}
let grad_output = q_vec(grad_output);
let mut grad_x_experts = vec![0.0f32; batch * IN_DIM];
for ei in 0..N_EXPERTS {
let mut go_e = vec![0.0f32; batch * OUT_DIM];
for bi in 0..batch {
let w = router_w[bi * N_EXPERTS + ei];
for d in 0..OUT_DIM {
go_e[bi * OUT_DIM + d] = grad_output[bi * OUT_DIM + d] * w;
}
}
let go_e = q_vec(&go_e);
let exp = &ref_model.experts[ei];
let cache = &exp_caches[ei];
let last_lin_idx = exp.linears.len() - 1;
let mut cur = linear_bwd_inplace(
&go_e,
&cache.pre_out_lin_in,
&exp.linears[last_lin_idx],
&mut exp_dw[ei][last_lin_idx],
&mut exp_db[ei][last_lin_idx],
);
for bi_idx in (0..cache.blocks.len()).rev() {
let blk = &cache.blocks[bi_idx];
let dgelu = gelu_bwd(&q_vec(&cur), &blk.gelu_in);
let dln_in = layernorm_bwd(
&dgelu,
&blk.ln_in,
&exp.layernorms[bi_idx],
&blk.ln_mean,
&blk.ln_rstd,
batch,
&mut exp_dgamma[ei][bi_idx],
&mut exp_dbeta[ei][bi_idx],
);
cur = linear_bwd_inplace(
&dln_in,
if bi_idx == 0 {
&x
} else {
&cache.blocks[bi_idx - 1].gelu_out
},
&exp.linears[bi_idx],
&mut exp_dw[ei][bi_idx],
&mut exp_db[ei][bi_idx],
);
}
for bi in 0..batch {
let w = router_w[bi * N_EXPERTS + ei];
for ki in 0..IN_DIM {
grad_x_experts[bi * IN_DIM + ki] += w * cur[bi * IN_DIM + ki];
}
}
}
let mut grad_w = vec![0.0f32; batch * N_EXPERTS];
for bi in 0..batch {
for ei in 0..N_EXPERTS {
let mut s = 0.0f32;
for d in 0..OUT_DIM {
s += grad_output[bi * OUT_DIM + d] * exp_outs[ei][bi * OUT_DIM + d];
}
grad_w[bi * N_EXPERTS + ei] = s;
}
}
for bi in 0..batch {
let row_topk = &top_k_indices[bi * TOP_K..(bi + 1) * TOP_K];
for ei in 0..N_EXPERTS {
if !row_topk.contains(&ei) {
grad_w[bi * N_EXPERTS + ei] = 0.0;
}
}
}
let grad_logits = softmax_bwd(&grad_w, &router_w, batch, N_EXPERTS);
let _router_dx = linear_bwd_inplace(
&grad_logits,
&x,
&ref_model.router.linear,
&mut router_dw,
&mut router_db,
);
let _ = grad_x_experts;
let mut param_grads: Vec<Vec<f32>> = Vec::new();
param_grads.push(router_dw);
param_grads.push(router_db);
for ei in 0..N_EXPERTS {
let e = &ref_model.experts[ei];
for li in 0..e.linears.len() {
param_grads.push(std::mem::take(&mut exp_dw[ei][li]));
param_grads.push(std::mem::take(&mut exp_db[ei][li]));
}
for li in 0..e.layernorms.len() {
param_grads.push(std::mem::take(&mut exp_dgamma[ei][li]));
param_grads.push(std::mem::take(&mut exp_dbeta[ei][li]));
}
}
Fp16SimResult {
param_grads,
logits,
batch,
top_k_indices,
}
}
fn expert_fwd_sim_v2(x: &[f32], exp: &Fp32Expert, n_rows: usize) -> ExpertFwdCache {
let mut current: Vec<f32> = x.to_vec();
let mut blocks: Vec<ExpertBlockCache> = Vec::new();
let last_lin_idx = exp.linears.len() - 1;
for li in 0..last_lin_idx {
let lin = linear_fwd(¤t, &exp.linears[li]);
let lin_q = q_vec(&lin);
let (ln_out, ln_mean, ln_rstd) = layernorm_fwd(&lin_q, &exp.layernorms[li], n_rows);
let gelu_in = q_vec(&ln_out);
let gelu_out = gelu_fwd(&gelu_in);
blocks.push(ExpertBlockCache {
ln_in: lin_q,
ln_mean,
ln_rstd,
gelu_in,
gelu_out: q_vec(&gelu_out),
});
current = gelu_out;
}
let last_lin = &exp.linears[last_lin_idx];
let out = linear_fwd(¤t, last_lin);
ExpertFwdCache {
blocks,
pre_out_lin_in: current,
out,
}
}
#[inline]
fn q_vec(v: &[f32]) -> Vec<f32> {
v.iter()
.map(|x| crate::model::util::f16_to_f32(crate::model::util::f32_to_f16(*x)))
.collect()
}
pub struct Fp32RunResult {
pub param_grads: Vec<Vec<f32>>,
pub logits: Vec<f32>,
pub batch: usize,
pub top_k_indices: Vec<usize>,
}
pub fn forward_backward(ref_model: &Fp32MoE, x: &[f32], grad_output: &[f32]) -> Fp32RunResult {
let k = IN_DIM;
let batch = x.len() / k;
assert_eq!(x.len(), batch * k);
assert_eq!(grad_output.len(), batch * OUT_DIM);
let (router_w, top_k_indices) = router_fwd(x, &ref_model.router, batch);
let mut exp_caches: Vec<ExpertFwdCache> = Vec::with_capacity(N_EXPERTS);
let mut exp_outs: Vec<Vec<f32>> = Vec::with_capacity(N_EXPERTS);
for ei in 0..N_EXPERTS {
let c = expert_fwd_v2(x, &ref_model.experts[ei], batch);
exp_outs.push(c.out.clone());
exp_caches.push(c);
}
let mut logits = vec![0.0f32; batch * OUT_DIM];
for ei in 0..N_EXPERTS {
for bi in 0..batch {
let w = router_w[bi * N_EXPERTS + ei];
if w == 0.0 {
continue;
}
for d in 0..OUT_DIM {
logits[bi * OUT_DIM + d] += w * exp_outs[ei][bi * OUT_DIM + d];
}
}
}
let mut router_dw = vec![0.0f32; ref_model.router.linear.weight.len()];
let mut router_db = vec![0.0f32; ref_model.router.linear.bias.len()];
let mut exp_dw: Vec<Vec<Vec<f32>>> = Vec::with_capacity(N_EXPERTS);
let mut exp_db: Vec<Vec<Vec<f32>>> = Vec::with_capacity(N_EXPERTS);
let mut exp_dgamma: Vec<Vec<Vec<f32>>> = Vec::with_capacity(N_EXPERTS);
let mut exp_dbeta: Vec<Vec<Vec<f32>>> = Vec::with_capacity(N_EXPERTS);
for ei in 0..N_EXPERTS {
let e = &ref_model.experts[ei];
let n_lin = e.linears.len();
let mut dw = Vec::with_capacity(n_lin);
let mut db_lin = Vec::with_capacity(n_lin);
for li in 0..n_lin {
dw.push(vec![0.0f32; e.linears[li].weight.len()]);
db_lin.push(vec![0.0f32; e.linears[li].bias.len()]);
}
exp_dw.push(dw);
exp_db.push(db_lin);
let mut dg = Vec::with_capacity(e.layernorms.len());
let mut db_ln = Vec::with_capacity(e.layernorms.len());
for ln in &e.layernorms {
dg.push(vec![0.0f32; ln.gamma.len()]);
db_ln.push(vec![0.0f32; ln.beta.len()]);
}
exp_dgamma.push(dg);
exp_dbeta.push(db_ln);
}
let mut grad_x_experts = vec![0.0f32; batch * IN_DIM];
for ei in 0..N_EXPERTS {
let mut go_e = vec![0.0f32; batch * OUT_DIM];
for bi in 0..batch {
let w = router_w[bi * N_EXPERTS + ei];
for d in 0..OUT_DIM {
go_e[bi * OUT_DIM + d] = grad_output[bi * OUT_DIM + d] * w;
}
}
let exp = &ref_model.experts[ei];
let cache = &exp_caches[ei];
let last_lin_idx = exp.linears.len() - 1;
let mut cur = linear_bwd_inplace(
&go_e,
&cache.pre_out_lin_in,
&exp.linears[last_lin_idx],
&mut exp_dw[ei][last_lin_idx],
&mut exp_db[ei][last_lin_idx],
);
for bi_idx in (0..cache.blocks.len()).rev() {
let blk = &cache.blocks[bi_idx];
let dgelu = gelu_bwd(&cur, &blk.gelu_in);
let dln_in = layernorm_bwd(
&dgelu,
&blk.ln_in,
&exp.layernorms[bi_idx],
&blk.ln_mean,
&blk.ln_rstd,
batch,
&mut exp_dgamma[ei][bi_idx],
&mut exp_dbeta[ei][bi_idx],
);
cur = linear_bwd_inplace(
&dln_in,
if bi_idx == 0 {
x
} else {
&cache.blocks[bi_idx - 1].gelu_out
},
&exp.linears[bi_idx],
&mut exp_dw[ei][bi_idx],
&mut exp_db[ei][bi_idx],
);
}
for bi in 0..batch {
let w = router_w[bi * N_EXPERTS + ei];
for ki in 0..IN_DIM {
grad_x_experts[bi * IN_DIM + ki] += w * cur[bi * IN_DIM + ki];
}
}
}
let mut grad_w = vec![0.0f32; batch * N_EXPERTS];
for bi in 0..batch {
for ei in 0..N_EXPERTS {
let mut s = 0.0f32;
for d in 0..OUT_DIM {
s += grad_output[bi * OUT_DIM + d] * exp_outs[ei][bi * OUT_DIM + d];
}
grad_w[bi * N_EXPERTS + ei] = s;
}
}
for bi in 0..batch {
let row_topk = &top_k_indices[bi * TOP_K..(bi + 1) * TOP_K];
for ei in 0..N_EXPERTS {
if !row_topk.contains(&ei) {
grad_w[bi * N_EXPERTS + ei] = 0.0;
}
}
}
let grad_logits = softmax_bwd(&grad_w, &router_w, batch, N_EXPERTS);
let _router_dx = linear_bwd_inplace(
&grad_logits,
x,
&ref_model.router.linear,
&mut router_dw,
&mut router_db,
);
let _ = grad_x_experts;
let mut param_grads: Vec<Vec<f32>> = Vec::new();
param_grads.push(router_dw);
param_grads.push(router_db);
for ei in 0..N_EXPERTS {
let e = &ref_model.experts[ei];
for li in 0..e.linears.len() {
param_grads.push(std::mem::take(&mut exp_dw[ei][li]));
param_grads.push(std::mem::take(&mut exp_db[ei][li]));
}
for li in 0..e.layernorms.len() {
param_grads.push(std::mem::take(&mut exp_dgamma[ei][li]));
param_grads.push(std::mem::take(&mut exp_dbeta[ei][li]));
}
}
Fp32RunResult {
param_grads,
logits,
batch,
top_k_indices,
}
}
pub fn param_names(model: &MoEModel) -> Vec<String> {
let mut names = Vec::new();
names.push("router.weight".to_string());
names.push("router.bias".to_string());
for (ei, seq) in model.experts.iter().enumerate() {
let mut lin_idx = 0;
let mut ln_idx = 0;
for layer in &seq.layers {
match layer.name() {
"Linear" => {
names.push(format!("expert{ei}.lin{lin_idx}.weight"));
names.push(format!("expert{ei}.lin{lin_idx}.bias"));
lin_idx += 1;
}
"LayerNorm" => {
names.push(format!("expert{ei}.ln{ln_idx}.gamma"));
names.push(format!("expert{ei}.ln{ln_idx}.beta"));
ln_idx += 1;
}
_ => {}
}
}
}
names
}
pub fn grad_norm_of(grad: &[f32]) -> f32 {
grad_norm(grad)
}