Skip to main content

rlx_ir/
hir_extension.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3
4//! Retroactive HIR extensions (Slang `extension` declarations).
5//!
6//! Third-party or arch-specific crates register transforms that run on a built
7//! [`HirModule`] before lower — without editing core block definitions.
8
9use std::sync::{OnceLock, RwLock};
10
11use crate::hir::HirModule;
12
13/// Transform applied after model flow build, before MIR lower.
14pub type HirExtensionFn = fn(&mut HirModule);
15
16struct Registry {
17    entries: Vec<(&'static str, HirExtensionFn)>,
18}
19
20impl Registry {
21    const fn new() -> Self {
22        Self {
23            entries: Vec::new(),
24        }
25    }
26}
27
28static REGISTRY: OnceLock<RwLock<Registry>> = OnceLock::new();
29
30fn registry() -> &'static RwLock<Registry> {
31    REGISTRY.get_or_init(|| RwLock::new(Registry::new()))
32}
33
34/// Register a named extension (call from `init` or model crate startup).
35pub fn register_hir_extension(name: &'static str, f: HirExtensionFn) {
36    let mut reg = registry().write().expect("hir extension registry");
37    if reg.entries.iter().any(|(n, _)| *n == name) {
38        reg.entries.retain(|(n, _)| *n != name);
39    }
40    reg.entries.push((name, f));
41}
42
43/// Registered extension names in registration order.
44pub fn registered_hir_extensions() -> Vec<&'static str> {
45    registry()
46        .read()
47        .expect("hir extension registry")
48        .entries
49        .iter()
50        .map(|(n, _)| *n)
51        .collect()
52}
53
54/// Apply all registered extensions in order.
55pub fn apply_hir_extensions(hir: &mut HirModule) {
56    let fns: Vec<HirExtensionFn> = registry()
57        .read()
58        .expect("hir extension registry")
59        .entries
60        .iter()
61        .map(|(_, f)| *f)
62        .collect();
63    for f in fns {
64        f(hir);
65    }
66}
67
68/// Apply only extensions whose names appear in `names`.
69pub fn apply_hir_extensions_named(hir: &mut HirModule, names: &[&str]) {
70    let reg = registry().read().expect("hir extension registry");
71    for (name, f) in &reg.entries {
72        if names.contains(name) {
73            f(hir);
74        }
75    }
76}
77
78#[cfg(test)]
79mod tests {
80    use super::*;
81    use crate::hir::HirMut;
82    use crate::{DType, HirModule, Shape};
83
84    fn tag_outputs(hir: &mut HirModule) {
85        if let Some(id) = hir.outputs.first().copied() {
86            hir.node_mut(id).name = Some("extended".into());
87        }
88    }
89
90    #[test]
91    fn extension_runs_on_module() {
92        register_hir_extension("test_tag", tag_outputs);
93        let mut hir = HirModule::new("ext");
94        let mut gb = HirMut::new(&mut hir);
95        let x = gb.input("x", Shape::new(&[2], DType::F32));
96        hir.set_outputs(vec![x]);
97        apply_hir_extensions_named(&mut hir, &["test_tag"]);
98        let out = hir.node(hir.outputs[0]);
99        assert_eq!(out.name.as_deref(), Some("extended"));
100    }
101}