#[derive(Debug, Clone, PartialEq)]
pub struct Tensor<T> {
pub shape: Vec<usize>,
pub data: Vec<T>,
}
impl<T> Tensor<T> {
pub fn new(shape: impl Into<Vec<usize>>, data: Vec<T>) -> Self {
let shape = shape.into();
assert_eq!(
shape.iter().product::<usize>(),
data.len(),
"shape {:?} is incompatible with {} data elements",
shape,
data.len()
);
Self { shape, data }
}
pub fn update(&mut self, mut other: Tensor<T>) {
assert_eq!(self.shape, other.shape, "shape mismatch");
std::mem::swap(&mut self.data, &mut other.data);
}
}
#[derive(Debug, Clone)]
pub struct WithGrad<T> {
pub value: T,
pub grad: T,
}
pub fn add_tensor<'a>(
a: &'a WithGrad<Tensor<f32>>,
b: &'a WithGrad<Tensor<f32>>
) -> (Tensor<f32>, impl Fn(&Tensor<f32>) -> (Tensor<f32>, Tensor<f32>) + 'a) {
assert_eq!(a.value.shape, b.value.shape);
let out = Tensor::new(
a.value.shape.clone(),
a.value.data.iter().zip(&b.value.data).map(|(x, y)| x + y).collect()
);
let a_shape = a.value.shape.clone();
let b_shape = b.value.shape.clone();
let back = move |grad_output: &Tensor<f32>| {
(
Tensor::new(a_shape.clone(), grad_output.data.clone()),
Tensor::new(b_shape.clone(), grad_output.data.clone()),
)
};
(out, back)
}
pub fn add(a: &WithGrad<f32>, b: &WithGrad<f32>) -> (f32, impl Fn(f32) -> (f32, f32)) {
let y = a.value + b.value;
let back = move |grad_output: f32| (grad_output, grad_output);
(y, back)
}
pub fn mul(a: &WithGrad<f32>, b: &WithGrad<f32>) -> (f32, impl Fn(f32) -> (f32, f32)) {
let y = a.value * b.value;
let back = move |grad_output: f32| (grad_output * b.value, grad_output * a.value);
(y, back)
}
pub fn sgd(w: &mut WithGrad<Tensor<f64>>, lr: f64) {
for (w_i, g_i) in w.value.data.iter_mut().zip(&w.grad.data) {
*w_i -= lr * *g_i;
}
for g_i in &mut w.grad.data {
*g_i = 0.0;
}
}
#[macro_export]
macro_rules! tensor {
($lit:literal) => {
$crate::tensors::Tensor::new(Vec::<usize>::new(), vec![$lit])
};
([ $( $inner:tt ),+ $(,)? ]) => {{
let children = vec![ $( tensor!($inner) ),+ ];
let first_shape = &children[0].shape;
assert!(children.iter().all(|c| c.shape == *first_shape),
"ragged tensor literal (rows have mismatched shapes)");
let mut shape = vec![children.len()];
shape.extend_from_slice(first_shape);
let mut data = Vec::with_capacity(children.len() * children[0].data.len());
for c in children { data.extend(c.data); }
$crate::tensors::Tensor::new(shape, data)
}};
}
pub fn parse_tensor(json: &str) -> Result<Tensor<f64>, &'static str> {
enum Tok { LBrack, RBrack, Comma, Num(f64) }
fn next_token(i: &mut usize, s: &[u8]) -> Result<Tok, &'static str> {
while *i < s.len() && s[*i].is_ascii_whitespace() { *i += 1; }
if *i >= s.len() { return Err("unexpected EOF"); }
let c = s[*i];
*i += 1;
Ok(match c {
b'[' => Tok::LBrack,
b']' => Tok::RBrack,
b',' => Tok::Comma,
b'-' | b'0'..=b'9' => {
let start = *i - 1;
while *i < s.len() && (s[*i].is_ascii_digit() || s[*i] == b'.' || s[*i] == b'e' || s[*i] == b'E' || s[*i] == b'+' || s[*i] == b'-') {
*i += 1;
}
let num = std::str::from_utf8(&s[start..*i]).unwrap().parse::<f64>()
.map_err(|_| "bad number")?;
Tok::Num(num)
}
_ => return Err("invalid char"),
})
}
let mut idx = 0;
let bytes = json.as_bytes();
let mut dims: Vec<usize> = Vec::new();
let mut data: Vec<f64> = Vec::new();
let mut level = 0;
let mut expect_val = true;
loop {
let t = next_token(&mut idx, bytes)?;
match t {
Tok::LBrack => {
if dims.len() == level { dims.push(0); }
level += 1;
expect_val = true;
}
Tok::RBrack => {
if expect_val && dims[level] != 0 { return Err("trailing comma"); }
level -= 1;
if level == 0 { break; }
if dims[level] == 0 { dims[level] = dims[level+1]; }
if dims[level] != dims[level+1] { return Err("ragged tensor"); }
dims[level+1] = 0;
expect_val = false;
}
Tok::Comma => {
if expect_val { return Err("comma where value expected"); }
expect_val = true;
}
Tok::Num(n) => {
if !expect_val { return Err("two values without comma"); }
data.push(n);
dims[level] += 1;
expect_val = false;
}
}
}
Ok(Tensor::new(dims[..level].to_vec(), data))
}