rlx_ir/logical_kernel.rs
1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15//! One logical kernel, many backends — dispatch policy and registry.
16//!
17//! A **logical kernel** is a single [`OpKind`] (e.g. [`OpKind::GaussianSplatRender`]) with a
18//! documented semantic contract. Backends may provide a **native** implementation (fast path:
19//! custom thunk, MSL, MPS, etc.). When native is unavailable or [`KernelDispatchPolicy::ForceCommon`]
20//! is set, the compiler lowers to a **common** subgraph built only from primitive MIR ops so each
21//! backend schedules the same math through its usual fusion/GEMM/elementwise paths.
22//!
23//! Native kernels are never removed from backends; common lowering is additive.
24
25use crate::env;
26use crate::op::OpKind;
27
28pub mod splat_common;
29
30/// When to use native backend kernels vs the shared IR common body.
31#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
32pub enum KernelDispatchPolicy {
33 /// Native thunk when `OpKind` is in the backend `supported_ops`; else common IR lower.
34 #[default]
35 PreferNative,
36 /// Always lower registered logical kernels to common IR (parity / minimal backends).
37 ForceCommon,
38 /// Never common-lower; legalization must succeed with native ops only.
39 ForceNative,
40}
41
42impl KernelDispatchPolicy {
43 pub fn from_env() -> Self {
44 let v = env::var("KERNEL_DISPATCH").or_else(|| env::var("RLX_KERNEL_DISPATCH"));
45 match v.as_deref() {
46 Some("common") | Some("force_common") | Some("ForceCommon") => Self::ForceCommon,
47 Some("native") | Some("force_native") | Some("ForceNative") => Self::ForceNative,
48 _ => Self::PreferNative,
49 }
50 }
51}
52
53/// Registered logical kernel: native [`OpKind`] plus optional common lower pass name.
54#[derive(Debug, Clone, Copy)]
55pub struct LogicalKernelEntry {
56 pub kind: OpKind,
57 /// Human-readable id (logging / docs).
58 pub name: &'static str,
59}
60
61/// Logical kernels that have a registered common IR body in `rlx-fusion`.
62pub fn registered_logical_kernels() -> &'static [LogicalKernelEntry] {
63 &[
64 LogicalKernelEntry {
65 kind: OpKind::GroupNorm,
66 name: "group_norm",
67 },
68 LogicalKernelEntry {
69 kind: OpKind::BatchNormInference,
70 name: "batch_norm_inference",
71 },
72 LogicalKernelEntry {
73 kind: OpKind::ResizeNearest2x,
74 name: "resize_nearest_2x",
75 },
76 LogicalKernelEntry {
77 kind: OpKind::GaussianSplatRender,
78 name: "gaussian_splat_render",
79 },
80 LogicalKernelEntry {
81 kind: OpKind::GaussianSplatRenderBackward,
82 name: "gaussian_splat_render_backward",
83 },
84 ]
85}
86
87/// Per-compile overrides on top of [`KernelDispatchPolicy`].
88#[derive(Debug, Clone, Copy, Default)]
89pub struct KernelDispatchConfig {
90 pub policy: KernelDispatchPolicy,
91 /// Always common-lower these kinds (e.g. splat on CPU while keeping native matmul).
92 pub force_common_kinds: &'static [OpKind],
93 /// Never common-lower these kinds (overrides `ForceCommon` for listed kinds).
94 pub force_native_kinds: &'static [OpKind],
95}
96
97impl KernelDispatchConfig {
98 pub fn new(policy: KernelDispatchPolicy) -> Self {
99 Self {
100 policy,
101 ..Self::default()
102 }
103 }
104
105 pub fn from_env() -> Self {
106 Self::new(KernelDispatchPolicy::from_env())
107 }
108}
109
110/// Whether `kind` should be common-lowered for this backend claim set and config.
111pub fn should_lower_to_common(
112 kind: OpKind,
113 supported: &[OpKind],
114 config: KernelDispatchConfig,
115) -> bool {
116 if !registered_logical_kernels().iter().any(|e| e.kind == kind) {
117 return false;
118 }
119 if config.force_native_kinds.contains(&kind) {
120 return false;
121 }
122 if config.force_common_kinds.contains(&kind) {
123 return true;
124 }
125 match config.policy {
126 KernelDispatchPolicy::ForceCommon => true,
127 KernelDispatchPolicy::ForceNative => false,
128 KernelDispatchPolicy::PreferNative => !supported.is_empty() && !supported.contains(&kind),
129 }
130}
131
132/// Op kinds that appear in the graph and may need common lowering.
133pub fn logical_kinds_in_graph(
134 graph: &crate::Graph,
135 supported: &[OpKind],
136 config: KernelDispatchConfig,
137) -> Vec<OpKind> {
138 let mut kinds = Vec::new();
139 for node in graph.nodes() {
140 let k = node.op.kind();
141 if should_lower_to_common(k, supported, config) && !kinds.contains(&k) {
142 kinds.push(k);
143 }
144 }
145 kinds
146}