use alloc::vec::Vec;
use super::tensor::Tensor;
use core::ops::{Add, Mul};
pub fn matmul<T>(a: &Tensor<T>, b: &Tensor<T>, zero: T) -> Tensor<T>
where
T: Add<Output = T> + Mul<Output = T> + Clone,
{
assert!(a.is_matrix(), "First argument must be a matrix");
assert!(b.is_matrix(), "Second argument must be a matrix");
let (m, n) = a.matrix_dims();
let (n2, p) = b.matrix_dims();
assert_eq!(
n, n2,
"Matrix dimensions don't match: ({}, {}) × ({}, {})",
m, n, n2, p
);
let mut result = vec![zero.clone(); m * p];
for i in 0..m {
for j in 0..p {
let mut sum = zero.clone();
for k in 0..n {
let a_val = a.data[i * n + k].clone();
let b_val = b.data[k * p + j].clone();
sum = sum + (a_val * b_val);
}
result[i * p + j] = sum;
}
}
Tensor::new(result, vec![m, p])
}
pub fn transpose<T: Clone>(tensor: &Tensor<T>) -> Tensor<T> {
assert!(tensor.is_matrix(), "Can only transpose 2D tensors");
let (rows, cols) = tensor.matrix_dims();
let mut result = Vec::with_capacity(rows * cols);
for j in 0..cols {
for i in 0..rows {
result.push(tensor.data[i * cols + j].clone());
}
}
Tensor::new(result, vec![cols, rows])
}
pub fn relu<T>(tensor: &Tensor<T>, zero: T) -> Tensor<T>
where
T: Clone + PartialOrd,
{
let data: Vec<T> = tensor
.data
.iter()
.map(|x| if x > &zero { x.clone() } else { zero.clone() })
.collect();
Tensor::new(data, tensor.shape.clone())
}
pub fn scale<T>(tensor: &Tensor<T>, scalar: T) -> Tensor<T>
where
T: Mul<Output = T> + Clone,
{
let data: Vec<T> = tensor
.data
.iter()
.map(|x| x.clone() * scalar.clone())
.collect();
Tensor::new(data, tensor.shape.clone())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ScalarF4E4;
#[test]
fn test_matmul_2x2() {
let a = Tensor::new(
vec![
ScalarF4E4::from(1.0),
ScalarF4E4::from(2.0),
ScalarF4E4::from(3.0),
ScalarF4E4::from(4.0),
],
vec![2, 2],
);
let b = Tensor::new(
vec![
ScalarF4E4::from(5.0),
ScalarF4E4::from(6.0),
ScalarF4E4::from(7.0),
ScalarF4E4::from(8.0),
],
vec![2, 2],
);
let c = matmul(&a, &b, ScalarF4E4::ZERO);
assert_eq!(c.data[0].to_f64(), 19.0); assert_eq!(c.data[1].to_f64(), 22.0); assert_eq!(c.data[2].to_f64(), 43.0); assert_eq!(c.data[3].to_f64(), 50.0); }
#[test]
fn test_transpose() {
let a = Tensor::new(
vec![
ScalarF4E4::from(1.0),
ScalarF4E4::from(2.0),
ScalarF4E4::from(3.0),
ScalarF4E4::from(4.0),
ScalarF4E4::from(5.0),
ScalarF4E4::from(6.0),
],
vec![2, 3],
);
let at = transpose(&a);
assert_eq!(at.shape, vec![3, 2]);
assert_eq!(at.data[0].to_f64(), 1.0);
assert_eq!(at.data[1].to_f64(), 4.0);
assert_eq!(at.data[2].to_f64(), 2.0);
assert_eq!(at.data[3].to_f64(), 5.0);
}
#[test]
fn test_relu() {
let a = Tensor::new(
vec![
ScalarF4E4::from(-1.0),
ScalarF4E4::from(2.0),
ScalarF4E4::from(-3.0),
ScalarF4E4::from(4.0),
],
vec![4],
);
let r = relu(&a, ScalarF4E4::ZERO);
assert_eq!(r.data[0].to_f64(), 0.0); assert_eq!(r.data[1].to_f64(), 2.0); assert_eq!(r.data[2].to_f64(), 0.0); assert_eq!(r.data[3].to_f64(), 4.0); }
}