1use std::collections::HashMap;
32
33use bb_ir::proto::onnx::{AttributeProto, GraphProto, NodeProto, TensorProto, ValueInfoProto};
34
35use super::backend::Backend;
36
37const SINGLE_OP_OUTPUT_NAME: &str = "__bb_default_walk_output";
38
39#[derive(Debug, Clone, PartialEq, Eq)]
45pub enum BackendWalkError {
46 MissingInput {
51 op_type: String,
53 input_name: String,
55 },
56 OutputArityMismatch {
60 op_type: String,
62 produced: usize,
64 declared: usize,
66 },
67 UnknownOpType(String),
73 MissingExecuteOutput {
78 op_type: String,
80 output_name: String,
82 },
83 WireMaterializeFailed {
90 type_hash: u64,
92 reason: String,
94 },
95}
96
97impl std::fmt::Display for BackendWalkError {
98 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
99 match self {
100 Self::MissingInput { op_type, input_name } => write!(
101 f,
102 "Backend default walker: `{op_type}` references input `{input_name}` not in the value env",
103 ),
104 Self::OutputArityMismatch { op_type, produced, declared } => write!(
105 f,
106 "Backend default walker: per-op `{op_type}` produced {produced} outputs but graph declares {declared}",
107 ),
108 Self::UnknownOpType(op_type) => write!(
109 f,
110 "Backend default walker: op_type `{op_type}` is not in TENSOR_PRIMITIVES_OPS",
111 ),
112 Self::MissingExecuteOutput { op_type, output_name } => write!(
113 f,
114 "Backend::execute (op_type `{op_type}`) did not produce its declared output `{output_name}`",
115 ),
116 Self::WireMaterializeFailed { type_hash, reason } => write!(
117 f,
118 "Backend default materialize_from_wire (type_hash {type_hash:#018x}): {reason}",
119 ),
120 }
121 }
122}
123
124impl std::error::Error for BackendWalkError {}
125
126pub fn execute_single<B: Backend + ?Sized>(
139 backend: &B,
140 op_type: &str,
141 inputs: &[&B::Tensor],
142 attributes: Vec<AttributeProto>,
143) -> Result<B::Tensor, B::Error> {
144 let input_names: Vec<String> = (0..inputs.len())
145 .map(|i| format!("__bb_default_walk_in_{i}"))
146 .collect();
147
148 let node = NodeProto {
149 op_type: op_type.to_string(),
150 input: input_names.clone(),
151 output: vec![SINGLE_OP_OUTPUT_NAME.to_string()],
152 attribute: attributes,
153 ..Default::default()
154 };
155 let graph = GraphProto {
156 node: vec![node],
157 output: vec![ValueInfoProto {
158 name: SINGLE_OP_OUTPUT_NAME.to_string(),
159 ..Default::default()
160 }],
161 ..Default::default()
162 };
163
164 let input_map: HashMap<String, B::Tensor> = input_names
165 .into_iter()
166 .zip(inputs.iter().map(|t| (*t).clone()))
167 .collect();
168
169 let mut output_map = backend.execute(
170 &graph,
171 input_map,
172 super::backend::BackendAttrs {
173 current_node_attributes: &[],
174 current_node_metadata: &[],
175 },
176 )?;
177 let result = output_map.remove(SINGLE_OP_OUTPUT_NAME).ok_or_else(|| {
178 BackendWalkError::MissingExecuteOutput {
179 op_type: op_type.to_string(),
180 output_name: SINGLE_OP_OUTPUT_NAME.to_string(),
181 }
182 })?;
183 Ok(result)
184}
185
186pub fn execute_multi<B: Backend + ?Sized>(
195 backend: &B,
196 op_type: &str,
197 inputs: &[&B::Tensor],
198 attributes: Vec<AttributeProto>,
199 output_count: usize,
200) -> Result<Vec<B::Tensor>, B::Error> {
201 if output_count == 0 {
202 return Ok(Vec::new());
203 }
204
205 let input_names: Vec<String> = (0..inputs.len())
206 .map(|i| format!("__bb_default_walk_in_{i}"))
207 .collect();
208 let output_names: Vec<String> = (0..output_count)
209 .map(|i| format!("__bb_default_walk_out_{i}"))
210 .collect();
211
212 let node = NodeProto {
213 op_type: op_type.to_string(),
214 input: input_names.clone(),
215 output: output_names.clone(),
216 attribute: attributes,
217 ..Default::default()
218 };
219 let graph = GraphProto {
220 node: vec![node],
221 output: output_names
222 .iter()
223 .map(|n| ValueInfoProto {
224 name: n.clone(),
225 ..Default::default()
226 })
227 .collect(),
228 ..Default::default()
229 };
230
231 let input_map: HashMap<String, B::Tensor> = input_names
232 .into_iter()
233 .zip(inputs.iter().map(|t| (*t).clone()))
234 .collect();
235
236 let mut output_map = backend.execute(
237 &graph,
238 input_map,
239 super::backend::BackendAttrs {
240 current_node_attributes: &[],
241 current_node_metadata: &[],
242 },
243 )?;
244 output_names
245 .into_iter()
246 .map(|n| {
247 output_map.remove(&n).ok_or_else(|| {
248 BackendWalkError::MissingExecuteOutput {
249 op_type: op_type.to_string(),
250 output_name: n,
251 }
252 .into()
253 })
254 })
255 .collect()
256}
257
258pub fn execute_graph_via_per_op<B: Backend + ?Sized>(
269 backend: &B,
270 graph: &GraphProto,
271 inputs: HashMap<String, B::Tensor>,
272) -> Result<HashMap<String, B::Tensor>, B::Error> {
273 let mut env: HashMap<String, B::Tensor> = inputs;
274
275 for node in &graph.node {
276 let input_tensors: Vec<&B::Tensor> = node
277 .input
278 .iter()
279 .filter(|n| !n.is_empty())
280 .map(|n| {
281 env.get(n).ok_or_else(|| BackendWalkError::MissingInput {
282 op_type: node.op_type.clone(),
283 input_name: n.clone(),
284 })
285 })
286 .collect::<Result<Vec<&B::Tensor>, BackendWalkError>>()
287 .map_err(B::Error::from)?;
288
289 let outputs = dispatch_per_op(backend, &node.op_type, &input_tensors, &node.attribute)?;
290
291 for (i, name) in node.output.iter().enumerate() {
292 if name.is_empty() {
293 continue;
294 }
295 let Some(tensor) = outputs.get(i) else {
296 return Err(BackendWalkError::OutputArityMismatch {
297 op_type: node.op_type.clone(),
298 produced: outputs.len(),
299 declared: node.output.len(),
300 }
301 .into());
302 };
303 env.insert(name.clone(), tensor.clone());
304 }
305 }
306
307 let mut result: HashMap<String, B::Tensor> = HashMap::new();
308 for vi in &graph.output {
309 if let Some(t) = env.remove(&vi.name) {
310 result.insert(vi.name.clone(), t);
311 }
312 }
313 Ok(result)
314}
315
316fn dispatch_per_op<B: Backend + ?Sized>(
321 backend: &B,
322 op_type: &str,
323 inputs: &[&B::Tensor],
324 attrs: &[AttributeProto],
325) -> Result<Vec<B::Tensor>, B::Error> {
326 let single = |t: B::Tensor| Ok(vec![t]);
327 match op_type {
328 "Add" => single(backend.add(inputs[0], inputs[1])?),
330 "Sub" => single(backend.sub(inputs[0], inputs[1])?),
331 "Mul" => single(backend.mul(inputs[0], inputs[1])?),
332 "Div" => single(backend.div(inputs[0], inputs[1])?),
333 "Neg" => single(backend.neg(inputs[0])?),
334 "Abs" => single(backend.abs(inputs[0])?),
335 "Sqrt" => single(backend.sqrt(inputs[0])?),
337 "Pow" => single(backend.pow(inputs[0], inputs[1])?),
338 "Exp" => single(backend.exp(inputs[0])?),
339 "Log" => single(backend.log(inputs[0])?),
340 "MatMul" => single(backend.matmul(inputs[0], inputs[1])?),
342 "ReduceSum" => single(backend.reduce_sum(
344 inputs[0],
345 &attr_ints(attrs, "axes"),
346 attr_int(attrs, "keepdims", 1) != 0,
347 )?),
348 "ReduceMean" => single(backend.reduce_mean(
349 inputs[0],
350 &attr_ints(attrs, "axes"),
351 attr_int(attrs, "keepdims", 1) != 0,
352 )?),
353 "ReduceMax" => single(backend.reduce_max(
354 inputs[0],
355 &attr_ints(attrs, "axes"),
356 attr_int(attrs, "keepdims", 1) != 0,
357 )?),
358 "ReduceMin" => single(backend.reduce_min(
359 inputs[0],
360 &attr_ints(attrs, "axes"),
361 attr_int(attrs, "keepdims", 1) != 0,
362 )?),
363 "Reshape" => single(backend.reshape(inputs[0], &attr_ints(attrs, "shape"))?),
365 "Transpose" => single(backend.transpose(inputs[0], &attr_ints(attrs, "perm"))?),
366 "Concat" => single(backend.concat(inputs, attr_int(attrs, "axis", 0))?),
367 "Slice" => single(backend.slice(
368 inputs[0],
369 &attr_ints(attrs, "starts"),
370 &attr_ints(attrs, "ends"),
371 &attr_ints(attrs, "axes"),
372 &attr_ints(attrs, "steps"),
373 )?),
374 "Split" => Ok(backend.split(
375 inputs[0],
376 attr_int(attrs, "axis", 0),
377 &attr_ints(attrs, "split"),
378 )?),
379 "Squeeze" => single(backend.squeeze(inputs[0], &attr_ints(attrs, "axes"))?),
380 "Unsqueeze" => single(backend.unsqueeze(inputs[0], &attr_ints(attrs, "axes"))?),
381 "Identity" => single(backend.identity(inputs[0])?),
382 "Cast" => single(backend.cast(inputs[0], attr_int(attrs, "to", 1) as i32)?),
383 "Equal" => single(backend.equal(inputs[0], inputs[1])?),
385 "Greater" => single(backend.greater(inputs[0], inputs[1])?),
386 "Less" => single(backend.less(inputs[0], inputs[1])?),
387 "Where" => single(backend.r#where(inputs[0], inputs[1], inputs[2])?),
389 "Constant" => single(backend.constant(attr_tensor(attrs, "value").unwrap_or_default())?),
391 "Gather" => single(backend.gather(inputs[0], inputs[1], attr_int(attrs, "axis", 0))?),
393 other => Err(BackendWalkError::UnknownOpType(other.to_string()).into()),
394 }
395}
396
397pub fn int_attr(name: &str, value: i64) -> AttributeProto {
405 AttributeProto {
406 name: name.to_string(),
407 r#type: bb_ir::proto::onnx::attribute_proto::AttributeType::Int as i32,
408 i: value,
409 ..Default::default()
410 }
411}
412
413pub fn ints_attr(name: &str, values: &[i64]) -> AttributeProto {
417 AttributeProto {
418 name: name.to_string(),
419 r#type: bb_ir::proto::onnx::attribute_proto::AttributeType::Ints as i32,
420 ints: values.to_vec(),
421 ..Default::default()
422 }
423}
424
425pub fn tensor_attr(name: &str, tensor: TensorProto) -> AttributeProto {
428 AttributeProto {
429 name: name.to_string(),
430 r#type: bb_ir::proto::onnx::attribute_proto::AttributeType::Tensor as i32,
431 t: Some(tensor),
432 ..Default::default()
433 }
434}
435
436fn attr_int(attrs: &[AttributeProto], name: &str, default: i64) -> i64 {
441 attrs
442 .iter()
443 .find(|a| a.name == name)
444 .map(|a| a.i)
445 .unwrap_or(default)
446}
447
448fn attr_ints(attrs: &[AttributeProto], name: &str) -> Vec<i64> {
449 attrs
450 .iter()
451 .find(|a| a.name == name)
452 .map(|a| a.ints.clone())
453 .unwrap_or_default()
454}
455
456fn attr_tensor(attrs: &[AttributeProto], name: &str) -> Option<TensorProto> {
457 attrs
458 .iter()
459 .find(|a| a.name == name)
460 .and_then(|a| a.t.clone())
461}
462