rlx_runtime/options.rs
1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Unified compile options.
17//!
18//! Replaces the historical mix of `compile()`, `compile_with_precision()`,
19//! `compile_with_options()` with a single `Backend::compile(graph, &options)`.
20//! New compile-time knobs can be added to `CompileOptions` without
21//! changing the trait — backends just read what they care about.
22//!
23//! Builder-pattern API for ergonomics:
24//!
25//! ```rust,ignore
26//! let opts = CompileOptions::new()
27//! .precision(Precision::F16)
28//! .policy(PrecisionPolicy::AutoMixed)
29//! .with_dce(true)
30//! .with_constant_folding(true);
31//! ```
32
33use crate::Precision;
34use rlx_ir::OpKind;
35use rlx_ir::logical_kernel::{KernelDispatchConfig, KernelDispatchPolicy};
36use rlx_opt::{FusionOptions, FusionTarget, PrecisionPolicy};
37use std::collections::HashMap;
38
39/// All knobs the compile pipeline understands.
40/// Add new fields here rather than introducing new compile entry points.
41#[derive(Debug, Clone)]
42pub struct CompileOptions {
43 /// Target numeric precision for execution. Default: F32.
44 pub precision: Precision,
45 /// Optional per-op precision policy (mixed precision rewrite).
46 pub policy: Option<PrecisionPolicy>,
47 /// Run dead-code elimination as part of compile. Default: true.
48 pub dce: bool,
49 /// Run constant folding. Default: true (cheap, only helps).
50 pub constant_folding: bool,
51 /// Verbose pass logging. Equivalent to `RLX_VERBOSE=1` or
52 /// [`rlx_ir::env::set("RLX_VERBOSE", "1")`].
53 pub verbose: bool,
54 /// Override fusion pipeline target (default: inferred from device).
55 pub fusion_target: Option<FusionTarget>,
56 /// Per-target fusion toggles (Metal env overrides, skip fusion, …).
57 pub fusion_opts: FusionOptions,
58 /// Arena alignment for buffer planning. Default: 64.
59 pub arena_alignment: usize,
60 /// Panic at compile time if fusion diagnostics report missed patterns.
61 pub assert_fusion_clean: bool,
62 /// Backend op claim set for backend-aware fusion + post-fusion
63 /// legalization. Set by [`Backend::compile`] implementations.
64 pub supported_ops: Option<&'static [OpKind]>,
65 /// When set, specialize symbolic dims before backend lowering.
66 pub dim_binding: Option<rlx_ir::DimBinding>,
67 /// Bake fixed param tensors into constants before DCE / constant folding.
68 pub param_bindings: Option<HashMap<String, Vec<f32>>>,
69 /// Native vs common IR lowering ([`KernelDispatchConfig`], `RLX_KERNEL_DISPATCH=common`).
70 pub kernel_dispatch: KernelDispatchConfig,
71}
72
73impl Default for CompileOptions {
74 fn default() -> Self {
75 Self {
76 precision: Precision::F32,
77 policy: None,
78 dce: true,
79 constant_folding: true,
80 verbose: false,
81 fusion_target: None,
82 fusion_opts: FusionOptions::default(),
83 arena_alignment: 64,
84 assert_fusion_clean: false,
85 supported_ops: None,
86 dim_binding: None,
87 param_bindings: None,
88 kernel_dispatch: KernelDispatchConfig::from_env(),
89 }
90 }
91}
92
93impl CompileOptions {
94 pub fn new() -> Self {
95 Self::default()
96 }
97
98 pub fn precision(mut self, p: Precision) -> Self {
99 self.precision = p;
100 self
101 }
102 pub fn policy(mut self, p: PrecisionPolicy) -> Self {
103 self.policy = Some(p);
104 self
105 }
106 pub fn no_policy(mut self) -> Self {
107 self.policy = None;
108 self
109 }
110 pub fn with_dce(mut self, on: bool) -> Self {
111 self.dce = on;
112 self
113 }
114 pub fn with_constant_folding(mut self, on: bool) -> Self {
115 self.constant_folding = on;
116 self
117 }
118 pub fn with_verbose(mut self, on: bool) -> Self {
119 self.verbose = on;
120 self
121 }
122 pub fn fusion_target(mut self, target: FusionTarget) -> Self {
123 self.fusion_target = Some(target);
124 self
125 }
126 pub fn fusion_opts(mut self, opts: FusionOptions) -> Self {
127 self.fusion_opts = opts;
128 self
129 }
130 pub fn arena_alignment(mut self, bytes: usize) -> Self {
131 self.arena_alignment = bytes;
132 self
133 }
134 pub fn supported_ops(mut self, ops: &'static [OpKind]) -> Self {
135 self.supported_ops = Some(ops);
136 self
137 }
138 pub fn assert_fusion_clean(mut self, on: bool) -> Self {
139 self.assert_fusion_clean = on;
140 self
141 }
142 pub fn dim_binding(mut self, binding: rlx_ir::DimBinding) -> Self {
143 self.dim_binding = Some(binding);
144 self
145 }
146 pub fn param_bindings(mut self, bindings: HashMap<String, Vec<f32>>) -> Self {
147 self.param_bindings = Some(bindings);
148 self
149 }
150 pub fn kernel_dispatch(mut self, policy: KernelDispatchPolicy) -> Self {
151 self.kernel_dispatch.policy = policy;
152 self
153 }
154
155 pub fn kernel_dispatch_config(mut self, config: KernelDispatchConfig) -> Self {
156 self.kernel_dispatch = config;
157 self
158 }
159
160 /// Force listed logical kernels to use common IR even when native is in `supported_ops`.
161 pub fn force_common_kinds(mut self, kinds: &'static [OpKind]) -> Self {
162 self.kernel_dispatch.force_common_kinds = kinds;
163 self
164 }
165
166 /// Keep listed logical kernels native even under `ForceCommon` / missing from `supported_ops`.
167 pub fn force_native_kinds(mut self, kinds: &'static [OpKind]) -> Self {
168 self.kernel_dispatch.force_native_kinds = kinds;
169 self
170 }
171}