1use std::collections::HashMap;
19
20use anyhow::Result;
21use rlx_ir::hir::HirModule;
22use rlx_ir::{Graph, GraphModule, GraphStage, HirNodeId, Shape, hir_to_graph};
23
24use crate::context::{FlowCtx, FlowState};
25use crate::execution::ModelExecutionConfig;
26use crate::extension::FlowExtensionPlan;
27use crate::profile::CompileProfile;
28use crate::stage::FlowStage;
29use crate::value::FlowValue;
30use crate::weight::WeightSource;
31
32#[derive(Debug)]
34pub struct ModelFlow {
35 name: String,
36 pub(crate) profile: CompileProfile,
37 inputs: Vec<(String, Shape)>,
39 pub(crate) stages: Vec<FlowStage>,
40 output_names: Vec<String>,
41 extra_outputs: Vec<HirNodeId>,
42 extension_plan: FlowExtensionPlan,
43}
44
45impl ModelFlow {
46 pub fn new(name: impl Into<String>) -> Self {
47 Self {
48 name: name.into(),
49 profile: CompileProfile::default(),
50 inputs: Vec::new(),
51 stages: Vec::new(),
52 output_names: vec!["output".into()],
53 extra_outputs: Vec::new(),
54 extension_plan: FlowExtensionPlan::default(),
55 }
56 }
57
58 pub fn with_extensions(mut self, plan: FlowExtensionPlan) -> Self {
60 self.extension_plan = plan;
61 self
62 }
63
64 pub fn input(mut self, name: impl Into<String>, shape: Shape) -> Self {
67 self.inputs.push((name.into(), shape));
68 self
69 }
70
71 pub fn with_profile(mut self, profile: CompileProfile) -> Self {
72 self.profile = profile;
73 self
74 }
75
76 pub fn profile(&self) -> &CompileProfile {
77 &self.profile
78 }
79
80 pub fn stage(mut self, stage: FlowStage) -> Self {
81 self.stages.push(stage);
82 self
83 }
84
85 pub fn output(mut self, name: impl Into<String>) -> Self {
86 self.output_names = vec![name.into()];
87 self
88 }
89
90 pub fn outputs(mut self, names: impl IntoIterator<Item = impl Into<String>>) -> Self {
91 self.output_names = names.into_iter().map(Into::into).collect();
92 self
93 }
94
95 pub fn with_extra_outputs(mut self, ids: Vec<HirNodeId>) -> Self {
97 self.extra_outputs = ids;
98 self
99 }
100
101 pub fn from_recipe(recipe: &impl crate::recipe::ModelRecipe) -> Self {
103 recipe.assemble()
104 }
105
106 pub fn build(self, weights: &mut dyn WeightSource) -> Result<BuiltModel> {
107 let mut module =
108 GraphModule::hir(&self.name).with_fusion_policy(self.profile.fusion_policy());
109 let mut params = HashMap::new();
110 let mut state = FlowState::default();
111 let mut ctx = FlowCtx {
112 module,
113 params: &mut params,
114 weights,
115 profile: &self.profile,
116 state: &mut state,
117 };
118
119 let mut value: Option<FlowValue> = None;
120 for (i, (name, shape)) in self.inputs.iter().enumerate() {
121 let id = ctx.input(name, shape.clone());
122 ctx.state.inputs.insert(name.clone(), (id, shape.clone()));
123 if i == 0 {
124 value = Some(ctx.wrap(id, shape.clone()));
125 }
126 }
127 for stage in &self.stages {
128 value = stage.emit(&mut ctx, value)?;
129 }
130
131 let primary = value.ok_or_else(|| anyhow::anyhow!("ModelFlow produced no output"))?;
132 let mut outputs = vec![primary.id];
133 outputs.extend(self.extra_outputs);
134
135 ctx.module.set_outputs(outputs);
136 module = ctx.module;
137 if let Some(hir) = module.as_hir_mut() {
138 self.extension_plan.apply(hir);
139 }
140
141 Ok(BuiltModel {
142 module,
143 params,
144 typed_params: Vec::new(),
145 profile: self.profile,
146 output_names: self.output_names,
147 primary_shape: primary.shape,
148 })
149 }
150
151 pub fn build_with(
155 self,
156 weights: &mut dyn WeightSource,
157 _gguf_packed: Option<&crate::GgufPackedParams>,
158 ) -> Result<BuiltModel> {
159 self.build(weights)
160 }
161}
162
163#[derive(Debug, Clone)]
165pub struct BuiltModel {
166 pub module: GraphModule,
167 pub params: HashMap<String, Vec<f32>>,
168 pub typed_params: Vec<(String, Vec<u8>, rlx_ir::DType)>,
170 pub profile: CompileProfile,
171 output_names: Vec<String>,
172 primary_shape: Shape,
173}
174
175impl BuiltModel {
176 pub fn with_execution_config(mut self, config: &ModelExecutionConfig) -> Self {
178 self.profile = config.compile_profile();
179 self
180 }
181
182 pub fn from_hir(hir: HirModule, params: HashMap<String, Vec<f32>>) -> anyhow::Result<Self> {
184 let primary = hir
185 .outputs
186 .first()
187 .copied()
188 .ok_or_else(|| anyhow::anyhow!("from_hir: module has no outputs"))?;
189 let primary_shape = hir.node(primary).shape.clone();
190 Ok(Self {
191 module: GraphModule::from_hir(hir),
192 params,
193 typed_params: Vec::new(),
194 profile: CompileProfile::default(),
195 output_names: vec!["output".into()],
196 primary_shape,
197 })
198 }
199
200 pub fn from_graph(graph: Graph, params: HashMap<String, Vec<f32>>) -> anyhow::Result<Self> {
202 let primary = graph
203 .outputs
204 .first()
205 .copied()
206 .ok_or_else(|| anyhow::anyhow!("from_graph: graph has no outputs"))?;
207 let primary_shape = graph.node(primary).shape.clone();
208 Ok(Self {
209 module: GraphModule::from_graph(graph),
210 params,
211 typed_params: Vec::new(),
212 profile: CompileProfile::default(),
213 output_names: vec!["output".into()],
214 primary_shape,
215 })
216 }
217
218 pub fn profile(&self) -> &CompileProfile {
219 &self.profile
220 }
221
222 pub fn params(&self) -> &HashMap<String, Vec<f32>> {
223 &self.params
224 }
225
226 pub fn primary_shape(&self) -> &Shape {
227 &self.primary_shape
228 }
229
230 pub fn output_names(&self) -> &[String] {
231 &self.output_names
232 }
233
234 pub fn into_graph_parts(mut self) -> Result<(Graph, HashMap<String, Vec<f32>>)> {
236 let params = std::mem::take(&mut self.params);
237 let graph = self.into_graph()?;
238 Ok((graph, params))
239 }
240
241 pub fn into_graph_module(self) -> GraphModule {
242 self.module
243 }
244
245 pub fn into_hir(self) -> Option<HirModule> {
246 self.module.into_hir()
247 }
248
249 pub fn into_graph(self) -> Result<Graph> {
250 if self.module.stage() == GraphStage::Hir {
251 let hir = self
252 .module
253 .into_hir()
254 .ok_or_else(|| anyhow::anyhow!("expected HIR stage"))?;
255 hir_to_graph(hir).map_err(Into::into)
256 } else {
257 self.module.into_graph().map_err(Into::into)
258 }
259 }
260
261 pub fn lower(self) -> Result<GraphModule> {
262 self.module.lower().map_err(Into::into)
263 }
264
265 pub fn with_extra_hir_outputs(mut self, extra: impl IntoIterator<Item = HirNodeId>) -> Self {
267 let primary = self.module.as_hir().expect("HIR stage").outputs[0];
268 let mut outputs = vec![primary];
269 outputs.extend(extra);
270 self.module.set_outputs(outputs);
271 self
272 }
273
274 pub fn into_parts(self) -> Result<(HirModule, HashMap<String, Vec<f32>>)> {
276 let params = self.params.clone();
277 let hir = self
278 .into_hir()
279 .ok_or_else(|| anyhow::anyhow!("expected HIR stage"))?;
280 Ok((hir, params))
281 }
282}
283
284#[cfg(test)]
285mod tests {
286 use super::*;
287 use crate::layer::LayerStack;
288 use crate::weight::MapWeights;
289 use rlx_ir::{DType, Shape};
290
291 #[test]
292 fn minimal_embed_flow() {
293 let mut w = MapWeights::default();
294 w.insert("embed.weight", vec![1.0, 0.0, 0.0, 1.0], vec![2, 2]);
295
296 let flow = ModelFlow::new("smoke")
297 .input("ids", Shape::new(&[1, 2], DType::F32))
298 .embed("embed.weight");
299
300 let built = flow.build(&mut w).unwrap();
301 let hir = built.into_hir().unwrap();
302 assert!(hir.len() >= 3);
303 }
304
305 #[test]
306 fn custom_stage_passthrough() {
307 let mut w = MapWeights::default();
308 w.insert("embed.weight", vec![1.0, 0.0, 0.0, 1.0], vec![2, 2]);
309
310 let flow = ModelFlow::new("custom")
311 .input("ids", Shape::new(&[1, 2], DType::F32))
312 .embed("embed.weight")
313 .custom(|_emit, input| Ok(input));
314
315 let built = flow.build(&mut w).unwrap();
316 assert_eq!(built.primary_shape().rank(), 3);
317 }
318
319 #[test]
320 fn layer_stack_builds_sequence() {
321 let mut w = MapWeights::default();
322 w.insert("ln.weight", vec![1.0; 4], vec![4]);
323
324 let stage = LayerStack::named("block")
325 .rms_norm("ln.weight", 1e-5)
326 .build();
327
328 let flow = ModelFlow::new("stack")
329 .input("x", Shape::new(&[1, 2, 4], DType::F32))
330 .zero_beta(4)
331 .raw_stage(stage);
332
333 let built = flow.build(&mut w).unwrap();
334 assert!(built.into_hir().unwrap().len() >= 4);
335 }
336
337 #[test]
338 fn when_conditional_embed() {
339 let mut w = MapWeights::default();
340 w.insert("embed.weight", vec![1.0, 0.0, 0.0, 1.0], vec![2, 2]);
341
342 let with_embed = ModelFlow::new("cond")
343 .input("ids", Shape::new(&[1, 2], DType::F32))
344 .when(true, |f| f.embed("embed.weight"))
345 .build(&mut w)
346 .unwrap();
347 assert!(with_embed.into_hir().unwrap().len() >= 3);
348
349 let mut w2 = MapWeights::default();
350 w2.insert("embed.weight", vec![1.0, 0.0, 0.0, 1.0], vec![2, 2]);
351 let skipped = ModelFlow::new("cond")
352 .input("ids", Shape::new(&[1, 2], DType::F32))
353 .when(false, |f| f.embed("embed.weight"))
354 .build(&mut w2)
355 .unwrap();
356 assert_eq!(skipped.into_hir().unwrap().len(), 1);
358 }
359}