Skip to main content

rlx_ir/
logical_kernel.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15//! One logical kernel, many backends — dispatch policy and registry.
16//!
17//! A **logical kernel** is a single [`OpKind`] (e.g. [`OpKind::GaussianSplatRender`]) with a
18//! documented semantic contract. Backends may provide a **native** implementation (fast path:
19//! custom thunk, MSL, MPS, etc.). When native is unavailable or [`KernelDispatchPolicy::ForceCommon`]
20//! is set, the compiler lowers to a **common** subgraph built only from primitive MIR ops so each
21//! backend schedules the same math through its usual fusion/GEMM/elementwise paths.
22//!
23//! Native kernels are never removed from backends; common lowering is additive.
24
25use crate::env;
26use crate::op::OpKind;
27
28pub mod splat_common;
29
30/// When to use native backend kernels vs the shared IR common body.
31#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
32pub enum KernelDispatchPolicy {
33    /// Native thunk when `OpKind` is in the backend `supported_ops`; else common IR lower.
34    #[default]
35    PreferNative,
36    /// Always lower registered logical kernels to common IR (parity / minimal backends).
37    ForceCommon,
38    /// Never common-lower; legalization must succeed with native ops only.
39    ForceNative,
40}
41
42impl KernelDispatchPolicy {
43    pub fn from_env() -> Self {
44        let v = env::var("KERNEL_DISPATCH").or_else(|| env::var("RLX_KERNEL_DISPATCH"));
45        match v.as_deref() {
46            Some("common") | Some("force_common") | Some("ForceCommon") => Self::ForceCommon,
47            Some("native") | Some("force_native") | Some("ForceNative") => Self::ForceNative,
48            _ => Self::PreferNative,
49        }
50    }
51}
52
53/// Registered logical kernel: native [`OpKind`] plus optional common lower pass name.
54#[derive(Debug, Clone, Copy)]
55pub struct LogicalKernelEntry {
56    pub kind: OpKind,
57    /// Human-readable id (logging / docs).
58    pub name: &'static str,
59}
60
61/// Logical kernels that have a registered common IR body in `rlx-fusion`.
62pub fn registered_logical_kernels() -> &'static [LogicalKernelEntry] {
63    &[
64        LogicalKernelEntry {
65            kind: OpKind::GroupNorm,
66            name: "group_norm",
67        },
68        LogicalKernelEntry {
69            kind: OpKind::ResizeNearest2x,
70            name: "resize_nearest_2x",
71        },
72        LogicalKernelEntry {
73            kind: OpKind::GaussianSplatRender,
74            name: "gaussian_splat_render",
75        },
76        LogicalKernelEntry {
77            kind: OpKind::GaussianSplatRenderBackward,
78            name: "gaussian_splat_render_backward",
79        },
80    ]
81}
82
83/// Per-compile overrides on top of [`KernelDispatchPolicy`].
84#[derive(Debug, Clone, Copy, Default)]
85pub struct KernelDispatchConfig {
86    pub policy: KernelDispatchPolicy,
87    /// Always common-lower these kinds (e.g. splat on CPU while keeping native matmul).
88    pub force_common_kinds: &'static [OpKind],
89    /// Never common-lower these kinds (overrides `ForceCommon` for listed kinds).
90    pub force_native_kinds: &'static [OpKind],
91}
92
93impl KernelDispatchConfig {
94    pub fn new(policy: KernelDispatchPolicy) -> Self {
95        Self {
96            policy,
97            ..Self::default()
98        }
99    }
100
101    pub fn from_env() -> Self {
102        Self::new(KernelDispatchPolicy::from_env())
103    }
104}
105
106/// Whether `kind` should be common-lowered for this backend claim set and config.
107pub fn should_lower_to_common(
108    kind: OpKind,
109    supported: &[OpKind],
110    config: KernelDispatchConfig,
111) -> bool {
112    if !registered_logical_kernels().iter().any(|e| e.kind == kind) {
113        return false;
114    }
115    if config.force_native_kinds.contains(&kind) {
116        return false;
117    }
118    if config.force_common_kinds.contains(&kind) {
119        return true;
120    }
121    match config.policy {
122        KernelDispatchPolicy::ForceCommon => true,
123        KernelDispatchPolicy::ForceNative => false,
124        KernelDispatchPolicy::PreferNative => !supported.is_empty() && !supported.contains(&kind),
125    }
126}
127
128/// Op kinds that appear in the graph and may need common lowering.
129pub fn logical_kinds_in_graph(
130    graph: &crate::Graph,
131    supported: &[OpKind],
132    config: KernelDispatchConfig,
133) -> Vec<OpKind> {
134    let mut kinds = Vec::new();
135    for node in graph.nodes() {
136        let k = node.op.kind();
137        if should_lower_to_common(k, supported, config) && !kinds.contains(&k) {
138            kinds.push(k);
139        }
140    }
141    kinds
142}