use std::{
collections::HashMap,
mem,
ops::Deref,
sync::{Arc, Mutex},
};
use reifydb_catalog::materialized::MaterializedCatalog;
use reifydb_type::value::sumtype::VariantRef;
use super::Procedure;
type ProcedureFactory = Arc<dyn Fn() -> Box<dyn Procedure> + Send + Sync>;
#[derive(Clone)]
pub struct Procedures(Arc<ProceduresInner>);
impl Procedures {
pub fn empty() -> Procedures {
Procedures::builder().configure()
}
pub fn builder() -> ProceduresConfigurator {
ProceduresConfigurator {
procedures: HashMap::new(),
deferred_handlers: Vec::new(),
}
}
}
impl Deref for Procedures {
type Target = ProceduresInner;
fn deref(&self) -> &Self::Target {
&self.0
}
}
struct RegistryState {
procedures: HashMap<String, ProcedureFactory>,
resolved_handlers: HashMap<VariantRef, Vec<ProcedureFactory>>,
deferred_handlers: Vec<(String, ProcedureFactory)>,
}
pub struct ProceduresInner {
state: Arc<Mutex<RegistryState>>,
}
impl Clone for ProceduresInner {
fn clone(&self) -> Self {
Self {
state: Arc::clone(&self.state),
}
}
}
impl ProceduresInner {
pub fn get_procedure(&self, name: &str) -> Option<Box<dyn Procedure>> {
self.state.lock().unwrap().procedures.get(name).map(|f| f())
}
pub fn has_procedure(&self, name: &str) -> bool {
self.state.lock().unwrap().procedures.contains_key(name)
}
pub fn procedure_names(&self) -> Vec<String> {
self.state.lock().unwrap().procedures.keys().cloned().collect()
}
pub fn get_handlers(&self, catalog: &MaterializedCatalog, variant: VariantRef) -> Vec<Box<dyn Procedure>> {
let mut state = self.state.lock().unwrap();
if !state.deferred_handlers.is_empty() {
let deferred = mem::take(&mut state.deferred_handlers);
let mut still_deferred = Vec::new();
for (path, factory) in deferred {
match resolve_event_path(&path, catalog) {
Ok(resolved) => {
state.resolved_handlers.entry(resolved).or_default().push(factory);
}
Err(_) => still_deferred.push((path, factory)),
}
}
state.deferred_handlers = still_deferred;
}
state.resolved_handlers
.get(&variant)
.map(|factories| factories.iter().map(|f| f()).collect())
.unwrap_or_default()
}
}
pub struct ProceduresConfigurator {
procedures: HashMap<String, ProcedureFactory>,
deferred_handlers: Vec<(String, ProcedureFactory)>,
}
impl ProceduresConfigurator {
pub fn with_procedure<F, P>(mut self, name: &str, init: F) -> Self
where
F: Fn() -> P + Send + Sync + 'static,
P: Procedure + 'static,
{
self.procedures.insert(name.to_string(), Arc::new(move || Box::new(init()) as Box<dyn Procedure>));
self
}
pub fn with_handler<F, P>(mut self, event_path: &str, init: F) -> Self
where
F: Fn() -> P + Send + Sync + 'static,
P: Procedure + 'static,
{
self.deferred_handlers
.push((event_path.to_string(), Arc::new(move || Box::new(init()) as Box<dyn Procedure>)));
self
}
pub fn configure(self) -> Procedures {
Procedures(Arc::new(ProceduresInner {
state: Arc::new(Mutex::new(RegistryState {
procedures: self.procedures,
resolved_handlers: HashMap::new(),
deferred_handlers: self.deferred_handlers,
})),
}))
}
}
fn resolve_event_path(path: &str, catalog: &MaterializedCatalog) -> Result<VariantRef, String> {
let parts: Vec<&str> = path.split("::").collect();
if parts.len() != 3 {
return Err(format!(
"Invalid event path '{}': expected format 'namespace::event_name::VariantName'",
path
));
}
let (namespace_name, event_name, variant_name) = (parts[0], parts[1], parts[2]);
let namespace = catalog
.find_namespace_by_name(namespace_name)
.ok_or_else(|| format!("Namespace '{}' not found", namespace_name))?;
let sumtype = catalog
.find_sumtype_by_name(namespace.id(), event_name)
.ok_or_else(|| format!("SumType '{}' not found in namespace '{}'", event_name, namespace_name))?;
let variant_name_lower = variant_name.to_lowercase();
let variant = sumtype.variants.iter().find(|v| v.name == variant_name_lower).ok_or_else(|| {
format!("Variant '{}' not found in sumtype '{}::{}'", variant_name, namespace_name, event_name)
})?;
Ok(VariantRef {
sumtype_id: sumtype.id,
variant_tag: variant.tag,
})
}