rlx_flow/blocks/
bind_decode.rs1use anyhow::Result;
7use rlx_ir::hir::{HirModule, HirNodeId, HirOp};
8
9use crate::context::{DecodeBindings, FlowCtx};
10#[derive(Debug, Clone)]
11pub struct BindDecodeInputsStage {
12 pub num_layers: usize,
13 pub use_custom_mask: bool,
14}
15
16impl BindDecodeInputsStage {
17 pub fn emit(&self, ctx: &mut FlowCtx<'_>) -> Result<()> {
18 let cos = ctx
19 .state
20 .rope_cos
21 .or_else(|| find_input(ctx.hir(), "rope_cos").ok())
22 .ok_or_else(|| anyhow::anyhow!("decode flow missing rope_cos"))?;
23 let sin = ctx
24 .state
25 .rope_sin
26 .or_else(|| find_input(ctx.hir(), "rope_sin").ok())
27 .ok_or_else(|| anyhow::anyhow!("decode flow missing rope_sin"))?;
28 let mask = if self.use_custom_mask {
29 Some(find_input(ctx.hir(), "mask")?)
30 } else {
31 None
32 };
33 let mut past_k = Vec::with_capacity(self.num_layers);
34 let mut past_v = Vec::with_capacity(self.num_layers);
35 for i in 0..self.num_layers {
36 past_k.push(find_input(ctx.hir(), &format!("past_k_{i}"))?);
37 past_v.push(find_input(ctx.hir(), &format!("past_v_{i}"))?);
38 }
39 ctx.state.decode = Some(DecodeBindings {
40 cos,
41 sin,
42 mask,
43 past_k,
44 past_v,
45 });
46 Ok(())
47 }
48}
49
50fn find_input(hir: &HirModule, name: &str) -> Result<HirNodeId> {
51 for node in hir.nodes() {
52 if let HirOp::Input { name: n } = &node.op {
53 if n == name {
54 return Ok(node.id);
55 }
56 }
57 }
58 Err(anyhow::anyhow!("decode flow missing input: {name}"))
59}