bb_runtime/contracts/backend.rs
1//! `bb::Backend` — Contract trait for tensor compute backends.
2//!
3//! The Contract has THREE surfaces, exposed side-by-side:
4//!
5//! 1. **One typed method per mandatory primitive op** (the 30
6//! entries in [`bb_ir::tensor_primitives::TENSOR_PRIMITIVES_OPS`]):
7//! `add`, `mul`, `matmul`, `reduce_sum`, `reshape`, …. Components
8//! reach these inline through `ctx.backends` for short-form
9//! tensor math — an Index's distance kernel calls
10//! `backend.matmul(&query, &vectors)?` instead of hand-rolling
11//! a loop.
12//!
13//! 2. **One method to execute a subgraph**:
14//! `execute(&GraphProto, HashMap<String, Tensor>, BackendAttrs) →
15//! HashMap<String, Tensor>`. Backends that prefer whole-graph
16//! dispatch override this entry point; the per-op defaults call
17//! through to it via a one-node `GraphProto`.
18//!
19//! 3. **One dispatch entry point for `BackendSubgraph` carriers**:
20//! `dispatch(&GraphProto, inputs, attrs, completion) →
21//! ContractResponse`. The engine calls this for every
22//! `BackendSubgraph` carrier op. The default falls through to
23//! `execute` synchronously. Backends with per-subgraph caching,
24//! JIT compilation, or async device execution override `dispatch`
25//! to return `ContractResponse::Later` while device work runs.
26//!
27//! ### How the two sides compose
28//!
29//! Default impls in [`crate::contracts::backend_default_walk`]
30//! bridge the surfaces so a backend author overrides only the side
31//! that's natural for their target:
32//!
33//! - **CpuBackend** overrides the 30 per-op methods directly
34//! (`add` runs ndarray's `Add` impl, `matmul` runs `dot`, …).
35//! It does NOT override `execute` — the default walker uses
36//! the overridden per-op methods.
37//!
38//! - **A Burn-style backend** overrides `execute` natively (Burn
39//! compiles the whole `GraphProto` to its own IR + runs once).
40//! It does NOT override per-op methods — they default-wrap a
41//! one-node `GraphProto` and call `execute`.
42//!
43//! Backends overriding *neither* side stack-overflow on the first
44//! call: every per-op default wraps into `execute`, whose default
45//! walks back to per-op, ad infinitum. Backends MUST override at
46//! least one side.
47//!
48//! ### Extension ops
49//!
50//! Activation functions (Relu, Sigmoid, Softmax), pooling (MaxPool,
51//! AveragePool), normalization (BatchNormalization, LayerNorm),
52//! Conv, and so on are NOT on the Contract surface. They're
53//! extensions — a backend MAY declare them via
54//! [`crate::roles::BackendRuntime::extension_opsets`] and handle
55//! them through its own `execute` override; OR a future lowering
56//! pass decomposes them into primitives so the Contract surface
57//! covers any graph.
58
59use std::collections::HashMap;
60
61use bb_ir::proto::onnx::{AttributeProto, GraphProto, StringStringEntryProto, TensorProto};
62
63use crate::completion::{CompletionHandle, ContractResponse};
64use crate::contracts::backend_default_walk;
65
66/// Per-call NodeProto context surfaced to `Backend::execute` so
67/// kernels overriding the whole-graph path see the original call
68/// site's attributes + metadata alongside the body. Per-op
69/// methods (which receive their attributes positionally as typed
70/// args) don't need this struct.
71pub struct BackendAttrs<'a> {
72 /// Attribute list from the call NodeProto's `attribute` field.
73 pub current_node_attributes: &'a [AttributeProto],
74 /// `metadata_props` from the call NodeProto.
75 pub current_node_metadata: &'a [StringStringEntryProto],
76}
77
78/// User-facing Contract trait for a tensor compute backend.
79///
80/// The `Tensor` associated type lets backends dispatch over their
81/// native storage (`Dense<f32>`, an `ndarray::ArrayD<f32>`, an
82/// opaque GPU handle, …); the framework round-trips through the
83/// producer/consumer `SlotValue` carriers via the derive bridge
84/// in [`crate::roles::BackendRuntime`].
85///
86/// `Self::Tensor: Clone` is required because the per-op default
87/// impls clone tensors into a temporary `HashMap<String, _>` to
88/// feed [`Backend::execute`]. Backends overriding the per-op
89/// methods directly never invoke this clone; backends overriding
90/// `execute` natively pay one clone per per-op call. ndarray's
91/// `ArrayD<f32>` clones the shape + bumps an internal refcount —
92/// a few-hundred-nanosecond cost, not a memcpy.
93pub trait Backend: Send + Sync {
94 /// Library-maker-defined error type. The
95 /// `From<BackendWalkError>` bound lets the default per-op /
96 /// `execute_graph_via_per_op` walker surface graph-validation
97 /// failures as typed errors instead of panicking on
98 /// peer-supplied or malformed `GraphProto` bodies.
99 type Error: std::error::Error
100 + std::fmt::Display
101 + Send
102 + Sync
103 + From<crate::contracts::backend_default_walk::BackendWalkError>
104 + 'static;
105
106 /// Native tensor representation.
107 type Tensor: Clone + Send + Sync + 'static + bb_ir::types::Storage;
108
109 // ──────────────────────────────────────────────────────────
110 // Per-op surface — one method per primitive in
111 // `TENSOR_PRIMITIVES_OPS`. Each default wraps a one-node
112 // `GraphProto` and calls `execute`.
113 // ──────────────────────────────────────────────────────────
114
115 // ─── Arithmetic (6) ───────────────────────────────────────
116
117 /// Element-wise `a + b` with NumPy broadcasting.
118 fn add(&self, a: &Self::Tensor, b: &Self::Tensor) -> Result<Self::Tensor, Self::Error> {
119 backend_default_walk::execute_single(self, "Add", &[a, b], Vec::new())
120 }
121 /// Element-wise `a - b` with NumPy broadcasting.
122 fn sub(&self, a: &Self::Tensor, b: &Self::Tensor) -> Result<Self::Tensor, Self::Error> {
123 backend_default_walk::execute_single(self, "Sub", &[a, b], Vec::new())
124 }
125 /// Element-wise `a * b` with NumPy broadcasting.
126 fn mul(&self, a: &Self::Tensor, b: &Self::Tensor) -> Result<Self::Tensor, Self::Error> {
127 backend_default_walk::execute_single(self, "Mul", &[a, b], Vec::new())
128 }
129 /// Element-wise `a / b` with NumPy broadcasting.
130 fn div(&self, a: &Self::Tensor, b: &Self::Tensor) -> Result<Self::Tensor, Self::Error> {
131 backend_default_walk::execute_single(self, "Div", &[a, b], Vec::new())
132 }
133 /// Element-wise unary negation.
134 fn neg(&self, a: &Self::Tensor) -> Result<Self::Tensor, Self::Error> {
135 backend_default_walk::execute_single(self, "Neg", &[a], Vec::new())
136 }
137 /// Element-wise absolute value.
138 fn abs(&self, a: &Self::Tensor) -> Result<Self::Tensor, Self::Error> {
139 backend_default_walk::execute_single(self, "Abs", &[a], Vec::new())
140 }
141
142 // ─── Math (4) ─────────────────────────────────────────────
143
144 /// Element-wise square root.
145 fn sqrt(&self, a: &Self::Tensor) -> Result<Self::Tensor, Self::Error> {
146 backend_default_walk::execute_single(self, "Sqrt", &[a], Vec::new())
147 }
148 /// Element-wise `a ** b` with NumPy broadcasting.
149 fn pow(&self, a: &Self::Tensor, b: &Self::Tensor) -> Result<Self::Tensor, Self::Error> {
150 backend_default_walk::execute_single(self, "Pow", &[a, b], Vec::new())
151 }
152 /// Element-wise natural exponential.
153 fn exp(&self, a: &Self::Tensor) -> Result<Self::Tensor, Self::Error> {
154 backend_default_walk::execute_single(self, "Exp", &[a], Vec::new())
155 }
156 /// Element-wise natural logarithm.
157 fn log(&self, a: &Self::Tensor) -> Result<Self::Tensor, Self::Error> {
158 backend_default_walk::execute_single(self, "Log", &[a], Vec::new())
159 }
160
161 // ─── Linear algebra (1) ───────────────────────────────────
162
163 /// Matrix multiplication (NumPy semantics: 2-D × 2-D + batched
164 /// higher-rank broadcasting).
165 fn matmul(&self, a: &Self::Tensor, b: &Self::Tensor) -> Result<Self::Tensor, Self::Error> {
166 backend_default_walk::execute_single(self, "MatMul", &[a, b], Vec::new())
167 }
168
169 // ─── Reductions (4) ───────────────────────────────────────
170
171 /// Sum-reduce `a` along `axes`. `keepdims = true` preserves
172 /// the reduced dims as length-1.
173 fn reduce_sum(
174 &self,
175 a: &Self::Tensor,
176 axes: &[i64],
177 keepdims: bool,
178 ) -> Result<Self::Tensor, Self::Error> {
179 backend_default_walk::execute_single(
180 self,
181 "ReduceSum",
182 &[a],
183 vec![
184 backend_default_walk::ints_attr("axes", axes),
185 backend_default_walk::int_attr("keepdims", keepdims as i64),
186 ],
187 )
188 }
189 /// Mean-reduce `a` along `axes`.
190 fn reduce_mean(
191 &self,
192 a: &Self::Tensor,
193 axes: &[i64],
194 keepdims: bool,
195 ) -> Result<Self::Tensor, Self::Error> {
196 backend_default_walk::execute_single(
197 self,
198 "ReduceMean",
199 &[a],
200 vec![
201 backend_default_walk::ints_attr("axes", axes),
202 backend_default_walk::int_attr("keepdims", keepdims as i64),
203 ],
204 )
205 }
206 /// Max-reduce `a` along `axes`.
207 fn reduce_max(
208 &self,
209 a: &Self::Tensor,
210 axes: &[i64],
211 keepdims: bool,
212 ) -> Result<Self::Tensor, Self::Error> {
213 backend_default_walk::execute_single(
214 self,
215 "ReduceMax",
216 &[a],
217 vec![
218 backend_default_walk::ints_attr("axes", axes),
219 backend_default_walk::int_attr("keepdims", keepdims as i64),
220 ],
221 )
222 }
223 /// Min-reduce `a` along `axes`.
224 fn reduce_min(
225 &self,
226 a: &Self::Tensor,
227 axes: &[i64],
228 keepdims: bool,
229 ) -> Result<Self::Tensor, Self::Error> {
230 backend_default_walk::execute_single(
231 self,
232 "ReduceMin",
233 &[a],
234 vec![
235 backend_default_walk::ints_attr("axes", axes),
236 backend_default_walk::int_attr("keepdims", keepdims as i64),
237 ],
238 )
239 }
240
241 // ─── Shape (9) ────────────────────────────────────────────
242
243 /// Reshape `a` to the given dims. Total element count must
244 /// match.
245 fn reshape(&self, a: &Self::Tensor, shape: &[i64]) -> Result<Self::Tensor, Self::Error> {
246 backend_default_walk::execute_single(
247 self,
248 "Reshape",
249 &[a],
250 vec![backend_default_walk::ints_attr("shape", shape)],
251 )
252 }
253 /// Transpose axes. Empty `perm` reverses all dims.
254 fn transpose(&self, a: &Self::Tensor, perm: &[i64]) -> Result<Self::Tensor, Self::Error> {
255 backend_default_walk::execute_single(
256 self,
257 "Transpose",
258 &[a],
259 vec![backend_default_walk::ints_attr("perm", perm)],
260 )
261 }
262 /// Concatenate `inputs` along `axis`.
263 fn concat(&self, inputs: &[&Self::Tensor], axis: i64) -> Result<Self::Tensor, Self::Error> {
264 backend_default_walk::execute_single(
265 self,
266 "Concat",
267 inputs,
268 vec![backend_default_walk::int_attr("axis", axis)],
269 )
270 }
271 /// NumPy-style slice. Empty `axes` defaults to all dims;
272 /// empty `steps` defaults to 1 per axis.
273 fn slice(
274 &self,
275 a: &Self::Tensor,
276 starts: &[i64],
277 ends: &[i64],
278 axes: &[i64],
279 steps: &[i64],
280 ) -> Result<Self::Tensor, Self::Error> {
281 backend_default_walk::execute_single(
282 self,
283 "Slice",
284 &[a],
285 vec![
286 backend_default_walk::ints_attr("starts", starts),
287 backend_default_walk::ints_attr("ends", ends),
288 backend_default_walk::ints_attr("axes", axes),
289 backend_default_walk::ints_attr("steps", steps),
290 ],
291 )
292 }
293 /// Split `a` along `axis` into parts of the given `sizes`.
294 /// Empty `sizes` means equal-sized splits (count comes from
295 /// the consumer side downstream).
296 fn split(
297 &self,
298 a: &Self::Tensor,
299 axis: i64,
300 sizes: &[i64],
301 ) -> Result<Vec<Self::Tensor>, Self::Error> {
302 // `Split` is the only primitive returning multiple
303 // tensors. We can't use `execute_single`'s single-output
304 // path — instead we wrap into a graph that names each
305 // output positionally and extract them.
306 backend_default_walk::execute_multi(
307 self,
308 "Split",
309 &[a],
310 vec![
311 backend_default_walk::int_attr("axis", axis),
312 backend_default_walk::ints_attr("split", sizes),
313 ],
314 sizes.len(),
315 )
316 }
317 /// Remove dimensions of size 1. Empty `axes` removes all
318 /// size-1 dims.
319 fn squeeze(&self, a: &Self::Tensor, axes: &[i64]) -> Result<Self::Tensor, Self::Error> {
320 backend_default_walk::execute_single(
321 self,
322 "Squeeze",
323 &[a],
324 vec![backend_default_walk::ints_attr("axes", axes)],
325 )
326 }
327 /// Insert dimensions of size 1 at the given axes.
328 fn unsqueeze(&self, a: &Self::Tensor, axes: &[i64]) -> Result<Self::Tensor, Self::Error> {
329 backend_default_walk::execute_single(
330 self,
331 "Unsqueeze",
332 &[a],
333 vec![backend_default_walk::ints_attr("axes", axes)],
334 )
335 }
336 /// Identity / clone — pass-through useful for graph rewrites.
337 fn identity(&self, a: &Self::Tensor) -> Result<Self::Tensor, Self::Error> {
338 backend_default_walk::execute_single(self, "Identity", &[a], Vec::new())
339 }
340 /// Cast to the given ONNX `DataType` enum value (matches
341 /// `bb_ir::proto::onnx::tensor_proto::DataType`).
342 fn cast(&self, a: &Self::Tensor, dtype: i32) -> Result<Self::Tensor, Self::Error> {
343 backend_default_walk::execute_single(
344 self,
345 "Cast",
346 &[a],
347 vec![backend_default_walk::int_attr("to", dtype as i64)],
348 )
349 }
350
351 // ─── Comparison (3) ───────────────────────────────────────
352
353 /// Element-wise `a == b`. Result is boolean-typed.
354 fn equal(&self, a: &Self::Tensor, b: &Self::Tensor) -> Result<Self::Tensor, Self::Error> {
355 backend_default_walk::execute_single(self, "Equal", &[a, b], Vec::new())
356 }
357 /// Element-wise `a > b`. Result is boolean-typed.
358 fn greater(&self, a: &Self::Tensor, b: &Self::Tensor) -> Result<Self::Tensor, Self::Error> {
359 backend_default_walk::execute_single(self, "Greater", &[a, b], Vec::new())
360 }
361 /// Element-wise `a < b`. Result is boolean-typed.
362 fn less(&self, a: &Self::Tensor, b: &Self::Tensor) -> Result<Self::Tensor, Self::Error> {
363 backend_default_walk::execute_single(self, "Less", &[a, b], Vec::new())
364 }
365
366 // ─── Conditional (1) ──────────────────────────────────────
367
368 /// Element-wise ternary: `where cond { t } else { f }`.
369 /// Named `r#where` to dodge the reserved Rust keyword.
370 fn r#where(
371 &self,
372 cond: &Self::Tensor,
373 t: &Self::Tensor,
374 f: &Self::Tensor,
375 ) -> Result<Self::Tensor, Self::Error> {
376 backend_default_walk::execute_single(self, "Where", &[cond, t, f], Vec::new())
377 }
378
379 // ─── Creation (1) ─────────────────────────────────────────
380
381 /// Materialize a constant from an ONNX `TensorProto`. The
382 /// `value` attribute on the ONNX `Constant` op carries the
383 /// data; rank, dtype, raw bytes all come from the proto.
384 fn constant(&self, value: TensorProto) -> Result<Self::Tensor, Self::Error> {
385 backend_default_walk::execute_single(
386 self,
387 "Constant",
388 &[],
389 vec![backend_default_walk::tensor_attr("value", value)],
390 )
391 }
392
393 // ─── Indexing (1) ─────────────────────────────────────────
394
395 /// Gather slices of `data` along `axis` indexed by `indices`.
396 fn gather(
397 &self,
398 data: &Self::Tensor,
399 indices: &Self::Tensor,
400 axis: i64,
401 ) -> Result<Self::Tensor, Self::Error> {
402 backend_default_walk::execute_single(
403 self,
404 "Gather",
405 &[data, indices],
406 vec![backend_default_walk::int_attr("axis", axis)],
407 )
408 }
409
410 // ──────────────────────────────────────────────────────────
411 // Whole-graph surface — default walks `graph.node` and
412 // dispatches each through the typed per-op methods above.
413 // ──────────────────────────────────────────────────────────
414
415 /// Execute every NodeProto in `graph.node` against the value
416 /// env `inputs`. Returns the subset of values named in
417 /// `graph.output`.
418 ///
419 /// `graph.node` is topologically ordered per the ONNX spec,
420 /// so the default walker (a linear scan) suffices for any
421 /// `GraphProto` whose ops are all in
422 /// [`bb_ir::tensor_primitives::TENSOR_PRIMITIVES_OPS`]. A
423 /// backend overriding this method may detect fused patterns,
424 /// compile to GPU, or any other strategy.
425 fn execute(
426 &self,
427 graph: &GraphProto,
428 inputs: HashMap<String, Self::Tensor>,
429 _attrs: BackendAttrs<'_>,
430 ) -> Result<HashMap<String, Self::Tensor>, Self::Error> {
431 backend_default_walk::execute_graph_via_per_op(self, graph, inputs)
432 }
433
434 /// Dispatch a `BackendSubgraph` carrier — the engine-facing entry
435 /// point for whole-subgraph execution.
436 ///
437 /// The default falls through to [`Self::execute`] synchronously,
438 /// keeping existing backends' behaviour identical. Backends with
439 /// per-subgraph caching, JIT compilation, or async device execution
440 /// override this to:
441 ///
442 /// - Cache the compiled subgraph by identity (e.g. graph name or
443 /// hash).
444 /// - Return [`ContractResponse::Later`] and retain `completion`
445 /// while the device runs. The engine schedules other work;
446 /// the backend completes the handle from whatever runtime it
447 /// uses — `std::thread`, tokio task, custom event loop, single-
448 /// thread no-std loop.
449 /// - Fall through to [`Self::execute`] on compile failure or
450 /// unsupported op.
451 ///
452 /// The `completion` parameter in the default impl is intentionally
453 /// discarded (`let _ = completion`) because [`ContractResponse::Now`]
454 /// does not retain the handle. This is correct — only overriders
455 /// that return [`ContractResponse::Later`] must hold it.
456 fn dispatch(
457 &self,
458 graph: &GraphProto,
459 inputs: HashMap<String, Self::Tensor>,
460 attrs: BackendAttrs<'_>,
461 completion: CompletionHandle<HashMap<String, Self::Tensor>, Self::Error>,
462 ) -> ContractResponse<HashMap<String, Self::Tensor>, Self::Error> {
463 let _ = completion; // default doesn't retain it; signature stays for opt-in overriders
464 ContractResponse::Now(self.execute(graph, inputs, attrs))
465 }
466
467 /// Materialise an inbound tensor `SlotFill` into this backend's
468 /// native tensor representation.
469 ///
470 /// The framework has already (a) capped `bytes.len()` against the
471 /// envelope's `EnvelopeCaps::max_per_fill_bytes`, (b) charged the
472 /// length against `NodeConfig::ingress_byte_budget`, and (c) moved
473 /// ownership of the wire bytes into this call. The backend may
474 /// adopt the `Vec<u8>` directly (zero-copy via
475 /// `ArrayD::from_shape_vec` when alignment permits), pull a buffer
476 /// from a pool and copy in, or allocate fresh. The framework will
477 /// not touch `bytes` after this call returns.
478 ///
479 /// The default delegates to the global wire-decoder registry: it
480 /// looks up the decoder for `type_hash`, runs it on the bytes,
481 /// then downcasts the resulting boxed `SlotValue` to `Self::Tensor`
482 /// via the registry's `Box<dyn Any>` repackaging. Backends that
483 /// have not implemented tensor pooling continue to work through
484 /// this path; backends that override pay the registry hop only at
485 /// override time.
486 ///
487 /// On `Err`, the engine drops the fill, releases the byte charge,
488 /// and emits `WireReceiveError { kind: BackendMaterializeFailed }`.
489 ///
490 /// Ownership note: `bytes: Vec<u8>` by value (not `&[u8]` or
491 /// `Cow`). This is the framework→backend handoff, NOT an external
492 /// boundary — the backend lives inside the framework ecosystem
493 /// and plays by the runtime contract. Principle 1a (ephemeral
494 /// borrowed slices at external boundaries) does not apply here:
495 /// the framework copied or owned the bytes already, and a backend
496 /// that wants to adopt them (zero-copy) needs ownership.
497 fn materialize_from_wire(
498 &self,
499 type_hash: u64,
500 bytes: Vec<u8>,
501 ) -> Result<Self::Tensor, Self::Error> {
502 use crate::contracts::backend_default_walk::BackendWalkError;
503 let decoder = bb_ir::slot_value::wire_decoder_registry()
504 .get(&type_hash)
505 .copied()
506 .ok_or_else(|| BackendWalkError::WireMaterializeFailed {
507 type_hash,
508 reason: "no decoder registered for type_hash".into(),
509 })?;
510 let boxed = decoder(&bytes).map_err(|e| BackendWalkError::WireMaterializeFailed {
511 type_hash,
512 reason: e.to_string(),
513 })?;
514 let any = boxed.into_any_boxed();
515 any.downcast::<Self::Tensor>().map(|b| *b).map_err(|_| {
516 BackendWalkError::WireMaterializeFailed {
517 type_hash,
518 reason: "decoded carrier is not Self::Tensor".into(),
519 }
520 .into()
521 })
522 }
523}