Skip to main content

cel_core/ext/
math_ext.rs

1//! Math extension library for CEL.
2//!
3//! This module provides additional math functions beyond the CEL standard library,
4//! matching the cel-go math extension.
5//!
6//! # Functions
7//!
8//! - `math.greatest(...)` - Returns greatest of arguments (variadic or list)
9//! - `math.least(...)` - Returns least of arguments (variadic or list)
10//! - `math.ceil(double)` - Ceiling function
11//! - `math.floor(double)` - Floor function
12//! - `math.round(double)` - Round to nearest integer
13//! - `math.trunc(double)` - Truncate toward zero
14//! - `math.abs(number)` - Absolute value
15//! - `math.sign(number)` - Sign of number (-1, 0, or 1)
16//! - `math.isNaN(double)` - Check if NaN
17//! - `math.isInf(double)` - Check if infinite
18//! - `math.isFinite(double)` - Check if finite
19//! - `math.bitAnd(int, int)` - Bitwise AND
20//! - `math.bitOr(int, int)` - Bitwise OR
21//! - `math.bitXor(int, int)` - Bitwise XOR
22//! - `math.bitNot(int)` - Bitwise NOT
23//! - `math.bitShiftLeft(int, int)` - Left shift
24//! - `math.bitShiftRight(int, int)` - Right shift
25
26use crate::types::{CelType, FunctionDecl, OverloadDecl};
27
28/// Returns the math extension library function declarations.
29pub fn math_extension() -> Vec<FunctionDecl> {
30    let mut funcs = Vec::new();
31
32    // math.greatest and math.least with multiple arities
33    funcs.push(build_minmax_function("math.greatest"));
34    funcs.push(build_minmax_function("math.least"));
35
36    // Simple double functions
37    for name in ["math.ceil", "math.floor", "math.round", "math.trunc"] {
38        funcs.push(FunctionDecl::new(name).with_overload(OverloadDecl::function(
39            &format!("{}_double", name.replace('.', "_")),
40            vec![CelType::Double],
41            CelType::Double,
42        )));
43    }
44
45    // math.abs
46    funcs.push(
47        FunctionDecl::new("math.abs")
48            .with_overload(OverloadDecl::function(
49                "math_abs_int",
50                vec![CelType::Int],
51                CelType::Int,
52            ))
53            .with_overload(OverloadDecl::function(
54                "math_abs_uint",
55                vec![CelType::UInt],
56                CelType::UInt,
57            ))
58            .with_overload(OverloadDecl::function(
59                "math_abs_double",
60                vec![CelType::Double],
61                CelType::Double,
62            )),
63    );
64
65    // math.sign
66    funcs.push(
67        FunctionDecl::new("math.sign")
68            .with_overload(OverloadDecl::function(
69                "math_sign_int",
70                vec![CelType::Int],
71                CelType::Int,
72            ))
73            .with_overload(OverloadDecl::function(
74                "math_sign_uint",
75                vec![CelType::UInt],
76                CelType::UInt,
77            ))
78            .with_overload(OverloadDecl::function(
79                "math_sign_double",
80                vec![CelType::Double],
81                CelType::Double,
82            )),
83    );
84
85    // math.isNaN, isInf, isFinite
86    funcs.push(FunctionDecl::new("math.isNaN").with_overload(OverloadDecl::function(
87        "math_isnan_double",
88        vec![CelType::Double],
89        CelType::Bool,
90    )));
91    funcs.push(FunctionDecl::new("math.isInf").with_overload(OverloadDecl::function(
92        "math_isinf_double",
93        vec![CelType::Double],
94        CelType::Bool,
95    )));
96    funcs.push(FunctionDecl::new("math.isFinite").with_overload(OverloadDecl::function(
97        "math_isfinite_double",
98        vec![CelType::Double],
99        CelType::Bool,
100    )));
101
102    // Bit operations
103    add_bit_operations(&mut funcs);
104
105    funcs
106}
107
108fn build_minmax_function(name: &str) -> FunctionDecl {
109    let base = name.replace('.', "_");
110    let mut decl = FunctionDecl::new(name);
111
112    // Unary (identity)
113    for (suffix, cel_type) in [
114        ("int", CelType::Int),
115        ("uint", CelType::UInt),
116        ("double", CelType::Double),
117    ] {
118        decl = decl.with_overload(OverloadDecl::function(
119            &format!("{}_{}", base, suffix),
120            vec![cel_type.clone()],
121            cel_type,
122        ));
123    }
124
125    // Binary same-type
126    for (suffix, cel_type) in [
127        ("int", CelType::Int),
128        ("uint", CelType::UInt),
129        ("double", CelType::Double),
130    ] {
131        decl = decl.with_overload(OverloadDecl::function(
132            &format!("{}_{}_{}", base, suffix, suffix),
133            vec![cel_type.clone(), cel_type.clone()],
134            cel_type,
135        ));
136    }
137
138    // Binary mixed-type -> Dyn
139    let types = [
140        ("int", CelType::Int),
141        ("uint", CelType::UInt),
142        ("double", CelType::Double),
143    ];
144    for (name1, type1) in &types {
145        for (name2, type2) in &types {
146            if name1 != name2 {
147                decl = decl.with_overload(OverloadDecl::function(
148                    &format!("{}_{}_{}", base, name1, name2),
149                    vec![type1.clone(), type2.clone()],
150                    CelType::Dyn,
151                ));
152            }
153        }
154    }
155
156    // Ternary and higher arities (3-6 args) for common use cases
157    for arity in 3..=6 {
158        // All same type
159        for (suffix, cel_type) in [
160            ("int", CelType::Int),
161            ("uint", CelType::UInt),
162            ("double", CelType::Double),
163        ] {
164            decl = decl.with_overload(OverloadDecl::function(
165                &format!("{}_{}{}", base, suffix, arity),
166                vec![cel_type.clone(); arity],
167                cel_type,
168            ));
169        }
170        // Mixed -> Dyn (just one overload for mixed types)
171        decl = decl.with_overload(OverloadDecl::function(
172            &format!("{}_dyn{}", base, arity),
173            vec![CelType::Dyn; arity],
174            CelType::Dyn,
175        ));
176    }
177
178    // List overloads
179    for (suffix, cel_type) in [
180        ("int", CelType::Int),
181        ("uint", CelType::UInt),
182        ("double", CelType::Double),
183    ] {
184        decl = decl.with_overload(OverloadDecl::function(
185            &format!("{}_list_{}", base, suffix),
186            vec![CelType::list(cel_type.clone())],
187            cel_type,
188        ));
189    }
190    decl = decl.with_overload(OverloadDecl::function(
191        &format!("{}_list_dyn", base),
192        vec![CelType::list(CelType::Dyn)],
193        CelType::Dyn,
194    ));
195
196    decl
197}
198
199fn add_bit_operations(funcs: &mut Vec<FunctionDecl>) {
200    // math.bitAnd, bitOr, bitXor (binary)
201    for op in ["bitAnd", "bitOr", "bitXor"] {
202        funcs.push(
203            FunctionDecl::new(&format!("math.{}", op))
204                .with_overload(OverloadDecl::function(
205                    &format!("math_{}_int_int", op.to_lowercase()),
206                    vec![CelType::Int, CelType::Int],
207                    CelType::Int,
208                ))
209                .with_overload(OverloadDecl::function(
210                    &format!("math_{}_uint_uint", op.to_lowercase()),
211                    vec![CelType::UInt, CelType::UInt],
212                    CelType::UInt,
213                )),
214        );
215    }
216
217    // math.bitNot (unary)
218    funcs.push(
219        FunctionDecl::new("math.bitNot")
220            .with_overload(OverloadDecl::function(
221                "math_bitnot_int",
222                vec![CelType::Int],
223                CelType::Int,
224            ))
225            .with_overload(OverloadDecl::function(
226                "math_bitnot_uint",
227                vec![CelType::UInt],
228                CelType::UInt,
229            )),
230    );
231
232    // math.bitShiftLeft, bitShiftRight
233    for op in ["bitShiftLeft", "bitShiftRight"] {
234        funcs.push(
235            FunctionDecl::new(&format!("math.{}", op))
236                .with_overload(OverloadDecl::function(
237                    &format!("math_{}_int_int", op.to_lowercase()),
238                    vec![CelType::Int, CelType::Int],
239                    CelType::Int,
240                ))
241                .with_overload(OverloadDecl::function(
242                    &format!("math_{}_uint_int", op.to_lowercase()),
243                    vec![CelType::UInt, CelType::Int],
244                    CelType::UInt,
245                )),
246        );
247    }
248}
249
250#[cfg(test)]
251mod tests {
252    use super::*;
253
254    #[test]
255    fn test_math_extension_has_functions() {
256        let funcs = math_extension();
257        // Should have: greatest, least, ceil, floor, round, trunc, abs, sign,
258        // isNaN, isInf, isFinite, bitAnd, bitOr, bitXor, bitNot, bitShiftLeft, bitShiftRight
259        assert!(funcs.len() >= 17);
260    }
261
262    #[test]
263    fn test_math_greatest_overloads() {
264        let funcs = math_extension();
265        let greatest = funcs.iter().find(|f| f.name == "math.greatest").unwrap();
266
267        // Check unary overloads exist
268        assert!(greatest
269            .overloads
270            .iter()
271            .any(|o| o.id == "math_greatest_int" && o.params.len() == 1));
272        assert!(greatest
273            .overloads
274            .iter()
275            .any(|o| o.id == "math_greatest_double" && o.params.len() == 1));
276
277        // Check binary same-type overloads
278        assert!(greatest
279            .overloads
280            .iter()
281            .any(|o| o.id == "math_greatest_int_int" && o.params.len() == 2));
282
283        // Check ternary overloads
284        assert!(greatest
285            .overloads
286            .iter()
287            .any(|o| o.id == "math_greatest_int3" && o.params.len() == 3));
288
289        // Check list overloads
290        assert!(greatest
291            .overloads
292            .iter()
293            .any(|o| o.id == "math_greatest_list_int"));
294    }
295
296    #[test]
297    fn test_math_abs_overloads() {
298        let funcs = math_extension();
299        let abs = funcs.iter().find(|f| f.name == "math.abs").unwrap();
300        assert_eq!(abs.overloads.len(), 3); // int, uint, double
301    }
302
303    #[test]
304    fn test_bit_operations() {
305        let funcs = math_extension();
306
307        let bit_and = funcs.iter().find(|f| f.name == "math.bitAnd").unwrap();
308        assert_eq!(bit_and.overloads.len(), 2); // int, uint
309
310        let bit_not = funcs.iter().find(|f| f.name == "math.bitNot").unwrap();
311        assert_eq!(bit_not.overloads.len(), 2); // int, uint
312    }
313
314    #[test]
315    fn test_all_functions_are_standalone() {
316        let funcs = math_extension();
317        for func in &funcs {
318            for overload in &func.overloads {
319                assert!(
320                    !overload.is_member,
321                    "Expected {} to be standalone, but it's a member function",
322                    overload.id
323                );
324            }
325        }
326    }
327}