Skip to main content

rlx_flow/
escape.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3
4//! Tier-2 escape hatch — custom HIR emission when blocks are not enough yet.
5//!
6//! Prefer adding a reusable block under `blocks/` over long-lived custom closures.
7//! Custom stages are for one-off arch experiments and novel subgraphs.
8
9use std::collections::HashMap;
10
11use anyhow::Result;
12use rlx_ir::hir::HirModule;
13use rlx_ir::{GraphModule, HirNodeId, Shape};
14
15use crate::context::{FlowCtx, FlowState};
16use crate::profile::CompileProfile;
17use crate::value::FlowValue;
18use crate::weight::WeightSource;
19
20/// Mutable emission context for custom stages (tier 2).
21pub struct Emit<'a> {
22    pub module: &'a mut GraphModule,
23    pub params: &'a mut HashMap<String, Vec<f32>>,
24    pub weights: &'a mut dyn WeightSource,
25    pub state: &'a mut FlowState,
26    pub profile: &'a CompileProfile,
27}
28
29impl<'a> Emit<'a> {
30    pub(crate) fn from_ctx(ctx: &'a mut FlowCtx<'_>) -> Self {
31        Self {
32            module: &mut ctx.module,
33            params: ctx.params,
34            weights: ctx.weights,
35            state: ctx.state,
36            profile: ctx.profile,
37        }
38    }
39
40    pub fn hir(&mut self) -> &mut HirModule {
41        self.module
42            .as_hir_mut()
43            .expect("flow context requires HIR stage")
44    }
45
46    pub fn load_param(&mut self, key: &str, transpose: bool) -> Result<HirNodeId> {
47        let cache_key = if transpose {
48            format!("{key}\0t")
49        } else {
50            key.to_string()
51        };
52        if let Some(&id) = self.state.loaded_params.get(&cache_key) {
53            return Ok(id);
54        }
55        let (data, shape) = self.weights.take(key, transpose)?;
56        let ir_shape = Shape::new(&shape, rlx_ir::DType::F32);
57        let id = self.hir().param(key, ir_shape);
58        self.params.insert(key.to_string(), data);
59        self.state.loaded_params.insert(cache_key, id);
60        Ok(id)
61    }
62
63    pub fn synth_param(&mut self, name: &str, data: Vec<f32>, shape: Shape) -> HirNodeId {
64        let id = self.hir().param(name, shape);
65        self.params.insert(name.to_string(), data);
66        id
67    }
68
69    pub fn synth_zeros(&mut self, name: &str, len: usize) -> HirNodeId {
70        self.synth_param(
71            name,
72            vec![0f32; len],
73            Shape::new(&[len], rlx_ir::DType::F32),
74        )
75    }
76
77    pub fn hir_and_params(&mut self) -> (&mut HirModule, &mut HashMap<String, Vec<f32>>) {
78        (
79            self.module
80                .as_hir_mut()
81                .expect("flow context requires HIR stage"),
82            self.params,
83        )
84    }
85
86    pub fn wrap(&self, id: HirNodeId, shape: Shape) -> FlowValue {
87        FlowValue::new(id, shape)
88    }
89
90    /// Look up a declared graph input (see [`FlowState::inputs`]).
91    pub fn flow_input(&self, name: &str) -> Result<FlowValue> {
92        let (id, shape) = self
93            .state
94            .inputs
95            .get(name)
96            .ok_or_else(|| anyhow::anyhow!("flow input missing `{name}`"))?;
97        Ok(FlowValue::new(*id, shape.clone()))
98    }
99
100    pub fn set_named(&mut self, key: impl Into<String>, id: HirNodeId) {
101        self.state.named.insert(key.into(), id);
102    }
103
104    pub fn named(&self, key: &str) -> Result<HirNodeId> {
105        self.state
106            .named
107            .get(key)
108            .copied()
109            .ok_or_else(|| anyhow::anyhow!("named flow handle missing `{key}`"))
110    }
111}
112
113pub use crate::context::DecodeBindings;