use bb_ir::types::{
common_relations::{
BROADCAST_BINARY, ELEMENTWISE, MATMUL_BINARY, NO_RELATIONS, REDUCE_AXIS, UNARY_SAME_ELEMENT,
},
relations::TypeRelation,
};
use bb_runtime::atomic::{AtomicOpDecl, AtomicOpKind, AtomicOpsetDecl};
pub const ONNX_DOMAIN: &str = "ai.onnx";
pub const ONNX_VERSION: i64 = 1;
pub const EXTENSION_VERSION: i64 = 1;
pub const EXTENSION_DOMAIN: &str = "ai.onnx";
pub static PRIMITIVE_OPS: &[AtomicOpDecl] = &[
op("Add", BROADCAST_BINARY),
op("Sub", BROADCAST_BINARY),
op("Mul", BROADCAST_BINARY),
op("Div", BROADCAST_BINARY),
op("Neg", ELEMENTWISE),
op("Abs", ELEMENTWISE),
op("Sqrt", ELEMENTWISE),
op("Pow", BROADCAST_BINARY),
op("Exp", ELEMENTWISE),
op("Log", ELEMENTWISE),
op("MatMul", MATMUL_BINARY),
op("ReduceSum", REDUCE_AXIS),
op("ReduceMean", REDUCE_AXIS),
op("ReduceMax", REDUCE_AXIS),
op("ReduceMin", REDUCE_AXIS),
op("Reshape", UNARY_SAME_ELEMENT),
op("Transpose", UNARY_SAME_ELEMENT),
op("Concat", NO_RELATIONS),
op("Slice", UNARY_SAME_ELEMENT),
op("Split", NO_RELATIONS),
op("Squeeze", UNARY_SAME_ELEMENT),
op("Unsqueeze", UNARY_SAME_ELEMENT),
op("Identity", ELEMENTWISE),
op("Cast", NO_RELATIONS),
op("Equal", NO_RELATIONS),
op("Greater", NO_RELATIONS),
op("Less", NO_RELATIONS),
op("Where", NO_RELATIONS),
op("Constant", NO_RELATIONS),
op("Gather", NO_RELATIONS),
];
pub static EXTENSION_OPS: &[AtomicOpDecl] = &[
op("Relu", ELEMENTWISE),
op("Sigmoid", ELEMENTWISE),
op("Tanh", ELEMENTWISE),
op("Softmax", ELEMENTWISE),
op("LeakyRelu", ELEMENTWISE),
op("Gelu", ELEMENTWISE),
op("Dot", MATMUL_BINARY),
op("Gemm", MATMUL_BINARY),
op("Zeros", NO_RELATIONS),
op("Ones", NO_RELATIONS),
op("GlobalAveragePool", ELEMENTWISE),
];
pub const ONNX_V1_OPSET: AtomicOpsetDecl = AtomicOpsetDecl {
domain: ONNX_DOMAIN,
version: ONNX_VERSION,
ops: PRIMITIVE_OPS,
};
pub const EXTENSION_OPSET: AtomicOpsetDecl = AtomicOpsetDecl {
domain: EXTENSION_DOMAIN,
version: EXTENSION_VERSION,
ops: EXTENSION_OPS,
};
const fn op(name: &'static str, type_relations: &'static [TypeRelation]) -> AtomicOpDecl {
AtomicOpDecl {
name,
inputs: &[],
outputs: &[],
kind: AtomicOpKind::Immediate,
type_relations,
}
}