1use std::collections::{HashMap, HashSet};
9use std::fmt::Write as _;
10
11use rlx_ir::logical_kernel::{
12 self, KernelDispatchConfig, KernelDispatchPolicy, registered_logical_kernels,
13 should_lower_to_common,
14};
15use rlx_ir::{Graph, NodeId, OpKind};
16
17use crate::legalize::legalize_for_backend;
18use crate::rewrite::rewrite_for_backend_with_config;
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum DispatchPath {
23 Native,
25 CommonIr,
27 Rewritten,
29 Unsupported,
31}
32
33impl DispatchPath {
34 pub fn label(self) -> &'static str {
35 match self {
36 Self::Native => "native",
37 Self::CommonIr => "common-ir",
38 Self::Rewritten => "rewritten",
39 Self::Unsupported => "unsupported",
40 }
41 }
42}
43
44#[derive(Debug, Clone)]
46pub struct KindDispatchSummary {
47 pub kind: OpKind,
48 pub node_count: usize,
49 pub path: DispatchPath,
50 pub logical_name: Option<&'static str>,
52}
53
54#[derive(Debug, Clone)]
56pub struct KernelDispatchReport {
57 pub backend_name: String,
58 pub policy: KernelDispatchPolicy,
59 pub supported_claim_count: usize,
61 pub summaries: Vec<KindDispatchSummary>,
62 pub common_lowered_kinds: Vec<OpKind>,
64 pub still_unsupported: Vec<(NodeId, OpKind)>,
66 pub compile_ready: bool,
67}
68
69fn logical_name(kind: OpKind) -> Option<&'static str> {
70 registered_logical_kernels()
71 .iter()
72 .find(|e| e.kind == kind)
73 .map(|e| e.name)
74}
75
76fn count_kinds(graph: &Graph) -> HashMap<OpKind, usize> {
77 let mut m = HashMap::new();
78 for node in graph.nodes() {
79 *m.entry(node.op.kind()).or_default() += 1;
80 }
81 m
82}
83
84fn classify_kind(
85 kind: OpKind,
86 supported: &[OpKind],
87 config: KernelDispatchConfig,
88 common_set: &HashSet<OpKind>,
89 before: &HashMap<OpKind, usize>,
90 after: &HashMap<OpKind, usize>,
91 unsupported_kinds: &HashSet<OpKind>,
92) -> DispatchPath {
93 if should_lower_to_common(kind, supported, config) || common_set.contains(&kind) {
94 return DispatchPath::CommonIr;
95 }
96 if unsupported_kinds.contains(&kind) {
97 return DispatchPath::Unsupported;
98 }
99 if before.contains_key(&kind) && !after.contains_key(&kind) {
100 return DispatchPath::Rewritten;
101 }
102 if supported.is_empty() || supported.contains(&kind) {
103 return DispatchPath::Native;
104 }
105 if after.contains_key(&kind) {
106 return DispatchPath::Native;
107 }
108 DispatchPath::Unsupported
109}
110
111pub fn analyze_dispatch(
113 graph: &Graph,
114 backend_name: &str,
115 supported: &[OpKind],
116 config: KernelDispatchConfig,
117) -> KernelDispatchReport {
118 let before = count_kinds(graph);
119 let common_lowered = logical_kernel::logical_kinds_in_graph(graph, supported, config);
120 let common_set: HashSet<OpKind> = common_lowered.iter().copied().collect();
121 let unsupported_set = HashSet::new();
122
123 let mut summaries: Vec<KindDispatchSummary> = before
124 .iter()
125 .map(|(&kind, &node_count)| {
126 let path = classify_kind(
127 kind,
128 supported,
129 config,
130 &common_set,
131 &before,
132 &before,
133 &unsupported_set,
134 );
135 KindDispatchSummary {
136 kind,
137 node_count,
138 path,
139 logical_name: logical_name(kind),
140 }
141 })
142 .collect();
143 summaries.sort_by_key(|s| format!("{:?}", s.kind));
144
145 KernelDispatchReport {
146 backend_name: backend_name.to_string(),
147 policy: config.policy,
148 supported_claim_count: supported.len(),
149 summaries,
150 common_lowered_kinds: common_lowered,
151 still_unsupported: Vec::new(),
152 compile_ready: true,
154 }
155}
156
157pub fn prepare_graph_for_backend_with_report(
159 graph: Graph,
160 backend_name: &str,
161 supported: &[OpKind],
162 config: KernelDispatchConfig,
163) -> (Graph, KernelDispatchReport) {
164 let before = count_kinds(&graph);
165 let common_lowered = logical_kernel::logical_kinds_in_graph(&graph, supported, config);
166 let common_set: HashSet<OpKind> = common_lowered.iter().copied().collect();
167
168 let rewritten = rewrite_for_backend_with_config(graph, supported, config);
169 let after = count_kinds(&rewritten);
170 let still_unsupported = legalize_for_backend(&rewritten, supported)
171 .err()
172 .unwrap_or_default();
173 let unsupported_set: HashSet<OpKind> = still_unsupported.iter().map(|(_, k)| *k).collect();
174
175 let mut summaries: Vec<KindDispatchSummary> = before
176 .iter()
177 .map(|(&kind, &node_count)| {
178 let path = classify_kind(
179 kind,
180 supported,
181 config,
182 &common_set,
183 &before,
184 &after,
185 &unsupported_set,
186 );
187 KindDispatchSummary {
188 kind,
189 node_count,
190 path,
191 logical_name: logical_name(kind),
192 }
193 })
194 .collect();
195 summaries.sort_by_key(|s| format!("{:?}", s.kind));
196
197 let compile_ready = still_unsupported.is_empty();
198 let report = KernelDispatchReport {
199 backend_name: backend_name.to_string(),
200 policy: config.policy,
201 supported_claim_count: supported.len(),
202 summaries,
203 common_lowered_kinds: common_lowered,
204 still_unsupported,
205 compile_ready,
206 };
207 (rewritten, report)
208}
209
210pub fn format_dispatch_report(report: &KernelDispatchReport) -> String {
212 let mut s = String::new();
213 let _ = writeln!(
214 s,
215 "rlx dispatch report — backend {:?}, policy {:?}, supported_ops claim={}",
216 report.backend_name, report.policy, report.supported_claim_count
217 );
218 if report.supported_claim_count == 0 {
219 let _ = writeln!(
220 s,
221 " (empty claim = legalize accepts all kinds; native/common split is advisory only)"
222 );
223 }
224
225 if !report.common_lowered_kinds.is_empty() {
226 let _ = writeln!(
227 s,
228 " common-ir lowering (portable, add to supported_ops for native fast path):"
229 );
230 for kind in &report.common_lowered_kinds {
231 let name = logical_name(*kind).unwrap_or("?");
232 let _ = writeln!(s, " - {kind:?} ({name})");
233 }
234 }
235
236 let mut by_path: [Vec<&KindDispatchSummary>; 4] = [vec![], vec![], vec![], vec![]];
237 for sum in &report.summaries {
238 let idx = match sum.path {
239 DispatchPath::Native => 0,
240 DispatchPath::CommonIr => 1,
241 DispatchPath::Rewritten => 2,
242 DispatchPath::Unsupported => 3,
243 };
244 by_path[idx].push(sum);
245 }
246
247 for (label, entries) in [
248 ("native", &by_path[0]),
249 ("common-ir", &by_path[1]),
250 ("rewritten", &by_path[2]),
251 ("unsupported", &by_path[3]),
252 ] {
253 if entries.is_empty() {
254 continue;
255 }
256 let _ = writeln!(s, " {label}:");
257 for e in entries {
258 let extra = e
259 .logical_name
260 .map(|n| format!(" logical={n}"))
261 .unwrap_or_default();
262 let _ = writeln!(s, " - {:?} ×{} nodes{extra}", e.kind, e.node_count);
263 }
264 }
265
266 if !report.still_unsupported.is_empty() {
267 let _ = writeln!(
268 s,
269 " still unsupported after rewrite ({} node(s)) — compile will fail:",
270 report.still_unsupported.len()
271 );
272 for (id, kind) in &report.still_unsupported {
273 let _ = writeln!(s, " - node {id:?}: {kind:?}");
274 }
275 let _ = writeln!(
276 s,
277 " Fix: implement native thunk + add to Backend::supported_ops, or add a \
278 rewrite/common body in rlx-fusion."
279 );
280 } else {
281 let _ = writeln!(s, " compile-ready: yes");
282 }
283
284 s
285}
286
287pub fn maybe_log_dispatch_report(report: &KernelDispatchReport) {
289 if rlx_ir::env::flag("RLX_DISPATCH_REPORT") || rlx_ir::env::flag("RLX_VERBOSE") {
290 eprintln!("{}", format_dispatch_report(report));
291 }
292}
293
294#[cfg(test)]
295mod tests {
296 use super::*;
297 use rlx_ir::*;
298
299 #[test]
300 fn common_lowered_when_not_in_supported() {
301 use rlx_ir::ops::splat::{GaussianSplatInputs, GaussianSplatRenderParams};
302 let mut g = Graph::new("splat");
303 let n = 2usize;
304 let f = DType::F32;
305 let positions = g.input("pos", Shape::new(&[n, 3], f));
306 let scales = g.input("scale", Shape::new(&[n, 3], f));
307 let rotations = g.input("rot", Shape::new(&[n, 4], f));
308 let opacities = g.input("opa", Shape::new(&[n], f));
309 let colors = g.input("col", Shape::new(&[n, 3], f));
310 let sh_coeffs = g.input("sh", Shape::new(&[n, 3], f));
311 let meta = g.input("meta", Shape::new(&[23], f));
312 let out = g.gaussian_splat_render(
313 GaussianSplatInputs {
314 positions,
315 scales,
316 rotations,
317 opacities,
318 colors,
319 sh_coeffs,
320 meta,
321 },
322 GaussianSplatRenderParams {
323 width: 8,
324 height: 8,
325 ..Default::default()
326 },
327 );
328 g.set_outputs(vec![out]);
329
330 let supported = &[OpKind::Input, OpKind::Param, OpKind::MatMul];
331 let report = analyze_dispatch(&g, "test", supported, KernelDispatchConfig::default());
332 assert!(
333 report
334 .common_lowered_kinds
335 .contains(&OpKind::GaussianSplatRender)
336 );
337 assert!(
338 report
339 .summaries
340 .iter()
341 .any(|s| s.kind == OpKind::GaussianSplatRender && s.path == DispatchPath::CommonIr)
342 );
343 }
344
345 #[test]
346 fn prepare_marks_rewritten_fused_op() {
347 let f = DType::F32;
348 let mut g = Graph::new("fused");
349 let x = g.input("x", Shape::new(&[2, 8], f));
350 let w = g.param("w", Shape::new(&[8, 4], f));
351 let b = g.param("b", Shape::new(&[4], f));
352 let out = g.fused_matmul_bias_act(x, w, b, None, Shape::new(&[2, 4], f));
353 g.set_outputs(vec![out]);
354
355 let supported = &[
356 OpKind::Input,
357 OpKind::Param,
358 OpKind::MatMul,
359 OpKind::Binary,
360 OpKind::Expand,
361 OpKind::Activation,
362 ];
363 let (rewritten, report) = prepare_graph_for_backend_with_report(
364 g,
365 "cpu",
366 supported,
367 KernelDispatchConfig::default(),
368 );
369 assert!(report.compile_ready);
370 assert!(
371 !rewritten
372 .nodes()
373 .iter()
374 .any(|n| n.op.kind() == OpKind::FusedMatMulBiasAct)
375 );
376 assert!(report.summaries.iter().any(|s| {
377 s.kind == OpKind::FusedMatMulBiasAct && s.path == DispatchPath::Rewritten
378 }));
379 }
380}