Skip to main content

bb_runtime/contracts/
backend_default_walk.rs

1//! Shared default impls for the `Backend` Contract trait.
2//!
3//! The Contract surface is *thirty typed per-op methods*
4//! ([`super::Backend::add`], [`super::Backend::matmul`], …) plus
5//! [`super::Backend::execute`] (`&GraphProto, HashMap` →
6//! `HashMap`). Each side has a default body that calls into the
7//! other so backend authors override whichever side is natural:
8//!
9//! - **Override per-op methods** (e.g. `CpuBackend` over ndarray):
10//!   each Contract method is a direct kernel call. The default
11//!   `execute` walks the graph node-by-node and dispatches through
12//!   the overridden per-op methods.
13//!
14//! - **Override `execute`** (e.g. a graph-compiling backend like
15//!   Burn): the whole `GraphProto` body is handed to the native
16//!   execution engine. The per-op defaults wrap a single-node
17//!   `GraphProto` and call back into `execute`.
18//!
19//! Pathological case: a backend that overrides *neither* side
20//! stack-overflows — every `add` call walks into a single-node
21//! graph that calls `add` again. Document loudly on the trait.
22//!
23//! This module also encodes the attribute conventions every
24//! `BackendSubgraph_*` carrier uses for primitive ops with
25//! attributes (`ReduceSum.axes`, `Reshape.shape`, `Cast.to`, …).
26//! Per-op defaults call [`ints_attr`] / [`int_attr`] / [`tensor_attr`]
27//! to encode; the walker calls `attr_ints` / `attr_int` /
28//! `attr_tensor` to decode. ONNX-style names are preserved
29//! (`axes`, `keepdims`, `shape`, `perm`, `axis`, `to`, `value`).
30
31use std::collections::HashMap;
32
33use bb_ir::proto::onnx::{AttributeProto, GraphProto, NodeProto, TensorProto, ValueInfoProto};
34
35use super::backend::Backend;
36
37const SINGLE_OP_OUTPUT_NAME: &str = "__bb_default_walk_output";
38
39/// Failures the default walker surfaces when handed a malformed
40/// `GraphProto` body or when a `Backend::execute` impl violates
41/// its output-name contract. Required `From` bound on
42/// [`Backend::Error`] makes the walker fail with a typed error
43/// instead of `panic!`-ing on peer-supplied or buggy input.
44#[derive(Debug, Clone, PartialEq, Eq)]
45pub enum BackendWalkError {
46    /// A `NodeProto` references an input value not in the running
47    /// environment. The graph either uses a value the caller didn't
48    /// bind OR an upstream node failed to populate one of its
49    /// declared outputs.
50    MissingInput {
51        /// `op_type` of the consuming node.
52        op_type: String,
53        /// Name of the missing input value.
54        input_name: String,
55    },
56    /// A per-op method produced a different number of outputs than
57    /// the consuming `NodeProto` declares. Indicates a compiler
58    /// or graph-builder bug.
59    OutputArityMismatch {
60        /// `op_type` of the node.
61        op_type: String,
62        /// Number of outputs the per-op method produced.
63        produced: usize,
64        /// Number of outputs the graph declares.
65        declared: usize,
66    },
67    /// `op_type` is not one of [`bb_ir::tensor_primitives::TENSOR_PRIMITIVES_OPS`].
68    /// Backends that use the default per-op walker must keep graphs
69    /// to primitives only. The wire path can hit this when an
70    /// adversarial peer ships a `BackendSubgraph_*` carrier whose
71    /// body references an extension op.
72    UnknownOpType(String),
73    /// A `Backend::execute` impl returned successfully but did not
74    /// populate the declared output `output_name` in the result
75    /// map. Indicates a Backend impl bug — `execute` MUST honor
76    /// the graph's output names.
77    MissingExecuteOutput {
78        /// `op_type` of the single-node graph that was executed.
79        op_type: String,
80        /// Output name the graph declared but `execute` omitted.
81        output_name: String,
82    },
83    /// Default [`crate::contracts::Backend::materialize_from_wire`]
84    /// failed before reaching the backend: no registered decoder,
85    /// the decoder rejected the bytes, or the decoded carrier was
86    /// not assignable to `Self::Tensor`. Backends overriding
87    /// `materialize_from_wire` never surface this — they emit their
88    /// own typed error.
89    WireMaterializeFailed {
90        /// `type_hash` the inbound `SlotFill` advertised.
91        type_hash: u64,
92        /// Why the default path could not produce a `Self::Tensor`.
93        reason: String,
94    },
95}
96
97impl std::fmt::Display for BackendWalkError {
98    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
99        match self {
100            Self::MissingInput { op_type, input_name } => write!(
101                f,
102                "Backend default walker: `{op_type}` references input `{input_name}` not in the value env",
103            ),
104            Self::OutputArityMismatch { op_type, produced, declared } => write!(
105                f,
106                "Backend default walker: per-op `{op_type}` produced {produced} outputs but graph declares {declared}",
107            ),
108            Self::UnknownOpType(op_type) => write!(
109                f,
110                "Backend default walker: op_type `{op_type}` is not in TENSOR_PRIMITIVES_OPS",
111            ),
112            Self::MissingExecuteOutput { op_type, output_name } => write!(
113                f,
114                "Backend::execute (op_type `{op_type}`) did not produce its declared output `{output_name}`",
115            ),
116            Self::WireMaterializeFailed { type_hash, reason } => write!(
117                f,
118                "Backend default materialize_from_wire (type_hash {type_hash:#018x}): {reason}",
119            ),
120        }
121    }
122}
123
124impl std::error::Error for BackendWalkError {}
125
126/// Default body for every per-op method on [`Backend`] — wraps the
127/// op in a one-node `GraphProto` and routes through
128/// [`Backend::execute`]. Backends that override `execute` natively
129/// (graph-compiling backends) get the per-op surface free; backends
130/// that override the per-op methods directly bypass this helper
131/// entirely.
132///
133/// `attributes` carries the ONNX-style per-op attributes the typed
134/// per-op signatures encode (`ReduceSum`'s `axes` + `keepdims`,
135/// `Reshape`'s `shape`, etc.). The walker decodes them back via
136/// `attr_ints` / `attr_int` / `attr_tensor` before calling
137/// the typed per-op method on the backend.
138pub fn execute_single<B: Backend + ?Sized>(
139    backend: &B,
140    op_type: &str,
141    inputs: &[&B::Tensor],
142    attributes: Vec<AttributeProto>,
143) -> Result<B::Tensor, B::Error> {
144    let input_names: Vec<String> = (0..inputs.len())
145        .map(|i| format!("__bb_default_walk_in_{i}"))
146        .collect();
147
148    let node = NodeProto {
149        op_type: op_type.to_string(),
150        input: input_names.clone(),
151        output: vec![SINGLE_OP_OUTPUT_NAME.to_string()],
152        attribute: attributes,
153        ..Default::default()
154    };
155    let graph = GraphProto {
156        node: vec![node],
157        output: vec![ValueInfoProto {
158            name: SINGLE_OP_OUTPUT_NAME.to_string(),
159            ..Default::default()
160        }],
161        ..Default::default()
162    };
163
164    let input_map: HashMap<String, B::Tensor> = input_names
165        .into_iter()
166        .zip(inputs.iter().map(|t| (*t).clone()))
167        .collect();
168
169    let mut output_map = backend.execute(
170        &graph,
171        input_map,
172        super::backend::BackendAttrs {
173            current_node_attributes: &[],
174            current_node_metadata: &[],
175        },
176    )?;
177    let result = output_map.remove(SINGLE_OP_OUTPUT_NAME).ok_or_else(|| {
178        BackendWalkError::MissingExecuteOutput {
179            op_type: op_type.to_string(),
180            output_name: SINGLE_OP_OUTPUT_NAME.to_string(),
181        }
182    })?;
183    Ok(result)
184}
185
186/// Default body for multi-output per-op methods ([`Backend::split`]
187/// today). Builds a one-node `GraphProto` with `output_count`
188/// positionally-named outputs, calls [`Backend::execute`], and
189/// extracts the outputs in declared order.
190///
191/// If `output_count == 0`, returns an empty `Vec` without invoking
192/// `execute` — multi-output ops with zero outputs are degenerate
193/// and the engine never produces such carriers.
194pub fn execute_multi<B: Backend + ?Sized>(
195    backend: &B,
196    op_type: &str,
197    inputs: &[&B::Tensor],
198    attributes: Vec<AttributeProto>,
199    output_count: usize,
200) -> Result<Vec<B::Tensor>, B::Error> {
201    if output_count == 0 {
202        return Ok(Vec::new());
203    }
204
205    let input_names: Vec<String> = (0..inputs.len())
206        .map(|i| format!("__bb_default_walk_in_{i}"))
207        .collect();
208    let output_names: Vec<String> = (0..output_count)
209        .map(|i| format!("__bb_default_walk_out_{i}"))
210        .collect();
211
212    let node = NodeProto {
213        op_type: op_type.to_string(),
214        input: input_names.clone(),
215        output: output_names.clone(),
216        attribute: attributes,
217        ..Default::default()
218    };
219    let graph = GraphProto {
220        node: vec![node],
221        output: output_names
222            .iter()
223            .map(|n| ValueInfoProto {
224                name: n.clone(),
225                ..Default::default()
226            })
227            .collect(),
228        ..Default::default()
229    };
230
231    let input_map: HashMap<String, B::Tensor> = input_names
232        .into_iter()
233        .zip(inputs.iter().map(|t| (*t).clone()))
234        .collect();
235
236    let mut output_map = backend.execute(
237        &graph,
238        input_map,
239        super::backend::BackendAttrs {
240            current_node_attributes: &[],
241            current_node_metadata: &[],
242        },
243    )?;
244    output_names
245        .into_iter()
246        .map(|n| {
247            output_map.remove(&n).ok_or_else(|| {
248                BackendWalkError::MissingExecuteOutput {
249                    op_type: op_type.to_string(),
250                    output_name: n,
251                }
252                .into()
253            })
254        })
255        .collect()
256}
257
258/// Default body for [`Backend::execute`] — walks `graph.node` in
259/// topological order, dispatching each through the typed per-op
260/// methods on `backend`. The implementation is a tight linear scan:
261/// ONNX guarantees `graph.node` is topologically ordered, so no
262/// petgraph / explicit ordering is needed.
263///
264/// Op-types must be in [`bb_ir::tensor_primitives::TENSOR_PRIMITIVES_OPS`].
265/// A graph containing extension ops (Relu, Conv, …) needs either a
266/// backend that overrides `execute` natively OR a lowering pass
267/// (future work) decomposing the extensions into primitives.
268pub fn execute_graph_via_per_op<B: Backend + ?Sized>(
269    backend: &B,
270    graph: &GraphProto,
271    inputs: HashMap<String, B::Tensor>,
272) -> Result<HashMap<String, B::Tensor>, B::Error> {
273    let mut env: HashMap<String, B::Tensor> = inputs;
274
275    for node in &graph.node {
276        let input_tensors: Vec<&B::Tensor> = node
277            .input
278            .iter()
279            .filter(|n| !n.is_empty())
280            .map(|n| {
281                env.get(n).ok_or_else(|| BackendWalkError::MissingInput {
282                    op_type: node.op_type.clone(),
283                    input_name: n.clone(),
284                })
285            })
286            .collect::<Result<Vec<&B::Tensor>, BackendWalkError>>()
287            .map_err(B::Error::from)?;
288
289        let outputs = dispatch_per_op(backend, &node.op_type, &input_tensors, &node.attribute)?;
290
291        for (i, name) in node.output.iter().enumerate() {
292            if name.is_empty() {
293                continue;
294            }
295            let Some(tensor) = outputs.get(i) else {
296                return Err(BackendWalkError::OutputArityMismatch {
297                    op_type: node.op_type.clone(),
298                    produced: outputs.len(),
299                    declared: node.output.len(),
300                }
301                .into());
302            };
303            env.insert(name.clone(), tensor.clone());
304        }
305    }
306
307    let mut result: HashMap<String, B::Tensor> = HashMap::new();
308    for vi in &graph.output {
309        if let Some(t) = env.remove(&vi.name) {
310            result.insert(vi.name.clone(), t);
311        }
312    }
313    Ok(result)
314}
315
316/// Dispatch a single `NodeProto` (whose `op_type` MUST be one of
317/// the 30 `TENSOR_PRIMITIVES_OPS`) through the appropriate typed
318/// per-op method on `backend`. Returns a `Vec` to handle multi-
319/// output primitives (`Split`).
320fn dispatch_per_op<B: Backend + ?Sized>(
321    backend: &B,
322    op_type: &str,
323    inputs: &[&B::Tensor],
324    attrs: &[AttributeProto],
325) -> Result<Vec<B::Tensor>, B::Error> {
326    let single = |t: B::Tensor| Ok(vec![t]);
327    match op_type {
328        // Arithmetic
329        "Add" => single(backend.add(inputs[0], inputs[1])?),
330        "Sub" => single(backend.sub(inputs[0], inputs[1])?),
331        "Mul" => single(backend.mul(inputs[0], inputs[1])?),
332        "Div" => single(backend.div(inputs[0], inputs[1])?),
333        "Neg" => single(backend.neg(inputs[0])?),
334        "Abs" => single(backend.abs(inputs[0])?),
335        // Math
336        "Sqrt" => single(backend.sqrt(inputs[0])?),
337        "Pow" => single(backend.pow(inputs[0], inputs[1])?),
338        "Exp" => single(backend.exp(inputs[0])?),
339        "Log" => single(backend.log(inputs[0])?),
340        // Linear algebra
341        "MatMul" => single(backend.matmul(inputs[0], inputs[1])?),
342        // Reductions
343        "ReduceSum" => single(backend.reduce_sum(
344            inputs[0],
345            &attr_ints(attrs, "axes"),
346            attr_int(attrs, "keepdims", 1) != 0,
347        )?),
348        "ReduceMean" => single(backend.reduce_mean(
349            inputs[0],
350            &attr_ints(attrs, "axes"),
351            attr_int(attrs, "keepdims", 1) != 0,
352        )?),
353        "ReduceMax" => single(backend.reduce_max(
354            inputs[0],
355            &attr_ints(attrs, "axes"),
356            attr_int(attrs, "keepdims", 1) != 0,
357        )?),
358        "ReduceMin" => single(backend.reduce_min(
359            inputs[0],
360            &attr_ints(attrs, "axes"),
361            attr_int(attrs, "keepdims", 1) != 0,
362        )?),
363        // Shape
364        "Reshape" => single(backend.reshape(inputs[0], &attr_ints(attrs, "shape"))?),
365        "Transpose" => single(backend.transpose(inputs[0], &attr_ints(attrs, "perm"))?),
366        "Concat" => single(backend.concat(inputs, attr_int(attrs, "axis", 0))?),
367        "Slice" => single(backend.slice(
368            inputs[0],
369            &attr_ints(attrs, "starts"),
370            &attr_ints(attrs, "ends"),
371            &attr_ints(attrs, "axes"),
372            &attr_ints(attrs, "steps"),
373        )?),
374        "Split" => Ok(backend.split(
375            inputs[0],
376            attr_int(attrs, "axis", 0),
377            &attr_ints(attrs, "split"),
378        )?),
379        "Squeeze" => single(backend.squeeze(inputs[0], &attr_ints(attrs, "axes"))?),
380        "Unsqueeze" => single(backend.unsqueeze(inputs[0], &attr_ints(attrs, "axes"))?),
381        "Identity" => single(backend.identity(inputs[0])?),
382        "Cast" => single(backend.cast(inputs[0], attr_int(attrs, "to", 1) as i32)?),
383        // Comparison
384        "Equal" => single(backend.equal(inputs[0], inputs[1])?),
385        "Greater" => single(backend.greater(inputs[0], inputs[1])?),
386        "Less" => single(backend.less(inputs[0], inputs[1])?),
387        // Conditional
388        "Where" => single(backend.r#where(inputs[0], inputs[1], inputs[2])?),
389        // Creation
390        "Constant" => single(backend.constant(attr_tensor(attrs, "value").unwrap_or_default())?),
391        // Indexing
392        "Gather" => single(backend.gather(inputs[0], inputs[1], attr_int(attrs, "axis", 0))?),
393        other => Err(BackendWalkError::UnknownOpType(other.to_string()).into()),
394    }
395}
396
397// ──────────────────────────────────────────────────────────────────
398// Attribute encoders — called from the Contract per-op defaults.
399// ──────────────────────────────────────────────────────────────────
400
401/// Build an `AttributeProto` of type `INT` (per ONNX
402/// [`AttributeType::Int`]). Used by the Contract per-op defaults
403/// for scalar `i64` attributes (`axis`, `to`, `keepdims`).
404pub fn int_attr(name: &str, value: i64) -> AttributeProto {
405    AttributeProto {
406        name: name.to_string(),
407        r#type: bb_ir::proto::onnx::attribute_proto::AttributeType::Int as i32,
408        i: value,
409        ..Default::default()
410    }
411}
412
413/// Build an `AttributeProto` of type `INTS`. Used for vector
414/// attributes (`axes`, `shape`, `perm`, `starts`, `ends`, `steps`,
415/// `split`).
416pub fn ints_attr(name: &str, values: &[i64]) -> AttributeProto {
417    AttributeProto {
418        name: name.to_string(),
419        r#type: bb_ir::proto::onnx::attribute_proto::AttributeType::Ints as i32,
420        ints: values.to_vec(),
421        ..Default::default()
422    }
423}
424
425/// Build an `AttributeProto` of type `TENSOR`. Used for
426/// `Constant`'s `value` attribute.
427pub fn tensor_attr(name: &str, tensor: TensorProto) -> AttributeProto {
428    AttributeProto {
429        name: name.to_string(),
430        r#type: bb_ir::proto::onnx::attribute_proto::AttributeType::Tensor as i32,
431        t: Some(tensor),
432        ..Default::default()
433    }
434}
435
436// ──────────────────────────────────────────────────────────────────
437// Attribute decoders — called from the walker.
438// ──────────────────────────────────────────────────────────────────
439
440fn attr_int(attrs: &[AttributeProto], name: &str, default: i64) -> i64 {
441    attrs
442        .iter()
443        .find(|a| a.name == name)
444        .map(|a| a.i)
445        .unwrap_or(default)
446}
447
448fn attr_ints(attrs: &[AttributeProto], name: &str) -> Vec<i64> {
449    attrs
450        .iter()
451        .find(|a| a.name == name)
452        .map(|a| a.ints.clone())
453        .unwrap_or_default()
454}
455
456fn attr_tensor(attrs: &[AttributeProto], name: &str) -> Option<TensorProto> {
457    attrs
458        .iter()
459        .find(|a| a.name == name)
460        .and_then(|a| a.t.clone())
461}
462