1use std::path::PathBuf;
19
20use anyhow::{Context, Result};
21use rlx_flow::{BuiltModel, ModelExecutionConfig};
22use rlx_ir::hir::HirModule;
23use rlx_ir::{BindingManifest, CompilationMode};
24use rlx_runtime::{AotCache, CompileOptions, CompiledGraph, ModelCompilePipeline};
25
26pub struct Qwen35CompileCache {
28 pub pipeline: ModelCompilePipeline,
29 aot: Option<AotCache>,
30}
31
32impl Qwen35CompileCache {
33 pub fn new(device: rlx_runtime::Device, capacity: usize) -> Self {
34 Self {
35 pipeline: ModelCompilePipeline::with_capacity(device, capacity),
36 aot: None,
37 }
38 }
39
40 pub fn with_aot(
42 device: rlx_runtime::Device,
43 capacity: usize,
44 root: impl Into<PathBuf>,
45 ) -> Self {
46 Self {
47 pipeline: ModelCompilePipeline::with_capacity(device, capacity),
48 aot: Some(AotCache::new(root)),
49 }
50 }
51
52 pub fn device(&self) -> rlx_runtime::Device {
53 self.pipeline.device()
54 }
55
56 pub fn contains(&self, config: &ModelExecutionConfig) -> bool {
57 self.pipeline.contains(config.cache_key())
58 }
59
60 pub fn len(&self) -> usize {
61 self.pipeline.len()
62 }
63
64 pub fn is_empty(&self) -> bool {
65 self.pipeline.is_empty()
66 }
67
68 pub fn has_template(&self) -> bool {
69 self.pipeline.has_template()
70 }
71
72 pub fn binding_manifest_for(
74 &self,
75 config: &ModelExecutionConfig,
76 options: &CompileOptions,
77 ) -> BindingManifest {
78 self.pipeline
79 .binding_manifest_for_component(config.component(), options)
80 }
81
82 pub fn compile_built(
84 &mut self,
85 built: BuiltModel,
86 config: &ModelExecutionConfig,
87 options: &CompileOptions,
88 ) -> Result<CompiledGraph> {
89 if config.component().compilation_mode == CompilationMode::Aot {
90 return self.compile_built_aot(built, config, options);
91 }
92 rlx_core::flow_bridge::compile_built_with_config(&mut self.pipeline, built, config, options)
93 }
94
95 pub fn compile_built_aot(
97 &mut self,
98 built: BuiltModel,
99 config: &ModelExecutionConfig,
100 options: &CompileOptions,
101 ) -> Result<CompiledGraph> {
102 let aot = self
103 .aot
104 .as_ref()
105 .context("CompilationMode::Aot requires Qwen35CompileCache::with_aot(root)")?;
106 let key = config.cache_key();
107 let binding = config.dim_binding();
108 let disk_base = config.component().aot_disk_base();
109 let (hir, params) = built.into_parts()?;
110 let mut compiled = self
111 .pipeline
112 .get_or_specialize_aot(aot, &disk_base, key, &binding, || hir, options)
113 .map_err(|e| anyhow::anyhow!("{e}"))?
114 .clone();
115 for (name, data) in params {
116 compiled.set_param(&name, &data);
117 }
118 Ok(compiled)
119 }
120}
121
122pub fn prefill_config(batch: usize, seq: usize) -> ModelExecutionConfig {
124 ModelExecutionConfig::qwen35_prefill(batch, seq)
125}
126
127pub fn hidden_prefill_config(batch: usize, seq: usize) -> ModelExecutionConfig {
129 ModelExecutionConfig::qwen35_prefill(batch, seq)
130}
131
132pub fn decode_config(batch: usize, past_seq: usize) -> ModelExecutionConfig {
134 ModelExecutionConfig::qwen35_decode(batch, past_seq)
135}
136
137#[inline]
139pub fn cache_key_for_config(config: &ModelExecutionConfig) -> u64 {
140 config.cache_key()
141}
142
143pub fn get_or_specialize_hir<'a, F>(
145 cache: &'a mut Qwen35CompileCache,
146 config: &ModelExecutionConfig,
147 build_hir: F,
148 on_first_hit: impl FnOnce(&mut CompiledGraph) -> Result<()>,
149) -> Result<&'a mut CompiledGraph>
150where
151 F: FnOnce() -> HirModule,
152{
153 get_or_specialize_hir_with_options(
154 cache,
155 config,
156 build_hir,
157 &rlx_core::flow_bridge::compile_options_for_device(config, cache.device()),
158 on_first_hit,
159 )
160}
161
162pub fn get_or_specialize_hir_with_options<'a, F>(
164 cache: &'a mut Qwen35CompileCache,
165 config: &ModelExecutionConfig,
166 build_hir: F,
167 options: &CompileOptions,
168 on_first_hit: impl FnOnce(&mut CompiledGraph) -> Result<()>,
169) -> Result<&'a mut CompiledGraph>
170where
171 F: FnOnce() -> HirModule,
172{
173 let key = config.cache_key();
174 let binding = config.dim_binding();
175 let first = !cache.pipeline.contains(key);
176
177 let compiled = match config.component().compilation_mode {
178 CompilationMode::Aot => {
179 let aot = cache.aot.as_ref().context(
180 "CompilationMode::Aot requires Qwen35CompileCache::with_aot(root) \
181 (disk LIR under that directory)",
182 )?;
183 let disk_base = config.component().aot_disk_base();
184 cache
185 .pipeline
186 .get_or_specialize_aot(aot, &disk_base, key, &binding, build_hir, options)
187 .map_err(|e| anyhow::anyhow!("{e}"))?
188 }
189 CompilationMode::Eager | CompilationMode::Lazy => cache
190 .pipeline
191 .get_or_compile(key, &binding, build_hir, options)
192 .map_err(|e| anyhow::anyhow!("{e}"))?,
193 };
194
195 if first {
196 on_first_hit(compiled)?;
197 }
198 Ok(compiled)
199}
200
201pub fn get_or_specialize_component<'a, F>(
203 cache: &'a mut Qwen35CompileCache,
204 config: &ModelExecutionConfig,
205 build_hir: F,
206 options: &CompileOptions,
207 on_first_hit: impl FnOnce(&mut CompiledGraph) -> Result<()>,
208) -> Result<(&'a mut CompiledGraph, BindingManifest)>
209where
210 F: FnOnce() -> HirModule,
211{
212 let manifest = cache.binding_manifest_for(config, options);
213 let compiled =
214 get_or_specialize_hir_with_options(cache, config, build_hir, options, on_first_hit)?;
215 Ok((compiled, manifest))
216}
217
218#[cfg(test)]
219mod tests {
220 use super::*;
221 use rlx_ir::hir::HirMut;
222 use rlx_ir::{CompilationMode, DType, HirModule, Shape};
223
224 fn tiny_hir() -> HirModule {
225 let mut hir = HirModule::new("qwen35_aot");
226 let mut gb = HirMut::new(&mut hir);
227 let x = gb.input("x", Shape::new(&[1, 4], DType::F32));
228 let w = gb.param("w", Shape::new(&[4, 2], DType::F32));
229 let y = hir.linear(x, w, None, None, Shape::new(&[1, 2], DType::F32));
230 hir.set_outputs(vec![y]);
231 hir
232 }
233
234 #[test]
235 fn compile_built_via_pipeline() {
236 use rlx_flow::BuiltModel;
237 use std::collections::HashMap;
238
239 let mut cache = Qwen35CompileCache::new(rlx_runtime::Device::Cpu, 4);
240 let config = prefill_config(1, 4);
241 let built = BuiltModel::from_hir(tiny_hir(), HashMap::new())
242 .unwrap()
243 .with_execution_config(&config);
244 let opts =
245 rlx_core::flow_bridge::compile_options_for_device(&config, rlx_runtime::Device::Cpu);
246 let mut compiled = cache.compile_built(built, &config, &opts).unwrap();
247 compiled.set_param("w", &[1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0]);
248 let outs = compiled.run(&[("x", &[1.0f32, 2.0, 3.0, 4.0])]);
249 assert_eq!(outs[0].len(), 2);
250 }
251
252 #[test]
253 fn unified_cache_key_differs_batch() {
254 let a = prefill_config(1, 8).cache_key();
255 let b = prefill_config(2, 8).cache_key();
256 assert_ne!(a, b);
257 assert_eq!(cache_key_for_config(&prefill_config(1, 8)), a);
258 }
259
260 #[test]
261 fn pipeline_has_template_after_specialize() {
262 let mut cache = Qwen35CompileCache::new(rlx_runtime::Device::Cpu, 4);
263 let config = prefill_config(1, 4);
264 let opts = rlx_core::flow_bridge::compile_options_for_device(&config, cache.device());
265 get_or_specialize_hir(&mut cache, &config, tiny_hir, |_| Ok(())).unwrap();
266 assert!(cache.has_template());
267 let _manifest = cache.binding_manifest_for(&config, &opts);
268 }
269
270 #[test]
271 fn aot_mode_writes_lir_disk() {
272 let dir = std::env::temp_dir().join(format!("qwen35_aot_{}", std::process::id()));
273 let mut cache = Qwen35CompileCache::with_aot(rlx_runtime::Device::Cpu, 4, &dir);
274 let config = prefill_config(1, 4).with_compilation_mode(CompilationMode::Aot);
275 get_or_specialize_hir(&mut cache, &config, tiny_hir, |_| Ok(())).unwrap();
276 let disk = config.component().aot_disk_base();
277 assert!(
278 dir.join(format!("{disk}__0.lir.json")).exists() || {
279 std::fs::read_dir(&dir)
280 .ok()
281 .map(|rd| {
282 rd.flatten()
283 .any(|e| e.file_name().to_string_lossy().contains(&disk))
284 })
285 .unwrap_or(false)
286 }
287 );
288 std::fs::remove_dir_all(&dir).ok();
289 }
290}