use onnx_ir::{Argument, Node};
use proc_macro2::TokenStream;
use super::node_traits::NodeCodegen;
use crate::burn::{BurnImports, Field};
use burn_store::TensorSnapshot;
macro_rules! impl_node_codegen_dispatch {
($($variant:ident),* $(,)?) => {
impl NodeCodegen for Node {
fn inputs(&self) -> &[Argument] {
match self {
$(Node::$variant(n) => n.inputs(),)*
_ => panic!("Unsupported node type for inputs: {:?}", self),
}
}
fn outputs(&self) -> &[Argument] {
match self {
$(Node::$variant(n) => n.outputs(),)*
_ => panic!("Unsupported node type for outputs: {:?}", self),
}
}
fn forward(&self, scope: &mut crate::burn::scope::ScopeAtPosition<'_>) -> TokenStream {
match self {
$(Node::$variant(n) => n.forward(scope),)*
_ => panic!("Unsupported node type for forward: {:?}", self),
}
}
fn field(&self) -> Option<Field> {
match self {
$(Node::$variant(n) => n.field(),)*
_ => None,
}
}
fn register_imports(&self, imports: &mut BurnImports) {
match self {
$(Node::$variant(n) => n.register_imports(imports),)*
_ => {}
}
}
fn collect_snapshots(&self, field_name: &str) -> Vec<TensorSnapshot> {
match self {
$(Node::$variant(n) => n.collect_snapshots(field_name),)*
_ => vec![],
}
}
}
};
}
impl_node_codegen_dispatch! {
Add,
Sub,
Mul,
Div,
Max,
Min,
MatMul,
Einsum,
Equal,
Greater,
GreaterOrEqual,
Less,
LessOrEqual,
And,
Or,
Xor,
Abs,
Acos,
Acosh,
Asin,
Asinh,
Atan,
Atanh,
Ceil,
Cos,
Cosh,
Erf,
Exp,
Floor,
Identity,
Log,
Neg,
Not,
Reciprocal,
Round,
Sigmoid,
Sign,
Sin,
Sinh,
Sqrt,
Tan,
Tanh,
Relu,
Gelu,
Mish,
LeakyRelu,
HardSigmoid,
HardSwish,
Softmax,
LogSoftmax,
PRelu,
Celu,
Elu,
Selu,
Softplus,
Softsign,
ThresholdedRelu,
Swish,
Hardmax,
Shrink,
Reshape,
Flatten,
Squeeze,
Unsqueeze,
Transpose,
Shape,
Size,
Concat,
Split,
Slice,
Gather,
GatherElements,
GatherND,
ScatterElements,
ScatterND,
Tile,
Expand,
Pad,
Conv1d,
Conv2d,
Conv3d,
ConvTranspose1d,
ConvTranspose2d,
ConvTranspose3d,
DeformConv,
Col2Im,
AveragePool1d,
AveragePool2d,
LpPool1d,
LpPool2d,
MaxPool1d,
MaxPool2d,
GlobalAveragePool,
BatchNormalization,
LayerNormalization,
Lrn,
GroupNormalization,
InstanceNormalization,
Cast,
CastLike,
Clip,
CumSum,
Dropout,
Where,
ArgMax,
ArgMin,
TopK,
NonZero,
OneHot,
Pow,
Mod,
Trilu,
BitShift,
BitwiseAnd,
BitwiseOr,
BitwiseXor,
BitwiseNot,
Sum,
Mean,
Gemm,
Linear,
MatMulInteger,
DequantizeLinear,
QuantizeLinear,
Constant,
ConstantOfShape,
EyeLike,
Range,
RandomNormal,
RandomUniform,
RandomNormalLike,
RandomUniformLike,
Bernoulli,
DepthToSpace,
SpaceToDepth,
Resize,
GridSample,
Det,
IsInf,
IsNaN,
Attention,
If,
Loop,
Scan,
Lstm,
Rnn,
Gru,
ReduceMax,
ReduceMin,
ReduceMean,
ReduceProd,
ReduceSum,
ReduceSumSquare,
ReduceL1,
ReduceL2,
ReduceLogSum,
ReduceLogSumExp,
}