1use std::collections::HashMap;
7
8use anyhow::Result;
9use rlx_ir::hir::HirModule;
10use rlx_ir::{Graph, GraphModule, GraphStage, HirNodeId, Shape, hir_to_graph};
11
12use crate::context::{FlowCtx, FlowState};
13use crate::execution::ModelExecutionConfig;
14use crate::extension::FlowExtensionPlan;
15use crate::profile::CompileProfile;
16use crate::stage::FlowStage;
17use crate::value::FlowValue;
18use crate::weight::WeightSource;
19
20#[derive(Debug)]
22pub struct ModelFlow {
23 name: String,
24 pub(crate) profile: CompileProfile,
25 inputs: Vec<(String, Shape)>,
27 pub(crate) stages: Vec<FlowStage>,
28 output_names: Vec<String>,
29 extra_outputs: Vec<HirNodeId>,
30 extension_plan: FlowExtensionPlan,
31}
32
33impl ModelFlow {
34 pub fn new(name: impl Into<String>) -> Self {
35 Self {
36 name: name.into(),
37 profile: CompileProfile::default(),
38 inputs: Vec::new(),
39 stages: Vec::new(),
40 output_names: vec!["output".into()],
41 extra_outputs: Vec::new(),
42 extension_plan: FlowExtensionPlan::default(),
43 }
44 }
45
46 pub fn with_extensions(mut self, plan: FlowExtensionPlan) -> Self {
48 self.extension_plan = plan;
49 self
50 }
51
52 pub fn input(mut self, name: impl Into<String>, shape: Shape) -> Self {
55 self.inputs.push((name.into(), shape));
56 self
57 }
58
59 pub fn with_profile(mut self, profile: CompileProfile) -> Self {
60 self.profile = profile;
61 self
62 }
63
64 pub fn profile(&self) -> &CompileProfile {
65 &self.profile
66 }
67
68 pub fn stage(mut self, stage: FlowStage) -> Self {
69 self.stages.push(stage);
70 self
71 }
72
73 pub fn output(mut self, name: impl Into<String>) -> Self {
74 self.output_names = vec![name.into()];
75 self
76 }
77
78 pub fn outputs(mut self, names: impl IntoIterator<Item = impl Into<String>>) -> Self {
79 self.output_names = names.into_iter().map(Into::into).collect();
80 self
81 }
82
83 pub fn with_extra_outputs(mut self, ids: Vec<HirNodeId>) -> Self {
85 self.extra_outputs = ids;
86 self
87 }
88
89 pub fn from_recipe(recipe: &impl crate::recipe::ModelRecipe) -> Self {
91 recipe.assemble()
92 }
93
94 pub fn build(self, weights: &mut dyn WeightSource) -> Result<BuiltModel> {
95 let mut module =
96 GraphModule::hir(&self.name).with_fusion_policy(self.profile.fusion_policy());
97 let mut params = HashMap::new();
98 let mut state = FlowState::default();
99 let mut ctx = FlowCtx {
100 module,
101 params: &mut params,
102 weights,
103 profile: &self.profile,
104 state: &mut state,
105 };
106
107 let mut value: Option<FlowValue> = None;
108 for (i, (name, shape)) in self.inputs.iter().enumerate() {
109 let id = ctx.input(name, shape.clone());
110 ctx.state.inputs.insert(name.clone(), (id, shape.clone()));
111 if i == 0 {
112 value = Some(ctx.wrap(id, shape.clone()));
113 }
114 }
115 for stage in &self.stages {
116 value = stage.emit(&mut ctx, value)?;
117 }
118
119 let primary = value.ok_or_else(|| anyhow::anyhow!("ModelFlow produced no output"))?;
120 let mut outputs = vec![primary.id];
121 outputs.extend(self.extra_outputs);
122
123 ctx.module.set_outputs(outputs);
124 module = ctx.module;
125 if let Some(hir) = module.as_hir_mut() {
126 self.extension_plan.apply(hir);
127 }
128
129 Ok(BuiltModel {
130 module,
131 params,
132 profile: self.profile,
133 output_names: self.output_names,
134 primary_shape: primary.shape,
135 })
136 }
137}
138
139#[derive(Debug, Clone)]
141pub struct BuiltModel {
142 pub module: GraphModule,
143 pub params: HashMap<String, Vec<f32>>,
144 pub profile: CompileProfile,
145 output_names: Vec<String>,
146 primary_shape: Shape,
147}
148
149impl BuiltModel {
150 pub fn with_execution_config(mut self, config: &ModelExecutionConfig) -> Self {
152 self.profile = config.compile_profile();
153 self
154 }
155
156 pub fn from_hir(hir: HirModule, params: HashMap<String, Vec<f32>>) -> anyhow::Result<Self> {
158 let primary = hir
159 .outputs
160 .first()
161 .copied()
162 .ok_or_else(|| anyhow::anyhow!("from_hir: module has no outputs"))?;
163 let primary_shape = hir.node(primary).shape.clone();
164 Ok(Self {
165 module: GraphModule::from_hir(hir),
166 params,
167 profile: CompileProfile::default(),
168 output_names: vec!["output".into()],
169 primary_shape,
170 })
171 }
172
173 pub fn from_graph(graph: Graph, params: HashMap<String, Vec<f32>>) -> anyhow::Result<Self> {
175 let primary = graph
176 .outputs
177 .first()
178 .copied()
179 .ok_or_else(|| anyhow::anyhow!("from_graph: graph has no outputs"))?;
180 let primary_shape = graph.node(primary).shape.clone();
181 Ok(Self {
182 module: GraphModule::from_graph(graph),
183 params,
184 profile: CompileProfile::default(),
185 output_names: vec!["output".into()],
186 primary_shape,
187 })
188 }
189
190 pub fn profile(&self) -> &CompileProfile {
191 &self.profile
192 }
193
194 pub fn params(&self) -> &HashMap<String, Vec<f32>> {
195 &self.params
196 }
197
198 pub fn primary_shape(&self) -> &Shape {
199 &self.primary_shape
200 }
201
202 pub fn output_names(&self) -> &[String] {
203 &self.output_names
204 }
205
206 pub fn into_graph_parts(self) -> Result<(Graph, HashMap<String, Vec<f32>>)> {
208 let params = self.params.clone();
209 let graph = self.into_graph()?;
210 Ok((graph, params))
211 }
212
213 pub fn into_graph_module(self) -> GraphModule {
214 self.module
215 }
216
217 pub fn into_hir(self) -> Option<HirModule> {
218 self.module.into_hir()
219 }
220
221 pub fn into_graph(self) -> Result<Graph> {
222 if self.module.stage() == GraphStage::Hir {
223 let hir = self
224 .module
225 .into_hir()
226 .ok_or_else(|| anyhow::anyhow!("expected HIR stage"))?;
227 hir_to_graph(hir).map_err(Into::into)
228 } else {
229 self.module.into_graph().map_err(Into::into)
230 }
231 }
232
233 pub fn lower(self) -> Result<GraphModule> {
234 self.module.lower().map_err(Into::into)
235 }
236
237 pub fn with_extra_hir_outputs(mut self, extra: impl IntoIterator<Item = HirNodeId>) -> Self {
239 let primary = self.module.as_hir().expect("HIR stage").outputs[0];
240 let mut outputs = vec![primary];
241 outputs.extend(extra);
242 self.module.set_outputs(outputs);
243 self
244 }
245
246 pub fn into_parts(self) -> Result<(HirModule, HashMap<String, Vec<f32>>)> {
248 let params = self.params.clone();
249 let hir = self
250 .into_hir()
251 .ok_or_else(|| anyhow::anyhow!("expected HIR stage"))?;
252 Ok((hir, params))
253 }
254}
255
256#[cfg(test)]
257mod tests {
258 use super::*;
259 use crate::layer::LayerStack;
260 use crate::weight::MapWeights;
261 use rlx_ir::{DType, Shape};
262
263 #[test]
264 fn minimal_embed_flow() {
265 let mut w = MapWeights::default();
266 w.insert("embed.weight", vec![1.0, 0.0, 0.0, 1.0], vec![2, 2]);
267
268 let flow = ModelFlow::new("smoke")
269 .input("ids", Shape::new(&[1, 2], DType::F32))
270 .embed("embed.weight");
271
272 let built = flow.build(&mut w).unwrap();
273 let hir = built.into_hir().unwrap();
274 assert!(hir.len() >= 3);
275 }
276
277 #[test]
278 fn custom_stage_passthrough() {
279 let mut w = MapWeights::default();
280 w.insert("embed.weight", vec![1.0, 0.0, 0.0, 1.0], vec![2, 2]);
281
282 let flow = ModelFlow::new("custom")
283 .input("ids", Shape::new(&[1, 2], DType::F32))
284 .embed("embed.weight")
285 .custom(|_emit, input| Ok(input));
286
287 let built = flow.build(&mut w).unwrap();
288 assert_eq!(built.primary_shape().rank(), 3);
289 }
290
291 #[test]
292 fn layer_stack_builds_sequence() {
293 let mut w = MapWeights::default();
294 w.insert("ln.weight", vec![1.0; 4], vec![4]);
295
296 let stage = LayerStack::named("block")
297 .rms_norm("ln.weight", 1e-5)
298 .build();
299
300 let flow = ModelFlow::new("stack")
301 .input("x", Shape::new(&[1, 2, 4], DType::F32))
302 .zero_beta(4)
303 .raw_stage(stage);
304
305 let built = flow.build(&mut w).unwrap();
306 assert!(built.into_hir().unwrap().len() >= 4);
307 }
308
309 #[test]
310 fn when_conditional_embed() {
311 let mut w = MapWeights::default();
312 w.insert("embed.weight", vec![1.0, 0.0, 0.0, 1.0], vec![2, 2]);
313
314 let with_embed = ModelFlow::new("cond")
315 .input("ids", Shape::new(&[1, 2], DType::F32))
316 .when(true, |f| f.embed("embed.weight"))
317 .build(&mut w)
318 .unwrap();
319 assert!(with_embed.into_hir().unwrap().len() >= 3);
320
321 let mut w2 = MapWeights::default();
322 w2.insert("embed.weight", vec![1.0, 0.0, 0.0, 1.0], vec![2, 2]);
323 let skipped = ModelFlow::new("cond")
324 .input("ids", Shape::new(&[1, 2], DType::F32))
325 .when(false, |f| f.embed("embed.weight"))
326 .build(&mut w2)
327 .unwrap();
328 assert_eq!(skipped.into_hir().unwrap().len(), 1);
330 }
331}