Skip to main content

formualizer_eval/builtins/math/
combinatorics.rs

1use super::super::utils::{ARG_NUM_LENIENT_ONE, ARG_NUM_LENIENT_TWO, coerce_num};
2use crate::args::ArgSchema;
3use crate::function::Function;
4use crate::traits::{ArgumentHandle, CalcValue, FunctionContext};
5use formualizer_common::{ExcelError, LiteralValue};
6use formualizer_macros::func_caps;
7
8/// FACT(number) - Returns the factorial of a number
9#[derive(Debug)]
10pub struct FactFn;
11impl Function for FactFn {
12    func_caps!(PURE);
13    fn name(&self) -> &'static str {
14        "FACT"
15    }
16    fn min_args(&self) -> usize {
17        1
18    }
19    fn arg_schema(&self) -> &'static [ArgSchema] {
20        &ARG_NUM_LENIENT_ONE[..]
21    }
22    fn eval<'a, 'b, 'c>(
23        &self,
24        args: &'c [ArgumentHandle<'a, 'b>],
25        _: &dyn FunctionContext<'b>,
26    ) -> Result<CalcValue<'b>, ExcelError> {
27        let v = args[0].value()?.into_literal();
28        let n = match v {
29            LiteralValue::Error(e) => return Ok(CalcValue::Scalar(LiteralValue::Error(e))),
30            other => coerce_num(&other)?,
31        };
32
33        // Excel truncates to integer
34        let n = n.trunc() as i64;
35
36        if n < 0 {
37            return Ok(CalcValue::Scalar(
38                LiteralValue::Error(ExcelError::new_num()),
39            ));
40        }
41
42        // Factorial calculation (Excel supports up to 170!)
43        if n > 170 {
44            return Ok(CalcValue::Scalar(
45                LiteralValue::Error(ExcelError::new_num()),
46            ));
47        }
48
49        let mut result = 1.0_f64;
50        for i in 2..=(n as u64) {
51            result *= i as f64;
52        }
53
54        Ok(CalcValue::Scalar(LiteralValue::Number(result)))
55    }
56}
57
58/// GCD(number1, [number2], ...) - Returns the greatest common divisor
59#[derive(Debug)]
60pub struct GcdFn;
61impl Function for GcdFn {
62    func_caps!(PURE);
63    fn name(&self) -> &'static str {
64        "GCD"
65    }
66    fn min_args(&self) -> usize {
67        1
68    }
69    fn variadic(&self) -> bool {
70        true
71    }
72    fn arg_schema(&self) -> &'static [ArgSchema] {
73        &ARG_NUM_LENIENT_TWO[..]
74    }
75    fn eval<'a, 'b, 'c>(
76        &self,
77        args: &'c [ArgumentHandle<'a, 'b>],
78        _: &dyn FunctionContext<'b>,
79    ) -> Result<CalcValue<'b>, ExcelError> {
80        fn gcd(a: u64, b: u64) -> u64 {
81            if b == 0 { a } else { gcd(b, a % b) }
82        }
83
84        let mut result: Option<u64> = None;
85
86        for arg in args {
87            let v = arg.value()?.into_literal();
88            let n = match v {
89                LiteralValue::Error(e) => return Ok(CalcValue::Scalar(LiteralValue::Error(e))),
90                other => coerce_num(&other)?,
91            };
92
93            // Excel truncates and requires non-negative
94            let n = n.trunc();
95            if n < 0.0 || n > 9.99999999e9 {
96                return Ok(CalcValue::Scalar(
97                    LiteralValue::Error(ExcelError::new_num()),
98                ));
99            }
100            let n = n as u64;
101
102            result = Some(match result {
103                None => n,
104                Some(r) => gcd(r, n),
105            });
106        }
107
108        Ok(CalcValue::Scalar(LiteralValue::Number(
109            result.unwrap_or(0) as f64
110        )))
111    }
112}
113
114/// LCM(number1, [number2], ...) - Returns the least common multiple
115#[derive(Debug)]
116pub struct LcmFn;
117impl Function for LcmFn {
118    func_caps!(PURE);
119    fn name(&self) -> &'static str {
120        "LCM"
121    }
122    fn min_args(&self) -> usize {
123        1
124    }
125    fn variadic(&self) -> bool {
126        true
127    }
128    fn arg_schema(&self) -> &'static [ArgSchema] {
129        &ARG_NUM_LENIENT_TWO[..]
130    }
131    fn eval<'a, 'b, 'c>(
132        &self,
133        args: &'c [ArgumentHandle<'a, 'b>],
134        _: &dyn FunctionContext<'b>,
135    ) -> Result<CalcValue<'b>, ExcelError> {
136        fn gcd(a: u64, b: u64) -> u64 {
137            if b == 0 { a } else { gcd(b, a % b) }
138        }
139        fn lcm(a: u64, b: u64) -> u64 {
140            if a == 0 || b == 0 {
141                0
142            } else {
143                (a / gcd(a, b)) * b
144            }
145        }
146
147        let mut result: Option<u64> = None;
148
149        for arg in args {
150            let v = arg.value()?.into_literal();
151            let n = match v {
152                LiteralValue::Error(e) => return Ok(CalcValue::Scalar(LiteralValue::Error(e))),
153                other => coerce_num(&other)?,
154            };
155
156            let n = n.trunc();
157            if n < 0.0 || n > 9.99999999e9 {
158                return Ok(CalcValue::Scalar(
159                    LiteralValue::Error(ExcelError::new_num()),
160                ));
161            }
162            let n = n as u64;
163
164            result = Some(match result {
165                None => n,
166                Some(r) => lcm(r, n),
167            });
168        }
169
170        Ok(CalcValue::Scalar(LiteralValue::Number(
171            result.unwrap_or(0) as f64
172        )))
173    }
174}
175
176/// COMBIN(n, k) - Returns the number of combinations
177#[derive(Debug)]
178pub struct CombinFn;
179impl Function for CombinFn {
180    func_caps!(PURE);
181    fn name(&self) -> &'static str {
182        "COMBIN"
183    }
184    fn min_args(&self) -> usize {
185        2
186    }
187    fn arg_schema(&self) -> &'static [ArgSchema] {
188        &ARG_NUM_LENIENT_TWO[..]
189    }
190    fn eval<'a, 'b, 'c>(
191        &self,
192        args: &'c [ArgumentHandle<'a, 'b>],
193        _: &dyn FunctionContext<'b>,
194    ) -> Result<CalcValue<'b>, ExcelError> {
195        // Check minimum required arguments
196        if args.len() < 2 {
197            return Ok(CalcValue::Scalar(LiteralValue::Error(
198                ExcelError::new_value(),
199            )));
200        }
201
202        let n_val = args[0].value()?.into_literal();
203        let k_val = args[1].value()?.into_literal();
204
205        let n = match n_val {
206            LiteralValue::Error(e) => return Ok(CalcValue::Scalar(LiteralValue::Error(e))),
207            other => coerce_num(&other)?,
208        };
209        let k = match k_val {
210            LiteralValue::Error(e) => return Ok(CalcValue::Scalar(LiteralValue::Error(e))),
211            other => coerce_num(&other)?,
212        };
213
214        let n = n.trunc() as i64;
215        let k = k.trunc() as i64;
216
217        if n < 0 || k < 0 || k > n {
218            return Ok(CalcValue::Scalar(
219                LiteralValue::Error(ExcelError::new_num()),
220            ));
221        }
222
223        // Calculate C(n, k) = n! / (k! * (n-k)!)
224        // Use the more efficient formula: C(n, k) = product of (n-i)/(i+1) for i in 0..k
225        let k = k.min(n - k) as u64; // Use symmetry for efficiency
226        let n = n as u64;
227
228        let mut result = 1.0_f64;
229        for i in 0..k {
230            result = result * (n - i) as f64 / (i + 1) as f64;
231        }
232
233        Ok(CalcValue::Scalar(LiteralValue::Number(result.round())))
234    }
235}
236
237/// PERMUT(n, k) - Returns the number of permutations
238#[derive(Debug)]
239pub struct PermutFn;
240impl Function for PermutFn {
241    func_caps!(PURE);
242    fn name(&self) -> &'static str {
243        "PERMUT"
244    }
245    fn min_args(&self) -> usize {
246        2
247    }
248    fn arg_schema(&self) -> &'static [ArgSchema] {
249        &ARG_NUM_LENIENT_TWO[..]
250    }
251    fn eval<'a, 'b, 'c>(
252        &self,
253        args: &'c [ArgumentHandle<'a, 'b>],
254        _: &dyn FunctionContext<'b>,
255    ) -> Result<CalcValue<'b>, ExcelError> {
256        // Check minimum required arguments
257        if args.len() < 2 {
258            return Ok(CalcValue::Scalar(LiteralValue::Error(
259                ExcelError::new_value(),
260            )));
261        }
262
263        let n_val = args[0].value()?.into_literal();
264        let k_val = args[1].value()?.into_literal();
265
266        let n = match n_val {
267            LiteralValue::Error(e) => return Ok(CalcValue::Scalar(LiteralValue::Error(e))),
268            other => coerce_num(&other)?,
269        };
270        let k = match k_val {
271            LiteralValue::Error(e) => return Ok(CalcValue::Scalar(LiteralValue::Error(e))),
272            other => coerce_num(&other)?,
273        };
274
275        let n = n.trunc() as i64;
276        let k = k.trunc() as i64;
277
278        if n < 0 || k < 0 || k > n {
279            return Ok(CalcValue::Scalar(
280                LiteralValue::Error(ExcelError::new_num()),
281            ));
282        }
283
284        // P(n, k) = n! / (n-k)! = n * (n-1) * ... * (n-k+1)
285        let mut result = 1.0_f64;
286        for i in 0..k {
287            result *= (n - i) as f64;
288        }
289
290        Ok(CalcValue::Scalar(LiteralValue::Number(result)))
291    }
292}
293
294pub fn register_builtins() {
295    use std::sync::Arc;
296    crate::function_registry::register_function(Arc::new(FactFn));
297    crate::function_registry::register_function(Arc::new(GcdFn));
298    crate::function_registry::register_function(Arc::new(LcmFn));
299    crate::function_registry::register_function(Arc::new(CombinFn));
300    crate::function_registry::register_function(Arc::new(PermutFn));
301}
302
303#[cfg(test)]
304mod tests {
305    use super::*;
306    use crate::test_workbook::TestWorkbook;
307    use crate::traits::ArgumentHandle;
308    use formualizer_parse::parser::{ASTNode, ASTNodeType};
309
310    fn interp(wb: &TestWorkbook) -> crate::interpreter::Interpreter<'_> {
311        wb.interpreter()
312    }
313    fn lit(v: LiteralValue) -> ASTNode {
314        ASTNode::new(ASTNodeType::Literal(v), None)
315    }
316
317    #[test]
318    fn fact_basic() {
319        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(FactFn));
320        let ctx = interp(&wb);
321        let n = lit(LiteralValue::Number(5.0));
322        let f = ctx.context.get_function("", "FACT").unwrap();
323        assert_eq!(
324            f.dispatch(
325                &[ArgumentHandle::new(&n, &ctx)],
326                &ctx.function_context(None)
327            )
328            .unwrap()
329            .into_literal(),
330            LiteralValue::Number(120.0)
331        );
332    }
333
334    #[test]
335    fn fact_zero() {
336        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(FactFn));
337        let ctx = interp(&wb);
338        let n = lit(LiteralValue::Number(0.0));
339        let f = ctx.context.get_function("", "FACT").unwrap();
340        assert_eq!(
341            f.dispatch(
342                &[ArgumentHandle::new(&n, &ctx)],
343                &ctx.function_context(None)
344            )
345            .unwrap()
346            .into_literal(),
347            LiteralValue::Number(1.0)
348        );
349    }
350
351    #[test]
352    fn gcd_basic() {
353        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(GcdFn));
354        let ctx = interp(&wb);
355        let a = lit(LiteralValue::Number(12.0));
356        let b = lit(LiteralValue::Number(18.0));
357        let f = ctx.context.get_function("", "GCD").unwrap();
358        assert_eq!(
359            f.dispatch(
360                &[ArgumentHandle::new(&a, &ctx), ArgumentHandle::new(&b, &ctx)],
361                &ctx.function_context(None)
362            )
363            .unwrap()
364            .into_literal(),
365            LiteralValue::Number(6.0)
366        );
367    }
368
369    #[test]
370    fn lcm_basic() {
371        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(LcmFn));
372        let ctx = interp(&wb);
373        let a = lit(LiteralValue::Number(4.0));
374        let b = lit(LiteralValue::Number(6.0));
375        let f = ctx.context.get_function("", "LCM").unwrap();
376        assert_eq!(
377            f.dispatch(
378                &[ArgumentHandle::new(&a, &ctx), ArgumentHandle::new(&b, &ctx)],
379                &ctx.function_context(None)
380            )
381            .unwrap()
382            .into_literal(),
383            LiteralValue::Number(12.0)
384        );
385    }
386
387    #[test]
388    fn combin_basic() {
389        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(CombinFn));
390        let ctx = interp(&wb);
391        let n = lit(LiteralValue::Number(5.0));
392        let k = lit(LiteralValue::Number(2.0));
393        let f = ctx.context.get_function("", "COMBIN").unwrap();
394        assert_eq!(
395            f.dispatch(
396                &[ArgumentHandle::new(&n, &ctx), ArgumentHandle::new(&k, &ctx)],
397                &ctx.function_context(None)
398            )
399            .unwrap()
400            .into_literal(),
401            LiteralValue::Number(10.0)
402        );
403    }
404
405    #[test]
406    fn permut_basic() {
407        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(PermutFn));
408        let ctx = interp(&wb);
409        let n = lit(LiteralValue::Number(5.0));
410        let k = lit(LiteralValue::Number(2.0));
411        let f = ctx.context.get_function("", "PERMUT").unwrap();
412        assert_eq!(
413            f.dispatch(
414                &[ArgumentHandle::new(&n, &ctx), ArgumentHandle::new(&k, &ctx)],
415                &ctx.function_context(None)
416            )
417            .unwrap()
418            .into_literal(),
419            LiteralValue::Number(20.0)
420        );
421    }
422}