Skip to main content

rlx_flow/blocks/
bind_decode.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3
4//! Bind decode graph inputs (RoPE slice, past K/V, optional mask).
5
6use anyhow::Result;
7use rlx_ir::hir::{HirModule, HirNodeId, HirOp};
8
9use crate::context::{DecodeBindings, FlowCtx};
10#[derive(Debug, Clone)]
11pub struct BindDecodeInputsStage {
12    pub num_layers: usize,
13    pub use_custom_mask: bool,
14}
15
16impl BindDecodeInputsStage {
17    pub fn emit(&self, ctx: &mut FlowCtx<'_>) -> Result<()> {
18        let cos = ctx
19            .state
20            .rope_cos
21            .or_else(|| find_input(ctx.hir(), "rope_cos").ok())
22            .ok_or_else(|| anyhow::anyhow!("decode flow missing rope_cos"))?;
23        let sin = ctx
24            .state
25            .rope_sin
26            .or_else(|| find_input(ctx.hir(), "rope_sin").ok())
27            .ok_or_else(|| anyhow::anyhow!("decode flow missing rope_sin"))?;
28        let mask = if self.use_custom_mask {
29            Some(find_input(ctx.hir(), "mask")?)
30        } else {
31            None
32        };
33        let mut past_k = Vec::with_capacity(self.num_layers);
34        let mut past_v = Vec::with_capacity(self.num_layers);
35        for i in 0..self.num_layers {
36            past_k.push(find_input(ctx.hir(), &format!("past_k_{i}"))?);
37            past_v.push(find_input(ctx.hir(), &format!("past_v_{i}"))?);
38        }
39        ctx.state.decode = Some(DecodeBindings {
40            cos,
41            sin,
42            mask,
43            past_k,
44            past_v,
45        });
46        Ok(())
47    }
48}
49
50fn find_input(hir: &HirModule, name: &str) -> Result<HirNodeId> {
51    for node in hir.nodes() {
52        if let HirOp::Input { name: n } = &node.op {
53            if n == name {
54                return Ok(node.id);
55            }
56        }
57    }
58    Err(anyhow::anyhow!("decode flow missing input: {name}"))
59}