use std::collections::BTreeMap;
#[derive(Debug, Clone)]
pub struct OpCoverage {
pub op_type: String,
pub supported: bool,
pub since_opset: u32,
pub category: String,
}
pub fn generate_coverage_report() -> Vec<OpCoverage> {
let registry = oxionnx_ops::default_registry();
let ops: Vec<(&str, u32, &str)> = vec![
("Add", 1, "math"),
("Sub", 1, "math"),
("Mul", 1, "math"),
("Div", 1, "math"),
("MatMul", 1, "math"),
("Gemm", 1, "math"),
("Pow", 1, "math"),
("Sqrt", 1, "math"),
("Exp", 1, "math"),
("Log", 1, "math"),
("Abs", 1, "math"),
("Neg", 1, "math"),
("Reciprocal", 1, "math"),
("Ceil", 1, "math"),
("Floor", 1, "math"),
("Round", 11, "math"),
("Sign", 9, "math"),
("Mod", 10, "math"),
("Sin", 7, "math"),
("Cos", 7, "math"),
("Tan", 7, "math"),
("Asin", 7, "math"),
("Acos", 7, "math"),
("Atan", 7, "math"),
("Sinh", 9, "math"),
("Cosh", 9, "math"),
("Asinh", 9, "math"),
("Acosh", 9, "math"),
("Atanh", 9, "math"),
("ReduceMean", 1, "math"),
("ReduceSum", 1, "math"),
("ReduceMax", 1, "math"),
("ReduceMin", 1, "math"),
("ReduceProd", 1, "math"),
("ArgMax", 1, "math"),
("ArgMin", 1, "math"),
("CumSum", 11, "math"),
("Range", 11, "math"),
("TopK", 1, "math"),
("BitShift", 11, "math"),
("ReduceL1", 1, "math"),
("ReduceL2", 1, "math"),
("ReduceLogSum", 1, "math"),
("ReduceLogSumExp", 1, "math"),
("ReduceSumSquare", 1, "math"),
("Min", 1, "math"),
("Max", 1, "math"),
("Mean", 1, "math"),
("Sum", 1, "math"),
("Relu", 1, "nn"),
("Sigmoid", 1, "nn"),
("Tanh", 1, "nn"),
("Softmax", 1, "nn"),
("LogSoftmax", 1, "nn"),
("Gelu", 20, "nn"),
("Erf", 9, "nn"),
("LeakyRelu", 1, "nn"),
("PRelu", 1, "nn"),
("Elu", 1, "nn"),
("Selu", 1, "nn"),
("Celu", 12, "nn"),
("HardSigmoid", 1, "nn"),
("HardSwish", 14, "nn"),
("Softplus", 1, "nn"),
("Softsign", 1, "nn"),
("Mish", 18, "nn"),
("ThresholdedRelu", 10, "nn"),
("Hardmax", 1, "nn"),
("Shrink", 9, "nn"),
("BatchNormalization", 1, "nn"),
("LayerNormalization", 17, "nn"),
("GroupNormalization", 18, "nn"),
("InstanceNormalization", 1, "nn"),
("LpNormalization", 1, "nn"),
("MeanVarianceNormalization", 9, "nn"),
("Dropout", 1, "nn"),
("Conv", 1, "conv"),
("ConvTranspose", 1, "conv"),
("MaxPool", 1, "conv"),
("AveragePool", 1, "conv"),
("GlobalAveragePool", 1, "conv"),
("GlobalMaxPool", 1, "conv"),
("Pad", 1, "conv"),
("Resize", 10, "conv"),
("Reshape", 1, "shape"),
("Transpose", 1, "shape"),
("Squeeze", 1, "shape"),
("Unsqueeze", 1, "shape"),
("Flatten", 1, "shape"),
("Concat", 1, "shape"),
("Slice", 1, "shape"),
("Expand", 8, "shape"),
("Split", 1, "shape"),
("Tile", 1, "shape"),
("DepthToSpace", 1, "shape"),
("SpaceToDepth", 1, "shape"),
("ReverseSequence", 10, "shape"),
("Gather", 1, "indexing"),
("GatherElements", 11, "indexing"),
("GatherND", 11, "indexing"),
("Where", 9, "indexing"),
("ScatterElements", 11, "indexing"),
("ScatterND", 11, "indexing"),
("OneHot", 9, "indexing"),
("Compress", 9, "indexing"),
("Unique", 11, "indexing"),
("NonZero", 9, "indexing"),
("Equal", 1, "comparison"),
("Greater", 1, "comparison"),
("GreaterOrEqual", 12, "comparison"),
("Less", 1, "comparison"),
("LessOrEqual", 12, "comparison"),
("And", 1, "logic"),
("Or", 1, "logic"),
("Xor", 1, "logic"),
("Not", 1, "logic"),
("BitwiseAnd", 18, "bitwise"),
("BitwiseOr", 18, "bitwise"),
("BitwiseXor", 18, "bitwise"),
("BitwiseNot", 18, "bitwise"),
("IsInf", 10, "logic"),
("IsNaN", 9, "logic"),
("Identity", 1, "misc"),
("Cast", 1, "misc"),
("Shape", 1, "misc"),
("Size", 1, "misc"),
("Constant", 1, "misc"),
("ConstantOfShape", 9, "misc"),
("Clip", 1, "misc"),
("EyeLike", 9, "misc"),
("Trilu", 14, "misc"),
("QuantizeLinear", 10, "quantization"),
("DequantizeLinear", 10, "quantization"),
("LSTM", 1, "rnn"),
("GRU", 1, "rnn"),
("Einsum", 12, "math"),
("NonMaxSuppression", 10, "detection"),
("GridSample", 16, "spatial"),
("RoiAlign", 10, "spatial"),
];
ops.iter()
.map(|&(name, since, cat)| OpCoverage {
op_type: name.to_string(),
supported: registry.contains(name),
since_opset: since,
category: cat.to_string(),
})
.collect()
}
pub fn format_coverage_markdown(report: &[OpCoverage]) -> String {
let total = report.len();
let supported = report.iter().filter(|o| o.supported).count();
let mut out = String::from("# ONNX Operator Coverage Report\n\n");
out.push_str(&format!(
"**{}/{} operators supported ({:.0}%)**\n\n",
supported,
total,
100.0 * supported as f64 / total as f64
));
out.push_str("| Operator | Category | Since Opset | Supported |\n");
out.push_str("|----------|----------|-------------|:---------:|\n");
for op in report {
let status = if op.supported { "yes" } else { "no" };
out.push_str(&format!(
"| {} | {} | {} | {} |\n",
op.op_type, op.category, op.since_opset, status
));
}
out
}
pub fn format_coverage_summary(report: &[OpCoverage]) -> String {
let mut by_category: BTreeMap<String, (usize, usize)> = BTreeMap::new();
for op in report {
let entry = by_category.entry(op.category.clone()).or_insert((0, 0));
entry.0 += 1;
if op.supported {
entry.1 += 1;
}
}
let mut out = String::from("Category Coverage:\n");
for (cat, (total, supported)) in &by_category {
out.push_str(&format!(
" {:<15} {}/{} ({:.0}%)\n",
cat,
supported,
total,
100.0 * *supported as f64 / *total as f64
));
}
out
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_coverage_report() {
let report = generate_coverage_report();
assert!(!report.is_empty(), "report should have entries");
let supported_count = report.iter().filter(|o| o.supported).count();
assert!(
supported_count > 0,
"at least some operators should be supported"
);
}
#[test]
fn test_coverage_markdown() {
let report = generate_coverage_report();
let md = format_coverage_markdown(&report);
assert!(md.contains("# ONNX Operator Coverage Report"));
assert!(md.contains("| Operator |"));
assert!(md.contains("| Add |"));
assert!(md.contains("yes") || md.contains("no"));
}
#[test]
fn test_coverage_summary() {
let report = generate_coverage_report();
let summary = format_coverage_summary(&report);
assert!(summary.contains("Category Coverage:"));
assert!(summary.contains("math"));
assert!(summary.contains("nn"));
assert!(summary.contains("conv"));
assert!(summary.contains("shape"));
}
}