1use hashbrown::HashMap;
7use kyu_common::{KyuError, KyuResult};
8use kyu_types::type_utils::implicit_cast_cost;
9use kyu_types::LogicalType;
10use smol_str::SmolStr;
11
12use crate::bound_expr::FunctionId;
13
14#[derive(Clone, Copy, Debug, PartialEq, Eq)]
16pub enum FunctionKind {
17 Scalar,
18 Aggregate,
19}
20
21#[derive(Clone, Debug)]
23pub struct FunctionSignature {
24 pub id: FunctionId,
25 pub name: SmolStr,
26 pub kind: FunctionKind,
27 pub param_types: Vec<LogicalType>,
28 pub variadic: bool,
29 pub return_type: LogicalType,
30}
31
32pub struct FunctionRegistry {
38 signatures: Vec<FunctionSignature>,
39 name_index: HashMap<SmolStr, Vec<usize>>,
40}
41
42impl Default for FunctionRegistry {
43 fn default() -> Self {
44 Self::new()
45 }
46}
47
48impl FunctionRegistry {
49 pub fn new() -> Self {
50 Self {
51 signatures: Vec::new(),
52 name_index: HashMap::new(),
53 }
54 }
55
56 pub fn with_builtins() -> Self {
58 let mut reg = Self::new();
59 register_builtins(&mut reg);
60 reg
61 }
62
63 pub fn register(&mut self, name: &str, kind: FunctionKind, param_types: Vec<LogicalType>, variadic: bool, return_type: LogicalType) -> FunctionId {
65 let id = FunctionId(self.signatures.len() as u32);
66 let lower_name = SmolStr::new(name.to_lowercase());
67 let sig = FunctionSignature {
68 id,
69 name: lower_name.clone(),
70 kind,
71 param_types,
72 variadic,
73 return_type,
74 };
75 let idx = self.signatures.len();
76 self.signatures.push(sig);
77 self.name_index
78 .entry(lower_name)
79 .or_default()
80 .push(idx);
81 id
82 }
83
84 pub fn resolve(
88 &self,
89 name: &str,
90 arg_types: &[LogicalType],
91 ) -> KyuResult<&FunctionSignature> {
92 let lower = name.to_lowercase();
93 let overloads = self
94 .name_index
95 .get(lower.as_str())
96 .ok_or_else(|| KyuError::Binder(format!("unknown function '{name}'")))?;
97
98 let mut best: Option<(usize, u32)> = None; for &idx in overloads {
101 let sig = &self.signatures[idx];
102 if let Some(cost) = match_cost(sig, arg_types) {
103 match best {
104 None => best = Some((idx, cost)),
105 Some((_, best_cost)) if cost < best_cost => {
106 best = Some((idx, cost));
107 }
108 _ => {}
109 }
110 }
111 }
112
113 match best {
114 Some((idx, _)) => Ok(&self.signatures[idx]),
115 None => {
116 let type_names: Vec<_> = arg_types.iter().map(|t| t.type_name().to_string()).collect();
117 Err(KyuError::Binder(format!(
118 "no matching overload for {}({})",
119 name,
120 type_names.join(", "),
121 )))
122 }
123 }
124 }
125
126 pub fn get(&self, id: FunctionId) -> Option<&FunctionSignature> {
128 self.signatures.get(id.0 as usize)
129 }
130
131 pub fn len(&self) -> usize {
133 self.signatures.len()
134 }
135
136 pub fn is_empty(&self) -> bool {
138 self.signatures.is_empty()
139 }
140}
141
142fn match_cost(sig: &FunctionSignature, arg_types: &[LogicalType]) -> Option<u32> {
145 if sig.variadic {
146 if arg_types.len() < sig.param_types.len() {
147 return None;
148 }
149 } else if arg_types.len() != sig.param_types.len() {
150 return None;
151 }
152
153 let mut total = 0u32;
154
155 for (param, arg) in sig.param_types.iter().zip(arg_types.iter()) {
157 if matches!(param, LogicalType::Any) {
158 continue;
160 }
161 let cost = implicit_cast_cost(arg, param)?;
162 total += cost;
163 }
164
165 Some(total)
167}
168
169fn register_builtins(reg: &mut FunctionRegistry) {
170 use FunctionKind::{Aggregate, Scalar};
171 use LogicalType::*;
172
173 for ty in &[Int64, Double] {
175 reg.register("abs", Scalar, vec![ty.clone()], false, ty.clone());
176 }
177 reg.register("floor", Scalar, vec![Double], false, Double);
178 reg.register("ceil", Scalar, vec![Double], false, Double);
179 reg.register("round", Scalar, vec![Double], false, Double);
180 reg.register("sqrt", Scalar, vec![Double], false, Double);
181 reg.register("log", Scalar, vec![Double], false, Double);
182 reg.register("log2", Scalar, vec![Double], false, Double);
183 reg.register("log10", Scalar, vec![Double], false, Double);
184 reg.register("sin", Scalar, vec![Double], false, Double);
185 reg.register("cos", Scalar, vec![Double], false, Double);
186 reg.register("tan", Scalar, vec![Double], false, Double);
187 reg.register("sign", Scalar, vec![Int64], false, Int64);
188 reg.register("sign", Scalar, vec![Double], false, Int64);
189 reg.register("greatest", Scalar, vec![Any], true, Any);
190 reg.register("least", Scalar, vec![Any], true, Any);
191
192 reg.register("lower", Scalar, vec![String], false, String);
194 reg.register("upper", Scalar, vec![String], false, String);
195 reg.register("length", Scalar, vec![String], false, Int64);
196 reg.register("size", Scalar, vec![String], false, Int64);
197 reg.register("trim", Scalar, vec![String], false, String);
198 reg.register("ltrim", Scalar, vec![String], false, String);
199 reg.register("rtrim", Scalar, vec![String], false, String);
200 reg.register("reverse", Scalar, vec![String], false, String);
201 reg.register(
202 "substring",
203 Scalar,
204 vec![String, Int64, Int64],
205 false,
206 String,
207 );
208 reg.register("left", Scalar, vec![String, Int64], false, String);
209 reg.register("right", Scalar, vec![String, Int64], false, String);
210 reg.register(
211 "replace",
212 Scalar,
213 vec![String, String, String],
214 false,
215 String,
216 );
217 reg.register("concat", Scalar, vec![String], true, String);
218 reg.register("lpad", Scalar, vec![String, Int64, String], false, String);
219 reg.register("rpad", Scalar, vec![String, Int64, String], false, String);
220
221 reg.register("tostring", Scalar, vec![Any], false, String);
223 reg.register("tostring", Scalar, vec![String], false, String);
224 reg.register("tointeger", Scalar, vec![Any], false, Int64);
225 reg.register("tofloat", Scalar, vec![Any], false, Double);
226 reg.register("toboolean", Scalar, vec![Any], false, Bool);
227
228 reg.register("coalesce", Scalar, vec![Any], true, Any);
230 reg.register("typeof", Scalar, vec![Any], false, String);
231 reg.register("hash", Scalar, vec![Any], false, Int64);
232
233 reg.register(
235 "range",
236 Scalar,
237 vec![Int64, Int64],
238 false,
239 List(Box::new(Int64)),
240 );
241 reg.register("size", Scalar, vec![List(Box::new(Any))], false, Int64);
242 reg.register("length", Scalar, vec![List(Box::new(Any))], false, Int64);
243
244 reg.register("json_extract", Scalar, vec![String, String], false, String);
246 reg.register("json_valid", Scalar, vec![String], false, Bool);
247 reg.register("json_type", Scalar, vec![String], false, String);
248 reg.register("json_keys", Scalar, vec![String], false, List(Box::new(String)));
249 reg.register("json_array_length", Scalar, vec![String], false, Int64);
250 reg.register("json_contains", Scalar, vec![String, String], false, Bool);
251 reg.register("json_set", Scalar, vec![String, String, String], false, String);
252
253 reg.register("count", Aggregate, vec![Any], false, Int64);
255 reg.register("sum", Aggregate, vec![Int64], false, Int64);
256 reg.register("sum", Aggregate, vec![Double], false, Double);
257 reg.register("avg", Aggregate, vec![Int64], false, Double);
258 reg.register("avg", Aggregate, vec![Double], false, Double);
259 reg.register("min", Aggregate, vec![Any], false, Any);
260 reg.register("max", Aggregate, vec![Any], false, Any);
261 reg.register("collect", Aggregate, vec![Any], false, List(Box::new(Any)));
262}
263
264#[cfg(test)]
265mod tests {
266 use super::*;
267
268 #[test]
269 fn empty_registry() {
270 let reg = FunctionRegistry::new();
271 assert!(reg.is_empty());
272 assert_eq!(reg.len(), 0);
273 }
274
275 #[test]
276 fn register_and_get() {
277 let mut reg = FunctionRegistry::new();
278 let id = reg.register("foo", FunctionKind::Scalar, vec![LogicalType::Int64], false, LogicalType::Int64);
279 assert_eq!(id.0, 0);
280
281 let sig = reg.get(id).unwrap();
282 assert_eq!(sig.name.as_str(), "foo");
283 assert_eq!(sig.return_type, LogicalType::Int64);
284 assert_eq!(sig.kind, FunctionKind::Scalar);
285 }
286
287 #[test]
288 fn resolve_exact_match() {
289 let reg = FunctionRegistry::with_builtins();
290 let sig = reg.resolve("abs", &[LogicalType::Int64]).unwrap();
291 assert_eq!(sig.return_type, LogicalType::Int64);
292 }
293
294 #[test]
295 fn resolve_case_insensitive() {
296 let reg = FunctionRegistry::with_builtins();
297 let sig = reg.resolve("ABS", &[LogicalType::Int64]).unwrap();
298 assert_eq!(sig.name.as_str(), "abs");
299 }
300
301 #[test]
302 fn resolve_with_implicit_coercion() {
303 let reg = FunctionRegistry::with_builtins();
304 let sig = reg.resolve("abs", &[LogicalType::Int32]).unwrap();
306 assert_eq!(sig.return_type, LogicalType::Int64);
307 }
308
309 #[test]
310 fn resolve_best_overload() {
311 let reg = FunctionRegistry::with_builtins();
312 let sig = reg.resolve("abs", &[LogicalType::Double]).unwrap();
314 assert_eq!(sig.return_type, LogicalType::Double);
315 }
316
317 #[test]
318 fn resolve_unknown_function() {
319 let reg = FunctionRegistry::with_builtins();
320 let result = reg.resolve("nonexistent", &[LogicalType::Int64]);
321 assert!(result.is_err());
322 }
323
324 #[test]
325 fn resolve_wrong_arg_count() {
326 let reg = FunctionRegistry::with_builtins();
327 let result = reg.resolve("abs", &[]);
328 assert!(result.is_err());
329 }
330
331 #[test]
332 fn resolve_aggregate() {
333 let reg = FunctionRegistry::with_builtins();
334 let sig = reg.resolve("count", &[LogicalType::Int64]).unwrap();
335 assert_eq!(sig.kind, FunctionKind::Aggregate);
336 assert_eq!(sig.return_type, LogicalType::Int64);
337 }
338
339 #[test]
340 fn resolve_string_function() {
341 let reg = FunctionRegistry::with_builtins();
342 let sig = reg.resolve("upper", &[LogicalType::String]).unwrap();
343 assert_eq!(sig.return_type, LogicalType::String);
344 }
345
346 #[test]
347 fn resolve_multi_arg_function() {
348 let reg = FunctionRegistry::with_builtins();
349 let sig = reg
350 .resolve("substring", &[LogicalType::String, LogicalType::Int64, LogicalType::Int64])
351 .unwrap();
352 assert_eq!(sig.return_type, LogicalType::String);
353 }
354
355 #[test]
356 fn resolve_variadic_function() {
357 let reg = FunctionRegistry::with_builtins();
358 let sig = reg
360 .resolve("coalesce", &[LogicalType::Int64, LogicalType::Int64, LogicalType::Int64])
361 .unwrap();
362 assert_eq!(sig.name.as_str(), "coalesce");
363 }
364
365 #[test]
366 fn builtins_populated() {
367 let reg = FunctionRegistry::with_builtins();
368 assert!(reg.len() > 20);
369 }
370
371 #[test]
372 fn function_id_sequential() {
373 let mut reg = FunctionRegistry::new();
374 let id0 = reg.register("a", FunctionKind::Scalar, vec![], false, LogicalType::Bool);
375 let id1 = reg.register("b", FunctionKind::Scalar, vec![], false, LogicalType::Bool);
376 assert_eq!(id0.0, 0);
377 assert_eq!(id1.0, 1);
378 }
379
380 #[test]
381 fn get_nonexistent_id() {
382 let reg = FunctionRegistry::new();
383 assert!(reg.get(FunctionId(999)).is_none());
384 }
385}