1use std::collections::HashMap;
7
8use anyhow::Result;
9use rlx_ir::hir::HirModule;
10use rlx_ir::{Graph, GraphModule, GraphStage, HirNodeId, Shape, hir_to_graph};
11
12use crate::context::{FlowCtx, FlowState};
13use crate::execution::ModelExecutionConfig;
14use crate::extension::FlowExtensionPlan;
15use crate::profile::CompileProfile;
16use crate::stage::FlowStage;
17use crate::value::FlowValue;
18use crate::weight::WeightSource;
19
20#[derive(Debug)]
22pub struct ModelFlow {
23 name: String,
24 pub(crate) profile: CompileProfile,
25 inputs: Vec<(String, Shape)>,
27 pub(crate) stages: Vec<FlowStage>,
28 output_names: Vec<String>,
29 extra_outputs: Vec<HirNodeId>,
30 extension_plan: FlowExtensionPlan,
31}
32
33impl ModelFlow {
34 pub fn new(name: impl Into<String>) -> Self {
35 Self {
36 name: name.into(),
37 profile: CompileProfile::default(),
38 inputs: Vec::new(),
39 stages: Vec::new(),
40 output_names: vec!["output".into()],
41 extra_outputs: Vec::new(),
42 extension_plan: FlowExtensionPlan::default(),
43 }
44 }
45
46 pub fn with_extensions(mut self, plan: FlowExtensionPlan) -> Self {
48 self.extension_plan = plan;
49 self
50 }
51
52 pub fn input(mut self, name: impl Into<String>, shape: Shape) -> Self {
55 self.inputs.push((name.into(), shape));
56 self
57 }
58
59 pub fn with_profile(mut self, profile: CompileProfile) -> Self {
60 self.profile = profile;
61 self
62 }
63
64 pub fn profile(&self) -> &CompileProfile {
65 &self.profile
66 }
67
68 pub fn stage(mut self, stage: FlowStage) -> Self {
69 self.stages.push(stage);
70 self
71 }
72
73 pub fn output(mut self, name: impl Into<String>) -> Self {
74 self.output_names = vec![name.into()];
75 self
76 }
77
78 pub fn outputs(mut self, names: impl IntoIterator<Item = impl Into<String>>) -> Self {
79 self.output_names = names.into_iter().map(Into::into).collect();
80 self
81 }
82
83 pub fn with_extra_outputs(mut self, ids: Vec<HirNodeId>) -> Self {
85 self.extra_outputs = ids;
86 self
87 }
88
89 pub fn from_recipe(recipe: &impl crate::recipe::ModelRecipe) -> Self {
91 recipe.assemble()
92 }
93
94 pub fn build(self, weights: &mut dyn WeightSource) -> Result<BuiltModel> {
95 let mut module =
96 GraphModule::hir(&self.name).with_fusion_policy(self.profile.fusion_policy());
97 let mut params = HashMap::new();
98 let mut state = FlowState::default();
99 let mut ctx = FlowCtx {
100 module,
101 params: &mut params,
102 weights,
103 profile: &self.profile,
104 state: &mut state,
105 };
106
107 let mut value: Option<FlowValue> = None;
108 for (i, (name, shape)) in self.inputs.iter().enumerate() {
109 let id = ctx.input(name, shape.clone());
110 ctx.state.inputs.insert(name.clone(), (id, shape.clone()));
111 if i == 0 {
112 value = Some(ctx.wrap(id, shape.clone()));
113 }
114 }
115 for stage in &self.stages {
116 value = stage.emit(&mut ctx, value)?;
117 }
118
119 let primary = value.ok_or_else(|| anyhow::anyhow!("ModelFlow produced no output"))?;
120 let mut outputs = vec![primary.id];
121 outputs.extend(self.extra_outputs);
122
123 ctx.module.set_outputs(outputs);
124 module = ctx.module;
125 if let Some(hir) = module.as_hir_mut() {
126 self.extension_plan.apply(hir);
127 }
128
129 Ok(BuiltModel {
130 module,
131 params,
132 typed_params: Vec::new(),
133 profile: self.profile,
134 output_names: self.output_names,
135 primary_shape: primary.shape,
136 })
137 }
138
139 pub fn build_with(
143 self,
144 weights: &mut dyn WeightSource,
145 _gguf_packed: Option<&crate::GgufPackedParams>,
146 ) -> Result<BuiltModel> {
147 self.build(weights)
148 }
149}
150
151#[derive(Debug, Clone)]
153pub struct BuiltModel {
154 pub module: GraphModule,
155 pub params: HashMap<String, Vec<f32>>,
156 pub typed_params: Vec<(String, Vec<u8>, rlx_ir::DType)>,
158 pub profile: CompileProfile,
159 output_names: Vec<String>,
160 primary_shape: Shape,
161}
162
163impl BuiltModel {
164 pub fn with_execution_config(mut self, config: &ModelExecutionConfig) -> Self {
166 self.profile = config.compile_profile();
167 self
168 }
169
170 pub fn from_hir(hir: HirModule, params: HashMap<String, Vec<f32>>) -> anyhow::Result<Self> {
172 let primary = hir
173 .outputs
174 .first()
175 .copied()
176 .ok_or_else(|| anyhow::anyhow!("from_hir: module has no outputs"))?;
177 let primary_shape = hir.node(primary).shape.clone();
178 Ok(Self {
179 module: GraphModule::from_hir(hir),
180 params,
181 typed_params: Vec::new(),
182 profile: CompileProfile::default(),
183 output_names: vec!["output".into()],
184 primary_shape,
185 })
186 }
187
188 pub fn from_graph(graph: Graph, params: HashMap<String, Vec<f32>>) -> anyhow::Result<Self> {
190 let primary = graph
191 .outputs
192 .first()
193 .copied()
194 .ok_or_else(|| anyhow::anyhow!("from_graph: graph has no outputs"))?;
195 let primary_shape = graph.node(primary).shape.clone();
196 Ok(Self {
197 module: GraphModule::from_graph(graph),
198 params,
199 typed_params: Vec::new(),
200 profile: CompileProfile::default(),
201 output_names: vec!["output".into()],
202 primary_shape,
203 })
204 }
205
206 pub fn profile(&self) -> &CompileProfile {
207 &self.profile
208 }
209
210 pub fn params(&self) -> &HashMap<String, Vec<f32>> {
211 &self.params
212 }
213
214 pub fn primary_shape(&self) -> &Shape {
215 &self.primary_shape
216 }
217
218 pub fn output_names(&self) -> &[String] {
219 &self.output_names
220 }
221
222 pub fn into_graph_parts(self) -> Result<(Graph, HashMap<String, Vec<f32>>)> {
224 let params = self.params.clone();
225 let graph = self.into_graph()?;
226 Ok((graph, params))
227 }
228
229 pub fn into_graph_module(self) -> GraphModule {
230 self.module
231 }
232
233 pub fn into_hir(self) -> Option<HirModule> {
234 self.module.into_hir()
235 }
236
237 pub fn into_graph(self) -> Result<Graph> {
238 if self.module.stage() == GraphStage::Hir {
239 let hir = self
240 .module
241 .into_hir()
242 .ok_or_else(|| anyhow::anyhow!("expected HIR stage"))?;
243 hir_to_graph(hir).map_err(Into::into)
244 } else {
245 self.module.into_graph().map_err(Into::into)
246 }
247 }
248
249 pub fn lower(self) -> Result<GraphModule> {
250 self.module.lower().map_err(Into::into)
251 }
252
253 pub fn with_extra_hir_outputs(mut self, extra: impl IntoIterator<Item = HirNodeId>) -> Self {
255 let primary = self.module.as_hir().expect("HIR stage").outputs[0];
256 let mut outputs = vec![primary];
257 outputs.extend(extra);
258 self.module.set_outputs(outputs);
259 self
260 }
261
262 pub fn into_parts(self) -> Result<(HirModule, HashMap<String, Vec<f32>>)> {
264 let params = self.params.clone();
265 let hir = self
266 .into_hir()
267 .ok_or_else(|| anyhow::anyhow!("expected HIR stage"))?;
268 Ok((hir, params))
269 }
270}
271
272#[cfg(test)]
273mod tests {
274 use super::*;
275 use crate::layer::LayerStack;
276 use crate::weight::MapWeights;
277 use rlx_ir::{DType, Shape};
278
279 #[test]
280 fn minimal_embed_flow() {
281 let mut w = MapWeights::default();
282 w.insert("embed.weight", vec![1.0, 0.0, 0.0, 1.0], vec![2, 2]);
283
284 let flow = ModelFlow::new("smoke")
285 .input("ids", Shape::new(&[1, 2], DType::F32))
286 .embed("embed.weight");
287
288 let built = flow.build(&mut w).unwrap();
289 let hir = built.into_hir().unwrap();
290 assert!(hir.len() >= 3);
291 }
292
293 #[test]
294 fn custom_stage_passthrough() {
295 let mut w = MapWeights::default();
296 w.insert("embed.weight", vec![1.0, 0.0, 0.0, 1.0], vec![2, 2]);
297
298 let flow = ModelFlow::new("custom")
299 .input("ids", Shape::new(&[1, 2], DType::F32))
300 .embed("embed.weight")
301 .custom(|_emit, input| Ok(input));
302
303 let built = flow.build(&mut w).unwrap();
304 assert_eq!(built.primary_shape().rank(), 3);
305 }
306
307 #[test]
308 fn layer_stack_builds_sequence() {
309 let mut w = MapWeights::default();
310 w.insert("ln.weight", vec![1.0; 4], vec![4]);
311
312 let stage = LayerStack::named("block")
313 .rms_norm("ln.weight", 1e-5)
314 .build();
315
316 let flow = ModelFlow::new("stack")
317 .input("x", Shape::new(&[1, 2, 4], DType::F32))
318 .zero_beta(4)
319 .raw_stage(stage);
320
321 let built = flow.build(&mut w).unwrap();
322 assert!(built.into_hir().unwrap().len() >= 4);
323 }
324
325 #[test]
326 fn when_conditional_embed() {
327 let mut w = MapWeights::default();
328 w.insert("embed.weight", vec![1.0, 0.0, 0.0, 1.0], vec![2, 2]);
329
330 let with_embed = ModelFlow::new("cond")
331 .input("ids", Shape::new(&[1, 2], DType::F32))
332 .when(true, |f| f.embed("embed.weight"))
333 .build(&mut w)
334 .unwrap();
335 assert!(with_embed.into_hir().unwrap().len() >= 3);
336
337 let mut w2 = MapWeights::default();
338 w2.insert("embed.weight", vec![1.0, 0.0, 0.0, 1.0], vec![2, 2]);
339 let skipped = ModelFlow::new("cond")
340 .input("ids", Shape::new(&[1, 2], DType::F32))
341 .when(false, |f| f.embed("embed.weight"))
342 .build(&mut w2)
343 .unwrap();
344 assert_eq!(skipped.into_hir().unwrap().len(), 1);
346 }
347}