Skip to main content

rlx_flow/
dsl.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3
4//! Fluent builder methods on [`ModelFlow`] — sugar over [`FlowStage`].
5
6use std::path::Path;
7use std::sync::Arc;
8
9use crate::blocks::{
10    AttnMaskStage, BertEncoderLayerSpec, BertEncoderLayerStage, BertQkvStyle,
11    BindDecodeInputsStage, ClsTokenPoolStage, CustomStage, EmbedStage, GatherAddStage,
12    GatherFromInputStage, GatherLastTokenStage, GeluFfnStage, LayerNormStage, LinearStage,
13    LlamaDecodeLayerStage, LlamaDecoderSpec, LlamaDecoderStage, LlamaKvTapStage, LmHeadStage,
14    NomicEncoderLayerSpec, NomicEncoderLayerStage, RepeatStage, ResidualAddStage,
15    ResidualSaveStage, RmsNormStage, RopeTablesStage, SelfAttnPrefillSpec, SelfAttnPrefillStage,
16    SwiGluStage, dinov2_layer_fused, llama_prefill_layer_composed, llama_prefill_layer_fused,
17    nomic_vision_layer_fused,
18};
19use crate::escape::Emit;
20use crate::flow::ModelFlow;
21use crate::layer::LayerStack;
22use crate::profile::CompileProfile;
23use crate::side::SideOutputs;
24use crate::stage::FlowStage;
25use crate::stream::{DualStreamStage, LoadStreamStage, StoreStreamStage};
26use crate::value::FlowValue;
27
28impl ModelFlow {
29    /// Load tier-1 profile from a `*.rlx.toml` file (falls back to default on error).
30    pub fn profile_file(mut self, path: impl AsRef<Path>, default: fn() -> CompileProfile) -> Self {
31        self.profile = CompileProfile::from_toml_path(path.as_ref()).unwrap_or_else(|_| default());
32        self
33    }
34
35    /// Encoder / embedding model defaults (Direct lowering, no KV fusion).
36    pub fn profile_encoder(mut self) -> Self {
37        self.profile = CompileProfile::encoder();
38        self
39    }
40
41    /// Gather rows from a side input into the primary flow (starts embedding stack).
42    pub fn gather_from_input(
43        mut self,
44        input_name: impl Into<String>,
45        weight_key: impl Into<String>,
46    ) -> Self {
47        self.stages
48            .push(FlowStage::GatherFromInput(GatherFromInputStage::new(
49                input_name, weight_key, 0,
50            )));
51        self
52    }
53
54    /// Add an embedding looked up from a side input.
55    pub fn gather_add(
56        mut self,
57        input_name: impl Into<String>,
58        weight_key: impl Into<String>,
59    ) -> Self {
60        self.stages.push(FlowStage::GatherAdd(GatherAddStage::new(
61            input_name, weight_key, 0,
62        )));
63        self
64    }
65
66    /// LayerNorm with separate gamma/beta weights.
67    pub fn layer_norm(
68        mut self,
69        gamma_key: impl Into<String>,
70        beta_key: impl Into<String>,
71        eps: f32,
72    ) -> Self {
73        self.stages.push(FlowStage::LayerNorm(LayerNormStage::new(
74            gamma_key, beta_key, eps,
75        )));
76        self
77    }
78
79    /// BERT-style GELU FFN under a layer prefix.
80    pub fn gelu_ffn(mut self, layer_prefix: impl Into<String>) -> Self {
81        self.stages
82            .push(FlowStage::GeluFfn(GeluFfnStage::hf_bert(layer_prefix)));
83        self
84    }
85
86    /// Repeat NomicBERT encoder layers.
87    pub fn repeat_nomic_layers(
88        self,
89        count: usize,
90        hidden_size: usize,
91        num_heads: usize,
92        head_dim: usize,
93        eps: f32,
94    ) -> Self {
95        self.repeat_layers(count, move |i| FlowStage::Named {
96            name: format!("layer{i}"),
97            inner: std::sync::Arc::new(FlowStage::NomicEncoderLayer(NomicEncoderLayerStage::new(
98                NomicEncoderLayerSpec::hf(
99                    format!("encoder.layers.{i}"),
100                    hidden_size,
101                    num_heads,
102                    head_dim,
103                    eps,
104                ),
105            ))),
106        })
107    }
108
109    /// BERT-style encoder layer (fused QKV + padding-mask attention + GELU FFN).
110    pub fn bert_encoder_layer(mut self, spec: BertEncoderLayerSpec) -> Self {
111        self.stages
112            .push(FlowStage::BertEncoderLayer(BertEncoderLayerStage::new(
113                spec,
114            )));
115        self
116    }
117
118    /// Repeat BERT encoder layers with auto-named prefixes.
119    pub fn repeat_bert_layers(
120        self,
121        count: usize,
122        prefix: impl Into<String>,
123        qkv_style: BertQkvStyle,
124        hidden_size: usize,
125        num_heads: usize,
126        eps: f32,
127    ) -> Self {
128        let prefix = prefix.into();
129        self.repeat_layers(count, move |i| {
130            let lp = if prefix.is_empty() {
131                format!("encoder.layer.{i}")
132            } else {
133                format!("{prefix}.encoder.layer.{i}")
134            };
135            FlowStage::Named {
136                name: format!("layer{i}"),
137                inner: std::sync::Arc::new(FlowStage::BertEncoderLayer(
138                    BertEncoderLayerStage::new(BertEncoderLayerSpec::hf(
139                        lp,
140                        qkv_style,
141                        hidden_size,
142                        num_heads,
143                        eps,
144                    )),
145                )),
146            }
147        })
148    }
149
150    /// Synthesize an all-ones attention mask for vision encoders (no padding).
151    pub fn attn_mask_ones(mut self, batch: usize, seq: usize) -> Self {
152        self.stages
153            .push(FlowStage::AttnMask(AttnMaskStage::ones(batch, seq)));
154        self
155    }
156
157    /// Repeat DINOv2 ViT encoder blocks.
158    pub fn repeat_dinov2_layers(
159        self,
160        count: usize,
161        hidden_size: usize,
162        num_heads: usize,
163        eps: f32,
164    ) -> Self {
165        self.repeat_layers(count, move |i| {
166            dinov2_layer_fused(i, hidden_size, num_heads, eps)
167        })
168    }
169
170    /// Repeat NomicVision encoder blocks.
171    pub fn repeat_vision_layers(
172        self,
173        count: usize,
174        hidden_size: usize,
175        num_heads: usize,
176        eps: f32,
177    ) -> Self {
178        self.repeat_layers(count, move |i| {
179            nomic_vision_layer_fused(i, hidden_size, num_heads, eps)
180        })
181    }
182
183    /// Compatibility shim: repeat SigLIP-style vision layers.
184    pub fn repeat_siglip_layers(
185        self,
186        count: usize,
187        hidden_size: usize,
188        num_heads: usize,
189        eps: f32,
190    ) -> Self {
191        self.repeat_layers(count, move |i| {
192            nomic_vision_layer_fused(i, hidden_size, num_heads, eps)
193        })
194    }
195
196    /// Pool CLS token: `[batch, seq, hidden]` → `[batch, hidden]`.
197    pub fn cls_token_pool(mut self, batch: usize, hidden: usize) -> Self {
198        self.stages
199            .push(FlowStage::ClsTokenPool(ClsTokenPoolStage::new(
200                batch, hidden,
201            )));
202        self
203    }
204
205    /// Fusion-first prefill defaults.
206    pub fn profile_prefill(mut self) -> Self {
207        self.profile = CompileProfile::llama32_prefill();
208        self
209    }
210
211    /// Decode / KV-cache defaults (`Fusable` lowering).
212    pub fn profile_decode(mut self) -> Self {
213        self.profile = CompileProfile::llama32_decode();
214        self
215    }
216
217    /// Token embedding (`model.embed_tokens.weight` by default).
218    pub fn embed(mut self, weight_key: impl Into<String>) -> Self {
219        self.stages
220            .push(FlowStage::Embed(EmbedStage::token(weight_key)));
221        self
222    }
223
224    /// HuggingFace-style token embedding table.
225    pub fn token_embed(self) -> Self {
226        self.embed("model.embed_tokens.weight")
227    }
228
229    /// Precomputed RoPE sin/cos tables stored as params.
230    pub fn rope_tables(mut self, tables: RopeTablesStage) -> Self {
231        self.stages.push(FlowStage::RopeTables(tables));
232        self
233    }
234
235    /// Rank-1 zero vector for RMSNorm beta slots (LLaMA has no beta).
236    pub fn zero_beta(self, len: usize) -> Self {
237        self.zero_beta_named("zero_beta", len)
238    }
239
240    pub fn zero_beta_named(mut self, name: impl Into<String>, len: usize) -> Self {
241        self.stages.push(FlowStage::ZeroBeta {
242            name: name.into(),
243            len,
244        });
245        self
246    }
247
248    /// Bind decode inputs (call after declaring `rope_cos`, `past_k_*`, …).
249    pub fn bind_decode_inputs(mut self, num_layers: usize, custom_mask: bool) -> Self {
250        self.stages
251            .push(FlowStage::BindDecodeInputs(BindDecodeInputsStage {
252                num_layers,
253                use_custom_mask: custom_mask,
254            }));
255        self
256    }
257
258    /// Repeat a per-layer stage `count` times (layer index passed to closure).
259    pub fn repeat_layers(
260        mut self,
261        count: usize,
262        stage_for_layer: impl Fn(usize) -> FlowStage + Send + Sync + 'static,
263    ) -> Self {
264        self.stages
265            .push(FlowStage::Repeat(RepeatStage::new(count, stage_for_layer)));
266        self
267    }
268
269    /// Named decoder layer (shows up in fusion / inspect dumps).
270    pub fn named_layer(mut self, name: impl Into<String>, inner: FlowStage) -> Self {
271        self.stages.push(FlowStage::Named {
272            name: name.into(),
273            inner: Arc::new(inner),
274        });
275        self
276    }
277
278    /// Build a named layer from a [`LayerStack`] closure.
279    pub fn layer(
280        self,
281        name: impl Into<String>,
282        build: impl FnOnce(LayerStack) -> LayerStack,
283    ) -> Self {
284        self.raw_stage(build(LayerStack::named(name)).build())
285    }
286
287    /// Fused LLaMA prefill layer (default fast path).
288    pub fn llama_prefill_layer(self, layer_idx: usize, spec: LlamaDecoderSpec) -> Self {
289        self.raw_stage(llama_prefill_layer_fused(layer_idx, spec))
290    }
291
292    /// Composed LLaMA prefill layer (small blocks — customize via [`LayerStack`]).
293    pub fn llama_prefill_layer_composed(self, layer_idx: usize, spec: LlamaDecoderSpec) -> Self {
294        self.raw_stage(llama_prefill_layer_composed(layer_idx, spec))
295    }
296
297    pub fn linear(mut self, weight_key: impl Into<String>, transpose: bool) -> Self {
298        self.stages
299            .push(FlowStage::Linear(LinearStage::new(weight_key, transpose)));
300        self
301    }
302
303    pub fn residual_save(mut self) -> Self {
304        self.stages.push(FlowStage::ResidualSave(ResidualSaveStage));
305        self
306    }
307
308    pub fn residual_add(mut self) -> Self {
309        self.stages.push(FlowStage::ResidualAdd(ResidualAddStage));
310        self
311    }
312
313    pub fn swiglu(
314        mut self,
315        gate_key: impl Into<String>,
316        up_key: impl Into<String>,
317        down_key: impl Into<String>,
318    ) -> Self {
319        self.stages.push(FlowStage::SwiGlu(SwiGluStage::new(
320            gate_key, up_key, down_key,
321        )));
322        self
323    }
324
325    pub fn swiglu_hf_mlp(mut self, prefix: impl Into<String>) -> Self {
326        self.stages
327            .push(FlowStage::SwiGlu(SwiGluStage::hf_mlp(prefix)));
328        self
329    }
330
331    pub fn self_attn_prefill(mut self, spec: SelfAttnPrefillSpec) -> Self {
332        self.stages
333            .push(FlowStage::SelfAttnPrefill(SelfAttnPrefillStage::new(spec)));
334        self
335    }
336
337    pub fn gdn_scan(mut self, stage: crate::blocks::GdnScanStage) -> Self {
338        self.stages.push(FlowStage::GdnScan(stage));
339        self
340    }
341
342    pub fn store_stream(mut self, name: impl Into<String>) -> Self {
343        self.stages
344            .push(FlowStage::StoreStream(StoreStreamStage::new(name)));
345        self
346    }
347
348    pub fn load_stream(mut self, name: impl Into<String>) -> Self {
349        self.stages
350            .push(FlowStage::LoadStream(LoadStreamStage::new(name)));
351        self
352    }
353
354    /// Bind declared graph inputs into named streams (multi-input models).
355    ///
356    /// Example: FLUX `.bind_inputs_to_streams(&[("hidden", "img"), ("encoder", "txt")])`.
357    pub fn bind_inputs_to_streams(
358        mut self,
359        pairs: impl IntoIterator<Item = (impl Into<String>, impl Into<String>)>,
360    ) -> Self {
361        let pairs: Vec<(String, String)> = pairs
362            .into_iter()
363            .map(|(input, stream)| (input.into(), stream.into()))
364            .collect();
365        self.stages.push(FlowStage::Custom(CustomStage::named(
366            "bind_inputs_to_streams",
367            move |emit, primary| {
368                let primary = primary.ok_or_else(|| {
369                    anyhow::anyhow!("bind_inputs_to_streams requires primary input")
370                })?;
371                for (input_name, stream_name) in &pairs {
372                    let value = emit.flow_input(input_name)?;
373                    emit.state.streams.insert(stream_name.clone(), value);
374                }
375                Ok(Some(primary))
376            },
377        )));
378        self
379    }
380
381    pub fn dual_stream<F>(
382        mut self,
383        name: impl Into<String>,
384        stream_a: impl Into<String>,
385        stream_b: impl Into<String>,
386        f: F,
387    ) -> Self
388    where
389        F: Fn(&mut Emit<'_>, FlowValue, FlowValue) -> anyhow::Result<(FlowValue, FlowValue)>
390            + Send
391            + Sync
392            + 'static,
393    {
394        self.stages.push(FlowStage::DualStream(DualStreamStage::new(
395            name, stream_a, stream_b, f,
396        )));
397        self
398    }
399
400    pub fn plugin<F>(mut self, f: F) -> Self
401    where
402        F: Fn(&mut Emit<'_>, Option<FlowValue>) -> anyhow::Result<Option<FlowValue>>
403            + Send
404            + Sync
405            + 'static,
406    {
407        self.stages.push(crate::plugin::plugin(f));
408        self
409    }
410
411    pub fn plugin_named<F>(mut self, name: impl Into<String>, f: F) -> Self
412    where
413        F: Fn(&mut Emit<'_>, Option<FlowValue>) -> anyhow::Result<Option<FlowValue>>
414            + Send
415            + Sync
416            + 'static,
417    {
418        self.stages.push(crate::plugin::plugin_named(name, f));
419        self
420    }
421
422    /// Hidden states output (no LM head).
423    pub fn hidden_states(self) -> Self {
424        self.output("hidden")
425    }
426
427    /// LLaMA prefill decoder block at `layer_idx`.
428    pub fn llama_decoder_layer(
429        self,
430        layer_idx: usize,
431        spec: crate::blocks::LlamaDecoderSpec,
432    ) -> Self {
433        self.named_layer(
434            format!("layer{layer_idx}"),
435            FlowStage::LlamaDecoder(LlamaDecoderStage::layer(layer_idx, spec)),
436        )
437    }
438
439    /// LLaMA decode block with KV-cache concat.
440    pub fn llama_decode_layer(
441        self,
442        layer_idx: usize,
443        spec: crate::blocks::LlamaDecodeLayerSpec,
444        kv_out: SideOutputs,
445    ) -> Self {
446        self.named_layer(
447            format!("layer{layer_idx}"),
448            FlowStage::LlamaDecodeLayer(LlamaDecodeLayerStage::layer(
449                layer_idx,
450                spec,
451                kv_out.inner(),
452            )),
453        )
454    }
455
456    /// Side-effect K/V tap before a prefill layer (exports cache tensors).
457    pub fn llama_kv_tap(
458        mut self,
459        layer_idx: usize,
460        head_dim: usize,
461        eps: f32,
462        sink: &SideOutputs,
463    ) -> Self {
464        self.stages
465            .push(FlowStage::LlamaKvTap(LlamaKvTapStage::layer(
466                layer_idx,
467                head_dim,
468                eps,
469                sink.inner(),
470            )));
471        self
472    }
473
474    /// Final RMSNorm before LM head (`model.norm.weight` by default).
475    pub fn final_norm(self, eps: f32) -> Self {
476        self.rms_norm("model.norm.weight", eps)
477    }
478
479    pub fn rms_norm(mut self, weight_key: impl Into<String>, eps: f32) -> Self {
480        self.stages
481            .push(FlowStage::RmsNorm(RmsNormStage::new(weight_key, eps)));
482        self
483    }
484
485    /// Gather last token (dynamic `last_token_idx` input).
486    pub fn gather_last_token_dynamic(mut self, batch: usize) -> Self {
487        self.stages
488            .push(FlowStage::GatherLastToken(GatherLastTokenStage::dynamic(
489                batch,
490            )));
491        self
492    }
493
494    /// Gather last token at fixed sequence length.
495    pub fn gather_last_token_at(mut self, batch: usize, seq: usize) -> Self {
496        self.stages.push(FlowStage::GatherLastToken(
497            GatherLastTokenStage::static_last(batch, seq),
498        ));
499        self
500    }
501
502    /// Causal LM head — tied or separate weights.
503    pub fn lm_head(
504        mut self,
505        vocab_size: usize,
506        hidden_size: usize,
507        tie_word_embeddings: bool,
508    ) -> Self {
509        let stage = if tie_word_embeddings {
510            LmHeadStage::tied(vocab_size, hidden_size)
511        } else {
512            LmHeadStage::separate("lm_head.weight", vocab_size, hidden_size)
513        };
514        self.stages.push(FlowStage::LmHead(stage));
515        self.output("logits")
516    }
517
518    /// Tier-2 escape hatch — append a raw stage.
519    pub fn raw_stage(mut self, stage: FlowStage) -> Self {
520        self.stages.push(stage);
521        self
522    }
523
524    /// Append multiple raw stages in order.
525    pub fn raw_stages(mut self, stages: impl IntoIterator<Item = FlowStage>) -> Self {
526        self.stages.extend(stages);
527        self
528    }
529
530    /// Run a list of stages as one nested sequence (side-effect stages allowed).
531    pub fn sequence(mut self, stages: impl IntoIterator<Item = FlowStage>) -> Self {
532        self.stages
533            .push(FlowStage::Sequence(stages.into_iter().collect()));
534        self
535    }
536
537    /// Conditionally transform the builder (e.g. optional vision tower).
538    pub fn when(self, cond: bool, f: impl FnOnce(Self) -> Self) -> Self {
539        if cond { f(self) } else { self }
540    }
541
542    /// Tier-2 custom subgraph — prefer promoting repeated patterns to blocks.
543    pub fn custom<F>(mut self, f: F) -> Self
544    where
545        F: Fn(&mut Emit<'_>, Option<FlowValue>) -> anyhow::Result<Option<FlowValue>>
546            + Send
547            + Sync
548            + 'static,
549    {
550        self.stages.push(FlowStage::Custom(CustomStage::new(f)));
551        self
552    }
553
554    /// Named custom subgraph (shows up in fusion / inspect dumps).
555    pub fn custom_named<F>(mut self, name: impl Into<String>, f: F) -> Self
556    where
557        F: Fn(&mut Emit<'_>, Option<FlowValue>) -> anyhow::Result<Option<FlowValue>>
558            + Send
559            + Sync
560            + 'static,
561    {
562        self.stages
563            .push(FlowStage::Custom(CustomStage::named(name, f)));
564        self
565    }
566
567    /// Patch the builder after preset assembly (arch recipes, Llama32Flow hooks).
568    pub fn patch(self, f: impl FnOnce(Self) -> Self) -> Self {
569        f(self)
570    }
571}