1use std::path::Path;
7use std::sync::Arc;
8
9use crate::blocks::{
10 AttnMaskStage, BertEncoderLayerSpec, BertEncoderLayerStage, BertQkvStyle,
11 BindDecodeInputsStage, ClsTokenPoolStage, CustomStage, EmbedStage, GatherAddStage,
12 GatherFromInputStage, GatherLastTokenStage, GeluFfnStage, LayerNormStage, LinearStage,
13 LlamaDecodeLayerStage, LlamaDecoderSpec, LlamaDecoderStage, LlamaKvTapStage, LmHeadStage,
14 NomicEncoderLayerSpec, NomicEncoderLayerStage, RepeatStage, ResidualAddStage,
15 ResidualSaveStage, RmsNormStage, RopeTablesStage, SelfAttnPrefillSpec, SelfAttnPrefillStage,
16 SwiGluStage, dinov2_layer_fused, llama_prefill_layer_composed, llama_prefill_layer_fused,
17 nomic_vision_layer_fused,
18};
19use crate::escape::Emit;
20use crate::flow::ModelFlow;
21use crate::layer::LayerStack;
22use crate::profile::CompileProfile;
23use crate::side::SideOutputs;
24use crate::stage::FlowStage;
25use crate::stream::{DualStreamStage, LoadStreamStage, StoreStreamStage};
26use crate::value::FlowValue;
27
28impl ModelFlow {
29 pub fn profile_file(mut self, path: impl AsRef<Path>, default: fn() -> CompileProfile) -> Self {
31 self.profile = CompileProfile::from_toml_path(path.as_ref()).unwrap_or_else(|_| default());
32 self
33 }
34
35 pub fn profile_encoder(mut self) -> Self {
37 self.profile = CompileProfile::encoder();
38 self
39 }
40
41 pub fn gather_from_input(
43 mut self,
44 input_name: impl Into<String>,
45 weight_key: impl Into<String>,
46 ) -> Self {
47 self.stages
48 .push(FlowStage::GatherFromInput(GatherFromInputStage::new(
49 input_name, weight_key, 0,
50 )));
51 self
52 }
53
54 pub fn gather_add(
56 mut self,
57 input_name: impl Into<String>,
58 weight_key: impl Into<String>,
59 ) -> Self {
60 self.stages.push(FlowStage::GatherAdd(GatherAddStage::new(
61 input_name, weight_key, 0,
62 )));
63 self
64 }
65
66 pub fn layer_norm(
68 mut self,
69 gamma_key: impl Into<String>,
70 beta_key: impl Into<String>,
71 eps: f32,
72 ) -> Self {
73 self.stages.push(FlowStage::LayerNorm(LayerNormStage::new(
74 gamma_key, beta_key, eps,
75 )));
76 self
77 }
78
79 pub fn gelu_ffn(mut self, layer_prefix: impl Into<String>) -> Self {
81 self.stages
82 .push(FlowStage::GeluFfn(GeluFfnStage::hf_bert(layer_prefix)));
83 self
84 }
85
86 pub fn repeat_nomic_layers(
88 self,
89 count: usize,
90 hidden_size: usize,
91 num_heads: usize,
92 head_dim: usize,
93 eps: f32,
94 ) -> Self {
95 self.repeat_layers(count, move |i| FlowStage::Named {
96 name: format!("layer{i}"),
97 inner: std::sync::Arc::new(FlowStage::NomicEncoderLayer(NomicEncoderLayerStage::new(
98 NomicEncoderLayerSpec::hf(
99 format!("encoder.layers.{i}"),
100 hidden_size,
101 num_heads,
102 head_dim,
103 eps,
104 ),
105 ))),
106 })
107 }
108
109 pub fn bert_encoder_layer(mut self, spec: BertEncoderLayerSpec) -> Self {
111 self.stages
112 .push(FlowStage::BertEncoderLayer(BertEncoderLayerStage::new(
113 spec,
114 )));
115 self
116 }
117
118 pub fn repeat_bert_layers(
120 self,
121 count: usize,
122 prefix: impl Into<String>,
123 qkv_style: BertQkvStyle,
124 hidden_size: usize,
125 num_heads: usize,
126 eps: f32,
127 ) -> Self {
128 let prefix = prefix.into();
129 self.repeat_layers(count, move |i| {
130 let lp = if prefix.is_empty() {
131 format!("encoder.layer.{i}")
132 } else {
133 format!("{prefix}.encoder.layer.{i}")
134 };
135 FlowStage::Named {
136 name: format!("layer{i}"),
137 inner: std::sync::Arc::new(FlowStage::BertEncoderLayer(
138 BertEncoderLayerStage::new(BertEncoderLayerSpec::hf(
139 lp,
140 qkv_style,
141 hidden_size,
142 num_heads,
143 eps,
144 )),
145 )),
146 }
147 })
148 }
149
150 pub fn attn_mask_ones(mut self, batch: usize, seq: usize) -> Self {
152 self.stages
153 .push(FlowStage::AttnMask(AttnMaskStage::ones(batch, seq)));
154 self
155 }
156
157 pub fn repeat_dinov2_layers(
159 self,
160 count: usize,
161 hidden_size: usize,
162 num_heads: usize,
163 eps: f32,
164 ) -> Self {
165 self.repeat_layers(count, move |i| {
166 dinov2_layer_fused(i, hidden_size, num_heads, eps)
167 })
168 }
169
170 pub fn repeat_vision_layers(
172 self,
173 count: usize,
174 hidden_size: usize,
175 num_heads: usize,
176 eps: f32,
177 ) -> Self {
178 self.repeat_layers(count, move |i| {
179 nomic_vision_layer_fused(i, hidden_size, num_heads, eps)
180 })
181 }
182
183 pub fn cls_token_pool(mut self, batch: usize, hidden: usize) -> Self {
185 self.stages
186 .push(FlowStage::ClsTokenPool(ClsTokenPoolStage::new(
187 batch, hidden,
188 )));
189 self
190 }
191
192 pub fn profile_prefill(mut self) -> Self {
194 self.profile = CompileProfile::llama32_prefill();
195 self
196 }
197
198 pub fn profile_decode(mut self) -> Self {
200 self.profile = CompileProfile::llama32_decode();
201 self
202 }
203
204 pub fn embed(mut self, weight_key: impl Into<String>) -> Self {
206 self.stages
207 .push(FlowStage::Embed(EmbedStage::token(weight_key)));
208 self
209 }
210
211 pub fn token_embed(self) -> Self {
213 self.embed("model.embed_tokens.weight")
214 }
215
216 pub fn rope_tables(mut self, tables: RopeTablesStage) -> Self {
218 self.stages.push(FlowStage::RopeTables(tables));
219 self
220 }
221
222 pub fn zero_beta(self, len: usize) -> Self {
224 self.zero_beta_named("zero_beta", len)
225 }
226
227 pub fn zero_beta_named(mut self, name: impl Into<String>, len: usize) -> Self {
228 self.stages.push(FlowStage::ZeroBeta {
229 name: name.into(),
230 len,
231 });
232 self
233 }
234
235 pub fn bind_decode_inputs(mut self, num_layers: usize, custom_mask: bool) -> Self {
237 self.stages
238 .push(FlowStage::BindDecodeInputs(BindDecodeInputsStage {
239 num_layers,
240 use_custom_mask: custom_mask,
241 }));
242 self
243 }
244
245 pub fn repeat_layers(
247 mut self,
248 count: usize,
249 stage_for_layer: impl Fn(usize) -> FlowStage + Send + Sync + 'static,
250 ) -> Self {
251 self.stages
252 .push(FlowStage::Repeat(RepeatStage::new(count, stage_for_layer)));
253 self
254 }
255
256 pub fn named_layer(mut self, name: impl Into<String>, inner: FlowStage) -> Self {
258 self.stages.push(FlowStage::Named {
259 name: name.into(),
260 inner: Arc::new(inner),
261 });
262 self
263 }
264
265 pub fn layer(
267 self,
268 name: impl Into<String>,
269 build: impl FnOnce(LayerStack) -> LayerStack,
270 ) -> Self {
271 self.raw_stage(build(LayerStack::named(name)).build())
272 }
273
274 pub fn llama_prefill_layer(self, layer_idx: usize, spec: LlamaDecoderSpec) -> Self {
276 self.raw_stage(llama_prefill_layer_fused(layer_idx, spec))
277 }
278
279 pub fn llama_prefill_layer_composed(self, layer_idx: usize, spec: LlamaDecoderSpec) -> Self {
281 self.raw_stage(llama_prefill_layer_composed(layer_idx, spec))
282 }
283
284 pub fn linear(mut self, weight_key: impl Into<String>, transpose: bool) -> Self {
285 self.stages
286 .push(FlowStage::Linear(LinearStage::new(weight_key, transpose)));
287 self
288 }
289
290 pub fn residual_save(mut self) -> Self {
291 self.stages.push(FlowStage::ResidualSave(ResidualSaveStage));
292 self
293 }
294
295 pub fn residual_add(mut self) -> Self {
296 self.stages.push(FlowStage::ResidualAdd(ResidualAddStage));
297 self
298 }
299
300 pub fn swiglu(
301 mut self,
302 gate_key: impl Into<String>,
303 up_key: impl Into<String>,
304 down_key: impl Into<String>,
305 ) -> Self {
306 self.stages.push(FlowStage::SwiGlu(SwiGluStage::new(
307 gate_key, up_key, down_key,
308 )));
309 self
310 }
311
312 pub fn swiglu_hf_mlp(mut self, prefix: impl Into<String>) -> Self {
313 self.stages
314 .push(FlowStage::SwiGlu(SwiGluStage::hf_mlp(prefix)));
315 self
316 }
317
318 pub fn self_attn_prefill(mut self, spec: SelfAttnPrefillSpec) -> Self {
319 self.stages
320 .push(FlowStage::SelfAttnPrefill(SelfAttnPrefillStage::new(spec)));
321 self
322 }
323
324 pub fn gdn_scan(mut self, stage: crate::blocks::GdnScanStage) -> Self {
325 self.stages.push(FlowStage::GdnScan(stage));
326 self
327 }
328
329 pub fn store_stream(mut self, name: impl Into<String>) -> Self {
330 self.stages
331 .push(FlowStage::StoreStream(StoreStreamStage::new(name)));
332 self
333 }
334
335 pub fn load_stream(mut self, name: impl Into<String>) -> Self {
336 self.stages
337 .push(FlowStage::LoadStream(LoadStreamStage::new(name)));
338 self
339 }
340
341 pub fn bind_inputs_to_streams(
345 mut self,
346 pairs: impl IntoIterator<Item = (impl Into<String>, impl Into<String>)>,
347 ) -> Self {
348 let pairs: Vec<(String, String)> = pairs
349 .into_iter()
350 .map(|(input, stream)| (input.into(), stream.into()))
351 .collect();
352 self.stages.push(FlowStage::Custom(CustomStage::named(
353 "bind_inputs_to_streams",
354 move |emit, primary| {
355 let primary = primary.ok_or_else(|| {
356 anyhow::anyhow!("bind_inputs_to_streams requires primary input")
357 })?;
358 for (input_name, stream_name) in &pairs {
359 let value = emit.flow_input(input_name)?;
360 emit.state.streams.insert(stream_name.clone(), value);
361 }
362 Ok(Some(primary))
363 },
364 )));
365 self
366 }
367
368 pub fn dual_stream<F>(
369 mut self,
370 name: impl Into<String>,
371 stream_a: impl Into<String>,
372 stream_b: impl Into<String>,
373 f: F,
374 ) -> Self
375 where
376 F: Fn(&mut Emit<'_>, FlowValue, FlowValue) -> anyhow::Result<(FlowValue, FlowValue)>
377 + Send
378 + Sync
379 + 'static,
380 {
381 self.stages.push(FlowStage::DualStream(DualStreamStage::new(
382 name, stream_a, stream_b, f,
383 )));
384 self
385 }
386
387 pub fn plugin<F>(mut self, f: F) -> Self
388 where
389 F: Fn(&mut Emit<'_>, Option<FlowValue>) -> anyhow::Result<Option<FlowValue>>
390 + Send
391 + Sync
392 + 'static,
393 {
394 self.stages.push(crate::plugin::plugin(f));
395 self
396 }
397
398 pub fn plugin_named<F>(mut self, name: impl Into<String>, f: F) -> Self
399 where
400 F: Fn(&mut Emit<'_>, Option<FlowValue>) -> anyhow::Result<Option<FlowValue>>
401 + Send
402 + Sync
403 + 'static,
404 {
405 self.stages.push(crate::plugin::plugin_named(name, f));
406 self
407 }
408
409 pub fn hidden_states(self) -> Self {
411 self.output("hidden")
412 }
413
414 pub fn llama_decoder_layer(
416 self,
417 layer_idx: usize,
418 spec: crate::blocks::LlamaDecoderSpec,
419 ) -> Self {
420 self.named_layer(
421 format!("layer{layer_idx}"),
422 FlowStage::LlamaDecoder(LlamaDecoderStage::layer(layer_idx, spec)),
423 )
424 }
425
426 pub fn llama_decode_layer(
428 self,
429 layer_idx: usize,
430 spec: crate::blocks::LlamaDecodeLayerSpec,
431 kv_out: SideOutputs,
432 ) -> Self {
433 self.named_layer(
434 format!("layer{layer_idx}"),
435 FlowStage::LlamaDecodeLayer(LlamaDecodeLayerStage::layer(
436 layer_idx,
437 spec,
438 kv_out.inner(),
439 )),
440 )
441 }
442
443 pub fn llama_kv_tap(
445 mut self,
446 layer_idx: usize,
447 head_dim: usize,
448 eps: f32,
449 sink: &SideOutputs,
450 ) -> Self {
451 self.stages
452 .push(FlowStage::LlamaKvTap(LlamaKvTapStage::layer(
453 layer_idx,
454 head_dim,
455 eps,
456 sink.inner(),
457 )));
458 self
459 }
460
461 pub fn final_norm(self, eps: f32) -> Self {
463 self.rms_norm("model.norm.weight", eps)
464 }
465
466 pub fn rms_norm(mut self, weight_key: impl Into<String>, eps: f32) -> Self {
467 self.stages
468 .push(FlowStage::RmsNorm(RmsNormStage::new(weight_key, eps)));
469 self
470 }
471
472 pub fn gather_last_token_dynamic(mut self, batch: usize) -> Self {
474 self.stages
475 .push(FlowStage::GatherLastToken(GatherLastTokenStage::dynamic(
476 batch,
477 )));
478 self
479 }
480
481 pub fn gather_last_token_at(mut self, batch: usize, seq: usize) -> Self {
483 self.stages.push(FlowStage::GatherLastToken(
484 GatherLastTokenStage::static_last(batch, seq),
485 ));
486 self
487 }
488
489 pub fn lm_head(
491 mut self,
492 vocab_size: usize,
493 hidden_size: usize,
494 tie_word_embeddings: bool,
495 ) -> Self {
496 let stage = if tie_word_embeddings {
497 LmHeadStage::tied(vocab_size, hidden_size)
498 } else {
499 LmHeadStage::separate("lm_head.weight", vocab_size, hidden_size)
500 };
501 self.stages.push(FlowStage::LmHead(stage));
502 self.output("logits")
503 }
504
505 pub fn raw_stage(mut self, stage: FlowStage) -> Self {
507 self.stages.push(stage);
508 self
509 }
510
511 pub fn raw_stages(mut self, stages: impl IntoIterator<Item = FlowStage>) -> Self {
513 self.stages.extend(stages);
514 self
515 }
516
517 pub fn sequence(mut self, stages: impl IntoIterator<Item = FlowStage>) -> Self {
519 self.stages
520 .push(FlowStage::Sequence(stages.into_iter().collect()));
521 self
522 }
523
524 pub fn when(self, cond: bool, f: impl FnOnce(Self) -> Self) -> Self {
526 if cond { f(self) } else { self }
527 }
528
529 pub fn custom<F>(mut self, f: F) -> Self
531 where
532 F: Fn(&mut Emit<'_>, Option<FlowValue>) -> anyhow::Result<Option<FlowValue>>
533 + Send
534 + Sync
535 + 'static,
536 {
537 self.stages.push(FlowStage::Custom(CustomStage::new(f)));
538 self
539 }
540
541 pub fn custom_named<F>(mut self, name: impl Into<String>, f: F) -> Self
543 where
544 F: Fn(&mut Emit<'_>, Option<FlowValue>) -> anyhow::Result<Option<FlowValue>>
545 + Send
546 + Sync
547 + 'static,
548 {
549 self.stages
550 .push(FlowStage::Custom(CustomStage::named(name, f)));
551 self
552 }
553
554 pub fn patch(self, f: impl FnOnce(Self) -> Self) -> Self {
556 f(self)
557 }
558}