bb_ops/backends/cpu/opset.rs
1//! Opset declaration for the `CpuBackend`.
2//!
3//! The backend's `atomic_opset` mirrors
4//! `bb_ir::tensor_primitives::TENSOR_PRIMITIVES_OPS` exactly — the
5//! 30-op primitive floor every `Backend` impl must declare. The
6//! ops that aren't primitives but ARE backed by ndarray kernels in
7//! this crate (Relu, Sigmoid, Tanh, Softmax, LeakyRelu, Gelu, Dot,
8//! Zeros, Ones, GlobalAveragePool) get listed via
9//! `extension_opsets()`. Lying entries the prior 49-op declaration
10//! carried (BatchNorm, LayerNorm, Conv, MaxPool, AveragePool,
11//! Scatter, If, Loop) are dropped — they have no kernel.
12//!
13//! `BackendSubgraph` is the framework's collapse-carrier op; it
14//! lives in `ai.bytesandbrains.framework` and routes through
15//! `invoke_backend_subgraph`, not this opset.
16//!
17//! Each entry carries `type_relations` so the TypeSolver narrows
18//! the participating values' TypeNodes. The canonical relation
19//! slices live in `bb_ir::types::common_relations`.
20
21use bb_ir::types::{
22 common_relations::{
23 BROADCAST_BINARY, ELEMENTWISE, MATMUL_BINARY, NO_RELATIONS, REDUCE_AXIS, UNARY_SAME_ELEMENT,
24 },
25 relations::TypeRelation,
26};
27use bb_runtime::atomic::{AtomicOpDecl, AtomicOpKind, AtomicOpsetDecl};
28
29/// `ai.onnx` opset domain — primitives + extension ops.
30pub const ONNX_DOMAIN: &str = "ai.onnx";
31
32/// Primitive-floor opset version.
33pub const ONNX_VERSION: i64 = 1;
34
35/// Backend-shipped extension version. Separate from the primitive
36/// floor so a future opset bump on either side stays independent.
37pub const EXTENSION_VERSION: i64 = 1;
38
39/// Opset domain for the activations + creation + indexing extras
40/// the CpuBackend ships. Same canonical `ai.onnx` namespace; the
41/// distinct *opset* (different version) keeps the floor + extras
42/// inspectable as separate declarations.
43pub const EXTENSION_DOMAIN: &str = "ai.onnx";
44
45/// 30-entry primitive-floor opset returned by
46/// `BackendRuntime::atomic_opset`. Matches
47/// `bb_ir::tensor_primitives::TENSOR_PRIMITIVES_OPS` element-for-
48/// element.
49pub static PRIMITIVE_OPS: &[AtomicOpDecl] = &[
50 // Arithmetic (6) — broadcast binary on same element type.
51 op("Add", BROADCAST_BINARY),
52 op("Sub", BROADCAST_BINARY),
53 op("Mul", BROADCAST_BINARY),
54 op("Div", BROADCAST_BINARY),
55 op("Neg", ELEMENTWISE),
56 op("Abs", ELEMENTWISE),
57 // Math (4) — Sqrt/Exp/Log are elementwise; Pow is broadcast binary.
58 op("Sqrt", ELEMENTWISE),
59 op("Pow", BROADCAST_BINARY),
60 op("Exp", ELEMENTWISE),
61 op("Log", ELEMENTWISE),
62 // Linear algebra (1) — same element type; shape is matmul-specific.
63 op("MatMul", MATMUL_BINARY),
64 // Reductions (4).
65 op("ReduceSum", REDUCE_AXIS),
66 op("ReduceMean", REDUCE_AXIS),
67 op("ReduceMax", REDUCE_AXIS),
68 op("ReduceMin", REDUCE_AXIS),
69 // Shape (9) — Reshape/Transpose/Slice/Squeeze/Unsqueeze preserve
70 // element type, change shape; Identity is a true pass-through;
71 // Concat/Split are variadic + Cast is attribute-driven, all left
72 // unconstrained until a Custom relation lands.
73 op("Reshape", UNARY_SAME_ELEMENT),
74 op("Transpose", UNARY_SAME_ELEMENT),
75 op("Concat", NO_RELATIONS),
76 op("Slice", UNARY_SAME_ELEMENT),
77 op("Split", NO_RELATIONS),
78 op("Squeeze", UNARY_SAME_ELEMENT),
79 op("Unsqueeze", UNARY_SAME_ELEMENT),
80 op("Identity", ELEMENTWISE),
81 op("Cast", NO_RELATIONS),
82 // Comparison (3) — element-wise inputs share type; output is
83 // bool. Leave unconstrained until the lattice ships a `bool`
84 // tensor leaf.
85 op("Equal", NO_RELATIONS),
86 op("Greater", NO_RELATIONS),
87 op("Less", NO_RELATIONS),
88 // Conditional (1).
89 op("Where", NO_RELATIONS),
90 // Creation (1) — value comes from an embedded `TensorProto`
91 // attribute, so the type is attribute-driven.
92 op("Constant", NO_RELATIONS),
93 // Indexing (1) — Gather mixes tensor + index types.
94 op("Gather", NO_RELATIONS),
95];
96
97/// Non-primitive ops the CpuBackend backs with ndarray kernels.
98/// Surfaces via `BackendRuntime::extension_opsets()` so the
99/// install-time check classifies them correctly (they're NOT in
100/// the primitive floor; users who bind a different backend may
101/// not get them).
102pub static EXTENSION_OPS: &[AtomicOpDecl] = &[
103 // Activations — pure element-wise; element type + shape
104 // preserved.
105 op("Relu", ELEMENTWISE),
106 op("Sigmoid", ELEMENTWISE),
107 op("Tanh", ELEMENTWISE),
108 op("Softmax", ELEMENTWISE),
109 op("LeakyRelu", ELEMENTWISE),
110 op("Gelu", ELEMENTWISE),
111 // Linear algebra extras — same element type across operands;
112 // shape is matmul-specific. `Gemm` takes an optional bias
113 // (3-input variadic) so its element-type relation is captured
114 // by the 2-operand `MATMUL_BINARY` and the optional `c` falls
115 // out of the constraint until a Custom relation lands.
116 op("Dot", MATMUL_BINARY),
117 op("Gemm", MATMUL_BINARY),
118 // Creation extras — attribute-driven shape; element type
119 // determined by the `dtype` attribute (defaulted to f32 today).
120 op("Zeros", NO_RELATIONS),
121 op("Ones", NO_RELATIONS),
122 // Pooling — element type preserved; spatial dims collapse, but
123 // the element-type constraint holds via `ELEMENTWISE`.
124 op("GlobalAveragePool", ELEMENTWISE),
125];
126
127/// Primitive-floor opset declaration. Returned by
128/// `CpuBackend::atomic_opset()`.
129pub const ONNX_V1_OPSET: AtomicOpsetDecl = AtomicOpsetDecl {
130 domain: ONNX_DOMAIN,
131 version: ONNX_VERSION,
132 ops: PRIMITIVE_OPS,
133};
134
135/// Extension opset declaration. Returned alongside the primitive
136/// floor by `CpuBackend::extension_opsets()`.
137pub const EXTENSION_OPSET: AtomicOpsetDecl = AtomicOpsetDecl {
138 domain: EXTENSION_DOMAIN,
139 version: EXTENSION_VERSION,
140 ops: EXTENSION_OPS,
141};
142
143const fn op(name: &'static str, type_relations: &'static [TypeRelation]) -> AtomicOpDecl {
144 AtomicOpDecl {
145 name,
146 inputs: &[],
147 outputs: &[],
148 kind: AtomicOpKind::Immediate,
149 type_relations,
150 }
151}
152