rlx_ir/logical_kernel.rs
1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15//! One logical kernel, many backends — dispatch policy and registry.
16//!
17//! A **logical kernel** is a single [`OpKind`] (e.g. [`OpKind::GaussianSplatRender`]) with a
18//! documented semantic contract. Backends may provide a **native** implementation (fast path:
19//! custom thunk, MSL, MPS, etc.). When native is unavailable or [`KernelDispatchPolicy::ForceCommon`]
20//! is set, the compiler lowers to a **common** subgraph built only from primitive MIR ops so each
21//! backend schedules the same math through its usual fusion/GEMM/elementwise paths.
22//!
23//! Native kernels are never removed from backends; common lowering is additive.
24
25use crate::env;
26use crate::op::OpKind;
27
28pub mod splat_common;
29
30/// When to use native backend kernels vs the shared IR common body.
31#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
32pub enum KernelDispatchPolicy {
33 /// Native thunk when `OpKind` is in the backend `supported_ops`; else common IR lower.
34 #[default]
35 PreferNative,
36 /// Always lower registered logical kernels to common IR (parity / minimal backends).
37 ForceCommon,
38 /// Never common-lower; legalization must succeed with native ops only.
39 ForceNative,
40}
41
42impl KernelDispatchPolicy {
43 pub fn from_env() -> Self {
44 let v = env::var("KERNEL_DISPATCH").or_else(|| env::var("RLX_KERNEL_DISPATCH"));
45 match v.as_deref() {
46 Some("common") | Some("force_common") | Some("ForceCommon") => Self::ForceCommon,
47 Some("native") | Some("force_native") | Some("ForceNative") => Self::ForceNative,
48 _ => Self::PreferNative,
49 }
50 }
51}
52
53/// Registered logical kernel: native [`OpKind`] plus optional common lower pass name.
54#[derive(Debug, Clone, Copy)]
55pub struct LogicalKernelEntry {
56 pub kind: OpKind,
57 /// Human-readable id (logging / docs).
58 pub name: &'static str,
59}
60
61/// Logical kernels that have a registered common IR body in `rlx-fusion`.
62pub fn registered_logical_kernels() -> &'static [LogicalKernelEntry] {
63 &[
64 LogicalKernelEntry {
65 kind: OpKind::GroupNorm,
66 name: "group_norm",
67 },
68 LogicalKernelEntry {
69 kind: OpKind::ResizeNearest2x,
70 name: "resize_nearest_2x",
71 },
72 LogicalKernelEntry {
73 kind: OpKind::GaussianSplatRender,
74 name: "gaussian_splat_render",
75 },
76 LogicalKernelEntry {
77 kind: OpKind::GaussianSplatRenderBackward,
78 name: "gaussian_splat_render_backward",
79 },
80 ]
81}
82
83/// Per-compile overrides on top of [`KernelDispatchPolicy`].
84#[derive(Debug, Clone, Copy, Default)]
85pub struct KernelDispatchConfig {
86 pub policy: KernelDispatchPolicy,
87 /// Always common-lower these kinds (e.g. splat on CPU while keeping native matmul).
88 pub force_common_kinds: &'static [OpKind],
89 /// Never common-lower these kinds (overrides `ForceCommon` for listed kinds).
90 pub force_native_kinds: &'static [OpKind],
91}
92
93impl KernelDispatchConfig {
94 pub fn new(policy: KernelDispatchPolicy) -> Self {
95 Self {
96 policy,
97 ..Self::default()
98 }
99 }
100
101 pub fn from_env() -> Self {
102 Self::new(KernelDispatchPolicy::from_env())
103 }
104}
105
106/// Whether `kind` should be common-lowered for this backend claim set and config.
107pub fn should_lower_to_common(
108 kind: OpKind,
109 supported: &[OpKind],
110 config: KernelDispatchConfig,
111) -> bool {
112 if !registered_logical_kernels().iter().any(|e| e.kind == kind) {
113 return false;
114 }
115 if config.force_native_kinds.contains(&kind) {
116 return false;
117 }
118 if config.force_common_kinds.contains(&kind) {
119 return true;
120 }
121 match config.policy {
122 KernelDispatchPolicy::ForceCommon => true,
123 KernelDispatchPolicy::ForceNative => false,
124 KernelDispatchPolicy::PreferNative => !supported.is_empty() && !supported.contains(&kind),
125 }
126}
127
128/// Op kinds that appear in the graph and may need common lowering.
129pub fn logical_kinds_in_graph(
130 graph: &crate::Graph,
131 supported: &[OpKind],
132 config: KernelDispatchConfig,
133) -> Vec<OpKind> {
134 let mut kinds = Vec::new();
135 for node in graph.nodes() {
136 let k = node.op.kind();
137 if should_lower_to_common(k, supported, config) && !kinds.contains(&k) {
138 kinds.push(k);
139 }
140 }
141 kinds
142}