1use std::path::Path;
19use std::sync::Arc;
20
21use crate::blocks::{
22 AttnMaskStage, BertEncoderLayerSpec, BertEncoderLayerStage, BertQkvStyle,
23 BindDecodeInputsStage, ClsTokenPoolStage, CustomStage, EmbedStage, GatherAddStage,
24 GatherFromInputStage, GatherLastTokenStage, GeluFfnStage, LayerNormStage, LinearStage,
25 LlamaDecodeLayerStage, LlamaDecoderSpec, LlamaDecoderStage, LlamaKvTapStage, LmHeadStage,
26 NomicEncoderLayerSpec, NomicEncoderLayerStage, RepeatStage, ResidualAddStage,
27 ResidualSaveStage, RmsNormStage, RopeTablesStage, SelfAttnPrefillSpec, SelfAttnPrefillStage,
28 SwiGluStage, dinov2_layer_fused, llama_prefill_layer_composed, llama_prefill_layer_fused,
29 nomic_vision_layer_fused,
30};
31use crate::escape::Emit;
32use crate::flow::ModelFlow;
33use crate::layer::LayerStack;
34use crate::profile::CompileProfile;
35use crate::side::SideOutputs;
36use crate::stage::FlowStage;
37use crate::stream::{DualStreamStage, LoadStreamStage, StoreStreamStage};
38use crate::value::FlowValue;
39
40impl ModelFlow {
41 pub fn profile_file(mut self, path: impl AsRef<Path>, default: fn() -> CompileProfile) -> Self {
43 self.profile = CompileProfile::from_toml_path(path.as_ref()).unwrap_or_else(|_| default());
44 self
45 }
46
47 pub fn profile_encoder(mut self) -> Self {
49 self.profile = CompileProfile::encoder();
50 self
51 }
52
53 pub fn gather_from_input(
55 mut self,
56 input_name: impl Into<String>,
57 weight_key: impl Into<String>,
58 ) -> Self {
59 self.stages
60 .push(FlowStage::GatherFromInput(GatherFromInputStage::new(
61 input_name, weight_key, 0,
62 )));
63 self
64 }
65
66 pub fn gather_add(
68 mut self,
69 input_name: impl Into<String>,
70 weight_key: impl Into<String>,
71 ) -> Self {
72 self.stages.push(FlowStage::GatherAdd(GatherAddStage::new(
73 input_name, weight_key, 0,
74 )));
75 self
76 }
77
78 pub fn layer_norm(
80 mut self,
81 gamma_key: impl Into<String>,
82 beta_key: impl Into<String>,
83 eps: f32,
84 ) -> Self {
85 self.stages.push(FlowStage::LayerNorm(LayerNormStage::new(
86 gamma_key, beta_key, eps,
87 )));
88 self
89 }
90
91 pub fn gelu_ffn(mut self, layer_prefix: impl Into<String>) -> Self {
93 self.stages
94 .push(FlowStage::GeluFfn(GeluFfnStage::hf_bert(layer_prefix)));
95 self
96 }
97
98 pub fn repeat_nomic_layers(
100 self,
101 count: usize,
102 hidden_size: usize,
103 num_heads: usize,
104 head_dim: usize,
105 eps: f32,
106 ) -> Self {
107 self.repeat_layers(count, move |i| FlowStage::Named {
108 name: format!("layer{i}"),
109 inner: std::sync::Arc::new(FlowStage::NomicEncoderLayer(NomicEncoderLayerStage::new(
110 NomicEncoderLayerSpec::hf(
111 format!("encoder.layers.{i}"),
112 hidden_size,
113 num_heads,
114 head_dim,
115 eps,
116 ),
117 ))),
118 })
119 }
120
121 pub fn bert_encoder_layer(mut self, spec: BertEncoderLayerSpec) -> Self {
123 self.stages
124 .push(FlowStage::BertEncoderLayer(BertEncoderLayerStage::new(
125 spec,
126 )));
127 self
128 }
129
130 pub fn repeat_bert_layers(
132 self,
133 count: usize,
134 prefix: impl Into<String>,
135 qkv_style: BertQkvStyle,
136 hidden_size: usize,
137 num_heads: usize,
138 eps: f32,
139 ) -> Self {
140 let prefix = prefix.into();
141 self.repeat_layers(count, move |i| {
142 let lp = if prefix.is_empty() {
143 format!("encoder.layer.{i}")
144 } else {
145 format!("{prefix}.encoder.layer.{i}")
146 };
147 FlowStage::Named {
148 name: format!("layer{i}"),
149 inner: std::sync::Arc::new(FlowStage::BertEncoderLayer(
150 BertEncoderLayerStage::new(BertEncoderLayerSpec::hf(
151 lp,
152 qkv_style,
153 hidden_size,
154 num_heads,
155 eps,
156 )),
157 )),
158 }
159 })
160 }
161
162 pub fn attn_mask_ones(mut self, batch: usize, seq: usize) -> Self {
164 self.stages
165 .push(FlowStage::AttnMask(AttnMaskStage::ones(batch, seq)));
166 self
167 }
168
169 pub fn repeat_dinov2_layers(
171 self,
172 count: usize,
173 hidden_size: usize,
174 num_heads: usize,
175 eps: f32,
176 ) -> Self {
177 self.repeat_layers(count, move |i| {
178 dinov2_layer_fused(i, hidden_size, num_heads, eps)
179 })
180 }
181
182 pub fn repeat_vision_layers(
184 self,
185 count: usize,
186 hidden_size: usize,
187 num_heads: usize,
188 eps: f32,
189 ) -> Self {
190 self.repeat_layers(count, move |i| {
191 nomic_vision_layer_fused(i, hidden_size, num_heads, eps)
192 })
193 }
194
195 pub fn repeat_siglip_layers(
197 self,
198 count: usize,
199 hidden_size: usize,
200 num_heads: usize,
201 eps: f32,
202 ) -> Self {
203 self.repeat_layers(count, move |i| {
204 nomic_vision_layer_fused(i, hidden_size, num_heads, eps)
205 })
206 }
207
208 pub fn cls_token_pool(mut self, batch: usize, hidden: usize) -> Self {
210 self.stages
211 .push(FlowStage::ClsTokenPool(ClsTokenPoolStage::new(
212 batch, hidden,
213 )));
214 self
215 }
216
217 pub fn profile_prefill(mut self) -> Self {
219 self.profile = CompileProfile::llama32_prefill();
220 self
221 }
222
223 pub fn profile_decode(mut self) -> Self {
225 self.profile = CompileProfile::llama32_decode();
226 self
227 }
228
229 pub fn embed(mut self, weight_key: impl Into<String>) -> Self {
231 self.stages
232 .push(FlowStage::Embed(EmbedStage::token(weight_key)));
233 self
234 }
235
236 pub fn token_embed(self) -> Self {
238 self.embed("model.embed_tokens.weight")
239 }
240
241 pub fn rope_tables(mut self, tables: RopeTablesStage) -> Self {
243 self.stages.push(FlowStage::RopeTables(tables));
244 self
245 }
246
247 pub fn zero_beta(self, len: usize) -> Self {
249 self.zero_beta_named("zero_beta", len)
250 }
251
252 pub fn zero_beta_named(mut self, name: impl Into<String>, len: usize) -> Self {
253 self.stages.push(FlowStage::ZeroBeta {
254 name: name.into(),
255 len,
256 });
257 self
258 }
259
260 pub fn bind_decode_inputs(mut self, num_layers: usize, custom_mask: bool) -> Self {
262 self.stages
263 .push(FlowStage::BindDecodeInputs(BindDecodeInputsStage {
264 num_layers,
265 use_custom_mask: custom_mask,
266 }));
267 self
268 }
269
270 pub fn repeat_layers(
272 mut self,
273 count: usize,
274 stage_for_layer: impl Fn(usize) -> FlowStage + Send + Sync + 'static,
275 ) -> Self {
276 self.stages
277 .push(FlowStage::Repeat(RepeatStage::new(count, stage_for_layer)));
278 self
279 }
280
281 pub fn named_layer(mut self, name: impl Into<String>, inner: FlowStage) -> Self {
283 self.stages.push(FlowStage::Named {
284 name: name.into(),
285 inner: Arc::new(inner),
286 });
287 self
288 }
289
290 pub fn layer(
292 self,
293 name: impl Into<String>,
294 build: impl FnOnce(LayerStack) -> LayerStack,
295 ) -> Self {
296 self.raw_stage(build(LayerStack::named(name)).build())
297 }
298
299 pub fn llama_prefill_layer(self, layer_idx: usize, spec: LlamaDecoderSpec) -> Self {
301 self.raw_stage(llama_prefill_layer_fused(layer_idx, spec))
302 }
303
304 pub fn llama_prefill_layer_composed(self, layer_idx: usize, spec: LlamaDecoderSpec) -> Self {
306 self.raw_stage(llama_prefill_layer_composed(layer_idx, spec))
307 }
308
309 pub fn linear(mut self, weight_key: impl Into<String>, transpose: bool) -> Self {
310 self.stages
311 .push(FlowStage::Linear(LinearStage::new(weight_key, transpose)));
312 self
313 }
314
315 pub fn residual_save(mut self) -> Self {
316 self.stages.push(FlowStage::ResidualSave(ResidualSaveStage));
317 self
318 }
319
320 pub fn residual_add(mut self) -> Self {
321 self.stages.push(FlowStage::ResidualAdd(ResidualAddStage));
322 self
323 }
324
325 pub fn swiglu(
326 mut self,
327 gate_key: impl Into<String>,
328 up_key: impl Into<String>,
329 down_key: impl Into<String>,
330 ) -> Self {
331 self.stages.push(FlowStage::SwiGlu(SwiGluStage::new(
332 gate_key, up_key, down_key,
333 )));
334 self
335 }
336
337 pub fn swiglu_hf_mlp(mut self, prefix: impl Into<String>) -> Self {
338 self.stages
339 .push(FlowStage::SwiGlu(SwiGluStage::hf_mlp(prefix)));
340 self
341 }
342
343 pub fn self_attn_prefill(mut self, spec: SelfAttnPrefillSpec) -> Self {
344 self.stages
345 .push(FlowStage::SelfAttnPrefill(SelfAttnPrefillStage::new(spec)));
346 self
347 }
348
349 pub fn gdn_scan(mut self, stage: crate::blocks::GdnScanStage) -> Self {
350 self.stages.push(FlowStage::GdnScan(stage));
351 self
352 }
353
354 pub fn store_stream(mut self, name: impl Into<String>) -> Self {
355 self.stages
356 .push(FlowStage::StoreStream(StoreStreamStage::new(name)));
357 self
358 }
359
360 pub fn load_stream(mut self, name: impl Into<String>) -> Self {
361 self.stages
362 .push(FlowStage::LoadStream(LoadStreamStage::new(name)));
363 self
364 }
365
366 pub fn bind_inputs_to_streams(
370 mut self,
371 pairs: impl IntoIterator<Item = (impl Into<String>, impl Into<String>)>,
372 ) -> Self {
373 let pairs: Vec<(String, String)> = pairs
374 .into_iter()
375 .map(|(input, stream)| (input.into(), stream.into()))
376 .collect();
377 self.stages.push(FlowStage::Custom(CustomStage::named(
378 "bind_inputs_to_streams",
379 move |emit, primary| {
380 let primary = primary.ok_or_else(|| {
381 anyhow::anyhow!("bind_inputs_to_streams requires primary input")
382 })?;
383 for (input_name, stream_name) in &pairs {
384 let value = emit.flow_input(input_name)?;
385 emit.state.streams.insert(stream_name.clone(), value);
386 }
387 Ok(Some(primary))
388 },
389 )));
390 self
391 }
392
393 pub fn dual_stream<F>(
394 mut self,
395 name: impl Into<String>,
396 stream_a: impl Into<String>,
397 stream_b: impl Into<String>,
398 f: F,
399 ) -> Self
400 where
401 F: Fn(&mut Emit<'_>, FlowValue, FlowValue) -> anyhow::Result<(FlowValue, FlowValue)>
402 + Send
403 + Sync
404 + 'static,
405 {
406 self.stages.push(FlowStage::DualStream(DualStreamStage::new(
407 name, stream_a, stream_b, f,
408 )));
409 self
410 }
411
412 pub fn plugin<F>(mut self, f: F) -> Self
413 where
414 F: Fn(&mut Emit<'_>, Option<FlowValue>) -> anyhow::Result<Option<FlowValue>>
415 + Send
416 + Sync
417 + 'static,
418 {
419 self.stages.push(crate::plugin::plugin(f));
420 self
421 }
422
423 pub fn plugin_named<F>(mut self, name: impl Into<String>, f: F) -> Self
424 where
425 F: Fn(&mut Emit<'_>, Option<FlowValue>) -> anyhow::Result<Option<FlowValue>>
426 + Send
427 + Sync
428 + 'static,
429 {
430 self.stages.push(crate::plugin::plugin_named(name, f));
431 self
432 }
433
434 pub fn hidden_states(self) -> Self {
436 self.output("hidden")
437 }
438
439 pub fn llama_decoder_layer(
441 self,
442 layer_idx: usize,
443 spec: crate::blocks::LlamaDecoderSpec,
444 ) -> Self {
445 self.named_layer(
446 format!("layer{layer_idx}"),
447 FlowStage::LlamaDecoder(LlamaDecoderStage::layer(layer_idx, spec)),
448 )
449 }
450
451 pub fn llama_decode_layer(
453 self,
454 layer_idx: usize,
455 spec: crate::blocks::LlamaDecodeLayerSpec,
456 kv_out: SideOutputs,
457 ) -> Self {
458 self.named_layer(
459 format!("layer{layer_idx}"),
460 FlowStage::LlamaDecodeLayer(LlamaDecodeLayerStage::layer(
461 layer_idx,
462 spec,
463 kv_out.inner(),
464 )),
465 )
466 }
467
468 pub fn llama_kv_tap(
470 mut self,
471 layer_idx: usize,
472 head_dim: usize,
473 eps: f32,
474 sink: &SideOutputs,
475 ) -> Self {
476 self.stages
477 .push(FlowStage::LlamaKvTap(LlamaKvTapStage::layer(
478 layer_idx,
479 head_dim,
480 eps,
481 sink.inner(),
482 )));
483 self
484 }
485
486 pub fn final_norm(self, eps: f32) -> Self {
488 self.rms_norm("model.norm.weight", eps)
489 }
490
491 pub fn rms_norm(mut self, weight_key: impl Into<String>, eps: f32) -> Self {
492 self.stages
493 .push(FlowStage::RmsNorm(RmsNormStage::new(weight_key, eps)));
494 self
495 }
496
497 pub fn gather_last_token_dynamic(mut self, batch: usize) -> Self {
499 self.stages
500 .push(FlowStage::GatherLastToken(GatherLastTokenStage::dynamic(
501 batch,
502 )));
503 self
504 }
505
506 pub fn gather_last_token_at(mut self, batch: usize, seq: usize) -> Self {
508 self.stages.push(FlowStage::GatherLastToken(
509 GatherLastTokenStage::static_last(batch, seq),
510 ));
511 self
512 }
513
514 pub fn lm_head(
516 mut self,
517 vocab_size: usize,
518 hidden_size: usize,
519 tie_word_embeddings: bool,
520 ) -> Self {
521 let stage = if tie_word_embeddings {
522 LmHeadStage::tied(vocab_size, hidden_size)
523 } else {
524 LmHeadStage::separate("lm_head.weight", vocab_size, hidden_size)
525 };
526 self.stages.push(FlowStage::LmHead(stage));
527 self.output("logits")
528 }
529
530 pub fn raw_stage(mut self, stage: FlowStage) -> Self {
532 self.stages.push(stage);
533 self
534 }
535
536 pub fn raw_stages(mut self, stages: impl IntoIterator<Item = FlowStage>) -> Self {
538 self.stages.extend(stages);
539 self
540 }
541
542 pub fn sequence(mut self, stages: impl IntoIterator<Item = FlowStage>) -> Self {
544 self.stages
545 .push(FlowStage::Sequence(stages.into_iter().collect()));
546 self
547 }
548
549 pub fn when(self, cond: bool, f: impl FnOnce(Self) -> Self) -> Self {
551 if cond { f(self) } else { self }
552 }
553
554 pub fn custom<F>(mut self, f: F) -> Self
556 where
557 F: Fn(&mut Emit<'_>, Option<FlowValue>) -> anyhow::Result<Option<FlowValue>>
558 + Send
559 + Sync
560 + 'static,
561 {
562 self.stages.push(FlowStage::Custom(CustomStage::new(f)));
563 self
564 }
565
566 pub fn custom_named<F>(mut self, name: impl Into<String>, f: F) -> Self
568 where
569 F: Fn(&mut Emit<'_>, Option<FlowValue>) -> anyhow::Result<Option<FlowValue>>
570 + Send
571 + Sync
572 + 'static,
573 {
574 self.stages
575 .push(FlowStage::Custom(CustomStage::named(name, f)));
576 self
577 }
578
579 pub fn patch(self, f: impl FnOnce(Self) -> Self) -> Self {
581 f(self)
582 }
583}