use std::collections::HashMap;
use bb_ir::proto::onnx::{AttributeProto, GraphProto, StringStringEntryProto, TensorProto};
use crate::completion::{CompletionHandle, ContractResponse};
use crate::contracts::backend_default_walk;
pub struct BackendAttrs<'a> {
pub current_node_attributes: &'a [AttributeProto],
pub current_node_metadata: &'a [StringStringEntryProto],
}
pub trait Backend: Send + Sync {
type Error: std::error::Error
+ std::fmt::Display
+ Send
+ Sync
+ From<crate::contracts::backend_default_walk::BackendWalkError>
+ 'static;
type Tensor: Clone + Send + Sync + 'static + bb_ir::types::Storage;
fn add(&self, a: &Self::Tensor, b: &Self::Tensor) -> Result<Self::Tensor, Self::Error> {
backend_default_walk::execute_single(self, "Add", &[a, b], Vec::new())
}
fn sub(&self, a: &Self::Tensor, b: &Self::Tensor) -> Result<Self::Tensor, Self::Error> {
backend_default_walk::execute_single(self, "Sub", &[a, b], Vec::new())
}
fn mul(&self, a: &Self::Tensor, b: &Self::Tensor) -> Result<Self::Tensor, Self::Error> {
backend_default_walk::execute_single(self, "Mul", &[a, b], Vec::new())
}
fn div(&self, a: &Self::Tensor, b: &Self::Tensor) -> Result<Self::Tensor, Self::Error> {
backend_default_walk::execute_single(self, "Div", &[a, b], Vec::new())
}
fn neg(&self, a: &Self::Tensor) -> Result<Self::Tensor, Self::Error> {
backend_default_walk::execute_single(self, "Neg", &[a], Vec::new())
}
fn abs(&self, a: &Self::Tensor) -> Result<Self::Tensor, Self::Error> {
backend_default_walk::execute_single(self, "Abs", &[a], Vec::new())
}
fn sqrt(&self, a: &Self::Tensor) -> Result<Self::Tensor, Self::Error> {
backend_default_walk::execute_single(self, "Sqrt", &[a], Vec::new())
}
fn pow(&self, a: &Self::Tensor, b: &Self::Tensor) -> Result<Self::Tensor, Self::Error> {
backend_default_walk::execute_single(self, "Pow", &[a, b], Vec::new())
}
fn exp(&self, a: &Self::Tensor) -> Result<Self::Tensor, Self::Error> {
backend_default_walk::execute_single(self, "Exp", &[a], Vec::new())
}
fn log(&self, a: &Self::Tensor) -> Result<Self::Tensor, Self::Error> {
backend_default_walk::execute_single(self, "Log", &[a], Vec::new())
}
fn matmul(&self, a: &Self::Tensor, b: &Self::Tensor) -> Result<Self::Tensor, Self::Error> {
backend_default_walk::execute_single(self, "MatMul", &[a, b], Vec::new())
}
fn reduce_sum(
&self,
a: &Self::Tensor,
axes: &[i64],
keepdims: bool,
) -> Result<Self::Tensor, Self::Error> {
backend_default_walk::execute_single(
self,
"ReduceSum",
&[a],
vec![
backend_default_walk::ints_attr("axes", axes),
backend_default_walk::int_attr("keepdims", keepdims as i64),
],
)
}
fn reduce_mean(
&self,
a: &Self::Tensor,
axes: &[i64],
keepdims: bool,
) -> Result<Self::Tensor, Self::Error> {
backend_default_walk::execute_single(
self,
"ReduceMean",
&[a],
vec![
backend_default_walk::ints_attr("axes", axes),
backend_default_walk::int_attr("keepdims", keepdims as i64),
],
)
}
fn reduce_max(
&self,
a: &Self::Tensor,
axes: &[i64],
keepdims: bool,
) -> Result<Self::Tensor, Self::Error> {
backend_default_walk::execute_single(
self,
"ReduceMax",
&[a],
vec![
backend_default_walk::ints_attr("axes", axes),
backend_default_walk::int_attr("keepdims", keepdims as i64),
],
)
}
fn reduce_min(
&self,
a: &Self::Tensor,
axes: &[i64],
keepdims: bool,
) -> Result<Self::Tensor, Self::Error> {
backend_default_walk::execute_single(
self,
"ReduceMin",
&[a],
vec![
backend_default_walk::ints_attr("axes", axes),
backend_default_walk::int_attr("keepdims", keepdims as i64),
],
)
}
fn reshape(&self, a: &Self::Tensor, shape: &[i64]) -> Result<Self::Tensor, Self::Error> {
backend_default_walk::execute_single(
self,
"Reshape",
&[a],
vec![backend_default_walk::ints_attr("shape", shape)],
)
}
fn transpose(&self, a: &Self::Tensor, perm: &[i64]) -> Result<Self::Tensor, Self::Error> {
backend_default_walk::execute_single(
self,
"Transpose",
&[a],
vec![backend_default_walk::ints_attr("perm", perm)],
)
}
fn concat(&self, inputs: &[&Self::Tensor], axis: i64) -> Result<Self::Tensor, Self::Error> {
backend_default_walk::execute_single(
self,
"Concat",
inputs,
vec![backend_default_walk::int_attr("axis", axis)],
)
}
fn slice(
&self,
a: &Self::Tensor,
starts: &[i64],
ends: &[i64],
axes: &[i64],
steps: &[i64],
) -> Result<Self::Tensor, Self::Error> {
backend_default_walk::execute_single(
self,
"Slice",
&[a],
vec![
backend_default_walk::ints_attr("starts", starts),
backend_default_walk::ints_attr("ends", ends),
backend_default_walk::ints_attr("axes", axes),
backend_default_walk::ints_attr("steps", steps),
],
)
}
fn split(
&self,
a: &Self::Tensor,
axis: i64,
sizes: &[i64],
) -> Result<Vec<Self::Tensor>, Self::Error> {
backend_default_walk::execute_multi(
self,
"Split",
&[a],
vec![
backend_default_walk::int_attr("axis", axis),
backend_default_walk::ints_attr("split", sizes),
],
sizes.len(),
)
}
fn squeeze(&self, a: &Self::Tensor, axes: &[i64]) -> Result<Self::Tensor, Self::Error> {
backend_default_walk::execute_single(
self,
"Squeeze",
&[a],
vec![backend_default_walk::ints_attr("axes", axes)],
)
}
fn unsqueeze(&self, a: &Self::Tensor, axes: &[i64]) -> Result<Self::Tensor, Self::Error> {
backend_default_walk::execute_single(
self,
"Unsqueeze",
&[a],
vec![backend_default_walk::ints_attr("axes", axes)],
)
}
fn identity(&self, a: &Self::Tensor) -> Result<Self::Tensor, Self::Error> {
backend_default_walk::execute_single(self, "Identity", &[a], Vec::new())
}
fn cast(&self, a: &Self::Tensor, dtype: i32) -> Result<Self::Tensor, Self::Error> {
backend_default_walk::execute_single(
self,
"Cast",
&[a],
vec![backend_default_walk::int_attr("to", dtype as i64)],
)
}
fn equal(&self, a: &Self::Tensor, b: &Self::Tensor) -> Result<Self::Tensor, Self::Error> {
backend_default_walk::execute_single(self, "Equal", &[a, b], Vec::new())
}
fn greater(&self, a: &Self::Tensor, b: &Self::Tensor) -> Result<Self::Tensor, Self::Error> {
backend_default_walk::execute_single(self, "Greater", &[a, b], Vec::new())
}
fn less(&self, a: &Self::Tensor, b: &Self::Tensor) -> Result<Self::Tensor, Self::Error> {
backend_default_walk::execute_single(self, "Less", &[a, b], Vec::new())
}
fn r#where(
&self,
cond: &Self::Tensor,
t: &Self::Tensor,
f: &Self::Tensor,
) -> Result<Self::Tensor, Self::Error> {
backend_default_walk::execute_single(self, "Where", &[cond, t, f], Vec::new())
}
fn constant(&self, value: TensorProto) -> Result<Self::Tensor, Self::Error> {
backend_default_walk::execute_single(
self,
"Constant",
&[],
vec![backend_default_walk::tensor_attr("value", value)],
)
}
fn gather(
&self,
data: &Self::Tensor,
indices: &Self::Tensor,
axis: i64,
) -> Result<Self::Tensor, Self::Error> {
backend_default_walk::execute_single(
self,
"Gather",
&[data, indices],
vec![backend_default_walk::int_attr("axis", axis)],
)
}
fn execute(
&self,
graph: &GraphProto,
inputs: HashMap<String, Self::Tensor>,
_attrs: BackendAttrs<'_>,
) -> Result<HashMap<String, Self::Tensor>, Self::Error> {
backend_default_walk::execute_graph_via_per_op(self, graph, inputs)
}
fn dispatch(
&self,
graph: &GraphProto,
inputs: HashMap<String, Self::Tensor>,
attrs: BackendAttrs<'_>,
completion: CompletionHandle<HashMap<String, Self::Tensor>, Self::Error>,
) -> ContractResponse<HashMap<String, Self::Tensor>, Self::Error> {
let _ = completion; ContractResponse::Now(self.execute(graph, inputs, attrs))
}
fn materialize_from_wire(
&self,
type_hash: u64,
bytes: Vec<u8>,
) -> Result<Self::Tensor, Self::Error> {
use crate::contracts::backend_default_walk::BackendWalkError;
let decoder = bb_ir::slot_value::wire_decoder_registry()
.get(&type_hash)
.copied()
.ok_or_else(|| BackendWalkError::WireMaterializeFailed {
type_hash,
reason: "no decoder registered for type_hash".into(),
})?;
let boxed = decoder(&bytes).map_err(|e| BackendWalkError::WireMaterializeFailed {
type_hash,
reason: e.to_string(),
})?;
let any = boxed.into_any_boxed();
any.downcast::<Self::Tensor>().map(|b| *b).map_err(|_| {
BackendWalkError::WireMaterializeFailed {
type_hash,
reason: "decoded carrier is not Self::Tensor".into(),
}
.into()
})
}
}