use std::collections::HashMap;
use bb_ir::proto::onnx::GraphProto;
use bb_runtime::atomic::DispatchResult;
use bb_runtime::bus::OpError;
use bb_runtime::slot_value::SlotValue;
use crate::backends::cpu::{ops, CpuBackend, CpuTensor};
#[derive(Debug)]
pub enum BackendError {
MissingInput {
name: String,
op_type: String,
},
OutputNotTensor {
op_type: String,
},
KernelFailed {
op_type: String,
source: OpError,
},
DefaultWalker(bb_runtime::contracts::backend_default_walk::BackendWalkError),
}
impl From<bb_runtime::contracts::backend_default_walk::BackendWalkError> for BackendError {
fn from(value: bb_runtime::contracts::backend_default_walk::BackendWalkError) -> Self {
Self::DefaultWalker(value)
}
}
impl std::fmt::Display for BackendError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::MissingInput { name, op_type } => {
write!(f, "execute_graph: input `{name}` missing for `{op_type}`",)
}
Self::OutputNotTensor { op_type } => write!(
f,
"execute_graph: `{op_type}` produced a non-CpuTensor output",
),
Self::KernelFailed { op_type, source } => {
write!(f, "execute_graph: `{op_type}` kernel failed: {source}",)
}
Self::DefaultWalker(e) => write!(f, "{e}"),
}
}
}
impl std::error::Error for BackendError {}
pub fn execute_graph(
backend: &CpuBackend,
graph: &GraphProto,
inputs: HashMap<String, CpuTensor>,
) -> Result<HashMap<String, CpuTensor>, BackendError> {
let mut env: HashMap<String, CpuTensor> = inputs;
for node in &graph.node {
let mut input_refs: Vec<(&str, &dyn SlotValue)> = Vec::with_capacity(node.input.len());
for name in &node.input {
if name.is_empty() {
continue;
}
let tensor = env.get(name).ok_or_else(|| BackendError::MissingInput {
name: name.clone(),
op_type: node.op_type.clone(),
})?;
input_refs.push((name.as_str(), tensor as &dyn SlotValue));
}
let result =
ops::dispatch(backend, &node.op_type, &input_refs, &node.attribute).map_err(|e| {
BackendError::KernelFailed {
op_type: node.op_type.clone(),
source: e,
}
})?;
let outputs = match result {
DispatchResult::Immediate(outs) => outs,
DispatchResult::Async(_) => {
return Err(BackendError::KernelFailed {
op_type: node.op_type.clone(),
source: OpError {
detail: format!(
"{op}: async dispatch unsupported inside execute_graph",
op = node.op_type,
),
..Default::default()
},
});
}
};
for (i, (_kernel_name, boxed)) in outputs.into_iter().enumerate() {
let Some(graph_name) = node.output.get(i) else {
continue;
};
if graph_name.is_empty() {
continue;
}
let any = boxed.into_any_boxed();
let tensor: Box<CpuTensor> =
any.downcast::<CpuTensor>()
.map_err(|_| BackendError::OutputNotTensor {
op_type: node.op_type.clone(),
})?;
env.insert(graph_name.clone(), *tensor);
}
}
let mut out: HashMap<String, CpuTensor> = HashMap::new();
for vi in &graph.output {
if let Some(t) = env.remove(&vi.name) {
out.insert(vi.name.clone(), t);
}
}
Ok(out)
}