Skip to main content

rlx_flow/
dsl.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3
4//! Fluent builder methods on [`ModelFlow`] — sugar over [`FlowStage`].
5
6use std::path::Path;
7use std::sync::Arc;
8
9use crate::blocks::{
10    AttnMaskStage, BertEncoderLayerSpec, BertEncoderLayerStage, BertQkvStyle,
11    BindDecodeInputsStage, ClsTokenPoolStage, CustomStage, EmbedStage, GatherAddStage,
12    GatherFromInputStage, GatherLastTokenStage, GeluFfnStage, LayerNormStage, LinearStage,
13    LlamaDecodeLayerStage, LlamaDecoderSpec, LlamaDecoderStage, LlamaKvTapStage, LmHeadStage,
14    NomicEncoderLayerSpec, NomicEncoderLayerStage, RepeatStage, ResidualAddStage,
15    ResidualSaveStage, RmsNormStage, RopeTablesStage, SelfAttnPrefillSpec, SelfAttnPrefillStage,
16    SwiGluStage, dinov2_layer_fused, llama_prefill_layer_composed, llama_prefill_layer_fused,
17    nomic_vision_layer_fused,
18};
19use crate::escape::Emit;
20use crate::flow::ModelFlow;
21use crate::layer::LayerStack;
22use crate::profile::CompileProfile;
23use crate::side::SideOutputs;
24use crate::stage::FlowStage;
25use crate::stream::{DualStreamStage, LoadStreamStage, StoreStreamStage};
26use crate::value::FlowValue;
27
28impl ModelFlow {
29    /// Load tier-1 profile from a `*.rlx.toml` file (falls back to default on error).
30    pub fn profile_file(mut self, path: impl AsRef<Path>, default: fn() -> CompileProfile) -> Self {
31        self.profile = CompileProfile::from_toml_path(path.as_ref()).unwrap_or_else(|_| default());
32        self
33    }
34
35    /// Encoder / embedding model defaults (Direct lowering, no KV fusion).
36    pub fn profile_encoder(mut self) -> Self {
37        self.profile = CompileProfile::encoder();
38        self
39    }
40
41    /// Gather rows from a side input into the primary flow (starts embedding stack).
42    pub fn gather_from_input(
43        mut self,
44        input_name: impl Into<String>,
45        weight_key: impl Into<String>,
46    ) -> Self {
47        self.stages
48            .push(FlowStage::GatherFromInput(GatherFromInputStage::new(
49                input_name, weight_key, 0,
50            )));
51        self
52    }
53
54    /// Add an embedding looked up from a side input.
55    pub fn gather_add(
56        mut self,
57        input_name: impl Into<String>,
58        weight_key: impl Into<String>,
59    ) -> Self {
60        self.stages.push(FlowStage::GatherAdd(GatherAddStage::new(
61            input_name, weight_key, 0,
62        )));
63        self
64    }
65
66    /// LayerNorm with separate gamma/beta weights.
67    pub fn layer_norm(
68        mut self,
69        gamma_key: impl Into<String>,
70        beta_key: impl Into<String>,
71        eps: f32,
72    ) -> Self {
73        self.stages.push(FlowStage::LayerNorm(LayerNormStage::new(
74            gamma_key, beta_key, eps,
75        )));
76        self
77    }
78
79    /// BERT-style GELU FFN under a layer prefix.
80    pub fn gelu_ffn(mut self, layer_prefix: impl Into<String>) -> Self {
81        self.stages
82            .push(FlowStage::GeluFfn(GeluFfnStage::hf_bert(layer_prefix)));
83        self
84    }
85
86    /// Repeat NomicBERT encoder layers.
87    pub fn repeat_nomic_layers(
88        self,
89        count: usize,
90        hidden_size: usize,
91        num_heads: usize,
92        head_dim: usize,
93        eps: f32,
94    ) -> Self {
95        self.repeat_layers(count, move |i| FlowStage::Named {
96            name: format!("layer{i}"),
97            inner: std::sync::Arc::new(FlowStage::NomicEncoderLayer(NomicEncoderLayerStage::new(
98                NomicEncoderLayerSpec::hf(
99                    format!("encoder.layers.{i}"),
100                    hidden_size,
101                    num_heads,
102                    head_dim,
103                    eps,
104                ),
105            ))),
106        })
107    }
108
109    /// BERT-style encoder layer (fused QKV + padding-mask attention + GELU FFN).
110    pub fn bert_encoder_layer(mut self, spec: BertEncoderLayerSpec) -> Self {
111        self.stages
112            .push(FlowStage::BertEncoderLayer(BertEncoderLayerStage::new(
113                spec,
114            )));
115        self
116    }
117
118    /// Repeat BERT encoder layers with auto-named prefixes.
119    pub fn repeat_bert_layers(
120        self,
121        count: usize,
122        prefix: impl Into<String>,
123        qkv_style: BertQkvStyle,
124        hidden_size: usize,
125        num_heads: usize,
126        eps: f32,
127    ) -> Self {
128        let prefix = prefix.into();
129        self.repeat_layers(count, move |i| {
130            let lp = if prefix.is_empty() {
131                format!("encoder.layer.{i}")
132            } else {
133                format!("{prefix}.encoder.layer.{i}")
134            };
135            FlowStage::Named {
136                name: format!("layer{i}"),
137                inner: std::sync::Arc::new(FlowStage::BertEncoderLayer(
138                    BertEncoderLayerStage::new(BertEncoderLayerSpec::hf(
139                        lp,
140                        qkv_style,
141                        hidden_size,
142                        num_heads,
143                        eps,
144                    )),
145                )),
146            }
147        })
148    }
149
150    /// Synthesize an all-ones attention mask for vision encoders (no padding).
151    pub fn attn_mask_ones(mut self, batch: usize, seq: usize) -> Self {
152        self.stages
153            .push(FlowStage::AttnMask(AttnMaskStage::ones(batch, seq)));
154        self
155    }
156
157    /// Repeat DINOv2 ViT encoder blocks.
158    pub fn repeat_dinov2_layers(
159        self,
160        count: usize,
161        hidden_size: usize,
162        num_heads: usize,
163        eps: f32,
164    ) -> Self {
165        self.repeat_layers(count, move |i| {
166            dinov2_layer_fused(i, hidden_size, num_heads, eps)
167        })
168    }
169
170    /// Repeat NomicVision encoder blocks.
171    pub fn repeat_vision_layers(
172        self,
173        count: usize,
174        hidden_size: usize,
175        num_heads: usize,
176        eps: f32,
177    ) -> Self {
178        self.repeat_layers(count, move |i| {
179            nomic_vision_layer_fused(i, hidden_size, num_heads, eps)
180        })
181    }
182
183    /// Pool CLS token: `[batch, seq, hidden]` → `[batch, hidden]`.
184    pub fn cls_token_pool(mut self, batch: usize, hidden: usize) -> Self {
185        self.stages
186            .push(FlowStage::ClsTokenPool(ClsTokenPoolStage::new(
187                batch, hidden,
188            )));
189        self
190    }
191
192    /// Fusion-first prefill defaults.
193    pub fn profile_prefill(mut self) -> Self {
194        self.profile = CompileProfile::llama32_prefill();
195        self
196    }
197
198    /// Decode / KV-cache defaults (`Fusable` lowering).
199    pub fn profile_decode(mut self) -> Self {
200        self.profile = CompileProfile::llama32_decode();
201        self
202    }
203
204    /// Token embedding (`model.embed_tokens.weight` by default).
205    pub fn embed(mut self, weight_key: impl Into<String>) -> Self {
206        self.stages
207            .push(FlowStage::Embed(EmbedStage::token(weight_key)));
208        self
209    }
210
211    /// HuggingFace-style token embedding table.
212    pub fn token_embed(self) -> Self {
213        self.embed("model.embed_tokens.weight")
214    }
215
216    /// Precomputed RoPE sin/cos tables stored as params.
217    pub fn rope_tables(mut self, tables: RopeTablesStage) -> Self {
218        self.stages.push(FlowStage::RopeTables(tables));
219        self
220    }
221
222    /// Rank-1 zero vector for RMSNorm beta slots (LLaMA has no beta).
223    pub fn zero_beta(self, len: usize) -> Self {
224        self.zero_beta_named("zero_beta", len)
225    }
226
227    pub fn zero_beta_named(mut self, name: impl Into<String>, len: usize) -> Self {
228        self.stages.push(FlowStage::ZeroBeta {
229            name: name.into(),
230            len,
231        });
232        self
233    }
234
235    /// Bind decode inputs (call after declaring `rope_cos`, `past_k_*`, …).
236    pub fn bind_decode_inputs(mut self, num_layers: usize, custom_mask: bool) -> Self {
237        self.stages
238            .push(FlowStage::BindDecodeInputs(BindDecodeInputsStage {
239                num_layers,
240                use_custom_mask: custom_mask,
241            }));
242        self
243    }
244
245    /// Repeat a per-layer stage `count` times (layer index passed to closure).
246    pub fn repeat_layers(
247        mut self,
248        count: usize,
249        stage_for_layer: impl Fn(usize) -> FlowStage + Send + Sync + 'static,
250    ) -> Self {
251        self.stages
252            .push(FlowStage::Repeat(RepeatStage::new(count, stage_for_layer)));
253        self
254    }
255
256    /// Named decoder layer (shows up in fusion / inspect dumps).
257    pub fn named_layer(mut self, name: impl Into<String>, inner: FlowStage) -> Self {
258        self.stages.push(FlowStage::Named {
259            name: name.into(),
260            inner: Arc::new(inner),
261        });
262        self
263    }
264
265    /// Build a named layer from a [`LayerStack`] closure.
266    pub fn layer(
267        self,
268        name: impl Into<String>,
269        build: impl FnOnce(LayerStack) -> LayerStack,
270    ) -> Self {
271        self.raw_stage(build(LayerStack::named(name)).build())
272    }
273
274    /// Fused LLaMA prefill layer (default fast path).
275    pub fn llama_prefill_layer(self, layer_idx: usize, spec: LlamaDecoderSpec) -> Self {
276        self.raw_stage(llama_prefill_layer_fused(layer_idx, spec))
277    }
278
279    /// Composed LLaMA prefill layer (small blocks — customize via [`LayerStack`]).
280    pub fn llama_prefill_layer_composed(self, layer_idx: usize, spec: LlamaDecoderSpec) -> Self {
281        self.raw_stage(llama_prefill_layer_composed(layer_idx, spec))
282    }
283
284    pub fn linear(mut self, weight_key: impl Into<String>, transpose: bool) -> Self {
285        self.stages
286            .push(FlowStage::Linear(LinearStage::new(weight_key, transpose)));
287        self
288    }
289
290    pub fn residual_save(mut self) -> Self {
291        self.stages.push(FlowStage::ResidualSave(ResidualSaveStage));
292        self
293    }
294
295    pub fn residual_add(mut self) -> Self {
296        self.stages.push(FlowStage::ResidualAdd(ResidualAddStage));
297        self
298    }
299
300    pub fn swiglu(
301        mut self,
302        gate_key: impl Into<String>,
303        up_key: impl Into<String>,
304        down_key: impl Into<String>,
305    ) -> Self {
306        self.stages.push(FlowStage::SwiGlu(SwiGluStage::new(
307            gate_key, up_key, down_key,
308        )));
309        self
310    }
311
312    pub fn swiglu_hf_mlp(mut self, prefix: impl Into<String>) -> Self {
313        self.stages
314            .push(FlowStage::SwiGlu(SwiGluStage::hf_mlp(prefix)));
315        self
316    }
317
318    pub fn self_attn_prefill(mut self, spec: SelfAttnPrefillSpec) -> Self {
319        self.stages
320            .push(FlowStage::SelfAttnPrefill(SelfAttnPrefillStage::new(spec)));
321        self
322    }
323
324    pub fn gdn_scan(mut self, stage: crate::blocks::GdnScanStage) -> Self {
325        self.stages.push(FlowStage::GdnScan(stage));
326        self
327    }
328
329    pub fn store_stream(mut self, name: impl Into<String>) -> Self {
330        self.stages
331            .push(FlowStage::StoreStream(StoreStreamStage::new(name)));
332        self
333    }
334
335    pub fn load_stream(mut self, name: impl Into<String>) -> Self {
336        self.stages
337            .push(FlowStage::LoadStream(LoadStreamStage::new(name)));
338        self
339    }
340
341    /// Bind declared graph inputs into named streams (multi-input models).
342    ///
343    /// Example: FLUX `.bind_inputs_to_streams(&[("hidden", "img"), ("encoder", "txt")])`.
344    pub fn bind_inputs_to_streams(
345        mut self,
346        pairs: impl IntoIterator<Item = (impl Into<String>, impl Into<String>)>,
347    ) -> Self {
348        let pairs: Vec<(String, String)> = pairs
349            .into_iter()
350            .map(|(input, stream)| (input.into(), stream.into()))
351            .collect();
352        self.stages.push(FlowStage::Custom(CustomStage::named(
353            "bind_inputs_to_streams",
354            move |emit, primary| {
355                let primary = primary.ok_or_else(|| {
356                    anyhow::anyhow!("bind_inputs_to_streams requires primary input")
357                })?;
358                for (input_name, stream_name) in &pairs {
359                    let value = emit.flow_input(input_name)?;
360                    emit.state.streams.insert(stream_name.clone(), value);
361                }
362                Ok(Some(primary))
363            },
364        )));
365        self
366    }
367
368    pub fn dual_stream<F>(
369        mut self,
370        name: impl Into<String>,
371        stream_a: impl Into<String>,
372        stream_b: impl Into<String>,
373        f: F,
374    ) -> Self
375    where
376        F: Fn(&mut Emit<'_>, FlowValue, FlowValue) -> anyhow::Result<(FlowValue, FlowValue)>
377            + Send
378            + Sync
379            + 'static,
380    {
381        self.stages.push(FlowStage::DualStream(DualStreamStage::new(
382            name, stream_a, stream_b, f,
383        )));
384        self
385    }
386
387    pub fn plugin<F>(mut self, f: F) -> Self
388    where
389        F: Fn(&mut Emit<'_>, Option<FlowValue>) -> anyhow::Result<Option<FlowValue>>
390            + Send
391            + Sync
392            + 'static,
393    {
394        self.stages.push(crate::plugin::plugin(f));
395        self
396    }
397
398    pub fn plugin_named<F>(mut self, name: impl Into<String>, f: F) -> Self
399    where
400        F: Fn(&mut Emit<'_>, Option<FlowValue>) -> anyhow::Result<Option<FlowValue>>
401            + Send
402            + Sync
403            + 'static,
404    {
405        self.stages.push(crate::plugin::plugin_named(name, f));
406        self
407    }
408
409    /// Hidden states output (no LM head).
410    pub fn hidden_states(self) -> Self {
411        self.output("hidden")
412    }
413
414    /// LLaMA prefill decoder block at `layer_idx`.
415    pub fn llama_decoder_layer(
416        self,
417        layer_idx: usize,
418        spec: crate::blocks::LlamaDecoderSpec,
419    ) -> Self {
420        self.named_layer(
421            format!("layer{layer_idx}"),
422            FlowStage::LlamaDecoder(LlamaDecoderStage::layer(layer_idx, spec)),
423        )
424    }
425
426    /// LLaMA decode block with KV-cache concat.
427    pub fn llama_decode_layer(
428        self,
429        layer_idx: usize,
430        spec: crate::blocks::LlamaDecodeLayerSpec,
431        kv_out: SideOutputs,
432    ) -> Self {
433        self.named_layer(
434            format!("layer{layer_idx}"),
435            FlowStage::LlamaDecodeLayer(LlamaDecodeLayerStage::layer(
436                layer_idx,
437                spec,
438                kv_out.inner(),
439            )),
440        )
441    }
442
443    /// Side-effect K/V tap before a prefill layer (exports cache tensors).
444    pub fn llama_kv_tap(
445        mut self,
446        layer_idx: usize,
447        head_dim: usize,
448        eps: f32,
449        sink: &SideOutputs,
450    ) -> Self {
451        self.stages
452            .push(FlowStage::LlamaKvTap(LlamaKvTapStage::layer(
453                layer_idx,
454                head_dim,
455                eps,
456                sink.inner(),
457            )));
458        self
459    }
460
461    /// Final RMSNorm before LM head (`model.norm.weight` by default).
462    pub fn final_norm(self, eps: f32) -> Self {
463        self.rms_norm("model.norm.weight", eps)
464    }
465
466    pub fn rms_norm(mut self, weight_key: impl Into<String>, eps: f32) -> Self {
467        self.stages
468            .push(FlowStage::RmsNorm(RmsNormStage::new(weight_key, eps)));
469        self
470    }
471
472    /// Gather last token (dynamic `last_token_idx` input).
473    pub fn gather_last_token_dynamic(mut self, batch: usize) -> Self {
474        self.stages
475            .push(FlowStage::GatherLastToken(GatherLastTokenStage::dynamic(
476                batch,
477            )));
478        self
479    }
480
481    /// Gather last token at fixed sequence length.
482    pub fn gather_last_token_at(mut self, batch: usize, seq: usize) -> Self {
483        self.stages.push(FlowStage::GatherLastToken(
484            GatherLastTokenStage::static_last(batch, seq),
485        ));
486        self
487    }
488
489    /// Causal LM head — tied or separate weights.
490    pub fn lm_head(
491        mut self,
492        vocab_size: usize,
493        hidden_size: usize,
494        tie_word_embeddings: bool,
495    ) -> Self {
496        let stage = if tie_word_embeddings {
497            LmHeadStage::tied(vocab_size, hidden_size)
498        } else {
499            LmHeadStage::separate("lm_head.weight", vocab_size, hidden_size)
500        };
501        self.stages.push(FlowStage::LmHead(stage));
502        self.output("logits")
503    }
504
505    /// Tier-2 escape hatch — append a raw stage.
506    pub fn raw_stage(mut self, stage: FlowStage) -> Self {
507        self.stages.push(stage);
508        self
509    }
510
511    /// Append multiple raw stages in order.
512    pub fn raw_stages(mut self, stages: impl IntoIterator<Item = FlowStage>) -> Self {
513        self.stages.extend(stages);
514        self
515    }
516
517    /// Run a list of stages as one nested sequence (side-effect stages allowed).
518    pub fn sequence(mut self, stages: impl IntoIterator<Item = FlowStage>) -> Self {
519        self.stages
520            .push(FlowStage::Sequence(stages.into_iter().collect()));
521        self
522    }
523
524    /// Conditionally transform the builder (e.g. optional vision tower).
525    pub fn when(self, cond: bool, f: impl FnOnce(Self) -> Self) -> Self {
526        if cond { f(self) } else { self }
527    }
528
529    /// Tier-2 custom subgraph — prefer promoting repeated patterns to blocks.
530    pub fn custom<F>(mut self, f: F) -> Self
531    where
532        F: Fn(&mut Emit<'_>, Option<FlowValue>) -> anyhow::Result<Option<FlowValue>>
533            + Send
534            + Sync
535            + 'static,
536    {
537        self.stages.push(FlowStage::Custom(CustomStage::new(f)));
538        self
539    }
540
541    /// Named custom subgraph (shows up in fusion / inspect dumps).
542    pub fn custom_named<F>(mut self, name: impl Into<String>, f: F) -> Self
543    where
544        F: Fn(&mut Emit<'_>, Option<FlowValue>) -> anyhow::Result<Option<FlowValue>>
545            + Send
546            + Sync
547            + 'static,
548    {
549        self.stages
550            .push(FlowStage::Custom(CustomStage::named(name, f)));
551        self
552    }
553
554    /// Patch the builder after preset assembly (arch recipes, Llama32Flow hooks).
555    pub fn patch(self, f: impl FnOnce(Self) -> Self) -> Self {
556        f(self)
557    }
558}