Skip to main content

kyu_expression/
function_registry.rs

1//! Function registry — scalar and aggregate function signatures.
2//!
3//! Pre-populated with built-in functions at startup. Resolves function calls
4//! by name + argument types, selecting the best overload via implicit cast cost.
5
6use hashbrown::HashMap;
7use kyu_common::{KyuError, KyuResult};
8use kyu_types::LogicalType;
9use kyu_types::type_utils::implicit_cast_cost;
10use smol_str::SmolStr;
11
12use crate::bound_expr::FunctionId;
13
14/// Whether a function is scalar (row-at-a-time) or aggregate.
15#[derive(Clone, Copy, Debug, PartialEq, Eq)]
16pub enum FunctionKind {
17    Scalar,
18    Aggregate,
19}
20
21/// A function signature: expected arg types and return type.
22#[derive(Clone, Debug)]
23pub struct FunctionSignature {
24    pub id: FunctionId,
25    pub name: SmolStr,
26    pub kind: FunctionKind,
27    pub param_types: Vec<LogicalType>,
28    pub variadic: bool,
29    pub return_type: LogicalType,
30}
31
32/// Registry of all known functions.
33///
34/// Populated once at startup; immutable during query processing.
35/// Functions are indexed by `FunctionId` (O(1) lookup) and by name
36/// (case-insensitive, O(1) via HashMap).
37pub struct FunctionRegistry {
38    signatures: Vec<FunctionSignature>,
39    name_index: HashMap<SmolStr, Vec<usize>>,
40}
41
42impl Default for FunctionRegistry {
43    fn default() -> Self {
44        Self::new()
45    }
46}
47
48impl FunctionRegistry {
49    pub fn new() -> Self {
50        Self {
51            signatures: Vec::new(),
52            name_index: HashMap::new(),
53        }
54    }
55
56    /// Create a registry pre-populated with built-in functions.
57    pub fn with_builtins() -> Self {
58        let mut reg = Self::new();
59        register_builtins(&mut reg);
60        reg
61    }
62
63    /// Register a function signature. Returns the assigned FunctionId.
64    pub fn register(
65        &mut self,
66        name: &str,
67        kind: FunctionKind,
68        param_types: Vec<LogicalType>,
69        variadic: bool,
70        return_type: LogicalType,
71    ) -> FunctionId {
72        let id = FunctionId(self.signatures.len() as u32);
73        let lower_name = SmolStr::new(name.to_lowercase());
74        let sig = FunctionSignature {
75            id,
76            name: lower_name.clone(),
77            kind,
78            param_types,
79            variadic,
80            return_type,
81        };
82        let idx = self.signatures.len();
83        self.signatures.push(sig);
84        self.name_index.entry(lower_name).or_default().push(idx);
85        id
86    }
87
88    /// Resolve a function call: name + actual arg types → best matching signature.
89    ///
90    /// Considers implicit coercions. Returns the matching signature.
91    pub fn resolve(&self, name: &str, arg_types: &[LogicalType]) -> KyuResult<&FunctionSignature> {
92        let lower = name.to_lowercase();
93        let overloads = self
94            .name_index
95            .get(lower.as_str())
96            .ok_or_else(|| KyuError::Binder(format!("unknown function '{name}'")))?;
97
98        let mut best: Option<(usize, u32)> = None; // (index, total_cost)
99
100        for &idx in overloads {
101            let sig = &self.signatures[idx];
102            if let Some(cost) = match_cost(sig, arg_types) {
103                match best {
104                    None => best = Some((idx, cost)),
105                    Some((_, best_cost)) if cost < best_cost => {
106                        best = Some((idx, cost));
107                    }
108                    _ => {}
109                }
110            }
111        }
112
113        match best {
114            Some((idx, _)) => Ok(&self.signatures[idx]),
115            None => {
116                let type_names: Vec<_> = arg_types
117                    .iter()
118                    .map(|t| t.type_name().to_string())
119                    .collect();
120                Err(KyuError::Binder(format!(
121                    "no matching overload for {}({})",
122                    name,
123                    type_names.join(", "),
124                )))
125            }
126        }
127    }
128
129    /// Look up by FunctionId (O(1)).
130    pub fn get(&self, id: FunctionId) -> Option<&FunctionSignature> {
131        self.signatures.get(id.0 as usize)
132    }
133
134    /// Number of registered functions.
135    pub fn len(&self) -> usize {
136        self.signatures.len()
137    }
138
139    /// Whether the registry is empty.
140    pub fn is_empty(&self) -> bool {
141        self.signatures.is_empty()
142    }
143}
144
145/// Compute the total implicit cast cost for matching `arg_types` against `sig`.
146/// Returns `None` if args don't match.
147fn match_cost(sig: &FunctionSignature, arg_types: &[LogicalType]) -> Option<u32> {
148    if sig.variadic {
149        if arg_types.len() < sig.param_types.len() {
150            return None;
151        }
152    } else if arg_types.len() != sig.param_types.len() {
153        return None;
154    }
155
156    let mut total = 0u32;
157
158    // Match declared params.
159    for (param, arg) in sig.param_types.iter().zip(arg_types.iter()) {
160        if matches!(param, LogicalType::Any) {
161            // Any accepts anything at cost 0.
162            continue;
163        }
164        let cost = implicit_cast_cost(arg, param)?;
165        total += cost;
166    }
167
168    // For variadic, extra args are accepted at cost 0 (type checking is loose).
169    Some(total)
170}
171
172fn register_builtins(reg: &mut FunctionRegistry) {
173    use FunctionKind::{Aggregate, Scalar};
174    use LogicalType::*;
175
176    // Numeric scalar functions — register common overloads.
177    for ty in &[Int64, Double] {
178        reg.register("abs", Scalar, vec![ty.clone()], false, ty.clone());
179    }
180    reg.register("floor", Scalar, vec![Double], false, Double);
181    reg.register("ceil", Scalar, vec![Double], false, Double);
182    reg.register("round", Scalar, vec![Double], false, Double);
183    reg.register("sqrt", Scalar, vec![Double], false, Double);
184    reg.register("log", Scalar, vec![Double], false, Double);
185    reg.register("log2", Scalar, vec![Double], false, Double);
186    reg.register("log10", Scalar, vec![Double], false, Double);
187    reg.register("sin", Scalar, vec![Double], false, Double);
188    reg.register("cos", Scalar, vec![Double], false, Double);
189    reg.register("tan", Scalar, vec![Double], false, Double);
190    reg.register("sign", Scalar, vec![Int64], false, Int64);
191    reg.register("sign", Scalar, vec![Double], false, Int64);
192    reg.register("greatest", Scalar, vec![Any], true, Any);
193    reg.register("least", Scalar, vec![Any], true, Any);
194
195    // String scalar functions.
196    reg.register("lower", Scalar, vec![String], false, String);
197    reg.register("upper", Scalar, vec![String], false, String);
198    reg.register("length", Scalar, vec![String], false, Int64);
199    reg.register("size", Scalar, vec![String], false, Int64);
200    reg.register("trim", Scalar, vec![String], false, String);
201    reg.register("ltrim", Scalar, vec![String], false, String);
202    reg.register("rtrim", Scalar, vec![String], false, String);
203    reg.register("reverse", Scalar, vec![String], false, String);
204    reg.register(
205        "substring",
206        Scalar,
207        vec![String, Int64, Int64],
208        false,
209        String,
210    );
211    reg.register("left", Scalar, vec![String, Int64], false, String);
212    reg.register("right", Scalar, vec![String, Int64], false, String);
213    reg.register(
214        "replace",
215        Scalar,
216        vec![String, String, String],
217        false,
218        String,
219    );
220    reg.register("concat", Scalar, vec![String], true, String);
221    reg.register("lpad", Scalar, vec![String, Int64, String], false, String);
222    reg.register("rpad", Scalar, vec![String, Int64, String], false, String);
223
224    // Conversion functions.
225    reg.register("tostring", Scalar, vec![Any], false, String);
226    reg.register("tostring", Scalar, vec![String], false, String);
227    reg.register("tointeger", Scalar, vec![Any], false, Int64);
228    reg.register("tofloat", Scalar, vec![Any], false, Double);
229    reg.register("toboolean", Scalar, vec![Any], false, Bool);
230
231    // Utility.
232    reg.register("coalesce", Scalar, vec![Any], true, Any);
233    reg.register("typeof", Scalar, vec![Any], false, String);
234    reg.register("hash", Scalar, vec![Any], false, Int64);
235
236    // List functions.
237    reg.register(
238        "range",
239        Scalar,
240        vec![Int64, Int64],
241        false,
242        List(Box::new(Int64)),
243    );
244    reg.register("size", Scalar, vec![List(Box::new(Any))], false, Int64);
245    reg.register("length", Scalar, vec![List(Box::new(Any))], false, Int64);
246
247    // JSON functions.
248    reg.register("json_extract", Scalar, vec![String, String], false, String);
249    reg.register("json_valid", Scalar, vec![String], false, Bool);
250    reg.register("json_type", Scalar, vec![String], false, String);
251    reg.register(
252        "json_keys",
253        Scalar,
254        vec![String],
255        false,
256        List(Box::new(String)),
257    );
258    reg.register("json_array_length", Scalar, vec![String], false, Int64);
259    reg.register("json_contains", Scalar, vec![String, String], false, Bool);
260    reg.register(
261        "json_set",
262        Scalar,
263        vec![String, String, String],
264        false,
265        String,
266    );
267
268    // Aggregate functions.
269    reg.register("count", Aggregate, vec![Any], false, Int64);
270    reg.register("sum", Aggregate, vec![Int64], false, Int64);
271    reg.register("sum", Aggregate, vec![Double], false, Double);
272    reg.register("avg", Aggregate, vec![Int64], false, Double);
273    reg.register("avg", Aggregate, vec![Double], false, Double);
274    reg.register("min", Aggregate, vec![Any], false, Any);
275    reg.register("max", Aggregate, vec![Any], false, Any);
276    reg.register("collect", Aggregate, vec![Any], false, List(Box::new(Any)));
277}
278
279#[cfg(test)]
280mod tests {
281    use super::*;
282
283    #[test]
284    fn empty_registry() {
285        let reg = FunctionRegistry::new();
286        assert!(reg.is_empty());
287        assert_eq!(reg.len(), 0);
288    }
289
290    #[test]
291    fn register_and_get() {
292        let mut reg = FunctionRegistry::new();
293        let id = reg.register(
294            "foo",
295            FunctionKind::Scalar,
296            vec![LogicalType::Int64],
297            false,
298            LogicalType::Int64,
299        );
300        assert_eq!(id.0, 0);
301
302        let sig = reg.get(id).unwrap();
303        assert_eq!(sig.name.as_str(), "foo");
304        assert_eq!(sig.return_type, LogicalType::Int64);
305        assert_eq!(sig.kind, FunctionKind::Scalar);
306    }
307
308    #[test]
309    fn resolve_exact_match() {
310        let reg = FunctionRegistry::with_builtins();
311        let sig = reg.resolve("abs", &[LogicalType::Int64]).unwrap();
312        assert_eq!(sig.return_type, LogicalType::Int64);
313    }
314
315    #[test]
316    fn resolve_case_insensitive() {
317        let reg = FunctionRegistry::with_builtins();
318        let sig = reg.resolve("ABS", &[LogicalType::Int64]).unwrap();
319        assert_eq!(sig.name.as_str(), "abs");
320    }
321
322    #[test]
323    fn resolve_with_implicit_coercion() {
324        let reg = FunctionRegistry::with_builtins();
325        // abs(Int32) should match abs(Int64) via implicit cast.
326        let sig = reg.resolve("abs", &[LogicalType::Int32]).unwrap();
327        assert_eq!(sig.return_type, LogicalType::Int64);
328    }
329
330    #[test]
331    fn resolve_best_overload() {
332        let reg = FunctionRegistry::with_builtins();
333        // abs(Double) should prefer abs(Double) over abs(Int64).
334        let sig = reg.resolve("abs", &[LogicalType::Double]).unwrap();
335        assert_eq!(sig.return_type, LogicalType::Double);
336    }
337
338    #[test]
339    fn resolve_unknown_function() {
340        let reg = FunctionRegistry::with_builtins();
341        let result = reg.resolve("nonexistent", &[LogicalType::Int64]);
342        assert!(result.is_err());
343    }
344
345    #[test]
346    fn resolve_wrong_arg_count() {
347        let reg = FunctionRegistry::with_builtins();
348        let result = reg.resolve("abs", &[]);
349        assert!(result.is_err());
350    }
351
352    #[test]
353    fn resolve_aggregate() {
354        let reg = FunctionRegistry::with_builtins();
355        let sig = reg.resolve("count", &[LogicalType::Int64]).unwrap();
356        assert_eq!(sig.kind, FunctionKind::Aggregate);
357        assert_eq!(sig.return_type, LogicalType::Int64);
358    }
359
360    #[test]
361    fn resolve_string_function() {
362        let reg = FunctionRegistry::with_builtins();
363        let sig = reg.resolve("upper", &[LogicalType::String]).unwrap();
364        assert_eq!(sig.return_type, LogicalType::String);
365    }
366
367    #[test]
368    fn resolve_multi_arg_function() {
369        let reg = FunctionRegistry::with_builtins();
370        let sig = reg
371            .resolve(
372                "substring",
373                &[LogicalType::String, LogicalType::Int64, LogicalType::Int64],
374            )
375            .unwrap();
376        assert_eq!(sig.return_type, LogicalType::String);
377    }
378
379    #[test]
380    fn resolve_variadic_function() {
381        let reg = FunctionRegistry::with_builtins();
382        // coalesce(Any...) — accepts any number of args >= 1.
383        let sig = reg
384            .resolve(
385                "coalesce",
386                &[LogicalType::Int64, LogicalType::Int64, LogicalType::Int64],
387            )
388            .unwrap();
389        assert_eq!(sig.name.as_str(), "coalesce");
390    }
391
392    #[test]
393    fn builtins_populated() {
394        let reg = FunctionRegistry::with_builtins();
395        assert!(reg.len() > 20);
396    }
397
398    #[test]
399    fn function_id_sequential() {
400        let mut reg = FunctionRegistry::new();
401        let id0 = reg.register("a", FunctionKind::Scalar, vec![], false, LogicalType::Bool);
402        let id1 = reg.register("b", FunctionKind::Scalar, vec![], false, LogicalType::Bool);
403        assert_eq!(id0.0, 0);
404        assert_eq!(id1.0, 1);
405    }
406
407    #[test]
408    fn get_nonexistent_id() {
409        let reg = FunctionRegistry::new();
410        assert!(reg.get(FunctionId(999)).is_none());
411    }
412}