formualizer_eval/builtins/math/
combinatorics.rs1use super::super::utils::{ARG_NUM_LENIENT_ONE, ARG_NUM_LENIENT_TWO, coerce_num};
2use crate::args::ArgSchema;
3use crate::function::Function;
4use crate::traits::{ArgumentHandle, CalcValue, FunctionContext};
5use formualizer_common::{ExcelError, LiteralValue};
6use formualizer_macros::func_caps;
7
8#[derive(Debug)]
10pub struct FactFn;
11impl Function for FactFn {
12 func_caps!(PURE);
13 fn name(&self) -> &'static str {
14 "FACT"
15 }
16 fn min_args(&self) -> usize {
17 1
18 }
19 fn arg_schema(&self) -> &'static [ArgSchema] {
20 &ARG_NUM_LENIENT_ONE[..]
21 }
22 fn eval<'a, 'b, 'c>(
23 &self,
24 args: &'c [ArgumentHandle<'a, 'b>],
25 _: &dyn FunctionContext<'b>,
26 ) -> Result<CalcValue<'b>, ExcelError> {
27 let v = args[0].value()?.into_literal();
28 let n = match v {
29 LiteralValue::Error(e) => return Ok(CalcValue::Scalar(LiteralValue::Error(e))),
30 other => coerce_num(&other)?,
31 };
32
33 let n = n.trunc() as i64;
35
36 if n < 0 {
37 return Ok(CalcValue::Scalar(
38 LiteralValue::Error(ExcelError::new_num()),
39 ));
40 }
41
42 if n > 170 {
44 return Ok(CalcValue::Scalar(
45 LiteralValue::Error(ExcelError::new_num()),
46 ));
47 }
48
49 let mut result = 1.0_f64;
50 for i in 2..=(n as u64) {
51 result *= i as f64;
52 }
53
54 Ok(CalcValue::Scalar(LiteralValue::Number(result)))
55 }
56}
57
58#[derive(Debug)]
60pub struct GcdFn;
61impl Function for GcdFn {
62 func_caps!(PURE);
63 fn name(&self) -> &'static str {
64 "GCD"
65 }
66 fn min_args(&self) -> usize {
67 1
68 }
69 fn variadic(&self) -> bool {
70 true
71 }
72 fn arg_schema(&self) -> &'static [ArgSchema] {
73 &ARG_NUM_LENIENT_TWO[..]
74 }
75 fn eval<'a, 'b, 'c>(
76 &self,
77 args: &'c [ArgumentHandle<'a, 'b>],
78 _: &dyn FunctionContext<'b>,
79 ) -> Result<CalcValue<'b>, ExcelError> {
80 fn gcd(a: u64, b: u64) -> u64 {
81 if b == 0 { a } else { gcd(b, a % b) }
82 }
83
84 let mut result: Option<u64> = None;
85
86 for arg in args {
87 let v = arg.value()?.into_literal();
88 let n = match v {
89 LiteralValue::Error(e) => return Ok(CalcValue::Scalar(LiteralValue::Error(e))),
90 other => coerce_num(&other)?,
91 };
92
93 let n = n.trunc();
95 if n < 0.0 || n > 9.99999999e9 {
96 return Ok(CalcValue::Scalar(
97 LiteralValue::Error(ExcelError::new_num()),
98 ));
99 }
100 let n = n as u64;
101
102 result = Some(match result {
103 None => n,
104 Some(r) => gcd(r, n),
105 });
106 }
107
108 Ok(CalcValue::Scalar(LiteralValue::Number(
109 result.unwrap_or(0) as f64
110 )))
111 }
112}
113
114#[derive(Debug)]
116pub struct LcmFn;
117impl Function for LcmFn {
118 func_caps!(PURE);
119 fn name(&self) -> &'static str {
120 "LCM"
121 }
122 fn min_args(&self) -> usize {
123 1
124 }
125 fn variadic(&self) -> bool {
126 true
127 }
128 fn arg_schema(&self) -> &'static [ArgSchema] {
129 &ARG_NUM_LENIENT_TWO[..]
130 }
131 fn eval<'a, 'b, 'c>(
132 &self,
133 args: &'c [ArgumentHandle<'a, 'b>],
134 _: &dyn FunctionContext<'b>,
135 ) -> Result<CalcValue<'b>, ExcelError> {
136 fn gcd(a: u64, b: u64) -> u64 {
137 if b == 0 { a } else { gcd(b, a % b) }
138 }
139 fn lcm(a: u64, b: u64) -> u64 {
140 if a == 0 || b == 0 {
141 0
142 } else {
143 (a / gcd(a, b)) * b
144 }
145 }
146
147 let mut result: Option<u64> = None;
148
149 for arg in args {
150 let v = arg.value()?.into_literal();
151 let n = match v {
152 LiteralValue::Error(e) => return Ok(CalcValue::Scalar(LiteralValue::Error(e))),
153 other => coerce_num(&other)?,
154 };
155
156 let n = n.trunc();
157 if n < 0.0 || n > 9.99999999e9 {
158 return Ok(CalcValue::Scalar(
159 LiteralValue::Error(ExcelError::new_num()),
160 ));
161 }
162 let n = n as u64;
163
164 result = Some(match result {
165 None => n,
166 Some(r) => lcm(r, n),
167 });
168 }
169
170 Ok(CalcValue::Scalar(LiteralValue::Number(
171 result.unwrap_or(0) as f64
172 )))
173 }
174}
175
176#[derive(Debug)]
178pub struct CombinFn;
179impl Function for CombinFn {
180 func_caps!(PURE);
181 fn name(&self) -> &'static str {
182 "COMBIN"
183 }
184 fn min_args(&self) -> usize {
185 2
186 }
187 fn arg_schema(&self) -> &'static [ArgSchema] {
188 &ARG_NUM_LENIENT_TWO[..]
189 }
190 fn eval<'a, 'b, 'c>(
191 &self,
192 args: &'c [ArgumentHandle<'a, 'b>],
193 _: &dyn FunctionContext<'b>,
194 ) -> Result<CalcValue<'b>, ExcelError> {
195 if args.len() < 2 {
197 return Ok(CalcValue::Scalar(LiteralValue::Error(
198 ExcelError::new_value(),
199 )));
200 }
201
202 let n_val = args[0].value()?.into_literal();
203 let k_val = args[1].value()?.into_literal();
204
205 let n = match n_val {
206 LiteralValue::Error(e) => return Ok(CalcValue::Scalar(LiteralValue::Error(e))),
207 other => coerce_num(&other)?,
208 };
209 let k = match k_val {
210 LiteralValue::Error(e) => return Ok(CalcValue::Scalar(LiteralValue::Error(e))),
211 other => coerce_num(&other)?,
212 };
213
214 let n = n.trunc() as i64;
215 let k = k.trunc() as i64;
216
217 if n < 0 || k < 0 || k > n {
218 return Ok(CalcValue::Scalar(
219 LiteralValue::Error(ExcelError::new_num()),
220 ));
221 }
222
223 let k = k.min(n - k) as u64; let n = n as u64;
227
228 let mut result = 1.0_f64;
229 for i in 0..k {
230 result = result * (n - i) as f64 / (i + 1) as f64;
231 }
232
233 Ok(CalcValue::Scalar(LiteralValue::Number(result.round())))
234 }
235}
236
237#[derive(Debug)]
239pub struct PermutFn;
240impl Function for PermutFn {
241 func_caps!(PURE);
242 fn name(&self) -> &'static str {
243 "PERMUT"
244 }
245 fn min_args(&self) -> usize {
246 2
247 }
248 fn arg_schema(&self) -> &'static [ArgSchema] {
249 &ARG_NUM_LENIENT_TWO[..]
250 }
251 fn eval<'a, 'b, 'c>(
252 &self,
253 args: &'c [ArgumentHandle<'a, 'b>],
254 _: &dyn FunctionContext<'b>,
255 ) -> Result<CalcValue<'b>, ExcelError> {
256 if args.len() < 2 {
258 return Ok(CalcValue::Scalar(LiteralValue::Error(
259 ExcelError::new_value(),
260 )));
261 }
262
263 let n_val = args[0].value()?.into_literal();
264 let k_val = args[1].value()?.into_literal();
265
266 let n = match n_val {
267 LiteralValue::Error(e) => return Ok(CalcValue::Scalar(LiteralValue::Error(e))),
268 other => coerce_num(&other)?,
269 };
270 let k = match k_val {
271 LiteralValue::Error(e) => return Ok(CalcValue::Scalar(LiteralValue::Error(e))),
272 other => coerce_num(&other)?,
273 };
274
275 let n = n.trunc() as i64;
276 let k = k.trunc() as i64;
277
278 if n < 0 || k < 0 || k > n {
279 return Ok(CalcValue::Scalar(
280 LiteralValue::Error(ExcelError::new_num()),
281 ));
282 }
283
284 let mut result = 1.0_f64;
286 for i in 0..k {
287 result *= (n - i) as f64;
288 }
289
290 Ok(CalcValue::Scalar(LiteralValue::Number(result)))
291 }
292}
293
294pub fn register_builtins() {
295 use std::sync::Arc;
296 crate::function_registry::register_function(Arc::new(FactFn));
297 crate::function_registry::register_function(Arc::new(GcdFn));
298 crate::function_registry::register_function(Arc::new(LcmFn));
299 crate::function_registry::register_function(Arc::new(CombinFn));
300 crate::function_registry::register_function(Arc::new(PermutFn));
301}
302
303#[cfg(test)]
304mod tests {
305 use super::*;
306 use crate::test_workbook::TestWorkbook;
307 use crate::traits::ArgumentHandle;
308 use formualizer_parse::parser::{ASTNode, ASTNodeType};
309
310 fn interp(wb: &TestWorkbook) -> crate::interpreter::Interpreter<'_> {
311 wb.interpreter()
312 }
313 fn lit(v: LiteralValue) -> ASTNode {
314 ASTNode::new(ASTNodeType::Literal(v), None)
315 }
316
317 #[test]
318 fn fact_basic() {
319 let wb = TestWorkbook::new().with_function(std::sync::Arc::new(FactFn));
320 let ctx = interp(&wb);
321 let n = lit(LiteralValue::Number(5.0));
322 let f = ctx.context.get_function("", "FACT").unwrap();
323 assert_eq!(
324 f.dispatch(
325 &[ArgumentHandle::new(&n, &ctx)],
326 &ctx.function_context(None)
327 )
328 .unwrap()
329 .into_literal(),
330 LiteralValue::Number(120.0)
331 );
332 }
333
334 #[test]
335 fn fact_zero() {
336 let wb = TestWorkbook::new().with_function(std::sync::Arc::new(FactFn));
337 let ctx = interp(&wb);
338 let n = lit(LiteralValue::Number(0.0));
339 let f = ctx.context.get_function("", "FACT").unwrap();
340 assert_eq!(
341 f.dispatch(
342 &[ArgumentHandle::new(&n, &ctx)],
343 &ctx.function_context(None)
344 )
345 .unwrap()
346 .into_literal(),
347 LiteralValue::Number(1.0)
348 );
349 }
350
351 #[test]
352 fn gcd_basic() {
353 let wb = TestWorkbook::new().with_function(std::sync::Arc::new(GcdFn));
354 let ctx = interp(&wb);
355 let a = lit(LiteralValue::Number(12.0));
356 let b = lit(LiteralValue::Number(18.0));
357 let f = ctx.context.get_function("", "GCD").unwrap();
358 assert_eq!(
359 f.dispatch(
360 &[ArgumentHandle::new(&a, &ctx), ArgumentHandle::new(&b, &ctx)],
361 &ctx.function_context(None)
362 )
363 .unwrap()
364 .into_literal(),
365 LiteralValue::Number(6.0)
366 );
367 }
368
369 #[test]
370 fn lcm_basic() {
371 let wb = TestWorkbook::new().with_function(std::sync::Arc::new(LcmFn));
372 let ctx = interp(&wb);
373 let a = lit(LiteralValue::Number(4.0));
374 let b = lit(LiteralValue::Number(6.0));
375 let f = ctx.context.get_function("", "LCM").unwrap();
376 assert_eq!(
377 f.dispatch(
378 &[ArgumentHandle::new(&a, &ctx), ArgumentHandle::new(&b, &ctx)],
379 &ctx.function_context(None)
380 )
381 .unwrap()
382 .into_literal(),
383 LiteralValue::Number(12.0)
384 );
385 }
386
387 #[test]
388 fn combin_basic() {
389 let wb = TestWorkbook::new().with_function(std::sync::Arc::new(CombinFn));
390 let ctx = interp(&wb);
391 let n = lit(LiteralValue::Number(5.0));
392 let k = lit(LiteralValue::Number(2.0));
393 let f = ctx.context.get_function("", "COMBIN").unwrap();
394 assert_eq!(
395 f.dispatch(
396 &[ArgumentHandle::new(&n, &ctx), ArgumentHandle::new(&k, &ctx)],
397 &ctx.function_context(None)
398 )
399 .unwrap()
400 .into_literal(),
401 LiteralValue::Number(10.0)
402 );
403 }
404
405 #[test]
406 fn permut_basic() {
407 let wb = TestWorkbook::new().with_function(std::sync::Arc::new(PermutFn));
408 let ctx = interp(&wb);
409 let n = lit(LiteralValue::Number(5.0));
410 let k = lit(LiteralValue::Number(2.0));
411 let f = ctx.context.get_function("", "PERMUT").unwrap();
412 assert_eq!(
413 f.dispatch(
414 &[ArgumentHandle::new(&n, &ctx), ArgumentHandle::new(&k, &ctx)],
415 &ctx.function_context(None)
416 )
417 .unwrap()
418 .into_literal(),
419 LiteralValue::Number(20.0)
420 );
421 }
422}