Skip to main content

rlx_flow/blocks/
bind_decode.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Bind decode graph inputs (RoPE slice, past K/V, optional mask).
17
18use anyhow::Result;
19use rlx_ir::hir::{HirModule, HirNodeId, HirOp};
20
21use crate::context::{DecodeBindings, FlowCtx};
22#[derive(Debug, Clone)]
23pub struct BindDecodeInputsStage {
24    pub num_layers: usize,
25    pub use_custom_mask: bool,
26    pub need_past_kv: bool,
27}
28
29impl BindDecodeInputsStage {
30    pub fn emit(&self, ctx: &mut FlowCtx<'_>) -> Result<()> {
31        let cos = ctx
32            .state
33            .rope_cos
34            .or_else(|| find_input(ctx.hir(), "rope_cos").ok())
35            .ok_or_else(|| anyhow::anyhow!("decode flow missing rope_cos"))?;
36        let sin = ctx
37            .state
38            .rope_sin
39            .or_else(|| find_input(ctx.hir(), "rope_sin").ok())
40            .ok_or_else(|| anyhow::anyhow!("decode flow missing rope_sin"))?;
41        let mask = if self.use_custom_mask {
42            Some(find_input(ctx.hir(), "mask")?)
43        } else {
44            None
45        };
46        let mut past_k = Vec::with_capacity(self.num_layers);
47        let mut past_v = Vec::with_capacity(self.num_layers);
48        if self.need_past_kv {
49            for i in 0..self.num_layers {
50                past_k.push(find_input(ctx.hir(), &format!("past_k_{i}"))?);
51                past_v.push(find_input(ctx.hir(), &format!("past_v_{i}"))?);
52            }
53        }
54        ctx.state.decode = Some(DecodeBindings {
55            cos,
56            sin,
57            mask,
58            past_k,
59            past_v,
60        });
61        Ok(())
62    }
63}
64
65fn find_input(hir: &HirModule, name: &str) -> Result<HirNodeId> {
66    for node in hir.nodes() {
67        if let HirOp::Input { name: n } = &node.op {
68            if n == name {
69                return Ok(node.id);
70            }
71        }
72    }
73    Err(anyhow::anyhow!("decode flow missing input: {name}"))
74}