Skip to main content

reifydb_engine/procedure/
registry.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// Copyright (c) 2025 ReifyDB
3
4use std::{
5	collections::HashMap,
6	mem,
7	ops::Deref,
8	sync::{Arc, Mutex},
9};
10
11use reifydb_catalog::materialized::MaterializedCatalog;
12use reifydb_type::value::sumtype::SumTypeId;
13
14use super::Procedure;
15
16type ProcedureFactory = Arc<dyn Fn() -> Box<dyn Procedure> + Send + Sync>;
17
18#[derive(Clone)]
19pub struct Procedures(Arc<ProceduresInner>);
20
21impl Procedures {
22	pub fn empty() -> Procedures {
23		Procedures::builder().build()
24	}
25
26	pub fn builder() -> ProceduresBuilder {
27		ProceduresBuilder {
28			procedures: HashMap::new(),
29			deferred_handlers: Vec::new(),
30		}
31	}
32}
33
34impl Deref for Procedures {
35	type Target = ProceduresInner;
36
37	fn deref(&self) -> &Self::Target {
38		&self.0
39	}
40}
41
42struct RegistryState {
43	procedures: HashMap<String, ProcedureFactory>,
44	resolved_handlers: HashMap<(SumTypeId, u8), Vec<ProcedureFactory>>,
45	deferred_handlers: Vec<(String, ProcedureFactory)>,
46}
47
48pub struct ProceduresInner {
49	state: Arc<Mutex<RegistryState>>,
50}
51
52impl Clone for ProceduresInner {
53	fn clone(&self) -> Self {
54		Self {
55			state: Arc::clone(&self.state),
56		}
57	}
58}
59
60impl ProceduresInner {
61	pub fn get_procedure(&self, name: &str) -> Option<Box<dyn Procedure>> {
62		self.state.lock().unwrap().procedures.get(name).map(|f| f())
63	}
64
65	pub fn has_procedure(&self, name: &str) -> bool {
66		self.state.lock().unwrap().procedures.contains_key(name)
67	}
68
69	pub fn get_handlers(
70		&self,
71		catalog: &MaterializedCatalog,
72		sumtype_id: SumTypeId,
73		variant_tag: u8,
74	) -> Vec<Box<dyn Procedure>> {
75		let mut state = self.state.lock().unwrap();
76		if !state.deferred_handlers.is_empty() {
77			let deferred = mem::take(&mut state.deferred_handlers);
78			let mut still_deferred = Vec::new();
79			for (path, factory) in deferred {
80				match resolve_event_path(&path, catalog) {
81					Ok((sid, tag)) => {
82						state.resolved_handlers.entry((sid, tag)).or_default().push(factory);
83					}
84					Err(_) => still_deferred.push((path, factory)),
85				}
86			}
87			state.deferred_handlers = still_deferred;
88		}
89		state.resolved_handlers
90			.get(&(sumtype_id, variant_tag))
91			.map(|factories| factories.iter().map(|f| f()).collect())
92			.unwrap_or_default()
93	}
94}
95
96pub struct ProceduresBuilder {
97	procedures: HashMap<String, ProcedureFactory>,
98	deferred_handlers: Vec<(String, ProcedureFactory)>,
99}
100
101impl ProceduresBuilder {
102	pub fn with_procedure<F, P>(mut self, name: &str, init: F) -> Self
103	where
104		F: Fn() -> P + Send + Sync + 'static,
105		P: Procedure + 'static,
106	{
107		self.procedures.insert(name.to_string(), Arc::new(move || Box::new(init()) as Box<dyn Procedure>));
108
109		self
110	}
111
112	/// Register an event handler by path.
113	///
114	/// `event_path` uses the format `"namespace::event_name::VariantName"`.
115	/// The handler is resolved lazily on first dispatch.
116	pub fn with_handler<F, P>(mut self, event_path: &str, init: F) -> Self
117	where
118		F: Fn() -> P + Send + Sync + 'static,
119		P: Procedure + 'static,
120	{
121		self.deferred_handlers
122			.push((event_path.to_string(), Arc::new(move || Box::new(init()) as Box<dyn Procedure>)));
123		self
124	}
125
126	pub fn build(self) -> Procedures {
127		Procedures(Arc::new(ProceduresInner {
128			state: Arc::new(Mutex::new(RegistryState {
129				procedures: self.procedures,
130				resolved_handlers: HashMap::new(),
131				deferred_handlers: self.deferred_handlers,
132			})),
133		}))
134	}
135}
136
137fn resolve_event_path(path: &str, catalog: &MaterializedCatalog) -> Result<(SumTypeId, u8), String> {
138	let parts: Vec<&str> = path.split("::").collect();
139	if parts.len() != 3 {
140		return Err(format!(
141			"Invalid event path '{}': expected format 'namespace::event_name::VariantName'",
142			path
143		));
144	}
145	let (namespace_name, event_name, variant_name) = (parts[0], parts[1], parts[2]);
146
147	let namespace_def = catalog
148		.find_namespace_by_name(namespace_name)
149		.ok_or_else(|| format!("Namespace '{}' not found", namespace_name))?;
150
151	let sumtype_def = catalog
152		.find_sumtype_by_name(namespace_def.id, event_name)
153		.ok_or_else(|| format!("SumType '{}' not found in namespace '{}'", event_name, namespace_name))?;
154
155	let variant_name_lower = variant_name.to_lowercase();
156	let variant = sumtype_def.variants.iter().find(|v| v.name == variant_name_lower).ok_or_else(|| {
157		format!("Variant '{}' not found in sumtype '{}::{}'", variant_name, namespace_name, event_name)
158	})?;
159
160	Ok((sumtype_def.id, variant.tag))
161}