bb_ir/tensor_primitives.rs
1//! The framework's curated floor of primitive tensor ops.
2//!
3//! Every `Backend` impl MUST declare an `atomic_opset()` whose ops
4//! list contains every entry in [`TENSOR_PRIMITIVES_OPS`]. Ops not
5//! in this set (Relu, Sigmoid, Tanh, Softmax, LeakyRelu, Gelu,
6//! Conv, MaxPool, AveragePool, BatchNormalization,
7//! LayerNormalization, …) are **extensions**: a backend MAY
8//! support them via `extension_opsets()`; a graph using them
9//! either binds to a backend that declares them OR (future work)
10//! a lowering pass decomposes them into primitives.
11//!
12//! Naming rationale — the op-types here live in the `ai.onnx`
13//! domain because that's where `Add`, `MatMul`, `Reshape`, etc.
14//! are canonically named. The framework deliberately avoids
15//! `ONNX_V1_*` / `onnx_v1` identifiers anywhere so users don't
16//! read this floor as a claim to implement the ONNX v1
17//! specification — the floor is OUR curation of primitives, not
18//! the formal ONNX v1 catalog.
19
20use crate::atomic::AtomicOpsetDecl;
21
22/// Canonical opset domain for the primitive tensor ops. Same
23/// string the upstream ONNX project uses for its op-type catalog.
24pub const TENSOR_PRIMITIVES_DOMAIN: &str = "ai.onnx";
25
26/// Version of the framework's primitive-tensor floor. Bumped when
27/// the set changes meaningfully.
28pub const TENSOR_PRIMITIVES_VERSION: i64 = 1;
29
30/// 30 primitive tensor ops every `Backend` impl MUST declare.
31///
32/// Categories: arithmetic (6) + math (4) + linear algebra (1) +
33/// reductions (4) + shape (9) + comparison (3) + conditional (1)
34/// + creation (1) + indexing (1) = 30.
35pub const TENSOR_PRIMITIVES_OPS: &[&str] = &[
36 // Arithmetic (6)
37 "Add",
38 "Sub",
39 "Mul",
40 "Div",
41 "Neg",
42 "Abs",
43 // Math (4)
44 "Sqrt",
45 "Pow",
46 "Exp",
47 "Log",
48 // Linear algebra (1)
49 "MatMul",
50 // Reductions (4)
51 "ReduceSum",
52 "ReduceMean",
53 "ReduceMax",
54 "ReduceMin",
55 // Shape (9)
56 "Reshape",
57 "Transpose",
58 "Concat",
59 "Slice",
60 "Split",
61 "Squeeze",
62 "Unsqueeze",
63 "Identity",
64 "Cast",
65 // Comparison (3)
66 "Equal",
67 "Greater",
68 "Less",
69 // Conditional (1)
70 "Where",
71 // Creation (1)
72 "Constant",
73 // Indexing (1)
74 "Gather",
75];
76
77/// Result of [`opset_covers_primitives`] when a backend's
78/// declared opset is missing one or more entries from the floor.
79#[derive(Debug, Clone, PartialEq, Eq)]
80pub struct MissingPrimitives {
81 /// The backend's `atomic_opset().domain`.
82 pub backend_domain: &'static str,
83 /// The backend's `atomic_opset().version`.
84 pub backend_version: i64,
85 /// Primitive op names absent from the backend's opset.
86 /// Reported in [`TENSOR_PRIMITIVES_OPS`] declaration order.
87 pub missing: Vec<&'static str>,
88}
89
90impl std::fmt::Display for MissingPrimitives {
91 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
92 write!(
93 f,
94 "backend opset {}@v{} missing {} primitive op(s): {}",
95 self.backend_domain,
96 self.backend_version,
97 self.missing.len(),
98 self.missing.join(", "),
99 )
100 }
101}
102
103impl std::error::Error for MissingPrimitives {}
104
105/// Confirm `opset` declares every primitive in
106/// [`TENSOR_PRIMITIVES_OPS`]. Returns the list of missing names so
107/// the caller surfaces a typed error instead of a flat boolean.
108/// Ops in opsets with a non-`TENSOR_PRIMITIVES_DOMAIN` domain
109/// don't count toward the check — primitives are sourced from the
110/// canonical `ai.onnx` namespace.
111pub fn opset_covers_primitives(opset: &AtomicOpsetDecl) -> Result<(), MissingPrimitives> {
112 // Collect the backend's declared op names from the primitives
113 // domain. Backends layer their non-primitive ops via
114 // `extension_opsets()`; those aren't relevant here.
115 let declared: std::collections::HashSet<&str> = if opset.domain == TENSOR_PRIMITIVES_DOMAIN {
116 opset.ops.iter().map(|o| o.name).collect()
117 } else {
118 std::collections::HashSet::new()
119 };
120
121 let missing: Vec<&'static str> = TENSOR_PRIMITIVES_OPS
122 .iter()
123 .copied()
124 .filter(|name| !declared.contains(*name))
125 .collect();
126
127 if missing.is_empty() {
128 Ok(())
129 } else {
130 Err(MissingPrimitives {
131 backend_domain: opset.domain,
132 backend_version: opset.version,
133 missing,
134 })
135 }
136}
137