Skip to main content

reifydb_routine/procedure/
registry.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright (c) 2025 ReifyDB
3
4use std::{
5	collections::HashMap,
6	mem,
7	ops::Deref,
8	sync::{Arc, Mutex},
9};
10
11use reifydb_catalog::materialized::MaterializedCatalog;
12use reifydb_type::value::sumtype::VariantRef;
13
14use super::Procedure;
15
16type ProcedureFactory = Arc<dyn Fn() -> Box<dyn Procedure> + Send + Sync>;
17
18#[derive(Clone)]
19pub struct Procedures(Arc<ProceduresInner>);
20
21impl Procedures {
22	pub fn empty() -> Procedures {
23		Procedures::builder().build()
24	}
25
26	pub fn builder() -> ProceduresBuilder {
27		ProceduresBuilder {
28			procedures: HashMap::new(),
29			deferred_handlers: Vec::new(),
30		}
31	}
32}
33
34impl Deref for Procedures {
35	type Target = ProceduresInner;
36
37	fn deref(&self) -> &Self::Target {
38		&self.0
39	}
40}
41
42struct RegistryState {
43	procedures: HashMap<String, ProcedureFactory>,
44	resolved_handlers: HashMap<VariantRef, Vec<ProcedureFactory>>,
45	deferred_handlers: Vec<(String, ProcedureFactory)>,
46}
47
48pub struct ProceduresInner {
49	state: Arc<Mutex<RegistryState>>,
50}
51
52impl Clone for ProceduresInner {
53	fn clone(&self) -> Self {
54		Self {
55			state: Arc::clone(&self.state),
56		}
57	}
58}
59
60impl ProceduresInner {
61	pub fn get_procedure(&self, name: &str) -> Option<Box<dyn Procedure>> {
62		self.state.lock().unwrap().procedures.get(name).map(|f| f())
63	}
64
65	pub fn has_procedure(&self, name: &str) -> bool {
66		self.state.lock().unwrap().procedures.contains_key(name)
67	}
68
69	pub fn get_handlers(&self, catalog: &MaterializedCatalog, variant: VariantRef) -> Vec<Box<dyn Procedure>> {
70		let mut state = self.state.lock().unwrap();
71		if !state.deferred_handlers.is_empty() {
72			let deferred = mem::take(&mut state.deferred_handlers);
73			let mut still_deferred = Vec::new();
74			for (path, factory) in deferred {
75				match resolve_event_path(&path, catalog) {
76					Ok(resolved) => {
77						state.resolved_handlers.entry(resolved).or_default().push(factory);
78					}
79					Err(_) => still_deferred.push((path, factory)),
80				}
81			}
82			state.deferred_handlers = still_deferred;
83		}
84		state.resolved_handlers
85			.get(&variant)
86			.map(|factories| factories.iter().map(|f| f()).collect())
87			.unwrap_or_default()
88	}
89}
90
91pub struct ProceduresBuilder {
92	procedures: HashMap<String, ProcedureFactory>,
93	deferred_handlers: Vec<(String, ProcedureFactory)>,
94}
95
96impl ProceduresBuilder {
97	pub fn with_procedure<F, P>(mut self, name: &str, init: F) -> Self
98	where
99		F: Fn() -> P + Send + Sync + 'static,
100		P: Procedure + 'static,
101	{
102		self.procedures.insert(name.to_string(), Arc::new(move || Box::new(init()) as Box<dyn Procedure>));
103
104		self
105	}
106
107	/// Register an event handler by path.
108	///
109	/// `event_path` uses the format `"namespace::event_name::VariantName"`.
110	/// The handler is resolved lazily on first dispatch.
111	pub fn with_handler<F, P>(mut self, event_path: &str, init: F) -> Self
112	where
113		F: Fn() -> P + Send + Sync + 'static,
114		P: Procedure + 'static,
115	{
116		self.deferred_handlers
117			.push((event_path.to_string(), Arc::new(move || Box::new(init()) as Box<dyn Procedure>)));
118		self
119	}
120
121	pub fn build(self) -> Procedures {
122		Procedures(Arc::new(ProceduresInner {
123			state: Arc::new(Mutex::new(RegistryState {
124				procedures: self.procedures,
125				resolved_handlers: HashMap::new(),
126				deferred_handlers: self.deferred_handlers,
127			})),
128		}))
129	}
130}
131
132fn resolve_event_path(path: &str, catalog: &MaterializedCatalog) -> Result<VariantRef, String> {
133	let parts: Vec<&str> = path.split("::").collect();
134	if parts.len() != 3 {
135		return Err(format!(
136			"Invalid event path '{}': expected format 'namespace::event_name::VariantName'",
137			path
138		));
139	}
140	let (namespace_name, event_name, variant_name) = (parts[0], parts[1], parts[2]);
141
142	let namespace = catalog
143		.find_namespace_by_name(namespace_name)
144		.ok_or_else(|| format!("Namespace '{}' not found", namespace_name))?;
145
146	let sumtype = catalog
147		.find_sumtype_by_name(namespace.id(), event_name)
148		.ok_or_else(|| format!("SumType '{}' not found in namespace '{}'", event_name, namespace_name))?;
149
150	let variant_name_lower = variant_name.to_lowercase();
151	let variant = sumtype.variants.iter().find(|v| v.name == variant_name_lower).ok_or_else(|| {
152		format!("Variant '{}' not found in sumtype '{}::{}'", variant_name, namespace_name, event_name)
153	})?;
154
155	Ok(VariantRef {
156		sumtype_id: sumtype.id,
157		variant_tag: variant.tag,
158	})
159}