Skip to main content

nodedb_sql/functions/
registry.rs

1//! Built-in function registry for SQL planning.
2//!
3//! Tracks known functions, their categories, and whether they trigger
4//! special engine routing (e.g., vector_distance → VectorSearch).
5
6/// Function category.
7#[derive(Debug, Clone, Copy, PartialEq, Eq)]
8pub enum FunctionCategory {
9    Scalar,
10    Aggregate,
11    Window,
12}
13
14/// Whether a function triggers special engine routing.
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum SearchTrigger {
17    None,
18    VectorSearch,
19    MultiVectorSearch,
20    TextSearch,
21    HybridSearch,
22    TextMatch,
23    SpatialDWithin,
24    SpatialContains,
25    SpatialIntersects,
26    SpatialWithin,
27    TimeBucket,
28}
29
30/// Metadata about a known function.
31#[derive(Debug, Clone)]
32pub struct FunctionMeta {
33    pub name: &'static str,
34    pub category: FunctionCategory,
35    pub min_args: usize,
36    pub max_args: usize,
37    pub search_trigger: SearchTrigger,
38}
39
40/// The function registry.
41pub struct FunctionRegistry {
42    functions: Vec<FunctionMeta>,
43}
44
45impl FunctionRegistry {
46    /// Create the default registry with all built-in functions.
47    pub fn new() -> Self {
48        Self {
49            functions: builtin_functions(),
50        }
51    }
52
53    /// Look up a function by name (case-insensitive).
54    pub fn lookup(&self, name: &str) -> Option<&FunctionMeta> {
55        let lower = name.to_lowercase();
56        self.functions.iter().find(|f| f.name == lower)
57    }
58
59    /// Check if a function triggers special search routing.
60    pub fn search_trigger(&self, name: &str) -> SearchTrigger {
61        self.lookup(name)
62            .map(|f| f.search_trigger)
63            .unwrap_or(SearchTrigger::None)
64    }
65
66    /// Check if a function is an aggregate.
67    pub fn is_aggregate(&self, name: &str) -> bool {
68        self.lookup(name)
69            .is_some_and(|f| f.category == FunctionCategory::Aggregate)
70    }
71
72    /// Check if a function is a window function.
73    pub fn is_window(&self, name: &str) -> bool {
74        self.lookup(name)
75            .is_some_and(|f| f.category == FunctionCategory::Window)
76    }
77}
78
79impl Default for FunctionRegistry {
80    fn default() -> Self {
81        Self::new()
82    }
83}
84
85fn s(
86    name: &'static str,
87    cat: FunctionCategory,
88    min: usize,
89    max: usize,
90    trigger: SearchTrigger,
91) -> FunctionMeta {
92    FunctionMeta {
93        name,
94        category: cat,
95        min_args: min,
96        max_args: max,
97        search_trigger: trigger,
98    }
99}
100
101fn builtin_functions() -> Vec<FunctionMeta> {
102    use FunctionCategory::*;
103    use SearchTrigger::*;
104
105    vec![
106        // ── Standard aggregates ──
107        s("count", Aggregate, 0, 1, None),
108        s("sum", Aggregate, 1, 1, None),
109        s("avg", Aggregate, 1, 1, None),
110        s("min", Aggregate, 1, 1, None),
111        s("max", Aggregate, 1, 1, None),
112        // ── Standard window ──
113        s("row_number", Window, 0, 0, None),
114        s("rank", Window, 0, 0, None),
115        s("dense_rank", Window, 0, 0, None),
116        s("lag", Window, 1, 3, None),
117        s("lead", Window, 1, 3, None),
118        s("first_value", Window, 1, 1, None),
119        s("last_value", Window, 1, 1, None),
120        s("nth_value", Window, 2, 2, None),
121        // ── Vector search ──
122        s("vector_distance", Scalar, 2, 3, VectorSearch),
123        s("multi_vector_search", Scalar, 1, 2, MultiVectorSearch),
124        s("multi_vector_score", Scalar, 3, 3, None),
125        s("sparse_score", Scalar, 3, 3, None),
126        // ── Text search ──
127        s("bm25_score", Scalar, 2, 2, TextSearch),
128        s("search_score", Scalar, 2, 2, TextSearch),
129        s("text_match", Scalar, 2, 3, TextMatch),
130        // ── Hybrid search ──
131        s("rrf_score", Scalar, 2, 4, HybridSearch),
132        // ── Spatial ──
133        s("st_dwithin", Scalar, 3, 3, SpatialDWithin),
134        s("st_contains", Scalar, 2, 2, SpatialContains),
135        s("st_intersects", Scalar, 2, 2, SpatialIntersects),
136        s("st_within", Scalar, 2, 2, SpatialWithin),
137        s("st_distance", Scalar, 2, 2, None),
138        s("st_point", Scalar, 2, 2, None),
139        // ── Timeseries ──
140        s("time_bucket", Scalar, 2, 2, TimeBucket),
141        // ── Timeseries aggregates ──
142        s("ts_percentile", Aggregate, 2, 2, None),
143        s("ts_stddev", Aggregate, 1, 1, None),
144        s("ts_correlate", Aggregate, 2, 2, None),
145        // ── Timeseries window ──
146        s("ts_rate", Window, 1, 1, None),
147        s("ts_derivative", Window, 1, 1, None),
148        s("ts_moving_avg", Window, 2, 2, None),
149        s("ts_ema", Window, 2, 2, None),
150        s("ts_delta", Window, 1, 1, None),
151        s("ts_interpolate", Window, 1, 1, None),
152        s("ts_lag", Window, 1, 3, None),
153        s("ts_lead", Window, 1, 3, None),
154        s("ts_rank", Window, 0, 0, None),
155        // ── Approximate aggregates ──
156        s("approx_count_distinct", Aggregate, 1, 1, None),
157        s("approx_percentile", Aggregate, 2, 2, None),
158        s("approx_topk", Aggregate, 2, 2, None),
159        s("approx_count", Aggregate, 1, 1, None),
160        // ── Document helpers ──
161        s("doc_get", Scalar, 2, 3, None),
162        s("doc_exists", Scalar, 2, 2, None),
163        s("doc_array_contains", Scalar, 3, 3, None),
164        s("nav", Scalar, 2, 2, None),
165        // ── Utility ──
166        s("chunk_text", Scalar, 2, 3, None),
167        s("currency", Scalar, 1, 2, None),
168        s("distribute", Scalar, 2, 3, None),
169        s("allocate", Scalar, 2, 3, None),
170        s("resolve_permission", Scalar, 2, 3, None),
171        // ── Standard scalar ──
172        s("coalesce", Scalar, 1, 255, None),
173        s("nullif", Scalar, 2, 2, None),
174        s("abs", Scalar, 1, 1, None),
175        s("ceil", Scalar, 1, 1, None),
176        s("floor", Scalar, 1, 1, None),
177        s("round", Scalar, 1, 2, None),
178        s("lower", Scalar, 1, 1, None),
179        s("upper", Scalar, 1, 1, None),
180        s("length", Scalar, 1, 1, None),
181        s("trim", Scalar, 1, 1, None),
182        s("substring", Scalar, 2, 3, None),
183        s("concat", Scalar, 1, 255, None),
184        s("replace", Scalar, 3, 3, None),
185        s("now", Scalar, 0, 0, None),
186        s("current_timestamp", Scalar, 0, 0, None),
187        s("make_array", Scalar, 0, 255, None),
188    ]
189}
190
191#[cfg(test)]
192mod tests {
193    use super::*;
194
195    #[test]
196    fn lookup_builtin() {
197        let reg = FunctionRegistry::new();
198        assert!(reg.is_aggregate("COUNT"));
199        assert!(reg.is_aggregate("sum"));
200        assert!(!reg.is_aggregate("vector_distance"));
201        assert!(reg.is_window("row_number"));
202        assert_eq!(
203            reg.search_trigger("vector_distance"),
204            SearchTrigger::VectorSearch
205        );
206        assert_eq!(
207            reg.search_trigger("st_dwithin"),
208            SearchTrigger::SpatialDWithin
209        );
210        assert_eq!(reg.search_trigger("time_bucket"), SearchTrigger::TimeBucket);
211    }
212}