1use std::collections::{HashMap, HashSet};
21use std::fmt::Write as _;
22
23use rlx_ir::logical_kernel::{
24 self, KernelDispatchConfig, KernelDispatchPolicy, registered_logical_kernels,
25 should_lower_to_common,
26};
27use rlx_ir::{Graph, NodeId, OpKind};
28
29use crate::legalize::legalize_for_backend;
30use crate::rewrite::rewrite_for_backend_with_config;
31
32#[derive(Debug, Clone, Copy, PartialEq, Eq)]
34pub enum DispatchPath {
35 Native,
37 CommonIr,
39 Rewritten,
41 Unsupported,
43}
44
45impl DispatchPath {
46 pub fn label(self) -> &'static str {
47 match self {
48 Self::Native => "native",
49 Self::CommonIr => "common-ir",
50 Self::Rewritten => "rewritten",
51 Self::Unsupported => "unsupported",
52 }
53 }
54}
55
56#[derive(Debug, Clone)]
58pub struct KindDispatchSummary {
59 pub kind: OpKind,
60 pub node_count: usize,
61 pub path: DispatchPath,
62 pub logical_name: Option<&'static str>,
64}
65
66#[derive(Debug, Clone)]
68pub struct KernelDispatchReport {
69 pub backend_name: String,
70 pub policy: KernelDispatchPolicy,
71 pub supported_claim_count: usize,
73 pub summaries: Vec<KindDispatchSummary>,
74 pub common_lowered_kinds: Vec<OpKind>,
76 pub still_unsupported: Vec<(NodeId, OpKind)>,
78 pub compile_ready: bool,
79}
80
81fn logical_name(kind: OpKind) -> Option<&'static str> {
82 registered_logical_kernels()
83 .iter()
84 .find(|e| e.kind == kind)
85 .map(|e| e.name)
86}
87
88fn count_kinds(graph: &Graph) -> HashMap<OpKind, usize> {
89 let mut m = HashMap::new();
90 for node in graph.nodes() {
91 *m.entry(node.op.kind()).or_default() += 1;
92 }
93 m
94}
95
96fn classify_kind(
97 kind: OpKind,
98 supported: &[OpKind],
99 config: KernelDispatchConfig,
100 common_set: &HashSet<OpKind>,
101 before: &HashMap<OpKind, usize>,
102 after: &HashMap<OpKind, usize>,
103 unsupported_kinds: &HashSet<OpKind>,
104) -> DispatchPath {
105 if should_lower_to_common(kind, supported, config) || common_set.contains(&kind) {
106 return DispatchPath::CommonIr;
107 }
108 if unsupported_kinds.contains(&kind) {
109 return DispatchPath::Unsupported;
110 }
111 if before.contains_key(&kind) && !after.contains_key(&kind) {
112 return DispatchPath::Rewritten;
113 }
114 if supported.is_empty() || supported.contains(&kind) {
115 return DispatchPath::Native;
116 }
117 if after.contains_key(&kind) {
118 return DispatchPath::Native;
119 }
120 DispatchPath::Unsupported
121}
122
123pub fn analyze_dispatch(
125 graph: &Graph,
126 backend_name: &str,
127 supported: &[OpKind],
128 config: KernelDispatchConfig,
129) -> KernelDispatchReport {
130 let before = count_kinds(graph);
131 let common_lowered = logical_kernel::logical_kinds_in_graph(graph, supported, config);
132 let common_set: HashSet<OpKind> = common_lowered.iter().copied().collect();
133 let unsupported_set = HashSet::new();
134
135 let mut summaries: Vec<KindDispatchSummary> = before
136 .iter()
137 .map(|(&kind, &node_count)| {
138 let path = classify_kind(
139 kind,
140 supported,
141 config,
142 &common_set,
143 &before,
144 &before,
145 &unsupported_set,
146 );
147 KindDispatchSummary {
148 kind,
149 node_count,
150 path,
151 logical_name: logical_name(kind),
152 }
153 })
154 .collect();
155 summaries.sort_by_key(|s| format!("{:?}", s.kind));
156
157 KernelDispatchReport {
158 backend_name: backend_name.to_string(),
159 policy: config.policy,
160 supported_claim_count: supported.len(),
161 summaries,
162 common_lowered_kinds: common_lowered,
163 still_unsupported: Vec::new(),
164 compile_ready: true,
166 }
167}
168
169pub fn prepare_graph_for_backend_with_report(
171 graph: Graph,
172 backend_name: &str,
173 supported: &[OpKind],
174 config: KernelDispatchConfig,
175) -> (Graph, KernelDispatchReport) {
176 let before = count_kinds(&graph);
177 let common_lowered = logical_kernel::logical_kinds_in_graph(&graph, supported, config);
178 let common_set: HashSet<OpKind> = common_lowered.iter().copied().collect();
179
180 let rewritten = rewrite_for_backend_with_config(graph, supported, config);
181 let after = count_kinds(&rewritten);
182 let still_unsupported = legalize_for_backend(&rewritten, supported)
183 .err()
184 .unwrap_or_default();
185 let unsupported_set: HashSet<OpKind> = still_unsupported.iter().map(|(_, k)| *k).collect();
186
187 let mut summaries: Vec<KindDispatchSummary> = before
188 .iter()
189 .map(|(&kind, &node_count)| {
190 let path = classify_kind(
191 kind,
192 supported,
193 config,
194 &common_set,
195 &before,
196 &after,
197 &unsupported_set,
198 );
199 KindDispatchSummary {
200 kind,
201 node_count,
202 path,
203 logical_name: logical_name(kind),
204 }
205 })
206 .collect();
207 summaries.sort_by_key(|s| format!("{:?}", s.kind));
208
209 let compile_ready = still_unsupported.is_empty();
210 let report = KernelDispatchReport {
211 backend_name: backend_name.to_string(),
212 policy: config.policy,
213 supported_claim_count: supported.len(),
214 summaries,
215 common_lowered_kinds: common_lowered,
216 still_unsupported,
217 compile_ready,
218 };
219 (rewritten, report)
220}
221
222pub fn format_dispatch_report(report: &KernelDispatchReport) -> String {
224 let mut s = String::new();
225 let _ = writeln!(
226 s,
227 "rlx dispatch report — backend {:?}, policy {:?}, supported_ops claim={}",
228 report.backend_name, report.policy, report.supported_claim_count
229 );
230 if report.supported_claim_count == 0 {
231 let _ = writeln!(
232 s,
233 " (empty claim = legalize accepts all kinds; native/common split is advisory only)"
234 );
235 }
236
237 if !report.common_lowered_kinds.is_empty() {
238 let _ = writeln!(
239 s,
240 " common-ir lowering (portable, add to supported_ops for native fast path):"
241 );
242 for kind in &report.common_lowered_kinds {
243 let name = logical_name(*kind).unwrap_or("?");
244 let _ = writeln!(s, " - {kind:?} ({name})");
245 }
246 }
247
248 let mut by_path: [Vec<&KindDispatchSummary>; 4] = [vec![], vec![], vec![], vec![]];
249 for sum in &report.summaries {
250 let idx = match sum.path {
251 DispatchPath::Native => 0,
252 DispatchPath::CommonIr => 1,
253 DispatchPath::Rewritten => 2,
254 DispatchPath::Unsupported => 3,
255 };
256 by_path[idx].push(sum);
257 }
258
259 for (label, entries) in [
260 ("native", &by_path[0]),
261 ("common-ir", &by_path[1]),
262 ("rewritten", &by_path[2]),
263 ("unsupported", &by_path[3]),
264 ] {
265 if entries.is_empty() {
266 continue;
267 }
268 let _ = writeln!(s, " {label}:");
269 for e in entries {
270 let extra = e
271 .logical_name
272 .map(|n| format!(" logical={n}"))
273 .unwrap_or_default();
274 let _ = writeln!(s, " - {:?} ×{} nodes{extra}", e.kind, e.node_count);
275 }
276 }
277
278 if !report.still_unsupported.is_empty() {
279 let _ = writeln!(
280 s,
281 " still unsupported after rewrite ({} node(s)) — compile will fail:",
282 report.still_unsupported.len()
283 );
284 for (id, kind) in &report.still_unsupported {
285 let _ = writeln!(s, " - node {id:?}: {kind:?}");
286 }
287 let _ = writeln!(
288 s,
289 " Fix: implement native thunk + add to Backend::supported_ops, or add a \
290 rewrite/common body in rlx-fusion."
291 );
292 } else {
293 let _ = writeln!(s, " compile-ready: yes");
294 }
295
296 s
297}
298
299pub fn maybe_log_dispatch_report(report: &KernelDispatchReport) {
301 if rlx_ir::env::flag("RLX_DISPATCH_REPORT") || rlx_ir::env::flag("RLX_VERBOSE") {
302 eprintln!("{}", format_dispatch_report(report));
303 }
304}
305
306#[cfg(test)]
307mod tests {
308 use super::*;
309 use rlx_ir::*;
310
311 #[test]
312 fn common_lowered_when_not_in_supported() {
313 use rlx_ir::ops::splat::{GaussianSplatInputs, GaussianSplatRenderParams};
314 let mut g = Graph::new("splat");
315 let n = 2usize;
316 let f = DType::F32;
317 let positions = g.input("pos", Shape::new(&[n, 3], f));
318 let scales = g.input("scale", Shape::new(&[n, 3], f));
319 let rotations = g.input("rot", Shape::new(&[n, 4], f));
320 let opacities = g.input("opa", Shape::new(&[n], f));
321 let colors = g.input("col", Shape::new(&[n, 3], f));
322 let sh_coeffs = g.input("sh", Shape::new(&[n, 3], f));
323 let meta = g.input("meta", Shape::new(&[23], f));
324 let out = g.gaussian_splat_render(
325 GaussianSplatInputs {
326 positions,
327 scales,
328 rotations,
329 opacities,
330 colors,
331 sh_coeffs,
332 meta,
333 },
334 GaussianSplatRenderParams {
335 width: 8,
336 height: 8,
337 ..Default::default()
338 },
339 );
340 g.set_outputs(vec![out]);
341
342 let supported = &[OpKind::Input, OpKind::Param, OpKind::MatMul];
343 let report = analyze_dispatch(&g, "test", supported, KernelDispatchConfig::default());
344 assert!(
345 report
346 .common_lowered_kinds
347 .contains(&OpKind::GaussianSplatRender)
348 );
349 assert!(
350 report
351 .summaries
352 .iter()
353 .any(|s| s.kind == OpKind::GaussianSplatRender && s.path == DispatchPath::CommonIr)
354 );
355 }
356
357 #[test]
358 fn prepare_marks_rewritten_fused_op() {
359 let f = DType::F32;
360 let mut g = Graph::new("fused");
361 let x = g.input("x", Shape::new(&[2, 8], f));
362 let w = g.param("w", Shape::new(&[8, 4], f));
363 let b = g.param("b", Shape::new(&[4], f));
364 let out = g.fused_matmul_bias_act(x, w, b, None, Shape::new(&[2, 4], f));
365 g.set_outputs(vec![out]);
366
367 let supported = &[
368 OpKind::Input,
369 OpKind::Param,
370 OpKind::MatMul,
371 OpKind::Binary,
372 OpKind::Expand,
373 OpKind::Activation,
374 ];
375 let (rewritten, report) = prepare_graph_for_backend_with_report(
376 g,
377 "cpu",
378 supported,
379 KernelDispatchConfig::default(),
380 );
381 assert!(report.compile_ready);
382 assert!(
383 !rewritten
384 .nodes()
385 .iter()
386 .any(|n| n.op.kind() == OpKind::FusedMatMulBiasAct)
387 );
388 assert!(report.summaries.iter().any(|s| {
389 s.kind == OpKind::FusedMatMulBiasAct && s.path == DispatchPath::Rewritten
390 }));
391 }
392}