Skip to main content

rlx_ir/
hir_extension.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Retroactive HIR extensions (Slang `extension` declarations).
17//!
18//! Third-party or arch-specific crates register transforms that run on a built
19//! [`HirModule`] before lower — without editing core block definitions.
20
21use std::sync::{OnceLock, RwLock};
22
23use crate::hir::HirModule;
24
25/// Transform applied after model flow build, before MIR lower.
26pub type HirExtensionFn = fn(&mut HirModule);
27
28struct Registry {
29    entries: Vec<(&'static str, HirExtensionFn)>,
30}
31
32impl Registry {
33    const fn new() -> Self {
34        Self {
35            entries: Vec::new(),
36        }
37    }
38}
39
40static REGISTRY: OnceLock<RwLock<Registry>> = OnceLock::new();
41
42fn registry() -> &'static RwLock<Registry> {
43    REGISTRY.get_or_init(|| RwLock::new(Registry::new()))
44}
45
46/// Register a named extension (call from `init` or model crate startup).
47pub fn register_hir_extension(name: &'static str, f: HirExtensionFn) {
48    let mut reg = registry().write().expect("hir extension registry");
49    if reg.entries.iter().any(|(n, _)| *n == name) {
50        reg.entries.retain(|(n, _)| *n != name);
51    }
52    reg.entries.push((name, f));
53}
54
55/// Registered extension names in registration order.
56pub fn registered_hir_extensions() -> Vec<&'static str> {
57    registry()
58        .read()
59        .expect("hir extension registry")
60        .entries
61        .iter()
62        .map(|(n, _)| *n)
63        .collect()
64}
65
66/// Apply all registered extensions in order.
67pub fn apply_hir_extensions(hir: &mut HirModule) {
68    let fns: Vec<HirExtensionFn> = registry()
69        .read()
70        .expect("hir extension registry")
71        .entries
72        .iter()
73        .map(|(_, f)| *f)
74        .collect();
75    for f in fns {
76        f(hir);
77    }
78}
79
80/// Apply only extensions whose names appear in `names`.
81pub fn apply_hir_extensions_named(hir: &mut HirModule, names: &[&str]) {
82    let reg = registry().read().expect("hir extension registry");
83    for (name, f) in &reg.entries {
84        if names.contains(name) {
85            f(hir);
86        }
87    }
88}
89
90#[cfg(test)]
91mod tests {
92    use super::*;
93    use crate::hir::HirMut;
94    use crate::{DType, HirModule, Shape};
95
96    fn tag_outputs(hir: &mut HirModule) {
97        if let Some(id) = hir.outputs.first().copied() {
98            hir.node_mut(id).name = Some("extended".into());
99        }
100    }
101
102    #[test]
103    fn extension_runs_on_module() {
104        register_hir_extension("test_tag", tag_outputs);
105        let mut hir = HirModule::new("ext");
106        let mut gb = HirMut::new(&mut hir);
107        let x = gb.input("x", Shape::new(&[2], DType::F32));
108        hir.set_outputs(vec![x]);
109        apply_hir_extensions_named(&mut hir, &["test_tag"]);
110        let out = hir.node(hir.outputs[0]);
111        assert_eq!(out.name.as_deref(), Some("extended"));
112    }
113}