Skip to main content

bb_ops/backends/cpu/
opset.rs

1//! Opset declaration for the `CpuBackend`.
2//!
3//! The backend's `atomic_opset` mirrors
4//! `bb_ir::tensor_primitives::TENSOR_PRIMITIVES_OPS` exactly — the
5//! 30-op primitive floor every `Backend` impl must declare. The
6//! ops that aren't primitives but ARE backed by ndarray kernels in
7//! this crate (Relu, Sigmoid, Tanh, Softmax, LeakyRelu, Gelu, Dot,
8//! Zeros, Ones, GlobalAveragePool) get listed via
9//! `extension_opsets()`. Lying entries the prior 49-op declaration
10//! carried (BatchNorm, LayerNorm, Conv, MaxPool, AveragePool,
11//! Scatter, If, Loop) are dropped — they have no kernel.
12//!
13//! `BackendSubgraph` is the framework's collapse-carrier op; it
14//! lives in `ai.bytesandbrains.framework` and routes through
15//! `invoke_backend_subgraph`, not this opset.
16//!
17//! Each entry carries `type_relations` so the TypeSolver narrows
18//! the participating values' TypeNodes. The canonical relation
19//! slices live in `bb_ir::types::common_relations`.
20
21use bb_ir::types::{
22    common_relations::{
23        BROADCAST_BINARY, ELEMENTWISE, MATMUL_BINARY, NO_RELATIONS, REDUCE_AXIS, UNARY_SAME_ELEMENT,
24    },
25    relations::TypeRelation,
26};
27use bb_runtime::atomic::{AtomicOpDecl, AtomicOpKind, AtomicOpsetDecl};
28
29/// `ai.onnx` opset domain — primitives + extension ops.
30pub const ONNX_DOMAIN: &str = "ai.onnx";
31
32/// Primitive-floor opset version.
33pub const ONNX_VERSION: i64 = 1;
34
35/// Backend-shipped extension version. Separate from the primitive
36/// floor so a future opset bump on either side stays independent.
37pub const EXTENSION_VERSION: i64 = 1;
38
39/// Opset domain for the activations + creation + indexing extras
40/// the CpuBackend ships. Same canonical `ai.onnx` namespace; the
41/// distinct *opset* (different version) keeps the floor + extras
42/// inspectable as separate declarations.
43pub const EXTENSION_DOMAIN: &str = "ai.onnx";
44
45/// 30-entry primitive-floor opset returned by
46/// `BackendRuntime::atomic_opset`. Matches
47/// `bb_ir::tensor_primitives::TENSOR_PRIMITIVES_OPS` element-for-
48/// element.
49pub static PRIMITIVE_OPS: &[AtomicOpDecl] = &[
50    // Arithmetic (6) — broadcast binary on same element type.
51    op("Add", BROADCAST_BINARY),
52    op("Sub", BROADCAST_BINARY),
53    op("Mul", BROADCAST_BINARY),
54    op("Div", BROADCAST_BINARY),
55    op("Neg", ELEMENTWISE),
56    op("Abs", ELEMENTWISE),
57    // Math (4) — Sqrt/Exp/Log are elementwise; Pow is broadcast binary.
58    op("Sqrt", ELEMENTWISE),
59    op("Pow", BROADCAST_BINARY),
60    op("Exp", ELEMENTWISE),
61    op("Log", ELEMENTWISE),
62    // Linear algebra (1) — same element type; shape is matmul-specific.
63    op("MatMul", MATMUL_BINARY),
64    // Reductions (4).
65    op("ReduceSum", REDUCE_AXIS),
66    op("ReduceMean", REDUCE_AXIS),
67    op("ReduceMax", REDUCE_AXIS),
68    op("ReduceMin", REDUCE_AXIS),
69    // Shape (9) — Reshape/Transpose/Slice/Squeeze/Unsqueeze preserve
70    // element type, change shape; Identity is a true pass-through;
71    // Concat/Split are variadic + Cast is attribute-driven, all left
72    // unconstrained until a Custom relation lands.
73    op("Reshape", UNARY_SAME_ELEMENT),
74    op("Transpose", UNARY_SAME_ELEMENT),
75    op("Concat", NO_RELATIONS),
76    op("Slice", UNARY_SAME_ELEMENT),
77    op("Split", NO_RELATIONS),
78    op("Squeeze", UNARY_SAME_ELEMENT),
79    op("Unsqueeze", UNARY_SAME_ELEMENT),
80    op("Identity", ELEMENTWISE),
81    op("Cast", NO_RELATIONS),
82    // Comparison (3) — element-wise inputs share type; output is
83    // bool. Leave unconstrained until the lattice ships a `bool`
84    // tensor leaf.
85    op("Equal", NO_RELATIONS),
86    op("Greater", NO_RELATIONS),
87    op("Less", NO_RELATIONS),
88    // Conditional (1).
89    op("Where", NO_RELATIONS),
90    // Creation (1) — value comes from an embedded `TensorProto`
91    // attribute, so the type is attribute-driven.
92    op("Constant", NO_RELATIONS),
93    // Indexing (1) — Gather mixes tensor + index types.
94    op("Gather", NO_RELATIONS),
95];
96
97/// Non-primitive ops the CpuBackend backs with ndarray kernels.
98/// Surfaces via `BackendRuntime::extension_opsets()` so the
99/// install-time check classifies them correctly (they're NOT in
100/// the primitive floor; users who bind a different backend may
101/// not get them).
102pub static EXTENSION_OPS: &[AtomicOpDecl] = &[
103    // Activations — pure element-wise; element type + shape
104    // preserved.
105    op("Relu", ELEMENTWISE),
106    op("Sigmoid", ELEMENTWISE),
107    op("Tanh", ELEMENTWISE),
108    op("Softmax", ELEMENTWISE),
109    op("LeakyRelu", ELEMENTWISE),
110    op("Gelu", ELEMENTWISE),
111    // Linear algebra extras — same element type across operands;
112    // shape is matmul-specific. `Gemm` takes an optional bias
113    // (3-input variadic) so its element-type relation is captured
114    // by the 2-operand `MATMUL_BINARY` and the optional `c` falls
115    // out of the constraint until a Custom relation lands.
116    op("Dot", MATMUL_BINARY),
117    op("Gemm", MATMUL_BINARY),
118    // Creation extras — attribute-driven shape; element type
119    // determined by the `dtype` attribute (defaulted to f32 today).
120    op("Zeros", NO_RELATIONS),
121    op("Ones", NO_RELATIONS),
122    // Pooling — element type preserved; spatial dims collapse, but
123    // the element-type constraint holds via `ELEMENTWISE`.
124    op("GlobalAveragePool", ELEMENTWISE),
125];
126
127/// Primitive-floor opset declaration. Returned by
128/// `CpuBackend::atomic_opset()`.
129pub const ONNX_V1_OPSET: AtomicOpsetDecl = AtomicOpsetDecl {
130    domain: ONNX_DOMAIN,
131    version: ONNX_VERSION,
132    ops: PRIMITIVE_OPS,
133};
134
135/// Extension opset declaration. Returned alongside the primitive
136/// floor by `CpuBackend::extension_opsets()`.
137pub const EXTENSION_OPSET: AtomicOpsetDecl = AtomicOpsetDecl {
138    domain: EXTENSION_DOMAIN,
139    version: EXTENSION_VERSION,
140    ops: EXTENSION_OPS,
141};
142
143const fn op(name: &'static str, type_relations: &'static [TypeRelation]) -> AtomicOpDecl {
144    AtomicOpDecl {
145        name,
146        inputs: &[],
147        outputs: &[],
148        kind: AtomicOpKind::Immediate,
149        type_relations,
150    }
151}
152