Skip to main content

formualizer_eval/builtins/math/
aggregate.rs

1use super::super::utils::{ARG_RANGE_NUM_LENIENT_ONE, coerce_num};
2use crate::args::ArgSchema;
3use crate::engine::VisibilityMaskMode;
4use crate::function::Function;
5use crate::traits::{ArgumentHandle, FunctionContext};
6use arrow_array::Array;
7use formualizer_common::{ExcelError, ExcelErrorKind, LiteralValue};
8use formualizer_macros::func_caps;
9
10/* ─────────────────────────── SUM() ──────────────────────────── */
11
12#[derive(Debug)]
13pub struct SumFn;
14
15/// Adds numeric values across scalars and ranges.
16///
17/// `SUM` evaluates all arguments, coercing text to numbers where possible,
18/// and returns the total. Blank cells and logical values in ranges are ignored.
19///
20/// # Remarks
21/// - If any argument evaluates to an error, `SUM` propagates the first error it encounters.
22/// - Unparseable text literals (e.g., `"foo"`) will result in a `#VALUE!` error.
23///
24/// # Examples
25///
26/// ```yaml,sandbox
27/// title: "Basic scalar addition"
28/// formula: "=SUM(10, 20, 5)"
29/// expected: 35
30/// ```
31///
32/// ```yaml,sandbox
33/// title: "Summing a range"
34/// grid:
35///   A1: 10
36///   A2: 20
37///   A3: "N/A"
38/// formula: "=SUM(A1:A3)"
39/// expected: 30
40/// ```
41///
42/// ```yaml,docs
43/// related:
44///   - SUMIF
45///   - SUMIFS
46///   - SUMPRODUCT
47///   - AVERAGE
48/// faq:
49///   - q: "Why does SUM return #VALUE! for some text arguments?"
50///     a: "Direct scalar text that cannot be parsed as a number raises #VALUE! during coercion."
51///   - q: "Do text and logical values inside ranges get added?"
52///     a: "No. In ranged inputs, only numeric cells contribute to the total."
53/// ```
54///
55/// [formualizer-docgen:schema:start]
56/// Name: SUM
57/// Type: SumFn
58/// Min args: 0
59/// Max args: variadic
60/// Variadic: true
61/// Signature: SUM(arg1...: number@range)
62/// Arg schema: arg1{kinds=number,required=true,shape=range,by_ref=false,coercion=NumberLenientText,max=None,repeating=None,default=false}
63/// Caps: PURE, REDUCTION, NUMERIC_ONLY, STREAM_OK, PARALLEL_ARGS
64/// [formualizer-docgen:schema:end]
65impl Function for SumFn {
66    func_caps!(PURE, REDUCTION, NUMERIC_ONLY, STREAM_OK, PARALLEL_ARGS);
67
68    fn name(&self) -> &'static str {
69        "SUM"
70    }
71    fn min_args(&self) -> usize {
72        0
73    }
74    fn variadic(&self) -> bool {
75        true
76    }
77    fn arg_schema(&self) -> &'static [ArgSchema] {
78        &ARG_RANGE_NUM_LENIENT_ONE[..]
79    }
80
81    fn eval<'a, 'b, 'c>(
82        &self,
83        args: &'c [ArgumentHandle<'a, 'b>],
84        ctx: &dyn FunctionContext<'b>,
85    ) -> Result<crate::traits::CalcValue<'b>, ExcelError> {
86        let mut total = 0.0;
87        for arg in args {
88            if let Ok(view) = arg.range_view() {
89                // Propagate errors from range first
90                for res in view.errors_slices() {
91                    let (_, _, err_cols) = res?;
92                    for col in err_cols {
93                        if col.null_count() < col.len() {
94                            for i in 0..col.len() {
95                                if !col.is_null(i) {
96                                    return Ok(crate::traits::CalcValue::Scalar(
97                                        LiteralValue::Error(ExcelError::new(
98                                            crate::arrow_store::unmap_error_code(col.value(i)),
99                                        )),
100                                    ));
101                                }
102                            }
103                        }
104                    }
105                }
106
107                for res in view.numbers_slices() {
108                    let (_, _, num_cols) = res?;
109                    for col in num_cols {
110                        total +=
111                            arrow::compute::kernels::aggregate::sum(col.as_ref()).unwrap_or(0.0);
112                    }
113                }
114            } else {
115                let v = arg.value()?.into_literal();
116                match v {
117                    LiteralValue::Error(e) => {
118                        return Ok(crate::traits::CalcValue::Scalar(LiteralValue::Error(e)));
119                    }
120                    v => total += coerce_num(&v)?,
121                }
122            }
123        }
124        Ok(crate::traits::CalcValue::Scalar(LiteralValue::Number(
125            total,
126        )))
127    }
128}
129
130/* ─────────────────────────── COUNT() ──────────────────────────── */
131
132#[derive(Debug)]
133pub struct CountFn;
134
135/// Counts numeric values across scalars and ranges.
136///
137/// `COUNT` evaluates all arguments and counts how many are numeric values.
138/// Numbers, dates, and text representations of numbers (when supplied directly) are counted.
139///
140/// # Remarks
141/// - Text values inside ranges are ignored and not counted.
142/// - Blank cells and logical values in ranges are ignored.
143///
144/// # Examples
145///
146/// ```yaml,sandbox
147/// title: "Counting mixed scalar inputs"
148/// formula: "=COUNT(1, \"x\", 2, 3)"
149/// expected: 3
150/// ```
151///
152/// ```yaml,sandbox
153/// title: "Counting in a range"
154/// grid:
155///   A1: 10
156///   A2: "foo"
157///   A3: 20
158/// formula: "=COUNT(A1:A3)"
159/// expected: 2
160/// ```
161///
162/// ```yaml,docs
163/// related:
164///   - COUNTA
165///   - COUNTBLANK
166///   - COUNTIF
167///   - COUNTIFS
168/// faq:
169///   - q: "Why doesn't COUNT include text in a range?"
170///     a: "COUNT only counts numeric values; text cells in ranges are ignored."
171///   - q: "Can direct text like \"12\" be counted?"
172///     a: "Yes. Direct scalar arguments are coerced and counted when they parse as numbers."
173/// ```
174///
175/// [formualizer-docgen:schema:start]
176/// Name: COUNT
177/// Type: CountFn
178/// Min args: 0
179/// Max args: variadic
180/// Variadic: true
181/// Signature: COUNT(arg1...: number@range)
182/// Arg schema: arg1{kinds=number,required=true,shape=range,by_ref=false,coercion=NumberLenientText,max=None,repeating=None,default=false}
183/// Caps: PURE, REDUCTION, NUMERIC_ONLY, STREAM_OK
184/// [formualizer-docgen:schema:end]
185impl Function for CountFn {
186    func_caps!(PURE, REDUCTION, NUMERIC_ONLY, STREAM_OK);
187
188    fn name(&self) -> &'static str {
189        "COUNT"
190    }
191    fn min_args(&self) -> usize {
192        0
193    }
194    fn variadic(&self) -> bool {
195        true
196    }
197    fn arg_schema(&self) -> &'static [ArgSchema] {
198        &ARG_RANGE_NUM_LENIENT_ONE[..]
199    }
200
201    fn eval<'a, 'b, 'c>(
202        &self,
203        args: &'c [ArgumentHandle<'a, 'b>],
204        _: &dyn FunctionContext<'b>,
205    ) -> Result<crate::traits::CalcValue<'b>, ExcelError> {
206        let mut count: i64 = 0;
207        for arg in args {
208            if let Ok(view) = arg.range_view() {
209                for res in view.numbers_slices() {
210                    let (_, _, num_cols) = res?;
211                    for col in num_cols {
212                        count += (col.len() - col.null_count()) as i64;
213                    }
214                }
215            } else {
216                let v = arg.value()?.into_literal();
217                if let LiteralValue::Error(e) = v {
218                    return Ok(crate::traits::CalcValue::Scalar(LiteralValue::Error(e)));
219                }
220                if !matches!(v, LiteralValue::Empty) && coerce_num(&v).is_ok() {
221                    count += 1;
222                }
223            }
224        }
225        Ok(crate::traits::CalcValue::Scalar(LiteralValue::Number(
226            count as f64,
227        )))
228    }
229}
230
231/* ─────────────────────────── AVERAGE() ──────────────────────────── */
232
233#[derive(Debug)]
234pub struct AverageFn;
235
236/// Returns the arithmetic mean of numeric values across scalars and ranges.
237///
238/// `AVERAGE` sums numeric inputs and divides by the count of numeric values that participated.
239///
240/// # Remarks
241/// - Errors in any scalar argument or referenced range propagate immediately.
242/// - In ranges, only numeric/date-time serial values are included; text and blanks are ignored.
243/// - Scalar arguments use lenient number coercion with locale support.
244/// - If no numeric values are found, `AVERAGE` returns `#DIV/0!`.
245///
246/// # Examples
247///
248/// ```yaml,sandbox
249/// title: "Average of scalar values"
250/// formula: "=AVERAGE(10, 20, 5)"
251/// expected: 11.666666666666666
252/// ```
253///
254/// ```yaml,sandbox
255/// title: "Average over a mixed range"
256/// grid:
257///   A1: 10
258///   A2: "x"
259///   A3: 20
260/// formula: "=AVERAGE(A1:A3)"
261/// expected: 15
262/// ```
263///
264/// ```yaml,sandbox
265/// title: "No numeric values returns divide-by-zero"
266/// formula: "=AVERAGE(\"x\", \"\")"
267/// expected: "#DIV/0!"
268/// ```
269///
270/// ```yaml,docs
271/// related:
272///   - SUM
273///   - COUNT
274///   - AVERAGEIF
275///   - AVERAGEIFS
276/// faq:
277///   - q: "When does AVERAGE return #DIV/0!?"
278///     a: "It returns #DIV/0! when no numeric values are found after filtering/coercion."
279///   - q: "Do text cells in ranges affect the denominator?"
280///     a: "No. Only numeric values are counted toward the divisor."
281/// ```
282///
283/// [formualizer-docgen:schema:start]
284/// Name: AVERAGE
285/// Type: AverageFn
286/// Min args: 1
287/// Max args: variadic
288/// Variadic: true
289/// Signature: AVERAGE(arg1...: number@range)
290/// Arg schema: arg1{kinds=number,required=true,shape=range,by_ref=false,coercion=NumberLenientText,max=None,repeating=None,default=false}
291/// Caps: PURE, REDUCTION, NUMERIC_ONLY, STREAM_OK
292/// [formualizer-docgen:schema:end]
293impl Function for AverageFn {
294    func_caps!(PURE, REDUCTION, NUMERIC_ONLY, STREAM_OK);
295
296    fn name(&self) -> &'static str {
297        "AVERAGE"
298    }
299    fn min_args(&self) -> usize {
300        1
301    }
302    fn variadic(&self) -> bool {
303        true
304    }
305    fn arg_schema(&self) -> &'static [ArgSchema] {
306        &ARG_RANGE_NUM_LENIENT_ONE[..]
307    }
308
309    fn eval<'a, 'b, 'c>(
310        &self,
311        args: &'c [ArgumentHandle<'a, 'b>],
312        ctx: &dyn FunctionContext<'b>,
313    ) -> Result<crate::traits::CalcValue<'b>, ExcelError> {
314        let mut sum = 0.0f64;
315        let mut cnt: i64 = 0;
316        for arg in args {
317            if let Ok(view) = arg.range_view() {
318                // Propagate errors from range first
319                for res in view.errors_slices() {
320                    let (_, _, err_cols) = res?;
321                    for col in err_cols {
322                        if col.null_count() < col.len() {
323                            for i in 0..col.len() {
324                                if !col.is_null(i) {
325                                    return Ok(crate::traits::CalcValue::Scalar(
326                                        LiteralValue::Error(ExcelError::new(
327                                            crate::arrow_store::unmap_error_code(col.value(i)),
328                                        )),
329                                    ));
330                                }
331                            }
332                        }
333                    }
334                }
335
336                for res in view.numbers_slices() {
337                    let (_, _, num_cols) = res?;
338                    for col in num_cols {
339                        sum += arrow::compute::kernels::aggregate::sum(col.as_ref()).unwrap_or(0.0);
340                        cnt += (col.len() - col.null_count()) as i64;
341                    }
342                }
343            } else {
344                let v = arg.value()?.into_literal();
345                if let LiteralValue::Error(e) = v {
346                    return Ok(crate::traits::CalcValue::Scalar(LiteralValue::Error(e)));
347                }
348                if let Ok(n) = crate::coercion::to_number_lenient_with_locale(&v, &ctx.locale()) {
349                    sum += n;
350                    cnt += 1;
351                }
352            }
353        }
354        if cnt == 0 {
355            return Ok(crate::traits::CalcValue::Scalar(LiteralValue::Error(
356                ExcelError::new_div(),
357            )));
358        }
359        Ok(crate::traits::CalcValue::Scalar(LiteralValue::Number(
360            sum / (cnt as f64),
361        )))
362    }
363}
364
365/* ──────────────────────── SUMPRODUCT() ───────────────────────── */
366
367#[derive(Debug)]
368pub struct SumProductFn;
369
370/// Multiplies aligned values across arrays and returns the sum of those products.
371///
372/// `SUMPRODUCT` supports scalar or range inputs, applies broadcast semantics, and accumulates
373/// the product for each aligned cell position.
374///
375/// # Remarks
376/// - Input shapes must be broadcast-compatible; otherwise `SUMPRODUCT` returns `#VALUE!`.
377/// - Non-numeric values are treated as `0` during multiplication.
378/// - Any explicit error value in the inputs propagates immediately.
379///
380/// # Examples
381///
382/// ```yaml,sandbox
383/// title: "Pairwise sum of products"
384/// formula: "=SUMPRODUCT({1,2,3}, {4,5,6})"
385/// expected: 32
386/// ```
387///
388/// ```yaml,sandbox
389/// title: "Range-based sumproduct"
390/// grid:
391///   A1: 2
392///   A2: 3
393///   A3: 4
394///   B1: 10
395///   B2: 20
396///   B3: 30
397/// formula: "=SUMPRODUCT(A1:A3, B1:B3)"
398/// expected: 200
399/// ```
400///
401/// ```yaml,sandbox
402/// title: "Text entries contribute zero"
403/// formula: "=SUMPRODUCT({1,\"x\",3}, {1,1,1})"
404/// expected: 4
405/// ```
406///
407/// ```yaml,docs
408/// related:
409///   - SUM
410///   - PRODUCT
411///   - MMULT
412///   - SUMIFS
413/// faq:
414///   - q: "Why does SUMPRODUCT return #VALUE! with some array shapes?"
415///     a: "The argument arrays must be broadcast-compatible; incompatible shapes raise #VALUE!."
416///   - q: "How are text values handled in multiplication?"
417///     a: "Non-numeric values are treated as 0, unless an explicit error is present."
418/// ```
419///
420/// [formualizer-docgen:schema:start]
421/// Name: SUMPRODUCT
422/// Type: SumProductFn
423/// Min args: 1
424/// Max args: variadic
425/// Variadic: true
426/// Signature: SUMPRODUCT(arg1...: number@range)
427/// Arg schema: arg1{kinds=number,required=true,shape=range,by_ref=false,coercion=NumberLenientText,max=None,repeating=None,default=false}
428/// Caps: PURE, REDUCTION
429/// [formualizer-docgen:schema:end]
430impl Function for SumProductFn {
431    // Pure reduction over arrays; uses broadcasting and lenient coercion
432    func_caps!(PURE, REDUCTION);
433
434    fn name(&self) -> &'static str {
435        "SUMPRODUCT"
436    }
437    fn min_args(&self) -> usize {
438        1
439    }
440    fn variadic(&self) -> bool {
441        true
442    }
443    fn arg_schema(&self) -> &'static [ArgSchema] {
444        // Accept ranges or scalars; numeric lenient coercion
445        &ARG_RANGE_NUM_LENIENT_ONE[..]
446    }
447
448    fn eval<'a, 'b, 'c>(
449        &self,
450        args: &'c [ArgumentHandle<'a, 'b>],
451        _: &dyn FunctionContext<'b>,
452    ) -> Result<crate::traits::CalcValue<'b>, ExcelError> {
453        use crate::broadcast::{broadcast_shape, project_index};
454
455        if args.is_empty() {
456            return Ok(crate::traits::CalcValue::Scalar(LiteralValue::Number(0.0)));
457        }
458
459        // Helper: materialize an argument to a 2D array of LiteralValue
460        let to_array = |ah: &ArgumentHandle| -> Result<Vec<Vec<LiteralValue>>, ExcelError> {
461            if let Ok(rv) = ah.range_view() {
462                let mut rows: Vec<Vec<LiteralValue>> = Vec::new();
463                rv.for_each_row(&mut |row| {
464                    rows.push(row.to_vec());
465                    Ok(())
466                })?;
467                Ok(rows)
468            } else {
469                let v = ah.value()?.into_literal();
470                Ok(match v {
471                    LiteralValue::Array(arr) => arr,
472                    other => vec![vec![other]],
473                })
474            }
475        };
476
477        // Collect arrays and shapes
478        let mut arrays: Vec<Vec<Vec<LiteralValue>>> = Vec::with_capacity(args.len());
479        let mut shapes: Vec<(usize, usize)> = Vec::with_capacity(args.len());
480        for a in args.iter() {
481            let arr = to_array(a)?;
482            let shape = (arr.len(), arr.first().map(|r| r.len()).unwrap_or(0));
483            arrays.push(arr);
484            shapes.push(shape);
485        }
486
487        // Compute broadcast target shape across all args
488        let target = match broadcast_shape(&shapes) {
489            Ok(s) => s,
490            Err(_) => {
491                return Ok(crate::traits::CalcValue::Scalar(LiteralValue::Error(
492                    ExcelError::new_value(),
493                )));
494            }
495        };
496
497        // Iterate target shape, multiply coerced values across args, sum total
498        let mut total = 0.0f64;
499        for r in 0..target.0 {
500            for c in 0..target.1 {
501                let mut prod = 1.0f64;
502                for (arr, &shape) in arrays.iter().zip(shapes.iter()) {
503                    let (rr, cc) = project_index((r, c), shape);
504                    let lv = arr
505                        .get(rr)
506                        .and_then(|row| row.get(cc))
507                        .cloned()
508                        .unwrap_or(LiteralValue::Empty);
509                    match lv {
510                        LiteralValue::Error(e) => {
511                            return Ok(crate::traits::CalcValue::Scalar(LiteralValue::Error(e)));
512                        }
513                        _ => match super::super::utils::coerce_num(&lv) {
514                            Ok(n) => {
515                                prod *= n;
516                            }
517                            Err(_) => {
518                                // Non-numeric -> treated as 0 in SUMPRODUCT
519                                prod *= 0.0;
520                            }
521                        },
522                    }
523                }
524                total += prod;
525            }
526        }
527        Ok(crate::traits::CalcValue::Scalar(LiteralValue::Number(
528            total,
529        )))
530    }
531}
532
533#[cfg(test)]
534mod tests_sumproduct {
535    use super::*;
536    use crate::test_workbook::TestWorkbook;
537    use crate::traits::ArgumentHandle;
538    use formualizer_parse::LiteralValue;
539    use formualizer_parse::parser::{ASTNode, ASTNodeType};
540
541    fn interp(wb: &TestWorkbook) -> crate::interpreter::Interpreter<'_> {
542        wb.interpreter()
543    }
544
545    fn arr(vals: Vec<Vec<LiteralValue>>) -> ASTNode {
546        ASTNode::new(ASTNodeType::Literal(LiteralValue::Array(vals)), None)
547    }
548
549    fn num(n: f64) -> ASTNode {
550        ASTNode::new(ASTNodeType::Literal(LiteralValue::Number(n)), None)
551    }
552
553    #[test]
554    fn sumproduct_basic_pairwise() {
555        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(SumProductFn));
556        let ctx = interp(&wb);
557        // {1,2,3} * {4,5,6} = 1*4 + 2*5 + 3*6 = 32
558        let a = arr(vec![vec![
559            LiteralValue::Int(1),
560            LiteralValue::Int(2),
561            LiteralValue::Int(3),
562        ]]);
563        let b = arr(vec![vec![
564            LiteralValue::Int(4),
565            LiteralValue::Int(5),
566            LiteralValue::Int(6),
567        ]]);
568        let args = vec![ArgumentHandle::new(&a, &ctx), ArgumentHandle::new(&b, &ctx)];
569        let f = ctx.context.get_function("", "SUMPRODUCT").unwrap();
570        assert_eq!(
571            f.dispatch(&args, &ctx.function_context(None))
572                .unwrap()
573                .into_literal(),
574            LiteralValue::Number(32.0)
575        );
576    }
577
578    #[test]
579    fn sumproduct_variadic_three_arrays() {
580        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(SumProductFn));
581        let ctx = interp(&wb);
582        // {1,2} * {3,4} * {2,2} = (1*3*2) + (2*4*2) = 6 + 16 = 22
583        let a = arr(vec![vec![LiteralValue::Int(1), LiteralValue::Int(2)]]);
584        let b = arr(vec![vec![LiteralValue::Int(3), LiteralValue::Int(4)]]);
585        let c = arr(vec![vec![LiteralValue::Int(2), LiteralValue::Int(2)]]);
586        let args = vec![
587            ArgumentHandle::new(&a, &ctx),
588            ArgumentHandle::new(&b, &ctx),
589            ArgumentHandle::new(&c, &ctx),
590        ];
591        let f = ctx.context.get_function("", "SUMPRODUCT").unwrap();
592        assert_eq!(
593            f.dispatch(&args, &ctx.function_context(None))
594                .unwrap()
595                .into_literal(),
596            LiteralValue::Number(22.0)
597        );
598    }
599
600    #[test]
601    fn sumproduct_broadcast_scalar_over_array() {
602        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(SumProductFn));
603        let ctx = interp(&wb);
604        // {1,2,3} * 10 => (1*10 + 2*10 + 3*10) = 60
605        let a = arr(vec![vec![
606            LiteralValue::Int(1),
607            LiteralValue::Int(2),
608            LiteralValue::Int(3),
609        ]]);
610        let s = num(10.0);
611        let args = vec![ArgumentHandle::new(&a, &ctx), ArgumentHandle::new(&s, &ctx)];
612        let f = ctx.context.get_function("", "SUMPRODUCT").unwrap();
613        assert_eq!(
614            f.dispatch(&args, &ctx.function_context(None))
615                .unwrap()
616                .into_literal(),
617            LiteralValue::Number(60.0)
618        );
619    }
620
621    #[test]
622    fn sumproduct_2d_arrays_broadcast_rows_cols() {
623        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(SumProductFn));
624        let ctx = interp(&wb);
625        // A is 2x2, B is 1x2 -> broadcast B across rows
626        // A = [[1,2],[3,4]], B = [[10,20]]
627        // sum = 1*10 + 2*20 + 3*10 + 4*20 = 10 + 40 + 30 + 80 = 160
628        let a = arr(vec![
629            vec![LiteralValue::Int(1), LiteralValue::Int(2)],
630            vec![LiteralValue::Int(3), LiteralValue::Int(4)],
631        ]);
632        let b = arr(vec![vec![LiteralValue::Int(10), LiteralValue::Int(20)]]);
633        let args = vec![ArgumentHandle::new(&a, &ctx), ArgumentHandle::new(&b, &ctx)];
634        let f = ctx.context.get_function("", "SUMPRODUCT").unwrap();
635        assert_eq!(
636            f.dispatch(&args, &ctx.function_context(None))
637                .unwrap()
638                .into_literal(),
639            LiteralValue::Number(160.0)
640        );
641    }
642
643    #[test]
644    fn sumproduct_non_numeric_treated_as_zero() {
645        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(SumProductFn));
646        let ctx = interp(&wb);
647        // {1,"x",3} * {1,1,1} => 1*1 + 0*1 + 3*1 = 4
648        let a = arr(vec![vec![
649            LiteralValue::Int(1),
650            LiteralValue::Text("x".into()),
651            LiteralValue::Int(3),
652        ]]);
653        let b = arr(vec![vec![
654            LiteralValue::Int(1),
655            LiteralValue::Int(1),
656            LiteralValue::Int(1),
657        ]]);
658        let args = vec![ArgumentHandle::new(&a, &ctx), ArgumentHandle::new(&b, &ctx)];
659        let f = ctx.context.get_function("", "SUMPRODUCT").unwrap();
660        assert_eq!(
661            f.dispatch(&args, &ctx.function_context(None))
662                .unwrap()
663                .into_literal(),
664            LiteralValue::Number(4.0)
665        );
666    }
667
668    #[test]
669    fn sumproduct_error_in_input_propagates() {
670        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(SumProductFn));
671        let ctx = interp(&wb);
672        let a = arr(vec![vec![LiteralValue::Int(1), LiteralValue::Int(2)]]);
673        let e = ASTNode::new(
674            ASTNodeType::Literal(LiteralValue::Error(ExcelError::new_na())),
675            None,
676        );
677        let args = vec![ArgumentHandle::new(&a, &ctx), ArgumentHandle::new(&e, &ctx)];
678        let f = ctx.context.get_function("", "SUMPRODUCT").unwrap();
679        match f
680            .dispatch(&args, &ctx.function_context(None))
681            .unwrap()
682            .into_literal()
683        {
684            LiteralValue::Error(err) => assert_eq!(err, "#N/A"),
685            v => panic!("expected error, got {v:?}"),
686        }
687    }
688
689    #[test]
690    fn sumproduct_incompatible_shapes_value_error() {
691        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(SumProductFn));
692        let ctx = interp(&wb);
693        // 1x3 and 1x2 -> #VALUE!
694        let a = arr(vec![vec![
695            LiteralValue::Int(1),
696            LiteralValue::Int(2),
697            LiteralValue::Int(3),
698        ]]);
699        let b = arr(vec![vec![LiteralValue::Int(4), LiteralValue::Int(5)]]);
700        let args = vec![ArgumentHandle::new(&a, &ctx), ArgumentHandle::new(&b, &ctx)];
701        let f = ctx.context.get_function("", "SUMPRODUCT").unwrap();
702        match f
703            .dispatch(&args, &ctx.function_context(None))
704            .unwrap()
705            .into_literal()
706        {
707            LiteralValue::Error(e) => assert_eq!(e, "#VALUE!"),
708            v => panic!("expected value error, got {v:?}"),
709        }
710    }
711}
712
713#[cfg(test)]
714mod tests {
715    use super::*;
716    use crate::test_workbook::TestWorkbook;
717    use formualizer_parse::LiteralValue;
718
719    fn interp(wb: &TestWorkbook) -> crate::interpreter::Interpreter<'_> {
720        wb.interpreter()
721    }
722
723    #[test]
724    fn test_sum_caps() {
725        let sum_fn = SumFn;
726        let caps = sum_fn.caps();
727
728        // Check that the expected capabilities are set
729        assert!(caps.contains(crate::function::FnCaps::PURE));
730        assert!(caps.contains(crate::function::FnCaps::REDUCTION));
731        assert!(caps.contains(crate::function::FnCaps::NUMERIC_ONLY));
732        assert!(caps.contains(crate::function::FnCaps::STREAM_OK));
733
734        // Check that other caps are not set
735        assert!(!caps.contains(crate::function::FnCaps::VOLATILE));
736        assert!(!caps.contains(crate::function::FnCaps::ELEMENTWISE));
737    }
738
739    #[test]
740    fn test_sum_basic() {
741        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(SumFn));
742        let ctx = interp(&wb);
743        let fctx = ctx.function_context(None);
744
745        // Test basic SUM functionality by creating ArgumentHandles manually
746        let dummy_ast_1 = formualizer_parse::parser::ASTNode::new(
747            formualizer_parse::parser::ASTNodeType::Literal(LiteralValue::Number(1.0)),
748            None,
749        );
750        let dummy_ast_2 = formualizer_parse::parser::ASTNode::new(
751            formualizer_parse::parser::ASTNodeType::Literal(LiteralValue::Number(2.0)),
752            None,
753        );
754        let dummy_ast_3 = formualizer_parse::parser::ASTNode::new(
755            formualizer_parse::parser::ASTNodeType::Literal(LiteralValue::Number(3.0)),
756            None,
757        );
758
759        let args = vec![
760            ArgumentHandle::new(&dummy_ast_1, &ctx),
761            ArgumentHandle::new(&dummy_ast_2, &ctx),
762            ArgumentHandle::new(&dummy_ast_3, &ctx),
763        ];
764
765        let sum_fn = ctx.context.get_function("", "SUM").unwrap();
766        let result = sum_fn.dispatch(&args, &fctx).unwrap().into_literal();
767        assert_eq!(result, LiteralValue::Number(6.0));
768    }
769}
770
771#[cfg(test)]
772mod tests_count {
773    use super::*;
774    use crate::test_workbook::TestWorkbook;
775    use crate::traits::ArgumentHandle;
776    use formualizer_parse::LiteralValue;
777    use formualizer_parse::parser::ASTNode;
778    use formualizer_parse::parser::ASTNodeType;
779
780    fn interp(wb: &TestWorkbook) -> crate::interpreter::Interpreter<'_> {
781        wb.interpreter()
782    }
783
784    #[test]
785    fn count_numbers_ignores_text() {
786        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(CountFn));
787        let ctx = interp(&wb);
788        // COUNT({1,2,"x",3}) => 3
789        let arr = LiteralValue::Array(vec![vec![
790            LiteralValue::Int(1),
791            LiteralValue::Int(2),
792            LiteralValue::Text("x".into()),
793            LiteralValue::Int(3),
794        ]]);
795        let node = ASTNode::new(ASTNodeType::Literal(arr), None);
796        let args = vec![ArgumentHandle::new(&node, &ctx)];
797        let f = ctx.context.get_function("", "COUNT").unwrap();
798        let fctx = ctx.function_context(None);
799        assert_eq!(
800            f.dispatch(&args, &fctx).unwrap().into_literal(),
801            LiteralValue::Number(3.0)
802        );
803    }
804
805    #[test]
806    fn count_multiple_args_and_scalars() {
807        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(CountFn));
808        let ctx = interp(&wb);
809        let n1 = ASTNode::new(ASTNodeType::Literal(LiteralValue::Int(10)), None);
810        let n2 = ASTNode::new(ASTNodeType::Literal(LiteralValue::Text("n".into())), None);
811        let arr = LiteralValue::Array(vec![vec![LiteralValue::Int(1), LiteralValue::Int(2)]]);
812        let a = ASTNode::new(ASTNodeType::Literal(arr), None);
813        let args = vec![
814            ArgumentHandle::new(&a, &ctx),
815            ArgumentHandle::new(&n1, &ctx),
816            ArgumentHandle::new(&n2, &ctx),
817        ];
818        let f = ctx.context.get_function("", "COUNT").unwrap();
819        // Two from array + scalar 10 = 3
820        let fctx = ctx.function_context(None);
821        assert_eq!(
822            f.dispatch(&args, &fctx).unwrap().into_literal(),
823            LiteralValue::Number(3.0)
824        );
825    }
826
827    #[test]
828    fn count_direct_error_argument_propagates() {
829        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(CountFn));
830        let ctx = interp(&wb);
831        let err = ASTNode::new(
832            ASTNodeType::Literal(LiteralValue::Error(ExcelError::from_error_string(
833                "#DIV/0!",
834            ))),
835            None,
836        );
837        let args = vec![ArgumentHandle::new(&err, &ctx)];
838        let f = ctx.context.get_function("", "COUNT").unwrap();
839        let fctx = ctx.function_context(None);
840        match f.dispatch(&args, &fctx).unwrap().into_literal() {
841            LiteralValue::Error(e) => assert_eq!(e, "#DIV/0!"),
842            v => panic!("unexpected {v:?}"),
843        }
844    }
845}
846
847#[cfg(test)]
848mod tests_average {
849    use super::*;
850    use crate::test_workbook::TestWorkbook;
851    use crate::traits::ArgumentHandle;
852    use formualizer_parse::LiteralValue;
853    use formualizer_parse::parser::ASTNode;
854    use formualizer_parse::parser::ASTNodeType;
855
856    fn interp(wb: &TestWorkbook) -> crate::interpreter::Interpreter<'_> {
857        wb.interpreter()
858    }
859
860    #[test]
861    fn average_basic_numbers() {
862        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(AverageFn));
863        let ctx = interp(&wb);
864        let arr = LiteralValue::Array(vec![vec![
865            LiteralValue::Int(2),
866            LiteralValue::Int(4),
867            LiteralValue::Int(6),
868        ]]);
869        let node = ASTNode::new(ASTNodeType::Literal(arr), None);
870        let args = vec![ArgumentHandle::new(&node, &ctx)];
871        let f = ctx.context.get_function("", "AVERAGE").unwrap();
872        assert_eq!(
873            f.dispatch(&args, &ctx.function_context(None))
874                .unwrap()
875                .into_literal(),
876            LiteralValue::Number(4.0)
877        );
878    }
879
880    #[test]
881    fn average_mixed_with_text() {
882        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(AverageFn));
883        let ctx = interp(&wb);
884        let arr = LiteralValue::Array(vec![vec![
885            LiteralValue::Int(2),
886            LiteralValue::Text("x".into()),
887            LiteralValue::Int(6),
888        ]]);
889        let node = ASTNode::new(ASTNodeType::Literal(arr), None);
890        let args = vec![ArgumentHandle::new(&node, &ctx)];
891        let f = ctx.context.get_function("", "AVERAGE").unwrap();
892        // average of 2 and 6 = 4
893        assert_eq!(
894            f.dispatch(&args, &ctx.function_context(None))
895                .unwrap()
896                .into_literal(),
897            LiteralValue::Number(4.0)
898        );
899    }
900
901    #[test]
902    fn average_no_numeric_div0() {
903        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(AverageFn));
904        let ctx = interp(&wb);
905        let arr = LiteralValue::Array(vec![vec![
906            LiteralValue::Text("a".into()),
907            LiteralValue::Text("b".into()),
908        ]]);
909        let node = ASTNode::new(ASTNodeType::Literal(arr), None);
910        let args = vec![ArgumentHandle::new(&node, &ctx)];
911        let f = ctx.context.get_function("", "AVERAGE").unwrap();
912        let fctx = ctx.function_context(None);
913        match f.dispatch(&args, &fctx).unwrap().into_literal() {
914            LiteralValue::Error(e) => assert_eq!(e, "#DIV/0!"),
915            v => panic!("expected #DIV/0!, got {v:?}"),
916        }
917    }
918
919    #[test]
920    fn average_direct_error_argument_propagates() {
921        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(AverageFn));
922        let ctx = interp(&wb);
923        let err = ASTNode::new(
924            ASTNodeType::Literal(LiteralValue::Error(ExcelError::from_error_string(
925                "#DIV/0!",
926            ))),
927            None,
928        );
929        let args = vec![ArgumentHandle::new(&err, &ctx)];
930        let f = ctx.context.get_function("", "AVERAGE").unwrap();
931        let fctx = ctx.function_context(None);
932        match f.dispatch(&args, &fctx).unwrap().into_literal() {
933            LiteralValue::Error(e) => assert_eq!(e, "#DIV/0!"),
934            v => panic!("unexpected {v:?}"),
935        }
936    }
937}
938
939#[derive(Copy, Clone, Debug, Eq, PartialEq)]
940enum VisibilityPolicy {
941    IncludeAll,
942    ExcludeManualOrFilterHidden,
943}
944
945#[derive(Copy, Clone, Debug, Eq, PartialEq)]
946enum ErrorPolicy {
947    Propagate,
948    Ignore,
949}
950
951#[derive(Copy, Clone, Debug, Eq, PartialEq)]
952enum AggregateOp {
953    Average,
954    Count,
955    CountA,
956    Max,
957    Min,
958    Product,
959    StdevSample,
960    StdevPopulation,
961    Sum,
962    VarSample,
963    VarPopulation,
964}
965
966fn aggregate_op_from_function_num(function_num: i32) -> Option<AggregateOp> {
967    match function_num {
968        1 => Some(AggregateOp::Average),
969        2 => Some(AggregateOp::Count),
970        3 => Some(AggregateOp::CountA),
971        4 => Some(AggregateOp::Max),
972        5 => Some(AggregateOp::Min),
973        6 => Some(AggregateOp::Product),
974        7 => Some(AggregateOp::StdevSample),
975        8 => Some(AggregateOp::StdevPopulation),
976        9 => Some(AggregateOp::Sum),
977        10 => Some(AggregateOp::VarSample),
978        11 => Some(AggregateOp::VarPopulation),
979        _ => None,
980    }
981}
982
983fn parse_strict_int_arg(arg: &ArgumentHandle<'_, '_>) -> Result<i32, ExcelError> {
984    let raw = arg.value()?.into_literal();
985    if let LiteralValue::Error(e) = raw {
986        return Err(e);
987    }
988
989    let n = coerce_num(&raw)?;
990    if !n.is_finite() {
991        return Err(ExcelError::new_value());
992    }
993
994    let rounded = n.round();
995    if (n - rounded).abs() > 1e-9 {
996        return Err(ExcelError::new_value());
997    }
998
999    if rounded < i32::MIN as f64 || rounded > i32::MAX as f64 {
1000        return Err(ExcelError::new_value());
1001    }
1002
1003    Ok(rounded as i32)
1004}
1005
1006fn row_is_visible(mask: Option<&arrow_array::BooleanArray>, relative_row: usize) -> bool {
1007    let Some(mask) = mask else {
1008        return true;
1009    };
1010
1011    if relative_row >= mask.len() || mask.is_null(relative_row) {
1012        return true;
1013    }
1014
1015    mask.value(relative_row)
1016}
1017
1018fn numeric_from_range_value(value: &LiteralValue) -> Option<f64> {
1019    match value {
1020        LiteralValue::Number(n) => Some(*n),
1021        LiteralValue::Int(i) => Some(*i as f64),
1022        LiteralValue::Date(_)
1023        | LiteralValue::DateTime(_)
1024        | LiteralValue::Time(_)
1025        | LiteralValue::Duration(_) => coerce_num(value).ok(),
1026        _ => None,
1027    }
1028}
1029
1030#[derive(Debug, Default)]
1031struct AggregateCollector {
1032    numeric_values: Vec<f64>,
1033    counta: usize,
1034}
1035
1036impl AggregateCollector {
1037    fn collect_args<'a, 'b>(
1038        args: &[ArgumentHandle<'a, 'b>],
1039        start_idx: usize,
1040        ctx: &dyn FunctionContext<'b>,
1041        op: AggregateOp,
1042        visibility_policy: VisibilityPolicy,
1043        error_policy: ErrorPolicy,
1044    ) -> Result<Self, ExcelError> {
1045        let mut out = Self::default();
1046
1047        for arg in args.iter().skip(start_idx) {
1048            if let Ok(view) = arg.range_view() {
1049                out.collect_range_arg(&view, ctx, op, visibility_policy, error_policy)?;
1050            } else {
1051                out.consume_scalar_value(arg.value()?.into_literal(), op, error_policy)?;
1052            }
1053        }
1054
1055        Ok(out)
1056    }
1057
1058    fn collect_range_arg<'b>(
1059        &mut self,
1060        view: &crate::engine::range_view::RangeView<'_>,
1061        ctx: &dyn FunctionContext<'b>,
1062        op: AggregateOp,
1063        visibility_policy: VisibilityPolicy,
1064        error_policy: ErrorPolicy,
1065    ) -> Result<(), ExcelError> {
1066        let visibility_mask = match visibility_policy {
1067            VisibilityPolicy::IncludeAll => None,
1068            VisibilityPolicy::ExcludeManualOrFilterHidden => {
1069                ctx.get_row_visibility_mask(view, VisibilityMaskMode::ExcludeManualOrFilterHidden)
1070            }
1071        };
1072
1073        let (_, cols) = view.dims();
1074        if cols == 0 {
1075            return Ok(());
1076        }
1077
1078        for chunk in view.iter_row_chunks() {
1079            let chunk = chunk?;
1080            for row_offset in 0..chunk.row_len {
1081                let rel_row = chunk.row_start + row_offset;
1082                if !row_is_visible(visibility_mask.as_deref(), rel_row) {
1083                    continue;
1084                }
1085
1086                for col in 0..cols {
1087                    // Phase-1 contract: nested SUBTOTAL/AGGREGATE exclusion is deferred.
1088                    // Nested aggregate results are treated as ordinary scalar values.
1089                    self.consume_range_value(view.get_cell(rel_row, col), op, error_policy)?;
1090                }
1091            }
1092        }
1093
1094        Ok(())
1095    }
1096
1097    fn consume_range_value(
1098        &mut self,
1099        value: LiteralValue,
1100        op: AggregateOp,
1101        error_policy: ErrorPolicy,
1102    ) -> Result<(), ExcelError> {
1103        match value {
1104            LiteralValue::Error(e) => {
1105                if op == AggregateOp::CountA {
1106                    if error_policy == ErrorPolicy::Ignore {
1107                        return Ok(());
1108                    }
1109                    self.counta += 1;
1110                    return Ok(());
1111                }
1112                match error_policy {
1113                    ErrorPolicy::Propagate => Err(e),
1114                    ErrorPolicy::Ignore => Ok(()),
1115                }
1116            }
1117            LiteralValue::Empty => Ok(()),
1118            other => {
1119                self.counta += 1;
1120                if let Some(n) = numeric_from_range_value(&other) {
1121                    self.numeric_values.push(n);
1122                }
1123                Ok(())
1124            }
1125        }
1126    }
1127
1128    fn consume_scalar_value(
1129        &mut self,
1130        value: LiteralValue,
1131        op: AggregateOp,
1132        error_policy: ErrorPolicy,
1133    ) -> Result<(), ExcelError> {
1134        match value {
1135            LiteralValue::Error(e) => {
1136                if op == AggregateOp::CountA {
1137                    if error_policy == ErrorPolicy::Ignore {
1138                        return Ok(());
1139                    }
1140                    self.counta += 1;
1141                    return Ok(());
1142                }
1143                match error_policy {
1144                    ErrorPolicy::Propagate => Err(e),
1145                    ErrorPolicy::Ignore => Ok(()),
1146                }
1147            }
1148            LiteralValue::Array(rows) => {
1149                for row in rows {
1150                    for cell in row {
1151                        self.consume_range_value(cell, op, error_policy)?;
1152                    }
1153                }
1154                Ok(())
1155            }
1156            other => {
1157                match op {
1158                    AggregateOp::CountA => {
1159                        if !matches!(other, LiteralValue::Empty) {
1160                            self.counta += 1;
1161                        }
1162                    }
1163                    AggregateOp::Count => {
1164                        if !matches!(other, LiteralValue::Empty) && coerce_num(&other).is_ok() {
1165                            self.numeric_values.push(0.0);
1166                        }
1167                    }
1168                    _ => {
1169                        if let Ok(n) = coerce_num(&other) {
1170                            self.numeric_values.push(n);
1171                        }
1172                    }
1173                }
1174                Ok(())
1175            }
1176        }
1177    }
1178
1179    fn variance(values: &[f64], sample: bool) -> Result<f64, ExcelError> {
1180        let n = values.len();
1181        if sample {
1182            if n < 2 {
1183                return Err(ExcelError::new_div());
1184            }
1185        } else if n == 0 {
1186            return Err(ExcelError::new_div());
1187        }
1188
1189        let mean = values.iter().copied().sum::<f64>() / (n as f64);
1190        let mut ss = 0.0;
1191        for value in values {
1192            let d = *value - mean;
1193            ss += d * d;
1194        }
1195
1196        if sample {
1197            Ok(ss / ((n - 1) as f64))
1198        } else {
1199            Ok(ss / (n as f64))
1200        }
1201    }
1202
1203    fn finalize(self, op: AggregateOp) -> LiteralValue {
1204        match op {
1205            AggregateOp::Average => {
1206                if self.numeric_values.is_empty() {
1207                    LiteralValue::Error(ExcelError::new_div())
1208                } else {
1209                    let sum = self.numeric_values.iter().copied().sum::<f64>();
1210                    LiteralValue::Number(sum / (self.numeric_values.len() as f64))
1211                }
1212            }
1213            AggregateOp::Count => LiteralValue::Number(self.numeric_values.len() as f64),
1214            AggregateOp::CountA => LiteralValue::Number(self.counta as f64),
1215            AggregateOp::Max => LiteralValue::Number(
1216                self.numeric_values
1217                    .iter()
1218                    .copied()
1219                    .reduce(f64::max)
1220                    .unwrap_or(0.0),
1221            ),
1222            AggregateOp::Min => LiteralValue::Number(
1223                self.numeric_values
1224                    .iter()
1225                    .copied()
1226                    .reduce(f64::min)
1227                    .unwrap_or(0.0),
1228            ),
1229            AggregateOp::Product => {
1230                if self.numeric_values.is_empty() {
1231                    LiteralValue::Number(0.0)
1232                } else {
1233                    LiteralValue::Number(self.numeric_values.iter().copied().product::<f64>())
1234                }
1235            }
1236            AggregateOp::StdevSample => match Self::variance(&self.numeric_values, true) {
1237                Ok(v) => LiteralValue::Number(v.sqrt()),
1238                Err(e) => LiteralValue::Error(e),
1239            },
1240            AggregateOp::StdevPopulation => match Self::variance(&self.numeric_values, false) {
1241                Ok(v) => LiteralValue::Number(v.sqrt()),
1242                Err(e) => LiteralValue::Error(e),
1243            },
1244            AggregateOp::Sum => LiteralValue::Number(self.numeric_values.iter().copied().sum()),
1245            AggregateOp::VarSample => match Self::variance(&self.numeric_values, true) {
1246                Ok(v) => LiteralValue::Number(v),
1247                Err(e) => LiteralValue::Error(e),
1248            },
1249            AggregateOp::VarPopulation => match Self::variance(&self.numeric_values, false) {
1250                Ok(v) => LiteralValue::Number(v),
1251                Err(e) => LiteralValue::Error(e),
1252            },
1253        }
1254    }
1255}
1256
1257#[derive(Debug)]
1258pub struct SubtotalFn;
1259
1260impl Function for SubtotalFn {
1261    func_caps!(VOLATILE, REDUCTION, NUMERIC_ONLY, STREAM_OK);
1262
1263    fn name(&self) -> &'static str {
1264        "SUBTOTAL"
1265    }
1266
1267    fn min_args(&self) -> usize {
1268        2
1269    }
1270
1271    fn variadic(&self) -> bool {
1272        true
1273    }
1274
1275    fn arg_schema(&self) -> &'static [ArgSchema] {
1276        &ARG_RANGE_NUM_LENIENT_ONE[..]
1277    }
1278
1279    fn eval<'a, 'b, 'c>(
1280        &self,
1281        args: &'c [ArgumentHandle<'a, 'b>],
1282        ctx: &dyn FunctionContext<'b>,
1283    ) -> Result<crate::traits::CalcValue<'b>, ExcelError> {
1284        if args.len() < 2 {
1285            return Ok(crate::traits::CalcValue::Scalar(LiteralValue::Error(
1286                ExcelError::new_value(),
1287            )));
1288        }
1289
1290        let function_num = match parse_strict_int_arg(&args[0]) {
1291            Ok(v) => v,
1292            Err(e) => return Ok(crate::traits::CalcValue::Scalar(LiteralValue::Error(e))),
1293        };
1294
1295        let (mapped_code, visibility) = if (1..=11).contains(&function_num) {
1296            (function_num, VisibilityPolicy::IncludeAll)
1297        } else if (101..=111).contains(&function_num) {
1298            (
1299                function_num - 100,
1300                VisibilityPolicy::ExcludeManualOrFilterHidden,
1301            )
1302        } else {
1303            return Ok(crate::traits::CalcValue::Scalar(LiteralValue::Error(
1304                ExcelError::new_value(),
1305            )));
1306        };
1307
1308        let Some(op) = aggregate_op_from_function_num(mapped_code) else {
1309            return Ok(crate::traits::CalcValue::Scalar(LiteralValue::Error(
1310                ExcelError::new_value(),
1311            )));
1312        };
1313
1314        let collected = match AggregateCollector::collect_args(
1315            args,
1316            1,
1317            ctx,
1318            op,
1319            visibility,
1320            ErrorPolicy::Propagate,
1321        ) {
1322            Ok(c) => c,
1323            Err(e) => return Ok(crate::traits::CalcValue::Scalar(LiteralValue::Error(e))),
1324        };
1325
1326        Ok(crate::traits::CalcValue::Scalar(collected.finalize(op)))
1327    }
1328}
1329
1330#[derive(Debug)]
1331pub struct AggregateFn;
1332
1333impl Function for AggregateFn {
1334    func_caps!(VOLATILE, REDUCTION, NUMERIC_ONLY, STREAM_OK);
1335
1336    fn name(&self) -> &'static str {
1337        "AGGREGATE"
1338    }
1339
1340    fn min_args(&self) -> usize {
1341        3
1342    }
1343
1344    fn variadic(&self) -> bool {
1345        true
1346    }
1347
1348    fn arg_schema(&self) -> &'static [ArgSchema] {
1349        &ARG_RANGE_NUM_LENIENT_ONE[..]
1350    }
1351
1352    fn eval<'a, 'b, 'c>(
1353        &self,
1354        args: &'c [ArgumentHandle<'a, 'b>],
1355        ctx: &dyn FunctionContext<'b>,
1356    ) -> Result<crate::traits::CalcValue<'b>, ExcelError> {
1357        if args.len() < 3 {
1358            return Ok(crate::traits::CalcValue::Scalar(LiteralValue::Error(
1359                ExcelError::new_value(),
1360            )));
1361        }
1362
1363        let function_num = match parse_strict_int_arg(&args[0]) {
1364            Ok(v) => v,
1365            Err(e) => return Ok(crate::traits::CalcValue::Scalar(LiteralValue::Error(e))),
1366        };
1367
1368        let op = if (1..=11).contains(&function_num) {
1369            aggregate_op_from_function_num(function_num)
1370                .expect("validated AGGREGATE function_num maps to operation")
1371        } else if (12..=19).contains(&function_num) {
1372            return Ok(crate::traits::CalcValue::Scalar(LiteralValue::Error(
1373                ExcelError::new(ExcelErrorKind::NImpl),
1374            )));
1375        } else {
1376            return Ok(crate::traits::CalcValue::Scalar(LiteralValue::Error(
1377                ExcelError::new_value(),
1378            )));
1379        };
1380
1381        let options = match parse_strict_int_arg(&args[1]) {
1382            Ok(v) => v,
1383            Err(e) => return Ok(crate::traits::CalcValue::Scalar(LiteralValue::Error(e))),
1384        };
1385
1386        let (visibility, error_policy) = match options {
1387            0 => (VisibilityPolicy::IncludeAll, ErrorPolicy::Propagate),
1388            1 => (
1389                VisibilityPolicy::ExcludeManualOrFilterHidden,
1390                ErrorPolicy::Propagate,
1391            ),
1392            2 => (VisibilityPolicy::IncludeAll, ErrorPolicy::Ignore),
1393            3 => (
1394                VisibilityPolicy::ExcludeManualOrFilterHidden,
1395                ErrorPolicy::Ignore,
1396            ),
1397            4..=7 => {
1398                return Ok(crate::traits::CalcValue::Scalar(LiteralValue::Error(
1399                    ExcelError::new(ExcelErrorKind::NImpl),
1400                )));
1401            }
1402            _ => {
1403                return Ok(crate::traits::CalcValue::Scalar(LiteralValue::Error(
1404                    ExcelError::new_value(),
1405                )));
1406            }
1407        };
1408
1409        let collected =
1410            match AggregateCollector::collect_args(args, 2, ctx, op, visibility, error_policy) {
1411                Ok(c) => c,
1412                Err(e) => return Ok(crate::traits::CalcValue::Scalar(LiteralValue::Error(e))),
1413            };
1414
1415        Ok(crate::traits::CalcValue::Scalar(collected.finalize(op)))
1416    }
1417}
1418
1419#[cfg(test)]
1420mod tests_subtotal_aggregate {
1421    use super::*;
1422    use crate::test_workbook::TestWorkbook;
1423    use crate::traits::ArgumentHandle;
1424    use formualizer_common::{ExcelErrorKind, LiteralValue};
1425    use formualizer_parse::parser::{ASTNode, ASTNodeType};
1426
1427    fn interp(wb: &TestWorkbook) -> crate::interpreter::Interpreter<'_> {
1428        wb.interpreter()
1429    }
1430
1431    fn lit(value: LiteralValue) -> ASTNode {
1432        ASTNode::new(ASTNodeType::Literal(value), None)
1433    }
1434
1435    fn dispatch(
1436        ctx: &crate::interpreter::Interpreter<'_>,
1437        fn_name: &str,
1438        nodes: &[ASTNode],
1439    ) -> LiteralValue {
1440        let args: Vec<_> = nodes.iter().map(|n| ArgumentHandle::new(n, ctx)).collect();
1441        let f = ctx.context.get_function("", fn_name).expect("function");
1442        f.dispatch(&args, &ctx.function_context(None))
1443            .expect("dispatch")
1444            .into_literal()
1445    }
1446
1447    fn assert_num_close(value: LiteralValue, expected: f64) {
1448        match value {
1449            LiteralValue::Number(n) => assert!((n - expected).abs() < 1e-9, "{n} != {expected}"),
1450            LiteralValue::Int(i) => assert!(((i as f64) - expected).abs() < 1e-9),
1451            other => panic!("expected numeric {expected}, got {other:?}"),
1452        }
1453    }
1454
1455    fn assert_error_kind(value: LiteralValue, expected: ExcelErrorKind) {
1456        match value {
1457            LiteralValue::Error(e) => assert_eq!(e.kind, expected),
1458            other => panic!("expected error {:?}, got {other:?}", expected),
1459        }
1460    }
1461
1462    #[test]
1463    fn subtotal_function_num_mapping_basics() {
1464        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(SubtotalFn));
1465        let ctx = interp(&wb);
1466        let values = LiteralValue::Array(vec![vec![
1467            LiteralValue::Int(1),
1468            LiteralValue::Int(2),
1469            LiteralValue::Int(3),
1470        ]]);
1471
1472        let cases: &[(i64, f64)] = &[
1473            (1, 2.0),
1474            (2, 3.0),
1475            (3, 3.0),
1476            (4, 3.0),
1477            (5, 1.0),
1478            (6, 6.0),
1479            (7, 1.0),
1480            (8, (2.0f64 / 3.0).sqrt()),
1481            (9, 6.0),
1482            (10, 1.0),
1483            (11, 2.0 / 3.0),
1484        ];
1485
1486        for (code, expected) in cases {
1487            let args = vec![lit(LiteralValue::Int(*code)), lit(values.clone())];
1488            let out = dispatch(&ctx, "SUBTOTAL", &args);
1489            assert_num_close(out, *expected);
1490        }
1491    }
1492
1493    #[test]
1494    fn subtotal_counta_counts_errors_as_non_empty() {
1495        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(SubtotalFn));
1496        let ctx = interp(&wb);
1497
1498        let args = vec![
1499            lit(LiteralValue::Int(3)),
1500            lit(LiteralValue::Array(vec![vec![
1501                LiteralValue::Int(1),
1502                LiteralValue::Error(ExcelError::new_div()),
1503                LiteralValue::Text("x".into()),
1504                LiteralValue::Text("".into()),
1505            ]])),
1506        ];
1507        let out = dispatch(&ctx, "SUBTOTAL", &args);
1508        assert_num_close(out, 4.0);
1509    }
1510
1511    #[test]
1512    fn subtotal_invalid_function_num_returns_value_error() {
1513        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(SubtotalFn));
1514        let ctx = interp(&wb);
1515
1516        let args = vec![
1517            lit(LiteralValue::Number(9.5)),
1518            lit(LiteralValue::Array(vec![vec![LiteralValue::Int(1)]])),
1519        ];
1520        let out = dispatch(&ctx, "SUBTOTAL", &args);
1521        assert_error_kind(out, ExcelErrorKind::Value);
1522    }
1523
1524    #[test]
1525    fn subtotal_requires_ref_argument() {
1526        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(SubtotalFn));
1527        let ctx = interp(&wb);
1528
1529        let out = dispatch(&ctx, "SUBTOTAL", &[lit(LiteralValue::Int(9))]);
1530        assert_error_kind(out, ExcelErrorKind::Value);
1531    }
1532
1533    #[test]
1534    fn aggregate_requires_options_and_ref_argument() {
1535        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(AggregateFn));
1536        let ctx = interp(&wb);
1537
1538        let out = dispatch(
1539            &ctx,
1540            "AGGREGATE",
1541            &[lit(LiteralValue::Int(9)), lit(LiteralValue::Int(0))],
1542        );
1543        assert_error_kind(out, ExcelErrorKind::Value);
1544    }
1545
1546    #[test]
1547    fn aggregate_options_zero_to_three_control_error_behavior() {
1548        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(AggregateFn));
1549        let ctx = interp(&wb);
1550        let values = LiteralValue::Array(vec![vec![
1551            LiteralValue::Int(10),
1552            LiteralValue::Error(ExcelError::new_div()),
1553            LiteralValue::Int(30),
1554        ]]);
1555
1556        let opt0 = dispatch(
1557            &ctx,
1558            "AGGREGATE",
1559            &[
1560                lit(LiteralValue::Int(9)),
1561                lit(LiteralValue::Int(0)),
1562                lit(values.clone()),
1563            ],
1564        );
1565        assert_error_kind(opt0, ExcelErrorKind::Div);
1566
1567        let opt1 = dispatch(
1568            &ctx,
1569            "AGGREGATE",
1570            &[
1571                lit(LiteralValue::Int(9)),
1572                lit(LiteralValue::Int(1)),
1573                lit(values.clone()),
1574            ],
1575        );
1576        assert_error_kind(opt1, ExcelErrorKind::Div);
1577
1578        let opt2 = dispatch(
1579            &ctx,
1580            "AGGREGATE",
1581            &[
1582                lit(LiteralValue::Int(9)),
1583                lit(LiteralValue::Int(2)),
1584                lit(values.clone()),
1585            ],
1586        );
1587        assert_num_close(opt2, 40.0);
1588
1589        let opt3 = dispatch(
1590            &ctx,
1591            "AGGREGATE",
1592            &[
1593                lit(LiteralValue::Int(9)),
1594                lit(LiteralValue::Int(3)),
1595                lit(values),
1596            ],
1597        );
1598        assert_num_close(opt3, 40.0);
1599    }
1600
1601    #[test]
1602    fn aggregate_counta_option_ignore_errors_skips_error_values() {
1603        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(AggregateFn));
1604        let ctx = interp(&wb);
1605
1606        let out = dispatch(
1607            &ctx,
1608            "AGGREGATE",
1609            &[
1610                lit(LiteralValue::Int(3)),
1611                lit(LiteralValue::Int(2)),
1612                lit(LiteralValue::Array(vec![vec![
1613                    LiteralValue::Int(1),
1614                    LiteralValue::Error(ExcelError::new_div()),
1615                    LiteralValue::Text("x".into()),
1616                ]])),
1617            ],
1618        );
1619        assert_num_close(out, 2.0);
1620    }
1621
1622    #[test]
1623    fn aggregate_unsupported_option_returns_nimpl() {
1624        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(AggregateFn));
1625        let ctx = interp(&wb);
1626
1627        let out = dispatch(
1628            &ctx,
1629            "AGGREGATE",
1630            &[
1631                lit(LiteralValue::Int(9)),
1632                lit(LiteralValue::Int(4)),
1633                lit(LiteralValue::Array(vec![vec![LiteralValue::Int(1)]])),
1634            ],
1635        );
1636        assert_error_kind(out, ExcelErrorKind::NImpl);
1637    }
1638
1639    #[test]
1640    fn aggregate_unsupported_function_num_returns_nimpl() {
1641        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(AggregateFn));
1642        let ctx = interp(&wb);
1643
1644        let out = dispatch(
1645            &ctx,
1646            "AGGREGATE",
1647            &[
1648                lit(LiteralValue::Int(12)),
1649                lit(LiteralValue::Int(0)),
1650                lit(LiteralValue::Array(vec![vec![LiteralValue::Int(1)]])),
1651            ],
1652        );
1653        assert_error_kind(out, ExcelErrorKind::NImpl);
1654    }
1655}
1656
1657pub fn register_builtins() {
1658    crate::function_registry::register_function(std::sync::Arc::new(SumProductFn));
1659    crate::function_registry::register_function(std::sync::Arc::new(SumFn));
1660    crate::function_registry::register_function(std::sync::Arc::new(CountFn));
1661    crate::function_registry::register_function(std::sync::Arc::new(AverageFn));
1662    crate::function_registry::register_function(std::sync::Arc::new(SubtotalFn));
1663    crate::function_registry::register_function(std::sync::Arc::new(AggregateFn));
1664}