1use std::collections::VecDeque;
23
24use rlx_ir::hir::HirModule;
25use rlx_ir::{BindingManifest, DimBinding, ModelComponent};
26use rlx_opt::CompileResult;
27
28use crate::stages;
29use crate::{CompileOptions, CompiledGraph, Device};
30
31pub struct ModelCompilePipeline {
33 device: Device,
34 capacity: usize,
35 template: Option<CompileResult>,
36 entries: Vec<(u64, CompiledGraph)>,
37 order: VecDeque<u64>,
38}
39
40impl ModelCompilePipeline {
41 pub fn new(device: Device) -> Self {
42 Self::with_capacity(device, 8)
43 }
44
45 pub fn with_capacity(device: Device, capacity: usize) -> Self {
46 assert!(capacity > 0, "ModelCompilePipeline capacity must be ≥ 1");
47 Self {
48 device,
49 capacity,
50 template: None,
51 entries: Vec::new(),
52 order: VecDeque::new(),
53 }
54 }
55
56 pub fn device(&self) -> Device {
57 self.device
58 }
59
60 pub fn has_template(&self) -> bool {
61 self.template.is_some()
62 }
63
64 pub fn build_template<F>(
66 &mut self,
67 build_hir: F,
68 options: &CompileOptions,
69 ) -> Result<&CompileResult, rlx_ir::hir::LowerError>
70 where
71 F: FnOnce() -> HirModule,
72 {
73 if self.template.is_none() {
74 let pipe = stages::pipeline_for(self.device, options);
75 self.template = Some(pipe.compile_hir(build_hir())?);
76 }
77 Ok(self.template.as_ref().expect("template set"))
78 }
79
80 pub fn template_binding_manifest(&self) -> BindingManifest {
81 let template = self.template.as_ref().expect("call build_template first");
82 BindingManifest::from_lir(&template.lir)
83 }
84
85 pub fn specialize_template(
87 &self,
88 binding: &DimBinding,
89 options: &CompileOptions,
90 ) -> CompileResult {
91 let template = self
92 .template
93 .as_ref()
94 .expect("call build_template before specialize_template");
95 let pipe = stages::pipeline_for(self.device, options);
96 template.specialize(&pipe, binding)
97 }
98
99 pub fn compile_lir(
101 &self,
102 specialized: CompileResult,
103 options: &CompileOptions,
104 ) -> CompiledGraph {
105 let backend = crate::registry::backend_for(self.device).expect("backend registered");
106 let executable = backend.compile_lir(specialized.lir, options);
107 CompiledGraph::new(executable, self.device)
108 }
109
110 pub fn get_or_compile<F>(
112 &mut self,
113 key: u64,
114 binding: &DimBinding,
115 build_hir: F,
116 options: &CompileOptions,
117 ) -> Result<&mut CompiledGraph, rlx_ir::hir::LowerError>
118 where
119 F: FnOnce() -> HirModule,
120 {
121 if let Some(idx) = self.entries.iter().position(|(k, _)| *k == key) {
122 return Ok(&mut self.entries[idx].1);
123 }
124 let mut template_opts = options.clone();
125 template_opts.dim_binding = None;
126 self.build_template(build_hir, &template_opts)?;
127 let specialized = self.specialize_template(binding, &template_opts);
128 let mut compile_opts = options.clone();
129 compile_opts.dim_binding = None;
130 let compiled = self.compile_lir(specialized, &compile_opts);
131
132 if self.entries.len() >= self.capacity
133 && let Some(evict) = self.order.pop_front()
134 {
135 self.entries.retain(|(k, _)| *k != evict);
136 }
137 self.entries.push((key, compiled));
138 self.order.push_back(key);
139 Ok(&mut self.entries.last_mut().unwrap().1)
140 }
141
142 pub fn binding_manifest_for_binding(
144 &self,
145 binding: &DimBinding,
146 options: &CompileOptions,
147 ) -> BindingManifest {
148 let specialized = self.specialize_template(binding, options);
149 BindingManifest::from_lir(&specialized.lir)
150 }
151
152 pub fn binding_manifest_for_component(
154 &self,
155 component: &ModelComponent,
156 options: &CompileOptions,
157 ) -> BindingManifest {
158 self.binding_manifest_for_binding(&component.dim_binding(), options)
159 }
160
161 pub fn get_or_compile_component<F>(
163 &mut self,
164 component: &ModelComponent,
165 build_hir: F,
166 options: &CompileOptions,
167 ) -> Result<(&mut CompiledGraph, BindingManifest), rlx_ir::hir::LowerError>
168 where
169 F: FnOnce() -> HirModule,
170 {
171 let key = component.cache_key();
172 let binding = component.dim_binding();
173 let manifest = self.binding_manifest_for_component(component, options);
174 let compiled = self.get_or_compile(key, &binding, build_hir, options)?;
175 Ok((compiled, manifest))
176 }
177
178 pub fn contains(&self, key: u64) -> bool {
179 self.entries.iter().any(|(k, _)| *k == key)
180 }
181
182 pub fn len(&self) -> usize {
183 self.entries.len()
184 }
185
186 pub fn is_empty(&self) -> bool {
187 self.entries.is_empty()
188 }
189
190 pub fn template_result(&self) -> Option<&CompileResult> {
192 self.template.as_ref()
193 }
194
195 pub fn ensure_template<F: FnOnce() -> HirModule>(
197 &mut self,
198 build_hir: F,
199 options: &CompileOptions,
200 ) -> Result<&CompileResult, rlx_ir::hir::LowerError> {
201 self.build_template(build_hir, options)
202 }
203
204 pub fn get_or_specialize_aot<F: FnOnce() -> HirModule>(
206 &mut self,
207 aot: &crate::AotCache,
208 disk_base: &str,
209 key: u64,
210 binding: &DimBinding,
211 build_hir: F,
212 options: &CompileOptions,
213 ) -> Result<&mut CompiledGraph, crate::AotCacheError> {
214 if let Some(idx) = self.entries.iter().position(|(k, _)| *k == key) {
215 return Ok(&mut self.entries[idx].1);
216 }
217 let device = self.device;
218 let template = self.ensure_template(build_hir, options)?;
219 let compiled = aot.specialize_cached(disk_base, binding, device, template, options)?;
220 if self.entries.len() >= self.capacity
221 && let Some(evict_key) = self.order.pop_front()
222 {
223 self.entries.retain(|(k, _)| *k != evict_key);
224 }
225 self.entries.push((key, compiled));
226 self.order.push_back(key);
227 Ok(&mut self.entries.last_mut().unwrap().1)
228 }
229}
230
231#[cfg(test)]
232mod tests {
233 use super::*;
234 use rlx_ir::hir::HirMut;
235 use rlx_ir::{DType, HirModule, Shape};
236
237 #[test]
238 fn template_specialize_compile_smoke() {
239 let device = Device::Cpu;
240 let mut pipe = ModelCompilePipeline::new(device);
241 let opts = CompileOptions::new();
242
243 let build = || {
244 let mut hir = HirModule::new("dyn");
245 let mut gb = HirMut::new(&mut hir);
246 let x = gb.input("x", Shape::new(&[1, 8, 4], DType::F32));
247 let w = gb.param("w", Shape::new(&[4, 2], DType::F32));
248 let y = hir.linear(x, w, None, None, Shape::new(&[1, 8, 2], DType::F32));
249 hir.set_outputs(vec![y]);
250 hir
251 };
252
253 pipe.build_template(build, &opts).unwrap();
254 let binding = rlx_ir::DimBinding::new();
255 let spec = pipe.specialize_template(&binding, &opts);
256 let manifest = BindingManifest::from_lir(&spec.lir);
257 assert_eq!(manifest.params[0].name, "w");
258 let mut compiled = pipe.compile_lir(spec, &opts);
259 compiled.set_param("w", &[0.0f32; 8]);
260 let outs = compiled.run(&[("x", &[0.0f32; 32])]);
261 assert_eq!(outs.len(), 1);
262 }
263}