Skip to main content

nodedb_sql/functions/
registry.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! Built-in function registry for SQL planning.
4//!
5//! Tracks known functions, their categories, arg specs, return types,
6//! and whether they trigger special engine routing (e.g., vector_distance
7//! → VectorSearch).
8
9use nodedb_types::columnar::ColumnType;
10
11use super::builtins::builtin_functions;
12
13/// Semantic version for tracking when a function was introduced.
14#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
15pub struct Version {
16    pub major: u8,
17    pub minor: u8,
18    pub patch: u8,
19}
20
21impl Version {
22    pub const fn new(major: u8, minor: u8, patch: u8) -> Self {
23        Self {
24            major,
25            minor,
26            patch,
27        }
28    }
29}
30
31/// Specification for a single function argument position.
32#[derive(Debug, Clone, Copy, PartialEq, Eq)]
33pub struct ArgTypeSpec {
34    /// Argument name, for documentation and error messages.
35    pub name: &'static str,
36    /// Accepted column types. Empty slice means any type is accepted (wildcard).
37    pub accepted: &'static [ColumnType],
38    /// If true on the last argument, this argument may repeat zero or more times.
39    pub variadic: bool,
40}
41
42/// Function category.
43#[derive(Debug, Clone, Copy, PartialEq, Eq)]
44pub enum FunctionCategory {
45    Scalar,
46    Aggregate,
47    Window,
48}
49
50/// Whether a function triggers special engine routing.
51#[derive(Debug, Clone, Copy, PartialEq, Eq)]
52pub enum SearchTrigger {
53    None,
54    VectorSearch,
55    MultiVectorSearch,
56    TextSearch,
57    HybridSearch,
58    TextMatch,
59    SpatialDWithin,
60    SpatialContains,
61    SpatialIntersects,
62    SpatialWithin,
63    TimeBucket,
64    /// Array engine table-valued read: `ARRAY_SLICE(name, slice_obj, attrs?, limit?)`.
65    ArraySlice,
66    /// Array engine table-valued read: `ARRAY_PROJECT(name, attrs)`.
67    ArrayProject,
68    /// Array engine table-valued aggregate:
69    /// `ARRAY_AGG(name, attr, reducer, group_by_dim?)`.
70    ArrayAgg,
71    /// Array engine table-valued elementwise:
72    /// `ARRAY_ELEMENTWISE(left, right, op, attr)`.
73    ArrayElementwise,
74    /// Array engine maintenance scalar (returns BOOL): `ARRAY_FLUSH(name)`.
75    ArrayFlush,
76    /// Array engine maintenance scalar (returns BOOL): `ARRAY_COMPACT(name)`.
77    ArrayCompact,
78    /// Planner-intercepted graph-distance marker for three-source RRF fusion.
79    /// `graph_score(node_id_col, seed_id, depth => N, label => 'edge_label')`
80    /// is never evaluated per-row; the hybrid planner extracts its arguments
81    /// and builds a graph BFS spec in the physical plan.
82    GraphSearch,
83}
84
85/// Metadata about a known function.
86#[derive(Debug, Clone)]
87pub struct FunctionMeta {
88    pub name: &'static str,
89    pub category: FunctionCategory,
90    pub min_args: usize,
91    pub max_args: usize,
92    pub search_trigger: SearchTrigger,
93    /// Static return type, when known at plan time. `None` means the type
94    /// is context-dependent or unknown (resolved at runtime).
95    pub return_type: Option<ColumnType>,
96    /// Per-position argument type specifications.
97    pub arg_types: &'static [ArgTypeSpec],
98    /// Version in which this function was introduced.
99    pub since: Version,
100}
101
102/// The function registry.
103pub struct FunctionRegistry {
104    functions: Vec<FunctionMeta>,
105}
106
107impl FunctionRegistry {
108    /// Create the default registry with all built-in functions.
109    pub fn new() -> Self {
110        Self {
111            functions: builtin_functions(),
112        }
113    }
114
115    /// Look up a function by name (case-insensitive).
116    pub fn lookup(&self, name: &str) -> Option<&FunctionMeta> {
117        let lower = name.to_lowercase();
118        self.functions.iter().find(|f| f.name == lower)
119    }
120
121    /// Check if a function triggers special search routing.
122    pub fn search_trigger(&self, name: &str) -> SearchTrigger {
123        self.lookup(name)
124            .map(|f| f.search_trigger)
125            .unwrap_or(SearchTrigger::None)
126    }
127
128    /// Check if a function is an aggregate.
129    pub fn is_aggregate(&self, name: &str) -> bool {
130        self.lookup(name)
131            .is_some_and(|f| f.category == FunctionCategory::Aggregate)
132    }
133
134    /// Check if a function is a window function.
135    pub fn is_window(&self, name: &str) -> bool {
136        self.lookup(name)
137            .is_some_and(|f| f.category == FunctionCategory::Window)
138    }
139}
140
141impl Default for FunctionRegistry {
142    fn default() -> Self {
143        Self::new()
144    }
145}
146
147#[cfg(test)]
148mod tests {
149    use super::*;
150
151    #[test]
152    fn lookup_builtin() {
153        let reg = FunctionRegistry::new();
154        assert!(reg.is_aggregate("COUNT"));
155        assert!(reg.is_aggregate("sum"));
156        assert!(!reg.is_aggregate("vector_distance"));
157        assert!(reg.is_window("row_number"));
158        assert_eq!(
159            reg.search_trigger("vector_distance"),
160            SearchTrigger::VectorSearch
161        );
162        assert_eq!(
163            reg.search_trigger("st_dwithin"),
164            SearchTrigger::SpatialDWithin
165        );
166        assert_eq!(reg.search_trigger("time_bucket"), SearchTrigger::TimeBucket);
167    }
168
169    #[test]
170    fn chunk_text_renamed_to_ndb_chunk_text() {
171        let reg = FunctionRegistry::new();
172        assert!(
173            reg.lookup("ndb_chunk_text").is_some(),
174            "ndb_chunk_text must be registered"
175        );
176        assert!(
177            reg.lookup("chunk_text").is_none(),
178            "old name chunk_text must not exist"
179        );
180    }
181
182    #[test]
183    fn unimplemented_functions_removed() {
184        let reg = FunctionRegistry::new();
185        assert!(reg.lookup("currency").is_none());
186        assert!(reg.lookup("distribute").is_none());
187        assert!(reg.lookup("allocate").is_none());
188        assert!(reg.lookup("resolve_permission").is_none());
189        assert!(reg.lookup("convert_currency").is_none());
190    }
191
192    #[test]
193    fn all_builtins_have_since_set() {
194        let reg = FunctionRegistry::new();
195        let v0_1_0 = Version::new(0, 1, 0);
196        for f in &reg.functions {
197            assert_eq!(
198                f.since, v0_1_0,
199                "function '{}' must have since = Version::new(0, 1, 0)",
200                f.name
201            );
202        }
203    }
204
205    #[test]
206    fn all_builtins_arg_counts_consistent() {
207        let reg = FunctionRegistry::new();
208        for f in &reg.functions {
209            let n = f.arg_types.len();
210            let last_variadic = f.arg_types.last().is_some_and(|a| a.variadic);
211            if last_variadic {
212                // min_args must be <= arg_types.len()
213                assert!(
214                    f.min_args <= n,
215                    "function '{}': min_args ({}) > arg_types.len() ({}) with variadic last arg",
216                    f.name,
217                    f.min_args,
218                    n
219                );
220            } else {
221                // min_args <= n <= max_args
222                assert!(
223                    f.min_args <= n,
224                    "function '{}': min_args ({}) > arg_types.len() ({})",
225                    f.name,
226                    f.min_args,
227                    n
228                );
229                assert!(
230                    n <= f.max_args,
231                    "function '{}': arg_types.len() ({}) > max_args ({})",
232                    f.name,
233                    n,
234                    f.max_args
235                );
236            }
237        }
238    }
239
240    #[test]
241    fn return_type_spot_checks() {
242        let reg = FunctionRegistry::new();
243        assert_eq!(
244            reg.lookup("now").and_then(|f| f.return_type),
245            Some(ColumnType::Timestamptz),
246            "now() must return Timestamptz"
247        );
248        assert_eq!(
249            reg.lookup("count").and_then(|f| f.return_type),
250            Some(ColumnType::Int64),
251            "count must return Int64"
252        );
253        assert_eq!(
254            reg.lookup("doc_exists").and_then(|f| f.return_type),
255            Some(ColumnType::Bool),
256            "doc_exists must return Bool"
257        );
258        assert_eq!(
259            reg.lookup("st_contains").and_then(|f| f.return_type),
260            Some(ColumnType::Bool),
261            "st_contains must return Bool"
262        );
263        assert_eq!(
264            reg.lookup("pg_fts_match").and_then(|f| f.return_type),
265            Some(ColumnType::Bool),
266            "pg_fts_match must return Bool"
267        );
268        assert_eq!(
269            reg.lookup("lower").and_then(|f| f.return_type),
270            Some(ColumnType::String),
271            "lower must return String"
272        );
273    }
274}