rlx_flow/blocks/
bind_decode.rs1use anyhow::Result;
7use rlx_ir::hir::{HirModule, HirNodeId, HirOp};
8
9use crate::context::{DecodeBindings, FlowCtx};
10#[derive(Debug, Clone)]
11pub struct BindDecodeInputsStage {
12 pub num_layers: usize,
13 pub use_custom_mask: bool,
14}
15
16impl BindDecodeInputsStage {
17 pub fn emit(&self, ctx: &mut FlowCtx<'_>) -> Result<()> {
18 let cos = find_input(ctx.hir(), "rope_cos")?;
19 let sin = find_input(ctx.hir(), "rope_sin")?;
20 let mask = if self.use_custom_mask {
21 Some(find_input(ctx.hir(), "mask")?)
22 } else {
23 None
24 };
25 let mut past_k = Vec::with_capacity(self.num_layers);
26 let mut past_v = Vec::with_capacity(self.num_layers);
27 for i in 0..self.num_layers {
28 past_k.push(find_input(ctx.hir(), &format!("past_k_{i}"))?);
29 past_v.push(find_input(ctx.hir(), &format!("past_v_{i}"))?);
30 }
31 ctx.state.decode = Some(DecodeBindings {
32 cos,
33 sin,
34 mask,
35 past_k,
36 past_v,
37 });
38 Ok(())
39 }
40}
41
42fn find_input(hir: &HirModule, name: &str) -> Result<HirNodeId> {
43 for node in hir.nodes() {
44 if let HirOp::Input { name: n } = &node.op {
45 if n == name {
46 return Ok(node.id);
47 }
48 }
49 }
50 Err(anyhow::anyhow!("decode flow missing input: {name}"))
51}