#[derive(Debug, Clone)]
pub struct Tensor {
pub data: Vec<f32>,
pub grad: Option<Vec<f32>>,
pub requires_grad: bool,
shape: Vec<usize>,
}
impl Tensor {
pub fn new(data: Vec<f32>, shape: Vec<usize>) -> Self {
debug_assert_eq!(
data.len(),
shape.iter().product::<usize>(),
"data length must match shape product"
);
Self {
data,
grad: None,
requires_grad: false,
shape,
}
}
pub fn requires_grad(mut self) -> Self {
self.requires_grad = true;
self
}
pub fn zeros(shape: &[usize]) -> Self {
let n = shape.iter().product();
Self::new(vec![0.0f32; n], shape.to_vec())
}
pub fn ones(shape: &[usize]) -> Self {
let n = shape.iter().product();
Self::new(vec![1.0f32; n], shape.to_vec())
}
pub fn from_scalar(v: f32) -> Self {
Self::new(vec![v], vec![1])
}
pub fn shape(&self) -> &[usize] {
&self.shape
}
pub fn numel(&self) -> usize {
self.data.len()
}
pub fn zero_grad(&mut self) {
match self.grad.as_mut() {
Some(g) => g.iter_mut().for_each(|x| *x = 0.0),
None => self.grad = Some(vec![0.0f32; self.data.len()]),
}
}
pub fn accumulate_grad(&mut self, grad: &[f32]) {
assert_eq!(
grad.len(),
self.data.len(),
"gradient length must match tensor length"
);
match self.grad.as_mut() {
Some(g) => {
for (dst, src) in g.iter_mut().zip(grad.iter()) {
*dst += src;
}
}
None => {
self.grad = Some(grad.to_vec());
}
}
}
pub fn detach(&self) -> Self {
Self {
data: self.data.clone(),
grad: None,
requires_grad: false,
shape: self.shape.clone(),
}
}
pub fn add(&self, other: &Tensor) -> Tensor {
assert_eq!(
self.shape, other.shape,
"add: shapes must match ({:?} vs {:?})",
self.shape, other.shape
);
let data: Vec<f32> = self
.data
.iter()
.zip(other.data.iter())
.map(|(a, b)| a + b)
.collect();
Tensor::new(data, self.shape.clone())
}
pub fn mul(&self, other: &Tensor) -> Tensor {
assert_eq!(
self.shape, other.shape,
"mul: shapes must match ({:?} vs {:?})",
self.shape, other.shape
);
let data: Vec<f32> = self
.data
.iter()
.zip(other.data.iter())
.map(|(a, b)| a * b)
.collect();
Tensor::new(data, self.shape.clone())
}
pub fn matmul(&self, other: &Tensor, m: usize, k: usize, n: usize) -> Tensor {
assert_eq!(
self.data.len(),
m * k,
"matmul: self must have m*k elements"
);
assert_eq!(
other.data.len(),
k * n,
"matmul: other must have k*n elements"
);
let mut out = vec![0.0f32; m * n];
for i in 0..m {
for j in 0..n {
let mut sum = 0.0f32;
for l in 0..k {
sum += self.data[i * k + l] * other.data[l * n + j];
}
out[i * n + j] = sum;
}
}
Tensor::new(out, vec![m, n])
}
pub fn relu(&self) -> Tensor {
let data: Vec<f32> = self.data.iter().map(|&x| x.max(0.0)).collect();
Tensor::new(data, self.shape.clone())
}
pub fn sigmoid(&self) -> Tensor {
let data: Vec<f32> = self
.data
.iter()
.map(|&x| 1.0 / (1.0 + (-x).exp()))
.collect();
Tensor::new(data, self.shape.clone())
}
pub fn softmax(&self) -> Tensor {
let last_dim = self.shape.last().copied().unwrap_or(self.data.len());
let batch = self.data.len() / last_dim.max(1);
let mut data = self.data.clone();
for b in 0..batch {
let start = b * last_dim;
let slice = &mut data[start..start + last_dim];
let max_val = slice.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let mut sum = 0.0f32;
for v in slice.iter_mut() {
*v = (*v - max_val).exp();
sum += *v;
}
if sum > 0.0 {
for v in slice.iter_mut() {
*v /= sum;
}
}
}
Tensor::new(data, self.shape.clone())
}
pub fn mean(&self) -> Tensor {
let n = self.data.len() as f32;
let m = self.data.iter().sum::<f32>() / n.max(1.0);
Tensor::from_scalar(m)
}
pub fn sum(&self) -> Tensor {
let s = self.data.iter().sum::<f32>();
Tensor::from_scalar(s)
}
pub fn neg(&self) -> Tensor {
let data: Vec<f32> = self.data.iter().map(|&x| -x).collect();
Tensor::new(data, self.shape.clone())
}
pub fn log(&self) -> Tensor {
let data: Vec<f32> = self.data.iter().map(|&x| x.ln()).collect();
Tensor::new(data, self.shape.clone())
}
}
pub mod backward {
use super::Tensor;
pub fn sum_backward(grad_output: f32, input: &Tensor) -> Vec<f32> {
vec![grad_output; input.data.len()]
}
pub fn mean_backward(grad_output: f32, input: &Tensor) -> Vec<f32> {
let n = input.data.len() as f32;
let scale = grad_output / n.max(1.0);
vec![scale; input.data.len()]
}
pub fn relu_backward(grad_output: &[f32], input: &Tensor) -> Vec<f32> {
input
.data
.iter()
.zip(grad_output.iter())
.map(|(&x, &g)| if x > 0.0 { g } else { 0.0 })
.collect()
}
pub fn sigmoid_backward(grad_output: &[f32], output: &Tensor) -> Vec<f32> {
output
.data
.iter()
.zip(grad_output.iter())
.map(|(&s, &g)| g * s * (1.0 - s))
.collect()
}
pub fn linear_backward(
grad_output: &[f32],
input: &[f32],
weights: &[f32],
m: usize,
k: usize,
n: usize,
) -> (Vec<f32>, Vec<f32>, Vec<f32>) {
let mut dx = vec![0.0f32; m * k];
for i in 0..m {
for j in 0..k {
let mut s = 0.0f32;
for l in 0..n {
s += grad_output[i * n + l] * weights[j * n + l];
}
dx[i * k + j] = s;
}
}
let mut dw = vec![0.0f32; k * n];
for i in 0..k {
for j in 0..n {
let mut s = 0.0f32;
for l in 0..m {
s += input[l * k + i] * grad_output[l * n + j];
}
dw[i * n + j] = s;
}
}
let mut db = vec![0.0f32; n];
for i in 0..m {
for j in 0..n {
db[j] += grad_output[i * n + j];
}
}
(dx, dw, db)
}
pub fn cross_entropy_backward(logits: &[f32], target_id: u32) -> Vec<f32> {
let max_val = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let mut probs: Vec<f32> = logits.iter().map(|&x| (x - max_val).exp()).collect();
let sum: f32 = probs.iter().sum();
let inv_sum = if sum > 0.0 { 1.0 / sum } else { 1.0 };
for p in probs.iter_mut() {
*p *= inv_sum;
}
let idx = target_id as usize;
if idx < probs.len() {
probs[idx] -= 1.0;
}
probs
}
}
pub fn cross_entropy_loss(logits: &[f32], target_id: u32) -> f32 {
let max_val = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let exps: Vec<f32> = logits.iter().map(|&x| (x - max_val).exp()).collect();
let sum: f32 = exps.iter().sum();
let idx = target_id as usize;
if idx >= logits.len() || sum == 0.0 {
return f32::INFINITY;
}
let log_prob = exps[idx].ln() - sum.ln();
-log_prob
}
pub fn sequence_perplexity(pairs: &[(Vec<f32>, u32)]) -> f32 {
if pairs.is_empty() {
return f32::INFINITY;
}
let avg_nll: f32 = pairs
.iter()
.map(|(logits, target)| cross_entropy_loss(logits, *target))
.sum::<f32>()
/ pairs.len() as f32;
avg_nll.exp()
}
#[cfg(test)]
mod tests {
use super::*;
const EPS: f32 = 1e-5;
fn approx_eq(a: f32, b: f32) -> bool {
(a - b).abs() < EPS
}
#[test]
fn test_tensor_add() {
let a = Tensor::new(vec![1.0, 2.0, 3.0], vec![3]);
let b = Tensor::new(vec![4.0, 5.0, 6.0], vec![3]);
let c = a.add(&b);
assert_eq!(c.data, vec![5.0, 7.0, 9.0]);
}
#[test]
fn test_tensor_mul() {
let a = Tensor::new(vec![1.0, 2.0, 3.0], vec![3]);
let b = Tensor::new(vec![2.0, 3.0, 4.0], vec![3]);
let c = a.mul(&b);
assert_eq!(c.data, vec![2.0, 6.0, 12.0]);
}
#[test]
fn test_tensor_matmul() {
let a = Tensor::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]);
let b = Tensor::new(vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0], vec![3, 2]);
let c = a.matmul(&b, 2, 3, 2);
assert!(approx_eq(c.data[0], 58.0));
assert!(approx_eq(c.data[1], 64.0));
assert!(approx_eq(c.data[2], 139.0));
assert!(approx_eq(c.data[3], 154.0));
assert_eq!(c.shape(), &[2, 2]);
}
#[test]
fn test_tensor_relu_forward() {
let t = Tensor::new(vec![-2.0, -0.5, 0.0, 0.5, 2.0], vec![5]);
let r = t.relu();
assert_eq!(r.data, vec![0.0, 0.0, 0.0, 0.5, 2.0]);
}
#[test]
fn test_tensor_sigmoid_forward() {
let t = Tensor::from_scalar(0.0);
let s = t.sigmoid();
assert!(approx_eq(s.data[0], 0.5), "sigmoid(0) must be 0.5");
let large = Tensor::from_scalar(100.0);
let sl = large.sigmoid();
assert!(sl.data[0] > 0.99, "sigmoid(100) must be close to 1");
}
#[test]
fn test_tensor_softmax_sums_to_one() {
let t = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![4]);
let s = t.softmax();
let total: f32 = s.data.iter().sum();
assert!(approx_eq(total, 1.0), "softmax must sum to 1, got {total}");
for &p in &s.data {
assert!(
(0.0..=1.0).contains(&p),
"each probability must be in [0,1]"
);
}
}
#[test]
fn test_tensor_mean_scalar() {
let t = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![4]);
let m = t.mean();
assert_eq!(m.shape(), &[1]);
assert!(approx_eq(m.data[0], 2.5));
}
#[test]
fn test_relu_backward_zeros_negatives() {
let input = Tensor::new(vec![-1.0, 0.0, 1.0, 2.0], vec![4]);
let grad_out = vec![1.0f32; 4];
let grad_in = backward::relu_backward(&grad_out, &input);
assert_eq!(grad_in, vec![0.0, 0.0, 1.0, 1.0]);
}
#[test]
fn test_sigmoid_backward_shape() {
let output = Tensor::new(vec![0.5f32; 4], vec![4]);
let grad_out = vec![1.0f32; 4];
let grad_in = backward::sigmoid_backward(&grad_out, &output);
assert_eq!(grad_in.len(), 4);
for &g in &grad_in {
assert!(approx_eq(g, 0.25), "expected 0.25, got {g}");
}
}
#[test]
fn test_linear_backward_shapes() {
let grad_output = vec![1.0f32; 2 * 4];
let input = vec![0.5f32; 2 * 3];
let weights = vec![0.1f32; 3 * 4];
let (dx, dw, db) = backward::linear_backward(&grad_output, &input, &weights, 2, 3, 4);
assert_eq!(dx.len(), 2 * 3, "dL/dx shape mismatch");
assert_eq!(dw.len(), 3 * 4, "dL/dW shape mismatch");
assert_eq!(db.len(), 4, "dL/db shape mismatch");
}
#[test]
fn test_cross_entropy_loss_basic() {
let logits = vec![10.0f32, 0.0, 0.0];
let loss_correct = cross_entropy_loss(&logits, 0);
let loss_wrong = cross_entropy_loss(&logits, 1);
assert!(loss_correct < 0.01, "loss for correct class must be near 0");
assert!(loss_wrong > 5.0, "loss for wrong class must be high");
}
#[test]
fn test_cross_entropy_backward_sums_to_zero() {
let logits = vec![1.0f32, 2.0, 3.0];
let grad = backward::cross_entropy_backward(&logits, 1);
let total: f32 = grad.iter().sum();
assert!(
total.abs() < EPS,
"cross-entropy gradient must sum to ~0, got {total}"
);
}
#[test]
fn test_sequence_perplexity() {
let pairs: Vec<(Vec<f32>, u32)> = (0..3)
.map(|i| {
let mut l = vec![0.0f32; 4];
l[i] = 20.0;
(l, i as u32)
})
.collect();
let ppl = sequence_perplexity(&pairs);
assert!(
ppl < 1.01,
"near-perfect predictor must have perplexity ~1, got {ppl}"
);
assert!(sequence_perplexity(&[]).is_infinite());
}
#[test]
fn test_tensor_accumulate_grad() {
let mut t = Tensor::new(vec![1.0, 2.0, 3.0], vec![3]).requires_grad();
t.accumulate_grad(&[0.1, 0.2, 0.3]);
t.accumulate_grad(&[0.1, 0.2, 0.3]);
let grad = t.grad.as_ref().expect("grad must be Some");
assert!(approx_eq(grad[0], 0.2));
assert!(approx_eq(grad[1], 0.4));
assert!(approx_eq(grad[2], 0.6));
}
}