1use nodedb_types::columnar::ColumnType;
10
11use super::builtins::builtin_functions;
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
15pub struct Version {
16 pub major: u8,
17 pub minor: u8,
18 pub patch: u8,
19}
20
21impl Version {
22 pub const fn new(major: u8, minor: u8, patch: u8) -> Self {
23 Self {
24 major,
25 minor,
26 patch,
27 }
28 }
29}
30
31#[derive(Debug, Clone, Copy, PartialEq, Eq)]
33pub struct ArgTypeSpec {
34 pub name: &'static str,
36 pub accepted: &'static [ColumnType],
38 pub variadic: bool,
40}
41
42#[derive(Debug, Clone, Copy, PartialEq, Eq)]
44pub enum FunctionCategory {
45 Scalar,
46 Aggregate,
47 Window,
48}
49
50#[derive(Debug, Clone, Copy, PartialEq, Eq)]
52pub enum SearchTrigger {
53 None,
54 VectorSearch,
55 MultiVectorSearch,
56 TextSearch,
57 HybridSearch,
58 TextMatch,
59 SpatialDWithin,
60 SpatialContains,
61 SpatialIntersects,
62 SpatialWithin,
63 TimeBucket,
64 ArraySlice,
66 ArrayProject,
68 ArrayAgg,
71 ArrayElementwise,
74 ArrayFlush,
76 ArrayCompact,
78 GraphSearch,
83}
84
85#[derive(Debug, Clone)]
87pub struct FunctionMeta {
88 pub name: &'static str,
89 pub category: FunctionCategory,
90 pub min_args: usize,
91 pub max_args: usize,
92 pub search_trigger: SearchTrigger,
93 pub return_type: Option<ColumnType>,
96 pub arg_types: &'static [ArgTypeSpec],
98 pub since: Version,
100}
101
102pub struct FunctionRegistry {
104 functions: Vec<FunctionMeta>,
105}
106
107impl FunctionRegistry {
108 pub fn new() -> Self {
110 Self {
111 functions: builtin_functions(),
112 }
113 }
114
115 pub fn lookup(&self, name: &str) -> Option<&FunctionMeta> {
117 let lower = name.to_lowercase();
118 self.functions.iter().find(|f| f.name == lower)
119 }
120
121 pub fn search_trigger(&self, name: &str) -> SearchTrigger {
123 self.lookup(name)
124 .map(|f| f.search_trigger)
125 .unwrap_or(SearchTrigger::None)
126 }
127
128 pub fn is_aggregate(&self, name: &str) -> bool {
130 self.lookup(name)
131 .is_some_and(|f| f.category == FunctionCategory::Aggregate)
132 }
133
134 pub fn is_window(&self, name: &str) -> bool {
136 self.lookup(name)
137 .is_some_and(|f| f.category == FunctionCategory::Window)
138 }
139}
140
141impl Default for FunctionRegistry {
142 fn default() -> Self {
143 Self::new()
144 }
145}
146
147#[cfg(test)]
148mod tests {
149 use super::*;
150
151 #[test]
152 fn lookup_builtin() {
153 let reg = FunctionRegistry::new();
154 assert!(reg.is_aggregate("COUNT"));
155 assert!(reg.is_aggregate("sum"));
156 assert!(!reg.is_aggregate("vector_distance"));
157 assert!(reg.is_window("row_number"));
158 assert_eq!(
159 reg.search_trigger("vector_distance"),
160 SearchTrigger::VectorSearch
161 );
162 assert_eq!(
163 reg.search_trigger("st_dwithin"),
164 SearchTrigger::SpatialDWithin
165 );
166 assert_eq!(reg.search_trigger("time_bucket"), SearchTrigger::TimeBucket);
167 }
168
169 #[test]
170 fn chunk_text_renamed_to_ndb_chunk_text() {
171 let reg = FunctionRegistry::new();
172 assert!(
173 reg.lookup("ndb_chunk_text").is_some(),
174 "ndb_chunk_text must be registered"
175 );
176 assert!(
177 reg.lookup("chunk_text").is_none(),
178 "old name chunk_text must not exist"
179 );
180 }
181
182 #[test]
183 fn unimplemented_functions_removed() {
184 let reg = FunctionRegistry::new();
185 assert!(reg.lookup("currency").is_none());
186 assert!(reg.lookup("distribute").is_none());
187 assert!(reg.lookup("allocate").is_none());
188 assert!(reg.lookup("resolve_permission").is_none());
189 assert!(reg.lookup("convert_currency").is_none());
190 }
191
192 #[test]
193 fn all_builtins_have_since_set() {
194 let reg = FunctionRegistry::new();
195 let v0_1_0 = Version::new(0, 1, 0);
196 for f in ®.functions {
197 assert_eq!(
198 f.since, v0_1_0,
199 "function '{}' must have since = Version::new(0, 1, 0)",
200 f.name
201 );
202 }
203 }
204
205 #[test]
206 fn all_builtins_arg_counts_consistent() {
207 let reg = FunctionRegistry::new();
208 for f in ®.functions {
209 let n = f.arg_types.len();
210 let last_variadic = f.arg_types.last().is_some_and(|a| a.variadic);
211 if last_variadic {
212 assert!(
214 f.min_args <= n,
215 "function '{}': min_args ({}) > arg_types.len() ({}) with variadic last arg",
216 f.name,
217 f.min_args,
218 n
219 );
220 } else {
221 assert!(
223 f.min_args <= n,
224 "function '{}': min_args ({}) > arg_types.len() ({})",
225 f.name,
226 f.min_args,
227 n
228 );
229 assert!(
230 n <= f.max_args,
231 "function '{}': arg_types.len() ({}) > max_args ({})",
232 f.name,
233 n,
234 f.max_args
235 );
236 }
237 }
238 }
239
240 #[test]
241 fn return_type_spot_checks() {
242 let reg = FunctionRegistry::new();
243 assert_eq!(
244 reg.lookup("now").and_then(|f| f.return_type),
245 Some(ColumnType::Timestamptz),
246 "now() must return Timestamptz"
247 );
248 assert_eq!(
249 reg.lookup("count").and_then(|f| f.return_type),
250 Some(ColumnType::Int64),
251 "count must return Int64"
252 );
253 assert_eq!(
254 reg.lookup("doc_exists").and_then(|f| f.return_type),
255 Some(ColumnType::Bool),
256 "doc_exists must return Bool"
257 );
258 assert_eq!(
259 reg.lookup("st_contains").and_then(|f| f.return_type),
260 Some(ColumnType::Bool),
261 "st_contains must return Bool"
262 );
263 assert_eq!(
264 reg.lookup("pg_fts_match").and_then(|f| f.return_type),
265 Some(ColumnType::Bool),
266 "pg_fts_match must return Bool"
267 );
268 assert_eq!(
269 reg.lookup("lower").and_then(|f| f.return_type),
270 Some(ColumnType::String),
271 "lower must return String"
272 );
273 }
274}