use std::sync::Arc;
use super::grad_fn::{
AddBackward, BroadcastAddBackward, DivBackward, ExpBackward, GeluBackward, LeakyReluBackward,
LogBackward, MatmulBackward, MeanBackward, MulBackward, NegBackward, PowBackward, ReluBackward,
SigmoidBackward, SoftmaxBackward, SqrtBackward, SubBackward, SumBackward, TanhBackward,
TransposeBackward, ViewBackward,
};
use super::tensor::Tensor;
use super::{is_grad_enabled, with_graph};
impl Tensor {
#[must_use]
pub fn add(&self, other: &Tensor) -> Tensor {
let data = trueno::blis::elementwise::add_alloc(self.data(), other.data());
let mut result = Tensor::from_vec(data, self.shape());
if is_grad_enabled() && (self.requires_grad_enabled() || other.requires_grad_enabled()) {
result.requires_grad_(true);
let grad_fn = Arc::new(AddBackward {
x_shape: self.shape().to_vec(),
y_shape: other.shape().to_vec(),
});
result.set_grad_fn(grad_fn.clone());
with_graph(|graph| {
graph.register_tensor(self.clone());
graph.register_tensor(other.clone());
graph.record(result.id(), grad_fn, vec![self.id(), other.id()]);
});
}
result
}
#[must_use]
pub fn sub(&self, other: &Tensor) -> Tensor {
let src_a = self.data();
let src_b = other.data();
let n = src_a.len();
let mut data = vec![0.0f32; n];
for i in 0..n {
data[i] = src_a[i] - src_b[i];
}
let mut result = Tensor::from_vec(data, self.shape());
if is_grad_enabled() && (self.requires_grad_enabled() || other.requires_grad_enabled()) {
result.requires_grad_(true);
let grad_fn = Arc::new(SubBackward {
x_shape: self.shape().to_vec(),
y_shape: other.shape().to_vec(),
});
result.set_grad_fn(grad_fn.clone());
with_graph(|graph| {
graph.register_tensor(self.clone());
graph.register_tensor(other.clone());
graph.record(result.id(), grad_fn, vec![self.id(), other.id()]);
});
}
result
}
#[must_use]
pub fn mul(&self, other: &Tensor) -> Tensor {
let src_a = self.data();
let src_b = other.data();
let n = src_a.len();
let mut data = vec![0.0f32; n];
for i in 0..n {
data[i] = src_a[i] * src_b[i];
}
let mut result = Tensor::from_vec(data, self.shape());
if is_grad_enabled() && (self.requires_grad_enabled() || other.requires_grad_enabled()) {
result.requires_grad_(true);
let grad_fn = Arc::new(MulBackward {
x: self.clone(),
y: other.clone(),
});
result.set_grad_fn(grad_fn.clone());
with_graph(|graph| {
graph.register_tensor(self.clone());
graph.register_tensor(other.clone());
graph.record(result.id(), grad_fn, vec![self.id(), other.id()]);
});
}
result
}
#[must_use]
pub fn div(&self, other: &Tensor) -> Tensor {
let src_a = self.data();
let src_b = other.data();
let n = src_a.len();
let mut data = vec![0.0f32; n];
for i in 0..n {
data[i] = src_a[i] / src_b[i];
}
let mut result = Tensor::from_vec(data, self.shape());
if is_grad_enabled() && (self.requires_grad_enabled() || other.requires_grad_enabled()) {
result.requires_grad_(true);
let grad_fn = Arc::new(DivBackward {
x: self.clone(),
y: other.clone(),
});
result.set_grad_fn(grad_fn.clone());
with_graph(|graph| {
graph.register_tensor(self.clone());
graph.register_tensor(other.clone());
graph.record(result.id(), grad_fn, vec![self.id(), other.id()]);
});
}
result
}
#[must_use]
pub fn neg(&self) -> Tensor {
let src = self.data();
let n = src.len();
let mut data = vec![0.0f32; n];
for i in 0..n {
data[i] = -src[i];
}
let mut result = Tensor::from_vec(data, self.shape());
if is_grad_enabled() && self.requires_grad_enabled() {
result.requires_grad_(true);
let grad_fn = Arc::new(NegBackward);
result.set_grad_fn(grad_fn.clone());
with_graph(|graph| {
graph.register_tensor(self.clone());
graph.record(result.id(), grad_fn, vec![self.id()]);
});
}
result
}
#[must_use]
pub fn mul_scalar(&self, scalar: f32) -> Tensor {
let data = trueno::blis::elementwise::mul_scalar_alloc(self.data(), scalar);
let mut result = Tensor::from_vec(data, self.shape());
if is_grad_enabled() && self.requires_grad_enabled() {
result.requires_grad_(true);
let grad_fn = Arc::new(MulBackward {
x: self.clone(),
y: Tensor::new(&vec![scalar; self.numel()], self.shape()),
});
result.set_grad_fn(grad_fn.clone());
with_graph(|graph| {
graph.register_tensor(self.clone());
graph.record(result.id(), grad_fn, vec![self.id()]);
});
}
result
}
}
impl Tensor {
#[must_use]
pub fn exp(&self) -> Tensor {
let data: Vec<f32> = self.data().iter().map(|&a| a.exp()).collect();
let mut result = Tensor::from_vec(data, self.shape());
if is_grad_enabled() && self.requires_grad_enabled() {
result.requires_grad_(true);
let grad_fn = Arc::new(ExpBackward {
output: result.clone(),
});
result.set_grad_fn(grad_fn.clone());
with_graph(|graph| {
graph.register_tensor(self.clone());
graph.record(result.id(), grad_fn, vec![self.id()]);
});
}
result
}
#[must_use]
pub fn log(&self) -> Tensor {
let data: Vec<f32> = self.data().iter().map(|&a| a.ln()).collect();
let mut result = Tensor::from_vec(data, self.shape());
if is_grad_enabled() && self.requires_grad_enabled() {
result.requires_grad_(true);
let grad_fn = Arc::new(LogBackward { x: self.clone() });
result.set_grad_fn(grad_fn.clone());
with_graph(|graph| {
graph.register_tensor(self.clone());
graph.record(result.id(), grad_fn, vec![self.id()]);
});
}
result
}
#[must_use]
pub fn pow(&self, n: f32) -> Tensor {
let data: Vec<f32> = self.data().iter().map(|&a| a.powf(n)).collect();
let mut result = Tensor::from_vec(data, self.shape());
if is_grad_enabled() && self.requires_grad_enabled() {
result.requires_grad_(true);
let grad_fn = Arc::new(PowBackward { x: self.clone(), n });
result.set_grad_fn(grad_fn.clone());
with_graph(|graph| {
graph.register_tensor(self.clone());
graph.record(result.id(), grad_fn, vec![self.id()]);
});
}
result
}
#[must_use]
pub fn sqrt(&self) -> Tensor {
let data: Vec<f32> = self.data().iter().map(|&a| a.sqrt()).collect();
let mut result = Tensor::from_vec(data, self.shape());
if is_grad_enabled() && self.requires_grad_enabled() {
result.requires_grad_(true);
let grad_fn = Arc::new(SqrtBackward {
output: result.clone(),
});
result.set_grad_fn(grad_fn.clone());
with_graph(|graph| {
graph.register_tensor(self.clone());
graph.record(result.id(), grad_fn, vec![self.id()]);
});
}
result
}
}
impl Tensor {
#[must_use]
pub fn sum(&self) -> Tensor {
let sum: f32 = self.data().iter().sum();
let mut result = Tensor::new(&[sum], &[1]);
if is_grad_enabled() && self.requires_grad_enabled() {
result.requires_grad_(true);
let grad_fn = Arc::new(SumBackward {
input_shape: self.shape().to_vec(),
});
result.set_grad_fn(grad_fn.clone());
with_graph(|graph| {
graph.register_tensor(self.clone());
graph.record(result.id(), grad_fn, vec![self.id()]);
});
}
result
}
#[must_use]
pub fn mean(&self) -> Tensor {
let sum: f32 = self.data().iter().sum();
let mean = sum / self.numel() as f32;
let mut result = Tensor::new(&[mean], &[1]);
if is_grad_enabled() && self.requires_grad_enabled() {
result.requires_grad_(true);
let grad_fn = Arc::new(MeanBackward {
input_shape: self.shape().to_vec(),
});
result.set_grad_fn(grad_fn.clone());
with_graph(|graph| {
graph.register_tensor(self.clone());
graph.record(result.id(), grad_fn, vec![self.id()]);
});
}
result
}
}
include!("activation.rs");