1use hashbrown::HashMap;
7use kyu_common::{KyuError, KyuResult};
8use kyu_types::LogicalType;
9use kyu_types::type_utils::implicit_cast_cost;
10use smol_str::SmolStr;
11
12use crate::bound_expr::FunctionId;
13
14#[derive(Clone, Copy, Debug, PartialEq, Eq)]
16pub enum FunctionKind {
17 Scalar,
18 Aggregate,
19}
20
21#[derive(Clone, Debug)]
23pub struct FunctionSignature {
24 pub id: FunctionId,
25 pub name: SmolStr,
26 pub kind: FunctionKind,
27 pub param_types: Vec<LogicalType>,
28 pub variadic: bool,
29 pub return_type: LogicalType,
30}
31
32pub struct FunctionRegistry {
38 signatures: Vec<FunctionSignature>,
39 name_index: HashMap<SmolStr, Vec<usize>>,
40}
41
42impl Default for FunctionRegistry {
43 fn default() -> Self {
44 Self::new()
45 }
46}
47
48impl FunctionRegistry {
49 pub fn new() -> Self {
50 Self {
51 signatures: Vec::new(),
52 name_index: HashMap::new(),
53 }
54 }
55
56 pub fn with_builtins() -> Self {
58 let mut reg = Self::new();
59 register_builtins(&mut reg);
60 reg
61 }
62
63 pub fn register(
65 &mut self,
66 name: &str,
67 kind: FunctionKind,
68 param_types: Vec<LogicalType>,
69 variadic: bool,
70 return_type: LogicalType,
71 ) -> FunctionId {
72 let id = FunctionId(self.signatures.len() as u32);
73 let lower_name = SmolStr::new(name.to_lowercase());
74 let sig = FunctionSignature {
75 id,
76 name: lower_name.clone(),
77 kind,
78 param_types,
79 variadic,
80 return_type,
81 };
82 let idx = self.signatures.len();
83 self.signatures.push(sig);
84 self.name_index.entry(lower_name).or_default().push(idx);
85 id
86 }
87
88 pub fn resolve(&self, name: &str, arg_types: &[LogicalType]) -> KyuResult<&FunctionSignature> {
92 let lower = name.to_lowercase();
93 let overloads = self
94 .name_index
95 .get(lower.as_str())
96 .ok_or_else(|| KyuError::Binder(format!("unknown function '{name}'")))?;
97
98 let mut best: Option<(usize, u32)> = None; for &idx in overloads {
101 let sig = &self.signatures[idx];
102 if let Some(cost) = match_cost(sig, arg_types) {
103 match best {
104 None => best = Some((idx, cost)),
105 Some((_, best_cost)) if cost < best_cost => {
106 best = Some((idx, cost));
107 }
108 _ => {}
109 }
110 }
111 }
112
113 match best {
114 Some((idx, _)) => Ok(&self.signatures[idx]),
115 None => {
116 let type_names: Vec<_> = arg_types
117 .iter()
118 .map(|t| t.type_name().to_string())
119 .collect();
120 Err(KyuError::Binder(format!(
121 "no matching overload for {}({})",
122 name,
123 type_names.join(", "),
124 )))
125 }
126 }
127 }
128
129 pub fn get(&self, id: FunctionId) -> Option<&FunctionSignature> {
131 self.signatures.get(id.0 as usize)
132 }
133
134 pub fn len(&self) -> usize {
136 self.signatures.len()
137 }
138
139 pub fn is_empty(&self) -> bool {
141 self.signatures.is_empty()
142 }
143}
144
145fn match_cost(sig: &FunctionSignature, arg_types: &[LogicalType]) -> Option<u32> {
148 if sig.variadic {
149 if arg_types.len() < sig.param_types.len() {
150 return None;
151 }
152 } else if arg_types.len() != sig.param_types.len() {
153 return None;
154 }
155
156 let mut total = 0u32;
157
158 for (param, arg) in sig.param_types.iter().zip(arg_types.iter()) {
160 if matches!(param, LogicalType::Any) {
161 continue;
163 }
164 let cost = implicit_cast_cost(arg, param)?;
165 total += cost;
166 }
167
168 Some(total)
170}
171
172fn register_builtins(reg: &mut FunctionRegistry) {
173 use FunctionKind::{Aggregate, Scalar};
174 use LogicalType::*;
175
176 for ty in &[Int64, Double] {
178 reg.register("abs", Scalar, vec![ty.clone()], false, ty.clone());
179 }
180 reg.register("floor", Scalar, vec![Double], false, Double);
181 reg.register("ceil", Scalar, vec![Double], false, Double);
182 reg.register("round", Scalar, vec![Double], false, Double);
183 reg.register("sqrt", Scalar, vec![Double], false, Double);
184 reg.register("log", Scalar, vec![Double], false, Double);
185 reg.register("log2", Scalar, vec![Double], false, Double);
186 reg.register("log10", Scalar, vec![Double], false, Double);
187 reg.register("sin", Scalar, vec![Double], false, Double);
188 reg.register("cos", Scalar, vec![Double], false, Double);
189 reg.register("tan", Scalar, vec![Double], false, Double);
190 reg.register("sign", Scalar, vec![Int64], false, Int64);
191 reg.register("sign", Scalar, vec![Double], false, Int64);
192 reg.register("greatest", Scalar, vec![Any], true, Any);
193 reg.register("least", Scalar, vec![Any], true, Any);
194
195 reg.register("lower", Scalar, vec![String], false, String);
197 reg.register("upper", Scalar, vec![String], false, String);
198 reg.register("length", Scalar, vec![String], false, Int64);
199 reg.register("size", Scalar, vec![String], false, Int64);
200 reg.register("trim", Scalar, vec![String], false, String);
201 reg.register("ltrim", Scalar, vec![String], false, String);
202 reg.register("rtrim", Scalar, vec![String], false, String);
203 reg.register("reverse", Scalar, vec![String], false, String);
204 reg.register(
205 "substring",
206 Scalar,
207 vec![String, Int64, Int64],
208 false,
209 String,
210 );
211 reg.register("left", Scalar, vec![String, Int64], false, String);
212 reg.register("right", Scalar, vec![String, Int64], false, String);
213 reg.register(
214 "replace",
215 Scalar,
216 vec![String, String, String],
217 false,
218 String,
219 );
220 reg.register("concat", Scalar, vec![String], true, String);
221 reg.register("lpad", Scalar, vec![String, Int64, String], false, String);
222 reg.register("rpad", Scalar, vec![String, Int64, String], false, String);
223
224 reg.register("tostring", Scalar, vec![Any], false, String);
226 reg.register("tostring", Scalar, vec![String], false, String);
227 reg.register("tointeger", Scalar, vec![Any], false, Int64);
228 reg.register("tofloat", Scalar, vec![Any], false, Double);
229 reg.register("toboolean", Scalar, vec![Any], false, Bool);
230
231 reg.register("coalesce", Scalar, vec![Any], true, Any);
233 reg.register("typeof", Scalar, vec![Any], false, String);
234 reg.register("hash", Scalar, vec![Any], false, Int64);
235
236 reg.register(
238 "range",
239 Scalar,
240 vec![Int64, Int64],
241 false,
242 List(Box::new(Int64)),
243 );
244 reg.register("size", Scalar, vec![List(Box::new(Any))], false, Int64);
245 reg.register("length", Scalar, vec![List(Box::new(Any))], false, Int64);
246
247 reg.register("json_extract", Scalar, vec![String, String], false, String);
249 reg.register("json_valid", Scalar, vec![String], false, Bool);
250 reg.register("json_type", Scalar, vec![String], false, String);
251 reg.register(
252 "json_keys",
253 Scalar,
254 vec![String],
255 false,
256 List(Box::new(String)),
257 );
258 reg.register("json_array_length", Scalar, vec![String], false, Int64);
259 reg.register("json_contains", Scalar, vec![String, String], false, Bool);
260 reg.register(
261 "json_set",
262 Scalar,
263 vec![String, String, String],
264 false,
265 String,
266 );
267
268 reg.register("count", Aggregate, vec![Any], false, Int64);
270 reg.register("sum", Aggregate, vec![Int64], false, Int64);
271 reg.register("sum", Aggregate, vec![Double], false, Double);
272 reg.register("avg", Aggregate, vec![Int64], false, Double);
273 reg.register("avg", Aggregate, vec![Double], false, Double);
274 reg.register("min", Aggregate, vec![Any], false, Any);
275 reg.register("max", Aggregate, vec![Any], false, Any);
276 reg.register("collect", Aggregate, vec![Any], false, List(Box::new(Any)));
277}
278
279#[cfg(test)]
280mod tests {
281 use super::*;
282
283 #[test]
284 fn empty_registry() {
285 let reg = FunctionRegistry::new();
286 assert!(reg.is_empty());
287 assert_eq!(reg.len(), 0);
288 }
289
290 #[test]
291 fn register_and_get() {
292 let mut reg = FunctionRegistry::new();
293 let id = reg.register(
294 "foo",
295 FunctionKind::Scalar,
296 vec![LogicalType::Int64],
297 false,
298 LogicalType::Int64,
299 );
300 assert_eq!(id.0, 0);
301
302 let sig = reg.get(id).unwrap();
303 assert_eq!(sig.name.as_str(), "foo");
304 assert_eq!(sig.return_type, LogicalType::Int64);
305 assert_eq!(sig.kind, FunctionKind::Scalar);
306 }
307
308 #[test]
309 fn resolve_exact_match() {
310 let reg = FunctionRegistry::with_builtins();
311 let sig = reg.resolve("abs", &[LogicalType::Int64]).unwrap();
312 assert_eq!(sig.return_type, LogicalType::Int64);
313 }
314
315 #[test]
316 fn resolve_case_insensitive() {
317 let reg = FunctionRegistry::with_builtins();
318 let sig = reg.resolve("ABS", &[LogicalType::Int64]).unwrap();
319 assert_eq!(sig.name.as_str(), "abs");
320 }
321
322 #[test]
323 fn resolve_with_implicit_coercion() {
324 let reg = FunctionRegistry::with_builtins();
325 let sig = reg.resolve("abs", &[LogicalType::Int32]).unwrap();
327 assert_eq!(sig.return_type, LogicalType::Int64);
328 }
329
330 #[test]
331 fn resolve_best_overload() {
332 let reg = FunctionRegistry::with_builtins();
333 let sig = reg.resolve("abs", &[LogicalType::Double]).unwrap();
335 assert_eq!(sig.return_type, LogicalType::Double);
336 }
337
338 #[test]
339 fn resolve_unknown_function() {
340 let reg = FunctionRegistry::with_builtins();
341 let result = reg.resolve("nonexistent", &[LogicalType::Int64]);
342 assert!(result.is_err());
343 }
344
345 #[test]
346 fn resolve_wrong_arg_count() {
347 let reg = FunctionRegistry::with_builtins();
348 let result = reg.resolve("abs", &[]);
349 assert!(result.is_err());
350 }
351
352 #[test]
353 fn resolve_aggregate() {
354 let reg = FunctionRegistry::with_builtins();
355 let sig = reg.resolve("count", &[LogicalType::Int64]).unwrap();
356 assert_eq!(sig.kind, FunctionKind::Aggregate);
357 assert_eq!(sig.return_type, LogicalType::Int64);
358 }
359
360 #[test]
361 fn resolve_string_function() {
362 let reg = FunctionRegistry::with_builtins();
363 let sig = reg.resolve("upper", &[LogicalType::String]).unwrap();
364 assert_eq!(sig.return_type, LogicalType::String);
365 }
366
367 #[test]
368 fn resolve_multi_arg_function() {
369 let reg = FunctionRegistry::with_builtins();
370 let sig = reg
371 .resolve(
372 "substring",
373 &[LogicalType::String, LogicalType::Int64, LogicalType::Int64],
374 )
375 .unwrap();
376 assert_eq!(sig.return_type, LogicalType::String);
377 }
378
379 #[test]
380 fn resolve_variadic_function() {
381 let reg = FunctionRegistry::with_builtins();
382 let sig = reg
384 .resolve(
385 "coalesce",
386 &[LogicalType::Int64, LogicalType::Int64, LogicalType::Int64],
387 )
388 .unwrap();
389 assert_eq!(sig.name.as_str(), "coalesce");
390 }
391
392 #[test]
393 fn builtins_populated() {
394 let reg = FunctionRegistry::with_builtins();
395 assert!(reg.len() > 20);
396 }
397
398 #[test]
399 fn function_id_sequential() {
400 let mut reg = FunctionRegistry::new();
401 let id0 = reg.register("a", FunctionKind::Scalar, vec![], false, LogicalType::Bool);
402 let id1 = reg.register("b", FunctionKind::Scalar, vec![], false, LogicalType::Bool);
403 assert_eq!(id0.0, 0);
404 assert_eq!(id1.0, 1);
405 }
406
407 #[test]
408 fn get_nonexistent_id() {
409 let reg = FunctionRegistry::new();
410 assert!(reg.get(FunctionId(999)).is_none());
411 }
412}