Skip to main content

rlx_runtime/
options.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Unified compile options.
17//!
18//! Replaces the historical mix of `compile()`, `compile_with_precision()`,
19//! `compile_with_options()` with a single `Backend::compile(graph, &options)`.
20//! New compile-time knobs can be added to `CompileOptions` without
21//! changing the trait — backends just read what they care about.
22//!
23//! Builder-pattern API for ergonomics:
24//!
25//! ```rust,ignore
26//! let opts = CompileOptions::new()
27//!     .precision(Precision::F16)
28//!     .policy(PrecisionPolicy::AutoMixed)
29//!     .with_dce(true)
30//!     .with_constant_folding(true);
31//! ```
32
33use crate::Precision;
34use rlx_ir::OpKind;
35use rlx_ir::logical_kernel::{KernelDispatchConfig, KernelDispatchPolicy};
36use rlx_opt::{FusionOptions, FusionTarget, PrecisionPolicy};
37use std::collections::HashMap;
38
39/// All knobs the compile pipeline understands.
40/// Add new fields here rather than introducing new compile entry points.
41#[derive(Debug, Clone)]
42pub struct CompileOptions {
43    /// Target numeric precision for execution. Default: F32.
44    pub precision: Precision,
45    /// Optional per-op precision policy (mixed precision rewrite).
46    pub policy: Option<PrecisionPolicy>,
47    /// RNG policy for in-graph [`Op::RngNormal`] / [`Op::RngUniform`] nodes.
48    pub rng: rlx_ir::RngOptions,
49    /// Run dead-code elimination as part of compile. Default: true.
50    pub dce: bool,
51    /// Run constant folding. Default: true (cheap, only helps).
52    pub constant_folding: bool,
53    /// Verbose pass logging. Equivalent to `RLX_VERBOSE=1` or
54    /// [`rlx_ir::env::set("RLX_VERBOSE", "1")`].
55    pub verbose: bool,
56    /// Override fusion pipeline target (default: inferred from device).
57    pub fusion_target: Option<FusionTarget>,
58    /// Per-target fusion toggles (Metal env overrides, skip fusion, …).
59    pub fusion_opts: FusionOptions,
60    /// Arena alignment for buffer planning. Default: 64.
61    pub arena_alignment: usize,
62    /// Panic at compile time if fusion diagnostics report missed patterns.
63    pub assert_fusion_clean: bool,
64    /// Backend op claim set for backend-aware fusion + post-fusion
65    /// legalization. Set by [`Backend::compile`] implementations.
66    pub supported_ops: Option<&'static [OpKind]>,
67    /// When set, specialize symbolic dims before backend lowering.
68    pub dim_binding: Option<rlx_ir::DimBinding>,
69    /// Bake fixed param tensors into constants before DCE / constant folding.
70    pub param_bindings: Option<HashMap<String, Vec<f32>>>,
71    /// Native vs common IR lowering ([`KernelDispatchConfig`], `RLX_KERNEL_DISPATCH=common`).
72    pub kernel_dispatch: KernelDispatchConfig,
73}
74
75impl Default for CompileOptions {
76    fn default() -> Self {
77        Self {
78            precision: Precision::F32,
79            policy: None,
80            rng: rlx_ir::RngOptions::default(),
81            dce: true,
82            constant_folding: true,
83            verbose: false,
84            fusion_target: None,
85            fusion_opts: FusionOptions::default(),
86            arena_alignment: 64,
87            assert_fusion_clean: false,
88            supported_ops: None,
89            dim_binding: None,
90            param_bindings: None,
91            kernel_dispatch: KernelDispatchConfig::from_env(),
92        }
93    }
94}
95
96impl CompileOptions {
97    pub fn new() -> Self {
98        Self::default()
99    }
100
101    pub fn precision(mut self, p: Precision) -> Self {
102        self.precision = p;
103        self
104    }
105    pub fn policy(mut self, p: PrecisionPolicy) -> Self {
106        self.policy = Some(p);
107        self
108    }
109    pub fn rng(mut self, rng: rlx_ir::RngOptions) -> Self {
110        self.rng = rng;
111        self
112    }
113    pub fn rng_backend(mut self, backend: rlx_ir::RngBackend) -> Self {
114        self.rng.backend = backend;
115        self
116    }
117    pub fn rng_seed(mut self, seed: u64) -> Self {
118        self.rng.seed = seed;
119        self
120    }
121    pub fn no_policy(mut self) -> Self {
122        self.policy = None;
123        self
124    }
125    pub fn with_dce(mut self, on: bool) -> Self {
126        self.dce = on;
127        self
128    }
129    pub fn with_constant_folding(mut self, on: bool) -> Self {
130        self.constant_folding = on;
131        self
132    }
133    pub fn with_verbose(mut self, on: bool) -> Self {
134        self.verbose = on;
135        self
136    }
137    pub fn fusion_target(mut self, target: FusionTarget) -> Self {
138        self.fusion_target = Some(target);
139        self
140    }
141    pub fn fusion_opts(mut self, opts: FusionOptions) -> Self {
142        self.fusion_opts = opts;
143        self
144    }
145    pub fn arena_alignment(mut self, bytes: usize) -> Self {
146        self.arena_alignment = bytes;
147        self
148    }
149    pub fn supported_ops(mut self, ops: &'static [OpKind]) -> Self {
150        self.supported_ops = Some(ops);
151        self
152    }
153    pub fn assert_fusion_clean(mut self, on: bool) -> Self {
154        self.assert_fusion_clean = on;
155        self
156    }
157    pub fn dim_binding(mut self, binding: rlx_ir::DimBinding) -> Self {
158        self.dim_binding = Some(binding);
159        self
160    }
161    pub fn param_bindings(mut self, bindings: HashMap<String, Vec<f32>>) -> Self {
162        self.param_bindings = Some(bindings);
163        self
164    }
165    pub fn kernel_dispatch(mut self, policy: KernelDispatchPolicy) -> Self {
166        self.kernel_dispatch.policy = policy;
167        self
168    }
169
170    pub fn kernel_dispatch_config(mut self, config: KernelDispatchConfig) -> Self {
171        self.kernel_dispatch = config;
172        self
173    }
174
175    /// Force listed logical kernels to use common IR even when native is in `supported_ops`.
176    pub fn force_common_kinds(mut self, kinds: &'static [OpKind]) -> Self {
177        self.kernel_dispatch.force_common_kinds = kinds;
178        self
179    }
180
181    /// Keep listed logical kernels native even under `ForceCommon` / missing from `supported_ops`.
182    pub fn force_native_kinds(mut self, kinds: &'static [OpKind]) -> Self {
183        self.kernel_dispatch.force_native_kinds = kinds;
184        self
185    }
186}