1#[derive(Debug, Clone, Copy, PartialEq, Eq)]
8pub enum FunctionCategory {
9 Scalar,
10 Aggregate,
11 Window,
12}
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum SearchTrigger {
17 None,
18 VectorSearch,
19 MultiVectorSearch,
20 TextSearch,
21 HybridSearch,
22 TextMatch,
23 SpatialDWithin,
24 SpatialContains,
25 SpatialIntersects,
26 SpatialWithin,
27 TimeBucket,
28}
29
30#[derive(Debug, Clone)]
32pub struct FunctionMeta {
33 pub name: &'static str,
34 pub category: FunctionCategory,
35 pub min_args: usize,
36 pub max_args: usize,
37 pub search_trigger: SearchTrigger,
38}
39
40pub struct FunctionRegistry {
42 functions: Vec<FunctionMeta>,
43}
44
45impl FunctionRegistry {
46 pub fn new() -> Self {
48 Self {
49 functions: builtin_functions(),
50 }
51 }
52
53 pub fn lookup(&self, name: &str) -> Option<&FunctionMeta> {
55 let lower = name.to_lowercase();
56 self.functions.iter().find(|f| f.name == lower)
57 }
58
59 pub fn search_trigger(&self, name: &str) -> SearchTrigger {
61 self.lookup(name)
62 .map(|f| f.search_trigger)
63 .unwrap_or(SearchTrigger::None)
64 }
65
66 pub fn is_aggregate(&self, name: &str) -> bool {
68 self.lookup(name)
69 .is_some_and(|f| f.category == FunctionCategory::Aggregate)
70 }
71
72 pub fn is_window(&self, name: &str) -> bool {
74 self.lookup(name)
75 .is_some_and(|f| f.category == FunctionCategory::Window)
76 }
77}
78
79impl Default for FunctionRegistry {
80 fn default() -> Self {
81 Self::new()
82 }
83}
84
85fn s(
86 name: &'static str,
87 cat: FunctionCategory,
88 min: usize,
89 max: usize,
90 trigger: SearchTrigger,
91) -> FunctionMeta {
92 FunctionMeta {
93 name,
94 category: cat,
95 min_args: min,
96 max_args: max,
97 search_trigger: trigger,
98 }
99}
100
101fn builtin_functions() -> Vec<FunctionMeta> {
102 use FunctionCategory::*;
103 use SearchTrigger::*;
104
105 vec![
106 s("count", Aggregate, 0, 1, None),
108 s("sum", Aggregate, 1, 1, None),
109 s("avg", Aggregate, 1, 1, None),
110 s("min", Aggregate, 1, 1, None),
111 s("max", Aggregate, 1, 1, None),
112 s("row_number", Window, 0, 0, None),
114 s("rank", Window, 0, 0, None),
115 s("dense_rank", Window, 0, 0, None),
116 s("lag", Window, 1, 3, None),
117 s("lead", Window, 1, 3, None),
118 s("first_value", Window, 1, 1, None),
119 s("last_value", Window, 1, 1, None),
120 s("nth_value", Window, 2, 2, None),
121 s("vector_distance", Scalar, 2, 3, VectorSearch),
123 s("multi_vector_search", Scalar, 1, 2, MultiVectorSearch),
124 s("multi_vector_score", Scalar, 3, 3, None),
125 s("sparse_score", Scalar, 3, 3, None),
126 s("bm25_score", Scalar, 2, 2, TextSearch),
128 s("search_score", Scalar, 2, 2, TextSearch),
129 s("text_match", Scalar, 2, 3, TextMatch),
130 s("rrf_score", Scalar, 2, 4, HybridSearch),
132 s("st_dwithin", Scalar, 3, 3, SpatialDWithin),
134 s("st_contains", Scalar, 2, 2, SpatialContains),
135 s("st_intersects", Scalar, 2, 2, SpatialIntersects),
136 s("st_within", Scalar, 2, 2, SpatialWithin),
137 s("st_distance", Scalar, 2, 2, None),
138 s("st_point", Scalar, 2, 2, None),
139 s("time_bucket", Scalar, 2, 2, TimeBucket),
141 s("ts_percentile", Aggregate, 2, 2, None),
143 s("ts_stddev", Aggregate, 1, 1, None),
144 s("ts_correlate", Aggregate, 2, 2, None),
145 s("ts_rate", Window, 1, 1, None),
147 s("ts_derivative", Window, 1, 1, None),
148 s("ts_moving_avg", Window, 2, 2, None),
149 s("ts_ema", Window, 2, 2, None),
150 s("ts_delta", Window, 1, 1, None),
151 s("ts_interpolate", Window, 1, 1, None),
152 s("ts_lag", Window, 1, 3, None),
153 s("ts_lead", Window, 1, 3, None),
154 s("ts_rank", Window, 0, 0, None),
155 s("approx_count_distinct", Aggregate, 1, 1, None),
157 s("approx_percentile", Aggregate, 2, 2, None),
158 s("approx_topk", Aggregate, 2, 2, None),
159 s("approx_count", Aggregate, 1, 1, None),
160 s("doc_get", Scalar, 2, 3, None),
162 s("doc_exists", Scalar, 2, 2, None),
163 s("doc_array_contains", Scalar, 3, 3, None),
164 s("nav", Scalar, 2, 2, None),
165 s("chunk_text", Scalar, 2, 3, None),
167 s("currency", Scalar, 1, 2, None),
168 s("distribute", Scalar, 2, 3, None),
169 s("allocate", Scalar, 2, 3, None),
170 s("resolve_permission", Scalar, 2, 3, None),
171 s("coalesce", Scalar, 1, 255, None),
173 s("nullif", Scalar, 2, 2, None),
174 s("abs", Scalar, 1, 1, None),
175 s("ceil", Scalar, 1, 1, None),
176 s("floor", Scalar, 1, 1, None),
177 s("round", Scalar, 1, 2, None),
178 s("lower", Scalar, 1, 1, None),
179 s("upper", Scalar, 1, 1, None),
180 s("length", Scalar, 1, 1, None),
181 s("trim", Scalar, 1, 1, None),
182 s("substring", Scalar, 2, 3, None),
183 s("concat", Scalar, 1, 255, None),
184 s("replace", Scalar, 3, 3, None),
185 s("now", Scalar, 0, 0, None),
186 s("current_timestamp", Scalar, 0, 0, None),
187 s("make_array", Scalar, 0, 255, None),
188 ]
189}
190
191#[cfg(test)]
192mod tests {
193 use super::*;
194
195 #[test]
196 fn lookup_builtin() {
197 let reg = FunctionRegistry::new();
198 assert!(reg.is_aggregate("COUNT"));
199 assert!(reg.is_aggregate("sum"));
200 assert!(!reg.is_aggregate("vector_distance"));
201 assert!(reg.is_window("row_number"));
202 assert_eq!(
203 reg.search_trigger("vector_distance"),
204 SearchTrigger::VectorSearch
205 );
206 assert_eq!(
207 reg.search_trigger("st_dwithin"),
208 SearchTrigger::SpatialDWithin
209 );
210 assert_eq!(reg.search_trigger("time_bucket"), SearchTrigger::TimeBucket);
211 }
212}