1use crate::Precision;
34use rlx_ir::OpKind;
35use rlx_ir::logical_kernel::{KernelDispatchConfig, KernelDispatchPolicy};
36use rlx_opt::{FusionOptions, FusionTarget, PrecisionPolicy};
37use std::collections::HashMap;
38
39#[derive(Debug, Clone)]
42pub struct CompileOptions {
43 pub precision: Precision,
45 pub policy: Option<PrecisionPolicy>,
47 pub rng: rlx_ir::RngOptions,
49 pub dce: bool,
51 pub constant_folding: bool,
53 pub verbose: bool,
56 pub fusion_target: Option<FusionTarget>,
58 pub fusion_opts: FusionOptions,
60 pub arena_alignment: usize,
62 pub assert_fusion_clean: bool,
64 pub supported_ops: Option<&'static [OpKind]>,
67 pub dim_binding: Option<rlx_ir::DimBinding>,
69 pub param_bindings: Option<HashMap<String, Vec<f32>>>,
71 pub kernel_dispatch: KernelDispatchConfig,
73}
74
75impl Default for CompileOptions {
76 fn default() -> Self {
77 Self {
78 precision: Precision::F32,
79 policy: None,
80 rng: rlx_ir::RngOptions::default(),
81 dce: true,
82 constant_folding: true,
83 verbose: false,
84 fusion_target: None,
85 fusion_opts: FusionOptions::default(),
86 arena_alignment: 64,
87 assert_fusion_clean: false,
88 supported_ops: None,
89 dim_binding: None,
90 param_bindings: None,
91 kernel_dispatch: KernelDispatchConfig::from_env(),
92 }
93 }
94}
95
96impl CompileOptions {
97 pub fn new() -> Self {
98 Self::default()
99 }
100
101 pub fn precision(mut self, p: Precision) -> Self {
102 self.precision = p;
103 self
104 }
105 pub fn policy(mut self, p: PrecisionPolicy) -> Self {
106 self.policy = Some(p);
107 self
108 }
109 pub fn rng(mut self, rng: rlx_ir::RngOptions) -> Self {
110 self.rng = rng;
111 self
112 }
113 pub fn rng_backend(mut self, backend: rlx_ir::RngBackend) -> Self {
114 self.rng.backend = backend;
115 self
116 }
117 pub fn rng_seed(mut self, seed: u64) -> Self {
118 self.rng.seed = seed;
119 self
120 }
121 pub fn no_policy(mut self) -> Self {
122 self.policy = None;
123 self
124 }
125 pub fn with_dce(mut self, on: bool) -> Self {
126 self.dce = on;
127 self
128 }
129 pub fn with_constant_folding(mut self, on: bool) -> Self {
130 self.constant_folding = on;
131 self
132 }
133 pub fn with_verbose(mut self, on: bool) -> Self {
134 self.verbose = on;
135 self
136 }
137 pub fn fusion_target(mut self, target: FusionTarget) -> Self {
138 self.fusion_target = Some(target);
139 self
140 }
141 pub fn fusion_opts(mut self, opts: FusionOptions) -> Self {
142 self.fusion_opts = opts;
143 self
144 }
145 pub fn arena_alignment(mut self, bytes: usize) -> Self {
146 self.arena_alignment = bytes;
147 self
148 }
149 pub fn supported_ops(mut self, ops: &'static [OpKind]) -> Self {
150 self.supported_ops = Some(ops);
151 self
152 }
153 pub fn assert_fusion_clean(mut self, on: bool) -> Self {
154 self.assert_fusion_clean = on;
155 self
156 }
157 pub fn dim_binding(mut self, binding: rlx_ir::DimBinding) -> Self {
158 self.dim_binding = Some(binding);
159 self
160 }
161 pub fn param_bindings(mut self, bindings: HashMap<String, Vec<f32>>) -> Self {
162 self.param_bindings = Some(bindings);
163 self
164 }
165 pub fn kernel_dispatch(mut self, policy: KernelDispatchPolicy) -> Self {
166 self.kernel_dispatch.policy = policy;
167 self
168 }
169
170 pub fn kernel_dispatch_config(mut self, config: KernelDispatchConfig) -> Self {
171 self.kernel_dispatch = config;
172 self
173 }
174
175 pub fn force_common_kinds(mut self, kinds: &'static [OpKind]) -> Self {
177 self.kernel_dispatch.force_common_kinds = kinds;
178 self
179 }
180
181 pub fn force_native_kinds(mut self, kinds: &'static [OpKind]) -> Self {
183 self.kernel_dispatch.force_native_kinds = kinds;
184 self
185 }
186}