1use std::sync::{OnceLock, RwLock};
10
11use crate::hir::HirModule;
12
13pub type HirExtensionFn = fn(&mut HirModule);
15
16struct Registry {
17 entries: Vec<(&'static str, HirExtensionFn)>,
18}
19
20impl Registry {
21 const fn new() -> Self {
22 Self {
23 entries: Vec::new(),
24 }
25 }
26}
27
28static REGISTRY: OnceLock<RwLock<Registry>> = OnceLock::new();
29
30fn registry() -> &'static RwLock<Registry> {
31 REGISTRY.get_or_init(|| RwLock::new(Registry::new()))
32}
33
34pub fn register_hir_extension(name: &'static str, f: HirExtensionFn) {
36 let mut reg = registry().write().expect("hir extension registry");
37 if reg.entries.iter().any(|(n, _)| *n == name) {
38 reg.entries.retain(|(n, _)| *n != name);
39 }
40 reg.entries.push((name, f));
41}
42
43pub fn registered_hir_extensions() -> Vec<&'static str> {
45 registry()
46 .read()
47 .expect("hir extension registry")
48 .entries
49 .iter()
50 .map(|(n, _)| *n)
51 .collect()
52}
53
54pub fn apply_hir_extensions(hir: &mut HirModule) {
56 let fns: Vec<HirExtensionFn> = registry()
57 .read()
58 .expect("hir extension registry")
59 .entries
60 .iter()
61 .map(|(_, f)| *f)
62 .collect();
63 for f in fns {
64 f(hir);
65 }
66}
67
68pub fn apply_hir_extensions_named(hir: &mut HirModule, names: &[&str]) {
70 let reg = registry().read().expect("hir extension registry");
71 for (name, f) in ®.entries {
72 if names.contains(name) {
73 f(hir);
74 }
75 }
76}
77
78#[cfg(test)]
79mod tests {
80 use super::*;
81 use crate::hir::HirMut;
82 use crate::{DType, HirModule, Shape};
83
84 fn tag_outputs(hir: &mut HirModule) {
85 if let Some(id) = hir.outputs.first().copied() {
86 hir.node_mut(id).name = Some("extended".into());
87 }
88 }
89
90 #[test]
91 fn extension_runs_on_module() {
92 register_hir_extension("test_tag", tag_outputs);
93 let mut hir = HirModule::new("ext");
94 let mut gb = HirMut::new(&mut hir);
95 let x = gb.input("x", Shape::new(&[2], DType::F32));
96 hir.set_outputs(vec![x]);
97 apply_hir_extensions_named(&mut hir, &["test_tag"]);
98 let out = hir.node(hir.outputs[0]);
99 assert_eq!(out.name.as_deref(), Some("extended"));
100 }
101}