use crate::FusionBackend;
use crate::{HandleContainer, TensorDescription};
use burn_tensor::ops::{ConvOptions, ConvTransposeOptions, InterpolateMode, InterpolateOptions};
use burn_tensor::{Distribution, Element};
use serde::{Deserialize, Serialize};
use std::ops::Range;
pub trait Operation<B: FusionBackend>: Send + Sync {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B>);
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
pub enum OperationDescription {
BaseFloat(BaseOperationDescription),
BaseInt(BaseOperationDescription),
BaseBool(BaseOperationDescription),
NumericFloat(NumericOperationDescription<f32>),
NumericInt(NumericOperationDescription<i32>),
Bool(BoolOperationDescription),
Int(IntOperationDescription),
Float(FloatOperationDescription),
Module(ModuleOperationDescription),
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
pub enum FloatOperationDescription {
Exp(UnaryOperationDescription),
Log(UnaryOperationDescription),
Log1p(UnaryOperationDescription),
Erf(UnaryOperationDescription),
PowfScalar(ScalarOperationDescription<f32>),
Sqrt(UnaryOperationDescription),
Cos(UnaryOperationDescription),
Sin(UnaryOperationDescription),
Tanh(UnaryOperationDescription),
IntoInt(UnaryOperationDescription),
Matmul(BinaryOperationDescription),
Random(RandomOperationDescription),
Recip(UnaryOperationDescription),
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
pub enum ModuleOperationDescription {
Embedding(EmbeddingDescription),
EmbeddingBackward(EmbeddingBackwardDescription),
Conv1d(Conv1dDescription),
Conv2d(Conv2dDescription),
ConvTranspose1d(ConvTranspose1dDescription),
ConvTranspose2d(ConvTranspose2dDescription),
AvgPool1d(AvgPool1dDescription),
AvgPool2d(AvgPool2dDescription),
AvgPool1dBackward(AvgPool1dBackwardDescription),
AvgPool2dBackward(AvgPool2dBackwardDescription),
AdaptiveAvgPool1d(AdaptiveAvgPool1dDescription),
AdaptiveAvgPool2d(AdaptiveAvgPool2dDescription),
AdaptiveAvgPool1dBackward(AdaptiveAvgPool1dBackwardDescription),
AdaptiveAvgPool2dBackward(AdaptiveAvgPool2dBackwardDescription),
MaxPool1d(MaxPool1dDescription),
MaxPool1dWithIndices(MaxPool1dWithIndicesDescription),
MaxPool1dWithIndicesBackward(MaxPool1dWithIndicesBackwardDescription),
MaxPool2d(MaxPool2dDescription),
MaxPool2dWithIndices(MaxPool2dWithIndicesDescription),
MaxPool2dWithIndicesBackward(MaxPool2dWithIndicesBackwardDescription),
Interpolate(InterpolateDescription),
InterpolateBackward(InterpolateBackwardDescription),
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
pub enum BaseOperationDescription {
ToDevice(TensorDescription),
Reshape(ReshapeDescription),
SwapDims(SwapDimsDescription),
Permute(PermuteOperationDescription),
Flip(FlipOperationDescription),
Expand(ExpandOperationDescription),
Slice(SliceOperationDescription),
SliceAssign(SliceAssignOperationDescription),
Equal(BinaryOperationDescription),
Repeat(RepeatOperationDescription),
Cat(CatOperationDescription),
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub enum NumericOperationDescription<E> {
Add(BinaryOperationDescription),
AddScalar(ScalarOperationDescription<E>),
Sub(BinaryOperationDescription),
SubScalar(ScalarOperationDescription<E>),
Div(BinaryOperationDescription),
DivScalar(ScalarOperationDescription<E>),
Mul(BinaryOperationDescription),
MulScalar(ScalarOperationDescription<E>),
Abs(UnaryOperationDescription),
Ones(TensorDescription),
Zeros(TensorDescription),
Full((TensorDescription, E)),
Gather(GatherOperationDescription),
Scatter(ScatterOperationDescription),
Select(SelectOperationDescription),
SelectAssign(SelectAssignOperationDescription),
MaskWhere(MaskWhereOperationDescription),
MaskFill(MaskFillOperationDescription<E>),
MeanDim(ScalarOperationDescription<usize>),
Mean(UnaryOperationDescription),
Sum(UnaryOperationDescription),
SumDim(ScalarOperationDescription<usize>),
Prod(UnaryOperationDescription),
ProdDim(ScalarOperationDescription<usize>),
EqualElem(ScalarOperationDescription<E>),
Greater(BinaryOperationDescription),
GreaterElem(ScalarOperationDescription<E>),
GreaterEqual(BinaryOperationDescription),
GreaterEqualElem(ScalarOperationDescription<E>),
Lower(BinaryOperationDescription),
LowerElem(ScalarOperationDescription<E>),
LowerEqual(BinaryOperationDescription),
LowerEqualElem(ScalarOperationDescription<E>),
ArgMax(ScalarOperationDescription<usize>),
ArgMin(ScalarOperationDescription<usize>),
Max(UnaryOperationDescription),
MaxDimWithIndices(ReduceDimWithIndicesDescription),
MinDimWithIndices(ReduceDimWithIndicesDescription),
Min(UnaryOperationDescription),
MaxDim(ScalarOperationDescription<usize>),
MinDim(ScalarOperationDescription<usize>),
Clamp(ClampOperationDescription<E>),
IntRandom(RandomOperationDescription),
Powf(BinaryOperationDescription),
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
pub enum IntOperationDescription {
IntoFloat(UnaryOperationDescription),
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
pub enum BoolOperationDescription {
IntoFloat(UnaryOperationDescription),
IntoInt(UnaryOperationDescription),
Not(UnaryOperationDescription),
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
pub struct SwapDimsDescription {
pub input: TensorDescription,
pub out: TensorDescription,
pub dim1: usize,
pub dim2: usize,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
pub struct PermuteOperationDescription {
pub input: TensorDescription,
pub out: TensorDescription,
pub axes: Vec<usize>,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
pub struct ExpandOperationDescription {
pub input: TensorDescription,
pub out: TensorDescription,
pub shape: Vec<usize>,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
pub struct FlipOperationDescription {
pub input: TensorDescription,
pub out: TensorDescription,
pub axes: Vec<usize>,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct RandomOperationDescription {
pub out: TensorDescription,
pub distribution: Distribution,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct ReshapeDescription {
pub input: TensorDescription,
pub out: TensorDescription,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct ExpandDescription {
pub input: TensorDescription,
pub out: TensorDescription,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct BinaryOperationDescription {
pub lhs: TensorDescription,
pub rhs: TensorDescription,
pub out: TensorDescription,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct UnaryOperationDescription {
pub input: TensorDescription,
pub out: TensorDescription,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct ScalarOperationDescription<E> {
pub lhs: TensorDescription,
pub rhs: E,
pub out: TensorDescription,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct GatherOperationDescription {
pub tensor: TensorDescription,
pub dim: usize,
pub indices: TensorDescription,
pub out: TensorDescription,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct ScatterOperationDescription {
pub tensor: TensorDescription,
pub dim: usize,
pub indices: TensorDescription,
pub value: TensorDescription,
pub out: TensorDescription,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct SelectOperationDescription {
pub tensor: TensorDescription,
pub dim: usize,
pub indices: TensorDescription,
pub out: TensorDescription,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct SelectAssignOperationDescription {
pub tensor: TensorDescription,
pub dim: usize,
pub indices: TensorDescription,
pub value: TensorDescription,
pub out: TensorDescription,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct SliceOperationDescription {
pub tensor: TensorDescription,
pub ranges: Vec<Range<usize>>,
pub out: TensorDescription,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct SliceAssignOperationDescription {
pub tensor: TensorDescription,
pub ranges: Vec<Range<usize>>,
pub value: TensorDescription,
pub out: TensorDescription,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct MaskWhereOperationDescription {
pub tensor: TensorDescription,
pub mask: TensorDescription,
pub value: TensorDescription,
pub out: TensorDescription,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct MaskFillOperationDescription<E> {
pub tensor: TensorDescription,
pub mask: TensorDescription,
pub value: E,
pub out: TensorDescription,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct ClampOperationDescription<E> {
pub tensor: TensorDescription,
pub min: E,
pub max: E,
pub out: TensorDescription,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct RepeatOperationDescription {
pub tensor: TensorDescription,
pub dim: usize,
pub times: usize,
pub out: TensorDescription,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct CatOperationDescription {
pub tensors: Vec<TensorDescription>,
pub dim: usize,
pub out: TensorDescription,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct ReduceDimWithIndicesDescription {
pub tensor: TensorDescription,
pub dim: usize,
pub out: TensorDescription,
pub out_indices: TensorDescription,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct EmbeddingDescription {
pub weights: TensorDescription,
pub indices: TensorDescription,
pub out: TensorDescription,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct EmbeddingBackwardDescription {
pub weights: TensorDescription,
pub out_grad: TensorDescription,
pub indices: TensorDescription,
pub out: TensorDescription,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct Conv1dDescription {
pub x: TensorDescription,
pub weight: TensorDescription,
pub bias: Option<TensorDescription>,
pub options: Conv1dOptionsDescription,
pub out: TensorDescription,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct Conv2dDescription {
pub x: TensorDescription,
pub weight: TensorDescription,
pub bias: Option<TensorDescription>,
pub options: Conv2dOptionsDescription,
pub out: TensorDescription,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct ConvTranspose1dDescription {
pub x: TensorDescription,
pub weight: TensorDescription,
pub bias: Option<TensorDescription>,
pub options: ConvTranspose1dOptionsDescription,
pub out: TensorDescription,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct ConvTranspose2dDescription {
pub x: TensorDescription,
pub weight: TensorDescription,
pub bias: Option<TensorDescription>,
pub options: ConvTranspose2dOptionsDescription,
pub out: TensorDescription,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct Conv1dOptionsDescription {
pub stride: [usize; 1],
pub padding: [usize; 1],
pub dilation: [usize; 1],
pub groups: usize,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct Conv2dOptionsDescription {
pub stride: [usize; 2],
pub padding: [usize; 2],
pub dilation: [usize; 2],
pub groups: usize,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct ConvTranspose1dOptionsDescription {
pub stride: [usize; 1],
pub padding: [usize; 1],
pub padding_out: [usize; 1],
pub dilation: [usize; 1],
pub groups: usize,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct ConvTranspose2dOptionsDescription {
pub stride: [usize; 2],
pub padding: [usize; 2],
pub padding_out: [usize; 2],
pub dilation: [usize; 2],
pub groups: usize,
}
impl From<ConvOptions<1>> for Conv1dOptionsDescription {
fn from(value: ConvOptions<1>) -> Self {
Self {
stride: value.stride,
padding: value.padding,
dilation: value.dilation,
groups: value.groups,
}
}
}
impl From<ConvOptions<2>> for Conv2dOptionsDescription {
fn from(value: ConvOptions<2>) -> Self {
Self {
stride: value.stride,
padding: value.padding,
dilation: value.dilation,
groups: value.groups,
}
}
}
impl From<ConvTransposeOptions<1>> for ConvTranspose1dOptionsDescription {
fn from(value: ConvTransposeOptions<1>) -> Self {
Self {
stride: value.stride,
padding: value.padding,
padding_out: value.padding_out,
dilation: value.dilation,
groups: value.groups,
}
}
}
impl From<ConvTransposeOptions<2>> for ConvTranspose2dOptionsDescription {
fn from(value: ConvTransposeOptions<2>) -> Self {
Self {
stride: value.stride,
padding: value.padding,
padding_out: value.padding_out,
dilation: value.dilation,
groups: value.groups,
}
}
}
impl From<Conv1dOptionsDescription> for ConvOptions<1> {
fn from(val: Conv1dOptionsDescription) -> Self {
ConvOptions {
stride: val.stride,
padding: val.padding,
dilation: val.dilation,
groups: val.groups,
}
}
}
impl From<Conv2dOptionsDescription> for ConvOptions<2> {
fn from(val: Conv2dOptionsDescription) -> Self {
ConvOptions {
stride: val.stride,
padding: val.padding,
dilation: val.dilation,
groups: val.groups,
}
}
}
impl From<ConvTranspose1dOptionsDescription> for ConvTransposeOptions<1> {
fn from(val: ConvTranspose1dOptionsDescription) -> Self {
ConvTransposeOptions {
stride: val.stride,
padding: val.padding,
padding_out: val.padding_out,
dilation: val.dilation,
groups: val.groups,
}
}
}
impl From<ConvTranspose2dOptionsDescription> for ConvTransposeOptions<2> {
fn from(val: ConvTranspose2dOptionsDescription) -> Self {
ConvTransposeOptions {
stride: val.stride,
padding: val.padding,
padding_out: val.padding_out,
dilation: val.dilation,
groups: val.groups,
}
}
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct AvgPool1dDescription {
pub x: TensorDescription,
pub kernel_size: usize,
pub stride: usize,
pub padding: usize,
pub count_include_pad: bool,
pub out: TensorDescription,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct AvgPool2dDescription {
pub x: TensorDescription,
pub kernel_size: [usize; 2],
pub stride: [usize; 2],
pub padding: [usize; 2],
pub count_include_pad: bool,
pub out: TensorDescription,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct AvgPool1dBackwardDescription {
pub x: TensorDescription,
pub grad: TensorDescription,
pub kernel_size: usize,
pub stride: usize,
pub padding: usize,
pub count_include_pad: bool,
pub out: TensorDescription,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct AvgPool2dBackwardDescription {
pub x: TensorDescription,
pub grad: TensorDescription,
pub kernel_size: [usize; 2],
pub stride: [usize; 2],
pub padding: [usize; 2],
pub count_include_pad: bool,
pub out: TensorDescription,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct AdaptiveAvgPool1dDescription {
pub x: TensorDescription,
pub output_size: usize,
pub out: TensorDescription,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct AdaptiveAvgPool2dDescription {
pub x: TensorDescription,
pub output_size: [usize; 2],
pub out: TensorDescription,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct AdaptiveAvgPool1dBackwardDescription {
pub x: TensorDescription,
pub grad: TensorDescription,
pub out: TensorDescription,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct AdaptiveAvgPool2dBackwardDescription {
pub x: TensorDescription,
pub grad: TensorDescription,
pub out: TensorDescription,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct MaxPool1dDescription {
pub x: TensorDescription,
pub kernel_size: usize,
pub stride: usize,
pub padding: usize,
pub dilation: usize,
pub out: TensorDescription,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct MaxPool1dWithIndicesDescription {
pub x: TensorDescription,
pub kernel_size: usize,
pub stride: usize,
pub padding: usize,
pub dilation: usize,
pub out: TensorDescription,
pub out_indices: TensorDescription,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct MaxPool1dWithIndicesBackwardDescription {
pub x: TensorDescription,
pub grad: TensorDescription,
pub indices: TensorDescription,
pub kernel_size: usize,
pub stride: usize,
pub padding: usize,
pub dilation: usize,
pub out: TensorDescription,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct MaxPool2dDescription {
pub x: TensorDescription,
pub kernel_size: [usize; 2],
pub stride: [usize; 2],
pub padding: [usize; 2],
pub dilation: [usize; 2],
pub out: TensorDescription,
}
#[allow(missing_docs)]
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
pub struct MaxPool2dWithIndicesDescription {
pub x: TensorDescription,
pub kernel_size: [usize; 2],
pub stride: [usize; 2],
pub padding: [usize; 2],
pub dilation: [usize; 2],
pub out: TensorDescription,
pub out_indices: TensorDescription,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct MaxPool2dWithIndicesBackwardDescription {
pub x: TensorDescription,
pub grad: TensorDescription,
pub indices: TensorDescription,
pub kernel_size: [usize; 2],
pub stride: [usize; 2],
pub padding: [usize; 2],
pub dilation: [usize; 2],
pub out: TensorDescription,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub enum InterpolateModeDescription {
Nearest,
Bilinear,
Bicubic,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct InterpolateOptionsDescription {
pub mode: InterpolateModeDescription,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct InterpolateDescription {
pub x: TensorDescription,
pub output_size: [usize; 2],
pub options: InterpolateOptionsDescription,
pub out: TensorDescription,
}
impl From<InterpolateModeDescription> for InterpolateMode {
fn from(val: InterpolateModeDescription) -> Self {
match val {
InterpolateModeDescription::Nearest => Self::Nearest,
InterpolateModeDescription::Bilinear => Self::Bilinear,
InterpolateModeDescription::Bicubic => Self::Bicubic,
}
}
}
impl From<InterpolateOptionsDescription> for InterpolateOptions {
fn from(val: InterpolateOptionsDescription) -> Self {
Self {
mode: val.mode.into(),
}
}
}
impl From<InterpolateMode> for InterpolateModeDescription {
fn from(val: InterpolateMode) -> Self {
match val {
InterpolateMode::Nearest => Self::Nearest,
InterpolateMode::Bilinear => Self::Bilinear,
InterpolateMode::Bicubic => Self::Bicubic,
}
}
}
impl From<InterpolateOptions> for InterpolateOptionsDescription {
fn from(val: InterpolateOptions) -> Self {
Self {
mode: val.mode.into(),
}
}
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct InterpolateBackwardDescription {
pub x: TensorDescription,
pub grad: TensorDescription,
pub output_size: [usize; 2],
pub options: InterpolateOptionsDescription,
pub out: TensorDescription,
}
impl OperationDescription {
pub(crate) fn nodes(&self) -> Vec<&TensorDescription> {
match self {
OperationDescription::BaseFloat(ops) => ops.nodes(),
OperationDescription::BaseInt(ops) => ops.nodes(),
OperationDescription::BaseBool(ops) => ops.nodes(),
OperationDescription::NumericFloat(ops) => ops.nodes(),
OperationDescription::NumericInt(ops) => ops.nodes(),
OperationDescription::Bool(ops) => ops.nodes(),
OperationDescription::Int(ops) => ops.nodes(),
OperationDescription::Float(ops) => ops.nodes(),
OperationDescription::Module(ops) => ops.nodes(),
}
}
}
impl BaseOperationDescription {
fn nodes(&self) -> Vec<&TensorDescription> {
match self {
BaseOperationDescription::ToDevice(desc) => vec![desc],
BaseOperationDescription::Reshape(desc) => {
vec![&desc.input, &desc.out]
}
BaseOperationDescription::SwapDims(desc) => {
vec![&desc.input, &desc.out]
}
BaseOperationDescription::Permute(desc) => {
vec![&desc.input, &desc.out]
}
BaseOperationDescription::Expand(desc) => {
vec![&desc.input, &desc.out]
}
BaseOperationDescription::Flip(desc) => {
vec![&desc.input, &desc.out]
}
BaseOperationDescription::Slice(desc) => {
vec![&desc.tensor, &desc.out]
}
BaseOperationDescription::SliceAssign(desc) => {
vec![&desc.tensor, &desc.value, &desc.out]
}
BaseOperationDescription::Equal(desc) => {
vec![&desc.lhs, &desc.rhs, &desc.out]
}
BaseOperationDescription::Repeat(desc) => {
vec![&desc.tensor, &desc.out]
}
BaseOperationDescription::Cat(desc) => desc.tensors.iter().collect(),
}
}
}
impl<E: Element> NumericOperationDescription<E> {
fn nodes(&self) -> Vec<&TensorDescription> {
match self {
NumericOperationDescription::Add(desc) => {
vec![&desc.lhs, &desc.rhs, &desc.out]
}
NumericOperationDescription::AddScalar(desc) => {
vec![&desc.lhs, &desc.out]
}
NumericOperationDescription::Sub(desc) => {
vec![&desc.lhs, &desc.rhs, &desc.out]
}
NumericOperationDescription::SubScalar(desc) => {
vec![&desc.lhs, &desc.out]
}
NumericOperationDescription::Mul(desc) => {
vec![&desc.lhs, &desc.rhs, &desc.out]
}
NumericOperationDescription::MulScalar(desc) => {
vec![&desc.lhs, &desc.out]
}
NumericOperationDescription::Div(desc) => {
vec![&desc.lhs, &desc.rhs, &desc.out]
}
NumericOperationDescription::DivScalar(desc) => {
vec![&desc.lhs, &desc.out]
}
NumericOperationDescription::Ones(desc) => vec![desc],
NumericOperationDescription::Gather(desc) => {
vec![&desc.tensor, &desc.indices, &desc.out]
}
NumericOperationDescription::Scatter(desc) => {
vec![&desc.tensor, &desc.indices, &desc.value, &desc.out]
}
NumericOperationDescription::Select(desc) => {
vec![&desc.tensor, &desc.indices, &desc.out]
}
NumericOperationDescription::SelectAssign(desc) => {
vec![&desc.tensor, &desc.indices, &desc.value, &desc.out]
}
NumericOperationDescription::MaskWhere(desc) => {
vec![&desc.tensor, &desc.mask, &desc.value, &desc.out]
}
NumericOperationDescription::MaskFill(desc) => {
vec![&desc.tensor, &desc.mask, &desc.out]
}
NumericOperationDescription::EqualElem(desc) => {
vec![&desc.lhs, &desc.out]
}
NumericOperationDescription::GreaterElem(desc) => {
vec![&desc.lhs, &desc.out]
}
NumericOperationDescription::GreaterEqualElem(desc) => {
vec![&desc.lhs, &desc.out]
}
NumericOperationDescription::LowerElem(desc) => {
vec![&desc.lhs, &desc.out]
}
NumericOperationDescription::LowerEqualElem(desc) => {
vec![&desc.lhs, &desc.out]
}
NumericOperationDescription::Greater(desc) => {
vec![&desc.lhs, &desc.rhs, &desc.out]
}
NumericOperationDescription::GreaterEqual(desc) => {
vec![&desc.lhs, &desc.rhs, &desc.out]
}
NumericOperationDescription::Lower(desc) => {
vec![&desc.lhs, &desc.rhs, &desc.out]
}
NumericOperationDescription::LowerEqual(desc) => {
vec![&desc.lhs, &desc.rhs, &desc.out]
}
NumericOperationDescription::ArgMax(desc) => {
vec![&desc.lhs, &desc.out]
}
NumericOperationDescription::ArgMin(desc) => {
vec![&desc.lhs, &desc.out]
}
NumericOperationDescription::Clamp(desc) => {
vec![&desc.tensor, &desc.out]
}
NumericOperationDescription::Abs(desc) => {
vec![&desc.input, &desc.out]
}
NumericOperationDescription::Zeros(desc) => vec![desc],
NumericOperationDescription::Full(desc) => vec![&desc.0],
NumericOperationDescription::MeanDim(desc) => {
vec![&desc.lhs, &desc.out]
}
NumericOperationDescription::Mean(desc) => {
vec![&desc.input, &desc.out]
}
NumericOperationDescription::Sum(desc) => {
vec![&desc.input, &desc.out]
}
NumericOperationDescription::SumDim(desc) => {
vec![&desc.lhs, &desc.out]
}
NumericOperationDescription::Prod(desc) => {
vec![&desc.input, &desc.out]
}
NumericOperationDescription::ProdDim(desc) => {
vec![&desc.lhs, &desc.out]
}
NumericOperationDescription::Max(desc) => {
vec![&desc.input, &desc.out]
}
NumericOperationDescription::MaxDimWithIndices(desc) => {
vec![&desc.tensor, &desc.out_indices, &desc.out]
}
NumericOperationDescription::MinDimWithIndices(desc) => {
vec![&desc.tensor, &desc.out_indices, &desc.out]
}
NumericOperationDescription::Min(desc) => {
vec![&desc.input, &desc.out]
}
NumericOperationDescription::MaxDim(desc) => {
vec![&desc.lhs, &desc.out]
}
NumericOperationDescription::MinDim(desc) => {
vec![&desc.lhs, &desc.out]
}
NumericOperationDescription::IntRandom(desc) => {
vec![&desc.out]
}
NumericOperationDescription::Powf(desc) => {
vec![&desc.lhs, &desc.rhs, &desc.out]
}
}
}
}
impl FloatOperationDescription {
fn nodes(&self) -> Vec<&TensorDescription> {
match self {
FloatOperationDescription::Matmul(desc) => {
vec![&desc.lhs, &desc.rhs, &desc.out]
}
FloatOperationDescription::Random(desc) => vec![&desc.out],
FloatOperationDescription::Exp(desc) => vec![&desc.input, &desc.out],
FloatOperationDescription::Log(desc) => vec![&desc.input, &desc.out],
FloatOperationDescription::Log1p(desc) => vec![&desc.input, &desc.out],
FloatOperationDescription::Erf(desc) => vec![&desc.input, &desc.out],
FloatOperationDescription::Recip(desc) => vec![&desc.input, &desc.out],
FloatOperationDescription::PowfScalar(desc) => vec![&desc.lhs, &desc.out],
FloatOperationDescription::Sqrt(desc) => vec![&desc.input, &desc.out],
FloatOperationDescription::Cos(desc) => vec![&desc.input, &desc.out],
FloatOperationDescription::Sin(desc) => vec![&desc.input, &desc.out],
FloatOperationDescription::Tanh(desc) => vec![&desc.input, &desc.out],
FloatOperationDescription::IntoInt(desc) => vec![&desc.input, &desc.out],
}
}
}
impl IntOperationDescription {
fn nodes(&self) -> Vec<&TensorDescription> {
match self {
IntOperationDescription::IntoFloat(desc) => vec![&desc.input, &desc.out],
}
}
}
impl BoolOperationDescription {
fn nodes(&self) -> Vec<&TensorDescription> {
match self {
BoolOperationDescription::IntoFloat(desc) => vec![&desc.input, &desc.out],
BoolOperationDescription::IntoInt(desc) => vec![&desc.input, &desc.out],
BoolOperationDescription::Not(desc) => vec![&desc.input, &desc.out],
}
}
}
impl ModuleOperationDescription {
fn nodes(&self) -> Vec<&TensorDescription> {
match self {
ModuleOperationDescription::Embedding(desc) => {
vec![&desc.weights, &desc.indices, &desc.out]
}
ModuleOperationDescription::EmbeddingBackward(desc) => {
vec![&desc.weights, &desc.out_grad, &desc.indices, &desc.out]
}
ModuleOperationDescription::Conv1d(desc) => {
if let Some(bias) = &desc.bias {
vec![&desc.x, &desc.weight, &bias, &desc.out]
} else {
vec![&desc.x, &desc.weight, &desc.out]
}
}
ModuleOperationDescription::Conv2d(desc) => {
if let Some(bias) = &desc.bias {
vec![&desc.x, &desc.weight, &bias, &desc.out]
} else {
vec![&desc.x, &desc.weight, &desc.out]
}
}
ModuleOperationDescription::ConvTranspose1d(desc) => {
if let Some(bias) = &desc.bias {
vec![&desc.x, &desc.weight, &bias, &desc.out]
} else {
vec![&desc.x, &desc.weight, &desc.out]
}
}
ModuleOperationDescription::ConvTranspose2d(desc) => {
if let Some(bias) = &desc.bias {
vec![&desc.x, &desc.weight, &bias, &desc.out]
} else {
vec![&desc.x, &desc.weight, &desc.out]
}
}
ModuleOperationDescription::AvgPool1d(desc) => {
vec![&desc.x, &desc.out]
}
ModuleOperationDescription::AvgPool2d(desc) => {
vec![&desc.x, &desc.out]
}
ModuleOperationDescription::AvgPool1dBackward(desc) => {
vec![&desc.x, &desc.out, &desc.grad]
}
ModuleOperationDescription::AvgPool2dBackward(desc) => {
vec![&desc.x, &desc.out, &desc.grad]
}
ModuleOperationDescription::AdaptiveAvgPool1d(desc) => {
vec![&desc.x, &desc.out]
}
ModuleOperationDescription::AdaptiveAvgPool2d(desc) => {
vec![&desc.x, &desc.out]
}
ModuleOperationDescription::AdaptiveAvgPool1dBackward(desc) => {
vec![&desc.x, &desc.out, &desc.grad]
}
ModuleOperationDescription::AdaptiveAvgPool2dBackward(desc) => {
vec![&desc.x, &desc.out, &desc.grad]
}
ModuleOperationDescription::MaxPool1d(desc) => {
vec![&desc.x, &desc.out]
}
ModuleOperationDescription::MaxPool1dWithIndices(desc) => {
vec![&desc.x, &desc.out, &desc.out_indices]
}
ModuleOperationDescription::MaxPool1dWithIndicesBackward(desc) => {
vec![&desc.x, &desc.out, &desc.indices, &desc.grad]
}
ModuleOperationDescription::MaxPool2d(desc) => {
vec![&desc.x, &desc.out]
}
ModuleOperationDescription::MaxPool2dWithIndices(desc) => {
vec![&desc.x, &desc.out, &desc.out_indices]
}
ModuleOperationDescription::MaxPool2dWithIndicesBackward(desc) => {
vec![&desc.x, &desc.out, &desc.indices, &desc.grad]
}
ModuleOperationDescription::Interpolate(desc) => {
vec![&desc.x, &desc.out]
}
ModuleOperationDescription::InterpolateBackward(desc) => {
vec![&desc.x, &desc.out, &desc.grad]
}
}
}
}
impl core::hash::Hash for RandomOperationDescription {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.out.hash(state);
match self.distribution {
Distribution::Default => 1u8.hash(state),
Distribution::Bernoulli(_) => 2u8.hash(state),
Distribution::Uniform(_, _) => 3u8.hash(state),
Distribution::Normal(_, _) => 4u8.hash(state),
}
}
}
impl<E> core::hash::Hash for ScalarOperationDescription<E> {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.lhs.hash(state);
self.out.hash(state);
}
}
impl<E> core::hash::Hash for MaskFillOperationDescription<E> {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.tensor.hash(state);
self.mask.hash(state);
self.out.hash(state);
}
}
impl<E> core::hash::Hash for ClampOperationDescription<E> {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.tensor.hash(state);
self.out.hash(state);
}
}
impl<E> core::hash::Hash for NumericOperationDescription<E> {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
match self {
NumericOperationDescription::Add(desc) => desc.hash(state),
NumericOperationDescription::AddScalar(desc) => desc.hash(state),
NumericOperationDescription::Sub(desc) => desc.hash(state),
NumericOperationDescription::SubScalar(desc) => desc.hash(state),
NumericOperationDescription::Div(desc) => desc.hash(state),
NumericOperationDescription::DivScalar(desc) => desc.hash(state),
NumericOperationDescription::Mul(desc) => desc.hash(state),
NumericOperationDescription::MulScalar(desc) => desc.hash(state),
NumericOperationDescription::Abs(desc) => desc.hash(state),
NumericOperationDescription::Ones(desc) => desc.hash(state),
NumericOperationDescription::Zeros(desc) => desc.hash(state),
NumericOperationDescription::Full(desc) => desc.0.hash(state),
NumericOperationDescription::Gather(desc) => desc.hash(state),
NumericOperationDescription::Scatter(desc) => desc.hash(state),
NumericOperationDescription::Select(desc) => desc.hash(state),
NumericOperationDescription::SelectAssign(desc) => desc.hash(state),
NumericOperationDescription::MaskWhere(desc) => desc.hash(state),
NumericOperationDescription::MaskFill(desc) => desc.hash(state),
NumericOperationDescription::MeanDim(desc) => desc.hash(state),
NumericOperationDescription::Mean(desc) => desc.hash(state),
NumericOperationDescription::Sum(desc) => desc.hash(state),
NumericOperationDescription::SumDim(desc) => desc.hash(state),
NumericOperationDescription::Prod(desc) => desc.hash(state),
NumericOperationDescription::ProdDim(desc) => desc.hash(state),
NumericOperationDescription::EqualElem(desc) => desc.hash(state),
NumericOperationDescription::Greater(desc) => desc.hash(state),
NumericOperationDescription::GreaterElem(desc) => desc.hash(state),
NumericOperationDescription::GreaterEqual(desc) => desc.hash(state),
NumericOperationDescription::GreaterEqualElem(desc) => desc.hash(state),
NumericOperationDescription::Lower(desc) => desc.hash(state),
NumericOperationDescription::LowerElem(desc) => desc.hash(state),
NumericOperationDescription::LowerEqual(desc) => desc.hash(state),
NumericOperationDescription::LowerEqualElem(desc) => desc.hash(state),
NumericOperationDescription::ArgMax(desc) => desc.hash(state),
NumericOperationDescription::ArgMin(desc) => desc.hash(state),
NumericOperationDescription::Max(desc) => desc.hash(state),
NumericOperationDescription::MaxDimWithIndices(desc) => desc.hash(state),
NumericOperationDescription::MinDimWithIndices(desc) => desc.hash(state),
NumericOperationDescription::Min(desc) => desc.hash(state),
NumericOperationDescription::MaxDim(desc) => desc.hash(state),
NumericOperationDescription::MinDim(desc) => desc.hash(state),
NumericOperationDescription::Clamp(desc) => desc.hash(state),
NumericOperationDescription::IntRandom(desc) => desc.hash(state),
NumericOperationDescription::Powf(desc) => desc.hash(state),
}
}
}