Skip to main content

reifydb_routine/function/
registry.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright (c) 2025 ReifyDB
3
4use std::{collections::HashMap, ops::Deref, sync::Arc};
5
6use super::{AggregateFunction, GeneratorFunction, ScalarFunction};
7
8#[derive(Clone)]
9pub struct Functions(Arc<FunctionsInner>);
10
11impl Functions {
12	pub fn empty() -> Functions {
13		Functions::builder().build()
14	}
15
16	pub fn builder() -> FunctionsBuilder {
17		FunctionsBuilder(FunctionsInner {
18			scalars: HashMap::new(),
19			aggregates: HashMap::new(),
20			generators: HashMap::new(),
21		})
22	}
23}
24
25impl Deref for Functions {
26	type Target = FunctionsInner;
27
28	fn deref(&self) -> &Self::Target {
29		&self.0
30	}
31}
32
33#[derive(Clone)]
34pub struct FunctionsInner {
35	scalars: HashMap<String, Arc<dyn Fn() -> Box<dyn ScalarFunction> + Send + Sync>>,
36	aggregates: HashMap<String, Arc<dyn Fn() -> Box<dyn AggregateFunction> + Send + Sync>>,
37	generators: HashMap<String, Arc<dyn Fn() -> Box<dyn GeneratorFunction> + Send + Sync>>,
38}
39
40impl FunctionsInner {
41	pub fn get_aggregate(&self, name: &str) -> Option<Box<dyn AggregateFunction>> {
42		self.aggregates.get(name).map(|func| func())
43	}
44
45	pub fn get_scalar(&self, name: &str) -> Option<Box<dyn ScalarFunction>> {
46		self.scalars.get(name).map(|func| func())
47	}
48
49	pub fn get_generator(&self, name: &str) -> Option<Box<dyn GeneratorFunction>> {
50		self.generators.get(name).map(|func| func())
51	}
52
53	pub fn scalar_names(&self) -> Vec<&str> {
54		self.scalars.keys().map(|s| s.as_str()).collect()
55	}
56
57	pub fn aggregate_names(&self) -> Vec<&str> {
58		self.aggregates.keys().map(|s| s.as_str()).collect()
59	}
60
61	pub fn generator_names(&self) -> Vec<&str> {
62		self.generators.keys().map(|s| s.as_str()).collect()
63	}
64
65	pub fn get_scalar_factory(&self, name: &str) -> Option<Arc<dyn Fn() -> Box<dyn ScalarFunction> + Send + Sync>> {
66		self.scalars.get(name).cloned()
67	}
68
69	pub fn get_aggregate_factory(
70		&self,
71		name: &str,
72	) -> Option<Arc<dyn Fn() -> Box<dyn AggregateFunction> + Send + Sync>> {
73		self.aggregates.get(name).cloned()
74	}
75}
76
77pub struct FunctionsBuilder(FunctionsInner);
78
79impl FunctionsBuilder {
80	pub fn register_scalar<F, A>(mut self, name: &str, init: F) -> Self
81	where
82		F: Fn() -> A + Send + Sync + 'static,
83		A: ScalarFunction + 'static,
84	{
85		self.0.scalars.insert(name.to_string(), Arc::new(move || Box::new(init()) as Box<dyn ScalarFunction>));
86
87		self
88	}
89
90	pub fn register_aggregate<F, A>(mut self, name: &str, init: F) -> Self
91	where
92		F: Fn() -> A + Send + Sync + 'static,
93		A: AggregateFunction + 'static,
94	{
95		self.0.aggregates
96			.insert(name.to_string(), Arc::new(move || Box::new(init()) as Box<dyn AggregateFunction>));
97
98		self
99	}
100
101	pub fn register_generator<F, G>(mut self, name: &str, init: F) -> Self
102	where
103		F: Fn() -> G + Send + Sync + 'static,
104		G: GeneratorFunction + 'static,
105	{
106		self.0.generators
107			.insert(name.to_string(), Arc::new(move || Box::new(init()) as Box<dyn GeneratorFunction>));
108
109		self
110	}
111
112	pub fn build(self) -> Functions {
113		Functions(Arc::new(self.0))
114	}
115}