use ndarray::{Array2, Array3, ArrayView2, ArrayView3, Axis};
use crate::state_dict::{StateDict, StateDictError};
#[derive(Debug, Clone)]
pub struct OneHotAndLinear {
pub num_classes: usize,
pub embed_dim: usize,
pub weight: Array2<f32>,
pub bias: Option<Vec<f32>>,
}
impl OneHotAndLinear {
pub fn from_raw_weight(weight: Array2<f32>, bias: Option<Vec<f32>>) -> Self {
let (embed_dim, num_classes) = (weight.shape()[0], weight.shape()[1]);
Self {
num_classes,
embed_dim,
weight,
bias,
}
}
pub fn load_from(&mut self, sd: &StateDict, prefix: &str) -> Result<(), StateDictError> {
self.weight = sd.take_array2(
&format!("{prefix}.weight"),
self.embed_dim,
self.num_classes,
)?;
let bias_key = format!("{prefix}.bias");
if sd.tensors.contains_key(&bias_key) {
self.bias = Some(sd.take_vec(&bias_key, self.embed_dim)?);
}
Ok(())
}
pub fn forward(&self, src: ArrayView2<usize>) -> Array3<f32> {
let (b, t) = (src.shape()[0], src.shape()[1]);
let mut out = Array3::<f32>::zeros((b, t, self.embed_dim));
for bi in 0..b {
for ti in 0..t {
let c = src[(bi, ti)];
debug_assert!(c < self.num_classes, "class index out of range");
for e in 0..self.embed_dim {
out[(bi, ti, e)] = self.weight[(e, c)];
}
if let Some(b_) = &self.bias {
for e in 0..self.embed_dim {
out[(bi, ti, e)] += b_[e];
}
}
}
}
out
}
}
#[derive(Debug, Clone)]
pub struct SkippableLinear {
pub weight: Array2<f32>,
pub bias: Option<Vec<f32>>,
pub skip_value: f32,
}
impl SkippableLinear {
pub fn new(weight: Array2<f32>, bias: Option<Vec<f32>>, skip_value: f32) -> Self {
Self {
weight,
bias,
skip_value,
}
}
pub fn in_features(&self) -> usize {
self.weight.shape()[1]
}
pub fn out_features(&self) -> usize {
self.weight.shape()[0]
}
pub fn load_from(&mut self, sd: &StateDict, prefix: &str) -> Result<(), StateDictError> {
let (out_f, in_f) = (self.out_features(), self.in_features());
self.weight = sd.take_array2(&format!("{prefix}.weight"), out_f, in_f)?;
let bias_key = format!("{prefix}.bias");
if sd.tensors.contains_key(&bias_key) {
self.bias = Some(sd.take_vec(&bias_key, out_f)?);
}
Ok(())
}
pub fn forward(&self, src: ArrayView3<f32>) -> Array3<f32> {
let (b, t, in_f) = (src.shape()[0], src.shape()[1], src.shape()[2]);
assert_eq!(in_f, self.in_features());
let out_f = self.out_features();
let mut out = Array3::<f32>::zeros((b, t, out_f));
for bi in 0..b {
for ti in 0..t {
let skip = (0..in_f).all(|k| src[(bi, ti, k)] == self.skip_value);
if skip {
for k in 0..out_f {
out[(bi, ti, k)] = self.skip_value;
}
continue;
}
for k in 0..out_f {
let mut acc = 0.0_f32;
for j in 0..in_f {
acc += src[(bi, ti, j)] * self.weight[(k, j)];
}
if let Some(b_) = &self.bias {
acc += b_[k];
}
out[(bi, ti, k)] = acc;
}
}
}
out
}
}
pub fn linear3d(
src: ArrayView3<f32>,
weight: ArrayView2<f32>,
bias: Option<&[f32]>,
) -> Array3<f32> {
let (b, t, in_f) = (src.shape()[0], src.shape()[1], src.shape()[2]);
let (out_f, in_f2) = (weight.shape()[0], weight.shape()[1]);
assert_eq!(in_f, in_f2);
let mut out = Array3::<f32>::zeros((b, t, out_f));
for bi in 0..b {
for ti in 0..t {
for k in 0..out_f {
let mut acc: f64 = 0.0;
for j in 0..in_f {
acc += (src[(bi, ti, j)] as f64) * (weight[(k, j)] as f64);
}
if let Some(b_) = bias {
acc += b_[k] as f64;
}
out[(bi, ti, k)] = acc as f32;
}
}
}
out
}
pub fn layer_norm_last(
x: ArrayView3<f32>,
gamma: &[f32],
beta: Option<&[f32]>,
eps: f32,
) -> Array3<f32> {
let (b, t, d) = (x.shape()[0], x.shape()[1], x.shape()[2]);
assert_eq!(gamma.len(), d);
let mut out = Array3::<f32>::zeros((b, t, d));
let inv_d = 1.0_f64 / d as f64;
for bi in 0..b {
for ti in 0..t {
let mut mean: f64 = 0.0;
for k in 0..d {
mean += x[(bi, ti, k)] as f64;
}
mean *= inv_d;
let mut var: f64 = 0.0;
for k in 0..d {
let dx = (x[(bi, ti, k)] as f64) - mean;
var += dx * dx;
}
var *= inv_d;
let inv = (var + eps as f64).sqrt().recip();
for k in 0..d {
let val = ((x[(bi, ti, k)] as f64 - mean) * inv) as f32 * gamma[k];
out[(bi, ti, k)] = match beta {
Some(b_) => val + b_[k],
None => val,
};
}
}
}
out
}
#[allow(dead_code)]
fn _silence_unused(_a: Axis) {}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::array;
#[test]
fn one_hot_and_linear_picks_column() {
let weight = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]; let m = OneHotAndLinear::from_raw_weight(weight, None);
let src = array![[0_usize, 1, 2]];
let out = m.forward(src.view());
assert_eq!(out.shape(), &[1, 3, 2]);
for t in 0..3 {
for e in 0..2 {
assert_eq!(out[(0, t, e)], (e * 3 + t) as f32 + 1.0);
}
}
}
#[test]
fn one_hot_and_linear_adds_bias() {
let weight = array![[1.0, 2.0], [3.0, 4.0]]; let m = OneHotAndLinear::from_raw_weight(weight, Some(vec![10.0, 100.0]));
let src = array![[0_usize]];
let out = m.forward(src.view());
assert_eq!(out[(0, 0, 0)], 1.0 + 10.0);
assert_eq!(out[(0, 0, 1)], 3.0 + 100.0);
}
#[test]
fn skippable_linear_skips_sentinel_rows() {
let weight = array![[1.0, 0.0], [0.0, 1.0]]; let m = SkippableLinear::new(weight, None, -100.0);
let src = array![[[1.0, 2.0], [-100.0, -100.0]]];
let out = m.forward(src.view());
assert_eq!(out[(0, 0, 0)], 1.0);
assert_eq!(out[(0, 0, 1)], 2.0);
assert_eq!(out[(0, 1, 0)], -100.0);
assert_eq!(out[(0, 1, 1)], -100.0);
}
#[test]
fn skippable_linear_partial_sentinel_is_not_skipped() {
let weight = array![[1.0, 1.0]]; let m = SkippableLinear::new(weight, None, -100.0);
let src = array![[[-100.0, 1.0]]]; let out = m.forward(src.view());
assert_eq!(out[(0, 0, 0)], -99.0); }
#[test]
fn layer_norm_zero_mean_unit_var() {
let x = array![[[-1.0_f32, 0.0, 1.0]]];
let y = layer_norm_last(x.view(), &[1.0, 1.0, 1.0], None, 1e-5);
let mean: f32 = y.iter().sum::<f32>() / 3.0;
let var: f32 = y.iter().map(|v| (v - mean).powi(2)).sum::<f32>() / 3.0;
assert!(mean.abs() < 1e-5);
assert!((var - 1.0).abs() < 1e-3);
}
#[test]
fn linear3d_matches_manual_compute() {
let x = array![[[1.0_f32, 2.0], [3.0, 4.0]]]; let w = array![[1.0, 1.0], [0.0, 2.0]]; let b = [0.0, 0.5];
let y = linear3d(x.view(), w.view(), Some(&b));
assert_eq!(y[(0, 0, 0)], 3.0);
assert_eq!(y[(0, 0, 1)], 4.5);
assert_eq!(y[(0, 1, 0)], 7.0);
assert_eq!(y[(0, 1, 1)], 8.5);
}
}