Skip to main content

rlx_qwen35/
execution.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Qwen3.5 execution variants — [`ModelExecutionConfig`] drives cache keys and [`DimBinding`].
17
18use std::path::PathBuf;
19
20use anyhow::{Context, Result};
21use rlx_flow::{BuiltModel, ModelExecutionConfig};
22use rlx_ir::hir::HirModule;
23use rlx_ir::{BindingManifest, CompilationMode};
24use rlx_runtime::{AotCache, CompileOptions, CompiledGraph, ModelCompilePipeline};
25
26/// Component compile pipeline + optional on-disk AOT LIR ([`CompilationMode::Aot`]).
27pub struct Qwen35CompileCache {
28    pub pipeline: ModelCompilePipeline,
29    aot: Option<AotCache>,
30}
31
32impl Qwen35CompileCache {
33    pub fn new(device: rlx_runtime::Device, capacity: usize) -> Self {
34        Self {
35            pipeline: ModelCompilePipeline::with_capacity(device, capacity),
36            aot: None,
37        }
38    }
39
40    /// Enable disk-backed LIR for [`CompilationMode::Aot`] variants.
41    pub fn with_aot(
42        device: rlx_runtime::Device,
43        capacity: usize,
44        root: impl Into<PathBuf>,
45    ) -> Self {
46        Self {
47            pipeline: ModelCompilePipeline::with_capacity(device, capacity),
48            aot: Some(AotCache::new(root)),
49        }
50    }
51
52    pub fn device(&self) -> rlx_runtime::Device {
53        self.pipeline.device()
54    }
55
56    pub fn contains(&self, config: &ModelExecutionConfig) -> bool {
57        self.pipeline.contains(config.cache_key())
58    }
59
60    pub fn len(&self) -> usize {
61        self.pipeline.len()
62    }
63
64    pub fn is_empty(&self) -> bool {
65        self.pipeline.is_empty()
66    }
67
68    pub fn has_template(&self) -> bool {
69        self.pipeline.has_template()
70    }
71
72    /// Binding layout for a variant (requires template built for that HIR family).
73    pub fn binding_manifest_for(
74        &self,
75        config: &ModelExecutionConfig,
76        options: &CompileOptions,
77    ) -> BindingManifest {
78        self.pipeline
79            .binding_manifest_for_component(config.component(), options)
80    }
81
82    /// Compile a tier-0 [`BuiltModel`] through this pipeline (profile + variant key).
83    pub fn compile_built(
84        &mut self,
85        built: BuiltModel,
86        config: &ModelExecutionConfig,
87        options: &CompileOptions,
88    ) -> Result<CompiledGraph> {
89        if config.component().compilation_mode == CompilationMode::Aot {
90            return self.compile_built_aot(built, config, options);
91        }
92        rlx_core::flow_bridge::compile_built_with_config(&mut self.pipeline, built, config, options)
93    }
94
95    /// [`CompilationMode::Aot`] — disk LIR via [`AotCache`] + pipeline specialize.
96    pub fn compile_built_aot(
97        &mut self,
98        built: BuiltModel,
99        config: &ModelExecutionConfig,
100        options: &CompileOptions,
101    ) -> Result<CompiledGraph> {
102        let aot = self
103            .aot
104            .as_ref()
105            .context("CompilationMode::Aot requires Qwen35CompileCache::with_aot(root)")?;
106        let key = config.cache_key();
107        let binding = config.dim_binding();
108        let disk_base = config.component().aot_disk_base();
109        let (hir, params) = built.into_parts()?;
110        let mut compiled = self
111            .pipeline
112            .get_or_specialize_aot(aot, &disk_base, key, &binding, || hir, options)
113            .map_err(|e| anyhow::anyhow!("{e}"))?
114            .clone();
115        for (name, data) in params {
116            compiled.set_param(&name, &data);
117        }
118        Ok(compiled)
119    }
120}
121
122/// Prefill-cache graph (symbolic `sym::SEQ`).
123pub fn prefill_config(batch: usize, seq: usize) -> ModelExecutionConfig {
124    ModelExecutionConfig::qwen35_prefill(batch, seq)
125}
126
127/// VLM hidden-state prefill-cache (same dim binding as text prefill).
128pub fn hidden_prefill_config(batch: usize, seq: usize) -> ModelExecutionConfig {
129    ModelExecutionConfig::qwen35_prefill(batch, seq)
130}
131
132/// Decode step (symbolic `sym::PAST_SEQ`, new tokens = 1).
133pub fn decode_config(batch: usize, past_seq: usize) -> ModelExecutionConfig {
134    ModelExecutionConfig::qwen35_decode(batch, past_seq)
135}
136
137/// Cache bucket key — use [`ModelExecutionConfig::cache_key`] (full component fingerprint).
138#[inline]
139pub fn cache_key_for_config(config: &ModelExecutionConfig) -> u64 {
140    config.cache_key()
141}
142
143/// Compile-once / specialize-at-runtime; upload params on first hit for this variant key.
144pub fn get_or_specialize_hir<'a, F>(
145    cache: &'a mut Qwen35CompileCache,
146    config: &ModelExecutionConfig,
147    build_hir: F,
148    on_first_hit: impl FnOnce(&mut CompiledGraph) -> Result<()>,
149) -> Result<&'a mut CompiledGraph>
150where
151    F: FnOnce() -> HirModule,
152{
153    get_or_specialize_hir_with_options(
154        cache,
155        config,
156        build_hir,
157        &rlx_core::flow_bridge::compile_options_for_device(config, cache.device()),
158        on_first_hit,
159    )
160}
161
162/// Like [`get_or_specialize_hir`] with explicit compile options (profile, dispatch, …).
163pub fn get_or_specialize_hir_with_options<'a, F>(
164    cache: &'a mut Qwen35CompileCache,
165    config: &ModelExecutionConfig,
166    build_hir: F,
167    options: &CompileOptions,
168    on_first_hit: impl FnOnce(&mut CompiledGraph) -> Result<()>,
169) -> Result<&'a mut CompiledGraph>
170where
171    F: FnOnce() -> HirModule,
172{
173    let key = config.cache_key();
174    let binding = config.dim_binding();
175    let first = !cache.pipeline.contains(key);
176
177    let compiled = match config.component().compilation_mode {
178        CompilationMode::Aot => {
179            let aot = cache.aot.as_ref().context(
180                "CompilationMode::Aot requires Qwen35CompileCache::with_aot(root) \
181                 (disk LIR under that directory)",
182            )?;
183            let disk_base = config.component().aot_disk_base();
184            cache
185                .pipeline
186                .get_or_specialize_aot(aot, &disk_base, key, &binding, build_hir, options)
187                .map_err(|e| anyhow::anyhow!("{e}"))?
188        }
189        CompilationMode::Eager | CompilationMode::Lazy => cache
190            .pipeline
191            .get_or_compile(key, &binding, build_hir, options)
192            .map_err(|e| anyhow::anyhow!("{e}"))?,
193    };
194
195    if first {
196        on_first_hit(compiled)?;
197    }
198    Ok(compiled)
199}
200
201/// Template → specialize → compile; returns compiled graph + binding manifest.
202pub fn get_or_specialize_component<'a, F>(
203    cache: &'a mut Qwen35CompileCache,
204    config: &ModelExecutionConfig,
205    build_hir: F,
206    options: &CompileOptions,
207    on_first_hit: impl FnOnce(&mut CompiledGraph) -> Result<()>,
208) -> Result<(&'a mut CompiledGraph, BindingManifest)>
209where
210    F: FnOnce() -> HirModule,
211{
212    let manifest = cache.binding_manifest_for(config, options);
213    let compiled =
214        get_or_specialize_hir_with_options(cache, config, build_hir, options, on_first_hit)?;
215    Ok((compiled, manifest))
216}
217
218#[cfg(test)]
219mod tests {
220    use super::*;
221    use rlx_ir::hir::HirMut;
222    use rlx_ir::{CompilationMode, DType, HirModule, Shape};
223
224    fn tiny_hir() -> HirModule {
225        let mut hir = HirModule::new("qwen35_aot");
226        let mut gb = HirMut::new(&mut hir);
227        let x = gb.input("x", Shape::new(&[1, 4], DType::F32));
228        let w = gb.param("w", Shape::new(&[4, 2], DType::F32));
229        let y = hir.linear(x, w, None, None, Shape::new(&[1, 2], DType::F32));
230        hir.set_outputs(vec![y]);
231        hir
232    }
233
234    #[test]
235    fn compile_built_via_pipeline() {
236        use rlx_flow::BuiltModel;
237        use std::collections::HashMap;
238
239        let mut cache = Qwen35CompileCache::new(rlx_runtime::Device::Cpu, 4);
240        let config = prefill_config(1, 4);
241        let built = BuiltModel::from_hir(tiny_hir(), HashMap::new())
242            .unwrap()
243            .with_execution_config(&config);
244        let opts =
245            rlx_core::flow_bridge::compile_options_for_device(&config, rlx_runtime::Device::Cpu);
246        let mut compiled = cache.compile_built(built, &config, &opts).unwrap();
247        compiled.set_param("w", &[1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0]);
248        let outs = compiled.run(&[("x", &[1.0f32, 2.0, 3.0, 4.0])]);
249        assert_eq!(outs[0].len(), 2);
250    }
251
252    #[test]
253    fn unified_cache_key_differs_batch() {
254        let a = prefill_config(1, 8).cache_key();
255        let b = prefill_config(2, 8).cache_key();
256        assert_ne!(a, b);
257        assert_eq!(cache_key_for_config(&prefill_config(1, 8)), a);
258    }
259
260    #[test]
261    fn pipeline_has_template_after_specialize() {
262        let mut cache = Qwen35CompileCache::new(rlx_runtime::Device::Cpu, 4);
263        let config = prefill_config(1, 4);
264        let opts = rlx_core::flow_bridge::compile_options_for_device(&config, cache.device());
265        get_or_specialize_hir(&mut cache, &config, tiny_hir, |_| Ok(())).unwrap();
266        assert!(cache.has_template());
267        let _manifest = cache.binding_manifest_for(&config, &opts);
268    }
269
270    #[test]
271    fn aot_mode_writes_lir_disk() {
272        let dir = std::env::temp_dir().join(format!("qwen35_aot_{}", std::process::id()));
273        let mut cache = Qwen35CompileCache::with_aot(rlx_runtime::Device::Cpu, 4, &dir);
274        let config = prefill_config(1, 4).with_compilation_mode(CompilationMode::Aot);
275        get_or_specialize_hir(&mut cache, &config, tiny_hir, |_| Ok(())).unwrap();
276        let disk = config.component().aot_disk_base();
277        assert!(
278            dir.join(format!("{disk}__0.lir.json")).exists() || {
279                std::fs::read_dir(&dir)
280                    .ok()
281                    .map(|rd| {
282                        rd.flatten()
283                            .any(|e| e.file_name().to_string_lossy().contains(&disk))
284                    })
285                    .unwrap_or(false)
286            }
287        );
288        std::fs::remove_dir_all(&dir).ok();
289    }
290}