use std::collections::{HashMap, HashSet};
use std::fmt::Write as _;
use rlx_ir::logical_kernel::{
self, KernelDispatchConfig, KernelDispatchPolicy, registered_logical_kernels,
should_lower_to_common,
};
use rlx_ir::{Graph, NodeId, OpKind};
use crate::legalize::legalize_for_backend;
use crate::rewrite::rewrite_for_backend_with_config;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DispatchPath {
Native,
CommonIr,
Rewritten,
Unsupported,
}
impl DispatchPath {
pub fn label(self) -> &'static str {
match self {
Self::Native => "native",
Self::CommonIr => "common-ir",
Self::Rewritten => "rewritten",
Self::Unsupported => "unsupported",
}
}
}
#[derive(Debug, Clone)]
pub struct KindDispatchSummary {
pub kind: OpKind,
pub node_count: usize,
pub path: DispatchPath,
pub logical_name: Option<&'static str>,
}
#[derive(Debug, Clone)]
pub struct KernelDispatchReport {
pub backend_name: String,
pub policy: KernelDispatchPolicy,
pub supported_claim_count: usize,
pub summaries: Vec<KindDispatchSummary>,
pub common_lowered_kinds: Vec<OpKind>,
pub still_unsupported: Vec<(NodeId, OpKind)>,
pub compile_ready: bool,
}
fn logical_name(kind: OpKind) -> Option<&'static str> {
registered_logical_kernels()
.iter()
.find(|e| e.kind == kind)
.map(|e| e.name)
}
fn count_kinds(graph: &Graph) -> HashMap<OpKind, usize> {
let mut m = HashMap::new();
for node in graph.nodes() {
*m.entry(node.op.kind()).or_default() += 1;
}
m
}
fn classify_kind(
kind: OpKind,
supported: &[OpKind],
config: KernelDispatchConfig,
common_set: &HashSet<OpKind>,
before: &HashMap<OpKind, usize>,
after: &HashMap<OpKind, usize>,
unsupported_kinds: &HashSet<OpKind>,
) -> DispatchPath {
if should_lower_to_common(kind, supported, config) || common_set.contains(&kind) {
return DispatchPath::CommonIr;
}
if unsupported_kinds.contains(&kind) {
return DispatchPath::Unsupported;
}
if before.contains_key(&kind) && !after.contains_key(&kind) {
return DispatchPath::Rewritten;
}
if supported.is_empty() || supported.contains(&kind) {
return DispatchPath::Native;
}
if after.contains_key(&kind) {
return DispatchPath::Native;
}
DispatchPath::Unsupported
}
pub fn analyze_dispatch(
graph: &Graph,
backend_name: &str,
supported: &[OpKind],
config: KernelDispatchConfig,
) -> KernelDispatchReport {
let before = count_kinds(graph);
let common_lowered = logical_kernel::logical_kinds_in_graph(graph, supported, config);
let common_set: HashSet<OpKind> = common_lowered.iter().copied().collect();
let unsupported_set = HashSet::new();
let mut summaries: Vec<KindDispatchSummary> = before
.iter()
.map(|(&kind, &node_count)| {
let path = classify_kind(
kind,
supported,
config,
&common_set,
&before,
&before,
&unsupported_set,
);
KindDispatchSummary {
kind,
node_count,
path,
logical_name: logical_name(kind),
}
})
.collect();
summaries.sort_by_key(|s| format!("{:?}", s.kind));
KernelDispatchReport {
backend_name: backend_name.to_string(),
policy: config.policy,
supported_claim_count: supported.len(),
summaries,
common_lowered_kinds: common_lowered,
still_unsupported: Vec::new(),
compile_ready: true,
}
}
pub fn prepare_graph_for_backend_with_report(
graph: Graph,
backend_name: &str,
supported: &[OpKind],
config: KernelDispatchConfig,
) -> (Graph, KernelDispatchReport) {
let before = count_kinds(&graph);
let common_lowered = logical_kernel::logical_kinds_in_graph(&graph, supported, config);
let common_set: HashSet<OpKind> = common_lowered.iter().copied().collect();
let rewritten = rewrite_for_backend_with_config(graph, supported, config);
let after = count_kinds(&rewritten);
let still_unsupported = legalize_for_backend(&rewritten, supported)
.err()
.unwrap_or_default();
let unsupported_set: HashSet<OpKind> = still_unsupported.iter().map(|(_, k)| *k).collect();
let mut summaries: Vec<KindDispatchSummary> = before
.iter()
.map(|(&kind, &node_count)| {
let path = classify_kind(
kind,
supported,
config,
&common_set,
&before,
&after,
&unsupported_set,
);
KindDispatchSummary {
kind,
node_count,
path,
logical_name: logical_name(kind),
}
})
.collect();
summaries.sort_by_key(|s| format!("{:?}", s.kind));
let compile_ready = still_unsupported.is_empty();
let report = KernelDispatchReport {
backend_name: backend_name.to_string(),
policy: config.policy,
supported_claim_count: supported.len(),
summaries,
common_lowered_kinds: common_lowered,
still_unsupported,
compile_ready,
};
(rewritten, report)
}
pub fn format_dispatch_report(report: &KernelDispatchReport) -> String {
let mut s = String::new();
let _ = writeln!(
s,
"rlx dispatch report — backend {:?}, policy {:?}, supported_ops claim={}",
report.backend_name, report.policy, report.supported_claim_count
);
if report.supported_claim_count == 0 {
let _ = writeln!(
s,
" (empty claim = legalize accepts all kinds; native/common split is advisory only)"
);
}
if !report.common_lowered_kinds.is_empty() {
let _ = writeln!(
s,
" common-ir lowering (portable, add to supported_ops for native fast path):"
);
for kind in &report.common_lowered_kinds {
let name = logical_name(*kind).unwrap_or("?");
let _ = writeln!(s, " - {kind:?} ({name})");
}
}
let mut by_path: [Vec<&KindDispatchSummary>; 4] = [vec![], vec![], vec![], vec![]];
for sum in &report.summaries {
let idx = match sum.path {
DispatchPath::Native => 0,
DispatchPath::CommonIr => 1,
DispatchPath::Rewritten => 2,
DispatchPath::Unsupported => 3,
};
by_path[idx].push(sum);
}
for (label, entries) in [
("native", &by_path[0]),
("common-ir", &by_path[1]),
("rewritten", &by_path[2]),
("unsupported", &by_path[3]),
] {
if entries.is_empty() {
continue;
}
let _ = writeln!(s, " {label}:");
for e in entries {
let extra = e
.logical_name
.map(|n| format!(" logical={n}"))
.unwrap_or_default();
let _ = writeln!(s, " - {:?} ×{} nodes{extra}", e.kind, e.node_count);
}
}
if !report.still_unsupported.is_empty() {
let _ = writeln!(
s,
" still unsupported after rewrite ({} node(s)) — compile will fail:",
report.still_unsupported.len()
);
for (id, kind) in &report.still_unsupported {
let _ = writeln!(s, " - node {id:?}: {kind:?}");
}
let _ = writeln!(
s,
" Fix: implement native thunk + add to Backend::supported_ops, or add a \
rewrite/common body in rlx-fusion."
);
} else {
let _ = writeln!(s, " compile-ready: yes");
}
s
}
pub fn maybe_log_dispatch_report(report: &KernelDispatchReport) {
if rlx_ir::env::flag("RLX_DISPATCH_REPORT") || rlx_ir::env::flag("RLX_VERBOSE") {
eprintln!("{}", format_dispatch_report(report));
}
}
#[cfg(test)]
mod tests {
use super::*;
use rlx_ir::*;
#[test]
fn common_lowered_when_not_in_supported() {
use rlx_ir::ops::splat::{GaussianSplatInputs, GaussianSplatRenderParams};
let mut g = Graph::new("splat");
let n = 2usize;
let f = DType::F32;
let positions = g.input("pos", Shape::new(&[n, 3], f));
let scales = g.input("scale", Shape::new(&[n, 3], f));
let rotations = g.input("rot", Shape::new(&[n, 4], f));
let opacities = g.input("opa", Shape::new(&[n], f));
let colors = g.input("col", Shape::new(&[n, 3], f));
let sh_coeffs = g.input("sh", Shape::new(&[n, 3], f));
let meta = g.input("meta", Shape::new(&[23], f));
let out = g.gaussian_splat_render(
GaussianSplatInputs {
positions,
scales,
rotations,
opacities,
colors,
sh_coeffs,
meta,
},
GaussianSplatRenderParams {
width: 8,
height: 8,
..Default::default()
},
);
g.set_outputs(vec![out]);
let supported = &[OpKind::Input, OpKind::Param, OpKind::MatMul];
let report = analyze_dispatch(&g, "test", supported, KernelDispatchConfig::default());
assert!(
report
.common_lowered_kinds
.contains(&OpKind::GaussianSplatRender)
);
assert!(
report
.summaries
.iter()
.any(|s| s.kind == OpKind::GaussianSplatRender && s.path == DispatchPath::CommonIr)
);
}
#[test]
fn prepare_marks_rewritten_fused_op() {
let f = DType::F32;
let mut g = Graph::new("fused");
let x = g.input("x", Shape::new(&[2, 8], f));
let w = g.param("w", Shape::new(&[8, 4], f));
let b = g.param("b", Shape::new(&[4], f));
let out = g.fused_matmul_bias_act(x, w, b, None, Shape::new(&[2, 4], f));
g.set_outputs(vec![out]);
let supported = &[
OpKind::Input,
OpKind::Param,
OpKind::MatMul,
OpKind::Binary,
OpKind::Expand,
OpKind::Activation,
];
let (rewritten, report) = prepare_graph_for_backend_with_report(
g,
"cpu",
supported,
KernelDispatchConfig::default(),
);
assert!(report.compile_ready);
assert!(
!rewritten
.nodes()
.iter()
.any(|n| n.op.kind() == OpKind::FusedMatMulBiasAct)
);
assert!(report.summaries.iter().any(|s| {
s.kind == OpKind::FusedMatMulBiasAct && s.path == DispatchPath::Rewritten
}));
}
}