1use std::collections::VecDeque;
11
12use rlx_ir::hir::HirModule;
13use rlx_ir::{BindingManifest, DimBinding, ModelComponent};
14use rlx_opt::CompileResult;
15
16use crate::stages;
17use crate::{CompileOptions, CompiledGraph, Device};
18
19pub struct ModelCompilePipeline {
21 device: Device,
22 capacity: usize,
23 template: Option<CompileResult>,
24 entries: Vec<(u64, CompiledGraph)>,
25 order: VecDeque<u64>,
26}
27
28impl ModelCompilePipeline {
29 pub fn new(device: Device) -> Self {
30 Self::with_capacity(device, 8)
31 }
32
33 pub fn with_capacity(device: Device, capacity: usize) -> Self {
34 assert!(capacity > 0, "ModelCompilePipeline capacity must be ≥ 1");
35 Self {
36 device,
37 capacity,
38 template: None,
39 entries: Vec::new(),
40 order: VecDeque::new(),
41 }
42 }
43
44 pub fn device(&self) -> Device {
45 self.device
46 }
47
48 pub fn has_template(&self) -> bool {
49 self.template.is_some()
50 }
51
52 pub fn build_template<F>(
54 &mut self,
55 build_hir: F,
56 options: &CompileOptions,
57 ) -> Result<&CompileResult, rlx_ir::hir::LowerError>
58 where
59 F: FnOnce() -> HirModule,
60 {
61 if self.template.is_none() {
62 let pipe = stages::pipeline_for(self.device, options);
63 self.template = Some(pipe.compile_hir(build_hir())?);
64 }
65 Ok(self.template.as_ref().expect("template set"))
66 }
67
68 pub fn template_binding_manifest(&self) -> BindingManifest {
69 let template = self.template.as_ref().expect("call build_template first");
70 BindingManifest::from_lir(&template.lir)
71 }
72
73 pub fn specialize_template(
75 &self,
76 binding: &DimBinding,
77 options: &CompileOptions,
78 ) -> CompileResult {
79 let template = self
80 .template
81 .as_ref()
82 .expect("call build_template before specialize_template");
83 let pipe = stages::pipeline_for(self.device, options);
84 template.specialize(&pipe, binding)
85 }
86
87 pub fn compile_lir(
89 &self,
90 specialized: CompileResult,
91 options: &CompileOptions,
92 ) -> CompiledGraph {
93 let backend = crate::registry::backend_for(self.device).expect("backend registered");
94 let executable = backend.compile_lir(specialized.lir, options);
95 CompiledGraph::new(executable, self.device)
96 }
97
98 pub fn get_or_compile<F>(
100 &mut self,
101 key: u64,
102 binding: &DimBinding,
103 build_hir: F,
104 options: &CompileOptions,
105 ) -> Result<&mut CompiledGraph, rlx_ir::hir::LowerError>
106 where
107 F: FnOnce() -> HirModule,
108 {
109 if let Some(idx) = self.entries.iter().position(|(k, _)| *k == key) {
110 return Ok(&mut self.entries[idx].1);
111 }
112 let mut template_opts = options.clone();
113 template_opts.dim_binding = None;
114 self.build_template(build_hir, &template_opts)?;
115 let specialized = self.specialize_template(binding, &template_opts);
116 let mut compile_opts = options.clone();
117 compile_opts.dim_binding = None;
118 let compiled = self.compile_lir(specialized, &compile_opts);
119
120 if self.entries.len() >= self.capacity
121 && let Some(evict) = self.order.pop_front()
122 {
123 self.entries.retain(|(k, _)| *k != evict);
124 }
125 self.entries.push((key, compiled));
126 self.order.push_back(key);
127 Ok(&mut self.entries.last_mut().unwrap().1)
128 }
129
130 pub fn binding_manifest_for_binding(
132 &self,
133 binding: &DimBinding,
134 options: &CompileOptions,
135 ) -> BindingManifest {
136 let specialized = self.specialize_template(binding, options);
137 BindingManifest::from_lir(&specialized.lir)
138 }
139
140 pub fn binding_manifest_for_component(
142 &self,
143 component: &ModelComponent,
144 options: &CompileOptions,
145 ) -> BindingManifest {
146 self.binding_manifest_for_binding(&component.dim_binding(), options)
147 }
148
149 pub fn get_or_compile_component<F>(
151 &mut self,
152 component: &ModelComponent,
153 build_hir: F,
154 options: &CompileOptions,
155 ) -> Result<(&mut CompiledGraph, BindingManifest), rlx_ir::hir::LowerError>
156 where
157 F: FnOnce() -> HirModule,
158 {
159 let key = component.cache_key();
160 let binding = component.dim_binding();
161 let manifest = self.binding_manifest_for_component(component, options);
162 let compiled = self.get_or_compile(key, &binding, build_hir, options)?;
163 Ok((compiled, manifest))
164 }
165
166 pub fn contains(&self, key: u64) -> bool {
167 self.entries.iter().any(|(k, _)| *k == key)
168 }
169
170 pub fn len(&self) -> usize {
171 self.entries.len()
172 }
173
174 pub fn is_empty(&self) -> bool {
175 self.entries.is_empty()
176 }
177
178 pub fn template_result(&self) -> Option<&CompileResult> {
180 self.template.as_ref()
181 }
182
183 pub fn ensure_template<F: FnOnce() -> HirModule>(
185 &mut self,
186 build_hir: F,
187 options: &CompileOptions,
188 ) -> Result<&CompileResult, rlx_ir::hir::LowerError> {
189 self.build_template(build_hir, options)
190 }
191
192 pub fn get_or_specialize_aot<F: FnOnce() -> HirModule>(
194 &mut self,
195 aot: &crate::AotCache,
196 disk_base: &str,
197 key: u64,
198 binding: &DimBinding,
199 build_hir: F,
200 options: &CompileOptions,
201 ) -> Result<&mut CompiledGraph, crate::AotCacheError> {
202 if let Some(idx) = self.entries.iter().position(|(k, _)| *k == key) {
203 return Ok(&mut self.entries[idx].1);
204 }
205 let device = self.device;
206 let template = self.ensure_template(build_hir, options)?;
207 let compiled = aot.specialize_cached(disk_base, binding, device, template, options)?;
208 if self.entries.len() >= self.capacity
209 && let Some(evict_key) = self.order.pop_front()
210 {
211 self.entries.retain(|(k, _)| *k != evict_key);
212 }
213 self.entries.push((key, compiled));
214 self.order.push_back(key);
215 Ok(&mut self.entries.last_mut().unwrap().1)
216 }
217}
218
219#[cfg(test)]
220mod tests {
221 use super::*;
222 use rlx_ir::hir::HirMut;
223 use rlx_ir::{DType, HirModule, Shape};
224
225 #[test]
226 fn template_specialize_compile_smoke() {
227 let device = Device::Cpu;
228 let mut pipe = ModelCompilePipeline::new(device);
229 let opts = CompileOptions::new();
230
231 let build = || {
232 let mut hir = HirModule::new("dyn");
233 let mut gb = HirMut::new(&mut hir);
234 let x = gb.input("x", Shape::new(&[1, 8, 4], DType::F32));
235 let w = gb.param("w", Shape::new(&[4, 2], DType::F32));
236 let y = hir.linear(x, w, None, None, Shape::new(&[1, 8, 2], DType::F32));
237 hir.set_outputs(vec![y]);
238 hir
239 };
240
241 pipe.build_template(build, &opts).unwrap();
242 let binding = rlx_ir::DimBinding::new();
243 let spec = pipe.specialize_template(&binding, &opts);
244 let manifest = BindingManifest::from_lir(&spec.lir);
245 assert_eq!(manifest.params[0].name, "w");
246 let mut compiled = pipe.compile_lir(spec, &opts);
247 compiled.set_param("w", &[0.0f32; 8]);
248 let outs = compiled.run(&[("x", &[0.0f32; 32])]);
249 assert_eq!(outs.len(), 1);
250 }
251}