Skip to main content

rlx_ir/
logical_kernel.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15//! One logical kernel, many backends — dispatch policy and registry.
16//!
17//! A **logical kernel** is a single [`OpKind`] (e.g. [`OpKind::GaussianSplatRender`]) with a
18//! documented semantic contract. Backends may provide a **native** implementation (fast path:
19//! custom thunk, MSL, MPS, etc.). When native is unavailable or [`KernelDispatchPolicy::ForceCommon`]
20//! is set, the compiler lowers to a **common** subgraph built only from primitive MIR ops so each
21//! backend schedules the same math through its usual fusion/GEMM/elementwise paths.
22//!
23//! Native kernels are never removed from backends; common lowering is additive.
24
25use crate::env;
26use crate::op::OpKind;
27
28pub mod splat_common;
29
30/// When to use native backend kernels vs the shared IR common body.
31#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
32pub enum KernelDispatchPolicy {
33    /// Native thunk when `OpKind` is in the backend `supported_ops`; else common IR lower.
34    #[default]
35    PreferNative,
36    /// Always lower registered logical kernels to common IR (parity / minimal backends).
37    ForceCommon,
38    /// Never common-lower; legalization must succeed with native ops only.
39    ForceNative,
40}
41
42impl KernelDispatchPolicy {
43    pub fn from_env() -> Self {
44        let v = env::var("KERNEL_DISPATCH").or_else(|| env::var("RLX_KERNEL_DISPATCH"));
45        match v.as_deref() {
46            Some("common") | Some("force_common") | Some("ForceCommon") => Self::ForceCommon,
47            Some("native") | Some("force_native") | Some("ForceNative") => Self::ForceNative,
48            _ => Self::PreferNative,
49        }
50    }
51}
52
53/// Registered logical kernel: native [`OpKind`] plus optional common lower pass name.
54#[derive(Debug, Clone, Copy)]
55pub struct LogicalKernelEntry {
56    pub kind: OpKind,
57    /// Human-readable id (logging / docs).
58    pub name: &'static str,
59}
60
61/// Logical kernels that have a registered common IR body in `rlx-fusion`.
62pub fn registered_logical_kernels() -> &'static [LogicalKernelEntry] {
63    &[
64        LogicalKernelEntry {
65            kind: OpKind::GroupNorm,
66            name: "group_norm",
67        },
68        LogicalKernelEntry {
69            kind: OpKind::BatchNormInference,
70            name: "batch_norm_inference",
71        },
72        LogicalKernelEntry {
73            kind: OpKind::ResizeNearest2x,
74            name: "resize_nearest_2x",
75        },
76        LogicalKernelEntry {
77            kind: OpKind::GaussianSplatRender,
78            name: "gaussian_splat_render",
79        },
80        LogicalKernelEntry {
81            kind: OpKind::GaussianSplatRenderBackward,
82            name: "gaussian_splat_render_backward",
83        },
84    ]
85}
86
87/// Per-compile overrides on top of [`KernelDispatchPolicy`].
88#[derive(Debug, Clone, Copy, Default)]
89pub struct KernelDispatchConfig {
90    pub policy: KernelDispatchPolicy,
91    /// Always common-lower these kinds (e.g. splat on CPU while keeping native matmul).
92    pub force_common_kinds: &'static [OpKind],
93    /// Never common-lower these kinds (overrides `ForceCommon` for listed kinds).
94    pub force_native_kinds: &'static [OpKind],
95}
96
97impl KernelDispatchConfig {
98    pub fn new(policy: KernelDispatchPolicy) -> Self {
99        Self {
100            policy,
101            ..Self::default()
102        }
103    }
104
105    pub fn from_env() -> Self {
106        Self::new(KernelDispatchPolicy::from_env())
107    }
108}
109
110/// Whether `kind` should be common-lowered for this backend claim set and config.
111pub fn should_lower_to_common(
112    kind: OpKind,
113    supported: &[OpKind],
114    config: KernelDispatchConfig,
115) -> bool {
116    if !registered_logical_kernels().iter().any(|e| e.kind == kind) {
117        return false;
118    }
119    if config.force_native_kinds.contains(&kind) {
120        return false;
121    }
122    if config.force_common_kinds.contains(&kind) {
123        return true;
124    }
125    match config.policy {
126        KernelDispatchPolicy::ForceCommon => true,
127        KernelDispatchPolicy::ForceNative => false,
128        KernelDispatchPolicy::PreferNative => !supported.is_empty() && !supported.contains(&kind),
129    }
130}
131
132/// Op kinds that appear in the graph and may need common lowering.
133pub fn logical_kinds_in_graph(
134    graph: &crate::Graph,
135    supported: &[OpKind],
136    config: KernelDispatchConfig,
137) -> Vec<OpKind> {
138    let mut kinds = Vec::new();
139    for node in graph.nodes() {
140        let k = node.op.kind();
141        if should_lower_to_common(k, supported, config) && !kinds.contains(&k) {
142            kinds.push(k);
143        }
144    }
145    kinds
146}