use crate::autograd::BackwardOp;
use crate::storage::Storage;
use crate::Tensor;
use rayon::prelude::*;
use std::sync::Arc;
pub mod activations;
pub mod conv;
pub mod embedding;
pub mod norm;
pub mod pool;
pub mod view;
pub use activations::{sigmoid, softmax, tanh};
pub use conv::conv2d;
pub use embedding::embedding;
pub use norm::{batch_norm2d, layer_norm};
pub use pool::max_pool2d;
pub use view::ReshapeBackward;
#[derive(Debug)]
pub struct MulBackward {
pub lhs: Tensor,
pub rhs: Tensor,
}
impl BackwardOp for MulBackward {
fn backward(&self, grad: &Tensor) {
if self.lhs.requires_grad() {
let grad_lhs = crate::ops::mul(grad, &self.rhs);
self.lhs.accumulate_grad(&grad_lhs);
self.lhs.backward_step();
}
if self.rhs.requires_grad() {
let grad_rhs = crate::ops::mul(grad, &self.lhs);
self.rhs.accumulate_grad(&grad_rhs);
self.rhs.backward_step();
}
}
}
pub fn mul(lhs: &Tensor, rhs: &Tensor) -> Tensor {
if lhs.shape() != rhs.shape() {
panic!(
"Mul: shapes mismatch {:?} vs {:?}",
lhs.shape(),
rhs.shape()
);
}
let lhs_guard = lhs.data();
let rhs_guard = rhs.data();
let lhs_data = &*lhs_guard;
let rhs_data = &*rhs_guard;
let result_data: Vec<f32> = lhs_data
.par_iter()
.zip(rhs_data.par_iter())
.map(|(a, b)| a * b)
.collect();
let storage = Storage::new(result_data);
let mut tensor = Tensor::new_with_storage(storage, lhs.shape());
if lhs.requires_grad() || rhs.requires_grad() {
tensor.set_requires_grad_mut(true);
tensor.set_op(Arc::new(MulBackward {
lhs: lhs.clone(),
rhs: rhs.clone(),
}));
}
tensor
}
#[derive(Debug)]
pub struct AddBackward {
pub lhs: Tensor,
pub rhs: Tensor,
}
impl BackwardOp for AddBackward {
fn backward(&self, grad: &Tensor) {
if self.lhs.requires_grad() {
let grad_lhs = grad.clone();
self.lhs.accumulate_grad(&grad_lhs);
self.lhs.backward_step();
}
if self.rhs.requires_grad() {
let grad_rhs = grad.clone();
self.rhs.accumulate_grad(&grad_rhs);
self.rhs.backward_step();
}
}
}
pub fn add(lhs: &Tensor, rhs: &Tensor) -> Tensor {
if lhs.shape() == rhs.shape() {
let lhs_guard = lhs.data();
let rhs_guard = rhs.data();
let lhs_data = &*lhs_guard;
let rhs_data = &*rhs_guard;
let result_data: Vec<f32> = lhs_data
.par_iter()
.zip(rhs_data.par_iter())
.map(|(a, b)| a + b)
.collect();
let storage = Storage::new(result_data);
let mut tensor = Tensor::new_with_storage(storage, lhs.shape());
if lhs.requires_grad() || rhs.requires_grad() {
tensor.set_requires_grad_mut(true);
tensor.set_op(Arc::new(AddBackward {
lhs: lhs.clone(),
rhs: rhs.clone(),
}));
}
return tensor;
}
let target_shape = crate::broadcast::broadcast_shapes(lhs.shape(), rhs.shape())
.expect("Shapes not broadcastable");
let lhs_expanded = lhs.expand(&target_shape);
let rhs_expanded = rhs.expand(&target_shape);
let lhs_guard = lhs_expanded.data();
let rhs_guard = rhs_expanded.data();
let lhs_data = &*lhs_guard;
let rhs_data = &*rhs_guard;
let result_data: Vec<f32> = lhs_data
.par_iter()
.zip(rhs_data.par_iter())
.map(|(a, b)| a + b)
.collect();
let storage = Storage::new(result_data);
let mut tensor = Tensor::new_with_storage(storage, &target_shape);
if lhs.requires_grad() || rhs.requires_grad() {
tensor.set_requires_grad_mut(true);
tensor.set_op(Arc::new(AddBackward {
lhs: lhs.clone(),
rhs: rhs.clone(),
}));
}
tensor
}
pub fn sub(lhs: &Tensor, rhs: &Tensor) -> Tensor {
if lhs.shape() != rhs.shape() {
panic!("Sub shape mismatch");
}
let lhs_guard = lhs.data();
let rhs_guard = rhs.data();
let lhs_data = &*lhs_guard;
let rhs_data = &*rhs_guard;
let result_data: Vec<f32> = lhs_data
.par_iter()
.zip(rhs_data.par_iter())
.map(|(a, b)| a - b)
.collect();
let storage = Storage::new(result_data);
let mut tensor = Tensor::new_with_storage(storage, lhs.shape());
if lhs.requires_grad() || rhs.requires_grad() {
tensor.set_requires_grad_mut(true);
}
tensor
}
pub fn neg(input: &Tensor) -> Tensor {
let input_guard = input.data();
let input_data = &*input_guard;
let result_data: Vec<f32> = input_data.par_iter().map(|x| -x).collect();
let storage = Storage::new(result_data);
let mut tensor = Tensor::new_with_storage(storage, input.shape());
if input.requires_grad() {
tensor.set_requires_grad_mut(true);
}
tensor
}
pub fn relu(input: &Tensor) -> Tensor {
let input_guard = input.data();
let input_data = &*input_guard;
let result_data: Vec<f32> = input_data.par_iter().map(|x| x.max(0.0)).collect();
let storage = Storage::new(result_data);
let mut tensor = Tensor::new_with_storage(storage, input.shape());
if input.requires_grad() {
tensor.set_requires_grad_mut(true);
}
tensor
}