Skip to main content

formualizer_eval/builtins/math/
aggregate.rs

1use super::super::utils::{ARG_RANGE_NUM_LENIENT_ONE, coerce_num};
2use crate::args::ArgSchema;
3use crate::engine::VisibilityMaskMode;
4use crate::function::Function;
5use crate::function_contract::FunctionDependencyContract;
6use crate::traits::{ArgumentHandle, FunctionContext};
7use arrow_array::Array;
8use formualizer_common::{ExcelError, ExcelErrorKind, LiteralValue};
9use formualizer_macros::func_caps;
10
11/* ─────────────────────────── SUM() ──────────────────────────── */
12
13#[derive(Debug)]
14pub struct SumFn;
15
16/// Adds numeric values across scalars and ranges.
17///
18/// `SUM` evaluates all arguments, coercing text to numbers where possible,
19/// and returns the total. Blank cells and logical values in ranges are ignored.
20///
21/// # Remarks
22/// - If any argument evaluates to an error, `SUM` propagates the first error it encounters.
23/// - Unparseable text literals (e.g., `"foo"`) will result in a `#VALUE!` error.
24///
25/// # Examples
26///
27/// ```yaml,sandbox
28/// title: "Basic scalar addition"
29/// formula: "=SUM(10, 20, 5)"
30/// expected: 35
31/// ```
32///
33/// ```yaml,sandbox
34/// title: "Summing a range"
35/// grid:
36///   A1: 10
37///   A2: 20
38///   A3: "N/A"
39/// formula: "=SUM(A1:A3)"
40/// expected: 30
41/// ```
42///
43/// ```yaml,docs
44/// related:
45///   - SUMIF
46///   - SUMIFS
47///   - SUMPRODUCT
48///   - AVERAGE
49/// faq:
50///   - q: "Why does SUM return #VALUE! for some text arguments?"
51///     a: "Direct scalar text that cannot be parsed as a number raises #VALUE! during coercion."
52///   - q: "Do text and logical values inside ranges get added?"
53///     a: "No. In ranged inputs, only numeric cells contribute to the total."
54/// ```
55///
56/// [formualizer-docgen:schema:start]
57/// Name: SUM
58/// Type: SumFn
59/// Min args: 0
60/// Max args: variadic
61/// Variadic: true
62/// Signature: SUM(arg1...: number@range)
63/// Arg schema: arg1{kinds=number,required=true,shape=range,by_ref=false,coercion=NumberLenientText,max=None,repeating=None,default=false}
64/// Caps: PURE, REDUCTION, NUMERIC_ONLY, STREAM_OK, PARALLEL_ARGS
65/// [formualizer-docgen:schema:end]
66impl Function for SumFn {
67    func_caps!(PURE, REDUCTION, NUMERIC_ONLY, STREAM_OK, PARALLEL_ARGS);
68
69    fn name(&self) -> &'static str {
70        "SUM"
71    }
72    fn min_args(&self) -> usize {
73        0
74    }
75    fn variadic(&self) -> bool {
76        true
77    }
78    fn dependency_contract(&self, arity: usize) -> Option<FunctionDependencyContract> {
79        FunctionDependencyContract::static_reduction(arity, self.min_args())
80    }
81    fn arg_schema(&self) -> &'static [ArgSchema] {
82        &ARG_RANGE_NUM_LENIENT_ONE[..]
83    }
84
85    fn eval<'a, 'b, 'c>(
86        &self,
87        args: &'c [ArgumentHandle<'a, 'b>],
88        ctx: &dyn FunctionContext<'b>,
89    ) -> Result<crate::traits::CalcValue<'b>, ExcelError> {
90        let mut total = 0.0;
91        for arg in args {
92            if let Ok(view) = arg.range_view() {
93                // Propagate errors from range first
94                for res in view.errors_slices() {
95                    let (_, _, err_cols) = res?;
96                    for col in err_cols {
97                        if col.null_count() < col.len() {
98                            for i in 0..col.len() {
99                                if !col.is_null(i) {
100                                    return Ok(crate::traits::CalcValue::Scalar(
101                                        LiteralValue::Error(ExcelError::new(
102                                            crate::arrow_store::unmap_error_code(col.value(i)),
103                                        )),
104                                    ));
105                                }
106                            }
107                        }
108                    }
109                }
110
111                for res in view.numbers_slices() {
112                    let (_, _, num_cols) = res?;
113                    for col in num_cols {
114                        total +=
115                            arrow::compute::kernels::aggregate::sum(col.as_ref()).unwrap_or(0.0);
116                    }
117                }
118            } else {
119                let v = arg.value()?.into_literal();
120                match v {
121                    LiteralValue::Error(e) => {
122                        return Ok(crate::traits::CalcValue::Scalar(LiteralValue::Error(e)));
123                    }
124                    v => total += coerce_num(&v)?,
125                }
126            }
127        }
128        Ok(crate::traits::CalcValue::Scalar(LiteralValue::Number(
129            total,
130        )))
131    }
132}
133
134/* ─────────────────────────── COUNT() ──────────────────────────── */
135
136#[derive(Debug)]
137pub struct CountFn;
138
139/// Counts numeric values across scalars and ranges.
140///
141/// `COUNT` evaluates all arguments and counts how many are numeric values.
142/// Numbers, dates, and text representations of numbers (when supplied directly) are counted.
143///
144/// # Remarks
145/// - Text values inside ranges are ignored and not counted.
146/// - Blank cells and logical values in ranges are ignored.
147///
148/// # Examples
149///
150/// ```yaml,sandbox
151/// title: "Counting mixed scalar inputs"
152/// formula: "=COUNT(1, \"x\", 2, 3)"
153/// expected: 3
154/// ```
155///
156/// ```yaml,sandbox
157/// title: "Counting in a range"
158/// grid:
159///   A1: 10
160///   A2: "foo"
161///   A3: 20
162/// formula: "=COUNT(A1:A3)"
163/// expected: 2
164/// ```
165///
166/// ```yaml,docs
167/// related:
168///   - COUNTA
169///   - COUNTBLANK
170///   - COUNTIF
171///   - COUNTIFS
172/// faq:
173///   - q: "Why doesn't COUNT include text in a range?"
174///     a: "COUNT only counts numeric values; text cells in ranges are ignored."
175///   - q: "Can direct text like \"12\" be counted?"
176///     a: "Yes. Direct scalar arguments are coerced and counted when they parse as numbers."
177/// ```
178///
179/// [formualizer-docgen:schema:start]
180/// Name: COUNT
181/// Type: CountFn
182/// Min args: 0
183/// Max args: variadic
184/// Variadic: true
185/// Signature: COUNT(arg1...: number@range)
186/// Arg schema: arg1{kinds=number,required=true,shape=range,by_ref=false,coercion=NumberLenientText,max=None,repeating=None,default=false}
187/// Caps: PURE, REDUCTION, NUMERIC_ONLY, STREAM_OK
188/// [formualizer-docgen:schema:end]
189impl Function for CountFn {
190    func_caps!(PURE, REDUCTION, NUMERIC_ONLY, STREAM_OK);
191
192    fn name(&self) -> &'static str {
193        "COUNT"
194    }
195    fn min_args(&self) -> usize {
196        0
197    }
198    fn variadic(&self) -> bool {
199        true
200    }
201    fn dependency_contract(&self, arity: usize) -> Option<FunctionDependencyContract> {
202        FunctionDependencyContract::static_reduction(arity, self.min_args())
203    }
204    fn arg_schema(&self) -> &'static [ArgSchema] {
205        &ARG_RANGE_NUM_LENIENT_ONE[..]
206    }
207
208    fn eval<'a, 'b, 'c>(
209        &self,
210        args: &'c [ArgumentHandle<'a, 'b>],
211        _: &dyn FunctionContext<'b>,
212    ) -> Result<crate::traits::CalcValue<'b>, ExcelError> {
213        let mut count: i64 = 0;
214        for arg in args {
215            if let Ok(view) = arg.range_view() {
216                for res in view.numbers_slices() {
217                    let (_, _, num_cols) = res?;
218                    for col in num_cols {
219                        count += (col.len() - col.null_count()) as i64;
220                    }
221                }
222            } else {
223                let v = arg.value()?.into_literal();
224                if let LiteralValue::Error(e) = v {
225                    return Ok(crate::traits::CalcValue::Scalar(LiteralValue::Error(e)));
226                }
227                if !matches!(v, LiteralValue::Empty) && coerce_num(&v).is_ok() {
228                    count += 1;
229                }
230            }
231        }
232        Ok(crate::traits::CalcValue::Scalar(LiteralValue::Number(
233            count as f64,
234        )))
235    }
236}
237
238/* ─────────────────────────── AVERAGE() ──────────────────────────── */
239
240#[derive(Debug)]
241pub struct AverageFn;
242
243/// Returns the arithmetic mean of numeric values across scalars and ranges.
244///
245/// `AVERAGE` sums numeric inputs and divides by the count of numeric values that participated.
246///
247/// # Remarks
248/// - Errors in any scalar argument or referenced range propagate immediately.
249/// - In ranges, only numeric/date-time serial values are included; text and blanks are ignored.
250/// - Scalar arguments use lenient number coercion with locale support.
251/// - If no numeric values are found, `AVERAGE` returns `#DIV/0!`.
252///
253/// # Examples
254///
255/// ```yaml,sandbox
256/// title: "Average of scalar values"
257/// formula: "=AVERAGE(10, 20, 5)"
258/// expected: 11.666666666666666
259/// ```
260///
261/// ```yaml,sandbox
262/// title: "Average over a mixed range"
263/// grid:
264///   A1: 10
265///   A2: "x"
266///   A3: 20
267/// formula: "=AVERAGE(A1:A3)"
268/// expected: 15
269/// ```
270///
271/// ```yaml,sandbox
272/// title: "No numeric values returns divide-by-zero"
273/// formula: "=AVERAGE(\"x\", \"\")"
274/// expected: "#DIV/0!"
275/// ```
276///
277/// ```yaml,docs
278/// related:
279///   - SUM
280///   - COUNT
281///   - AVERAGEIF
282///   - AVERAGEIFS
283/// faq:
284///   - q: "When does AVERAGE return #DIV/0!?"
285///     a: "It returns #DIV/0! when no numeric values are found after filtering/coercion."
286///   - q: "Do text cells in ranges affect the denominator?"
287///     a: "No. Only numeric values are counted toward the divisor."
288/// ```
289///
290/// [formualizer-docgen:schema:start]
291/// Name: AVERAGE
292/// Type: AverageFn
293/// Min args: 1
294/// Max args: variadic
295/// Variadic: true
296/// Signature: AVERAGE(arg1...: number@range)
297/// Arg schema: arg1{kinds=number,required=true,shape=range,by_ref=false,coercion=NumberLenientText,max=None,repeating=None,default=false}
298/// Caps: PURE, REDUCTION, NUMERIC_ONLY, STREAM_OK
299/// [formualizer-docgen:schema:end]
300impl Function for AverageFn {
301    func_caps!(PURE, REDUCTION, NUMERIC_ONLY, STREAM_OK);
302
303    fn name(&self) -> &'static str {
304        "AVERAGE"
305    }
306    fn min_args(&self) -> usize {
307        1
308    }
309    fn variadic(&self) -> bool {
310        true
311    }
312    fn dependency_contract(&self, arity: usize) -> Option<FunctionDependencyContract> {
313        FunctionDependencyContract::static_reduction(arity, self.min_args())
314    }
315    fn arg_schema(&self) -> &'static [ArgSchema] {
316        &ARG_RANGE_NUM_LENIENT_ONE[..]
317    }
318
319    fn eval<'a, 'b, 'c>(
320        &self,
321        args: &'c [ArgumentHandle<'a, 'b>],
322        ctx: &dyn FunctionContext<'b>,
323    ) -> Result<crate::traits::CalcValue<'b>, ExcelError> {
324        let mut sum = 0.0f64;
325        let mut cnt: i64 = 0;
326        for arg in args {
327            if let Ok(view) = arg.range_view() {
328                // Propagate errors from range first
329                for res in view.errors_slices() {
330                    let (_, _, err_cols) = res?;
331                    for col in err_cols {
332                        if col.null_count() < col.len() {
333                            for i in 0..col.len() {
334                                if !col.is_null(i) {
335                                    return Ok(crate::traits::CalcValue::Scalar(
336                                        LiteralValue::Error(ExcelError::new(
337                                            crate::arrow_store::unmap_error_code(col.value(i)),
338                                        )),
339                                    ));
340                                }
341                            }
342                        }
343                    }
344                }
345
346                for res in view.numbers_slices() {
347                    let (_, _, num_cols) = res?;
348                    for col in num_cols {
349                        sum += arrow::compute::kernels::aggregate::sum(col.as_ref()).unwrap_or(0.0);
350                        cnt += (col.len() - col.null_count()) as i64;
351                    }
352                }
353            } else {
354                let v = arg.value()?.into_literal();
355                if let LiteralValue::Error(e) = v {
356                    return Ok(crate::traits::CalcValue::Scalar(LiteralValue::Error(e)));
357                }
358                if let Ok(n) = crate::coercion::to_number_lenient_with_locale(&v, &ctx.locale()) {
359                    sum += n;
360                    cnt += 1;
361                }
362            }
363        }
364        if cnt == 0 {
365            return Ok(crate::traits::CalcValue::Scalar(LiteralValue::Error(
366                ExcelError::new_div(),
367            )));
368        }
369        Ok(crate::traits::CalcValue::Scalar(LiteralValue::Number(
370            sum / (cnt as f64),
371        )))
372    }
373}
374
375/* ──────────────────────── SUMPRODUCT() ───────────────────────── */
376
377#[derive(Debug)]
378pub struct SumProductFn;
379
380/// Multiplies aligned values across arrays and returns the sum of those products.
381///
382/// `SUMPRODUCT` supports scalar or range inputs, applies broadcast semantics, and accumulates
383/// the product for each aligned cell position.
384///
385/// # Remarks
386/// - Input shapes must be broadcast-compatible; otherwise `SUMPRODUCT` returns `#VALUE!`.
387/// - Non-numeric values are treated as `0` during multiplication.
388/// - Any explicit error value in the inputs propagates immediately.
389///
390/// # Examples
391///
392/// ```yaml,sandbox
393/// title: "Pairwise sum of products"
394/// formula: "=SUMPRODUCT({1,2,3}, {4,5,6})"
395/// expected: 32
396/// ```
397///
398/// ```yaml,sandbox
399/// title: "Range-based sumproduct"
400/// grid:
401///   A1: 2
402///   A2: 3
403///   A3: 4
404///   B1: 10
405///   B2: 20
406///   B3: 30
407/// formula: "=SUMPRODUCT(A1:A3, B1:B3)"
408/// expected: 200
409/// ```
410///
411/// ```yaml,sandbox
412/// title: "Text entries contribute zero"
413/// formula: "=SUMPRODUCT({1,\"x\",3}, {1,1,1})"
414/// expected: 4
415/// ```
416///
417/// ```yaml,docs
418/// related:
419///   - SUM
420///   - PRODUCT
421///   - MMULT
422///   - SUMIFS
423/// faq:
424///   - q: "Why does SUMPRODUCT return #VALUE! with some array shapes?"
425///     a: "The argument arrays must be broadcast-compatible; incompatible shapes raise #VALUE!."
426///   - q: "How are text values handled in multiplication?"
427///     a: "Non-numeric values are treated as 0, unless an explicit error is present."
428/// ```
429///
430/// [formualizer-docgen:schema:start]
431/// Name: SUMPRODUCT
432/// Type: SumProductFn
433/// Min args: 1
434/// Max args: variadic
435/// Variadic: true
436/// Signature: SUMPRODUCT(arg1...: number@range)
437/// Arg schema: arg1{kinds=number,required=true,shape=range,by_ref=false,coercion=NumberLenientText,max=None,repeating=None,default=false}
438/// Caps: PURE, REDUCTION
439/// [formualizer-docgen:schema:end]
440impl Function for SumProductFn {
441    // Pure reduction over arrays; uses broadcasting and lenient coercion
442    func_caps!(PURE, REDUCTION);
443
444    fn name(&self) -> &'static str {
445        "SUMPRODUCT"
446    }
447    fn min_args(&self) -> usize {
448        1
449    }
450    fn variadic(&self) -> bool {
451        true
452    }
453    fn arg_schema(&self) -> &'static [ArgSchema] {
454        // Accept ranges or scalars; numeric lenient coercion
455        &ARG_RANGE_NUM_LENIENT_ONE[..]
456    }
457
458    fn eval<'a, 'b, 'c>(
459        &self,
460        args: &'c [ArgumentHandle<'a, 'b>],
461        _: &dyn FunctionContext<'b>,
462    ) -> Result<crate::traits::CalcValue<'b>, ExcelError> {
463        use crate::broadcast::{broadcast_shape, project_index};
464
465        if args.is_empty() {
466            return Ok(crate::traits::CalcValue::Scalar(LiteralValue::Number(0.0)));
467        }
468
469        // Helper: materialize an argument to a 2D array of LiteralValue
470        let to_array = |ah: &ArgumentHandle| -> Result<Vec<Vec<LiteralValue>>, ExcelError> {
471            if let Ok(rv) = ah.range_view() {
472                let mut rows: Vec<Vec<LiteralValue>> = Vec::new();
473                rv.for_each_row(&mut |row| {
474                    rows.push(row.to_vec());
475                    Ok(())
476                })?;
477                Ok(rows)
478            } else {
479                let v = ah.value()?.into_literal();
480                Ok(match v {
481                    LiteralValue::Array(arr) => arr,
482                    other => vec![vec![other]],
483                })
484            }
485        };
486
487        // Collect arrays and shapes
488        let mut arrays: Vec<Vec<Vec<LiteralValue>>> = Vec::with_capacity(args.len());
489        let mut shapes: Vec<(usize, usize)> = Vec::with_capacity(args.len());
490        for a in args.iter() {
491            let arr = to_array(a)?;
492            let shape = (arr.len(), arr.first().map(|r| r.len()).unwrap_or(0));
493            arrays.push(arr);
494            shapes.push(shape);
495        }
496
497        // Compute broadcast target shape across all args
498        let target = match broadcast_shape(&shapes) {
499            Ok(s) => s,
500            Err(_) => {
501                return Ok(crate::traits::CalcValue::Scalar(LiteralValue::Error(
502                    ExcelError::new_value(),
503                )));
504            }
505        };
506
507        // Iterate target shape, multiply coerced values across args, sum total
508        let mut total = 0.0f64;
509        for r in 0..target.0 {
510            for c in 0..target.1 {
511                let mut prod = 1.0f64;
512                for (arr, &shape) in arrays.iter().zip(shapes.iter()) {
513                    let (rr, cc) = project_index((r, c), shape);
514                    let lv = arr
515                        .get(rr)
516                        .and_then(|row| row.get(cc))
517                        .cloned()
518                        .unwrap_or(LiteralValue::Empty);
519                    match lv {
520                        LiteralValue::Error(e) => {
521                            return Ok(crate::traits::CalcValue::Scalar(LiteralValue::Error(e)));
522                        }
523                        _ => match super::super::utils::coerce_num(&lv) {
524                            Ok(n) => {
525                                prod *= n;
526                            }
527                            Err(_) => {
528                                // Non-numeric -> treated as 0 in SUMPRODUCT
529                                prod *= 0.0;
530                            }
531                        },
532                    }
533                }
534                total += prod;
535            }
536        }
537        Ok(crate::traits::CalcValue::Scalar(LiteralValue::Number(
538            total,
539        )))
540    }
541}
542
543#[cfg(test)]
544mod tests_sumproduct {
545    use super::*;
546    use crate::test_workbook::TestWorkbook;
547    use crate::traits::ArgumentHandle;
548    use formualizer_parse::LiteralValue;
549    use formualizer_parse::parser::{ASTNode, ASTNodeType};
550
551    fn interp(wb: &TestWorkbook) -> crate::interpreter::Interpreter<'_> {
552        wb.interpreter()
553    }
554
555    fn arr(vals: Vec<Vec<LiteralValue>>) -> ASTNode {
556        ASTNode::new(ASTNodeType::Literal(LiteralValue::Array(vals)), None)
557    }
558
559    fn num(n: f64) -> ASTNode {
560        ASTNode::new(ASTNodeType::Literal(LiteralValue::Number(n)), None)
561    }
562
563    #[test]
564    fn sumproduct_basic_pairwise() {
565        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(SumProductFn));
566        let ctx = interp(&wb);
567        // {1,2,3} * {4,5,6} = 1*4 + 2*5 + 3*6 = 32
568        let a = arr(vec![vec![
569            LiteralValue::Int(1),
570            LiteralValue::Int(2),
571            LiteralValue::Int(3),
572        ]]);
573        let b = arr(vec![vec![
574            LiteralValue::Int(4),
575            LiteralValue::Int(5),
576            LiteralValue::Int(6),
577        ]]);
578        let args = vec![ArgumentHandle::new(&a, &ctx), ArgumentHandle::new(&b, &ctx)];
579        let f = ctx.context.get_function("", "SUMPRODUCT").unwrap();
580        assert_eq!(
581            f.dispatch(&args, &ctx.function_context(None))
582                .unwrap()
583                .into_literal(),
584            LiteralValue::Number(32.0)
585        );
586    }
587
588    #[test]
589    fn sumproduct_variadic_three_arrays() {
590        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(SumProductFn));
591        let ctx = interp(&wb);
592        // {1,2} * {3,4} * {2,2} = (1*3*2) + (2*4*2) = 6 + 16 = 22
593        let a = arr(vec![vec![LiteralValue::Int(1), LiteralValue::Int(2)]]);
594        let b = arr(vec![vec![LiteralValue::Int(3), LiteralValue::Int(4)]]);
595        let c = arr(vec![vec![LiteralValue::Int(2), LiteralValue::Int(2)]]);
596        let args = vec![
597            ArgumentHandle::new(&a, &ctx),
598            ArgumentHandle::new(&b, &ctx),
599            ArgumentHandle::new(&c, &ctx),
600        ];
601        let f = ctx.context.get_function("", "SUMPRODUCT").unwrap();
602        assert_eq!(
603            f.dispatch(&args, &ctx.function_context(None))
604                .unwrap()
605                .into_literal(),
606            LiteralValue::Number(22.0)
607        );
608    }
609
610    #[test]
611    fn sumproduct_broadcast_scalar_over_array() {
612        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(SumProductFn));
613        let ctx = interp(&wb);
614        // {1,2,3} * 10 => (1*10 + 2*10 + 3*10) = 60
615        let a = arr(vec![vec![
616            LiteralValue::Int(1),
617            LiteralValue::Int(2),
618            LiteralValue::Int(3),
619        ]]);
620        let s = num(10.0);
621        let args = vec![ArgumentHandle::new(&a, &ctx), ArgumentHandle::new(&s, &ctx)];
622        let f = ctx.context.get_function("", "SUMPRODUCT").unwrap();
623        assert_eq!(
624            f.dispatch(&args, &ctx.function_context(None))
625                .unwrap()
626                .into_literal(),
627            LiteralValue::Number(60.0)
628        );
629    }
630
631    #[test]
632    fn sumproduct_2d_arrays_broadcast_rows_cols() {
633        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(SumProductFn));
634        let ctx = interp(&wb);
635        // A is 2x2, B is 1x2 -> broadcast B across rows
636        // A = [[1,2],[3,4]], B = [[10,20]]
637        // sum = 1*10 + 2*20 + 3*10 + 4*20 = 10 + 40 + 30 + 80 = 160
638        let a = arr(vec![
639            vec![LiteralValue::Int(1), LiteralValue::Int(2)],
640            vec![LiteralValue::Int(3), LiteralValue::Int(4)],
641        ]);
642        let b = arr(vec![vec![LiteralValue::Int(10), LiteralValue::Int(20)]]);
643        let args = vec![ArgumentHandle::new(&a, &ctx), ArgumentHandle::new(&b, &ctx)];
644        let f = ctx.context.get_function("", "SUMPRODUCT").unwrap();
645        assert_eq!(
646            f.dispatch(&args, &ctx.function_context(None))
647                .unwrap()
648                .into_literal(),
649            LiteralValue::Number(160.0)
650        );
651    }
652
653    #[test]
654    fn sumproduct_non_numeric_treated_as_zero() {
655        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(SumProductFn));
656        let ctx = interp(&wb);
657        // {1,"x",3} * {1,1,1} => 1*1 + 0*1 + 3*1 = 4
658        let a = arr(vec![vec![
659            LiteralValue::Int(1),
660            LiteralValue::Text("x".into()),
661            LiteralValue::Int(3),
662        ]]);
663        let b = arr(vec![vec![
664            LiteralValue::Int(1),
665            LiteralValue::Int(1),
666            LiteralValue::Int(1),
667        ]]);
668        let args = vec![ArgumentHandle::new(&a, &ctx), ArgumentHandle::new(&b, &ctx)];
669        let f = ctx.context.get_function("", "SUMPRODUCT").unwrap();
670        assert_eq!(
671            f.dispatch(&args, &ctx.function_context(None))
672                .unwrap()
673                .into_literal(),
674            LiteralValue::Number(4.0)
675        );
676    }
677
678    #[test]
679    fn sumproduct_error_in_input_propagates() {
680        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(SumProductFn));
681        let ctx = interp(&wb);
682        let a = arr(vec![vec![LiteralValue::Int(1), LiteralValue::Int(2)]]);
683        let e = ASTNode::new(
684            ASTNodeType::Literal(LiteralValue::Error(ExcelError::new_na())),
685            None,
686        );
687        let args = vec![ArgumentHandle::new(&a, &ctx), ArgumentHandle::new(&e, &ctx)];
688        let f = ctx.context.get_function("", "SUMPRODUCT").unwrap();
689        match f
690            .dispatch(&args, &ctx.function_context(None))
691            .unwrap()
692            .into_literal()
693        {
694            LiteralValue::Error(err) => assert_eq!(err, "#N/A"),
695            v => panic!("expected error, got {v:?}"),
696        }
697    }
698
699    #[test]
700    fn sumproduct_incompatible_shapes_value_error() {
701        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(SumProductFn));
702        let ctx = interp(&wb);
703        // 1x3 and 1x2 -> #VALUE!
704        let a = arr(vec![vec![
705            LiteralValue::Int(1),
706            LiteralValue::Int(2),
707            LiteralValue::Int(3),
708        ]]);
709        let b = arr(vec![vec![LiteralValue::Int(4), LiteralValue::Int(5)]]);
710        let args = vec![ArgumentHandle::new(&a, &ctx), ArgumentHandle::new(&b, &ctx)];
711        let f = ctx.context.get_function("", "SUMPRODUCT").unwrap();
712        match f
713            .dispatch(&args, &ctx.function_context(None))
714            .unwrap()
715            .into_literal()
716        {
717            LiteralValue::Error(e) => assert_eq!(e, "#VALUE!"),
718            v => panic!("expected value error, got {v:?}"),
719        }
720    }
721}
722
723#[cfg(test)]
724mod tests {
725    use super::*;
726    use crate::test_workbook::TestWorkbook;
727    use formualizer_parse::LiteralValue;
728
729    fn interp(wb: &TestWorkbook) -> crate::interpreter::Interpreter<'_> {
730        wb.interpreter()
731    }
732
733    #[test]
734    fn test_sum_caps() {
735        let sum_fn = SumFn;
736        let caps = sum_fn.caps();
737
738        // Check that the expected capabilities are set
739        assert!(caps.contains(crate::function::FnCaps::PURE));
740        assert!(caps.contains(crate::function::FnCaps::REDUCTION));
741        assert!(caps.contains(crate::function::FnCaps::NUMERIC_ONLY));
742        assert!(caps.contains(crate::function::FnCaps::STREAM_OK));
743
744        // Check that other caps are not set
745        assert!(!caps.contains(crate::function::FnCaps::VOLATILE));
746        assert!(!caps.contains(crate::function::FnCaps::ELEMENTWISE));
747    }
748
749    #[test]
750    fn test_sum_basic() {
751        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(SumFn));
752        let ctx = interp(&wb);
753        let fctx = ctx.function_context(None);
754
755        // Test basic SUM functionality by creating ArgumentHandles manually
756        let dummy_ast_1 = formualizer_parse::parser::ASTNode::new(
757            formualizer_parse::parser::ASTNodeType::Literal(LiteralValue::Number(1.0)),
758            None,
759        );
760        let dummy_ast_2 = formualizer_parse::parser::ASTNode::new(
761            formualizer_parse::parser::ASTNodeType::Literal(LiteralValue::Number(2.0)),
762            None,
763        );
764        let dummy_ast_3 = formualizer_parse::parser::ASTNode::new(
765            formualizer_parse::parser::ASTNodeType::Literal(LiteralValue::Number(3.0)),
766            None,
767        );
768
769        let args = vec![
770            ArgumentHandle::new(&dummy_ast_1, &ctx),
771            ArgumentHandle::new(&dummy_ast_2, &ctx),
772            ArgumentHandle::new(&dummy_ast_3, &ctx),
773        ];
774
775        let sum_fn = ctx.context.get_function("", "SUM").unwrap();
776        let result = sum_fn.dispatch(&args, &fctx).unwrap().into_literal();
777        assert_eq!(result, LiteralValue::Number(6.0));
778    }
779}
780
781#[cfg(test)]
782mod tests_count {
783    use super::*;
784    use crate::test_workbook::TestWorkbook;
785    use crate::traits::ArgumentHandle;
786    use formualizer_parse::LiteralValue;
787    use formualizer_parse::parser::ASTNode;
788    use formualizer_parse::parser::ASTNodeType;
789
790    fn interp(wb: &TestWorkbook) -> crate::interpreter::Interpreter<'_> {
791        wb.interpreter()
792    }
793
794    #[test]
795    fn count_numbers_ignores_text() {
796        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(CountFn));
797        let ctx = interp(&wb);
798        // COUNT({1,2,"x",3}) => 3
799        let arr = LiteralValue::Array(vec![vec![
800            LiteralValue::Int(1),
801            LiteralValue::Int(2),
802            LiteralValue::Text("x".into()),
803            LiteralValue::Int(3),
804        ]]);
805        let node = ASTNode::new(ASTNodeType::Literal(arr), None);
806        let args = vec![ArgumentHandle::new(&node, &ctx)];
807        let f = ctx.context.get_function("", "COUNT").unwrap();
808        let fctx = ctx.function_context(None);
809        assert_eq!(
810            f.dispatch(&args, &fctx).unwrap().into_literal(),
811            LiteralValue::Number(3.0)
812        );
813    }
814
815    #[test]
816    fn count_multiple_args_and_scalars() {
817        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(CountFn));
818        let ctx = interp(&wb);
819        let n1 = ASTNode::new(ASTNodeType::Literal(LiteralValue::Int(10)), None);
820        let n2 = ASTNode::new(ASTNodeType::Literal(LiteralValue::Text("n".into())), None);
821        let arr = LiteralValue::Array(vec![vec![LiteralValue::Int(1), LiteralValue::Int(2)]]);
822        let a = ASTNode::new(ASTNodeType::Literal(arr), None);
823        let args = vec![
824            ArgumentHandle::new(&a, &ctx),
825            ArgumentHandle::new(&n1, &ctx),
826            ArgumentHandle::new(&n2, &ctx),
827        ];
828        let f = ctx.context.get_function("", "COUNT").unwrap();
829        // Two from array + scalar 10 = 3
830        let fctx = ctx.function_context(None);
831        assert_eq!(
832            f.dispatch(&args, &fctx).unwrap().into_literal(),
833            LiteralValue::Number(3.0)
834        );
835    }
836
837    #[test]
838    fn count_direct_error_argument_propagates() {
839        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(CountFn));
840        let ctx = interp(&wb);
841        let err = ASTNode::new(
842            ASTNodeType::Literal(LiteralValue::Error(ExcelError::from_error_string(
843                "#DIV/0!",
844            ))),
845            None,
846        );
847        let args = vec![ArgumentHandle::new(&err, &ctx)];
848        let f = ctx.context.get_function("", "COUNT").unwrap();
849        let fctx = ctx.function_context(None);
850        match f.dispatch(&args, &fctx).unwrap().into_literal() {
851            LiteralValue::Error(e) => assert_eq!(e, "#DIV/0!"),
852            v => panic!("unexpected {v:?}"),
853        }
854    }
855}
856
857#[cfg(test)]
858mod tests_average {
859    use super::*;
860    use crate::test_workbook::TestWorkbook;
861    use crate::traits::ArgumentHandle;
862    use formualizer_parse::LiteralValue;
863    use formualizer_parse::parser::ASTNode;
864    use formualizer_parse::parser::ASTNodeType;
865
866    fn interp(wb: &TestWorkbook) -> crate::interpreter::Interpreter<'_> {
867        wb.interpreter()
868    }
869
870    #[test]
871    fn average_basic_numbers() {
872        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(AverageFn));
873        let ctx = interp(&wb);
874        let arr = LiteralValue::Array(vec![vec![
875            LiteralValue::Int(2),
876            LiteralValue::Int(4),
877            LiteralValue::Int(6),
878        ]]);
879        let node = ASTNode::new(ASTNodeType::Literal(arr), None);
880        let args = vec![ArgumentHandle::new(&node, &ctx)];
881        let f = ctx.context.get_function("", "AVERAGE").unwrap();
882        assert_eq!(
883            f.dispatch(&args, &ctx.function_context(None))
884                .unwrap()
885                .into_literal(),
886            LiteralValue::Number(4.0)
887        );
888    }
889
890    #[test]
891    fn average_mixed_with_text() {
892        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(AverageFn));
893        let ctx = interp(&wb);
894        let arr = LiteralValue::Array(vec![vec![
895            LiteralValue::Int(2),
896            LiteralValue::Text("x".into()),
897            LiteralValue::Int(6),
898        ]]);
899        let node = ASTNode::new(ASTNodeType::Literal(arr), None);
900        let args = vec![ArgumentHandle::new(&node, &ctx)];
901        let f = ctx.context.get_function("", "AVERAGE").unwrap();
902        // average of 2 and 6 = 4
903        assert_eq!(
904            f.dispatch(&args, &ctx.function_context(None))
905                .unwrap()
906                .into_literal(),
907            LiteralValue::Number(4.0)
908        );
909    }
910
911    #[test]
912    fn average_no_numeric_div0() {
913        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(AverageFn));
914        let ctx = interp(&wb);
915        let arr = LiteralValue::Array(vec![vec![
916            LiteralValue::Text("a".into()),
917            LiteralValue::Text("b".into()),
918        ]]);
919        let node = ASTNode::new(ASTNodeType::Literal(arr), None);
920        let args = vec![ArgumentHandle::new(&node, &ctx)];
921        let f = ctx.context.get_function("", "AVERAGE").unwrap();
922        let fctx = ctx.function_context(None);
923        match f.dispatch(&args, &fctx).unwrap().into_literal() {
924            LiteralValue::Error(e) => assert_eq!(e, "#DIV/0!"),
925            v => panic!("expected #DIV/0!, got {v:?}"),
926        }
927    }
928
929    #[test]
930    fn average_direct_error_argument_propagates() {
931        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(AverageFn));
932        let ctx = interp(&wb);
933        let err = ASTNode::new(
934            ASTNodeType::Literal(LiteralValue::Error(ExcelError::from_error_string(
935                "#DIV/0!",
936            ))),
937            None,
938        );
939        let args = vec![ArgumentHandle::new(&err, &ctx)];
940        let f = ctx.context.get_function("", "AVERAGE").unwrap();
941        let fctx = ctx.function_context(None);
942        match f.dispatch(&args, &fctx).unwrap().into_literal() {
943            LiteralValue::Error(e) => assert_eq!(e, "#DIV/0!"),
944            v => panic!("unexpected {v:?}"),
945        }
946    }
947}
948
949#[derive(Copy, Clone, Debug, Eq, PartialEq)]
950enum VisibilityPolicy {
951    IncludeAll,
952    ExcludeManualOrFilterHidden,
953}
954
955#[derive(Copy, Clone, Debug, Eq, PartialEq)]
956enum ErrorPolicy {
957    Propagate,
958    Ignore,
959}
960
961#[derive(Copy, Clone, Debug, Eq, PartialEq)]
962enum AggregateOp {
963    Average,
964    Count,
965    CountA,
966    Max,
967    Min,
968    Product,
969    StdevSample,
970    StdevPopulation,
971    Sum,
972    VarSample,
973    VarPopulation,
974}
975
976fn aggregate_op_from_function_num(function_num: i32) -> Option<AggregateOp> {
977    match function_num {
978        1 => Some(AggregateOp::Average),
979        2 => Some(AggregateOp::Count),
980        3 => Some(AggregateOp::CountA),
981        4 => Some(AggregateOp::Max),
982        5 => Some(AggregateOp::Min),
983        6 => Some(AggregateOp::Product),
984        7 => Some(AggregateOp::StdevSample),
985        8 => Some(AggregateOp::StdevPopulation),
986        9 => Some(AggregateOp::Sum),
987        10 => Some(AggregateOp::VarSample),
988        11 => Some(AggregateOp::VarPopulation),
989        _ => None,
990    }
991}
992
993fn parse_strict_int_arg(arg: &ArgumentHandle<'_, '_>) -> Result<i32, ExcelError> {
994    let raw = arg.value()?.into_literal();
995    if let LiteralValue::Error(e) = raw {
996        return Err(e);
997    }
998
999    let n = coerce_num(&raw)?;
1000    if !n.is_finite() {
1001        return Err(ExcelError::new_value());
1002    }
1003
1004    let rounded = n.round();
1005    if (n - rounded).abs() > 1e-9 {
1006        return Err(ExcelError::new_value());
1007    }
1008
1009    if rounded < i32::MIN as f64 || rounded > i32::MAX as f64 {
1010        return Err(ExcelError::new_value());
1011    }
1012
1013    Ok(rounded as i32)
1014}
1015
1016fn row_is_visible(mask: Option<&arrow_array::BooleanArray>, relative_row: usize) -> bool {
1017    let Some(mask) = mask else {
1018        return true;
1019    };
1020
1021    if relative_row >= mask.len() || mask.is_null(relative_row) {
1022        return true;
1023    }
1024
1025    mask.value(relative_row)
1026}
1027
1028fn numeric_from_range_value(value: &LiteralValue) -> Option<f64> {
1029    match value {
1030        LiteralValue::Number(n) => Some(*n),
1031        LiteralValue::Int(i) => Some(*i as f64),
1032        LiteralValue::Date(_)
1033        | LiteralValue::DateTime(_)
1034        | LiteralValue::Time(_)
1035        | LiteralValue::Duration(_) => coerce_num(value).ok(),
1036        _ => None,
1037    }
1038}
1039
1040#[derive(Debug, Default)]
1041struct AggregateCollector {
1042    numeric_values: Vec<f64>,
1043    counta: usize,
1044}
1045
1046impl AggregateCollector {
1047    fn collect_args<'a, 'b>(
1048        args: &[ArgumentHandle<'a, 'b>],
1049        start_idx: usize,
1050        ctx: &dyn FunctionContext<'b>,
1051        op: AggregateOp,
1052        visibility_policy: VisibilityPolicy,
1053        error_policy: ErrorPolicy,
1054    ) -> Result<Self, ExcelError> {
1055        let mut out = Self::default();
1056
1057        for arg in args.iter().skip(start_idx) {
1058            if let Ok(view) = arg.range_view() {
1059                out.collect_range_arg(&view, ctx, op, visibility_policy, error_policy)?;
1060            } else {
1061                out.consume_scalar_value(arg.value()?.into_literal(), op, error_policy)?;
1062            }
1063        }
1064
1065        Ok(out)
1066    }
1067
1068    fn collect_range_arg<'b>(
1069        &mut self,
1070        view: &crate::engine::range_view::RangeView<'_>,
1071        ctx: &dyn FunctionContext<'b>,
1072        op: AggregateOp,
1073        visibility_policy: VisibilityPolicy,
1074        error_policy: ErrorPolicy,
1075    ) -> Result<(), ExcelError> {
1076        let visibility_mask = match visibility_policy {
1077            VisibilityPolicy::IncludeAll => None,
1078            VisibilityPolicy::ExcludeManualOrFilterHidden => {
1079                ctx.get_row_visibility_mask(view, VisibilityMaskMode::ExcludeManualOrFilterHidden)
1080            }
1081        };
1082
1083        let (_, cols) = view.dims();
1084        if cols == 0 {
1085            return Ok(());
1086        }
1087
1088        for chunk in view.iter_row_chunks() {
1089            let chunk = chunk?;
1090            for row_offset in 0..chunk.row_len {
1091                let rel_row = chunk.row_start + row_offset;
1092                if !row_is_visible(visibility_mask.as_deref(), rel_row) {
1093                    continue;
1094                }
1095
1096                for col in 0..cols {
1097                    // Phase-1 contract: nested SUBTOTAL/AGGREGATE exclusion is deferred.
1098                    // Nested aggregate results are treated as ordinary scalar values.
1099                    self.consume_range_value(view.get_cell(rel_row, col), op, error_policy)?;
1100                }
1101            }
1102        }
1103
1104        Ok(())
1105    }
1106
1107    fn consume_range_value(
1108        &mut self,
1109        value: LiteralValue,
1110        op: AggregateOp,
1111        error_policy: ErrorPolicy,
1112    ) -> Result<(), ExcelError> {
1113        match value {
1114            LiteralValue::Error(e) => {
1115                if op == AggregateOp::CountA {
1116                    if error_policy == ErrorPolicy::Ignore {
1117                        return Ok(());
1118                    }
1119                    self.counta += 1;
1120                    return Ok(());
1121                }
1122                match error_policy {
1123                    ErrorPolicy::Propagate => Err(e),
1124                    ErrorPolicy::Ignore => Ok(()),
1125                }
1126            }
1127            LiteralValue::Empty => Ok(()),
1128            other => {
1129                self.counta += 1;
1130                if let Some(n) = numeric_from_range_value(&other) {
1131                    self.numeric_values.push(n);
1132                }
1133                Ok(())
1134            }
1135        }
1136    }
1137
1138    fn consume_scalar_value(
1139        &mut self,
1140        value: LiteralValue,
1141        op: AggregateOp,
1142        error_policy: ErrorPolicy,
1143    ) -> Result<(), ExcelError> {
1144        match value {
1145            LiteralValue::Error(e) => {
1146                if op == AggregateOp::CountA {
1147                    if error_policy == ErrorPolicy::Ignore {
1148                        return Ok(());
1149                    }
1150                    self.counta += 1;
1151                    return Ok(());
1152                }
1153                match error_policy {
1154                    ErrorPolicy::Propagate => Err(e),
1155                    ErrorPolicy::Ignore => Ok(()),
1156                }
1157            }
1158            LiteralValue::Array(rows) => {
1159                for row in rows {
1160                    for cell in row {
1161                        self.consume_range_value(cell, op, error_policy)?;
1162                    }
1163                }
1164                Ok(())
1165            }
1166            other => {
1167                match op {
1168                    AggregateOp::CountA => {
1169                        if !matches!(other, LiteralValue::Empty) {
1170                            self.counta += 1;
1171                        }
1172                    }
1173                    AggregateOp::Count => {
1174                        if !matches!(other, LiteralValue::Empty) && coerce_num(&other).is_ok() {
1175                            self.numeric_values.push(0.0);
1176                        }
1177                    }
1178                    _ => {
1179                        if let Ok(n) = coerce_num(&other) {
1180                            self.numeric_values.push(n);
1181                        }
1182                    }
1183                }
1184                Ok(())
1185            }
1186        }
1187    }
1188
1189    fn variance(values: &[f64], sample: bool) -> Result<f64, ExcelError> {
1190        let n = values.len();
1191        if sample {
1192            if n < 2 {
1193                return Err(ExcelError::new_div());
1194            }
1195        } else if n == 0 {
1196            return Err(ExcelError::new_div());
1197        }
1198
1199        let mean = values.iter().copied().sum::<f64>() / (n as f64);
1200        let mut ss = 0.0;
1201        for value in values {
1202            let d = *value - mean;
1203            ss += d * d;
1204        }
1205
1206        if sample {
1207            Ok(ss / ((n - 1) as f64))
1208        } else {
1209            Ok(ss / (n as f64))
1210        }
1211    }
1212
1213    fn finalize(self, op: AggregateOp) -> LiteralValue {
1214        match op {
1215            AggregateOp::Average => {
1216                if self.numeric_values.is_empty() {
1217                    LiteralValue::Error(ExcelError::new_div())
1218                } else {
1219                    let sum = self.numeric_values.iter().copied().sum::<f64>();
1220                    LiteralValue::Number(sum / (self.numeric_values.len() as f64))
1221                }
1222            }
1223            AggregateOp::Count => LiteralValue::Number(self.numeric_values.len() as f64),
1224            AggregateOp::CountA => LiteralValue::Number(self.counta as f64),
1225            AggregateOp::Max => LiteralValue::Number(
1226                self.numeric_values
1227                    .iter()
1228                    .copied()
1229                    .reduce(f64::max)
1230                    .unwrap_or(0.0),
1231            ),
1232            AggregateOp::Min => LiteralValue::Number(
1233                self.numeric_values
1234                    .iter()
1235                    .copied()
1236                    .reduce(f64::min)
1237                    .unwrap_or(0.0),
1238            ),
1239            AggregateOp::Product => {
1240                if self.numeric_values.is_empty() {
1241                    LiteralValue::Number(0.0)
1242                } else {
1243                    LiteralValue::Number(self.numeric_values.iter().copied().product::<f64>())
1244                }
1245            }
1246            AggregateOp::StdevSample => match Self::variance(&self.numeric_values, true) {
1247                Ok(v) => LiteralValue::Number(v.sqrt()),
1248                Err(e) => LiteralValue::Error(e),
1249            },
1250            AggregateOp::StdevPopulation => match Self::variance(&self.numeric_values, false) {
1251                Ok(v) => LiteralValue::Number(v.sqrt()),
1252                Err(e) => LiteralValue::Error(e),
1253            },
1254            AggregateOp::Sum => LiteralValue::Number(self.numeric_values.iter().copied().sum()),
1255            AggregateOp::VarSample => match Self::variance(&self.numeric_values, true) {
1256                Ok(v) => LiteralValue::Number(v),
1257                Err(e) => LiteralValue::Error(e),
1258            },
1259            AggregateOp::VarPopulation => match Self::variance(&self.numeric_values, false) {
1260                Ok(v) => LiteralValue::Number(v),
1261                Err(e) => LiteralValue::Error(e),
1262            },
1263        }
1264    }
1265}
1266
1267#[derive(Debug)]
1268pub struct SubtotalFn;
1269
1270/// [formualizer-docgen:schema:start]
1271/// Name: SUBTOTAL
1272/// Type: SubtotalFn
1273/// Min args: 2
1274/// Max args: variadic
1275/// Variadic: true
1276/// Signature: SUBTOTAL(arg1...: number@range)
1277/// Arg schema: arg1{kinds=number,required=true,shape=range,by_ref=false,coercion=NumberLenientText,max=None,repeating=None,default=false}
1278/// Caps: VOLATILE, REDUCTION, NUMERIC_ONLY, STREAM_OK
1279/// [formualizer-docgen:schema:end]
1280impl Function for SubtotalFn {
1281    func_caps!(VOLATILE, REDUCTION, NUMERIC_ONLY, STREAM_OK);
1282
1283    fn name(&self) -> &'static str {
1284        "SUBTOTAL"
1285    }
1286
1287    fn min_args(&self) -> usize {
1288        2
1289    }
1290
1291    fn variadic(&self) -> bool {
1292        true
1293    }
1294
1295    fn arg_schema(&self) -> &'static [ArgSchema] {
1296        &ARG_RANGE_NUM_LENIENT_ONE[..]
1297    }
1298
1299    fn eval<'a, 'b, 'c>(
1300        &self,
1301        args: &'c [ArgumentHandle<'a, 'b>],
1302        ctx: &dyn FunctionContext<'b>,
1303    ) -> Result<crate::traits::CalcValue<'b>, ExcelError> {
1304        if args.len() < 2 {
1305            return Ok(crate::traits::CalcValue::Scalar(LiteralValue::Error(
1306                ExcelError::new_value(),
1307            )));
1308        }
1309
1310        let function_num = match parse_strict_int_arg(&args[0]) {
1311            Ok(v) => v,
1312            Err(e) => return Ok(crate::traits::CalcValue::Scalar(LiteralValue::Error(e))),
1313        };
1314
1315        let (mapped_code, visibility) = if (1..=11).contains(&function_num) {
1316            (function_num, VisibilityPolicy::IncludeAll)
1317        } else if (101..=111).contains(&function_num) {
1318            (
1319                function_num - 100,
1320                VisibilityPolicy::ExcludeManualOrFilterHidden,
1321            )
1322        } else {
1323            return Ok(crate::traits::CalcValue::Scalar(LiteralValue::Error(
1324                ExcelError::new_value(),
1325            )));
1326        };
1327
1328        let Some(op) = aggregate_op_from_function_num(mapped_code) else {
1329            return Ok(crate::traits::CalcValue::Scalar(LiteralValue::Error(
1330                ExcelError::new_value(),
1331            )));
1332        };
1333
1334        let collected = match AggregateCollector::collect_args(
1335            args,
1336            1,
1337            ctx,
1338            op,
1339            visibility,
1340            ErrorPolicy::Propagate,
1341        ) {
1342            Ok(c) => c,
1343            Err(e) => return Ok(crate::traits::CalcValue::Scalar(LiteralValue::Error(e))),
1344        };
1345
1346        Ok(crate::traits::CalcValue::Scalar(collected.finalize(op)))
1347    }
1348}
1349
1350#[derive(Debug)]
1351pub struct AggregateFn;
1352
1353/// [formualizer-docgen:schema:start]
1354/// Name: AGGREGATE
1355/// Type: AggregateFn
1356/// Min args: 3
1357/// Max args: variadic
1358/// Variadic: true
1359/// Signature: AGGREGATE(arg1...: number@range)
1360/// Arg schema: arg1{kinds=number,required=true,shape=range,by_ref=false,coercion=NumberLenientText,max=None,repeating=None,default=false}
1361/// Caps: VOLATILE, REDUCTION, NUMERIC_ONLY, STREAM_OK
1362/// [formualizer-docgen:schema:end]
1363impl Function for AggregateFn {
1364    func_caps!(VOLATILE, REDUCTION, NUMERIC_ONLY, STREAM_OK);
1365
1366    fn name(&self) -> &'static str {
1367        "AGGREGATE"
1368    }
1369
1370    fn min_args(&self) -> usize {
1371        3
1372    }
1373
1374    fn variadic(&self) -> bool {
1375        true
1376    }
1377
1378    fn arg_schema(&self) -> &'static [ArgSchema] {
1379        &ARG_RANGE_NUM_LENIENT_ONE[..]
1380    }
1381
1382    fn eval<'a, 'b, 'c>(
1383        &self,
1384        args: &'c [ArgumentHandle<'a, 'b>],
1385        ctx: &dyn FunctionContext<'b>,
1386    ) -> Result<crate::traits::CalcValue<'b>, ExcelError> {
1387        if args.len() < 3 {
1388            return Ok(crate::traits::CalcValue::Scalar(LiteralValue::Error(
1389                ExcelError::new_value(),
1390            )));
1391        }
1392
1393        let function_num = match parse_strict_int_arg(&args[0]) {
1394            Ok(v) => v,
1395            Err(e) => return Ok(crate::traits::CalcValue::Scalar(LiteralValue::Error(e))),
1396        };
1397
1398        let op = if (1..=11).contains(&function_num) {
1399            aggregate_op_from_function_num(function_num)
1400                .expect("validated AGGREGATE function_num maps to operation")
1401        } else if (12..=19).contains(&function_num) {
1402            return Ok(crate::traits::CalcValue::Scalar(LiteralValue::Error(
1403                ExcelError::new(ExcelErrorKind::NImpl),
1404            )));
1405        } else {
1406            return Ok(crate::traits::CalcValue::Scalar(LiteralValue::Error(
1407                ExcelError::new_value(),
1408            )));
1409        };
1410
1411        let options = match parse_strict_int_arg(&args[1]) {
1412            Ok(v) => v,
1413            Err(e) => return Ok(crate::traits::CalcValue::Scalar(LiteralValue::Error(e))),
1414        };
1415
1416        let (visibility, error_policy) = match options {
1417            0 => (VisibilityPolicy::IncludeAll, ErrorPolicy::Propagate),
1418            1 => (
1419                VisibilityPolicy::ExcludeManualOrFilterHidden,
1420                ErrorPolicy::Propagate,
1421            ),
1422            2 => (VisibilityPolicy::IncludeAll, ErrorPolicy::Ignore),
1423            3 => (
1424                VisibilityPolicy::ExcludeManualOrFilterHidden,
1425                ErrorPolicy::Ignore,
1426            ),
1427            4..=7 => {
1428                return Ok(crate::traits::CalcValue::Scalar(LiteralValue::Error(
1429                    ExcelError::new(ExcelErrorKind::NImpl),
1430                )));
1431            }
1432            _ => {
1433                return Ok(crate::traits::CalcValue::Scalar(LiteralValue::Error(
1434                    ExcelError::new_value(),
1435                )));
1436            }
1437        };
1438
1439        let collected =
1440            match AggregateCollector::collect_args(args, 2, ctx, op, visibility, error_policy) {
1441                Ok(c) => c,
1442                Err(e) => return Ok(crate::traits::CalcValue::Scalar(LiteralValue::Error(e))),
1443            };
1444
1445        Ok(crate::traits::CalcValue::Scalar(collected.finalize(op)))
1446    }
1447}
1448
1449#[cfg(test)]
1450mod tests_subtotal_aggregate {
1451    use super::*;
1452    use crate::test_workbook::TestWorkbook;
1453    use crate::traits::ArgumentHandle;
1454    use formualizer_common::{ExcelErrorKind, LiteralValue};
1455    use formualizer_parse::parser::{ASTNode, ASTNodeType};
1456
1457    fn interp(wb: &TestWorkbook) -> crate::interpreter::Interpreter<'_> {
1458        wb.interpreter()
1459    }
1460
1461    fn lit(value: LiteralValue) -> ASTNode {
1462        ASTNode::new(ASTNodeType::Literal(value), None)
1463    }
1464
1465    fn dispatch(
1466        ctx: &crate::interpreter::Interpreter<'_>,
1467        fn_name: &str,
1468        nodes: &[ASTNode],
1469    ) -> LiteralValue {
1470        let args: Vec<_> = nodes.iter().map(|n| ArgumentHandle::new(n, ctx)).collect();
1471        let f = ctx.context.get_function("", fn_name).expect("function");
1472        f.dispatch(&args, &ctx.function_context(None))
1473            .expect("dispatch")
1474            .into_literal()
1475    }
1476
1477    fn assert_num_close(value: LiteralValue, expected: f64) {
1478        match value {
1479            LiteralValue::Number(n) => assert!((n - expected).abs() < 1e-9, "{n} != {expected}"),
1480            LiteralValue::Int(i) => assert!(((i as f64) - expected).abs() < 1e-9),
1481            other => panic!("expected numeric {expected}, got {other:?}"),
1482        }
1483    }
1484
1485    fn assert_error_kind(value: LiteralValue, expected: ExcelErrorKind) {
1486        match value {
1487            LiteralValue::Error(e) => assert_eq!(e.kind, expected),
1488            other => panic!("expected error {:?}, got {other:?}", expected),
1489        }
1490    }
1491
1492    #[test]
1493    fn subtotal_function_num_mapping_basics() {
1494        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(SubtotalFn));
1495        let ctx = interp(&wb);
1496        let values = LiteralValue::Array(vec![vec![
1497            LiteralValue::Int(1),
1498            LiteralValue::Int(2),
1499            LiteralValue::Int(3),
1500        ]]);
1501
1502        let cases: &[(i64, f64)] = &[
1503            (1, 2.0),
1504            (2, 3.0),
1505            (3, 3.0),
1506            (4, 3.0),
1507            (5, 1.0),
1508            (6, 6.0),
1509            (7, 1.0),
1510            (8, (2.0f64 / 3.0).sqrt()),
1511            (9, 6.0),
1512            (10, 1.0),
1513            (11, 2.0 / 3.0),
1514        ];
1515
1516        for (code, expected) in cases {
1517            let args = vec![lit(LiteralValue::Int(*code)), lit(values.clone())];
1518            let out = dispatch(&ctx, "SUBTOTAL", &args);
1519            assert_num_close(out, *expected);
1520        }
1521    }
1522
1523    #[test]
1524    fn subtotal_counta_counts_errors_as_non_empty() {
1525        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(SubtotalFn));
1526        let ctx = interp(&wb);
1527
1528        let args = vec![
1529            lit(LiteralValue::Int(3)),
1530            lit(LiteralValue::Array(vec![vec![
1531                LiteralValue::Int(1),
1532                LiteralValue::Error(ExcelError::new_div()),
1533                LiteralValue::Text("x".into()),
1534                LiteralValue::Text("".into()),
1535            ]])),
1536        ];
1537        let out = dispatch(&ctx, "SUBTOTAL", &args);
1538        assert_num_close(out, 4.0);
1539    }
1540
1541    #[test]
1542    fn subtotal_invalid_function_num_returns_value_error() {
1543        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(SubtotalFn));
1544        let ctx = interp(&wb);
1545
1546        let args = vec![
1547            lit(LiteralValue::Number(9.5)),
1548            lit(LiteralValue::Array(vec![vec![LiteralValue::Int(1)]])),
1549        ];
1550        let out = dispatch(&ctx, "SUBTOTAL", &args);
1551        assert_error_kind(out, ExcelErrorKind::Value);
1552    }
1553
1554    #[test]
1555    fn subtotal_requires_ref_argument() {
1556        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(SubtotalFn));
1557        let ctx = interp(&wb);
1558
1559        let out = dispatch(&ctx, "SUBTOTAL", &[lit(LiteralValue::Int(9))]);
1560        assert_error_kind(out, ExcelErrorKind::Value);
1561    }
1562
1563    #[test]
1564    fn aggregate_requires_options_and_ref_argument() {
1565        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(AggregateFn));
1566        let ctx = interp(&wb);
1567
1568        let out = dispatch(
1569            &ctx,
1570            "AGGREGATE",
1571            &[lit(LiteralValue::Int(9)), lit(LiteralValue::Int(0))],
1572        );
1573        assert_error_kind(out, ExcelErrorKind::Value);
1574    }
1575
1576    #[test]
1577    fn aggregate_options_zero_to_three_control_error_behavior() {
1578        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(AggregateFn));
1579        let ctx = interp(&wb);
1580        let values = LiteralValue::Array(vec![vec![
1581            LiteralValue::Int(10),
1582            LiteralValue::Error(ExcelError::new_div()),
1583            LiteralValue::Int(30),
1584        ]]);
1585
1586        let opt0 = dispatch(
1587            &ctx,
1588            "AGGREGATE",
1589            &[
1590                lit(LiteralValue::Int(9)),
1591                lit(LiteralValue::Int(0)),
1592                lit(values.clone()),
1593            ],
1594        );
1595        assert_error_kind(opt0, ExcelErrorKind::Div);
1596
1597        let opt1 = dispatch(
1598            &ctx,
1599            "AGGREGATE",
1600            &[
1601                lit(LiteralValue::Int(9)),
1602                lit(LiteralValue::Int(1)),
1603                lit(values.clone()),
1604            ],
1605        );
1606        assert_error_kind(opt1, ExcelErrorKind::Div);
1607
1608        let opt2 = dispatch(
1609            &ctx,
1610            "AGGREGATE",
1611            &[
1612                lit(LiteralValue::Int(9)),
1613                lit(LiteralValue::Int(2)),
1614                lit(values.clone()),
1615            ],
1616        );
1617        assert_num_close(opt2, 40.0);
1618
1619        let opt3 = dispatch(
1620            &ctx,
1621            "AGGREGATE",
1622            &[
1623                lit(LiteralValue::Int(9)),
1624                lit(LiteralValue::Int(3)),
1625                lit(values),
1626            ],
1627        );
1628        assert_num_close(opt3, 40.0);
1629    }
1630
1631    #[test]
1632    fn aggregate_counta_option_ignore_errors_skips_error_values() {
1633        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(AggregateFn));
1634        let ctx = interp(&wb);
1635
1636        let out = dispatch(
1637            &ctx,
1638            "AGGREGATE",
1639            &[
1640                lit(LiteralValue::Int(3)),
1641                lit(LiteralValue::Int(2)),
1642                lit(LiteralValue::Array(vec![vec![
1643                    LiteralValue::Int(1),
1644                    LiteralValue::Error(ExcelError::new_div()),
1645                    LiteralValue::Text("x".into()),
1646                ]])),
1647            ],
1648        );
1649        assert_num_close(out, 2.0);
1650    }
1651
1652    #[test]
1653    fn aggregate_unsupported_option_returns_nimpl() {
1654        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(AggregateFn));
1655        let ctx = interp(&wb);
1656
1657        let out = dispatch(
1658            &ctx,
1659            "AGGREGATE",
1660            &[
1661                lit(LiteralValue::Int(9)),
1662                lit(LiteralValue::Int(4)),
1663                lit(LiteralValue::Array(vec![vec![LiteralValue::Int(1)]])),
1664            ],
1665        );
1666        assert_error_kind(out, ExcelErrorKind::NImpl);
1667    }
1668
1669    #[test]
1670    fn aggregate_unsupported_function_num_returns_nimpl() {
1671        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(AggregateFn));
1672        let ctx = interp(&wb);
1673
1674        let out = dispatch(
1675            &ctx,
1676            "AGGREGATE",
1677            &[
1678                lit(LiteralValue::Int(12)),
1679                lit(LiteralValue::Int(0)),
1680                lit(LiteralValue::Array(vec![vec![LiteralValue::Int(1)]])),
1681            ],
1682        );
1683        assert_error_kind(out, ExcelErrorKind::NImpl);
1684    }
1685}
1686
1687pub fn register_builtins() {
1688    crate::function_registry::register_function(std::sync::Arc::new(SumProductFn));
1689    crate::function_registry::register_function(std::sync::Arc::new(SumFn));
1690    crate::function_registry::register_function(std::sync::Arc::new(CountFn));
1691    crate::function_registry::register_function(std::sync::Arc::new(AverageFn));
1692    crate::function_registry::register_function(std::sync::Arc::new(SubtotalFn));
1693    crate::function_registry::register_function(std::sync::Arc::new(AggregateFn));
1694}