1use std::path::Path;
7use std::sync::Arc;
8
9use crate::blocks::{
10 AttnMaskStage, BertEncoderLayerSpec, BertEncoderLayerStage, BertQkvStyle,
11 BindDecodeInputsStage, ClsTokenPoolStage, CustomStage, EmbedStage, GatherAddStage,
12 GatherFromInputStage, GatherLastTokenStage, GeluFfnStage, LayerNormStage, LinearStage,
13 LlamaDecodeLayerStage, LlamaDecoderSpec, LlamaDecoderStage, LlamaKvTapStage, LmHeadStage,
14 NomicEncoderLayerSpec, NomicEncoderLayerStage, RepeatStage, ResidualAddStage,
15 ResidualSaveStage, RmsNormStage, RopeTablesStage, SelfAttnPrefillSpec, SelfAttnPrefillStage,
16 SwiGluStage, dinov2_layer_fused, llama_prefill_layer_composed, llama_prefill_layer_fused,
17 nomic_vision_layer_fused,
18};
19use crate::escape::Emit;
20use crate::flow::ModelFlow;
21use crate::layer::LayerStack;
22use crate::profile::CompileProfile;
23use crate::side::SideOutputs;
24use crate::stage::FlowStage;
25use crate::stream::{DualStreamStage, LoadStreamStage, StoreStreamStage};
26use crate::value::FlowValue;
27
28impl ModelFlow {
29 pub fn profile_file(mut self, path: impl AsRef<Path>, default: fn() -> CompileProfile) -> Self {
31 self.profile = CompileProfile::from_toml_path(path.as_ref()).unwrap_or_else(|_| default());
32 self
33 }
34
35 pub fn profile_encoder(mut self) -> Self {
37 self.profile = CompileProfile::encoder();
38 self
39 }
40
41 pub fn gather_from_input(
43 mut self,
44 input_name: impl Into<String>,
45 weight_key: impl Into<String>,
46 ) -> Self {
47 self.stages
48 .push(FlowStage::GatherFromInput(GatherFromInputStage::new(
49 input_name, weight_key, 0,
50 )));
51 self
52 }
53
54 pub fn gather_add(
56 mut self,
57 input_name: impl Into<String>,
58 weight_key: impl Into<String>,
59 ) -> Self {
60 self.stages.push(FlowStage::GatherAdd(GatherAddStage::new(
61 input_name, weight_key, 0,
62 )));
63 self
64 }
65
66 pub fn layer_norm(
68 mut self,
69 gamma_key: impl Into<String>,
70 beta_key: impl Into<String>,
71 eps: f32,
72 ) -> Self {
73 self.stages.push(FlowStage::LayerNorm(LayerNormStage::new(
74 gamma_key, beta_key, eps,
75 )));
76 self
77 }
78
79 pub fn gelu_ffn(mut self, layer_prefix: impl Into<String>) -> Self {
81 self.stages
82 .push(FlowStage::GeluFfn(GeluFfnStage::hf_bert(layer_prefix)));
83 self
84 }
85
86 pub fn repeat_nomic_layers(
88 self,
89 count: usize,
90 hidden_size: usize,
91 num_heads: usize,
92 head_dim: usize,
93 eps: f32,
94 ) -> Self {
95 self.repeat_layers(count, move |i| FlowStage::Named {
96 name: format!("layer{i}"),
97 inner: std::sync::Arc::new(FlowStage::NomicEncoderLayer(NomicEncoderLayerStage::new(
98 NomicEncoderLayerSpec::hf(
99 format!("encoder.layers.{i}"),
100 hidden_size,
101 num_heads,
102 head_dim,
103 eps,
104 ),
105 ))),
106 })
107 }
108
109 pub fn bert_encoder_layer(mut self, spec: BertEncoderLayerSpec) -> Self {
111 self.stages
112 .push(FlowStage::BertEncoderLayer(BertEncoderLayerStage::new(
113 spec,
114 )));
115 self
116 }
117
118 pub fn repeat_bert_layers(
120 self,
121 count: usize,
122 prefix: impl Into<String>,
123 qkv_style: BertQkvStyle,
124 hidden_size: usize,
125 num_heads: usize,
126 eps: f32,
127 ) -> Self {
128 let prefix = prefix.into();
129 self.repeat_layers(count, move |i| {
130 let lp = if prefix.is_empty() {
131 format!("encoder.layer.{i}")
132 } else {
133 format!("{prefix}.encoder.layer.{i}")
134 };
135 FlowStage::Named {
136 name: format!("layer{i}"),
137 inner: std::sync::Arc::new(FlowStage::BertEncoderLayer(
138 BertEncoderLayerStage::new(BertEncoderLayerSpec::hf(
139 lp,
140 qkv_style,
141 hidden_size,
142 num_heads,
143 eps,
144 )),
145 )),
146 }
147 })
148 }
149
150 pub fn attn_mask_ones(mut self, batch: usize, seq: usize) -> Self {
152 self.stages
153 .push(FlowStage::AttnMask(AttnMaskStage::ones(batch, seq)));
154 self
155 }
156
157 pub fn repeat_dinov2_layers(
159 self,
160 count: usize,
161 hidden_size: usize,
162 num_heads: usize,
163 eps: f32,
164 ) -> Self {
165 self.repeat_layers(count, move |i| {
166 dinov2_layer_fused(i, hidden_size, num_heads, eps)
167 })
168 }
169
170 pub fn repeat_vision_layers(
172 self,
173 count: usize,
174 hidden_size: usize,
175 num_heads: usize,
176 eps: f32,
177 ) -> Self {
178 self.repeat_layers(count, move |i| {
179 nomic_vision_layer_fused(i, hidden_size, num_heads, eps)
180 })
181 }
182
183 pub fn repeat_siglip_layers(
185 self,
186 count: usize,
187 hidden_size: usize,
188 num_heads: usize,
189 eps: f32,
190 ) -> Self {
191 self.repeat_layers(count, move |i| {
192 nomic_vision_layer_fused(i, hidden_size, num_heads, eps)
193 })
194 }
195
196 pub fn cls_token_pool(mut self, batch: usize, hidden: usize) -> Self {
198 self.stages
199 .push(FlowStage::ClsTokenPool(ClsTokenPoolStage::new(
200 batch, hidden,
201 )));
202 self
203 }
204
205 pub fn profile_prefill(mut self) -> Self {
207 self.profile = CompileProfile::llama32_prefill();
208 self
209 }
210
211 pub fn profile_decode(mut self) -> Self {
213 self.profile = CompileProfile::llama32_decode();
214 self
215 }
216
217 pub fn embed(mut self, weight_key: impl Into<String>) -> Self {
219 self.stages
220 .push(FlowStage::Embed(EmbedStage::token(weight_key)));
221 self
222 }
223
224 pub fn token_embed(self) -> Self {
226 self.embed("model.embed_tokens.weight")
227 }
228
229 pub fn rope_tables(mut self, tables: RopeTablesStage) -> Self {
231 self.stages.push(FlowStage::RopeTables(tables));
232 self
233 }
234
235 pub fn zero_beta(self, len: usize) -> Self {
237 self.zero_beta_named("zero_beta", len)
238 }
239
240 pub fn zero_beta_named(mut self, name: impl Into<String>, len: usize) -> Self {
241 self.stages.push(FlowStage::ZeroBeta {
242 name: name.into(),
243 len,
244 });
245 self
246 }
247
248 pub fn bind_decode_inputs(mut self, num_layers: usize, custom_mask: bool) -> Self {
250 self.stages
251 .push(FlowStage::BindDecodeInputs(BindDecodeInputsStage {
252 num_layers,
253 use_custom_mask: custom_mask,
254 }));
255 self
256 }
257
258 pub fn repeat_layers(
260 mut self,
261 count: usize,
262 stage_for_layer: impl Fn(usize) -> FlowStage + Send + Sync + 'static,
263 ) -> Self {
264 self.stages
265 .push(FlowStage::Repeat(RepeatStage::new(count, stage_for_layer)));
266 self
267 }
268
269 pub fn named_layer(mut self, name: impl Into<String>, inner: FlowStage) -> Self {
271 self.stages.push(FlowStage::Named {
272 name: name.into(),
273 inner: Arc::new(inner),
274 });
275 self
276 }
277
278 pub fn layer(
280 self,
281 name: impl Into<String>,
282 build: impl FnOnce(LayerStack) -> LayerStack,
283 ) -> Self {
284 self.raw_stage(build(LayerStack::named(name)).build())
285 }
286
287 pub fn llama_prefill_layer(self, layer_idx: usize, spec: LlamaDecoderSpec) -> Self {
289 self.raw_stage(llama_prefill_layer_fused(layer_idx, spec))
290 }
291
292 pub fn llama_prefill_layer_composed(self, layer_idx: usize, spec: LlamaDecoderSpec) -> Self {
294 self.raw_stage(llama_prefill_layer_composed(layer_idx, spec))
295 }
296
297 pub fn linear(mut self, weight_key: impl Into<String>, transpose: bool) -> Self {
298 self.stages
299 .push(FlowStage::Linear(LinearStage::new(weight_key, transpose)));
300 self
301 }
302
303 pub fn residual_save(mut self) -> Self {
304 self.stages.push(FlowStage::ResidualSave(ResidualSaveStage));
305 self
306 }
307
308 pub fn residual_add(mut self) -> Self {
309 self.stages.push(FlowStage::ResidualAdd(ResidualAddStage));
310 self
311 }
312
313 pub fn swiglu(
314 mut self,
315 gate_key: impl Into<String>,
316 up_key: impl Into<String>,
317 down_key: impl Into<String>,
318 ) -> Self {
319 self.stages.push(FlowStage::SwiGlu(SwiGluStage::new(
320 gate_key, up_key, down_key,
321 )));
322 self
323 }
324
325 pub fn swiglu_hf_mlp(mut self, prefix: impl Into<String>) -> Self {
326 self.stages
327 .push(FlowStage::SwiGlu(SwiGluStage::hf_mlp(prefix)));
328 self
329 }
330
331 pub fn self_attn_prefill(mut self, spec: SelfAttnPrefillSpec) -> Self {
332 self.stages
333 .push(FlowStage::SelfAttnPrefill(SelfAttnPrefillStage::new(spec)));
334 self
335 }
336
337 pub fn gdn_scan(mut self, stage: crate::blocks::GdnScanStage) -> Self {
338 self.stages.push(FlowStage::GdnScan(stage));
339 self
340 }
341
342 pub fn store_stream(mut self, name: impl Into<String>) -> Self {
343 self.stages
344 .push(FlowStage::StoreStream(StoreStreamStage::new(name)));
345 self
346 }
347
348 pub fn load_stream(mut self, name: impl Into<String>) -> Self {
349 self.stages
350 .push(FlowStage::LoadStream(LoadStreamStage::new(name)));
351 self
352 }
353
354 pub fn bind_inputs_to_streams(
358 mut self,
359 pairs: impl IntoIterator<Item = (impl Into<String>, impl Into<String>)>,
360 ) -> Self {
361 let pairs: Vec<(String, String)> = pairs
362 .into_iter()
363 .map(|(input, stream)| (input.into(), stream.into()))
364 .collect();
365 self.stages.push(FlowStage::Custom(CustomStage::named(
366 "bind_inputs_to_streams",
367 move |emit, primary| {
368 let primary = primary.ok_or_else(|| {
369 anyhow::anyhow!("bind_inputs_to_streams requires primary input")
370 })?;
371 for (input_name, stream_name) in &pairs {
372 let value = emit.flow_input(input_name)?;
373 emit.state.streams.insert(stream_name.clone(), value);
374 }
375 Ok(Some(primary))
376 },
377 )));
378 self
379 }
380
381 pub fn dual_stream<F>(
382 mut self,
383 name: impl Into<String>,
384 stream_a: impl Into<String>,
385 stream_b: impl Into<String>,
386 f: F,
387 ) -> Self
388 where
389 F: Fn(&mut Emit<'_>, FlowValue, FlowValue) -> anyhow::Result<(FlowValue, FlowValue)>
390 + Send
391 + Sync
392 + 'static,
393 {
394 self.stages.push(FlowStage::DualStream(DualStreamStage::new(
395 name, stream_a, stream_b, f,
396 )));
397 self
398 }
399
400 pub fn plugin<F>(mut self, f: F) -> Self
401 where
402 F: Fn(&mut Emit<'_>, Option<FlowValue>) -> anyhow::Result<Option<FlowValue>>
403 + Send
404 + Sync
405 + 'static,
406 {
407 self.stages.push(crate::plugin::plugin(f));
408 self
409 }
410
411 pub fn plugin_named<F>(mut self, name: impl Into<String>, f: F) -> Self
412 where
413 F: Fn(&mut Emit<'_>, Option<FlowValue>) -> anyhow::Result<Option<FlowValue>>
414 + Send
415 + Sync
416 + 'static,
417 {
418 self.stages.push(crate::plugin::plugin_named(name, f));
419 self
420 }
421
422 pub fn hidden_states(self) -> Self {
424 self.output("hidden")
425 }
426
427 pub fn llama_decoder_layer(
429 self,
430 layer_idx: usize,
431 spec: crate::blocks::LlamaDecoderSpec,
432 ) -> Self {
433 self.named_layer(
434 format!("layer{layer_idx}"),
435 FlowStage::LlamaDecoder(LlamaDecoderStage::layer(layer_idx, spec)),
436 )
437 }
438
439 pub fn llama_decode_layer(
441 self,
442 layer_idx: usize,
443 spec: crate::blocks::LlamaDecodeLayerSpec,
444 kv_out: SideOutputs,
445 ) -> Self {
446 self.named_layer(
447 format!("layer{layer_idx}"),
448 FlowStage::LlamaDecodeLayer(LlamaDecodeLayerStage::layer(
449 layer_idx,
450 spec,
451 kv_out.inner(),
452 )),
453 )
454 }
455
456 pub fn llama_kv_tap(
458 mut self,
459 layer_idx: usize,
460 head_dim: usize,
461 eps: f32,
462 sink: &SideOutputs,
463 ) -> Self {
464 self.stages
465 .push(FlowStage::LlamaKvTap(LlamaKvTapStage::layer(
466 layer_idx,
467 head_dim,
468 eps,
469 sink.inner(),
470 )));
471 self
472 }
473
474 pub fn final_norm(self, eps: f32) -> Self {
476 self.rms_norm("model.norm.weight", eps)
477 }
478
479 pub fn rms_norm(mut self, weight_key: impl Into<String>, eps: f32) -> Self {
480 self.stages
481 .push(FlowStage::RmsNorm(RmsNormStage::new(weight_key, eps)));
482 self
483 }
484
485 pub fn gather_last_token_dynamic(mut self, batch: usize) -> Self {
487 self.stages
488 .push(FlowStage::GatherLastToken(GatherLastTokenStage::dynamic(
489 batch,
490 )));
491 self
492 }
493
494 pub fn gather_last_token_at(mut self, batch: usize, seq: usize) -> Self {
496 self.stages.push(FlowStage::GatherLastToken(
497 GatherLastTokenStage::static_last(batch, seq),
498 ));
499 self
500 }
501
502 pub fn lm_head(
504 mut self,
505 vocab_size: usize,
506 hidden_size: usize,
507 tie_word_embeddings: bool,
508 ) -> Self {
509 let stage = if tie_word_embeddings {
510 LmHeadStage::tied(vocab_size, hidden_size)
511 } else {
512 LmHeadStage::separate("lm_head.weight", vocab_size, hidden_size)
513 };
514 self.stages.push(FlowStage::LmHead(stage));
515 self.output("logits")
516 }
517
518 pub fn raw_stage(mut self, stage: FlowStage) -> Self {
520 self.stages.push(stage);
521 self
522 }
523
524 pub fn raw_stages(mut self, stages: impl IntoIterator<Item = FlowStage>) -> Self {
526 self.stages.extend(stages);
527 self
528 }
529
530 pub fn sequence(mut self, stages: impl IntoIterator<Item = FlowStage>) -> Self {
532 self.stages
533 .push(FlowStage::Sequence(stages.into_iter().collect()));
534 self
535 }
536
537 pub fn when(self, cond: bool, f: impl FnOnce(Self) -> Self) -> Self {
539 if cond { f(self) } else { self }
540 }
541
542 pub fn custom<F>(mut self, f: F) -> Self
544 where
545 F: Fn(&mut Emit<'_>, Option<FlowValue>) -> anyhow::Result<Option<FlowValue>>
546 + Send
547 + Sync
548 + 'static,
549 {
550 self.stages.push(FlowStage::Custom(CustomStage::new(f)));
551 self
552 }
553
554 pub fn custom_named<F>(mut self, name: impl Into<String>, f: F) -> Self
556 where
557 F: Fn(&mut Emit<'_>, Option<FlowValue>) -> anyhow::Result<Option<FlowValue>>
558 + Send
559 + Sync
560 + 'static,
561 {
562 self.stages
563 .push(FlowStage::Custom(CustomStage::named(name, f)));
564 self
565 }
566
567 pub fn patch(self, f: impl FnOnce(Self) -> Self) -> Self {
569 f(self)
570 }
571}