use crate::{Primitive, Tensor};
use std::any::Any;
use tracing::Level;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct Add;
impl Primitive for Add {
fn clone_boxed(&self) -> Box<dyn Primitive> {
Box::new(*self)
}
fn as_any(&self) -> &dyn Any {
self
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn jvp(&self, _output: &Tensor, _primals: &[Tensor], tangents: &[Tensor]) -> Tensor {
let tangent_lhs = &tangents[0];
let tangent_rhs = &tangents[1];
tangent_lhs + tangent_rhs
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn vjp(&self, _output: &Tensor, _primals: &[Tensor], cotangent: &Tensor) -> Vec<Tensor> {
let cotangent_lhs = cotangent.clone();
let cotangent_rhs = cotangent.clone();
vec![cotangent_lhs, cotangent_rhs]
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct Sub;
impl Primitive for Sub {
fn clone_boxed(&self) -> Box<dyn Primitive> {
Box::new(*self)
}
fn as_any(&self) -> &dyn Any {
self
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn jvp(&self, _output: &Tensor, _primals: &[Tensor], tangents: &[Tensor]) -> Tensor {
let tangent_lhs = &tangents[0];
let tangent_rhs = &tangents[1];
tangent_lhs - tangent_rhs
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn vjp(&self, _output: &Tensor, _primals: &[Tensor], cotangent: &Tensor) -> Vec<Tensor> {
let cotangent_lhs = cotangent.clone();
let cotangent_rhs = -cotangent;
vec![cotangent_lhs, cotangent_rhs]
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct Mul;
impl Primitive for Mul {
fn clone_boxed(&self) -> Box<dyn Primitive> {
Box::new(*self)
}
fn as_any(&self) -> &dyn Any {
self
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn jvp(&self, _output: &Tensor, primals: &[Tensor], tangents: &[Tensor]) -> Tensor {
let lhs = &primals[0];
let rhs = &primals[1];
let tangent_lhs = &tangents[0];
let tangent_rhs = &tangents[1];
tangent_lhs * rhs + tangent_rhs * lhs
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn vjp(&self, _output: &Tensor, primals: &[Tensor], cotangent: &Tensor) -> Vec<Tensor> {
let lhs = &primals[0];
let rhs = &primals[1];
let cotangent_lhs = cotangent * rhs;
let cotangent_rhs = cotangent * lhs;
vec![cotangent_lhs, cotangent_rhs]
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct Div;
impl Primitive for Div {
fn clone_boxed(&self) -> Box<dyn Primitive> {
Box::new(*self)
}
fn as_any(&self) -> &dyn Any {
self
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn jvp(&self, output: &Tensor, primals: &[Tensor], tangents: &[Tensor]) -> Tensor {
let rhs = &primals[1];
let tangent_lhs = &tangents[0];
let tangent_rhs = &tangents[1];
tangent_lhs / rhs - output * tangent_rhs
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn vjp(&self, output: &Tensor, primals: &[Tensor], cotangent: &Tensor) -> Vec<Tensor> {
let rhs = &primals[1];
let cotangent_lhs = cotangent / rhs;
let cotangent_rhs = -cotangent * output / rhs;
vec![cotangent_lhs, cotangent_rhs]
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct MatMul;
impl Primitive for MatMul {
fn clone_boxed(&self) -> Box<dyn Primitive> {
Box::new(*self)
}
fn as_any(&self) -> &dyn Any {
self
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn jvp(&self, _output: &Tensor, primals: &[Tensor], tangents: &[Tensor]) -> Tensor {
tangents[0].matmul(&primals[1]) + primals[0].matmul(&tangents[1])
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn vjp(&self, _output: &Tensor, primals: &[Tensor], cotangent: &Tensor) -> Vec<Tensor> {
let lhs = &primals[0];
let rhs = &primals[1];
let cotangent_lhs = cotangent.matmul(rhs.t());
let cotangent_rhs = lhs.t().matmul(cotangent);
vec![cotangent_lhs, cotangent_rhs]
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct Equal;
impl Primitive for Equal {
fn clone_boxed(&self) -> Box<dyn Primitive> {
Box::new(*self)
}
fn as_any(&self) -> &dyn Any {
self
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn jvp(&self, output: &Tensor, _primals: &[Tensor], _tangents: &[Tensor]) -> Tensor {
output.zeros_like()
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn vjp(&self, _output: &Tensor, primals: &[Tensor], _cotangent: &Tensor) -> Vec<Tensor> {
let lhs = &primals[0];
let rhs = &primals[1];
let cotangent_lhs = lhs.zeros_like();
let cotangent_rhs = rhs.zeros_like();
vec![cotangent_lhs, cotangent_rhs]
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct NotEqual;
impl Primitive for NotEqual {
fn clone_boxed(&self) -> Box<dyn Primitive> {
Box::new(*self)
}
fn as_any(&self) -> &dyn Any {
self
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn jvp(&self, output: &Tensor, _primals: &[Tensor], _tangents: &[Tensor]) -> Tensor {
output.zeros_like()
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn vjp(&self, _output: &Tensor, primals: &[Tensor], _cotangent: &Tensor) -> Vec<Tensor> {
let lhs = &primals[0];
let rhs = &primals[1];
let cotangent_lhs = lhs.zeros_like();
let cotangent_rhs = rhs.zeros_like();
vec![cotangent_lhs, cotangent_rhs]
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct Greater;
impl Primitive for Greater {
fn clone_boxed(&self) -> Box<dyn Primitive> {
Box::new(*self)
}
fn as_any(&self) -> &dyn Any {
self
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn jvp(&self, output: &Tensor, _primals: &[Tensor], _tangents: &[Tensor]) -> Tensor {
output.zeros_like()
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn vjp(&self, _output: &Tensor, primals: &[Tensor], _cotangent: &Tensor) -> Vec<Tensor> {
let lhs = &primals[0];
let rhs = &primals[1];
let cotangent_lhs = lhs.zeros_like();
let cotangent_rhs = rhs.zeros_like();
vec![cotangent_lhs, cotangent_rhs]
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct GreaterEqual;
impl Primitive for GreaterEqual {
fn clone_boxed(&self) -> Box<dyn Primitive> {
Box::new(*self)
}
fn as_any(&self) -> &dyn Any {
self
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn jvp(&self, output: &Tensor, _primals: &[Tensor], _tangents: &[Tensor]) -> Tensor {
output.zeros_like()
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn vjp(&self, _output: &Tensor, primals: &[Tensor], _cotangent: &Tensor) -> Vec<Tensor> {
let lhs = &primals[0];
let rhs = &primals[1];
let cotangent_lhs = lhs.zeros_like();
let cotangent_rhs = rhs.zeros_like();
vec![cotangent_lhs, cotangent_rhs]
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct Less;
impl Primitive for Less {
fn clone_boxed(&self) -> Box<dyn Primitive> {
Box::new(*self)
}
fn as_any(&self) -> &dyn Any {
self
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn jvp(&self, output: &Tensor, _primals: &[Tensor], _tangents: &[Tensor]) -> Tensor {
output.zeros_like()
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn vjp(&self, _output: &Tensor, primals: &[Tensor], _cotangent: &Tensor) -> Vec<Tensor> {
let lhs = &primals[0];
let rhs = &primals[1];
let cotangent_lhs = lhs.zeros_like();
let cotangent_rhs = rhs.zeros_like();
vec![cotangent_lhs, cotangent_rhs]
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct LessEqual;
impl Primitive for LessEqual {
fn clone_boxed(&self) -> Box<dyn Primitive> {
Box::new(*self)
}
fn as_any(&self) -> &dyn Any {
self
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn jvp(&self, output: &Tensor, _primals: &[Tensor], _tangents: &[Tensor]) -> Tensor {
output.zeros_like()
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn vjp(&self, _output: &Tensor, primals: &[Tensor], _cotangent: &Tensor) -> Vec<Tensor> {
let lhs = &primals[0];
let rhs = &primals[1];
let cotangent_lhs = lhs.zeros_like();
let cotangent_rhs = rhs.zeros_like();
vec![cotangent_lhs, cotangent_rhs]
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct Maximum;
impl Primitive for Maximum {
fn clone_boxed(&self) -> Box<dyn Primitive> {
Box::new(*self)
}
fn as_any(&self) -> &dyn Any {
self
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn jvp(&self, output: &Tensor, primals: &[Tensor], tangents: &[Tensor]) -> Tensor {
let lhs = &primals[0];
let rhs = &primals[1];
let tangent_lhs = &tangents[0];
let tangent_rhs = &tangents[1];
let lhs_mask = &output.eq(lhs).to_dtype(tangent_lhs);
let rhs_mask = &output.eq(rhs).to_dtype(tangent_rhs);
tangent_lhs * lhs_mask / (rhs_mask + 1.0) + tangent_rhs * rhs_mask / (lhs_mask + 1.0)
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn vjp(&self, output: &Tensor, primals: &[Tensor], cotangent: &Tensor) -> Vec<Tensor> {
let lhs = &primals[0];
let rhs = &primals[1];
let lhs_mask = &output.eq(lhs).to_dtype(cotangent);
let rhs_mask = &output.eq(rhs).to_dtype(cotangent);
let cotangent_lhs = cotangent * lhs_mask / (rhs_mask + 1.0);
let cotangent_rhs = cotangent * rhs_mask / (lhs_mask + 1.0);
vec![cotangent_lhs, cotangent_rhs]
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct Minimum;
impl Primitive for Minimum {
fn clone_boxed(&self) -> Box<dyn Primitive> {
Box::new(*self)
}
fn as_any(&self) -> &dyn Any {
self
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn jvp(&self, output: &Tensor, primals: &[Tensor], tangents: &[Tensor]) -> Tensor {
let lhs = &primals[0];
let rhs = &primals[1];
let tangent_lhs = &tangents[0];
let tangent_rhs = &tangents[1];
let lhs_mask = &output.eq(lhs).to_dtype(tangent_lhs);
let rhs_mask = &output.eq(rhs).to_dtype(tangent_rhs);
tangent_lhs * lhs_mask / (rhs_mask + 1.0) + tangent_rhs * rhs_mask / (lhs_mask + 1.0)
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn vjp(&self, output: &Tensor, primals: &[Tensor], cotangent: &Tensor) -> Vec<Tensor> {
let lhs = &primals[0];
let rhs = &primals[1];
let lhs_mask = &output.eq(lhs).to_dtype(cotangent);
let rhs_mask = &output.eq(rhs).to_dtype(cotangent);
let cotangent_lhs = cotangent * lhs_mask / (rhs_mask + 1.0);
let cotangent_rhs = cotangent * rhs_mask / (lhs_mask + 1.0);
vec![cotangent_lhs, cotangent_rhs]
}
}