use alloc::vec;
use alloc::vec::Vec;
#[derive(Clone, Debug)]
#[non_exhaustive]
pub enum AttentionState {
Vector(Vec<f64>),
Matrix {
data: Vec<f64>,
rows: usize,
cols: usize,
},
}
impl AttentionState {
pub fn new_vector(d: usize) -> Self {
Self::Vector(vec![0.0; d])
}
pub fn new_matrix(d_k: usize, d_v: usize) -> Self {
Self::Matrix {
data: vec![0.0; d_k * d_v],
rows: d_k,
cols: d_v,
}
}
pub fn reset(&mut self) {
match self {
Self::Vector(v) => {
for x in v.iter_mut() {
*x = 0.0;
}
}
Self::Matrix { data, .. } => {
for x in data.iter_mut() {
*x = 0.0;
}
}
}
}
pub fn as_slice(&self) -> &[f64] {
match self {
Self::Vector(v) => v.as_slice(),
Self::Matrix { data, .. } => data.as_slice(),
}
}
pub fn get_matrix(&self, row: usize, col: usize) -> f64 {
match self {
Self::Matrix { data, cols, rows } => {
debug_assert!(row < *rows, "row {} out of bounds (rows={})", row, rows);
debug_assert!(col < *cols, "col {} out of bounds (cols={})", col, cols);
data[row * cols + col]
}
Self::Vector(_) => panic!("get_matrix called on Vector state"),
}
}
pub fn set_matrix(&mut self, row: usize, col: usize, val: f64) {
match self {
Self::Matrix { data, cols, rows } => {
debug_assert!(row < *rows, "row {} out of bounds (rows={})", row, rows);
debug_assert!(col < *cols, "col {} out of bounds (cols={})", col, cols);
data[row * *cols + col] = val;
}
Self::Vector(_) => panic!("set_matrix called on Vector state"),
}
}
pub fn scale(&mut self, factor: f64) {
match self {
Self::Vector(v) => {
for x in v.iter_mut() {
*x *= factor;
}
}
Self::Matrix { data, .. } => {
for x in data.iter_mut() {
*x *= factor;
}
}
}
}
#[allow(clippy::needless_range_loop)]
pub fn add_outer_product(&mut self, u: &[f64], v: &[f64]) {
match self {
Self::Matrix { data, rows, cols } => {
debug_assert_eq!(u.len(), *rows, "u.len() must equal rows");
debug_assert_eq!(v.len(), *cols, "v.len() must equal cols");
for i in 0..*rows {
let row_start = i * *cols;
for j in 0..*cols {
data[row_start + j] += u[i] * v[j];
}
}
}
Self::Vector(_) => panic!("add_outer_product called on Vector state"),
}
}
#[allow(clippy::needless_range_loop)]
pub fn scale_per_row(&mut self, weights: &[f64]) {
match self {
Self::Matrix { data, rows, cols } => {
debug_assert_eq!(weights.len(), *rows, "weights.len() must equal rows");
for i in 0..*rows {
let row_start = i * *cols;
let wi = weights[i];
for j in 0..*cols {
data[row_start + j] *= wi;
}
}
}
Self::Vector(_) => panic!("scale_per_row called on Vector state"),
}
}
#[allow(clippy::needless_range_loop)]
pub fn query(&self, q: &[f64]) -> Vec<f64> {
match self {
Self::Matrix { data, rows, cols } => {
debug_assert_eq!(q.len(), *rows, "q.len() must equal rows (d_key)");
let mut out = vec![0.0; *cols];
for i in 0..*rows {
let row_start = i * *cols;
let qi = q[i];
for j in 0..*cols {
out[j] += data[row_start + j] * qi;
}
}
out
}
Self::Vector(_) => panic!("query called on Vector state"),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn vector_state_init_and_reset() {
let mut state = AttentionState::new_vector(8);
match &state {
AttentionState::Vector(v) => {
assert_eq!(v.len(), 8, "vector should have length 8");
assert!(
v.iter().all(|&x| x == 0.0),
"initial vector should be all zeros"
);
}
_ => panic!("expected Vector state"),
}
if let AttentionState::Vector(v) = &mut state {
v[0] = 5.0;
v[3] = -2.0;
}
state.reset();
assert!(
state.as_slice().iter().all(|&x| x == 0.0),
"after reset, all values should be zero"
);
}
#[test]
fn matrix_state_init_dimensions() {
let state = AttentionState::new_matrix(3, 5);
match &state {
AttentionState::Matrix { data, rows, cols } => {
assert_eq!(*rows, 3, "should have 3 rows");
assert_eq!(*cols, 5, "should have 5 cols");
assert_eq!(data.len(), 15, "data should have 3*5=15 elements");
assert!(
data.iter().all(|&x| x == 0.0),
"initial matrix should be all zeros"
);
}
_ => panic!("expected Matrix state"),
}
}
#[test]
fn matrix_get_set() {
let mut state = AttentionState::new_matrix(2, 3);
state.set_matrix(0, 1, 7.0);
state.set_matrix(1, 2, -3.0);
assert!(
(state.get_matrix(0, 1) - 7.0).abs() < 1e-12,
"get(0,1) should be 7.0"
);
assert!(
(state.get_matrix(1, 2) - (-3.0)).abs() < 1e-12,
"get(1,2) should be -3.0"
);
assert!(
state.get_matrix(0, 0).abs() < 1e-12,
"unset elements should be 0.0"
);
}
#[test]
fn scale_multiplies_all() {
let mut state = AttentionState::new_matrix(2, 2);
state.set_matrix(0, 0, 1.0);
state.set_matrix(0, 1, 2.0);
state.set_matrix(1, 0, 3.0);
state.set_matrix(1, 1, 4.0);
state.scale(0.5);
assert!(
(state.get_matrix(0, 0) - 0.5).abs() < 1e-12,
"scaled (0,0) should be 0.5"
);
assert!(
(state.get_matrix(1, 1) - 2.0).abs() < 1e-12,
"scaled (1,1) should be 2.0"
);
}
#[test]
fn add_outer_product_correct() {
let mut state = AttentionState::new_matrix(2, 3);
let u = [1.0, 2.0];
let v = [3.0, 4.0, 5.0];
state.add_outer_product(&u, &v);
assert!(
(state.get_matrix(0, 0) - 3.0).abs() < 1e-12,
"S[0][0] should be 1*3=3"
);
assert!(
(state.get_matrix(0, 2) - 5.0).abs() < 1e-12,
"S[0][2] should be 1*5=5"
);
assert!(
(state.get_matrix(1, 1) - 8.0).abs() < 1e-12,
"S[1][1] should be 2*4=8"
);
assert!(
(state.get_matrix(1, 2) - 10.0).abs() < 1e-12,
"S[1][2] should be 2*5=10"
);
}
#[test]
fn query_computes_st_times_q() {
let mut state = AttentionState::new_matrix(2, 3);
state.set_matrix(0, 0, 1.0);
state.set_matrix(0, 1, 2.0);
state.set_matrix(0, 2, 3.0);
state.set_matrix(1, 0, 4.0);
state.set_matrix(1, 1, 5.0);
state.set_matrix(1, 2, 6.0);
let q = [1.0, 1.0];
let out = state.query(&q);
assert_eq!(out.len(), 3, "query output should have d_value=3 elements");
assert!(
(out[0] - 5.0).abs() < 1e-12,
"S^T*q[0] should be 1+4=5, got {}",
out[0]
);
assert!(
(out[1] - 7.0).abs() < 1e-12,
"S^T*q[1] should be 2+5=7, got {}",
out[1]
);
assert!(
(out[2] - 9.0).abs() < 1e-12,
"S^T*q[2] should be 3+6=9, got {}",
out[2]
);
}
#[test]
fn as_slice_returns_flat_data() {
let mut state = AttentionState::new_matrix(2, 2);
state.set_matrix(0, 0, 1.0);
state.set_matrix(1, 1, 4.0);
let s = state.as_slice();
assert_eq!(s.len(), 4, "flat slice should have 4 elements");
assert!((s[0] - 1.0).abs() < 1e-12, "s[0] should be 1.0");
assert!((s[3] - 4.0).abs() < 1e-12, "s[3] should be 4.0");
}
#[test]
fn scale_per_row_applies_different_factors() {
let mut state = AttentionState::new_matrix(2, 3);
state.set_matrix(0, 0, 2.0);
state.set_matrix(0, 1, 4.0);
state.set_matrix(0, 2, 6.0);
state.set_matrix(1, 0, 3.0);
state.set_matrix(1, 1, 5.0);
state.set_matrix(1, 2, 7.0);
state.scale_per_row(&[0.5, 2.0]);
assert!(
(state.get_matrix(0, 0) - 1.0).abs() < 1e-12,
"row 0 scaled by 0.5: 2*0.5=1"
);
assert!(
(state.get_matrix(0, 2) - 3.0).abs() < 1e-12,
"row 0 scaled by 0.5: 6*0.5=3"
);
assert!(
(state.get_matrix(1, 0) - 6.0).abs() < 1e-12,
"row 1 scaled by 2: 3*2=6"
);
assert!(
(state.get_matrix(1, 2) - 14.0).abs() < 1e-12,
"row 1 scaled by 2: 7*2=14"
);
}
#[test]
fn scale_vector_state() {
let mut state = AttentionState::new_vector(3);
if let AttentionState::Vector(v) = &mut state {
v[0] = 2.0;
v[1] = -4.0;
v[2] = 6.0;
}
state.scale(0.5);
let s = state.as_slice();
assert!((s[0] - 1.0).abs() < 1e-12, "scaled v[0] should be 1.0");
assert!((s[1] - (-2.0)).abs() < 1e-12, "scaled v[1] should be -2.0");
assert!((s[2] - 3.0).abs() < 1e-12, "scaled v[2] should be 3.0");
}
}