use std::collections::HashMap;
use bb_ir::proto::onnx::{AttributeProto, GraphProto, NodeProto, TensorProto, ValueInfoProto};
use super::backend::Backend;
const SINGLE_OP_OUTPUT_NAME: &str = "__bb_default_walk_output";
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum BackendWalkError {
MissingInput {
op_type: String,
input_name: String,
},
OutputArityMismatch {
op_type: String,
produced: usize,
declared: usize,
},
UnknownOpType(String),
MissingExecuteOutput {
op_type: String,
output_name: String,
},
WireMaterializeFailed {
type_hash: u64,
reason: String,
},
}
impl std::fmt::Display for BackendWalkError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::MissingInput { op_type, input_name } => write!(
f,
"Backend default walker: `{op_type}` references input `{input_name}` not in the value env",
),
Self::OutputArityMismatch { op_type, produced, declared } => write!(
f,
"Backend default walker: per-op `{op_type}` produced {produced} outputs but graph declares {declared}",
),
Self::UnknownOpType(op_type) => write!(
f,
"Backend default walker: op_type `{op_type}` is not in TENSOR_PRIMITIVES_OPS",
),
Self::MissingExecuteOutput { op_type, output_name } => write!(
f,
"Backend::execute (op_type `{op_type}`) did not produce its declared output `{output_name}`",
),
Self::WireMaterializeFailed { type_hash, reason } => write!(
f,
"Backend default materialize_from_wire (type_hash {type_hash:#018x}): {reason}",
),
}
}
}
impl std::error::Error for BackendWalkError {}
pub fn execute_single<B: Backend + ?Sized>(
backend: &B,
op_type: &str,
inputs: &[&B::Tensor],
attributes: Vec<AttributeProto>,
) -> Result<B::Tensor, B::Error> {
let input_names: Vec<String> = (0..inputs.len())
.map(|i| format!("__bb_default_walk_in_{i}"))
.collect();
let node = NodeProto {
op_type: op_type.to_string(),
input: input_names.clone(),
output: vec![SINGLE_OP_OUTPUT_NAME.to_string()],
attribute: attributes,
..Default::default()
};
let graph = GraphProto {
node: vec![node],
output: vec![ValueInfoProto {
name: SINGLE_OP_OUTPUT_NAME.to_string(),
..Default::default()
}],
..Default::default()
};
let input_map: HashMap<String, B::Tensor> = input_names
.into_iter()
.zip(inputs.iter().map(|t| (*t).clone()))
.collect();
let mut output_map = backend.execute(
&graph,
input_map,
super::backend::BackendAttrs {
current_node_attributes: &[],
current_node_metadata: &[],
},
)?;
let result = output_map.remove(SINGLE_OP_OUTPUT_NAME).ok_or_else(|| {
BackendWalkError::MissingExecuteOutput {
op_type: op_type.to_string(),
output_name: SINGLE_OP_OUTPUT_NAME.to_string(),
}
})?;
Ok(result)
}
pub fn execute_multi<B: Backend + ?Sized>(
backend: &B,
op_type: &str,
inputs: &[&B::Tensor],
attributes: Vec<AttributeProto>,
output_count: usize,
) -> Result<Vec<B::Tensor>, B::Error> {
if output_count == 0 {
return Ok(Vec::new());
}
let input_names: Vec<String> = (0..inputs.len())
.map(|i| format!("__bb_default_walk_in_{i}"))
.collect();
let output_names: Vec<String> = (0..output_count)
.map(|i| format!("__bb_default_walk_out_{i}"))
.collect();
let node = NodeProto {
op_type: op_type.to_string(),
input: input_names.clone(),
output: output_names.clone(),
attribute: attributes,
..Default::default()
};
let graph = GraphProto {
node: vec![node],
output: output_names
.iter()
.map(|n| ValueInfoProto {
name: n.clone(),
..Default::default()
})
.collect(),
..Default::default()
};
let input_map: HashMap<String, B::Tensor> = input_names
.into_iter()
.zip(inputs.iter().map(|t| (*t).clone()))
.collect();
let mut output_map = backend.execute(
&graph,
input_map,
super::backend::BackendAttrs {
current_node_attributes: &[],
current_node_metadata: &[],
},
)?;
output_names
.into_iter()
.map(|n| {
output_map.remove(&n).ok_or_else(|| {
BackendWalkError::MissingExecuteOutput {
op_type: op_type.to_string(),
output_name: n,
}
.into()
})
})
.collect()
}
pub fn execute_graph_via_per_op<B: Backend + ?Sized>(
backend: &B,
graph: &GraphProto,
inputs: HashMap<String, B::Tensor>,
) -> Result<HashMap<String, B::Tensor>, B::Error> {
let mut env: HashMap<String, B::Tensor> = inputs;
for node in &graph.node {
let input_tensors: Vec<&B::Tensor> = node
.input
.iter()
.filter(|n| !n.is_empty())
.map(|n| {
env.get(n).ok_or_else(|| BackendWalkError::MissingInput {
op_type: node.op_type.clone(),
input_name: n.clone(),
})
})
.collect::<Result<Vec<&B::Tensor>, BackendWalkError>>()
.map_err(B::Error::from)?;
let outputs = dispatch_per_op(backend, &node.op_type, &input_tensors, &node.attribute)?;
for (i, name) in node.output.iter().enumerate() {
if name.is_empty() {
continue;
}
let Some(tensor) = outputs.get(i) else {
return Err(BackendWalkError::OutputArityMismatch {
op_type: node.op_type.clone(),
produced: outputs.len(),
declared: node.output.len(),
}
.into());
};
env.insert(name.clone(), tensor.clone());
}
}
let mut result: HashMap<String, B::Tensor> = HashMap::new();
for vi in &graph.output {
if let Some(t) = env.remove(&vi.name) {
result.insert(vi.name.clone(), t);
}
}
Ok(result)
}
fn dispatch_per_op<B: Backend + ?Sized>(
backend: &B,
op_type: &str,
inputs: &[&B::Tensor],
attrs: &[AttributeProto],
) -> Result<Vec<B::Tensor>, B::Error> {
let single = |t: B::Tensor| Ok(vec![t]);
match op_type {
"Add" => single(backend.add(inputs[0], inputs[1])?),
"Sub" => single(backend.sub(inputs[0], inputs[1])?),
"Mul" => single(backend.mul(inputs[0], inputs[1])?),
"Div" => single(backend.div(inputs[0], inputs[1])?),
"Neg" => single(backend.neg(inputs[0])?),
"Abs" => single(backend.abs(inputs[0])?),
"Sqrt" => single(backend.sqrt(inputs[0])?),
"Pow" => single(backend.pow(inputs[0], inputs[1])?),
"Exp" => single(backend.exp(inputs[0])?),
"Log" => single(backend.log(inputs[0])?),
"MatMul" => single(backend.matmul(inputs[0], inputs[1])?),
"ReduceSum" => single(backend.reduce_sum(
inputs[0],
&attr_ints(attrs, "axes"),
attr_int(attrs, "keepdims", 1) != 0,
)?),
"ReduceMean" => single(backend.reduce_mean(
inputs[0],
&attr_ints(attrs, "axes"),
attr_int(attrs, "keepdims", 1) != 0,
)?),
"ReduceMax" => single(backend.reduce_max(
inputs[0],
&attr_ints(attrs, "axes"),
attr_int(attrs, "keepdims", 1) != 0,
)?),
"ReduceMin" => single(backend.reduce_min(
inputs[0],
&attr_ints(attrs, "axes"),
attr_int(attrs, "keepdims", 1) != 0,
)?),
"Reshape" => single(backend.reshape(inputs[0], &attr_ints(attrs, "shape"))?),
"Transpose" => single(backend.transpose(inputs[0], &attr_ints(attrs, "perm"))?),
"Concat" => single(backend.concat(inputs, attr_int(attrs, "axis", 0))?),
"Slice" => single(backend.slice(
inputs[0],
&attr_ints(attrs, "starts"),
&attr_ints(attrs, "ends"),
&attr_ints(attrs, "axes"),
&attr_ints(attrs, "steps"),
)?),
"Split" => Ok(backend.split(
inputs[0],
attr_int(attrs, "axis", 0),
&attr_ints(attrs, "split"),
)?),
"Squeeze" => single(backend.squeeze(inputs[0], &attr_ints(attrs, "axes"))?),
"Unsqueeze" => single(backend.unsqueeze(inputs[0], &attr_ints(attrs, "axes"))?),
"Identity" => single(backend.identity(inputs[0])?),
"Cast" => single(backend.cast(inputs[0], attr_int(attrs, "to", 1) as i32)?),
"Equal" => single(backend.equal(inputs[0], inputs[1])?),
"Greater" => single(backend.greater(inputs[0], inputs[1])?),
"Less" => single(backend.less(inputs[0], inputs[1])?),
"Where" => single(backend.r#where(inputs[0], inputs[1], inputs[2])?),
"Constant" => single(backend.constant(attr_tensor(attrs, "value").unwrap_or_default())?),
"Gather" => single(backend.gather(inputs[0], inputs[1], attr_int(attrs, "axis", 0))?),
other => Err(BackendWalkError::UnknownOpType(other.to_string()).into()),
}
}
pub fn int_attr(name: &str, value: i64) -> AttributeProto {
AttributeProto {
name: name.to_string(),
r#type: bb_ir::proto::onnx::attribute_proto::AttributeType::Int as i32,
i: value,
..Default::default()
}
}
pub fn ints_attr(name: &str, values: &[i64]) -> AttributeProto {
AttributeProto {
name: name.to_string(),
r#type: bb_ir::proto::onnx::attribute_proto::AttributeType::Ints as i32,
ints: values.to_vec(),
..Default::default()
}
}
pub fn tensor_attr(name: &str, tensor: TensorProto) -> AttributeProto {
AttributeProto {
name: name.to_string(),
r#type: bb_ir::proto::onnx::attribute_proto::AttributeType::Tensor as i32,
t: Some(tensor),
..Default::default()
}
}
fn attr_int(attrs: &[AttributeProto], name: &str, default: i64) -> i64 {
attrs
.iter()
.find(|a| a.name == name)
.map(|a| a.i)
.unwrap_or(default)
}
fn attr_ints(attrs: &[AttributeProto], name: &str) -> Vec<i64> {
attrs
.iter()
.find(|a| a.name == name)
.map(|a| a.ints.clone())
.unwrap_or_default()
}
fn attr_tensor(attrs: &[AttributeProto], name: &str) -> Option<TensorProto> {
attrs
.iter()
.find(|a| a.name == name)
.and_then(|a| a.t.clone())
}