rlx_flow/blocks/
bind_decode.rs1use anyhow::Result;
19use rlx_ir::hir::{HirModule, HirNodeId, HirOp};
20
21use crate::context::{DecodeBindings, FlowCtx};
22#[derive(Debug, Clone)]
23pub struct BindDecodeInputsStage {
24 pub num_layers: usize,
25 pub use_custom_mask: bool,
26}
27
28impl BindDecodeInputsStage {
29 pub fn emit(&self, ctx: &mut FlowCtx<'_>) -> Result<()> {
30 let cos = ctx
31 .state
32 .rope_cos
33 .or_else(|| find_input(ctx.hir(), "rope_cos").ok())
34 .ok_or_else(|| anyhow::anyhow!("decode flow missing rope_cos"))?;
35 let sin = ctx
36 .state
37 .rope_sin
38 .or_else(|| find_input(ctx.hir(), "rope_sin").ok())
39 .ok_or_else(|| anyhow::anyhow!("decode flow missing rope_sin"))?;
40 let mask = if self.use_custom_mask {
41 Some(find_input(ctx.hir(), "mask")?)
42 } else {
43 None
44 };
45 let mut past_k = Vec::with_capacity(self.num_layers);
46 let mut past_v = Vec::with_capacity(self.num_layers);
47 for i in 0..self.num_layers {
48 past_k.push(find_input(ctx.hir(), &format!("past_k_{i}"))?);
49 past_v.push(find_input(ctx.hir(), &format!("past_v_{i}"))?);
50 }
51 ctx.state.decode = Some(DecodeBindings {
52 cos,
53 sin,
54 mask,
55 past_k,
56 past_v,
57 });
58 Ok(())
59 }
60}
61
62fn find_input(hir: &HirModule, name: &str) -> Result<HirNodeId> {
63 for node in hir.nodes() {
64 if let HirOp::Input { name: n } = &node.op {
65 if n == name {
66 return Ok(node.id);
67 }
68 }
69 }
70 Err(anyhow::anyhow!("decode flow missing input: {name}"))
71}