Skip to main content

bb_runtime/contracts/
backend.rs

1//! `bb::Backend` — Contract trait for tensor compute backends.
2//!
3//! The Contract has THREE surfaces, exposed side-by-side:
4//!
5//! 1. **One typed method per mandatory primitive op** (the 30
6//!    entries in [`bb_ir::tensor_primitives::TENSOR_PRIMITIVES_OPS`]):
7//!    `add`, `mul`, `matmul`, `reduce_sum`, `reshape`, …. Components
8//!    reach these inline through `ctx.backends` for short-form
9//!    tensor math — an Index's distance kernel calls
10//!    `backend.matmul(&query, &vectors)?` instead of hand-rolling
11//!    a loop.
12//!
13//! 2. **One method to execute a subgraph**:
14//!    `execute(&GraphProto, HashMap<String, Tensor>, BackendAttrs) →
15//!    HashMap<String, Tensor>`. Backends that prefer whole-graph
16//!    dispatch override this entry point; the per-op defaults call
17//!    through to it via a one-node `GraphProto`.
18//!
19//! 3. **One dispatch entry point for `BackendSubgraph` carriers**:
20//!    `dispatch(&GraphProto, inputs, attrs, completion) →
21//!    ContractResponse`. The engine calls this for every
22//!    `BackendSubgraph` carrier op. The default falls through to
23//!    `execute` synchronously. Backends with per-subgraph caching,
24//!    JIT compilation, or async device execution override `dispatch`
25//!    to return `ContractResponse::Later` while device work runs.
26//!
27//! ### How the two sides compose
28//!
29//! Default impls in [`crate::contracts::backend_default_walk`]
30//! bridge the surfaces so a backend author overrides only the side
31//! that's natural for their target:
32//!
33//! - **CpuBackend** overrides the 30 per-op methods directly
34//!   (`add` runs ndarray's `Add` impl, `matmul` runs `dot`, …).
35//!   It does NOT override `execute` — the default walker uses
36//!   the overridden per-op methods.
37//!
38//! - **A Burn-style backend** overrides `execute` natively (Burn
39//!   compiles the whole `GraphProto` to its own IR + runs once).
40//!   It does NOT override per-op methods — they default-wrap a
41//!   one-node `GraphProto` and call `execute`.
42//!
43//! Backends overriding *neither* side stack-overflow on the first
44//! call: every per-op default wraps into `execute`, whose default
45//! walks back to per-op, ad infinitum. Backends MUST override at
46//! least one side.
47//!
48//! ### Extension ops
49//!
50//! Activation functions (Relu, Sigmoid, Softmax), pooling (MaxPool,
51//! AveragePool), normalization (BatchNormalization, LayerNorm),
52//! Conv, and so on are NOT on the Contract surface. They're
53//! extensions — a backend MAY declare them via
54//! [`crate::roles::BackendRuntime::extension_opsets`] and handle
55//! them through its own `execute` override; OR a future lowering
56//! pass decomposes them into primitives so the Contract surface
57//! covers any graph.
58
59use std::collections::HashMap;
60
61use bb_ir::proto::onnx::{AttributeProto, GraphProto, StringStringEntryProto, TensorProto};
62
63use crate::completion::{CompletionHandle, ContractResponse};
64use crate::contracts::backend_default_walk;
65
66/// Per-call NodeProto context surfaced to `Backend::execute` so
67/// kernels overriding the whole-graph path see the original call
68/// site's attributes + metadata alongside the body. Per-op
69/// methods (which receive their attributes positionally as typed
70/// args) don't need this struct.
71pub struct BackendAttrs<'a> {
72    /// Attribute list from the call NodeProto's `attribute` field.
73    pub current_node_attributes: &'a [AttributeProto],
74    /// `metadata_props` from the call NodeProto.
75    pub current_node_metadata: &'a [StringStringEntryProto],
76}
77
78/// User-facing Contract trait for a tensor compute backend.
79///
80/// The `Tensor` associated type lets backends dispatch over their
81/// native storage (`Dense<f32>`, an `ndarray::ArrayD<f32>`, an
82/// opaque GPU handle, …); the framework round-trips through the
83/// producer/consumer `SlotValue` carriers via the derive bridge
84/// in [`crate::roles::BackendRuntime`].
85///
86/// `Self::Tensor: Clone` is required because the per-op default
87/// impls clone tensors into a temporary `HashMap<String, _>` to
88/// feed [`Backend::execute`]. Backends overriding the per-op
89/// methods directly never invoke this clone; backends overriding
90/// `execute` natively pay one clone per per-op call. ndarray's
91/// `ArrayD<f32>` clones the shape + bumps an internal refcount —
92/// a few-hundred-nanosecond cost, not a memcpy.
93pub trait Backend: Send + Sync {
94    /// Library-maker-defined error type. The
95    /// `From<BackendWalkError>` bound lets the default per-op /
96    /// `execute_graph_via_per_op` walker surface graph-validation
97    /// failures as typed errors instead of panicking on
98    /// peer-supplied or malformed `GraphProto` bodies.
99    type Error: std::error::Error
100        + std::fmt::Display
101        + Send
102        + Sync
103        + From<crate::contracts::backend_default_walk::BackendWalkError>
104        + 'static;
105
106    /// Native tensor representation.
107    type Tensor: Clone + Send + Sync + 'static + bb_ir::types::Storage;
108
109    // ──────────────────────────────────────────────────────────
110    // Per-op surface — one method per primitive in
111    // `TENSOR_PRIMITIVES_OPS`. Each default wraps a one-node
112    // `GraphProto` and calls `execute`.
113    // ──────────────────────────────────────────────────────────
114
115    // ─── Arithmetic (6) ───────────────────────────────────────
116
117    /// Element-wise `a + b` with NumPy broadcasting.
118    fn add(&self, a: &Self::Tensor, b: &Self::Tensor) -> Result<Self::Tensor, Self::Error> {
119        backend_default_walk::execute_single(self, "Add", &[a, b], Vec::new())
120    }
121    /// Element-wise `a - b` with NumPy broadcasting.
122    fn sub(&self, a: &Self::Tensor, b: &Self::Tensor) -> Result<Self::Tensor, Self::Error> {
123        backend_default_walk::execute_single(self, "Sub", &[a, b], Vec::new())
124    }
125    /// Element-wise `a * b` with NumPy broadcasting.
126    fn mul(&self, a: &Self::Tensor, b: &Self::Tensor) -> Result<Self::Tensor, Self::Error> {
127        backend_default_walk::execute_single(self, "Mul", &[a, b], Vec::new())
128    }
129    /// Element-wise `a / b` with NumPy broadcasting.
130    fn div(&self, a: &Self::Tensor, b: &Self::Tensor) -> Result<Self::Tensor, Self::Error> {
131        backend_default_walk::execute_single(self, "Div", &[a, b], Vec::new())
132    }
133    /// Element-wise unary negation.
134    fn neg(&self, a: &Self::Tensor) -> Result<Self::Tensor, Self::Error> {
135        backend_default_walk::execute_single(self, "Neg", &[a], Vec::new())
136    }
137    /// Element-wise absolute value.
138    fn abs(&self, a: &Self::Tensor) -> Result<Self::Tensor, Self::Error> {
139        backend_default_walk::execute_single(self, "Abs", &[a], Vec::new())
140    }
141
142    // ─── Math (4) ─────────────────────────────────────────────
143
144    /// Element-wise square root.
145    fn sqrt(&self, a: &Self::Tensor) -> Result<Self::Tensor, Self::Error> {
146        backend_default_walk::execute_single(self, "Sqrt", &[a], Vec::new())
147    }
148    /// Element-wise `a ** b` with NumPy broadcasting.
149    fn pow(&self, a: &Self::Tensor, b: &Self::Tensor) -> Result<Self::Tensor, Self::Error> {
150        backend_default_walk::execute_single(self, "Pow", &[a, b], Vec::new())
151    }
152    /// Element-wise natural exponential.
153    fn exp(&self, a: &Self::Tensor) -> Result<Self::Tensor, Self::Error> {
154        backend_default_walk::execute_single(self, "Exp", &[a], Vec::new())
155    }
156    /// Element-wise natural logarithm.
157    fn log(&self, a: &Self::Tensor) -> Result<Self::Tensor, Self::Error> {
158        backend_default_walk::execute_single(self, "Log", &[a], Vec::new())
159    }
160
161    // ─── Linear algebra (1) ───────────────────────────────────
162
163    /// Matrix multiplication (NumPy semantics: 2-D × 2-D + batched
164    /// higher-rank broadcasting).
165    fn matmul(&self, a: &Self::Tensor, b: &Self::Tensor) -> Result<Self::Tensor, Self::Error> {
166        backend_default_walk::execute_single(self, "MatMul", &[a, b], Vec::new())
167    }
168
169    // ─── Reductions (4) ───────────────────────────────────────
170
171    /// Sum-reduce `a` along `axes`. `keepdims = true` preserves
172    /// the reduced dims as length-1.
173    fn reduce_sum(
174        &self,
175        a: &Self::Tensor,
176        axes: &[i64],
177        keepdims: bool,
178    ) -> Result<Self::Tensor, Self::Error> {
179        backend_default_walk::execute_single(
180            self,
181            "ReduceSum",
182            &[a],
183            vec![
184                backend_default_walk::ints_attr("axes", axes),
185                backend_default_walk::int_attr("keepdims", keepdims as i64),
186            ],
187        )
188    }
189    /// Mean-reduce `a` along `axes`.
190    fn reduce_mean(
191        &self,
192        a: &Self::Tensor,
193        axes: &[i64],
194        keepdims: bool,
195    ) -> Result<Self::Tensor, Self::Error> {
196        backend_default_walk::execute_single(
197            self,
198            "ReduceMean",
199            &[a],
200            vec![
201                backend_default_walk::ints_attr("axes", axes),
202                backend_default_walk::int_attr("keepdims", keepdims as i64),
203            ],
204        )
205    }
206    /// Max-reduce `a` along `axes`.
207    fn reduce_max(
208        &self,
209        a: &Self::Tensor,
210        axes: &[i64],
211        keepdims: bool,
212    ) -> Result<Self::Tensor, Self::Error> {
213        backend_default_walk::execute_single(
214            self,
215            "ReduceMax",
216            &[a],
217            vec![
218                backend_default_walk::ints_attr("axes", axes),
219                backend_default_walk::int_attr("keepdims", keepdims as i64),
220            ],
221        )
222    }
223    /// Min-reduce `a` along `axes`.
224    fn reduce_min(
225        &self,
226        a: &Self::Tensor,
227        axes: &[i64],
228        keepdims: bool,
229    ) -> Result<Self::Tensor, Self::Error> {
230        backend_default_walk::execute_single(
231            self,
232            "ReduceMin",
233            &[a],
234            vec![
235                backend_default_walk::ints_attr("axes", axes),
236                backend_default_walk::int_attr("keepdims", keepdims as i64),
237            ],
238        )
239    }
240
241    // ─── Shape (9) ────────────────────────────────────────────
242
243    /// Reshape `a` to the given dims. Total element count must
244    /// match.
245    fn reshape(&self, a: &Self::Tensor, shape: &[i64]) -> Result<Self::Tensor, Self::Error> {
246        backend_default_walk::execute_single(
247            self,
248            "Reshape",
249            &[a],
250            vec![backend_default_walk::ints_attr("shape", shape)],
251        )
252    }
253    /// Transpose axes. Empty `perm` reverses all dims.
254    fn transpose(&self, a: &Self::Tensor, perm: &[i64]) -> Result<Self::Tensor, Self::Error> {
255        backend_default_walk::execute_single(
256            self,
257            "Transpose",
258            &[a],
259            vec![backend_default_walk::ints_attr("perm", perm)],
260        )
261    }
262    /// Concatenate `inputs` along `axis`.
263    fn concat(&self, inputs: &[&Self::Tensor], axis: i64) -> Result<Self::Tensor, Self::Error> {
264        backend_default_walk::execute_single(
265            self,
266            "Concat",
267            inputs,
268            vec![backend_default_walk::int_attr("axis", axis)],
269        )
270    }
271    /// NumPy-style slice. Empty `axes` defaults to all dims;
272    /// empty `steps` defaults to 1 per axis.
273    fn slice(
274        &self,
275        a: &Self::Tensor,
276        starts: &[i64],
277        ends: &[i64],
278        axes: &[i64],
279        steps: &[i64],
280    ) -> Result<Self::Tensor, Self::Error> {
281        backend_default_walk::execute_single(
282            self,
283            "Slice",
284            &[a],
285            vec![
286                backend_default_walk::ints_attr("starts", starts),
287                backend_default_walk::ints_attr("ends", ends),
288                backend_default_walk::ints_attr("axes", axes),
289                backend_default_walk::ints_attr("steps", steps),
290            ],
291        )
292    }
293    /// Split `a` along `axis` into parts of the given `sizes`.
294    /// Empty `sizes` means equal-sized splits (count comes from
295    /// the consumer side downstream).
296    fn split(
297        &self,
298        a: &Self::Tensor,
299        axis: i64,
300        sizes: &[i64],
301    ) -> Result<Vec<Self::Tensor>, Self::Error> {
302        // `Split` is the only primitive returning multiple
303        // tensors. We can't use `execute_single`'s single-output
304        // path — instead we wrap into a graph that names each
305        // output positionally and extract them.
306        backend_default_walk::execute_multi(
307            self,
308            "Split",
309            &[a],
310            vec![
311                backend_default_walk::int_attr("axis", axis),
312                backend_default_walk::ints_attr("split", sizes),
313            ],
314            sizes.len(),
315        )
316    }
317    /// Remove dimensions of size 1. Empty `axes` removes all
318    /// size-1 dims.
319    fn squeeze(&self, a: &Self::Tensor, axes: &[i64]) -> Result<Self::Tensor, Self::Error> {
320        backend_default_walk::execute_single(
321            self,
322            "Squeeze",
323            &[a],
324            vec![backend_default_walk::ints_attr("axes", axes)],
325        )
326    }
327    /// Insert dimensions of size 1 at the given axes.
328    fn unsqueeze(&self, a: &Self::Tensor, axes: &[i64]) -> Result<Self::Tensor, Self::Error> {
329        backend_default_walk::execute_single(
330            self,
331            "Unsqueeze",
332            &[a],
333            vec![backend_default_walk::ints_attr("axes", axes)],
334        )
335    }
336    /// Identity / clone — pass-through useful for graph rewrites.
337    fn identity(&self, a: &Self::Tensor) -> Result<Self::Tensor, Self::Error> {
338        backend_default_walk::execute_single(self, "Identity", &[a], Vec::new())
339    }
340    /// Cast to the given ONNX `DataType` enum value (matches
341    /// `bb_ir::proto::onnx::tensor_proto::DataType`).
342    fn cast(&self, a: &Self::Tensor, dtype: i32) -> Result<Self::Tensor, Self::Error> {
343        backend_default_walk::execute_single(
344            self,
345            "Cast",
346            &[a],
347            vec![backend_default_walk::int_attr("to", dtype as i64)],
348        )
349    }
350
351    // ─── Comparison (3) ───────────────────────────────────────
352
353    /// Element-wise `a == b`. Result is boolean-typed.
354    fn equal(&self, a: &Self::Tensor, b: &Self::Tensor) -> Result<Self::Tensor, Self::Error> {
355        backend_default_walk::execute_single(self, "Equal", &[a, b], Vec::new())
356    }
357    /// Element-wise `a > b`. Result is boolean-typed.
358    fn greater(&self, a: &Self::Tensor, b: &Self::Tensor) -> Result<Self::Tensor, Self::Error> {
359        backend_default_walk::execute_single(self, "Greater", &[a, b], Vec::new())
360    }
361    /// Element-wise `a < b`. Result is boolean-typed.
362    fn less(&self, a: &Self::Tensor, b: &Self::Tensor) -> Result<Self::Tensor, Self::Error> {
363        backend_default_walk::execute_single(self, "Less", &[a, b], Vec::new())
364    }
365
366    // ─── Conditional (1) ──────────────────────────────────────
367
368    /// Element-wise ternary: `where cond { t } else { f }`.
369    /// Named `r#where` to dodge the reserved Rust keyword.
370    fn r#where(
371        &self,
372        cond: &Self::Tensor,
373        t: &Self::Tensor,
374        f: &Self::Tensor,
375    ) -> Result<Self::Tensor, Self::Error> {
376        backend_default_walk::execute_single(self, "Where", &[cond, t, f], Vec::new())
377    }
378
379    // ─── Creation (1) ─────────────────────────────────────────
380
381    /// Materialize a constant from an ONNX `TensorProto`. The
382    /// `value` attribute on the ONNX `Constant` op carries the
383    /// data; rank, dtype, raw bytes all come from the proto.
384    fn constant(&self, value: TensorProto) -> Result<Self::Tensor, Self::Error> {
385        backend_default_walk::execute_single(
386            self,
387            "Constant",
388            &[],
389            vec![backend_default_walk::tensor_attr("value", value)],
390        )
391    }
392
393    // ─── Indexing (1) ─────────────────────────────────────────
394
395    /// Gather slices of `data` along `axis` indexed by `indices`.
396    fn gather(
397        &self,
398        data: &Self::Tensor,
399        indices: &Self::Tensor,
400        axis: i64,
401    ) -> Result<Self::Tensor, Self::Error> {
402        backend_default_walk::execute_single(
403            self,
404            "Gather",
405            &[data, indices],
406            vec![backend_default_walk::int_attr("axis", axis)],
407        )
408    }
409
410    // ──────────────────────────────────────────────────────────
411    // Whole-graph surface — default walks `graph.node` and
412    // dispatches each through the typed per-op methods above.
413    // ──────────────────────────────────────────────────────────
414
415    /// Execute every NodeProto in `graph.node` against the value
416    /// env `inputs`. Returns the subset of values named in
417    /// `graph.output`.
418    ///
419    /// `graph.node` is topologically ordered per the ONNX spec,
420    /// so the default walker (a linear scan) suffices for any
421    /// `GraphProto` whose ops are all in
422    /// [`bb_ir::tensor_primitives::TENSOR_PRIMITIVES_OPS`]. A
423    /// backend overriding this method may detect fused patterns,
424    /// compile to GPU, or any other strategy.
425    fn execute(
426        &self,
427        graph: &GraphProto,
428        inputs: HashMap<String, Self::Tensor>,
429        _attrs: BackendAttrs<'_>,
430    ) -> Result<HashMap<String, Self::Tensor>, Self::Error> {
431        backend_default_walk::execute_graph_via_per_op(self, graph, inputs)
432    }
433
434    /// Dispatch a `BackendSubgraph` carrier — the engine-facing entry
435    /// point for whole-subgraph execution.
436    ///
437    /// The default falls through to [`Self::execute`] synchronously,
438    /// keeping existing backends' behaviour identical. Backends with
439    /// per-subgraph caching, JIT compilation, or async device execution
440    /// override this to:
441    ///
442    /// - Cache the compiled subgraph by identity (e.g. graph name or
443    ///   hash).
444    /// - Return [`ContractResponse::Later`] and retain `completion`
445    ///   while the device runs. The engine schedules other work;
446    ///   the backend completes the handle from whatever runtime it
447    ///   uses — `std::thread`, tokio task, custom event loop, single-
448    ///   thread no-std loop.
449    /// - Fall through to [`Self::execute`] on compile failure or
450    ///   unsupported op.
451    ///
452    /// The `completion` parameter in the default impl is intentionally
453    /// discarded (`let _ = completion`) because [`ContractResponse::Now`]
454    /// does not retain the handle. This is correct — only overriders
455    /// that return [`ContractResponse::Later`] must hold it.
456    fn dispatch(
457        &self,
458        graph: &GraphProto,
459        inputs: HashMap<String, Self::Tensor>,
460        attrs: BackendAttrs<'_>,
461        completion: CompletionHandle<HashMap<String, Self::Tensor>, Self::Error>,
462    ) -> ContractResponse<HashMap<String, Self::Tensor>, Self::Error> {
463        let _ = completion; // default doesn't retain it; signature stays for opt-in overriders
464        ContractResponse::Now(self.execute(graph, inputs, attrs))
465    }
466
467    /// Materialise an inbound tensor `SlotFill` into this backend's
468    /// native tensor representation.
469    ///
470    /// The framework has already (a) capped `bytes.len()` against the
471    /// envelope's `EnvelopeCaps::max_per_fill_bytes`, (b) charged the
472    /// length against `NodeConfig::ingress_byte_budget`, and (c) moved
473    /// ownership of the wire bytes into this call. The backend may
474    /// adopt the `Vec<u8>` directly (zero-copy via
475    /// `ArrayD::from_shape_vec` when alignment permits), pull a buffer
476    /// from a pool and copy in, or allocate fresh. The framework will
477    /// not touch `bytes` after this call returns.
478    ///
479    /// The default delegates to the global wire-decoder registry: it
480    /// looks up the decoder for `type_hash`, runs it on the bytes,
481    /// then downcasts the resulting boxed `SlotValue` to `Self::Tensor`
482    /// via the registry's `Box<dyn Any>` repackaging. Backends that
483    /// have not implemented tensor pooling continue to work through
484    /// this path; backends that override pay the registry hop only at
485    /// override time.
486    ///
487    /// On `Err`, the engine drops the fill, releases the byte charge,
488    /// and emits `WireReceiveError { kind: BackendMaterializeFailed }`.
489    ///
490    /// Ownership note: `bytes: Vec<u8>` by value (not `&[u8]` or
491    /// `Cow`). This is the framework→backend handoff, NOT an external
492    /// boundary — the backend lives inside the framework ecosystem
493    /// and plays by the runtime contract. Principle 1a (ephemeral
494    /// borrowed slices at external boundaries) does not apply here:
495    /// the framework copied or owned the bytes already, and a backend
496    /// that wants to adopt them (zero-copy) needs ownership.
497    fn materialize_from_wire(
498        &self,
499        type_hash: u64,
500        bytes: Vec<u8>,
501    ) -> Result<Self::Tensor, Self::Error> {
502        use crate::contracts::backend_default_walk::BackendWalkError;
503        let decoder = bb_ir::slot_value::wire_decoder_registry()
504            .get(&type_hash)
505            .copied()
506            .ok_or_else(|| BackendWalkError::WireMaterializeFailed {
507                type_hash,
508                reason: "no decoder registered for type_hash".into(),
509            })?;
510        let boxed = decoder(&bytes).map_err(|e| BackendWalkError::WireMaterializeFailed {
511            type_hash,
512            reason: e.to_string(),
513        })?;
514        let any = boxed.into_any_boxed();
515        any.downcast::<Self::Tensor>().map(|b| *b).map_err(|_| {
516            BackendWalkError::WireMaterializeFailed {
517                type_hash,
518                reason: "decoded carrier is not Self::Tensor".into(),
519            }
520            .into()
521        })
522    }
523}