Skip to main content

rlx_compile/
dispatch_report.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Backend lowering transparency — which ops run native, via common IR, or are missing.
17//!
18//! Use before or during compile to see what will be fast vs decomposed vs blocking.
19
20use std::collections::{HashMap, HashSet};
21use std::fmt::Write as _;
22
23use rlx_ir::logical_kernel::{
24    self, KernelDispatchConfig, KernelDispatchPolicy, registered_logical_kernels,
25    should_lower_to_common,
26};
27use rlx_ir::{Graph, NodeId, OpKind};
28
29use crate::legalize::legalize_for_backend;
30use crate::rewrite::rewrite_for_backend_with_config;
31
32/// How a logical / fused op reaches the backend executable.
33#[derive(Debug, Clone, Copy, PartialEq, Eq)]
34pub enum DispatchPath {
35    /// Claimed in `supported_ops` (or backend makes no op claim).
36    Native,
37    /// Registered logical kernel lowered to primitive MIR (portable, often slower).
38    CommonIr,
39    /// Removed by structural rewrite (unfuse, LowerDotGeneral, …) into other kinds.
40    Rewritten,
41    /// Still not in `supported_ops` after rewrite — compile will fail.
42    Unsupported,
43}
44
45impl DispatchPath {
46    pub fn label(self) -> &'static str {
47        match self {
48            Self::Native => "native",
49            Self::CommonIr => "common-ir",
50            Self::Rewritten => "rewritten",
51            Self::Unsupported => "unsupported",
52        }
53    }
54}
55
56/// Per-`OpKind` summary for one graph + backend claim set.
57#[derive(Debug, Clone)]
58pub struct KindDispatchSummary {
59    pub kind: OpKind,
60    pub node_count: usize,
61    pub path: DispatchPath,
62    /// Set when [`DispatchPath::CommonIr`] (see [`registered_logical_kernels`]).
63    pub logical_name: Option<&'static str>,
64}
65
66/// Full report after rewrite + legalization probe (same path as [`crate::rewrite::legalize_or_rewrite_for_backend_with_config`]).
67#[derive(Debug, Clone)]
68pub struct KernelDispatchReport {
69    pub backend_name: String,
70    pub policy: KernelDispatchPolicy,
71    /// Length of `supported_ops` slice (0 = accept all kinds at legalize).
72    pub supported_claim_count: usize,
73    pub summaries: Vec<KindDispatchSummary>,
74    /// Kinds that will use common IR lowering on this compile.
75    pub common_lowered_kinds: Vec<OpKind>,
76    /// Offenders after all rewrites (empty when compile-ready).
77    pub still_unsupported: Vec<(NodeId, OpKind)>,
78    pub compile_ready: bool,
79}
80
81fn logical_name(kind: OpKind) -> Option<&'static str> {
82    registered_logical_kernels()
83        .iter()
84        .find(|e| e.kind == kind)
85        .map(|e| e.name)
86}
87
88fn count_kinds(graph: &Graph) -> HashMap<OpKind, usize> {
89    let mut m = HashMap::new();
90    for node in graph.nodes() {
91        *m.entry(node.op.kind()).or_default() += 1;
92    }
93    m
94}
95
96fn classify_kind(
97    kind: OpKind,
98    supported: &[OpKind],
99    config: KernelDispatchConfig,
100    common_set: &HashSet<OpKind>,
101    before: &HashMap<OpKind, usize>,
102    after: &HashMap<OpKind, usize>,
103    unsupported_kinds: &HashSet<OpKind>,
104) -> DispatchPath {
105    if should_lower_to_common(kind, supported, config) || common_set.contains(&kind) {
106        return DispatchPath::CommonIr;
107    }
108    if unsupported_kinds.contains(&kind) {
109        return DispatchPath::Unsupported;
110    }
111    if before.contains_key(&kind) && !after.contains_key(&kind) {
112        return DispatchPath::Rewritten;
113    }
114    if supported.is_empty() || supported.contains(&kind) {
115        return DispatchPath::Native;
116    }
117    if after.contains_key(&kind) {
118        return DispatchPath::Native;
119    }
120    DispatchPath::Unsupported
121}
122
123/// Analyze the graph **before** rewrite (static — does not run unfuse passes).
124pub fn analyze_dispatch(
125    graph: &Graph,
126    backend_name: &str,
127    supported: &[OpKind],
128    config: KernelDispatchConfig,
129) -> KernelDispatchReport {
130    let before = count_kinds(graph);
131    let common_lowered = logical_kernel::logical_kinds_in_graph(graph, supported, config);
132    let common_set: HashSet<OpKind> = common_lowered.iter().copied().collect();
133    let unsupported_set = HashSet::new();
134
135    let mut summaries: Vec<KindDispatchSummary> = before
136        .iter()
137        .map(|(&kind, &node_count)| {
138            let path = classify_kind(
139                kind,
140                supported,
141                config,
142                &common_set,
143                &before,
144                &before,
145                &unsupported_set,
146            );
147            KindDispatchSummary {
148                kind,
149                node_count,
150                path,
151                logical_name: logical_name(kind),
152            }
153        })
154        .collect();
155    summaries.sort_by_key(|s| format!("{:?}", s.kind));
156
157    KernelDispatchReport {
158        backend_name: backend_name.to_string(),
159        policy: config.policy,
160        supported_claim_count: supported.len(),
161        summaries,
162        common_lowered_kinds: common_lowered,
163        still_unsupported: Vec::new(),
164        // Static probe only — common-ir is compile-ready; use prepare_* for hard failures.
165        compile_ready: true,
166    }
167}
168
169/// Rewrite toward `supported`, then report native / common / rewritten / missing.
170pub fn prepare_graph_for_backend_with_report(
171    graph: Graph,
172    backend_name: &str,
173    supported: &[OpKind],
174    config: KernelDispatchConfig,
175) -> (Graph, KernelDispatchReport) {
176    let before = count_kinds(&graph);
177    let common_lowered = logical_kernel::logical_kinds_in_graph(&graph, supported, config);
178    let common_set: HashSet<OpKind> = common_lowered.iter().copied().collect();
179
180    let rewritten = rewrite_for_backend_with_config(graph, supported, config);
181    let after = count_kinds(&rewritten);
182    let still_unsupported = legalize_for_backend(&rewritten, supported)
183        .err()
184        .unwrap_or_default();
185    let unsupported_set: HashSet<OpKind> = still_unsupported.iter().map(|(_, k)| *k).collect();
186
187    let mut summaries: Vec<KindDispatchSummary> = before
188        .iter()
189        .map(|(&kind, &node_count)| {
190            let path = classify_kind(
191                kind,
192                supported,
193                config,
194                &common_set,
195                &before,
196                &after,
197                &unsupported_set,
198            );
199            KindDispatchSummary {
200                kind,
201                node_count,
202                path,
203                logical_name: logical_name(kind),
204            }
205        })
206        .collect();
207    summaries.sort_by_key(|s| format!("{:?}", s.kind));
208
209    let compile_ready = still_unsupported.is_empty();
210    let report = KernelDispatchReport {
211        backend_name: backend_name.to_string(),
212        policy: config.policy,
213        supported_claim_count: supported.len(),
214        summaries,
215        common_lowered_kinds: common_lowered,
216        still_unsupported,
217        compile_ready,
218    };
219    (rewritten, report)
220}
221
222/// Human-readable report for logs / CI / REPL.
223pub fn format_dispatch_report(report: &KernelDispatchReport) -> String {
224    let mut s = String::new();
225    let _ = writeln!(
226        s,
227        "rlx dispatch report — backend {:?}, policy {:?}, supported_ops claim={}",
228        report.backend_name, report.policy, report.supported_claim_count
229    );
230    if report.supported_claim_count == 0 {
231        let _ = writeln!(
232            s,
233            "  (empty claim = legalize accepts all kinds; native/common split is advisory only)"
234        );
235    }
236
237    if !report.common_lowered_kinds.is_empty() {
238        let _ = writeln!(
239            s,
240            "  common-ir lowering (portable, add to supported_ops for native fast path):"
241        );
242        for kind in &report.common_lowered_kinds {
243            let name = logical_name(*kind).unwrap_or("?");
244            let _ = writeln!(s, "    - {kind:?} ({name})");
245        }
246    }
247
248    let mut by_path: [Vec<&KindDispatchSummary>; 4] = [vec![], vec![], vec![], vec![]];
249    for sum in &report.summaries {
250        let idx = match sum.path {
251            DispatchPath::Native => 0,
252            DispatchPath::CommonIr => 1,
253            DispatchPath::Rewritten => 2,
254            DispatchPath::Unsupported => 3,
255        };
256        by_path[idx].push(sum);
257    }
258
259    for (label, entries) in [
260        ("native", &by_path[0]),
261        ("common-ir", &by_path[1]),
262        ("rewritten", &by_path[2]),
263        ("unsupported", &by_path[3]),
264    ] {
265        if entries.is_empty() {
266            continue;
267        }
268        let _ = writeln!(s, "  {label}:");
269        for e in entries {
270            let extra = e
271                .logical_name
272                .map(|n| format!(" logical={n}"))
273                .unwrap_or_default();
274            let _ = writeln!(s, "    - {:?} ×{} nodes{extra}", e.kind, e.node_count);
275        }
276    }
277
278    if !report.still_unsupported.is_empty() {
279        let _ = writeln!(
280            s,
281            "  still unsupported after rewrite ({} node(s)) — compile will fail:",
282            report.still_unsupported.len()
283        );
284        for (id, kind) in &report.still_unsupported {
285            let _ = writeln!(s, "    - node {id:?}: {kind:?}");
286        }
287        let _ = writeln!(
288            s,
289            "  Fix: implement native thunk + add to Backend::supported_ops, or add a \
290             rewrite/common body in rlx-fusion."
291        );
292    } else {
293        let _ = writeln!(s, "  compile-ready: yes");
294    }
295
296    s
297}
298
299/// Print when `RLX_VERBOSE=1` or `RLX_DISPATCH_REPORT=1`.
300pub fn maybe_log_dispatch_report(report: &KernelDispatchReport) {
301    if rlx_ir::env::flag("RLX_DISPATCH_REPORT") || rlx_ir::env::flag("RLX_VERBOSE") {
302        eprintln!("{}", format_dispatch_report(report));
303    }
304}
305
306#[cfg(test)]
307mod tests {
308    use super::*;
309    use rlx_ir::*;
310
311    #[test]
312    fn common_lowered_when_not_in_supported() {
313        use rlx_ir::ops::splat::{GaussianSplatInputs, GaussianSplatRenderParams};
314        let mut g = Graph::new("splat");
315        let n = 2usize;
316        let f = DType::F32;
317        let positions = g.input("pos", Shape::new(&[n, 3], f));
318        let scales = g.input("scale", Shape::new(&[n, 3], f));
319        let rotations = g.input("rot", Shape::new(&[n, 4], f));
320        let opacities = g.input("opa", Shape::new(&[n], f));
321        let colors = g.input("col", Shape::new(&[n, 3], f));
322        let sh_coeffs = g.input("sh", Shape::new(&[n, 3], f));
323        let meta = g.input("meta", Shape::new(&[23], f));
324        let out = g.gaussian_splat_render(
325            GaussianSplatInputs {
326                positions,
327                scales,
328                rotations,
329                opacities,
330                colors,
331                sh_coeffs,
332                meta,
333            },
334            GaussianSplatRenderParams {
335                width: 8,
336                height: 8,
337                ..Default::default()
338            },
339        );
340        g.set_outputs(vec![out]);
341
342        let supported = &[OpKind::Input, OpKind::Param, OpKind::MatMul];
343        let report = analyze_dispatch(&g, "test", supported, KernelDispatchConfig::default());
344        assert!(
345            report
346                .common_lowered_kinds
347                .contains(&OpKind::GaussianSplatRender)
348        );
349        assert!(
350            report
351                .summaries
352                .iter()
353                .any(|s| s.kind == OpKind::GaussianSplatRender && s.path == DispatchPath::CommonIr)
354        );
355    }
356
357    #[test]
358    fn prepare_marks_rewritten_fused_op() {
359        let f = DType::F32;
360        let mut g = Graph::new("fused");
361        let x = g.input("x", Shape::new(&[2, 8], f));
362        let w = g.param("w", Shape::new(&[8, 4], f));
363        let b = g.param("b", Shape::new(&[4], f));
364        let out = g.fused_matmul_bias_act(x, w, b, None, Shape::new(&[2, 4], f));
365        g.set_outputs(vec![out]);
366
367        let supported = &[
368            OpKind::Input,
369            OpKind::Param,
370            OpKind::MatMul,
371            OpKind::Binary,
372            OpKind::Expand,
373            OpKind::Activation,
374        ];
375        let (rewritten, report) = prepare_graph_for_backend_with_report(
376            g,
377            "cpu",
378            supported,
379            KernelDispatchConfig::default(),
380        );
381        assert!(report.compile_ready);
382        assert!(
383            !rewritten
384                .nodes()
385                .iter()
386                .any(|n| n.op.kind() == OpKind::FusedMatMulBiasAct)
387        );
388        assert!(report.summaries.iter().any(|s| {
389            s.kind == OpKind::FusedMatMulBiasAct && s.path == DispatchPath::Rewritten
390        }));
391    }
392}