Skip to main content

rlx_flow/
context.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Flow context — internal HIR emission surface (not for model authors).
17
18use std::collections::HashMap;
19
20use anyhow::Result;
21use rlx_ir::hir::{HirModule, HirNodeId};
22use rlx_ir::{DType, GraphModule, Shape};
23
24use crate::profile::CompileProfile;
25use crate::value::FlowValue;
26use crate::weight::WeightSource;
27
28/// Handles for a [`Op::GatedDeltaNet`] / carry scan.
29#[derive(Debug, Clone, Copy)]
30pub struct GdnInputSlots {
31    pub q: HirNodeId,
32    pub k: HirNodeId,
33    pub v: HirNodeId,
34    pub g: HirNodeId,
35    pub beta: HirNodeId,
36}
37
38/// Cross-stage shared handles (RoPE tables, zero-beta, tied embed, …).
39#[derive(Debug, Default)]
40pub struct FlowState {
41    pub rope_cos: Option<HirNodeId>,
42    pub rope_sin: Option<HirNodeId>,
43    pub zero_beta: Option<HirNodeId>,
44    pub embed_weight: Option<HirNodeId>,
45    pub hidden_shape: Option<Shape>,
46    pub decode: Option<DecodeBindings>,
47    pub residual_skip: Option<HirNodeId>,
48    pub residual_shape: Option<Shape>,
49    /// Named tensor streams (`img`, `txt`, …) for multi-stream models.
50    pub streams: HashMap<String, FlowValue>,
51    /// Graph inputs beyond the primary tensor flow (`encoder`, `temb`, …).
52    pub inputs: HashMap<String, (HirNodeId, Shape)>,
53    /// Named scalar/tensor node refs (RoPE tables, mod params, carry state, …).
54    pub named: HashMap<String, HirNodeId>,
55    /// Last-published GDN q/k/v/g/beta handles for [`crate::blocks::GdnScanStage`].
56    pub gdn: Option<GdnInputSlots>,
57    /// Reuse param nodes when multiple stages in one layer load the same key
58    /// (e.g. [`crate::blocks::LlamaKvTapStage`] + fused decoder).
59    pub loaded_params: HashMap<String, HirNodeId>,
60}
61
62/// KV-cache decode inputs bound by [`crate::blocks::BindDecodeInputsStage`].
63#[derive(Debug, Clone)]
64pub struct DecodeBindings {
65    pub cos: HirNodeId,
66    pub sin: HirNodeId,
67    pub mask: Option<HirNodeId>,
68    pub past_k: Vec<HirNodeId>,
69    pub past_v: Vec<HirNodeId>,
70}
71
72/// Internal builder context. Blocks emit through this — tier-2 via [`crate::escape::Emit`].
73pub struct FlowCtx<'a> {
74    pub(crate) module: GraphModule,
75    pub(crate) params: &'a mut HashMap<String, Vec<f32>>,
76    pub(crate) weights: &'a mut dyn WeightSource,
77    pub(crate) profile: &'a CompileProfile,
78    pub(crate) state: &'a mut FlowState,
79}
80
81impl FlowCtx<'_> {
82    pub fn hir(&mut self) -> &mut HirModule {
83        self.module
84            .as_hir_mut()
85            .expect("flow context requires HIR stage")
86    }
87
88    pub fn node_shape(&self, id: HirNodeId) -> Result<Shape> {
89        Ok(self
90            .module
91            .as_hir()
92            .ok_or_else(|| anyhow::anyhow!("flow context requires HIR stage"))?
93            .node(id)
94            .shape
95            .clone())
96    }
97
98    pub fn load_param(&mut self, key: &str, transpose: bool) -> Result<HirNodeId> {
99        let cache_key = param_cache_key(key, transpose);
100        if let Some(&id) = self.state.loaded_params.get(&cache_key) {
101            return Ok(id);
102        }
103        let (data, shape) = self.weights.take(key, transpose)?;
104        let ir_shape = Shape::new(&shape, DType::F32);
105        let id = self.hir().param(key, ir_shape);
106        self.params.insert(key.to_string(), data);
107        self.state.loaded_params.insert(cache_key, id);
108        Ok(id)
109    }
110
111    pub fn synth_param(&mut self, name: &str, data: Vec<f32>, shape: Shape) -> HirNodeId {
112        let id = self.hir().param(name, shape);
113        self.params.insert(name.to_string(), data);
114        id
115    }
116
117    pub fn synth_zeros(&mut self, name: &str, len: usize) -> HirNodeId {
118        self.synth_param(name, vec![0f32; len], Shape::new(&[len], DType::F32))
119    }
120
121    pub fn input(&mut self, name: &str, shape: Shape) -> HirNodeId {
122        self.hir().input(name, shape)
123    }
124
125    pub fn wrap(&self, id: HirNodeId, shape: Shape) -> FlowValue {
126        FlowValue::new(id, shape)
127    }
128}
129
130fn param_cache_key(key: &str, transpose: bool) -> String {
131    if transpose {
132        format!("{key}\0t")
133    } else {
134        key.to_string()
135    }
136}