use crate::tensor::Tensor3;
use alloc::format;
use alloc::string::String;
use alloc::vec;
use alloc::vec::Vec;
use core::cmp::Ordering;
#[cfg(not(feature = "std"))]
use crate::no_std_math::F32Ext as _;
pub const DEFAULT_HIDDEN_DIM: usize = 32;
#[derive(Clone, Debug)]
pub struct FastGrnnGate {
pub input_dim: usize,
pub hidden_dim: usize,
w_gate: Vec<f32>,
u_gate: Vec<f32>,
w_update: Vec<f32>,
u_update: Vec<f32>,
bias_gate: Vec<f32>,
bias_update: Vec<f32>,
zeta: f32,
nu: f32,
}
impl FastGrnnGate {
pub fn new(input_dim: usize, hidden_dim: usize) -> Self {
assert!(input_dim > 0 && hidden_dim > 0, "dims must be > 0");
let scale = (2.0 / (input_dim + hidden_dim) as f32).sqrt();
let seed_w = |i: usize| -> f32 {
let r = ((i as u32).wrapping_mul(2654435761).wrapping_add(0x9e3779b9)) as f32;
(r / u32::MAX as f32 - 0.5) * 2.0 * scale * 0.1
};
let n_in_h = input_dim * hidden_dim;
let n_h_h = hidden_dim * hidden_dim;
Self {
input_dim,
hidden_dim,
w_gate: (0..n_in_h).map(|i| seed_w(i)).collect(),
u_gate: (0..n_h_h ).map(|i| seed_w(i + 1000)).collect(),
w_update: (0..n_in_h).map(|i| seed_w(i + 2000)).collect(),
u_update: (0..n_h_h ).map(|i| seed_w(i + 3000)).collect(),
bias_gate: vec![0.0; hidden_dim],
bias_update: vec![0.0; hidden_dim],
zeta: 1.0,
nu: 0.0,
}
}
#[allow(clippy::too_many_arguments)]
pub fn from_weights(
input_dim: usize,
hidden_dim: usize,
w_gate: Vec<f32>, u_gate: Vec<f32>,
w_update: Vec<f32>, u_update: Vec<f32>,
bias_gate: Vec<f32>, bias_update: Vec<f32>,
zeta: f32, nu: f32,
) -> Result<Self, String> {
let n_in_h = input_dim * hidden_dim;
let n_h_h = hidden_dim * hidden_dim;
if w_gate.len() != n_in_h { return Err(format!("w_gate len {} != {}", w_gate.len(), n_in_h)); }
if u_gate.len() != n_h_h { return Err(format!("u_gate len {} != {}", u_gate.len(), n_h_h)); }
if w_update.len() != n_in_h { return Err(format!("w_update len {} != {}", w_update.len(), n_in_h)); }
if u_update.len() != n_h_h { return Err(format!("u_update len {} != {}", u_update.len(), n_h_h)); }
if bias_gate.len() != hidden_dim { return Err("bias_gate len != hidden_dim".into()); }
if bias_update.len() != hidden_dim { return Err("bias_update len != hidden_dim".into()); }
Ok(Self {
input_dim, hidden_dim,
w_gate, u_gate, w_update, u_update,
bias_gate, bias_update,
zeta, nu,
})
}
pub fn score_sequence(&self, tokens: &[Vec<f32>]) -> Vec<f32> {
let seq = tokens.len();
let mut salience = Vec::with_capacity(seq);
let mut h = vec![0.0f32; self.hidden_dim];
for x in tokens {
debug_assert_eq!(x.len(), self.input_dim, "token feature dim mismatch");
let h_new = self.step(x, &h);
salience.push(l2_norm(&h_new));
h = h_new;
}
salience
}
pub fn score_kv(&self, k: &Tensor3) -> Vec<f32> {
assert_eq!(self.input_dim, k.dim, "gate input_dim must equal k.dim");
let seq = k.seq;
let mut tokens: Vec<Vec<f32>> = Vec::with_capacity(seq);
let inv_h = 1.0 / k.heads as f32;
for i in 0..seq {
let mut pooled = vec![0.0f32; k.dim];
for h in 0..k.heads {
let row = k.row(i, h);
for d in 0..k.dim { pooled[d] += row[d] * inv_h; }
}
tokens.push(pooled);
}
self.score_sequence(&tokens)
}
pub fn step_with_hidden(&self, x: &[f32], h: &[f32]) -> (Vec<f32>, f32) {
debug_assert_eq!(x.len(), self.input_dim);
debug_assert_eq!(h.len(), self.hidden_dim);
let h_new = self.step(x, h);
let s = l2_norm(&h_new);
(h_new, s)
}
pub fn keep_mask_quantile(salience: &[f32], quantile: f32) -> Vec<bool> {
let n = salience.len();
if n == 0 { return Vec::new(); }
let q = quantile.clamp(0.0, 1.0);
let mut sorted: Vec<f32> = salience.iter().copied().collect();
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal));
let idx = ((n as f32) * q) as usize;
let threshold = sorted[idx.min(n - 1)];
salience.iter().map(|&s| s >= threshold).collect()
}
pub fn keep_mask_top_k(salience: &[f32], k: usize) -> Vec<bool> {
let n = salience.len();
let mut keep = vec![false; n];
if k == 0 { return keep; }
if k >= n { keep.iter_mut().for_each(|b| *b = true); return keep; }
let mut idx: Vec<usize> = (0..n).collect();
idx.sort_by(|&a, &b| {
salience[b].partial_cmp(&salience[a]).unwrap_or(Ordering::Equal)
.then(a.cmp(&b))
});
for &i in idx.iter().take(k) { keep[i] = true; }
keep
}
fn step(&self, x: &[f32], h: &[f32]) -> Vec<f32> {
let d = self.hidden_dim;
let mut gate = vec![0.0f32; d];
let mut update = vec![0.0f32; d];
matmul_add(&self.w_gate, x, &mut gate, self.input_dim);
matmul_add(&self.u_gate, h, &mut gate, self.hidden_dim);
matmul_add(&self.w_update, x, &mut update, self.input_dim);
matmul_add(&self.u_update, h, &mut update, self.hidden_dim);
let mut h_new = vec![0.0f32; d];
for i in 0..d {
let g = sigmoid(gate[i] + self.bias_gate[i]);
let c = (update[i] + self.bias_update[i]).tanh();
let gf = (self.zeta * g + self.nu).clamp(0.0, 1.0);
h_new[i] = gf * h[i] + (1.0 - gf) * c;
}
h_new
}
}
#[inline]
fn matmul_add(weights: &[f32], input: &[f32], result: &mut [f32], cols: usize) {
let rows = result.len();
for i in 0..rows {
let row_off = i * cols;
let mut s = result[i];
for j in 0..cols {
s += weights[row_off + j] * input[j];
}
result[i] = s;
}
}
#[inline]
fn sigmoid(x: f32) -> f32 { 1.0 / (1.0 + (-x).exp()) }
#[inline]
fn l2_norm(v: &[f32]) -> f32 {
v.iter().map(|x| x * x).sum::<f32>().sqrt()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_gate_score_sequence_shape() {
let g = FastGrnnGate::new(8, 16);
let tokens: Vec<Vec<f32>> = (0..32)
.map(|t| (0..8).map(|d| (t * 8 + d) as f32 * 0.01).collect())
.collect();
let s = g.score_sequence(&tokens);
assert_eq!(s.len(), 32);
assert!(s.iter().all(|&v| v.is_finite() && v >= 0.0));
}
#[test]
fn test_gate_score_sequence_deterministic() {
let g = FastGrnnGate::new(4, 8);
let tokens: Vec<Vec<f32>> = (0..16)
.map(|t| vec![t as f32 * 0.05; 4])
.collect();
let a = g.score_sequence(&tokens);
let b = g.score_sequence(&tokens);
assert_eq!(a, b);
}
#[test]
fn test_gate_zero_input_zero_baseline() {
let g = FastGrnnGate::new(4, 8);
let tokens = vec![vec![0.0f32; 4]; 32];
let s = g.score_sequence(&tokens);
assert!(s.iter().all(|&v| v < 0.1), "zero input should keep salience near zero, got {:?}", s);
}
#[test]
fn test_keep_mask_quantile_top_quartile() {
let salience = vec![0.1, 0.5, 0.9, 0.2, 0.8, 0.3, 0.7, 0.4];
let keep = FastGrnnGate::keep_mask_quantile(&salience, 0.75);
let kept: usize = keep.iter().filter(|&&b| b).count();
assert_eq!(kept, 2, "keep = {:?}", keep);
assert!(keep[2]); assert!(keep[4]); }
#[test]
fn test_keep_mask_top_k() {
let salience = vec![0.1, 0.5, 0.9, 0.2, 0.8, 0.3, 0.7, 0.4];
let keep = FastGrnnGate::keep_mask_top_k(&salience, 3);
assert!(keep[2] && keep[4] && keep[6]);
let kept: usize = keep.iter().filter(|&&b| b).count();
assert_eq!(kept, 3);
}
#[test]
fn test_keep_mask_top_k_edges() {
let salience = vec![0.1, 0.2, 0.3];
assert!(FastGrnnGate::keep_mask_top_k(&salience, 0).iter().all(|&b| !b));
assert!(FastGrnnGate::keep_mask_top_k(&salience, 99).iter().all(|&b| b));
}
#[test]
fn test_from_weights_validates_shapes() {
let r = FastGrnnGate::from_weights(
4, 8,
vec![0.0; 32], vec![0.0; 64], vec![0.0; 32],
vec![0.0; 64],
vec![0.0; 8],
vec![0.0; 8],
1.0, 0.0,
);
assert!(r.is_ok());
let bad = FastGrnnGate::from_weights(
4, 8,
vec![0.0; 7], vec![0.0; 64], vec![0.0; 32], vec![0.0; 64],
vec![0.0; 8], vec![0.0; 8],
1.0, 0.0,
);
assert!(bad.is_err());
}
#[test]
fn test_step_with_hidden_advances_state() {
let g = FastGrnnGate::new(4, 8);
let h0 = vec![0.0; 8];
let x = vec![0.5; 4];
let (h1, _) = g.step_with_hidden(&x, &h0);
let (h2, _) = g.step_with_hidden(&x, &h1);
assert!(h0.iter().zip(h1.iter()).any(|(a, b)| a != b));
assert!(h1.iter().zip(h2.iter()).any(|(a, b)| a != b));
}
}