reifydb_engine/procedure/
registry.rs1use std::{
5 collections::HashMap,
6 mem,
7 ops::Deref,
8 sync::{Arc, Mutex},
9};
10
11use reifydb_catalog::materialized::MaterializedCatalog;
12use reifydb_type::value::sumtype::SumTypeId;
13
14use super::Procedure;
15
16type ProcedureFactory = Arc<dyn Fn() -> Box<dyn Procedure> + Send + Sync>;
17
18#[derive(Clone)]
19pub struct Procedures(Arc<ProceduresInner>);
20
21impl Procedures {
22 pub fn empty() -> Procedures {
23 Procedures::builder().build()
24 }
25
26 pub fn builder() -> ProceduresBuilder {
27 ProceduresBuilder {
28 procedures: HashMap::new(),
29 deferred_handlers: Vec::new(),
30 }
31 }
32}
33
34impl Deref for Procedures {
35 type Target = ProceduresInner;
36
37 fn deref(&self) -> &Self::Target {
38 &self.0
39 }
40}
41
42struct RegistryState {
43 procedures: HashMap<String, ProcedureFactory>,
44 resolved_handlers: HashMap<(SumTypeId, u8), Vec<ProcedureFactory>>,
45 deferred_handlers: Vec<(String, ProcedureFactory)>,
46}
47
48pub struct ProceduresInner {
49 state: Arc<Mutex<RegistryState>>,
50}
51
52impl Clone for ProceduresInner {
53 fn clone(&self) -> Self {
54 Self {
55 state: Arc::clone(&self.state),
56 }
57 }
58}
59
60impl ProceduresInner {
61 pub fn get_procedure(&self, name: &str) -> Option<Box<dyn Procedure>> {
62 self.state.lock().unwrap().procedures.get(name).map(|f| f())
63 }
64
65 pub fn has_procedure(&self, name: &str) -> bool {
66 self.state.lock().unwrap().procedures.contains_key(name)
67 }
68
69 pub fn get_handlers(
70 &self,
71 catalog: &MaterializedCatalog,
72 sumtype_id: SumTypeId,
73 variant_tag: u8,
74 ) -> Vec<Box<dyn Procedure>> {
75 let mut state = self.state.lock().unwrap();
76 if !state.deferred_handlers.is_empty() {
77 let deferred = mem::take(&mut state.deferred_handlers);
78 let mut still_deferred = Vec::new();
79 for (path, factory) in deferred {
80 match resolve_event_path(&path, catalog) {
81 Ok((sid, tag)) => {
82 state.resolved_handlers.entry((sid, tag)).or_default().push(factory);
83 }
84 Err(_) => still_deferred.push((path, factory)),
85 }
86 }
87 state.deferred_handlers = still_deferred;
88 }
89 state.resolved_handlers
90 .get(&(sumtype_id, variant_tag))
91 .map(|factories| factories.iter().map(|f| f()).collect())
92 .unwrap_or_default()
93 }
94}
95
96pub struct ProceduresBuilder {
97 procedures: HashMap<String, ProcedureFactory>,
98 deferred_handlers: Vec<(String, ProcedureFactory)>,
99}
100
101impl ProceduresBuilder {
102 pub fn with_procedure<F, P>(mut self, name: &str, init: F) -> Self
103 where
104 F: Fn() -> P + Send + Sync + 'static,
105 P: Procedure + 'static,
106 {
107 self.procedures.insert(name.to_string(), Arc::new(move || Box::new(init()) as Box<dyn Procedure>));
108
109 self
110 }
111
112 pub fn with_handler<F, P>(mut self, event_path: &str, init: F) -> Self
117 where
118 F: Fn() -> P + Send + Sync + 'static,
119 P: Procedure + 'static,
120 {
121 self.deferred_handlers
122 .push((event_path.to_string(), Arc::new(move || Box::new(init()) as Box<dyn Procedure>)));
123 self
124 }
125
126 pub fn build(self) -> Procedures {
127 Procedures(Arc::new(ProceduresInner {
128 state: Arc::new(Mutex::new(RegistryState {
129 procedures: self.procedures,
130 resolved_handlers: HashMap::new(),
131 deferred_handlers: self.deferred_handlers,
132 })),
133 }))
134 }
135}
136
137fn resolve_event_path(path: &str, catalog: &MaterializedCatalog) -> Result<(SumTypeId, u8), String> {
138 let parts: Vec<&str> = path.split("::").collect();
139 if parts.len() != 3 {
140 return Err(format!(
141 "Invalid event path '{}': expected format 'namespace::event_name::VariantName'",
142 path
143 ));
144 }
145 let (namespace_name, event_name, variant_name) = (parts[0], parts[1], parts[2]);
146
147 let namespace_def = catalog
148 .find_namespace_by_name(namespace_name)
149 .ok_or_else(|| format!("Namespace '{}' not found", namespace_name))?;
150
151 let sumtype_def = catalog
152 .find_sumtype_by_name(namespace_def.id, event_name)
153 .ok_or_else(|| format!("SumType '{}' not found in namespace '{}'", event_name, namespace_name))?;
154
155 let variant_name_lower = variant_name.to_lowercase();
156 let variant = sumtype_def.variants.iter().find(|v| v.name == variant_name_lower).ok_or_else(|| {
157 format!("Variant '{}' not found in sumtype '{}::{}'", variant_name, namespace_name, event_name)
158 })?;
159
160 Ok((sumtype_def.id, variant.tag))
161}