use crate::tensors::{Tensor, WithGrad};
pub fn relu(
input: &WithGrad<Tensor<f64>>,
) -> (Tensor<f64>, impl Fn(&Tensor<f64>) -> Tensor<f64>) {
let shape = input.value.shape.clone();
let out_data = input
.value
.data
.iter()
.map(|&x| x.max(0.0))
.collect();
let out = Tensor::new(shape.clone(), out_data);
let back = move |grad_output: &Tensor<f64>| {
let grad_data = input
.value
.data
.iter()
.zip(&grad_output.data)
.map(|(&x, &g)| if x > 0.0 { g } else { 0.0 })
.collect();
Tensor::new(shape.clone(), grad_data)
};
(out, back)
}
pub fn matmul<'a>(
a: &'a WithGrad<Tensor<f64>>,
b: &'a WithGrad<Tensor<f64>>,
) -> (Tensor<f64>, impl Fn(&Tensor<f64>) -> (Tensor<f64>, Tensor<f64>) + 'a) {
let (m, k1) = (a.value.shape[0], a.value.shape[1]);
let (k2, n) = (b.value.shape[0], b.value.shape[1]);
assert_eq!(k1, k2, "matmul shape mismatch");
let mut out_data = vec![0.0; m * n];
for i in 0..m {
for j in 0..n {
for k in 0..k1 {
out_data[i * n + j] += a.value.data[i * k1 + k] * b.value.data[k * n + j];
}
}
}
let out = Tensor::new(vec![m, n], out_data);
let a_shape = a.value.shape.clone();
let b_shape = b.value.shape.clone();
let a_data = a.value.data.clone();
let b_data = b.value.data.clone();
let back = move |grad_output: &Tensor<f64>| {
let mut da = vec![0.0; m * k1];
for i in 0..m {
for k in 0..k1 {
for j in 0..n {
da[i * k1 + k] += grad_output.data[i * n + j] * b_data[k * n + j];
}
}
}
let mut db = vec![0.0; k1 * n];
for k in 0..k1 {
for j in 0..n {
for i in 0..m {
db[k * n + j] += a_data[i * k1 + k] * grad_output.data[i * n + j];
}
}
}
(
Tensor::new(a_shape.clone(), da),
Tensor::new(b_shape.clone(), db),
)
};
(out, back)
}
pub fn mse_loss<'a>(
prediction: &'a WithGrad<Tensor<f64>>,
target: &'a Tensor<f64>,
) -> (f64, impl Fn(f64) -> Tensor<f64> + 'a) {
assert_eq!(prediction.value.shape, target.shape);
let n = prediction.value.data.len();
let loss = prediction
.value
.data
.iter()
.zip(&target.data)
.map(|(&y, &t)| (y - t).powi(2))
.sum::<f64>() / n as f64;
let shape = prediction.value.shape.clone();
let pred_data = prediction.value.data.clone();
let tgt_data = target.data.clone();
let back = move |grad_out: f64| {
let grad_vec: Vec<f64> = pred_data
.iter()
.zip(&tgt_data)
.map(|(&y, &t)| 2.0 * (y - t) * grad_out / n as f64)
.collect();
Tensor::new(shape.clone(), grad_vec)
};
(loss, back)
}
pub fn sgd(w: &mut WithGrad<Tensor<f64>>, lr: f64) {
for (param, grad) in w.value.data.iter_mut().zip(&w.grad.data) {
*param -= lr * *grad;
}
for grad in &mut w.grad.data {
*grad = 0.0;
}
}