Skip to main content

formualizer_eval/builtins/math/
aggregate.rs

1use super::super::utils::{ARG_RANGE_NUM_LENIENT_ONE, coerce_num};
2use crate::args::ArgSchema;
3use crate::function::Function;
4use crate::traits::{ArgumentHandle, FunctionContext};
5use arrow_array::Array;
6use formualizer_common::{ExcelError, LiteralValue};
7use formualizer_macros::func_caps;
8
9/* ─────────────────────────── SUM() ──────────────────────────── */
10
11#[derive(Debug)]
12pub struct SumFn;
13
14/// Adds numeric values across scalars and ranges.
15///
16/// `SUM` evaluates all arguments, coercing text to numbers where possible,
17/// and returns the total. Blank cells and logical values in ranges are ignored.
18///
19/// # Remarks
20/// - If any argument evaluates to an error, `SUM` propagates the first error it encounters.
21/// - Unparseable text literals (e.g., `"foo"`) will result in a `#VALUE!` error.
22///
23/// # Examples
24///
25/// ```yaml,sandbox
26/// title: "Basic scalar addition"
27/// formula: "=SUM(10, 20, 5)"
28/// expected: 35
29/// ```
30///
31/// ```yaml,sandbox
32/// title: "Summing a range"
33/// grid:
34///   A1: 10
35///   A2: 20
36///   A3: "N/A"
37/// formula: "=SUM(A1:A3)"
38/// expected: 30
39/// ```
40///
41/// ```yaml,docs
42/// related:
43///   - SUMIF
44///   - SUMIFS
45///   - SUMPRODUCT
46///   - AVERAGE
47/// faq:
48///   - q: "Why does SUM return #VALUE! for some text arguments?"
49///     a: "Direct scalar text that cannot be parsed as a number raises #VALUE! during coercion."
50///   - q: "Do text and logical values inside ranges get added?"
51///     a: "No. In ranged inputs, only numeric cells contribute to the total."
52/// ```
53///
54/// [formualizer-docgen:schema:start]
55/// Name: SUM
56/// Type: SumFn
57/// Min args: 0
58/// Max args: variadic
59/// Variadic: true
60/// Signature: SUM(arg1...: number@range)
61/// Arg schema: arg1{kinds=number,required=true,shape=range,by_ref=false,coercion=NumberLenientText,max=None,repeating=None,default=false}
62/// Caps: PURE, REDUCTION, NUMERIC_ONLY, STREAM_OK, PARALLEL_ARGS
63/// [formualizer-docgen:schema:end]
64impl Function for SumFn {
65    func_caps!(PURE, REDUCTION, NUMERIC_ONLY, STREAM_OK, PARALLEL_ARGS);
66
67    fn name(&self) -> &'static str {
68        "SUM"
69    }
70    fn min_args(&self) -> usize {
71        0
72    }
73    fn variadic(&self) -> bool {
74        true
75    }
76    fn arg_schema(&self) -> &'static [ArgSchema] {
77        &ARG_RANGE_NUM_LENIENT_ONE[..]
78    }
79
80    fn eval<'a, 'b, 'c>(
81        &self,
82        args: &'c [ArgumentHandle<'a, 'b>],
83        ctx: &dyn FunctionContext<'b>,
84    ) -> Result<crate::traits::CalcValue<'b>, ExcelError> {
85        let mut total = 0.0;
86        for arg in args {
87            if let Ok(view) = arg.range_view() {
88                // Propagate errors from range first
89                for res in view.errors_slices() {
90                    let (_, _, err_cols) = res?;
91                    for col in err_cols {
92                        if col.null_count() < col.len() {
93                            for i in 0..col.len() {
94                                if !col.is_null(i) {
95                                    return Ok(crate::traits::CalcValue::Scalar(
96                                        LiteralValue::Error(ExcelError::new(
97                                            crate::arrow_store::unmap_error_code(col.value(i)),
98                                        )),
99                                    ));
100                                }
101                            }
102                        }
103                    }
104                }
105
106                for res in view.numbers_slices() {
107                    let (_, _, num_cols) = res?;
108                    for col in num_cols {
109                        total +=
110                            arrow::compute::kernels::aggregate::sum(col.as_ref()).unwrap_or(0.0);
111                    }
112                }
113            } else {
114                let v = arg.value()?.into_literal();
115                match v {
116                    LiteralValue::Error(e) => {
117                        return Ok(crate::traits::CalcValue::Scalar(LiteralValue::Error(e)));
118                    }
119                    v => total += coerce_num(&v)?,
120                }
121            }
122        }
123        Ok(crate::traits::CalcValue::Scalar(LiteralValue::Number(
124            total,
125        )))
126    }
127}
128
129/* ─────────────────────────── COUNT() ──────────────────────────── */
130
131#[derive(Debug)]
132pub struct CountFn;
133
134/// Counts numeric values across scalars and ranges.
135///
136/// `COUNT` evaluates all arguments and counts how many are numeric values.
137/// Numbers, dates, and text representations of numbers (when supplied directly) are counted.
138///
139/// # Remarks
140/// - Text values inside ranges are ignored and not counted.
141/// - Blank cells and logical values in ranges are ignored.
142///
143/// # Examples
144///
145/// ```yaml,sandbox
146/// title: "Counting mixed scalar inputs"
147/// formula: "=COUNT(1, \"x\", 2, 3)"
148/// expected: 3
149/// ```
150///
151/// ```yaml,sandbox
152/// title: "Counting in a range"
153/// grid:
154///   A1: 10
155///   A2: "foo"
156///   A3: 20
157/// formula: "=COUNT(A1:A3)"
158/// expected: 2
159/// ```
160///
161/// ```yaml,docs
162/// related:
163///   - COUNTA
164///   - COUNTBLANK
165///   - COUNTIF
166///   - COUNTIFS
167/// faq:
168///   - q: "Why doesn't COUNT include text in a range?"
169///     a: "COUNT only counts numeric values; text cells in ranges are ignored."
170///   - q: "Can direct text like \"12\" be counted?"
171///     a: "Yes. Direct scalar arguments are coerced and counted when they parse as numbers."
172/// ```
173///
174/// [formualizer-docgen:schema:start]
175/// Name: COUNT
176/// Type: CountFn
177/// Min args: 0
178/// Max args: variadic
179/// Variadic: true
180/// Signature: COUNT(arg1...: number@range)
181/// Arg schema: arg1{kinds=number,required=true,shape=range,by_ref=false,coercion=NumberLenientText,max=None,repeating=None,default=false}
182/// Caps: PURE, REDUCTION, NUMERIC_ONLY, STREAM_OK
183/// [formualizer-docgen:schema:end]
184impl Function for CountFn {
185    func_caps!(PURE, REDUCTION, NUMERIC_ONLY, STREAM_OK);
186
187    fn name(&self) -> &'static str {
188        "COUNT"
189    }
190    fn min_args(&self) -> usize {
191        0
192    }
193    fn variadic(&self) -> bool {
194        true
195    }
196    fn arg_schema(&self) -> &'static [ArgSchema] {
197        &ARG_RANGE_NUM_LENIENT_ONE[..]
198    }
199
200    fn eval<'a, 'b, 'c>(
201        &self,
202        args: &'c [ArgumentHandle<'a, 'b>],
203        _: &dyn FunctionContext<'b>,
204    ) -> Result<crate::traits::CalcValue<'b>, ExcelError> {
205        let mut count: i64 = 0;
206        for arg in args {
207            if let Ok(view) = arg.range_view() {
208                for res in view.numbers_slices() {
209                    let (_, _, num_cols) = res?;
210                    for col in num_cols {
211                        count += (col.len() - col.null_count()) as i64;
212                    }
213                }
214            } else {
215                let v = arg.value()?.into_literal();
216                if let LiteralValue::Error(e) = v {
217                    return Ok(crate::traits::CalcValue::Scalar(LiteralValue::Error(e)));
218                }
219                if !matches!(v, LiteralValue::Empty) && coerce_num(&v).is_ok() {
220                    count += 1;
221                }
222            }
223        }
224        Ok(crate::traits::CalcValue::Scalar(LiteralValue::Number(
225            count as f64,
226        )))
227    }
228}
229
230/* ─────────────────────────── AVERAGE() ──────────────────────────── */
231
232#[derive(Debug)]
233pub struct AverageFn;
234
235/// Returns the arithmetic mean of numeric values across scalars and ranges.
236///
237/// `AVERAGE` sums numeric inputs and divides by the count of numeric values that participated.
238///
239/// # Remarks
240/// - Errors in any scalar argument or referenced range propagate immediately.
241/// - In ranges, only numeric/date-time serial values are included; text and blanks are ignored.
242/// - Scalar arguments use lenient number coercion with locale support.
243/// - If no numeric values are found, `AVERAGE` returns `#DIV/0!`.
244///
245/// # Examples
246///
247/// ```yaml,sandbox
248/// title: "Average of scalar values"
249/// formula: "=AVERAGE(10, 20, 5)"
250/// expected: 11.666666666666666
251/// ```
252///
253/// ```yaml,sandbox
254/// title: "Average over a mixed range"
255/// grid:
256///   A1: 10
257///   A2: "x"
258///   A3: 20
259/// formula: "=AVERAGE(A1:A3)"
260/// expected: 15
261/// ```
262///
263/// ```yaml,sandbox
264/// title: "No numeric values returns divide-by-zero"
265/// formula: "=AVERAGE(\"x\", \"\")"
266/// expected: "#DIV/0!"
267/// ```
268///
269/// ```yaml,docs
270/// related:
271///   - SUM
272///   - COUNT
273///   - AVERAGEIF
274///   - AVERAGEIFS
275/// faq:
276///   - q: "When does AVERAGE return #DIV/0!?"
277///     a: "It returns #DIV/0! when no numeric values are found after filtering/coercion."
278///   - q: "Do text cells in ranges affect the denominator?"
279///     a: "No. Only numeric values are counted toward the divisor."
280/// ```
281///
282/// [formualizer-docgen:schema:start]
283/// Name: AVERAGE
284/// Type: AverageFn
285/// Min args: 1
286/// Max args: variadic
287/// Variadic: true
288/// Signature: AVERAGE(arg1...: number@range)
289/// Arg schema: arg1{kinds=number,required=true,shape=range,by_ref=false,coercion=NumberLenientText,max=None,repeating=None,default=false}
290/// Caps: PURE, REDUCTION, NUMERIC_ONLY, STREAM_OK
291/// [formualizer-docgen:schema:end]
292impl Function for AverageFn {
293    func_caps!(PURE, REDUCTION, NUMERIC_ONLY, STREAM_OK);
294
295    fn name(&self) -> &'static str {
296        "AVERAGE"
297    }
298    fn min_args(&self) -> usize {
299        1
300    }
301    fn variadic(&self) -> bool {
302        true
303    }
304    fn arg_schema(&self) -> &'static [ArgSchema] {
305        &ARG_RANGE_NUM_LENIENT_ONE[..]
306    }
307
308    fn eval<'a, 'b, 'c>(
309        &self,
310        args: &'c [ArgumentHandle<'a, 'b>],
311        ctx: &dyn FunctionContext<'b>,
312    ) -> Result<crate::traits::CalcValue<'b>, ExcelError> {
313        let mut sum = 0.0f64;
314        let mut cnt: i64 = 0;
315        for arg in args {
316            if let Ok(view) = arg.range_view() {
317                // Propagate errors from range first
318                for res in view.errors_slices() {
319                    let (_, _, err_cols) = res?;
320                    for col in err_cols {
321                        if col.null_count() < col.len() {
322                            for i in 0..col.len() {
323                                if !col.is_null(i) {
324                                    return Ok(crate::traits::CalcValue::Scalar(
325                                        LiteralValue::Error(ExcelError::new(
326                                            crate::arrow_store::unmap_error_code(col.value(i)),
327                                        )),
328                                    ));
329                                }
330                            }
331                        }
332                    }
333                }
334
335                for res in view.numbers_slices() {
336                    let (_, _, num_cols) = res?;
337                    for col in num_cols {
338                        sum += arrow::compute::kernels::aggregate::sum(col.as_ref()).unwrap_or(0.0);
339                        cnt += (col.len() - col.null_count()) as i64;
340                    }
341                }
342            } else {
343                let v = arg.value()?.into_literal();
344                if let LiteralValue::Error(e) = v {
345                    return Ok(crate::traits::CalcValue::Scalar(LiteralValue::Error(e)));
346                }
347                if let Ok(n) = crate::coercion::to_number_lenient_with_locale(&v, &ctx.locale()) {
348                    sum += n;
349                    cnt += 1;
350                }
351            }
352        }
353        if cnt == 0 {
354            return Ok(crate::traits::CalcValue::Scalar(LiteralValue::Error(
355                ExcelError::new_div(),
356            )));
357        }
358        Ok(crate::traits::CalcValue::Scalar(LiteralValue::Number(
359            sum / (cnt as f64),
360        )))
361    }
362}
363
364/* ──────────────────────── SUMPRODUCT() ───────────────────────── */
365
366#[derive(Debug)]
367pub struct SumProductFn;
368
369/// Multiplies aligned values across arrays and returns the sum of those products.
370///
371/// `SUMPRODUCT` supports scalar or range inputs, applies broadcast semantics, and accumulates
372/// the product for each aligned cell position.
373///
374/// # Remarks
375/// - Input shapes must be broadcast-compatible; otherwise `SUMPRODUCT` returns `#VALUE!`.
376/// - Non-numeric values are treated as `0` during multiplication.
377/// - Any explicit error value in the inputs propagates immediately.
378///
379/// # Examples
380///
381/// ```yaml,sandbox
382/// title: "Pairwise sum of products"
383/// formula: "=SUMPRODUCT({1,2,3}, {4,5,6})"
384/// expected: 32
385/// ```
386///
387/// ```yaml,sandbox
388/// title: "Range-based sumproduct"
389/// grid:
390///   A1: 2
391///   A2: 3
392///   A3: 4
393///   B1: 10
394///   B2: 20
395///   B3: 30
396/// formula: "=SUMPRODUCT(A1:A3, B1:B3)"
397/// expected: 200
398/// ```
399///
400/// ```yaml,sandbox
401/// title: "Text entries contribute zero"
402/// formula: "=SUMPRODUCT({1,\"x\",3}, {1,1,1})"
403/// expected: 4
404/// ```
405///
406/// ```yaml,docs
407/// related:
408///   - SUM
409///   - PRODUCT
410///   - MMULT
411///   - SUMIFS
412/// faq:
413///   - q: "Why does SUMPRODUCT return #VALUE! with some array shapes?"
414///     a: "The argument arrays must be broadcast-compatible; incompatible shapes raise #VALUE!."
415///   - q: "How are text values handled in multiplication?"
416///     a: "Non-numeric values are treated as 0, unless an explicit error is present."
417/// ```
418///
419/// [formualizer-docgen:schema:start]
420/// Name: SUMPRODUCT
421/// Type: SumProductFn
422/// Min args: 1
423/// Max args: variadic
424/// Variadic: true
425/// Signature: SUMPRODUCT(arg1...: number@range)
426/// Arg schema: arg1{kinds=number,required=true,shape=range,by_ref=false,coercion=NumberLenientText,max=None,repeating=None,default=false}
427/// Caps: PURE, REDUCTION
428/// [formualizer-docgen:schema:end]
429impl Function for SumProductFn {
430    // Pure reduction over arrays; uses broadcasting and lenient coercion
431    func_caps!(PURE, REDUCTION);
432
433    fn name(&self) -> &'static str {
434        "SUMPRODUCT"
435    }
436    fn min_args(&self) -> usize {
437        1
438    }
439    fn variadic(&self) -> bool {
440        true
441    }
442    fn arg_schema(&self) -> &'static [ArgSchema] {
443        // Accept ranges or scalars; numeric lenient coercion
444        &ARG_RANGE_NUM_LENIENT_ONE[..]
445    }
446
447    fn eval<'a, 'b, 'c>(
448        &self,
449        args: &'c [ArgumentHandle<'a, 'b>],
450        _: &dyn FunctionContext<'b>,
451    ) -> Result<crate::traits::CalcValue<'b>, ExcelError> {
452        use crate::broadcast::{broadcast_shape, project_index};
453
454        if args.is_empty() {
455            return Ok(crate::traits::CalcValue::Scalar(LiteralValue::Number(0.0)));
456        }
457
458        // Helper: materialize an argument to a 2D array of LiteralValue
459        let to_array = |ah: &ArgumentHandle| -> Result<Vec<Vec<LiteralValue>>, ExcelError> {
460            if let Ok(rv) = ah.range_view() {
461                let mut rows: Vec<Vec<LiteralValue>> = Vec::new();
462                rv.for_each_row(&mut |row| {
463                    rows.push(row.to_vec());
464                    Ok(())
465                })?;
466                Ok(rows)
467            } else {
468                let v = ah.value()?.into_literal();
469                Ok(match v {
470                    LiteralValue::Array(arr) => arr,
471                    other => vec![vec![other]],
472                })
473            }
474        };
475
476        // Collect arrays and shapes
477        let mut arrays: Vec<Vec<Vec<LiteralValue>>> = Vec::with_capacity(args.len());
478        let mut shapes: Vec<(usize, usize)> = Vec::with_capacity(args.len());
479        for a in args.iter() {
480            let arr = to_array(a)?;
481            let shape = (arr.len(), arr.first().map(|r| r.len()).unwrap_or(0));
482            arrays.push(arr);
483            shapes.push(shape);
484        }
485
486        // Compute broadcast target shape across all args
487        let target = match broadcast_shape(&shapes) {
488            Ok(s) => s,
489            Err(_) => {
490                return Ok(crate::traits::CalcValue::Scalar(LiteralValue::Error(
491                    ExcelError::new_value(),
492                )));
493            }
494        };
495
496        // Iterate target shape, multiply coerced values across args, sum total
497        let mut total = 0.0f64;
498        for r in 0..target.0 {
499            for c in 0..target.1 {
500                let mut prod = 1.0f64;
501                for (arr, &shape) in arrays.iter().zip(shapes.iter()) {
502                    let (rr, cc) = project_index((r, c), shape);
503                    let lv = arr
504                        .get(rr)
505                        .and_then(|row| row.get(cc))
506                        .cloned()
507                        .unwrap_or(LiteralValue::Empty);
508                    match lv {
509                        LiteralValue::Error(e) => {
510                            return Ok(crate::traits::CalcValue::Scalar(LiteralValue::Error(e)));
511                        }
512                        _ => match super::super::utils::coerce_num(&lv) {
513                            Ok(n) => {
514                                prod *= n;
515                            }
516                            Err(_) => {
517                                // Non-numeric -> treated as 0 in SUMPRODUCT
518                                prod *= 0.0;
519                            }
520                        },
521                    }
522                }
523                total += prod;
524            }
525        }
526        Ok(crate::traits::CalcValue::Scalar(LiteralValue::Number(
527            total,
528        )))
529    }
530}
531
532#[cfg(test)]
533mod tests_sumproduct {
534    use super::*;
535    use crate::test_workbook::TestWorkbook;
536    use crate::traits::ArgumentHandle;
537    use formualizer_parse::LiteralValue;
538    use formualizer_parse::parser::{ASTNode, ASTNodeType};
539
540    fn interp(wb: &TestWorkbook) -> crate::interpreter::Interpreter<'_> {
541        wb.interpreter()
542    }
543
544    fn arr(vals: Vec<Vec<LiteralValue>>) -> ASTNode {
545        ASTNode::new(ASTNodeType::Literal(LiteralValue::Array(vals)), None)
546    }
547
548    fn num(n: f64) -> ASTNode {
549        ASTNode::new(ASTNodeType::Literal(LiteralValue::Number(n)), None)
550    }
551
552    #[test]
553    fn sumproduct_basic_pairwise() {
554        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(SumProductFn));
555        let ctx = interp(&wb);
556        // {1,2,3} * {4,5,6} = 1*4 + 2*5 + 3*6 = 32
557        let a = arr(vec![vec![
558            LiteralValue::Int(1),
559            LiteralValue::Int(2),
560            LiteralValue::Int(3),
561        ]]);
562        let b = arr(vec![vec![
563            LiteralValue::Int(4),
564            LiteralValue::Int(5),
565            LiteralValue::Int(6),
566        ]]);
567        let args = vec![ArgumentHandle::new(&a, &ctx), ArgumentHandle::new(&b, &ctx)];
568        let f = ctx.context.get_function("", "SUMPRODUCT").unwrap();
569        assert_eq!(
570            f.dispatch(&args, &ctx.function_context(None))
571                .unwrap()
572                .into_literal(),
573            LiteralValue::Number(32.0)
574        );
575    }
576
577    #[test]
578    fn sumproduct_variadic_three_arrays() {
579        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(SumProductFn));
580        let ctx = interp(&wb);
581        // {1,2} * {3,4} * {2,2} = (1*3*2) + (2*4*2) = 6 + 16 = 22
582        let a = arr(vec![vec![LiteralValue::Int(1), LiteralValue::Int(2)]]);
583        let b = arr(vec![vec![LiteralValue::Int(3), LiteralValue::Int(4)]]);
584        let c = arr(vec![vec![LiteralValue::Int(2), LiteralValue::Int(2)]]);
585        let args = vec![
586            ArgumentHandle::new(&a, &ctx),
587            ArgumentHandle::new(&b, &ctx),
588            ArgumentHandle::new(&c, &ctx),
589        ];
590        let f = ctx.context.get_function("", "SUMPRODUCT").unwrap();
591        assert_eq!(
592            f.dispatch(&args, &ctx.function_context(None))
593                .unwrap()
594                .into_literal(),
595            LiteralValue::Number(22.0)
596        );
597    }
598
599    #[test]
600    fn sumproduct_broadcast_scalar_over_array() {
601        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(SumProductFn));
602        let ctx = interp(&wb);
603        // {1,2,3} * 10 => (1*10 + 2*10 + 3*10) = 60
604        let a = arr(vec![vec![
605            LiteralValue::Int(1),
606            LiteralValue::Int(2),
607            LiteralValue::Int(3),
608        ]]);
609        let s = num(10.0);
610        let args = vec![ArgumentHandle::new(&a, &ctx), ArgumentHandle::new(&s, &ctx)];
611        let f = ctx.context.get_function("", "SUMPRODUCT").unwrap();
612        assert_eq!(
613            f.dispatch(&args, &ctx.function_context(None))
614                .unwrap()
615                .into_literal(),
616            LiteralValue::Number(60.0)
617        );
618    }
619
620    #[test]
621    fn sumproduct_2d_arrays_broadcast_rows_cols() {
622        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(SumProductFn));
623        let ctx = interp(&wb);
624        // A is 2x2, B is 1x2 -> broadcast B across rows
625        // A = [[1,2],[3,4]], B = [[10,20]]
626        // sum = 1*10 + 2*20 + 3*10 + 4*20 = 10 + 40 + 30 + 80 = 160
627        let a = arr(vec![
628            vec![LiteralValue::Int(1), LiteralValue::Int(2)],
629            vec![LiteralValue::Int(3), LiteralValue::Int(4)],
630        ]);
631        let b = arr(vec![vec![LiteralValue::Int(10), LiteralValue::Int(20)]]);
632        let args = vec![ArgumentHandle::new(&a, &ctx), ArgumentHandle::new(&b, &ctx)];
633        let f = ctx.context.get_function("", "SUMPRODUCT").unwrap();
634        assert_eq!(
635            f.dispatch(&args, &ctx.function_context(None))
636                .unwrap()
637                .into_literal(),
638            LiteralValue::Number(160.0)
639        );
640    }
641
642    #[test]
643    fn sumproduct_non_numeric_treated_as_zero() {
644        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(SumProductFn));
645        let ctx = interp(&wb);
646        // {1,"x",3} * {1,1,1} => 1*1 + 0*1 + 3*1 = 4
647        let a = arr(vec![vec![
648            LiteralValue::Int(1),
649            LiteralValue::Text("x".into()),
650            LiteralValue::Int(3),
651        ]]);
652        let b = arr(vec![vec![
653            LiteralValue::Int(1),
654            LiteralValue::Int(1),
655            LiteralValue::Int(1),
656        ]]);
657        let args = vec![ArgumentHandle::new(&a, &ctx), ArgumentHandle::new(&b, &ctx)];
658        let f = ctx.context.get_function("", "SUMPRODUCT").unwrap();
659        assert_eq!(
660            f.dispatch(&args, &ctx.function_context(None))
661                .unwrap()
662                .into_literal(),
663            LiteralValue::Number(4.0)
664        );
665    }
666
667    #[test]
668    fn sumproduct_error_in_input_propagates() {
669        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(SumProductFn));
670        let ctx = interp(&wb);
671        let a = arr(vec![vec![LiteralValue::Int(1), LiteralValue::Int(2)]]);
672        let e = ASTNode::new(
673            ASTNodeType::Literal(LiteralValue::Error(ExcelError::new_na())),
674            None,
675        );
676        let args = vec![ArgumentHandle::new(&a, &ctx), ArgumentHandle::new(&e, &ctx)];
677        let f = ctx.context.get_function("", "SUMPRODUCT").unwrap();
678        match f
679            .dispatch(&args, &ctx.function_context(None))
680            .unwrap()
681            .into_literal()
682        {
683            LiteralValue::Error(err) => assert_eq!(err, "#N/A"),
684            v => panic!("expected error, got {v:?}"),
685        }
686    }
687
688    #[test]
689    fn sumproduct_incompatible_shapes_value_error() {
690        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(SumProductFn));
691        let ctx = interp(&wb);
692        // 1x3 and 1x2 -> #VALUE!
693        let a = arr(vec![vec![
694            LiteralValue::Int(1),
695            LiteralValue::Int(2),
696            LiteralValue::Int(3),
697        ]]);
698        let b = arr(vec![vec![LiteralValue::Int(4), LiteralValue::Int(5)]]);
699        let args = vec![ArgumentHandle::new(&a, &ctx), ArgumentHandle::new(&b, &ctx)];
700        let f = ctx.context.get_function("", "SUMPRODUCT").unwrap();
701        match f
702            .dispatch(&args, &ctx.function_context(None))
703            .unwrap()
704            .into_literal()
705        {
706            LiteralValue::Error(e) => assert_eq!(e, "#VALUE!"),
707            v => panic!("expected value error, got {v:?}"),
708        }
709    }
710}
711
712#[cfg(test)]
713mod tests {
714    use super::*;
715    use crate::test_workbook::TestWorkbook;
716    use formualizer_parse::LiteralValue;
717
718    fn interp(wb: &TestWorkbook) -> crate::interpreter::Interpreter<'_> {
719        wb.interpreter()
720    }
721
722    #[test]
723    fn test_sum_caps() {
724        let sum_fn = SumFn;
725        let caps = sum_fn.caps();
726
727        // Check that the expected capabilities are set
728        assert!(caps.contains(crate::function::FnCaps::PURE));
729        assert!(caps.contains(crate::function::FnCaps::REDUCTION));
730        assert!(caps.contains(crate::function::FnCaps::NUMERIC_ONLY));
731        assert!(caps.contains(crate::function::FnCaps::STREAM_OK));
732
733        // Check that other caps are not set
734        assert!(!caps.contains(crate::function::FnCaps::VOLATILE));
735        assert!(!caps.contains(crate::function::FnCaps::ELEMENTWISE));
736    }
737
738    #[test]
739    fn test_sum_basic() {
740        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(SumFn));
741        let ctx = interp(&wb);
742        let fctx = ctx.function_context(None);
743
744        // Test basic SUM functionality by creating ArgumentHandles manually
745        let dummy_ast_1 = formualizer_parse::parser::ASTNode::new(
746            formualizer_parse::parser::ASTNodeType::Literal(LiteralValue::Number(1.0)),
747            None,
748        );
749        let dummy_ast_2 = formualizer_parse::parser::ASTNode::new(
750            formualizer_parse::parser::ASTNodeType::Literal(LiteralValue::Number(2.0)),
751            None,
752        );
753        let dummy_ast_3 = formualizer_parse::parser::ASTNode::new(
754            formualizer_parse::parser::ASTNodeType::Literal(LiteralValue::Number(3.0)),
755            None,
756        );
757
758        let args = vec![
759            ArgumentHandle::new(&dummy_ast_1, &ctx),
760            ArgumentHandle::new(&dummy_ast_2, &ctx),
761            ArgumentHandle::new(&dummy_ast_3, &ctx),
762        ];
763
764        let sum_fn = ctx.context.get_function("", "SUM").unwrap();
765        let result = sum_fn.dispatch(&args, &fctx).unwrap().into_literal();
766        assert_eq!(result, LiteralValue::Number(6.0));
767    }
768}
769
770#[cfg(test)]
771mod tests_count {
772    use super::*;
773    use crate::test_workbook::TestWorkbook;
774    use crate::traits::ArgumentHandle;
775    use formualizer_parse::LiteralValue;
776    use formualizer_parse::parser::ASTNode;
777    use formualizer_parse::parser::ASTNodeType;
778
779    fn interp(wb: &TestWorkbook) -> crate::interpreter::Interpreter<'_> {
780        wb.interpreter()
781    }
782
783    #[test]
784    fn count_numbers_ignores_text() {
785        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(CountFn));
786        let ctx = interp(&wb);
787        // COUNT({1,2,"x",3}) => 3
788        let arr = LiteralValue::Array(vec![vec![
789            LiteralValue::Int(1),
790            LiteralValue::Int(2),
791            LiteralValue::Text("x".into()),
792            LiteralValue::Int(3),
793        ]]);
794        let node = ASTNode::new(ASTNodeType::Literal(arr), None);
795        let args = vec![ArgumentHandle::new(&node, &ctx)];
796        let f = ctx.context.get_function("", "COUNT").unwrap();
797        let fctx = ctx.function_context(None);
798        assert_eq!(
799            f.dispatch(&args, &fctx).unwrap().into_literal(),
800            LiteralValue::Number(3.0)
801        );
802    }
803
804    #[test]
805    fn count_multiple_args_and_scalars() {
806        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(CountFn));
807        let ctx = interp(&wb);
808        let n1 = ASTNode::new(ASTNodeType::Literal(LiteralValue::Int(10)), None);
809        let n2 = ASTNode::new(ASTNodeType::Literal(LiteralValue::Text("n".into())), None);
810        let arr = LiteralValue::Array(vec![vec![LiteralValue::Int(1), LiteralValue::Int(2)]]);
811        let a = ASTNode::new(ASTNodeType::Literal(arr), None);
812        let args = vec![
813            ArgumentHandle::new(&a, &ctx),
814            ArgumentHandle::new(&n1, &ctx),
815            ArgumentHandle::new(&n2, &ctx),
816        ];
817        let f = ctx.context.get_function("", "COUNT").unwrap();
818        // Two from array + scalar 10 = 3
819        let fctx = ctx.function_context(None);
820        assert_eq!(
821            f.dispatch(&args, &fctx).unwrap().into_literal(),
822            LiteralValue::Number(3.0)
823        );
824    }
825
826    #[test]
827    fn count_direct_error_argument_propagates() {
828        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(CountFn));
829        let ctx = interp(&wb);
830        let err = ASTNode::new(
831            ASTNodeType::Literal(LiteralValue::Error(ExcelError::from_error_string(
832                "#DIV/0!",
833            ))),
834            None,
835        );
836        let args = vec![ArgumentHandle::new(&err, &ctx)];
837        let f = ctx.context.get_function("", "COUNT").unwrap();
838        let fctx = ctx.function_context(None);
839        match f.dispatch(&args, &fctx).unwrap().into_literal() {
840            LiteralValue::Error(e) => assert_eq!(e, "#DIV/0!"),
841            v => panic!("unexpected {v:?}"),
842        }
843    }
844}
845
846#[cfg(test)]
847mod tests_average {
848    use super::*;
849    use crate::test_workbook::TestWorkbook;
850    use crate::traits::ArgumentHandle;
851    use formualizer_parse::LiteralValue;
852    use formualizer_parse::parser::ASTNode;
853    use formualizer_parse::parser::ASTNodeType;
854
855    fn interp(wb: &TestWorkbook) -> crate::interpreter::Interpreter<'_> {
856        wb.interpreter()
857    }
858
859    #[test]
860    fn average_basic_numbers() {
861        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(AverageFn));
862        let ctx = interp(&wb);
863        let arr = LiteralValue::Array(vec![vec![
864            LiteralValue::Int(2),
865            LiteralValue::Int(4),
866            LiteralValue::Int(6),
867        ]]);
868        let node = ASTNode::new(ASTNodeType::Literal(arr), None);
869        let args = vec![ArgumentHandle::new(&node, &ctx)];
870        let f = ctx.context.get_function("", "AVERAGE").unwrap();
871        assert_eq!(
872            f.dispatch(&args, &ctx.function_context(None))
873                .unwrap()
874                .into_literal(),
875            LiteralValue::Number(4.0)
876        );
877    }
878
879    #[test]
880    fn average_mixed_with_text() {
881        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(AverageFn));
882        let ctx = interp(&wb);
883        let arr = LiteralValue::Array(vec![vec![
884            LiteralValue::Int(2),
885            LiteralValue::Text("x".into()),
886            LiteralValue::Int(6),
887        ]]);
888        let node = ASTNode::new(ASTNodeType::Literal(arr), None);
889        let args = vec![ArgumentHandle::new(&node, &ctx)];
890        let f = ctx.context.get_function("", "AVERAGE").unwrap();
891        // average of 2 and 6 = 4
892        assert_eq!(
893            f.dispatch(&args, &ctx.function_context(None))
894                .unwrap()
895                .into_literal(),
896            LiteralValue::Number(4.0)
897        );
898    }
899
900    #[test]
901    fn average_no_numeric_div0() {
902        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(AverageFn));
903        let ctx = interp(&wb);
904        let arr = LiteralValue::Array(vec![vec![
905            LiteralValue::Text("a".into()),
906            LiteralValue::Text("b".into()),
907        ]]);
908        let node = ASTNode::new(ASTNodeType::Literal(arr), None);
909        let args = vec![ArgumentHandle::new(&node, &ctx)];
910        let f = ctx.context.get_function("", "AVERAGE").unwrap();
911        let fctx = ctx.function_context(None);
912        match f.dispatch(&args, &fctx).unwrap().into_literal() {
913            LiteralValue::Error(e) => assert_eq!(e, "#DIV/0!"),
914            v => panic!("expected #DIV/0!, got {v:?}"),
915        }
916    }
917
918    #[test]
919    fn average_direct_error_argument_propagates() {
920        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(AverageFn));
921        let ctx = interp(&wb);
922        let err = ASTNode::new(
923            ASTNodeType::Literal(LiteralValue::Error(ExcelError::from_error_string(
924                "#DIV/0!",
925            ))),
926            None,
927        );
928        let args = vec![ArgumentHandle::new(&err, &ctx)];
929        let f = ctx.context.get_function("", "AVERAGE").unwrap();
930        let fctx = ctx.function_context(None);
931        match f.dispatch(&args, &fctx).unwrap().into_literal() {
932            LiteralValue::Error(e) => assert_eq!(e, "#DIV/0!"),
933            v => panic!("unexpected {v:?}"),
934        }
935    }
936}
937
938pub fn register_builtins() {
939    crate::function_registry::register_function(std::sync::Arc::new(SumProductFn));
940    crate::function_registry::register_function(std::sync::Arc::new(SumFn));
941    crate::function_registry::register_function(std::sync::Arc::new(CountFn));
942    crate::function_registry::register_function(std::sync::Arc::new(AverageFn));
943}