1use std::sync::{OnceLock, RwLock};
22
23use crate::hir::HirModule;
24
25pub type HirExtensionFn = fn(&mut HirModule);
27
28struct Registry {
29 entries: Vec<(&'static str, HirExtensionFn)>,
30}
31
32impl Registry {
33 const fn new() -> Self {
34 Self {
35 entries: Vec::new(),
36 }
37 }
38}
39
40static REGISTRY: OnceLock<RwLock<Registry>> = OnceLock::new();
41
42fn registry() -> &'static RwLock<Registry> {
43 REGISTRY.get_or_init(|| RwLock::new(Registry::new()))
44}
45
46pub fn register_hir_extension(name: &'static str, f: HirExtensionFn) {
48 let mut reg = registry().write().expect("hir extension registry");
49 if reg.entries.iter().any(|(n, _)| *n == name) {
50 reg.entries.retain(|(n, _)| *n != name);
51 }
52 reg.entries.push((name, f));
53}
54
55pub fn registered_hir_extensions() -> Vec<&'static str> {
57 registry()
58 .read()
59 .expect("hir extension registry")
60 .entries
61 .iter()
62 .map(|(n, _)| *n)
63 .collect()
64}
65
66pub fn apply_hir_extensions(hir: &mut HirModule) {
68 let fns: Vec<HirExtensionFn> = registry()
69 .read()
70 .expect("hir extension registry")
71 .entries
72 .iter()
73 .map(|(_, f)| *f)
74 .collect();
75 for f in fns {
76 f(hir);
77 }
78}
79
80pub fn apply_hir_extensions_named(hir: &mut HirModule, names: &[&str]) {
82 let reg = registry().read().expect("hir extension registry");
83 for (name, f) in ®.entries {
84 if names.contains(name) {
85 f(hir);
86 }
87 }
88}
89
90#[cfg(test)]
91mod tests {
92 use super::*;
93 use crate::hir::HirMut;
94 use crate::{DType, HirModule, Shape};
95
96 fn tag_outputs(hir: &mut HirModule) {
97 if let Some(id) = hir.outputs.first().copied() {
98 hir.node_mut(id).name = Some("extended".into());
99 }
100 }
101
102 #[test]
103 fn extension_runs_on_module() {
104 register_hir_extension("test_tag", tag_outputs);
105 let mut hir = HirModule::new("ext");
106 let mut gb = HirMut::new(&mut hir);
107 let x = gb.input("x", Shape::new(&[2], DType::F32));
108 hir.set_outputs(vec![x]);
109 apply_hir_extensions_named(&mut hir, &["test_tag"]);
110 let out = hir.node(hir.outputs[0]);
111 assert_eq!(out.name.as_deref(), Some("extended"));
112 }
113}