reifydb_routine/procedure/
registry.rs1use std::{
5 collections::HashMap,
6 mem,
7 ops::Deref,
8 sync::{Arc, Mutex},
9};
10
11use reifydb_catalog::materialized::MaterializedCatalog;
12use reifydb_type::value::sumtype::VariantRef;
13
14use super::Procedure;
15
16type ProcedureFactory = Arc<dyn Fn() -> Box<dyn Procedure> + Send + Sync>;
17
18#[derive(Clone)]
19pub struct Procedures(Arc<ProceduresInner>);
20
21impl Procedures {
22 pub fn empty() -> Procedures {
23 Procedures::builder().build()
24 }
25
26 pub fn builder() -> ProceduresBuilder {
27 ProceduresBuilder {
28 procedures: HashMap::new(),
29 deferred_handlers: Vec::new(),
30 }
31 }
32}
33
34impl Deref for Procedures {
35 type Target = ProceduresInner;
36
37 fn deref(&self) -> &Self::Target {
38 &self.0
39 }
40}
41
42struct RegistryState {
43 procedures: HashMap<String, ProcedureFactory>,
44 resolved_handlers: HashMap<VariantRef, Vec<ProcedureFactory>>,
45 deferred_handlers: Vec<(String, ProcedureFactory)>,
46}
47
48pub struct ProceduresInner {
49 state: Arc<Mutex<RegistryState>>,
50}
51
52impl Clone for ProceduresInner {
53 fn clone(&self) -> Self {
54 Self {
55 state: Arc::clone(&self.state),
56 }
57 }
58}
59
60impl ProceduresInner {
61 pub fn get_procedure(&self, name: &str) -> Option<Box<dyn Procedure>> {
62 self.state.lock().unwrap().procedures.get(name).map(|f| f())
63 }
64
65 pub fn has_procedure(&self, name: &str) -> bool {
66 self.state.lock().unwrap().procedures.contains_key(name)
67 }
68
69 pub fn get_handlers(&self, catalog: &MaterializedCatalog, variant: VariantRef) -> Vec<Box<dyn Procedure>> {
70 let mut state = self.state.lock().unwrap();
71 if !state.deferred_handlers.is_empty() {
72 let deferred = mem::take(&mut state.deferred_handlers);
73 let mut still_deferred = Vec::new();
74 for (path, factory) in deferred {
75 match resolve_event_path(&path, catalog) {
76 Ok(resolved) => {
77 state.resolved_handlers.entry(resolved).or_default().push(factory);
78 }
79 Err(_) => still_deferred.push((path, factory)),
80 }
81 }
82 state.deferred_handlers = still_deferred;
83 }
84 state.resolved_handlers
85 .get(&variant)
86 .map(|factories| factories.iter().map(|f| f()).collect())
87 .unwrap_or_default()
88 }
89}
90
91pub struct ProceduresBuilder {
92 procedures: HashMap<String, ProcedureFactory>,
93 deferred_handlers: Vec<(String, ProcedureFactory)>,
94}
95
96impl ProceduresBuilder {
97 pub fn with_procedure<F, P>(mut self, name: &str, init: F) -> Self
98 where
99 F: Fn() -> P + Send + Sync + 'static,
100 P: Procedure + 'static,
101 {
102 self.procedures.insert(name.to_string(), Arc::new(move || Box::new(init()) as Box<dyn Procedure>));
103
104 self
105 }
106
107 pub fn with_handler<F, P>(mut self, event_path: &str, init: F) -> Self
112 where
113 F: Fn() -> P + Send + Sync + 'static,
114 P: Procedure + 'static,
115 {
116 self.deferred_handlers
117 .push((event_path.to_string(), Arc::new(move || Box::new(init()) as Box<dyn Procedure>)));
118 self
119 }
120
121 pub fn build(self) -> Procedures {
122 Procedures(Arc::new(ProceduresInner {
123 state: Arc::new(Mutex::new(RegistryState {
124 procedures: self.procedures,
125 resolved_handlers: HashMap::new(),
126 deferred_handlers: self.deferred_handlers,
127 })),
128 }))
129 }
130}
131
132fn resolve_event_path(path: &str, catalog: &MaterializedCatalog) -> Result<VariantRef, String> {
133 let parts: Vec<&str> = path.split("::").collect();
134 if parts.len() != 3 {
135 return Err(format!(
136 "Invalid event path '{}': expected format 'namespace::event_name::VariantName'",
137 path
138 ));
139 }
140 let (namespace_name, event_name, variant_name) = (parts[0], parts[1], parts[2]);
141
142 let namespace = catalog
143 .find_namespace_by_name(namespace_name)
144 .ok_or_else(|| format!("Namespace '{}' not found", namespace_name))?;
145
146 let sumtype = catalog
147 .find_sumtype_by_name(namespace.id(), event_name)
148 .ok_or_else(|| format!("SumType '{}' not found in namespace '{}'", event_name, namespace_name))?;
149
150 let variant_name_lower = variant_name.to_lowercase();
151 let variant = sumtype.variants.iter().find(|v| v.name == variant_name_lower).ok_or_else(|| {
152 format!("Variant '{}' not found in sumtype '{}::{}'", variant_name, namespace_name, event_name)
153 })?;
154
155 Ok(VariantRef {
156 sumtype_id: sumtype.id,
157 variant_tag: variant.tag,
158 })
159}