Skip to main content

rlx_flow/blocks/
bind_decode.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Bind decode graph inputs (RoPE slice, past K/V, optional mask).
17
18use anyhow::Result;
19use rlx_ir::hir::{HirModule, HirNodeId, HirOp};
20
21use crate::context::{DecodeBindings, FlowCtx};
22#[derive(Debug, Clone)]
23pub struct BindDecodeInputsStage {
24    pub num_layers: usize,
25    pub use_custom_mask: bool,
26}
27
28impl BindDecodeInputsStage {
29    pub fn emit(&self, ctx: &mut FlowCtx<'_>) -> Result<()> {
30        let cos = ctx
31            .state
32            .rope_cos
33            .or_else(|| find_input(ctx.hir(), "rope_cos").ok())
34            .ok_or_else(|| anyhow::anyhow!("decode flow missing rope_cos"))?;
35        let sin = ctx
36            .state
37            .rope_sin
38            .or_else(|| find_input(ctx.hir(), "rope_sin").ok())
39            .ok_or_else(|| anyhow::anyhow!("decode flow missing rope_sin"))?;
40        let mask = if self.use_custom_mask {
41            Some(find_input(ctx.hir(), "mask")?)
42        } else {
43            None
44        };
45        let mut past_k = Vec::with_capacity(self.num_layers);
46        let mut past_v = Vec::with_capacity(self.num_layers);
47        for i in 0..self.num_layers {
48            past_k.push(find_input(ctx.hir(), &format!("past_k_{i}"))?);
49            past_v.push(find_input(ctx.hir(), &format!("past_v_{i}"))?);
50        }
51        ctx.state.decode = Some(DecodeBindings {
52            cos,
53            sin,
54            mask,
55            past_k,
56            past_v,
57        });
58        Ok(())
59    }
60}
61
62fn find_input(hir: &HirModule, name: &str) -> Result<HirNodeId> {
63    for node in hir.nodes() {
64        if let HirOp::Input { name: n } = &node.op {
65            if n == name {
66                return Ok(node.id);
67            }
68        }
69    }
70    Err(anyhow::anyhow!("decode flow missing input: {name}"))
71}