Skip to main content

bb_ops/backends/cpu/ops/
mod.rs

1//! Per-op kernel dispatch for the `CpuBackend`.
2//!
3//! The 51-op `ai.onnx v1` catalog is grouped into category modules.
4//! This module routes each declared op to its category-specific
5//! kernel; categories that don't yet have implementations return a
6//! clear "not yet implemented" error.
7//!
8//! **Attribute access limitation.** `BackendRuntime::dispatch_atomic`
9//! receives `(op_type, inputs, attrs)` - no NodeProto, no attribute
10//! map. Ops whose semantics depend on attributes (Reshape's `dims`,
11//! Softmax's `axis`, Conv's `kernel_shape`, etc.) cannot be fully
12//! implemented within this dispatch shape. They report
13//! `RequiresAttributes` so callers can route them through a
14//! follow-up channel.
15
16mod elementwise;
17mod linalg;
18mod shape;
19
20use bb_ir::proto::onnx::{attribute_proto::AttributeType, AttributeProto};
21use bb_runtime::atomic::DispatchResult;
22use bb_runtime::bus::OpError;
23use bb_runtime::slot_value::SlotValue;
24
25use crate::backends::cpu::CpuBackend;
26use crate::backends::cpu::CpuTensor;
27
28/// Find an attribute by name in the supplied `attrs` slice.
29fn find_attr<'a>(attrs: &'a [AttributeProto], name: &str) -> Option<&'a AttributeProto> {
30    attrs.iter().find(|a| a.name == name)
31}
32
33fn need_int_attr(op: &str, attrs: &[AttributeProto], name: &str) -> Result<i64, OpError> {
34    let a = find_attr(attrs, name).ok_or_else(|| OpError {
35        detail: format!("{op}: missing `{name}` attribute"),
36        ..Default::default()
37    })?;
38    if a.r#type != AttributeType::Int as i32 {
39        return Err(OpError {
40            detail: format!("{op}: `{name}` attribute must be INT"),
41            ..Default::default()
42        });
43    }
44    Ok(a.i)
45}
46
47fn need_ints_attr(op: &str, attrs: &[AttributeProto], name: &str) -> Result<Vec<i64>, OpError> {
48    let a = find_attr(attrs, name).ok_or_else(|| OpError {
49        detail: format!("{op}: missing `{name}` attribute"),
50        ..Default::default()
51    })?;
52    if a.r#type != AttributeType::Ints as i32 {
53        return Err(OpError {
54            detail: format!("{op}: `{name}` attribute must be INTS"),
55            ..Default::default()
56        });
57    }
58    Ok(a.ints.clone())
59}
60
61fn opt_float_attr(attrs: &[AttributeProto], name: &str, default: f32) -> f32 {
62    find_attr(attrs, name)
63        .filter(|a| a.r#type == AttributeType::Float as i32)
64        .map(|a| a.f)
65        .unwrap_or(default)
66}
67
68fn opt_int_attr(attrs: &[AttributeProto], name: &str, default: i64) -> i64 {
69    find_attr(attrs, name)
70        .filter(|a| a.r#type == AttributeType::Int as i32)
71        .map(|a| a.i)
72        .unwrap_or(default)
73}
74
75/// Downcast a `&dyn SlotValue` to `&CpuTensor`, raising `OpError`
76/// with a clear message on type mismatch.
77fn as_cpu_tensor<'a>(op: &str, role: &str, h: &'a dyn SlotValue) -> Result<&'a CpuTensor, OpError> {
78    h.as_any()
79        .downcast_ref::<CpuTensor>()
80        .ok_or_else(|| OpError {
81            detail: format!("{op}: {role} is not a CpuTensor"),
82            ..Default::default()
83        })
84}
85
86fn need_two_inputs<'a>(
87    op: &str,
88    inputs: &'a [(&str, &dyn SlotValue)],
89) -> Result<(&'a CpuTensor, &'a CpuTensor), OpError> {
90    if inputs.len() < 2 {
91        return Err(OpError {
92            detail: format!("{op}: requires two inputs, got {}", inputs.len()),
93            ..Default::default()
94        });
95    }
96    let a = as_cpu_tensor(op, "input 0", inputs[0].1)?;
97    let b = as_cpu_tensor(op, "input 1", inputs[1].1)?;
98    Ok((a, b))
99}
100
101fn need_one_input<'a>(
102    op: &str,
103    inputs: &'a [(&str, &dyn SlotValue)],
104) -> Result<&'a CpuTensor, OpError> {
105    if inputs.is_empty() {
106        return Err(OpError {
107            detail: format!("{op}: requires one input, got 0"),
108            ..Default::default()
109        });
110    }
111    as_cpu_tensor(op, "input 0", inputs[0].1)
112}
113
114fn out(name: &str, tensor: CpuTensor) -> DispatchResult {
115    DispatchResult::Immediate(vec![(name.to_string(), Box::new(tensor))])
116}
117
118/// Route `op_type` to the matching kernel. `attrs` is the
119/// NodeProto's attribute slice — ops whose semantics depend on
120/// attributes (Reshape's `dims`, Softmax's `axis`, Gemm's
121/// `alpha`/`beta`, …) read it directly. The framework's per-op
122/// dispatch path threads `ctx.current.node_attributes` here; the
123/// `execute_graph` walker threads `node.attribute` per node.
124pub fn dispatch(
125    backend: &CpuBackend,
126    op_type: &str,
127    inputs: &[(&str, &dyn SlotValue)],
128    attrs: &[AttributeProto],
129) -> Result<DispatchResult, OpError> {
130    match op_type {
131        // --- Element-wise binary -----------------------------------
132        "Add" => Ok(out("C", elementwise::add(backend, op_type, inputs)?)),
133        "Sub" => Ok(out("C", elementwise::sub(backend, op_type, inputs)?)),
134        "Mul" => Ok(out("C", elementwise::mul(backend, op_type, inputs)?)),
135        "Div" => Ok(out("C", elementwise::div(backend, op_type, inputs)?)),
136        "Pow" => Ok(out("C", elementwise::pow(backend, op_type, inputs)?)),
137
138        // --- Element-wise unary ------------------------------------
139        "Neg" => Ok(out("Y", elementwise::neg(backend, op_type, inputs)?)),
140        "Abs" => Ok(out("Y", elementwise::abs(backend, op_type, inputs)?)),
141        "Sqrt" => Ok(out("Y", elementwise::sqrt(backend, op_type, inputs)?)),
142        "Exp" => Ok(out("Y", elementwise::exp(backend, op_type, inputs)?)),
143        "Log" => Ok(out("Y", elementwise::log(backend, op_type, inputs)?)),
144
145        // --- Activations -------------------------------------------
146        "Relu" => Ok(out("Y", elementwise::relu(backend, op_type, inputs)?)),
147        "Sigmoid" => Ok(out("Y", elementwise::sigmoid(backend, op_type, inputs)?)),
148        "Tanh" => Ok(out("Y", elementwise::tanh(backend, op_type, inputs)?)),
149        "Gelu" => Ok(out("Y", elementwise::gelu(backend, op_type, inputs)?)),
150        "Identity" => Ok(out("Y", elementwise::identity(backend, op_type, inputs)?)),
151        "Softmax" => Ok(out("Y", shape::softmax(backend, op_type, inputs, attrs)?)),
152        "LeakyRelu" => Ok(out(
153            "Y",
154            shape::leaky_relu(backend, op_type, inputs, attrs)?,
155        )),
156
157        // --- Element-wise comparison -------------------------------
158        "Equal" => Ok(out("C", elementwise::equal(backend, op_type, inputs)?)),
159        "Greater" => Ok(out("C", elementwise::greater(backend, op_type, inputs)?)),
160        "Less" => Ok(out("C", elementwise::less(backend, op_type, inputs)?)),
161
162        // --- Linear algebra ----------------------------------------
163        "MatMul" => Ok(out("Y", linalg::matmul(backend, op_type, inputs)?)),
164        "Dot" => Ok(out("Y", linalg::dot(backend, op_type, inputs)?)),
165        "Gemm" => Ok(out("Y", shape::gemm(backend, op_type, inputs, attrs)?)),
166
167        // --- Reductions --------------------------------------------
168        "ReduceSum" => Ok(out("Y", linalg::reduce_sum(backend, op_type, inputs)?)),
169        "ReduceMean" => Ok(out("Y", linalg::reduce_mean(backend, op_type, inputs)?)),
170        "ReduceMax" => Ok(out("Y", linalg::reduce_max(backend, op_type, inputs)?)),
171        "ReduceMin" => Ok(out("Y", linalg::reduce_min(backend, op_type, inputs)?)),
172
173        // --- Shape / structural ------------------------------------
174        "Reshape" => Ok(out("Y", shape::reshape(backend, op_type, inputs, attrs)?)),
175        "Transpose" => Ok(out("Y", shape::transpose(backend, op_type, inputs, attrs)?)),
176        "Concat" => Ok(out("Y", shape::concat(backend, op_type, inputs, attrs)?)),
177        "Squeeze" => Ok(out("Y", shape::squeeze(backend, op_type, inputs, attrs)?)),
178        "Unsqueeze" => Ok(out("Y", shape::unsqueeze(backend, op_type, inputs, attrs)?)),
179        "Cast" => Ok(out("Y", shape::cast(backend, op_type, inputs, attrs)?)),
180        "Slice" => Ok(out("Y", shape::slice(backend, op_type, inputs, attrs)?)),
181        "Split" => Ok(shape::split(backend, op_type, inputs, attrs)?),
182
183        // --- Indexing ----------------------------------------------
184        "Gather" => Ok(out("Y", shape::gather(backend, op_type, inputs, attrs)?)),
185
186        // --- Pooling without attributes ----------------------------
187        "GlobalAveragePool" => Ok(out(
188            "Y",
189            linalg::global_average_pool(backend, op_type, inputs)?,
190        )),
191
192        // --- Creation ----------------------------------------------
193        "Zeros" => Ok(out("Y", shape::zeros(backend, op_type, attrs)?)),
194        "Ones" => Ok(out("Y", shape::ones(backend, op_type, attrs)?)),
195        "Constant" => Ok(out("Y", shape::constant(backend, op_type, attrs)?)),
196
197        // `BackendSubgraph` is routed by the engine's
198        // `invoke_backend_subgraph` path, not this dispatch.
199        // BatchNorm / LayerNorm / Conv / MaxPool / AveragePool /
200        // Scatter / If / Loop are NOT in this backend's declared
201        // opset — `Node::ready`'s per-`BackendSubgraph` check
202        // catches them at install time. They surface here only as
203        // a defensive fallthrough.
204        other => Err(OpError {
205            detail: format!("CpuBackend: unsupported op_type '{other}'"),
206            ..Default::default()
207        }),
208    }
209}
210