use crate::{Primitive, Tensor};
use std::{any::Any, f64::consts::PI};
use tracing::Level;
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct Negative;
impl Primitive for Negative {
fn clone_boxed(&self) -> Box<dyn Primitive> {
Box::new(self.clone())
}
fn as_any(&self) -> &dyn Any {
self
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn jvp(&self, _output: &Tensor, _primals: &[Tensor], tangents: &[Tensor]) -> Tensor {
let tangent_x = &tangents[0];
-tangent_x
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn vjp(&self, _output: &Tensor, _primals: &[Tensor], cotangent: &Tensor) -> Vec<Tensor> {
let cotangent_x = -cotangent;
vec![cotangent_x]
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct Sin;
impl Primitive for Sin {
fn clone_boxed(&self) -> Box<dyn Primitive> {
Box::new(self.clone())
}
fn as_any(&self) -> &dyn Any {
self
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn jvp(&self, _output: &Tensor, primals: &[Tensor], tangents: &[Tensor]) -> Tensor {
let x = &primals[0];
let tangent_x = &tangents[0];
x.cos() * tangent_x
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn vjp(&self, _output: &Tensor, primals: &[Tensor], cotangent: &Tensor) -> Vec<Tensor> {
let x = &primals[0];
let cotangent_x = cotangent * x.cos();
vec![cotangent_x]
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct Cos;
impl Primitive for Cos {
fn clone_boxed(&self) -> Box<dyn Primitive> {
Box::new(self.clone())
}
fn as_any(&self) -> &dyn Any {
self
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn jvp(&self, _output: &Tensor, primals: &[Tensor], tangents: &[Tensor]) -> Tensor {
let x = &primals[0];
let tangent_x = &tangents[0];
-x.sin() * tangent_x
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn vjp(&self, _output: &Tensor, primals: &[Tensor], cotangent: &Tensor) -> Vec<Tensor> {
let x = &primals[0];
let cotangent_x = cotangent * -x.sin();
vec![cotangent_x]
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct Square;
impl Primitive for Square {
fn clone_boxed(&self) -> Box<dyn Primitive> {
Box::new(self.clone())
}
fn as_any(&self) -> &dyn Any {
self
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn jvp(&self, _output: &Tensor, primals: &[Tensor], tangents: &[Tensor]) -> Tensor {
let x = &primals[0];
let tangent_x = &tangents[0];
2 * x * tangent_x
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn vjp(&self, _output: &Tensor, primals: &[Tensor], cotangent: &Tensor) -> Vec<Tensor> {
let x = &primals[0];
let cotangent_x = 2 * cotangent * x;
vec![cotangent_x]
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct Sqrt;
impl Primitive for Sqrt {
fn clone_boxed(&self) -> Box<dyn Primitive> {
Box::new(self.clone())
}
fn as_any(&self) -> &dyn Any {
self
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn jvp(&self, _output: &Tensor, primals: &[Tensor], tangents: &[Tensor]) -> Tensor {
let x = &primals[0];
let tangent_x = &tangents[0];
0.5 * tangent_x / x.sqrt()
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn vjp(&self, _output: &Tensor, primals: &[Tensor], cotangent: &Tensor) -> Vec<Tensor> {
let x = &primals[0];
let cotangent_x = 0.5 * cotangent / x.sqrt();
vec![cotangent_x]
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct Rsqrt;
impl Primitive for Rsqrt {
fn clone_boxed(&self) -> Box<dyn Primitive> {
Box::new(self.clone())
}
fn as_any(&self) -> &dyn Any {
self
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn jvp(&self, _output: &Tensor, primals: &[Tensor], tangents: &[Tensor]) -> Tensor {
let x = &primals[0];
let tangent_x = &tangents[0];
-0.5 * tangent_x * (x.rsqrt() / x)
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn vjp(&self, _output: &Tensor, primals: &[Tensor], cotangent: &Tensor) -> Vec<Tensor> {
let x = &primals[0];
let cotangent_x = -0.5 * cotangent * (x.rsqrt() / x);
vec![cotangent_x]
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct Sign;
impl Primitive for Sign {
fn clone_boxed(&self) -> Box<dyn Primitive> {
Box::new(self.clone())
}
fn as_any(&self) -> &dyn Any {
self
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn jvp(&self, _output: &Tensor, primals: &[Tensor], _tangents: &[Tensor]) -> Tensor {
let x = &primals[0];
x.zeros_like()
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn vjp(&self, _output: &Tensor, primals: &[Tensor], _cotangent: &Tensor) -> Vec<Tensor> {
let x = &primals[0];
let cotangent_x = x.zeros_like();
vec![cotangent_x]
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct Abs;
impl Primitive for Abs {
fn clone_boxed(&self) -> Box<dyn Primitive> {
Box::new(self.clone())
}
fn as_any(&self) -> &dyn Any {
self
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn jvp(&self, _output: &Tensor, primals: &[Tensor], tangents: &[Tensor]) -> Tensor {
let x = &primals[0];
let tangent_x = &tangents[0];
tangent_x * x.sign()
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn vjp(&self, _output: &Tensor, primals: &[Tensor], cotangent: &Tensor) -> Vec<Tensor> {
let x = &primals[0];
let cotangent_x = cotangent * x.sign();
vec![cotangent_x]
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct Exp;
impl Primitive for Exp {
fn clone_boxed(&self) -> Box<dyn Primitive> {
Box::new(self.clone())
}
fn as_any(&self) -> &dyn Any {
self
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn jvp(&self, _output: &Tensor, primals: &[Tensor], tangents: &[Tensor]) -> Tensor {
let x = &primals[0];
let tangent_x = &tangents[0];
tangent_x * x.exp()
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn vjp(&self, _output: &Tensor, primals: &[Tensor], cotangent: &Tensor) -> Vec<Tensor> {
let x = &primals[0];
let cotangent_x = cotangent * x.exp();
vec![cotangent_x]
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct Log;
impl Primitive for Log {
fn clone_boxed(&self) -> Box<dyn Primitive> {
Box::new(self.clone())
}
fn as_any(&self) -> &dyn Any {
self
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn jvp(&self, _output: &Tensor, primals: &[Tensor], tangents: &[Tensor]) -> Tensor {
let x = &primals[0];
let tangent_x = &tangents[0];
tangent_x / x
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn vjp(&self, _output: &Tensor, primals: &[Tensor], cotangent: &Tensor) -> Vec<Tensor> {
let x = &primals[0];
let cotangent_x = cotangent / x;
vec![cotangent_x]
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct Log2;
impl Primitive for Log2 {
fn clone_boxed(&self) -> Box<dyn Primitive> {
Box::new(self.clone())
}
fn as_any(&self) -> &dyn Any {
self
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn jvp(&self, _output: &Tensor, primals: &[Tensor], tangents: &[Tensor]) -> Tensor {
let x = &primals[0];
let tangent_x = &tangents[0];
tangent_x / (x * f32::ln(2.0))
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn vjp(&self, _output: &Tensor, primals: &[Tensor], cotangent: &Tensor) -> Vec<Tensor> {
let x = &primals[0];
let cotangent_x = cotangent / (x * f32::ln(2.0));
vec![cotangent_x]
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct Log10;
impl Primitive for Log10 {
fn clone_boxed(&self) -> Box<dyn Primitive> {
Box::new(self.clone())
}
fn as_any(&self) -> &dyn Any {
self
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn jvp(&self, _output: &Tensor, primals: &[Tensor], tangents: &[Tensor]) -> Tensor {
let x = &primals[0];
let tangent_x = &tangents[0];
tangent_x / (x * f32::ln(10.0))
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn vjp(&self, _output: &Tensor, primals: &[Tensor], cotangent: &Tensor) -> Vec<Tensor> {
let x = &primals[0];
let cotangent_x = cotangent / (x * f32::ln(10.0));
vec![cotangent_x]
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct Softmax {
pub dim: usize,
}
impl Softmax {
pub fn new(dim: usize) -> Self {
Self { dim }
}
}
impl Primitive for Softmax {
fn clone_boxed(&self) -> Box<dyn Primitive> {
Box::new(self.clone())
}
fn dot_label(&self) -> String {
format!("Softmax({:?})", &self.dim)
}
fn as_any(&self) -> &dyn Any {
self
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn jvp(&self, output: &Tensor, _primals: &[Tensor], tangents: &[Tensor]) -> Tensor {
let tangent_x = &tangents[0];
let sv = &(output * tangent_x);
sv - output * sv.sum((self.dim, true))
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn vjp(&self, output: &Tensor, _primals: &[Tensor], cotangent: &Tensor) -> Vec<Tensor> {
let sv = &(output * cotangent);
let cotangent_x = sv - output * sv.sum((self.dim, true));
vec![cotangent_x]
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct LogSoftmax {
pub dim: usize,
}
impl LogSoftmax {
pub fn new(dim: usize) -> Self {
Self { dim }
}
}
impl Primitive for LogSoftmax {
fn clone_boxed(&self) -> Box<dyn Primitive> {
Box::new(self.clone())
}
fn dot_label(&self) -> String {
format!("LogSoftmax({:?})", &self.dim)
}
fn as_any(&self) -> &dyn Any {
self
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn jvp(&self, output: &Tensor, _primals: &[Tensor], tangents: &[Tensor]) -> Tensor {
let tangent_x = &tangents[0];
tangent_x - tangent_x.sum((self.dim, true)) * output.exp()
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn vjp(&self, output: &Tensor, _primals: &[Tensor], cotangent: &Tensor) -> Vec<Tensor> {
let cotangent_x = cotangent - cotangent.sum((self.dim, true)) * output.exp();
vec![cotangent_x]
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct Erf;
impl Primitive for Erf {
fn clone_boxed(&self) -> Box<dyn Primitive> {
Box::new(self.clone())
}
fn as_any(&self) -> &dyn Any {
self
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn jvp(&self, output: &Tensor, primals: &[Tensor], tangents: &[Tensor]) -> Tensor {
let x = &primals[0];
let tangent_x = &tangents[0];
(2. / PI.sqrt()) * (x.square().neg()).exp() * tangent_x
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn vjp(&self, output: &Tensor, primals: &[Tensor], cotangent: &Tensor) -> Vec<Tensor> {
let x = &primals[0];
let cotangent_x = (2. / PI.sqrt()) * (x.square().neg()).exp() * cotangent;
vec![cotangent_x]
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct Tanh;
impl Primitive for Tanh {
fn clone_boxed(&self) -> Box<dyn Primitive> {
Box::new(self.clone())
}
fn as_any(&self) -> &dyn Any {
self
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn jvp(&self, output: &Tensor, primals: &[Tensor], tangents: &[Tensor]) -> Tensor {
let x = &primals[0];
let tangent_x = &tangents[0];
(tangent_x + tangent_x * output) * (x.ones_like() - output)
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn vjp(&self, output: &Tensor, primals: &[Tensor], cotangent: &Tensor) -> Vec<Tensor> {
let x = &primals[0];
let cotangent_x = (cotangent + cotangent * output) * (x.ones_like() - output);
vec![cotangent_x]
}
}
#[derive(Clone, Debug, PartialEq)]
pub struct PowerFloat {
pub exponent: f64,
}
impl PowerFloat {
pub fn new(exponent: f64) -> Self {
Self { exponent }
}
pub fn exponent(&self) -> f64 {
self.exponent
}
}
impl Primitive for PowerFloat {
fn clone_boxed(&self) -> Box<dyn Primitive> {
Box::new(self.clone())
}
fn dot_label(&self) -> String {
format!("PowerFloat({:?})", &self.exponent)
}
fn as_any(&self) -> &dyn Any {
self
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn jvp(&self, output: &Tensor, primals: &[Tensor], tangents: &[Tensor]) -> Tensor {
let x = &primals[0];
let tangent_x = &tangents[0];
tangent_x * x.powf(self.exponent - 1.0) * self.exponent
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn vjp(&self, output: &Tensor, primals: &[Tensor], cotangent: &Tensor) -> Vec<Tensor> {
let x = &primals[0];
let cotangent_x = cotangent * x.powf(self.exponent - 1.0) * self.exponent;
vec![cotangent_x]
}
}