Skip to main content

rlx_flow/
escape.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Tier-2 escape hatch — custom HIR emission when blocks are not enough yet.
17//!
18//! Prefer adding a reusable block under `blocks/` over long-lived custom closures.
19//! Custom stages are for one-off arch experiments and novel subgraphs.
20
21use std::collections::HashMap;
22
23use anyhow::Result;
24use rlx_ir::hir::HirModule;
25use rlx_ir::{GraphModule, HirNodeId, Shape};
26
27use crate::context::{FlowCtx, FlowState};
28use crate::profile::CompileProfile;
29use crate::value::FlowValue;
30use crate::weight::WeightSource;
31
32/// Mutable emission context for custom stages (tier 2).
33pub struct Emit<'a> {
34    pub module: &'a mut GraphModule,
35    pub params: &'a mut HashMap<String, Vec<f32>>,
36    pub weights: &'a mut dyn WeightSource,
37    pub state: &'a mut FlowState,
38    pub profile: &'a CompileProfile,
39}
40
41impl<'a> Emit<'a> {
42    pub(crate) fn from_ctx(ctx: &'a mut FlowCtx<'_>) -> Self {
43        Self {
44            module: &mut ctx.module,
45            params: ctx.params,
46            weights: ctx.weights,
47            state: ctx.state,
48            profile: ctx.profile,
49        }
50    }
51
52    pub fn hir(&mut self) -> &mut HirModule {
53        self.module
54            .as_hir_mut()
55            .expect("flow context requires HIR stage")
56    }
57
58    pub fn load_param(&mut self, key: &str, transpose: bool) -> Result<HirNodeId> {
59        let cache_key = if transpose {
60            format!("{key}\0t")
61        } else {
62            key.to_string()
63        };
64        if let Some(&id) = self.state.loaded_params.get(&cache_key) {
65            return Ok(id);
66        }
67        let (data, shape) = self.weights.take(key, transpose)?;
68        let ir_shape = Shape::new(&shape, rlx_ir::DType::F32);
69        let id = self.hir().param(key, ir_shape);
70        self.params.insert(key.to_string(), data);
71        self.state.loaded_params.insert(cache_key, id);
72        Ok(id)
73    }
74
75    pub fn synth_param(&mut self, name: &str, data: Vec<f32>, shape: Shape) -> HirNodeId {
76        let id = self.hir().param(name, shape);
77        self.params.insert(name.to_string(), data);
78        id
79    }
80
81    pub fn synth_zeros(&mut self, name: &str, len: usize) -> HirNodeId {
82        self.synth_param(
83            name,
84            vec![0f32; len],
85            Shape::new(&[len], rlx_ir::DType::F32),
86        )
87    }
88
89    pub fn hir_and_params(&mut self) -> (&mut HirModule, &mut HashMap<String, Vec<f32>>) {
90        (
91            self.module
92                .as_hir_mut()
93                .expect("flow context requires HIR stage"),
94            self.params,
95        )
96    }
97
98    pub fn wrap(&self, id: HirNodeId, shape: Shape) -> FlowValue {
99        FlowValue::new(id, shape)
100    }
101
102    /// Look up a declared graph input (see [`FlowState::inputs`]).
103    pub fn flow_input(&self, name: &str) -> Result<FlowValue> {
104        let (id, shape) = self
105            .state
106            .inputs
107            .get(name)
108            .ok_or_else(|| anyhow::anyhow!("flow input missing `{name}`"))?;
109        Ok(FlowValue::new(*id, shape.clone()))
110    }
111
112    pub fn set_named(&mut self, key: impl Into<String>, id: HirNodeId) {
113        self.state.named.insert(key.into(), id);
114    }
115
116    pub fn named(&self, key: &str) -> Result<HirNodeId> {
117        self.state
118            .named
119            .get(key)
120            .copied()
121            .ok_or_else(|| anyhow::anyhow!("named flow handle missing `{key}`"))
122    }
123}
124
125pub use crate::context::DecodeBindings;