rlx_flow/blocks/
bind_decode.rs1use anyhow::Result;
19use rlx_ir::hir::{HirModule, HirNodeId, HirOp};
20
21use crate::context::{DecodeBindings, FlowCtx};
22#[derive(Debug, Clone)]
23pub struct BindDecodeInputsStage {
24 pub num_layers: usize,
25 pub use_custom_mask: bool,
26 pub need_past_kv: bool,
27}
28
29impl BindDecodeInputsStage {
30 pub fn emit(&self, ctx: &mut FlowCtx<'_>) -> Result<()> {
31 let cos = ctx
32 .state
33 .rope_cos
34 .or_else(|| find_input(ctx.hir(), "rope_cos").ok())
35 .ok_or_else(|| anyhow::anyhow!("decode flow missing rope_cos"))?;
36 let sin = ctx
37 .state
38 .rope_sin
39 .or_else(|| find_input(ctx.hir(), "rope_sin").ok())
40 .ok_or_else(|| anyhow::anyhow!("decode flow missing rope_sin"))?;
41 let mask = if self.use_custom_mask {
42 Some(find_input(ctx.hir(), "mask")?)
43 } else {
44 None
45 };
46 let mut past_k = Vec::with_capacity(self.num_layers);
47 let mut past_v = Vec::with_capacity(self.num_layers);
48 if self.need_past_kv {
49 for i in 0..self.num_layers {
50 past_k.push(find_input(ctx.hir(), &format!("past_k_{i}"))?);
51 past_v.push(find_input(ctx.hir(), &format!("past_v_{i}"))?);
52 }
53 }
54 ctx.state.decode = Some(DecodeBindings {
55 cos,
56 sin,
57 mask,
58 past_k,
59 past_v,
60 });
61 Ok(())
62 }
63}
64
65fn find_input(hir: &HirModule, name: &str) -> Result<HirNodeId> {
66 for node in hir.nodes() {
67 if let HirOp::Input { name: n } = &node.op {
68 if n == name {
69 return Ok(node.id);
70 }
71 }
72 }
73 Err(anyhow::anyhow!("decode flow missing input: {name}"))
74}