use std::any::Any;
use tracing::Level;
use crate::{Primitive, Shape, Tensor};
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct Broadcast {
pub shape: Vec<usize>,
}
impl Broadcast {
pub fn new(shape: impl Shape) -> Self {
Self {
shape: shape.shape().to_vec(),
}
}
pub fn shape(&self) -> &[usize] {
&self.shape
}
}
impl Primitive for Broadcast {
fn clone_boxed(&self) -> Box<dyn Primitive> {
Box::new(self.clone())
}
fn dot_label(&self) -> String {
format!("Broadcast({:?})", self.shape)
}
fn as_any(&self) -> &dyn Any {
self
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn jvp(&self, output: &Tensor, primals: &[Tensor], tangents: &[Tensor]) -> Tensor {
let tangent_x = &tangents[0];
tangent_x.broadcast_to(self.shape())
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn vjp(&self, _output: &Tensor, primals: &[Tensor], cotangent: &Tensor) -> Vec<Tensor> {
let x = &primals[0];
let shape = x.shape().to_vec();
let diff = cotangent.ndim() - shape.ndim();
let mut dims = Vec::new();
for i in 0..cotangent.ndim() {
if i < diff || shape[i - diff] != cotangent.shape_at(i) {
dims.push(i);
}
}
let cotangent_x = cotangent.sum((dims, true)).reshape(&shape);
vec![cotangent_x]
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct Reshape {
pub shape: Vec<usize>,
}
impl Reshape {
pub fn new(shape: impl Shape) -> Self {
Self {
shape: shape.shape().to_vec(),
}
}
pub fn shape(&self) -> &[usize] {
&self.shape
}
}
impl Primitive for Reshape {
fn clone_boxed(&self) -> Box<dyn Primitive> {
Box::new(self.clone())
}
fn as_any(&self) -> &dyn Any {
self
}
fn dot_label(&self) -> String {
format!("Reshape({:?})", self.shape)
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn jvp(&self, output: &Tensor, _primals: &[Tensor], tangents: &[Tensor]) -> Tensor {
let tangent_x = &tangents[0];
tangent_x.reshape(self.shape())
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn vjp(&self, _output: &Tensor, primals: &[Tensor], cotangent: &Tensor) -> Vec<Tensor> {
let x = &primals[0];
let cotangent_x = cotangent.reshape(x);
vec![cotangent_x]
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct Transpose {
pub dims: Vec<usize>,
}
impl Transpose {
pub fn new(dims: impl Into<Vec<usize>>) -> Self {
Self { dims: dims.into() }
}
pub fn dims(&self) -> &[usize] {
&self.dims
}
}
impl Primitive for Transpose {
fn clone_boxed(&self) -> Box<dyn Primitive> {
Box::new(self.clone())
}
fn as_any(&self) -> &dyn Any {
self
}
fn dot_label(&self) -> String {
format!("Transpose({:?})", self.dims())
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn jvp(&self, _output: &Tensor, _primals: &[Tensor], tangents: &[Tensor]) -> Tensor {
let tangent_x = &tangents[0];
tangent_x.transpose(self.dims())
}
#[tracing::instrument(ret(level = Level::TRACE))]
fn vjp(&self, _output: &Tensor, _primals: &[Tensor], cotangent: &Tensor) -> Vec<Tensor> {
let mut dims = vec![0; self.dims.len()];
for i in 0..dims.len() {
dims[self.dims[i]] = i;
}
let cotangent_x = cotangent.transpose(dims);
vec![cotangent_x]
}
}