1use std::path::Path;
19
20use rlx_flow::CompileProfile;
21use rlx_flow::{
22 BuiltModel, FusionTargetKind, MixedPrecisionKind, ModelExecutionConfig, PrecisionKind,
23};
24use rlx_ir::logical_kernel::KernelDispatchConfig;
25use rlx_opt::{FusionOptions, FusionTarget, PrecisionPolicy};
26use rlx_runtime::Device;
27use rlx_runtime::{CompileOptions, ModelCompilePipeline, Precision, Session, stages};
28
29use crate::weight_loader::WeightLoader;
30
31pub struct WeightLoaderSource<'a>(pub &'a mut dyn WeightLoader);
33
34impl rlx_flow::WeightSource for WeightLoaderSource<'_> {
35 fn take(&mut self, key: &str, transpose: bool) -> anyhow::Result<(Vec<f32>, Vec<usize>)> {
36 if transpose {
37 self.0.take_transposed(key)
38 } else {
39 self.0.take(key)
40 }
41 }
42}
43
44pub fn load_compile_profile(path: &Path, default: CompileProfile) -> CompileProfile {
46 CompileProfile::from_toml_path(path).unwrap_or(default)
47}
48
49pub fn profile_near_weights(
51 weights: &Path,
52 profile_file: &str,
53 default: CompileProfile,
54) -> CompileProfile {
55 let dir = weights.parent().unwrap_or_else(|| Path::new("."));
56 load_compile_profile(&dir.join(profile_file), default)
57}
58
59pub fn apply_compile_profile(profile: &CompileProfile, opts: &mut CompileOptions) {
61 opts.dce = profile.passes.dce;
62 opts.constant_folding = profile.passes.constant_folding;
63 opts.verbose = profile.passes.verbose;
64 opts.assert_fusion_clean = profile.fusion.assert_clean;
65 opts.fusion_opts = FusionOptions {
66 skip_fusion: profile.fusion.skip,
67 unfuse_elementwise_regions: profile.backend.metal.unfuse_regions
68 || profile.backend.cpu.unfuse_regions,
69 ..FusionOptions::default()
70 };
71 if let Some(target) = fusion_target_from_profile(profile.fusion.target) {
72 opts.fusion_target = Some(target);
73 }
74 opts.precision = match profile.precision.compute {
75 PrecisionKind::F32 => Precision::F32,
76 PrecisionKind::F16 => Precision::F16,
77 PrecisionKind::Bf16 => Precision::F16, };
79 opts.policy = match profile.precision.mixed {
80 MixedPrecisionKind::None => None,
81 MixedPrecisionKind::Auto => Some(PrecisionPolicy::AutoMixed),
82 };
83}
84
85pub fn compile_options_dynamic(binding: rlx_ir::DimBinding) -> CompileOptions {
87 CompileOptions::new().dim_binding(binding)
88}
89
90pub fn compile_options_from_profile(
92 profile: &CompileProfile,
93 device: Device,
94 kernel_dispatch: KernelDispatchConfig,
95) -> CompileOptions {
96 let mut opts = CompileOptions::new();
97 apply_compile_profile(profile, &mut opts);
98 opts.kernel_dispatch = kernel_dispatch;
99 if opts.fusion_target.is_none() {
100 opts.fusion_target = Some(stages::fusion_target_for(device));
101 }
102 opts
103}
104
105pub fn compile_options_for_profile(profile: &CompileProfile, device: Device) -> CompileOptions {
107 compile_options_from_profile(profile, device, KernelDispatchConfig::default())
108}
109
110pub fn compile_options_sam_encoder(device: Device) -> CompileOptions {
112 compile_options_for_profile(&CompileProfile::sam_encoder(), device)
113}
114
115pub fn compile_options_sam3(device: Device) -> CompileOptions {
117 compile_options_for_profile(&CompileProfile::sam3(), device)
118}
119
120pub fn compile_options_sam2_memory_attention(device: Device) -> CompileOptions {
122 compile_options_for_profile(&CompileProfile::sam2_memory_attention(), device)
123}
124
125pub fn compile_graph_with_profile(
127 device: Device,
128 graph: rlx_ir::Graph,
129 profile: &CompileProfile,
130) -> anyhow::Result<rlx_runtime::CompiledGraph> {
131 use rlx_runtime::Session;
132 let opts = compile_options_for_profile(profile, device);
133 Ok(Session::new(device).compile_with(graph, &opts))
134}
135
136pub fn compile_graph_sam(
138 device: Device,
139 graph: rlx_ir::Graph,
140) -> anyhow::Result<rlx_runtime::CompiledGraph> {
141 compile_graph_with_profile(device, graph, &CompileProfile::sam_encoder())
142}
143
144pub fn compile_graph_encoder(
146 device: Device,
147 graph: rlx_ir::Graph,
148) -> anyhow::Result<rlx_runtime::CompiledGraph> {
149 compile_graph_with_profile(device, graph, &CompileProfile::encoder())
150}
151
152pub fn compile_graph_qwen3_prefill(
154 device: Device,
155 graph: rlx_ir::Graph,
156) -> anyhow::Result<rlx_runtime::CompiledGraph> {
157 compile_graph_with_profile(device, graph, &CompileProfile::qwen3_prefill())
158}
159
160pub fn compile_graph_qwen3_decode(
162 device: Device,
163 graph: rlx_ir::Graph,
164) -> anyhow::Result<rlx_runtime::CompiledGraph> {
165 compile_graph_with_profile(device, graph, &CompileProfile::qwen3_decode())
166}
167
168pub fn compile_graph_qwen35_prefill(
170 device: Device,
171 graph: rlx_ir::Graph,
172) -> anyhow::Result<rlx_runtime::CompiledGraph> {
173 compile_graph_with_profile(device, graph, &CompileProfile::qwen35_prefill())
174}
175
176pub fn compile_graph_qwen35_decode(
178 device: Device,
179 graph: rlx_ir::Graph,
180) -> anyhow::Result<rlx_runtime::CompiledGraph> {
181 compile_graph_with_profile(device, graph, &CompileProfile::qwen35_decode())
182}
183
184pub fn compile_graph_gemma_prefill(
186 device: Device,
187 graph: rlx_ir::Graph,
188) -> anyhow::Result<rlx_runtime::CompiledGraph> {
189 compile_graph_with_profile(device, graph, &CompileProfile::gemma_prefill())
190}
191
192pub fn compile_graph_gemma_decode(
194 device: Device,
195 graph: rlx_ir::Graph,
196) -> anyhow::Result<rlx_runtime::CompiledGraph> {
197 compile_graph_with_profile(device, graph, &CompileProfile::gemma_decode())
198}
199
200pub fn compile_graph_llama32_prefill(
202 device: Device,
203 graph: rlx_ir::Graph,
204) -> anyhow::Result<rlx_runtime::CompiledGraph> {
205 compile_graph_with_profile(device, graph, &CompileProfile::llama32_prefill())
206}
207
208pub fn compile_graph_llama32_decode(
210 device: Device,
211 graph: rlx_ir::Graph,
212) -> anyhow::Result<rlx_runtime::CompiledGraph> {
213 compile_graph_with_profile(device, graph, &CompileProfile::llama32_decode())
214}
215
216pub fn compile_graph_legacy(
218 device: Device,
219 graph: rlx_ir::Graph,
220) -> anyhow::Result<rlx_runtime::CompiledGraph> {
221 use rlx_runtime::{CompileOptions, Session};
222 Ok(Session::new(device).compile_with(graph, &CompileOptions::new()))
223}
224
225pub fn compile_hir_sam(
227 device: Device,
228 hir: rlx_ir::hir::HirModule,
229) -> anyhow::Result<rlx_runtime::CompiledGraph> {
230 compile_hir_with_profile(device, hir, &CompileProfile::sam_encoder())
231}
232
233pub fn compile_hir_sam3(
235 device: Device,
236 hir: rlx_ir::hir::HirModule,
237) -> anyhow::Result<rlx_runtime::CompiledGraph> {
238 compile_hir_with_profile(device, hir, &CompileProfile::sam3())
239}
240
241pub fn compile_hir_with_profile(
243 device: Device,
244 hir: rlx_ir::hir::HirModule,
245 profile: &CompileProfile,
246) -> anyhow::Result<rlx_runtime::CompiledGraph> {
247 use rlx_runtime::Session;
248 let opts = compile_options_for_profile(profile, device);
249 Ok(Session::new(device).compile_hir_with(hir, &opts)?)
250}
251
252pub fn compile_options_for(config: &ModelExecutionConfig) -> CompileOptions {
254 compile_options_from_profile(
255 &config.compile_profile(),
256 Device::Cpu,
257 config.component().kernel_dispatch,
258 )
259 .dim_binding(config.dim_binding())
260}
261
262pub fn compile_options_for_device(config: &ModelExecutionConfig, device: Device) -> CompileOptions {
264 compile_options_from_profile(
265 &config.compile_profile(),
266 device,
267 config.component().kernel_dispatch,
268 )
269 .dim_binding(config.dim_binding())
270}
271
272pub fn compile_built_with_config(
274 pipeline: &mut ModelCompilePipeline,
275 built: BuiltModel,
276 config: &ModelExecutionConfig,
277 options: &CompileOptions,
278) -> anyhow::Result<rlx_runtime::CompiledGraph> {
279 let key = config.cache_key();
280 let binding = config.dim_binding();
281 let device = pipeline.device();
282 let (hir, params) = built.into_parts()?;
283 if !pipeline.contains(key) {
286 pipeline.get_or_compile(key, &binding, || hir.clone(), options)?;
287 }
288 let mut compiled = if device == Device::Cpu {
289 pipeline
290 .get_or_compile(key, &binding, || hir.clone(), options)?
291 .clone()
292 } else {
293 Session::new(device).compile_hir_with(hir, options)?
294 };
295 for (name, data) in params {
296 compiled.set_param(&name, &data);
297 }
298 Ok(compiled)
299}
300
301fn fusion_target_from_profile(kind: FusionTargetKind) -> Option<FusionTarget> {
302 match kind {
303 FusionTargetKind::Auto => None,
304 FusionTargetKind::Cpu => Some(FusionTarget::Cpu),
305 FusionTargetKind::Metal => Some(FusionTarget::Metal),
306 FusionTargetKind::Mlx => Some(FusionTarget::Mlx),
307 FusionTargetKind::Cuda => Some(FusionTarget::Cuda),
308 FusionTargetKind::Rocm => Some(FusionTarget::Rocm),
309 FusionTargetKind::Wgpu => Some(FusionTarget::Wgpu),
310 FusionTargetKind::Tpu => Some(FusionTarget::Tpu),
311 }
312}