Skip to main content

rlx_flow/
context.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3
4//! Flow context — internal HIR emission surface (not for model authors).
5
6use std::collections::HashMap;
7
8use anyhow::Result;
9use rlx_ir::hir::{HirModule, HirNodeId};
10use rlx_ir::{DType, GraphModule, Shape};
11
12use crate::profile::CompileProfile;
13use crate::value::FlowValue;
14use crate::weight::WeightSource;
15
16/// Handles for a [`Op::GatedDeltaNet`] / carry scan.
17#[derive(Debug, Clone, Copy)]
18pub struct GdnInputSlots {
19    pub q: HirNodeId,
20    pub k: HirNodeId,
21    pub v: HirNodeId,
22    pub g: HirNodeId,
23    pub beta: HirNodeId,
24}
25
26/// Cross-stage shared handles (RoPE tables, zero-beta, tied embed, …).
27#[derive(Debug, Default)]
28pub struct FlowState {
29    pub rope_cos: Option<HirNodeId>,
30    pub rope_sin: Option<HirNodeId>,
31    pub zero_beta: Option<HirNodeId>,
32    pub embed_weight: Option<HirNodeId>,
33    pub hidden_shape: Option<Shape>,
34    pub decode: Option<DecodeBindings>,
35    pub residual_skip: Option<HirNodeId>,
36    pub residual_shape: Option<Shape>,
37    /// Named tensor streams (`img`, `txt`, …) for multi-stream models.
38    pub streams: HashMap<String, FlowValue>,
39    /// Graph inputs beyond the primary tensor flow (`encoder`, `temb`, …).
40    pub inputs: HashMap<String, (HirNodeId, Shape)>,
41    /// Named scalar/tensor node refs (RoPE tables, mod params, carry state, …).
42    pub named: HashMap<String, HirNodeId>,
43    /// Last-published GDN q/k/v/g/beta handles for [`crate::blocks::GdnScanStage`].
44    pub gdn: Option<GdnInputSlots>,
45    /// Reuse param nodes when multiple stages in one layer load the same key
46    /// (e.g. [`crate::blocks::LlamaKvTapStage`] + fused decoder).
47    pub loaded_params: HashMap<String, HirNodeId>,
48}
49
50/// KV-cache decode inputs bound by [`crate::blocks::BindDecodeInputsStage`].
51#[derive(Debug, Clone)]
52pub struct DecodeBindings {
53    pub cos: HirNodeId,
54    pub sin: HirNodeId,
55    pub mask: Option<HirNodeId>,
56    pub past_k: Vec<HirNodeId>,
57    pub past_v: Vec<HirNodeId>,
58}
59
60/// Internal builder context. Blocks emit through this — tier-2 via [`crate::escape::Emit`].
61pub struct FlowCtx<'a> {
62    pub(crate) module: GraphModule,
63    pub(crate) params: &'a mut HashMap<String, Vec<f32>>,
64    pub(crate) weights: &'a mut dyn WeightSource,
65    pub(crate) profile: &'a CompileProfile,
66    pub(crate) state: &'a mut FlowState,
67}
68
69impl FlowCtx<'_> {
70    pub fn hir(&mut self) -> &mut HirModule {
71        self.module
72            .as_hir_mut()
73            .expect("flow context requires HIR stage")
74    }
75
76    pub fn node_shape(&self, id: HirNodeId) -> Result<Shape> {
77        Ok(self
78            .module
79            .as_hir()
80            .ok_or_else(|| anyhow::anyhow!("flow context requires HIR stage"))?
81            .node(id)
82            .shape
83            .clone())
84    }
85
86    pub fn load_param(&mut self, key: &str, transpose: bool) -> Result<HirNodeId> {
87        let cache_key = param_cache_key(key, transpose);
88        if let Some(&id) = self.state.loaded_params.get(&cache_key) {
89            return Ok(id);
90        }
91        let (data, shape) = self.weights.take(key, transpose)?;
92        let ir_shape = Shape::new(&shape, DType::F32);
93        let id = self.hir().param(key, ir_shape);
94        self.params.insert(key.to_string(), data);
95        self.state.loaded_params.insert(cache_key, id);
96        Ok(id)
97    }
98
99    pub fn synth_param(&mut self, name: &str, data: Vec<f32>, shape: Shape) -> HirNodeId {
100        let id = self.hir().param(name, shape);
101        self.params.insert(name.to_string(), data);
102        id
103    }
104
105    pub fn synth_zeros(&mut self, name: &str, len: usize) -> HirNodeId {
106        self.synth_param(name, vec![0f32; len], Shape::new(&[len], DType::F32))
107    }
108
109    pub fn input(&mut self, name: &str, shape: Shape) -> HirNodeId {
110        self.hir().input(name, shape)
111    }
112
113    pub fn wrap(&self, id: HirNodeId, shape: Shape) -> FlowValue {
114        FlowValue::new(id, shape)
115    }
116}
117
118fn param_cache_key(key: &str, transpose: bool) -> String {
119    if transpose {
120        format!("{key}\0t")
121    } else {
122        key.to_string()
123    }
124}