use crate::{op, prelude::*};
use std::ops::{Add, Mul, Neg};
impl<S: Shape> Neg for GraphTensor<S> {
type Output = GraphTensor<S>;
fn neg(self) -> Self::Output {
self * -1.0
}
}
impl<S: Shape> GraphTensor<S> {
pub fn log2(self) -> GraphTensor<S> {
let new_id = self
.graph()
.add_op(op::Log2)
.input(self.id, 0, self.shape)
.finish();
GraphTensor::from_id(new_id, self.shape, self.graph_ref)
}
pub fn exp2(self) -> GraphTensor<S> {
let new_id = self
.graph()
.add_op(op::Exp2)
.input(self.id, 0, self.shape)
.finish();
GraphTensor::from_id(new_id, self.shape, self.graph_ref)
}
pub fn exp(self) -> GraphTensor<S> {
(self * (1.0 / f32::ln(2.))).exp2()
}
pub fn ln(self) -> GraphTensor<S> {
self.log2() * f32::ln(2.)
}
pub fn recip(self) -> GraphTensor<S> {
let new_id = self
.graph()
.add_op(op::Recip)
.input(self.id, 0, self.shape)
.finish();
GraphTensor::from_id(new_id, self.shape, self.graph_ref)
}
pub fn sin(self) -> GraphTensor<S> {
let new_id = self
.graph()
.add_op(op::Sin)
.input(self.id, 0, self.shape)
.finish();
GraphTensor::from_id(new_id, self.shape, self.graph_ref)
}
pub fn cos(self) -> GraphTensor<S> {
(-self + (std::f32::consts::PI / 2.)).sin()
}
pub fn sqrt(self) -> GraphTensor<S> {
let new_id = self
.graph()
.add_op(op::Sqrt)
.input(self.id, 0, self.shape)
.finish();
GraphTensor::from_id(new_id, self.shape, self.graph_ref)
}
pub fn std_norm<const DIM: usize, T>(self, epsilon: T) -> GraphTensor<S>
where
<S as ReduceShape<Axis<DIM>>>::Reduced: Shape,
GraphTensor<<S as ReduceShape<Axis<DIM>>>::Reduced>:
Add<T, Output = GraphTensor<<S as ReduceShape<Axis<DIM>>>::Reduced>>,
S: ReduceShape<Axis<DIM>>,
{
(self * self)
.mean_reduce::<<S as ReduceShape<Axis<DIM>>>::Reduced, _>()
.add(epsilon)
.sqrt()
.recip()
.expand()
.mul(self)
}
pub fn mean_norm<const DIM: usize>(self) -> GraphTensor<S>
where
<S as ReduceShape<Axis<DIM>>>::Reduced: Shape,
S: ReduceShape<Axis<DIM>>,
{
self - self
.mean_reduce::<<S as ReduceShape<Axis<DIM>>>::Reduced, _>()
.expand()
}
pub fn layer_norm<const DIM: usize, T>(self, epsilon: T) -> GraphTensor<S>
where
<S as ReduceShape<Axis<DIM>>>::Reduced: Shape,
GraphTensor<<S as ReduceShape<Axis<DIM>>>::Reduced>:
Add<T, Output = GraphTensor<<S as ReduceShape<Axis<DIM>>>::Reduced>>,
S: ReduceShape<Axis<DIM>>,
{
self.mean_norm().std_norm(epsilon)
}
pub fn softmax<const DIM: usize>(self) -> GraphTensor<S>
where
<S as ReduceShape<Axis<DIM>>>::Reduced: Shape,
S: ReduceShape<Axis<DIM>>,
{
let m = self
- self
.max_reduce::<<S as ReduceShape<Axis<DIM>>>::Reduced, _>()
.expand();
let exp = m.exp();
let exp_sum = exp.sum_reduce::<<S as ReduceShape<Axis<DIM>>>::Reduced, _>();
exp / exp_sum.expand()
}
pub fn argmax(self) -> GraphTensor<<S as ReduceShape<<S as Shape>::LastAxis>>::Reduced> {
let x_equal = self.equals(self.max_reduce::<_, S::LastAxis>().expand());
let r = self.graph().constant(1.).expand().cumsum_last_dim() - 1.;
(x_equal * r).max_reduce::<_, S::LastAxis>()
}
pub fn abs(self) -> GraphTensor<S> {
self.relu() + (-self).relu()
}
pub fn sign(self) -> GraphTensor<S> {
self / (self.abs() + 1e-10)
}
pub fn pow<T>(self, e: T) -> GraphTensor<S>
where
Self: Mul<T, Output = Self>,
{
self.abs().ln().mul(e).exp()
}
pub fn inv_pow(self, base: f32) -> GraphTensor<S> {
self.mul(base.abs().ln()).exp()
}
pub fn relu(self) -> GraphTensor<S> {
self.max_f32(0.)
}
pub fn sigmoid(self) -> GraphTensor<S> {
let one = self.graph().constant(1.0);
one.expand() / (one.expand() + (-self).exp())
}
pub fn swish(self) -> GraphTensor<S> {
self * self.sigmoid()
}
pub fn tanh(self) -> GraphTensor<S> {
(self * 2.0).sigmoid() * 2.0 - 1.0
}
pub fn leaky_relu(self, neg_slope: f32) -> GraphTensor<S> {
self.relu() - (self * -neg_slope).relu()
}
}
#[cfg(test)]
mod tests {
crate::test_imports!();
#[test]
fn test_exp() {
let mut cx = Graph::new();
let a_data = random_vec(6);
let a = cx.tensor::<R2<2, 3>>().set(a_data.clone());
let b = a.exp().retrieve();
cx.execute();
let d_dev = Cpu::default();
let d_a = d_dev.tensor_from_vec(a_data, (DConst::<2>, DConst::<3>));
let d_b = d_a.exp();
assert_close(&b.data(), &d_b.as_vec());
}
#[test]
fn test_layer_norm() {
let mut cx = Graph::new();
let a_data = random_vec(6);
let a = cx.tensor::<R2<2, 3>>().set(a_data.clone());
let b = a.layer_norm::<0, _>(1e-5).retrieve();
let c = a.layer_norm::<1, _>(1e-5).retrieve();
cx.execute();
let d_dev = Cpu::default();
let d_a = d_dev.tensor_from_vec(a_data, (DConst::<2>, DConst::<3>));
let d_b = d_a.clone().normalize::<DAxis<0>>(1e-5);
let d_c = d_a.normalize::<DAxis<1>>(1e-5);
assert_close(&b.data(), &d_b.as_vec());
assert_close(&c.data(), &d_c.as_vec());
}
#[test]
fn test_softmax() {
let mut cx = Graph::new();
let a_data = random_vec(6);
let a = cx.tensor::<R2<2, 3>>().set(a_data.clone());
let b = a.softmax::<1>().retrieve();
cx.execute();
let d_dev = Cpu::default();
let d_a = d_dev.tensor_from_vec(a_data, (DConst::<2>, DConst::<3>));
let d_b = d_a.softmax::<DAxis<1>>();
let r = b.data();
assert_close(&r, &d_b.as_vec());
}
#[test]
fn test_sin() {
let mut cx = Graph::new();
let a_data = random_vec(6);
let a = cx.tensor::<R2<2, 3>>().set(a_data.clone());
let b = a.sin().retrieve();
cx.execute();
let d_dev = Cpu::default();
let d_a = d_dev.tensor_from_vec(a_data, (DConst::<2>, DConst::<3>));
let d_b = d_a.sin();
let r = b.data();
assert_close(&r, &d_b.as_vec());
}
#[test]
fn test_cos() {
let mut cx = Graph::new();
let a_data = random_vec(6);
let a = cx.tensor::<R2<2, 3>>().set(a_data.clone());
let b = a.cos().retrieve();
cx.execute();
let d_dev = Cpu::default();
let d_a = d_dev.tensor_from_vec(a_data, (DConst::<2>, DConst::<3>));
let d_b = d_a.cos();
assert_close(&b.data(), &d_b.as_vec());
}
#[test]
fn test_relu() {
let mut cx = Graph::new();
let a_data = random_vec(4);
let a = cx
.tensor::<(Dyn<'a'>, Dyn<'b'>)>()
.set_dyn(a_data.clone(), &[2, 2]);
let b = a.relu().retrieve();
cx.execute();
let d_dev = Cpu::default();
let d_a = d_dev.tensor_from_vec(a_data, (DConst::<2>, DConst::<2>));
let d_b = d_a.relu();
assert_close(&b.data(), &d_b.as_vec());
}
#[test]
fn test_sigmoid() {
let mut cx = Graph::new();
let a_data = random_vec(4);
let a = cx
.tensor::<(Dyn<'a'>, Dyn<'b'>)>()
.set_dyn(a_data.clone(), &[2, 2]);
let b = a.sigmoid().retrieve();
cx.execute();
let d_dev = Cpu::default();
let d_a = d_dev.tensor_from_vec(a_data, (DConst::<2>, DConst::<2>));
let d_b = d_a.sigmoid();
assert_close(&b.data(), &d_b.as_vec());
}
#[test]
fn test_swish() {
let mut cx = Graph::new();
let a_data = random_vec(4);
let a = cx
.tensor::<(Dyn<'a'>, Dyn<'b'>)>()
.set_dyn(a_data.clone(), &[2, 2]);
let b = a.swish().retrieve();
cx.execute();
let d_dev = Cpu::default();
let d_a = d_dev.tensor_from_vec(a_data, (DConst::<2>, DConst::<2>));
let d_b = d_a.clone() * d_a.sigmoid();
assert_close(&b.data(), &d_b.as_vec());
}
#[test]
fn test_tanh() {
let mut cx = Graph::new();
let a_data = random_vec(4);
let a = cx
.tensor::<(Dyn<'a'>, Dyn<'b'>)>()
.set_dyn(a_data.clone(), &[2, 2]);
let b = a.tanh().retrieve();
cx.execute();
let d_dev = Cpu::default();
let d_a = d_dev.tensor_from_vec(a_data, (DConst::<2>, DConst::<2>));
let d_b = d_a.tanh();
assert_close(&b.data(), &d_b.as_vec());
}
}