Skip to main content

rlx_compile/
dispatch_report.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3
4//! Backend lowering transparency — which ops run native, via common IR, or are missing.
5//!
6//! Use before or during compile to see what will be fast vs decomposed vs blocking.
7
8use std::collections::{HashMap, HashSet};
9use std::fmt::Write as _;
10
11use rlx_ir::logical_kernel::{
12    self, KernelDispatchConfig, KernelDispatchPolicy, registered_logical_kernels,
13    should_lower_to_common,
14};
15use rlx_ir::{Graph, NodeId, OpKind};
16
17use crate::legalize::legalize_for_backend;
18use crate::rewrite::rewrite_for_backend_with_config;
19
20/// How a logical / fused op reaches the backend executable.
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum DispatchPath {
23    /// Claimed in `supported_ops` (or backend makes no op claim).
24    Native,
25    /// Registered logical kernel lowered to primitive MIR (portable, often slower).
26    CommonIr,
27    /// Removed by structural rewrite (unfuse, LowerDotGeneral, …) into other kinds.
28    Rewritten,
29    /// Still not in `supported_ops` after rewrite — compile will fail.
30    Unsupported,
31}
32
33impl DispatchPath {
34    pub fn label(self) -> &'static str {
35        match self {
36            Self::Native => "native",
37            Self::CommonIr => "common-ir",
38            Self::Rewritten => "rewritten",
39            Self::Unsupported => "unsupported",
40        }
41    }
42}
43
44/// Per-`OpKind` summary for one graph + backend claim set.
45#[derive(Debug, Clone)]
46pub struct KindDispatchSummary {
47    pub kind: OpKind,
48    pub node_count: usize,
49    pub path: DispatchPath,
50    /// Set when [`DispatchPath::CommonIr`] (see [`registered_logical_kernels`]).
51    pub logical_name: Option<&'static str>,
52}
53
54/// Full report after rewrite + legalization probe (same path as [`crate::rewrite::legalize_or_rewrite_for_backend_with_config`]).
55#[derive(Debug, Clone)]
56pub struct KernelDispatchReport {
57    pub backend_name: String,
58    pub policy: KernelDispatchPolicy,
59    /// Length of `supported_ops` slice (0 = accept all kinds at legalize).
60    pub supported_claim_count: usize,
61    pub summaries: Vec<KindDispatchSummary>,
62    /// Kinds that will use common IR lowering on this compile.
63    pub common_lowered_kinds: Vec<OpKind>,
64    /// Offenders after all rewrites (empty when compile-ready).
65    pub still_unsupported: Vec<(NodeId, OpKind)>,
66    pub compile_ready: bool,
67}
68
69fn logical_name(kind: OpKind) -> Option<&'static str> {
70    registered_logical_kernels()
71        .iter()
72        .find(|e| e.kind == kind)
73        .map(|e| e.name)
74}
75
76fn count_kinds(graph: &Graph) -> HashMap<OpKind, usize> {
77    let mut m = HashMap::new();
78    for node in graph.nodes() {
79        *m.entry(node.op.kind()).or_default() += 1;
80    }
81    m
82}
83
84fn classify_kind(
85    kind: OpKind,
86    supported: &[OpKind],
87    config: KernelDispatchConfig,
88    common_set: &HashSet<OpKind>,
89    before: &HashMap<OpKind, usize>,
90    after: &HashMap<OpKind, usize>,
91    unsupported_kinds: &HashSet<OpKind>,
92) -> DispatchPath {
93    if should_lower_to_common(kind, supported, config) || common_set.contains(&kind) {
94        return DispatchPath::CommonIr;
95    }
96    if unsupported_kinds.contains(&kind) {
97        return DispatchPath::Unsupported;
98    }
99    if before.contains_key(&kind) && !after.contains_key(&kind) {
100        return DispatchPath::Rewritten;
101    }
102    if supported.is_empty() || supported.contains(&kind) {
103        return DispatchPath::Native;
104    }
105    if after.contains_key(&kind) {
106        return DispatchPath::Native;
107    }
108    DispatchPath::Unsupported
109}
110
111/// Analyze the graph **before** rewrite (static — does not run unfuse passes).
112pub fn analyze_dispatch(
113    graph: &Graph,
114    backend_name: &str,
115    supported: &[OpKind],
116    config: KernelDispatchConfig,
117) -> KernelDispatchReport {
118    let before = count_kinds(graph);
119    let common_lowered = logical_kernel::logical_kinds_in_graph(graph, supported, config);
120    let common_set: HashSet<OpKind> = common_lowered.iter().copied().collect();
121    let unsupported_set = HashSet::new();
122
123    let mut summaries: Vec<KindDispatchSummary> = before
124        .iter()
125        .map(|(&kind, &node_count)| {
126            let path = classify_kind(
127                kind,
128                supported,
129                config,
130                &common_set,
131                &before,
132                &before,
133                &unsupported_set,
134            );
135            KindDispatchSummary {
136                kind,
137                node_count,
138                path,
139                logical_name: logical_name(kind),
140            }
141        })
142        .collect();
143    summaries.sort_by_key(|s| format!("{:?}", s.kind));
144
145    KernelDispatchReport {
146        backend_name: backend_name.to_string(),
147        policy: config.policy,
148        supported_claim_count: supported.len(),
149        summaries,
150        common_lowered_kinds: common_lowered,
151        still_unsupported: Vec::new(),
152        // Static probe only — common-ir is compile-ready; use prepare_* for hard failures.
153        compile_ready: true,
154    }
155}
156
157/// Rewrite toward `supported`, then report native / common / rewritten / missing.
158pub fn prepare_graph_for_backend_with_report(
159    graph: Graph,
160    backend_name: &str,
161    supported: &[OpKind],
162    config: KernelDispatchConfig,
163) -> (Graph, KernelDispatchReport) {
164    let before = count_kinds(&graph);
165    let common_lowered = logical_kernel::logical_kinds_in_graph(&graph, supported, config);
166    let common_set: HashSet<OpKind> = common_lowered.iter().copied().collect();
167
168    let rewritten = rewrite_for_backend_with_config(graph, supported, config);
169    let after = count_kinds(&rewritten);
170    let still_unsupported = legalize_for_backend(&rewritten, supported)
171        .err()
172        .unwrap_or_default();
173    let unsupported_set: HashSet<OpKind> = still_unsupported.iter().map(|(_, k)| *k).collect();
174
175    let mut summaries: Vec<KindDispatchSummary> = before
176        .iter()
177        .map(|(&kind, &node_count)| {
178            let path = classify_kind(
179                kind,
180                supported,
181                config,
182                &common_set,
183                &before,
184                &after,
185                &unsupported_set,
186            );
187            KindDispatchSummary {
188                kind,
189                node_count,
190                path,
191                logical_name: logical_name(kind),
192            }
193        })
194        .collect();
195    summaries.sort_by_key(|s| format!("{:?}", s.kind));
196
197    let compile_ready = still_unsupported.is_empty();
198    let report = KernelDispatchReport {
199        backend_name: backend_name.to_string(),
200        policy: config.policy,
201        supported_claim_count: supported.len(),
202        summaries,
203        common_lowered_kinds: common_lowered,
204        still_unsupported,
205        compile_ready,
206    };
207    (rewritten, report)
208}
209
210/// Human-readable report for logs / CI / REPL.
211pub fn format_dispatch_report(report: &KernelDispatchReport) -> String {
212    let mut s = String::new();
213    let _ = writeln!(
214        s,
215        "rlx dispatch report — backend {:?}, policy {:?}, supported_ops claim={}",
216        report.backend_name, report.policy, report.supported_claim_count
217    );
218    if report.supported_claim_count == 0 {
219        let _ = writeln!(
220            s,
221            "  (empty claim = legalize accepts all kinds; native/common split is advisory only)"
222        );
223    }
224
225    if !report.common_lowered_kinds.is_empty() {
226        let _ = writeln!(
227            s,
228            "  common-ir lowering (portable, add to supported_ops for native fast path):"
229        );
230        for kind in &report.common_lowered_kinds {
231            let name = logical_name(*kind).unwrap_or("?");
232            let _ = writeln!(s, "    - {kind:?} ({name})");
233        }
234    }
235
236    let mut by_path: [Vec<&KindDispatchSummary>; 4] = [vec![], vec![], vec![], vec![]];
237    for sum in &report.summaries {
238        let idx = match sum.path {
239            DispatchPath::Native => 0,
240            DispatchPath::CommonIr => 1,
241            DispatchPath::Rewritten => 2,
242            DispatchPath::Unsupported => 3,
243        };
244        by_path[idx].push(sum);
245    }
246
247    for (label, entries) in [
248        ("native", &by_path[0]),
249        ("common-ir", &by_path[1]),
250        ("rewritten", &by_path[2]),
251        ("unsupported", &by_path[3]),
252    ] {
253        if entries.is_empty() {
254            continue;
255        }
256        let _ = writeln!(s, "  {label}:");
257        for e in entries {
258            let extra = e
259                .logical_name
260                .map(|n| format!(" logical={n}"))
261                .unwrap_or_default();
262            let _ = writeln!(s, "    - {:?} ×{} nodes{extra}", e.kind, e.node_count);
263        }
264    }
265
266    if !report.still_unsupported.is_empty() {
267        let _ = writeln!(
268            s,
269            "  still unsupported after rewrite ({} node(s)) — compile will fail:",
270            report.still_unsupported.len()
271        );
272        for (id, kind) in &report.still_unsupported {
273            let _ = writeln!(s, "    - node {id:?}: {kind:?}");
274        }
275        let _ = writeln!(
276            s,
277            "  Fix: implement native thunk + add to Backend::supported_ops, or add a \
278             rewrite/common body in rlx-fusion."
279        );
280    } else {
281        let _ = writeln!(s, "  compile-ready: yes");
282    }
283
284    s
285}
286
287/// Print when `RLX_VERBOSE=1` or `RLX_DISPATCH_REPORT=1`.
288pub fn maybe_log_dispatch_report(report: &KernelDispatchReport) {
289    if rlx_ir::env::flag("RLX_DISPATCH_REPORT") || rlx_ir::env::flag("RLX_VERBOSE") {
290        eprintln!("{}", format_dispatch_report(report));
291    }
292}
293
294#[cfg(test)]
295mod tests {
296    use super::*;
297    use rlx_ir::*;
298
299    #[test]
300    fn common_lowered_when_not_in_supported() {
301        use rlx_ir::ops::splat::{GaussianSplatInputs, GaussianSplatRenderParams};
302        let mut g = Graph::new("splat");
303        let n = 2usize;
304        let f = DType::F32;
305        let positions = g.input("pos", Shape::new(&[n, 3], f));
306        let scales = g.input("scale", Shape::new(&[n, 3], f));
307        let rotations = g.input("rot", Shape::new(&[n, 4], f));
308        let opacities = g.input("opa", Shape::new(&[n], f));
309        let colors = g.input("col", Shape::new(&[n, 3], f));
310        let sh_coeffs = g.input("sh", Shape::new(&[n, 3], f));
311        let meta = g.input("meta", Shape::new(&[23], f));
312        let out = g.gaussian_splat_render(
313            GaussianSplatInputs {
314                positions,
315                scales,
316                rotations,
317                opacities,
318                colors,
319                sh_coeffs,
320                meta,
321            },
322            GaussianSplatRenderParams {
323                width: 8,
324                height: 8,
325                ..Default::default()
326            },
327        );
328        g.set_outputs(vec![out]);
329
330        let supported = &[OpKind::Input, OpKind::Param, OpKind::MatMul];
331        let report = analyze_dispatch(&g, "test", supported, KernelDispatchConfig::default());
332        assert!(
333            report
334                .common_lowered_kinds
335                .contains(&OpKind::GaussianSplatRender)
336        );
337        assert!(
338            report
339                .summaries
340                .iter()
341                .any(|s| s.kind == OpKind::GaussianSplatRender && s.path == DispatchPath::CommonIr)
342        );
343    }
344
345    #[test]
346    fn prepare_marks_rewritten_fused_op() {
347        let f = DType::F32;
348        let mut g = Graph::new("fused");
349        let x = g.input("x", Shape::new(&[2, 8], f));
350        let w = g.param("w", Shape::new(&[8, 4], f));
351        let b = g.param("b", Shape::new(&[4], f));
352        let out = g.fused_matmul_bias_act(x, w, b, None, Shape::new(&[2, 4], f));
353        g.set_outputs(vec![out]);
354
355        let supported = &[
356            OpKind::Input,
357            OpKind::Param,
358            OpKind::MatMul,
359            OpKind::Binary,
360            OpKind::Expand,
361            OpKind::Activation,
362        ];
363        let (rewritten, report) = prepare_graph_for_backend_with_report(
364            g,
365            "cpu",
366            supported,
367            KernelDispatchConfig::default(),
368        );
369        assert!(report.compile_ready);
370        assert!(
371            !rewritten
372                .nodes()
373                .iter()
374                .any(|n| n.op.kind() == OpKind::FusedMatMulBiasAct)
375        );
376        assert!(report.summaries.iter().any(|s| {
377            s.kind == OpKind::FusedMatMulBiasAct && s.path == DispatchPath::Rewritten
378        }));
379    }
380}