use crate::atomic::AtomicOpsetDecl;
pub const TENSOR_PRIMITIVES_DOMAIN: &str = "ai.onnx";
pub const TENSOR_PRIMITIVES_VERSION: i64 = 1;
pub const TENSOR_PRIMITIVES_OPS: &[&str] = &[
"Add",
"Sub",
"Mul",
"Div",
"Neg",
"Abs",
"Sqrt",
"Pow",
"Exp",
"Log",
"MatMul",
"ReduceSum",
"ReduceMean",
"ReduceMax",
"ReduceMin",
"Reshape",
"Transpose",
"Concat",
"Slice",
"Split",
"Squeeze",
"Unsqueeze",
"Identity",
"Cast",
"Equal",
"Greater",
"Less",
"Where",
"Constant",
"Gather",
];
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct MissingPrimitives {
pub backend_domain: &'static str,
pub backend_version: i64,
pub missing: Vec<&'static str>,
}
impl std::fmt::Display for MissingPrimitives {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"backend opset {}@v{} missing {} primitive op(s): {}",
self.backend_domain,
self.backend_version,
self.missing.len(),
self.missing.join(", "),
)
}
}
impl std::error::Error for MissingPrimitives {}
pub fn opset_covers_primitives(opset: &AtomicOpsetDecl) -> Result<(), MissingPrimitives> {
let declared: std::collections::HashSet<&str> = if opset.domain == TENSOR_PRIMITIVES_DOMAIN {
opset.ops.iter().map(|o| o.name).collect()
} else {
std::collections::HashSet::new()
};
let missing: Vec<&'static str> = TENSOR_PRIMITIVES_OPS
.iter()
.copied()
.filter(|name| !declared.contains(*name))
.collect();
if missing.is_empty() {
Ok(())
} else {
Err(MissingPrimitives {
backend_domain: opset.domain,
backend_version: opset.version,
missing,
})
}
}