Skip to main content

kyu_expression/
function_registry.rs

1//! Function registry — scalar and aggregate function signatures.
2//!
3//! Pre-populated with built-in functions at startup. Resolves function calls
4//! by name + argument types, selecting the best overload via implicit cast cost.
5
6use hashbrown::HashMap;
7use kyu_common::{KyuError, KyuResult};
8use kyu_types::type_utils::implicit_cast_cost;
9use kyu_types::LogicalType;
10use smol_str::SmolStr;
11
12use crate::bound_expr::FunctionId;
13
14/// Whether a function is scalar (row-at-a-time) or aggregate.
15#[derive(Clone, Copy, Debug, PartialEq, Eq)]
16pub enum FunctionKind {
17    Scalar,
18    Aggregate,
19}
20
21/// A function signature: expected arg types and return type.
22#[derive(Clone, Debug)]
23pub struct FunctionSignature {
24    pub id: FunctionId,
25    pub name: SmolStr,
26    pub kind: FunctionKind,
27    pub param_types: Vec<LogicalType>,
28    pub variadic: bool,
29    pub return_type: LogicalType,
30}
31
32/// Registry of all known functions.
33///
34/// Populated once at startup; immutable during query processing.
35/// Functions are indexed by `FunctionId` (O(1) lookup) and by name
36/// (case-insensitive, O(1) via HashMap).
37pub struct FunctionRegistry {
38    signatures: Vec<FunctionSignature>,
39    name_index: HashMap<SmolStr, Vec<usize>>,
40}
41
42impl Default for FunctionRegistry {
43    fn default() -> Self {
44        Self::new()
45    }
46}
47
48impl FunctionRegistry {
49    pub fn new() -> Self {
50        Self {
51            signatures: Vec::new(),
52            name_index: HashMap::new(),
53        }
54    }
55
56    /// Create a registry pre-populated with built-in functions.
57    pub fn with_builtins() -> Self {
58        let mut reg = Self::new();
59        register_builtins(&mut reg);
60        reg
61    }
62
63    /// Register a function signature. Returns the assigned FunctionId.
64    pub fn register(&mut self, name: &str, kind: FunctionKind, param_types: Vec<LogicalType>, variadic: bool, return_type: LogicalType) -> FunctionId {
65        let id = FunctionId(self.signatures.len() as u32);
66        let lower_name = SmolStr::new(name.to_lowercase());
67        let sig = FunctionSignature {
68            id,
69            name: lower_name.clone(),
70            kind,
71            param_types,
72            variadic,
73            return_type,
74        };
75        let idx = self.signatures.len();
76        self.signatures.push(sig);
77        self.name_index
78            .entry(lower_name)
79            .or_default()
80            .push(idx);
81        id
82    }
83
84    /// Resolve a function call: name + actual arg types → best matching signature.
85    ///
86    /// Considers implicit coercions. Returns the matching signature.
87    pub fn resolve(
88        &self,
89        name: &str,
90        arg_types: &[LogicalType],
91    ) -> KyuResult<&FunctionSignature> {
92        let lower = name.to_lowercase();
93        let overloads = self
94            .name_index
95            .get(lower.as_str())
96            .ok_or_else(|| KyuError::Binder(format!("unknown function '{name}'")))?;
97
98        let mut best: Option<(usize, u32)> = None; // (index, total_cost)
99
100        for &idx in overloads {
101            let sig = &self.signatures[idx];
102            if let Some(cost) = match_cost(sig, arg_types) {
103                match best {
104                    None => best = Some((idx, cost)),
105                    Some((_, best_cost)) if cost < best_cost => {
106                        best = Some((idx, cost));
107                    }
108                    _ => {}
109                }
110            }
111        }
112
113        match best {
114            Some((idx, _)) => Ok(&self.signatures[idx]),
115            None => {
116                let type_names: Vec<_> = arg_types.iter().map(|t| t.type_name().to_string()).collect();
117                Err(KyuError::Binder(format!(
118                    "no matching overload for {}({})",
119                    name,
120                    type_names.join(", "),
121                )))
122            }
123        }
124    }
125
126    /// Look up by FunctionId (O(1)).
127    pub fn get(&self, id: FunctionId) -> Option<&FunctionSignature> {
128        self.signatures.get(id.0 as usize)
129    }
130
131    /// Number of registered functions.
132    pub fn len(&self) -> usize {
133        self.signatures.len()
134    }
135
136    /// Whether the registry is empty.
137    pub fn is_empty(&self) -> bool {
138        self.signatures.is_empty()
139    }
140}
141
142/// Compute the total implicit cast cost for matching `arg_types` against `sig`.
143/// Returns `None` if args don't match.
144fn match_cost(sig: &FunctionSignature, arg_types: &[LogicalType]) -> Option<u32> {
145    if sig.variadic {
146        if arg_types.len() < sig.param_types.len() {
147            return None;
148        }
149    } else if arg_types.len() != sig.param_types.len() {
150        return None;
151    }
152
153    let mut total = 0u32;
154
155    // Match declared params.
156    for (param, arg) in sig.param_types.iter().zip(arg_types.iter()) {
157        if matches!(param, LogicalType::Any) {
158            // Any accepts anything at cost 0.
159            continue;
160        }
161        let cost = implicit_cast_cost(arg, param)?;
162        total += cost;
163    }
164
165    // For variadic, extra args are accepted at cost 0 (type checking is loose).
166    Some(total)
167}
168
169fn register_builtins(reg: &mut FunctionRegistry) {
170    use FunctionKind::{Aggregate, Scalar};
171    use LogicalType::*;
172
173    // Numeric scalar functions — register common overloads.
174    for ty in &[Int64, Double] {
175        reg.register("abs", Scalar, vec![ty.clone()], false, ty.clone());
176    }
177    reg.register("floor", Scalar, vec![Double], false, Double);
178    reg.register("ceil", Scalar, vec![Double], false, Double);
179    reg.register("round", Scalar, vec![Double], false, Double);
180    reg.register("sqrt", Scalar, vec![Double], false, Double);
181    reg.register("log", Scalar, vec![Double], false, Double);
182    reg.register("log2", Scalar, vec![Double], false, Double);
183    reg.register("log10", Scalar, vec![Double], false, Double);
184    reg.register("sin", Scalar, vec![Double], false, Double);
185    reg.register("cos", Scalar, vec![Double], false, Double);
186    reg.register("tan", Scalar, vec![Double], false, Double);
187    reg.register("sign", Scalar, vec![Int64], false, Int64);
188    reg.register("sign", Scalar, vec![Double], false, Int64);
189    reg.register("greatest", Scalar, vec![Any], true, Any);
190    reg.register("least", Scalar, vec![Any], true, Any);
191
192    // String scalar functions.
193    reg.register("lower", Scalar, vec![String], false, String);
194    reg.register("upper", Scalar, vec![String], false, String);
195    reg.register("length", Scalar, vec![String], false, Int64);
196    reg.register("size", Scalar, vec![String], false, Int64);
197    reg.register("trim", Scalar, vec![String], false, String);
198    reg.register("ltrim", Scalar, vec![String], false, String);
199    reg.register("rtrim", Scalar, vec![String], false, String);
200    reg.register("reverse", Scalar, vec![String], false, String);
201    reg.register(
202        "substring",
203        Scalar,
204        vec![String, Int64, Int64],
205        false,
206        String,
207    );
208    reg.register("left", Scalar, vec![String, Int64], false, String);
209    reg.register("right", Scalar, vec![String, Int64], false, String);
210    reg.register(
211        "replace",
212        Scalar,
213        vec![String, String, String],
214        false,
215        String,
216    );
217    reg.register("concat", Scalar, vec![String], true, String);
218    reg.register("lpad", Scalar, vec![String, Int64, String], false, String);
219    reg.register("rpad", Scalar, vec![String, Int64, String], false, String);
220
221    // Conversion functions.
222    reg.register("tostring", Scalar, vec![Any], false, String);
223    reg.register("tostring", Scalar, vec![String], false, String);
224    reg.register("tointeger", Scalar, vec![Any], false, Int64);
225    reg.register("tofloat", Scalar, vec![Any], false, Double);
226    reg.register("toboolean", Scalar, vec![Any], false, Bool);
227
228    // Utility.
229    reg.register("coalesce", Scalar, vec![Any], true, Any);
230    reg.register("typeof", Scalar, vec![Any], false, String);
231    reg.register("hash", Scalar, vec![Any], false, Int64);
232
233    // List functions.
234    reg.register(
235        "range",
236        Scalar,
237        vec![Int64, Int64],
238        false,
239        List(Box::new(Int64)),
240    );
241    reg.register("size", Scalar, vec![List(Box::new(Any))], false, Int64);
242    reg.register("length", Scalar, vec![List(Box::new(Any))], false, Int64);
243
244    // JSON functions.
245    reg.register("json_extract", Scalar, vec![String, String], false, String);
246    reg.register("json_valid", Scalar, vec![String], false, Bool);
247    reg.register("json_type", Scalar, vec![String], false, String);
248    reg.register("json_keys", Scalar, vec![String], false, List(Box::new(String)));
249    reg.register("json_array_length", Scalar, vec![String], false, Int64);
250    reg.register("json_contains", Scalar, vec![String, String], false, Bool);
251    reg.register("json_set", Scalar, vec![String, String, String], false, String);
252
253    // Aggregate functions.
254    reg.register("count", Aggregate, vec![Any], false, Int64);
255    reg.register("sum", Aggregate, vec![Int64], false, Int64);
256    reg.register("sum", Aggregate, vec![Double], false, Double);
257    reg.register("avg", Aggregate, vec![Int64], false, Double);
258    reg.register("avg", Aggregate, vec![Double], false, Double);
259    reg.register("min", Aggregate, vec![Any], false, Any);
260    reg.register("max", Aggregate, vec![Any], false, Any);
261    reg.register("collect", Aggregate, vec![Any], false, List(Box::new(Any)));
262}
263
264#[cfg(test)]
265mod tests {
266    use super::*;
267
268    #[test]
269    fn empty_registry() {
270        let reg = FunctionRegistry::new();
271        assert!(reg.is_empty());
272        assert_eq!(reg.len(), 0);
273    }
274
275    #[test]
276    fn register_and_get() {
277        let mut reg = FunctionRegistry::new();
278        let id = reg.register("foo", FunctionKind::Scalar, vec![LogicalType::Int64], false, LogicalType::Int64);
279        assert_eq!(id.0, 0);
280
281        let sig = reg.get(id).unwrap();
282        assert_eq!(sig.name.as_str(), "foo");
283        assert_eq!(sig.return_type, LogicalType::Int64);
284        assert_eq!(sig.kind, FunctionKind::Scalar);
285    }
286
287    #[test]
288    fn resolve_exact_match() {
289        let reg = FunctionRegistry::with_builtins();
290        let sig = reg.resolve("abs", &[LogicalType::Int64]).unwrap();
291        assert_eq!(sig.return_type, LogicalType::Int64);
292    }
293
294    #[test]
295    fn resolve_case_insensitive() {
296        let reg = FunctionRegistry::with_builtins();
297        let sig = reg.resolve("ABS", &[LogicalType::Int64]).unwrap();
298        assert_eq!(sig.name.as_str(), "abs");
299    }
300
301    #[test]
302    fn resolve_with_implicit_coercion() {
303        let reg = FunctionRegistry::with_builtins();
304        // abs(Int32) should match abs(Int64) via implicit cast.
305        let sig = reg.resolve("abs", &[LogicalType::Int32]).unwrap();
306        assert_eq!(sig.return_type, LogicalType::Int64);
307    }
308
309    #[test]
310    fn resolve_best_overload() {
311        let reg = FunctionRegistry::with_builtins();
312        // abs(Double) should prefer abs(Double) over abs(Int64).
313        let sig = reg.resolve("abs", &[LogicalType::Double]).unwrap();
314        assert_eq!(sig.return_type, LogicalType::Double);
315    }
316
317    #[test]
318    fn resolve_unknown_function() {
319        let reg = FunctionRegistry::with_builtins();
320        let result = reg.resolve("nonexistent", &[LogicalType::Int64]);
321        assert!(result.is_err());
322    }
323
324    #[test]
325    fn resolve_wrong_arg_count() {
326        let reg = FunctionRegistry::with_builtins();
327        let result = reg.resolve("abs", &[]);
328        assert!(result.is_err());
329    }
330
331    #[test]
332    fn resolve_aggregate() {
333        let reg = FunctionRegistry::with_builtins();
334        let sig = reg.resolve("count", &[LogicalType::Int64]).unwrap();
335        assert_eq!(sig.kind, FunctionKind::Aggregate);
336        assert_eq!(sig.return_type, LogicalType::Int64);
337    }
338
339    #[test]
340    fn resolve_string_function() {
341        let reg = FunctionRegistry::with_builtins();
342        let sig = reg.resolve("upper", &[LogicalType::String]).unwrap();
343        assert_eq!(sig.return_type, LogicalType::String);
344    }
345
346    #[test]
347    fn resolve_multi_arg_function() {
348        let reg = FunctionRegistry::with_builtins();
349        let sig = reg
350            .resolve("substring", &[LogicalType::String, LogicalType::Int64, LogicalType::Int64])
351            .unwrap();
352        assert_eq!(sig.return_type, LogicalType::String);
353    }
354
355    #[test]
356    fn resolve_variadic_function() {
357        let reg = FunctionRegistry::with_builtins();
358        // coalesce(Any...) — accepts any number of args >= 1.
359        let sig = reg
360            .resolve("coalesce", &[LogicalType::Int64, LogicalType::Int64, LogicalType::Int64])
361            .unwrap();
362        assert_eq!(sig.name.as_str(), "coalesce");
363    }
364
365    #[test]
366    fn builtins_populated() {
367        let reg = FunctionRegistry::with_builtins();
368        assert!(reg.len() > 20);
369    }
370
371    #[test]
372    fn function_id_sequential() {
373        let mut reg = FunctionRegistry::new();
374        let id0 = reg.register("a", FunctionKind::Scalar, vec![], false, LogicalType::Bool);
375        let id1 = reg.register("b", FunctionKind::Scalar, vec![], false, LogicalType::Bool);
376        assert_eq!(id0.0, 0);
377        assert_eq!(id1.0, 1);
378    }
379
380    #[test]
381    fn get_nonexistent_id() {
382        let reg = FunctionRegistry::new();
383        assert!(reg.get(FunctionId(999)).is_none());
384    }
385}