use crate::errors::{NeuroxError, NeuroxResult};
use rand::Rng;
use std::fmt;
#[derive(Clone)]
pub struct Tensor {
pub data: Vec<f32>,
pub rows: usize,
pub cols: usize,
}
impl Tensor {
pub fn new(rows: usize, cols: usize) -> Self {
Self {
data: vec![0.0; rows * cols],
rows,
cols,
}
}
pub fn from_data(data: Vec<f32>, rows: usize, cols: usize) -> Self {
assert_eq!(
data.len(),
rows * cols,
"Data size must match tensor dimensions."
);
Self { data, rows, cols }
}
pub fn random(rows: usize, cols: usize) -> Self {
let mut rng = rand::rng();
let data = (0..rows * cols)
.map(|_| rng.random_range(-1.0..1.0))
.collect();
Self { data, rows, cols }
}
pub fn shape(&self) -> (usize, usize) {
(self.rows, self.cols)
}
pub fn get(&self, r: usize, c: usize) -> f32 {
self.data[r * self.cols + c]
}
pub fn set(&mut self, r: usize, c: usize, v: f32) {
self.data[r * self.cols + c] = v;
}
pub fn map<F>(&self, mut f: F) -> Tensor
where
F: FnMut(f32) -> f32,
{
let d = self.data.iter().map(|&x| f(x)).collect();
Tensor::from_data(d, self.rows, self.cols)
}
pub fn add_row_broadcast(&self, bias: &Tensor) -> NeuroxResult<Tensor> {
if bias.rows != 1 || bias.cols != self.cols {
return Err(NeuroxError::ShapeMismatch(
"bias shape must be (1, cols)".into(),
));
}
let mut out = self.clone();
for i in 0..self.rows {
for j in 0..self.cols {
let idx = i * self.cols + j;
out.data[idx] += bias.data[j];
}
}
Ok(out)
}
pub fn zeros(rows: usize, cols: usize) -> Self {
Self::new(rows, cols)
}
pub fn transpose(&self) -> Tensor {
let mut out = vec![0.0; self.rows * self.cols];
for i in 0..self.rows {
for j in 0..self.cols {
out[j * self.rows + i] = self.get(i, j);
}
}
Tensor::from_data(out, self.cols, self.rows)
}
}
impl fmt::Debug for Tensor {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut s = String::new();
s.push_str(&format!("Tensor({}, {}) [", self.rows, self.cols));
for i in 0..self.rows.min(3) {
s.push('[');
for j in 0..self.cols.min(6) {
s.push_str(&format!("{:.4}", self.get(i, j)));
if j + 1 < self.cols.min(6) {
s.push_str(", ");
}
}
if self.cols > 6 {
s.push_str(", ...");
}
s.push(']');
if i + 1 < self.rows.min(3) {
s.push_str(", ");
}
}
if self.rows > 3 {
s.push_str(", ...");
}
s.push(']');
write!(f, "{}", s)
}
}