use std::any::Any;
use tracing::Level;
use crate::{DType, Primitive, Tensor};
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct Negative;
impl Primitive for Negative {
fn clone_boxed(&self) -> Box<dyn Primitive> {
Box::new(self.clone())
}
fn as_any(&self) -> &dyn Any {
self
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn jvp(&self, _output: &Tensor, _primals: &[Tensor], tangents: &[Tensor]) -> Tensor {
let tangent_x = &tangents[0];
-tangent_x
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn vjp(&self, _output: &Tensor, _primals: &[Tensor], cotangent: &Tensor) -> Vec<Tensor> {
let cotangent_x = -cotangent;
vec![cotangent_x]
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct Sin;
impl Primitive for Sin {
fn clone_boxed(&self) -> Box<dyn Primitive> {
Box::new(self.clone())
}
fn as_any(&self) -> &dyn Any {
self
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn jvp(&self, _output: &Tensor, primals: &[Tensor], tangents: &[Tensor]) -> Tensor {
let x = &primals[0];
let tangent_x = &tangents[0];
x.cos() * tangent_x
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn vjp(&self, _output: &Tensor, primals: &[Tensor], cotangent: &Tensor) -> Vec<Tensor> {
let x = &primals[0];
let cotangent_x = cotangent * x.cos();
vec![cotangent_x]
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct Cos;
impl Primitive for Cos {
fn clone_boxed(&self) -> Box<dyn Primitive> {
Box::new(self.clone())
}
fn as_any(&self) -> &dyn Any {
self
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn jvp(&self, _output: &Tensor, primals: &[Tensor], tangents: &[Tensor]) -> Tensor {
let x = &primals[0];
let tangent_x = &tangents[0];
-x.sin() * tangent_x
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn vjp(&self, _output: &Tensor, primals: &[Tensor], cotangent: &Tensor) -> Vec<Tensor> {
let x = &primals[0];
let cotangent_x = cotangent * -x.sin();
vec![cotangent_x]
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct Square;
impl Primitive for Square {
fn clone_boxed(&self) -> Box<dyn Primitive> {
Box::new(self.clone())
}
fn as_any(&self) -> &dyn Any {
self
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn jvp(&self, _output: &Tensor, primals: &[Tensor], tangents: &[Tensor]) -> Tensor {
let x = &primals[0];
let tangent_x = &tangents[0];
x * 2.0 * tangent_x
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn vjp(&self, _output: &Tensor, primals: &[Tensor], cotangent: &Tensor) -> Vec<Tensor> {
let x = &primals[0];
let cotangent_x = cotangent * x * 2.0;
vec![cotangent_x]
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct Sqrt;
impl Primitive for Sqrt {
fn clone_boxed(&self) -> Box<dyn Primitive> {
Box::new(self.clone())
}
fn as_any(&self) -> &dyn Any {
self
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn jvp(&self, _output: &Tensor, primals: &[Tensor], tangents: &[Tensor]) -> Tensor {
let x = &primals[0];
let tangent_x = &tangents[0];
tangent_x * 0.5 / x.sqrt()
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn vjp(&self, _output: &Tensor, primals: &[Tensor], cotangent: &Tensor) -> Vec<Tensor> {
let x = &primals[0];
let cotangent_x = cotangent * 0.5 / x.sqrt();
vec![cotangent_x]
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct Rsqrt;
impl Primitive for Rsqrt {
fn clone_boxed(&self) -> Box<dyn Primitive> {
Box::new(self.clone())
}
fn as_any(&self) -> &dyn Any {
self
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn jvp(&self, _output: &Tensor, primals: &[Tensor], tangents: &[Tensor]) -> Tensor {
let x = &primals[0];
let tangent_x = &tangents[0];
-0.5 * tangent_x * (x.rsqrt() / x)
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn vjp(&self, _output: &Tensor, primals: &[Tensor], cotangent: &Tensor) -> Vec<Tensor> {
let x = &primals[0];
let cotangent_x = -0.5 * cotangent * (x.rsqrt() / x);
vec![cotangent_x]
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct Sign;
impl Primitive for Sign {
fn clone_boxed(&self) -> Box<dyn Primitive> {
Box::new(self.clone())
}
fn as_any(&self) -> &dyn Any {
self
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn jvp(&self, _output: &Tensor, primals: &[Tensor], _tangents: &[Tensor]) -> Tensor {
let x = &primals[0];
x.zeros_like()
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn vjp(&self, _output: &Tensor, primals: &[Tensor], _cotangent: &Tensor) -> Vec<Tensor> {
let x = &primals[0];
let cotangent_x = x.zeros_like();
vec![cotangent_x]
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct Abs;
impl Primitive for Abs {
fn clone_boxed(&self) -> Box<dyn Primitive> {
Box::new(self.clone())
}
fn as_any(&self) -> &dyn Any {
self
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn jvp(&self, _output: &Tensor, primals: &[Tensor], tangents: &[Tensor]) -> Tensor {
let x = &primals[0];
let tangent_x = &tangents[0];
tangent_x * x.sign()
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn vjp(&self, _output: &Tensor, primals: &[Tensor], cotangent: &Tensor) -> Vec<Tensor> {
let x = &primals[0];
let cotangent_x = cotangent * x.sign();
vec![cotangent_x]
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct Exp;
impl Primitive for Exp {
fn clone_boxed(&self) -> Box<dyn Primitive> {
Box::new(self.clone())
}
fn as_any(&self) -> &dyn Any {
self
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn jvp(&self, _output: &Tensor, primals: &[Tensor], tangents: &[Tensor]) -> Tensor {
let x = &primals[0];
let tangent_x = &tangents[0];
tangent_x * x.exp()
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn vjp(&self, _output: &Tensor, primals: &[Tensor], cotangent: &Tensor) -> Vec<Tensor> {
let x = &primals[0];
let cotangent_x = cotangent * x.exp();
vec![cotangent_x]
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct Log;
impl Primitive for Log {
fn clone_boxed(&self) -> Box<dyn Primitive> {
Box::new(self.clone())
}
fn as_any(&self) -> &dyn Any {
self
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn jvp(&self, _output: &Tensor, primals: &[Tensor], tangents: &[Tensor]) -> Tensor {
let x = &primals[0];
let tangent_x = &tangents[0];
tangent_x / x
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn vjp(&self, _output: &Tensor, primals: &[Tensor], cotangent: &Tensor) -> Vec<Tensor> {
let x = &primals[0];
let cotangent_x = cotangent / x;
vec![cotangent_x]
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct Log2;
impl Primitive for Log2 {
fn clone_boxed(&self) -> Box<dyn Primitive> {
Box::new(self.clone())
}
fn as_any(&self) -> &dyn Any {
self
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn jvp(&self, _output: &Tensor, primals: &[Tensor], tangents: &[Tensor]) -> Tensor {
let x = &primals[0];
let tangent_x = &tangents[0];
tangent_x / (x * f32::ln(2.0))
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn vjp(&self, _output: &Tensor, primals: &[Tensor], cotangent: &Tensor) -> Vec<Tensor> {
let x = &primals[0];
let cotangent_x = cotangent / (x * f32::ln(2.0));
vec![cotangent_x]
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct Log10;
impl Primitive for Log10 {
fn clone_boxed(&self) -> Box<dyn Primitive> {
Box::new(self.clone())
}
fn as_any(&self) -> &dyn Any {
self
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn jvp(&self, _output: &Tensor, primals: &[Tensor], tangents: &[Tensor]) -> Tensor {
let x = &primals[0];
let tangent_x = &tangents[0];
tangent_x / (x * f32::ln(10.0))
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn vjp(&self, _output: &Tensor, primals: &[Tensor], cotangent: &Tensor) -> Vec<Tensor> {
let x = &primals[0];
let cotangent_x = cotangent / (x * f32::ln(10.0));
vec![cotangent_x]
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct AsType {
pub dtype: DType,
}
impl AsType {
pub fn new(dtype: DType) -> Self {
Self { dtype }
}
}
impl Primitive for AsType {
fn clone_boxed(&self) -> Box<dyn Primitive> {
Box::new(self.clone())
}
fn as_any(&self) -> &dyn Any {
self
}
fn dot_label(&self) -> String {
format!("AsType({:?})", &self.dtype)
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn jvp(&self, _output: &Tensor, primals: &[Tensor], tangents: &[Tensor]) -> Tensor {
let tangent_x = &tangents[0];
tangent_x.as_type(self.dtype)
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn vjp(&self, _output: &Tensor, primals: &[Tensor], cotangent: &Tensor) -> Vec<Tensor> {
let x = &primals[0];
let cotangent_x = cotangent.as_type(x.dtype());
vec![cotangent_x]
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct Softmax {
pub dim: usize,
}
impl Softmax {
pub fn new(dim: usize) -> Self {
Self { dim }
}
}
impl Primitive for Softmax {
fn clone_boxed(&self) -> Box<dyn Primitive> {
Box::new(self.clone())
}
fn dot_label(&self) -> String {
format!("Softmax({:?})", &self.dim)
}
fn as_any(&self) -> &dyn Any {
self
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn jvp(&self, output: &Tensor, _primals: &[Tensor], tangents: &[Tensor]) -> Tensor {
let tangent_x = &tangents[0];
let sv = &(output * tangent_x);
sv - output * sv.sum((self.dim, true))
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn vjp(&self, output: &Tensor, _primals: &[Tensor], cotangent: &Tensor) -> Vec<Tensor> {
let sv = &(output * cotangent);
let cotangent_x = sv - output * sv.sum((self.dim, true));
vec![cotangent_x]
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct LogSoftmax {
pub dim: usize,
}
impl LogSoftmax {
pub fn new(dim: usize) -> Self {
Self { dim }
}
}
impl Primitive for LogSoftmax {
fn clone_boxed(&self) -> Box<dyn Primitive> {
Box::new(self.clone())
}
fn dot_label(&self) -> String {
format!("LogSoftmax({:?})", &self.dim)
}
fn as_any(&self) -> &dyn Any {
self
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn jvp(&self, output: &Tensor, _primals: &[Tensor], tangents: &[Tensor]) -> Tensor {
let tangent_x = &tangents[0];
tangent_x - tangent_x.sum((self.dim, true)) * output.exp()
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn vjp(&self, output: &Tensor, _primals: &[Tensor], cotangent: &Tensor) -> Vec<Tensor> {
let cotangent_x = cotangent - cotangent.sum((self.dim, true)) * output.exp();
vec![cotangent_x]
}
}