Skip to main content

formualizer_eval/builtins/math/
aggregate.rs

1use super::super::utils::{ARG_RANGE_NUM_LENIENT_ONE, coerce_num};
2use crate::args::ArgSchema;
3use crate::engine::VisibilityMaskMode;
4use crate::function::Function;
5use crate::function_contract::FunctionDependencyContract;
6use crate::traits::{ArgumentHandle, FunctionContext};
7use arrow_array::Array;
8use formualizer_common::{ExcelError, ExcelErrorKind, LiteralValue};
9use formualizer_macros::func_caps;
10
11/* ─────────────────────────── SUM() ──────────────────────────── */
12
13#[derive(Debug)]
14pub struct SumFn;
15
16/// Adds numeric values across scalars and ranges.
17///
18/// `SUM` evaluates all arguments, coercing text to numbers where possible,
19/// and returns the total. Blank cells and logical values in ranges are ignored.
20///
21/// # Remarks
22/// - If any argument evaluates to an error, `SUM` propagates the first error it encounters.
23/// - Unparseable text literals (e.g., `"foo"`) will result in a `#VALUE!` error.
24///
25/// # Examples
26///
27/// ```yaml,sandbox
28/// title: "Basic scalar addition"
29/// formula: "=SUM(10, 20, 5)"
30/// expected: 35
31/// ```
32///
33/// ```yaml,sandbox
34/// title: "Summing a range"
35/// grid:
36///   A1: 10
37///   A2: 20
38///   A3: "N/A"
39/// formula: "=SUM(A1:A3)"
40/// expected: 30
41/// ```
42///
43/// ```yaml,docs
44/// related:
45///   - SUMIF
46///   - SUMIFS
47///   - SUMPRODUCT
48///   - AVERAGE
49/// faq:
50///   - q: "Why does SUM return #VALUE! for some text arguments?"
51///     a: "Direct scalar text that cannot be parsed as a number raises #VALUE! during coercion."
52///   - q: "Do text and logical values inside ranges get added?"
53///     a: "No. In ranged inputs, only numeric cells contribute to the total."
54/// ```
55///
56/// [formualizer-docgen:schema:start]
57/// Name: SUM
58/// Type: SumFn
59/// Min args: 0
60/// Max args: variadic
61/// Variadic: true
62/// Signature: SUM(arg1...: number@range)
63/// Arg schema: arg1{kinds=number,required=true,shape=range,by_ref=false,coercion=NumberLenientText,max=None,repeating=None,default=false}
64/// Caps: PURE, REDUCTION, NUMERIC_ONLY, STREAM_OK, PARALLEL_ARGS
65/// [formualizer-docgen:schema:end]
66impl Function for SumFn {
67    func_caps!(PURE, REDUCTION, NUMERIC_ONLY, STREAM_OK, PARALLEL_ARGS);
68
69    fn name(&self) -> &'static str {
70        "SUM"
71    }
72    fn min_args(&self) -> usize {
73        0
74    }
75    fn variadic(&self) -> bool {
76        true
77    }
78    fn dependency_contract(&self, arity: usize) -> Option<FunctionDependencyContract> {
79        FunctionDependencyContract::static_reduction(arity, self.min_args())
80    }
81    fn arg_schema(&self) -> &'static [ArgSchema] {
82        &ARG_RANGE_NUM_LENIENT_ONE[..]
83    }
84
85    fn eval<'a, 'b, 'c>(
86        &self,
87        args: &'c [ArgumentHandle<'a, 'b>],
88        ctx: &dyn FunctionContext<'b>,
89    ) -> Result<crate::traits::CalcValue<'b>, ExcelError> {
90        let mut total = 0.0;
91        for arg in args {
92            if let Ok(view) = arg.range_view() {
93                // Propagate errors from range first
94                for res in view.errors_slices() {
95                    let (_, _, err_cols) = res?;
96                    for col in err_cols {
97                        if col.null_count() < col.len() {
98                            for i in 0..col.len() {
99                                if !col.is_null(i) {
100                                    return Ok(crate::traits::CalcValue::Scalar(
101                                        LiteralValue::Error(ExcelError::new(
102                                            crate::arrow_store::unmap_error_code(col.value(i)),
103                                        )),
104                                    ));
105                                }
106                            }
107                        }
108                    }
109                }
110
111                for res in view.numbers_slices() {
112                    let (_, _, num_cols) = res?;
113                    for col in num_cols {
114                        total +=
115                            arrow::compute::kernels::aggregate::sum(col.as_ref()).unwrap_or(0.0);
116                    }
117                }
118            } else {
119                let v = arg.value()?.into_literal();
120                match v {
121                    LiteralValue::Error(e) => {
122                        return Ok(crate::traits::CalcValue::Scalar(LiteralValue::Error(e)));
123                    }
124                    v => total += coerce_num(&v)?,
125                }
126            }
127        }
128        Ok(crate::traits::CalcValue::Scalar(
129            super::super::utils::aggregate_result(total),
130        ))
131    }
132}
133
134/* ─────────────────────────── COUNT() ──────────────────────────── */
135
136#[derive(Debug)]
137pub struct CountFn;
138
139/// Counts numeric values across scalars and ranges.
140///
141/// `COUNT` evaluates all arguments and counts how many are numeric values.
142/// Numbers, dates, and text representations of numbers (when supplied directly) are counted.
143///
144/// # Remarks
145/// - Text values inside ranges are ignored and not counted.
146/// - Blank cells and logical values in ranges are ignored.
147///
148/// # Examples
149///
150/// ```yaml,sandbox
151/// title: "Counting mixed scalar inputs"
152/// formula: "=COUNT(1, \"x\", 2, 3)"
153/// expected: 3
154/// ```
155///
156/// ```yaml,sandbox
157/// title: "Counting in a range"
158/// grid:
159///   A1: 10
160///   A2: "foo"
161///   A3: 20
162/// formula: "=COUNT(A1:A3)"
163/// expected: 2
164/// ```
165///
166/// ```yaml,docs
167/// related:
168///   - COUNTA
169///   - COUNTBLANK
170///   - COUNTIF
171///   - COUNTIFS
172/// faq:
173///   - q: "Why doesn't COUNT include text in a range?"
174///     a: "COUNT only counts numeric values; text cells in ranges are ignored."
175///   - q: "Can direct text like \"12\" be counted?"
176///     a: "Yes. Direct scalar arguments are coerced and counted when they parse as numbers."
177/// ```
178///
179/// [formualizer-docgen:schema:start]
180/// Name: COUNT
181/// Type: CountFn
182/// Min args: 0
183/// Max args: variadic
184/// Variadic: true
185/// Signature: COUNT(arg1...: number@range)
186/// Arg schema: arg1{kinds=number,required=true,shape=range,by_ref=false,coercion=NumberLenientText,max=None,repeating=None,default=false}
187/// Caps: PURE, REDUCTION, NUMERIC_ONLY, STREAM_OK
188/// [formualizer-docgen:schema:end]
189impl Function for CountFn {
190    func_caps!(PURE, REDUCTION, NUMERIC_ONLY, STREAM_OK);
191
192    fn name(&self) -> &'static str {
193        "COUNT"
194    }
195    fn min_args(&self) -> usize {
196        0
197    }
198    fn variadic(&self) -> bool {
199        true
200    }
201    fn dependency_contract(&self, arity: usize) -> Option<FunctionDependencyContract> {
202        FunctionDependencyContract::static_reduction(arity, self.min_args())
203    }
204    fn arg_schema(&self) -> &'static [ArgSchema] {
205        &ARG_RANGE_NUM_LENIENT_ONE[..]
206    }
207
208    fn eval<'a, 'b, 'c>(
209        &self,
210        args: &'c [ArgumentHandle<'a, 'b>],
211        _: &dyn FunctionContext<'b>,
212    ) -> Result<crate::traits::CalcValue<'b>, ExcelError> {
213        let mut count: i64 = 0;
214        for arg in args {
215            if let Ok(view) = arg.range_view() {
216                for res in view.numbers_slices() {
217                    let (_, _, num_cols) = res?;
218                    for col in num_cols {
219                        count += (col.len() - col.null_count()) as i64;
220                    }
221                }
222            } else {
223                let v = arg.value()?.into_literal();
224                if let LiteralValue::Error(e) = v {
225                    return Ok(crate::traits::CalcValue::Scalar(LiteralValue::Error(e)));
226                }
227                if !matches!(v, LiteralValue::Empty) && coerce_num(&v).is_ok() {
228                    count += 1;
229                }
230            }
231        }
232        Ok(crate::traits::CalcValue::Scalar(LiteralValue::Number(
233            count as f64,
234        )))
235    }
236}
237
238/* ─────────────────────────── AVERAGE() ──────────────────────────── */
239
240#[derive(Debug)]
241pub struct AverageFn;
242
243/// Returns the arithmetic mean of numeric values across scalars and ranges.
244///
245/// `AVERAGE` sums numeric inputs and divides by the count of numeric values that participated.
246///
247/// # Remarks
248/// - Errors in any scalar argument or referenced range propagate immediately.
249/// - In ranges, only numeric/date-time serial values are included; text and blanks are ignored.
250/// - Scalar arguments use lenient number coercion with locale support.
251/// - If no numeric values are found, `AVERAGE` returns `#DIV/0!`.
252///
253/// # Examples
254///
255/// ```yaml,sandbox
256/// title: "Average of scalar values"
257/// formula: "=AVERAGE(10, 20, 5)"
258/// expected: 11.666666666666666
259/// ```
260///
261/// ```yaml,sandbox
262/// title: "Average over a mixed range"
263/// grid:
264///   A1: 10
265///   A2: "x"
266///   A3: 20
267/// formula: "=AVERAGE(A1:A3)"
268/// expected: 15
269/// ```
270///
271/// ```yaml,sandbox
272/// title: "No numeric values returns divide-by-zero"
273/// formula: "=AVERAGE(\"x\", \"\")"
274/// expected: "#DIV/0!"
275/// ```
276///
277/// ```yaml,docs
278/// related:
279///   - SUM
280///   - COUNT
281///   - AVERAGEIF
282///   - AVERAGEIFS
283/// faq:
284///   - q: "When does AVERAGE return #DIV/0!?"
285///     a: "It returns #DIV/0! when no numeric values are found after filtering/coercion."
286///   - q: "Do text cells in ranges affect the denominator?"
287///     a: "No. Only numeric values are counted toward the divisor."
288/// ```
289///
290/// [formualizer-docgen:schema:start]
291/// Name: AVERAGE
292/// Type: AverageFn
293/// Min args: 1
294/// Max args: variadic
295/// Variadic: true
296/// Signature: AVERAGE(arg1...: number@range)
297/// Arg schema: arg1{kinds=number,required=true,shape=range,by_ref=false,coercion=NumberLenientText,max=None,repeating=None,default=false}
298/// Caps: PURE, REDUCTION, NUMERIC_ONLY, STREAM_OK
299/// [formualizer-docgen:schema:end]
300impl Function for AverageFn {
301    func_caps!(PURE, REDUCTION, NUMERIC_ONLY, STREAM_OK);
302
303    fn name(&self) -> &'static str {
304        "AVERAGE"
305    }
306    fn min_args(&self) -> usize {
307        1
308    }
309    fn variadic(&self) -> bool {
310        true
311    }
312    fn dependency_contract(&self, arity: usize) -> Option<FunctionDependencyContract> {
313        FunctionDependencyContract::static_reduction(arity, self.min_args())
314    }
315    fn arg_schema(&self) -> &'static [ArgSchema] {
316        &ARG_RANGE_NUM_LENIENT_ONE[..]
317    }
318
319    fn eval<'a, 'b, 'c>(
320        &self,
321        args: &'c [ArgumentHandle<'a, 'b>],
322        ctx: &dyn FunctionContext<'b>,
323    ) -> Result<crate::traits::CalcValue<'b>, ExcelError> {
324        let mut sum = 0.0f64;
325        let mut cnt: i64 = 0;
326        for arg in args {
327            if let Ok(view) = arg.range_view() {
328                // Propagate errors from range first
329                for res in view.errors_slices() {
330                    let (_, _, err_cols) = res?;
331                    for col in err_cols {
332                        if col.null_count() < col.len() {
333                            for i in 0..col.len() {
334                                if !col.is_null(i) {
335                                    return Ok(crate::traits::CalcValue::Scalar(
336                                        LiteralValue::Error(ExcelError::new(
337                                            crate::arrow_store::unmap_error_code(col.value(i)),
338                                        )),
339                                    ));
340                                }
341                            }
342                        }
343                    }
344                }
345
346                for res in view.numbers_slices() {
347                    let (_, _, num_cols) = res?;
348                    for col in num_cols {
349                        sum += arrow::compute::kernels::aggregate::sum(col.as_ref()).unwrap_or(0.0);
350                        cnt += (col.len() - col.null_count()) as i64;
351                    }
352                }
353            } else {
354                let v = arg.value()?.into_literal();
355                if let LiteralValue::Error(e) = v {
356                    return Ok(crate::traits::CalcValue::Scalar(LiteralValue::Error(e)));
357                }
358                if let Ok(n) = crate::coercion::to_number_lenient_with_locale(&v, &ctx.locale()) {
359                    sum += n;
360                    cnt += 1;
361                }
362            }
363        }
364        if cnt == 0 {
365            return Ok(crate::traits::CalcValue::Scalar(LiteralValue::Error(
366                ExcelError::new_div(),
367            )));
368        }
369        Ok(crate::traits::CalcValue::Scalar(
370            super::super::utils::aggregate_result(sum / (cnt as f64)),
371        ))
372    }
373}
374
375/* ──────────────────────── SUMPRODUCT() ───────────────────────── */
376
377#[derive(Debug)]
378pub struct SumProductFn;
379
380/// Multiplies aligned values across arrays and returns the sum of those products.
381///
382/// `SUMPRODUCT` supports scalar or range inputs, applies broadcast semantics, and accumulates
383/// the product for each aligned cell position.
384///
385/// # Remarks
386/// - Input shapes must be broadcast-compatible; otherwise `SUMPRODUCT` returns `#VALUE!`.
387/// - Non-numeric values are treated as `0` during multiplication.
388/// - Any explicit error value in the inputs propagates immediately.
389///
390/// # Examples
391///
392/// ```yaml,sandbox
393/// title: "Pairwise sum of products"
394/// formula: "=SUMPRODUCT({1,2,3}, {4,5,6})"
395/// expected: 32
396/// ```
397///
398/// ```yaml,sandbox
399/// title: "Range-based sumproduct"
400/// grid:
401///   A1: 2
402///   A2: 3
403///   A3: 4
404///   B1: 10
405///   B2: 20
406///   B3: 30
407/// formula: "=SUMPRODUCT(A1:A3, B1:B3)"
408/// expected: 200
409/// ```
410///
411/// ```yaml,sandbox
412/// title: "Text entries contribute zero"
413/// formula: "=SUMPRODUCT({1,\"x\",3}, {1,1,1})"
414/// expected: 4
415/// ```
416///
417/// ```yaml,docs
418/// related:
419///   - SUM
420///   - PRODUCT
421///   - MMULT
422///   - SUMIFS
423/// faq:
424///   - q: "Why does SUMPRODUCT return #VALUE! with some array shapes?"
425///     a: "The argument arrays must be broadcast-compatible; incompatible shapes raise #VALUE!."
426///   - q: "How are text values handled in multiplication?"
427///     a: "Non-numeric values are treated as 0, unless an explicit error is present."
428/// ```
429///
430/// [formualizer-docgen:schema:start]
431/// Name: SUMPRODUCT
432/// Type: SumProductFn
433/// Min args: 1
434/// Max args: variadic
435/// Variadic: true
436/// Signature: SUMPRODUCT(arg1...: number@range)
437/// Arg schema: arg1{kinds=number,required=true,shape=range,by_ref=false,coercion=NumberLenientText,max=None,repeating=None,default=false}
438/// Caps: PURE, REDUCTION
439/// [formualizer-docgen:schema:end]
440impl Function for SumProductFn {
441    // Pure reduction over arrays; uses broadcasting and lenient coercion
442    func_caps!(PURE, REDUCTION);
443
444    fn name(&self) -> &'static str {
445        "SUMPRODUCT"
446    }
447    fn min_args(&self) -> usize {
448        1
449    }
450    fn variadic(&self) -> bool {
451        true
452    }
453    fn arg_schema(&self) -> &'static [ArgSchema] {
454        // Accept ranges or scalars; numeric lenient coercion
455        &ARG_RANGE_NUM_LENIENT_ONE[..]
456    }
457
458    fn eval<'a, 'b, 'c>(
459        &self,
460        args: &'c [ArgumentHandle<'a, 'b>],
461        _: &dyn FunctionContext<'b>,
462    ) -> Result<crate::traits::CalcValue<'b>, ExcelError> {
463        use crate::broadcast::{broadcast_shape, project_index};
464
465        if args.is_empty() {
466            return Ok(crate::traits::CalcValue::Scalar(LiteralValue::Number(0.0)));
467        }
468
469        // Helper: materialize an argument to a 2D array of LiteralValue
470        let to_array = |ah: &ArgumentHandle| -> Result<Vec<Vec<LiteralValue>>, ExcelError> {
471            if let Ok(rv) = ah.range_view() {
472                let mut rows: Vec<Vec<LiteralValue>> = Vec::new();
473                rv.for_each_row(&mut |row| {
474                    rows.push(row.to_vec());
475                    Ok(())
476                })?;
477                Ok(rows)
478            } else {
479                let v = ah.value()?.into_literal();
480                Ok(match v {
481                    LiteralValue::Array(arr) => arr,
482                    other => vec![vec![other]],
483                })
484            }
485        };
486
487        // Collect arrays and shapes
488        let mut arrays: Vec<Vec<Vec<LiteralValue>>> = Vec::with_capacity(args.len());
489        let mut shapes: Vec<(usize, usize)> = Vec::with_capacity(args.len());
490        for a in args.iter() {
491            let arr = to_array(a)?;
492            let shape = (arr.len(), arr.first().map(|r| r.len()).unwrap_or(0));
493            arrays.push(arr);
494            shapes.push(shape);
495        }
496
497        // Compute broadcast target shape across all args
498        let target = match broadcast_shape(&shapes) {
499            Ok(s) => s,
500            Err(_) => {
501                return Ok(crate::traits::CalcValue::Scalar(LiteralValue::Error(
502                    ExcelError::new_value(),
503                )));
504            }
505        };
506
507        // Iterate target shape, multiply coerced values across args, sum total
508        let mut total = 0.0f64;
509        for r in 0..target.0 {
510            for c in 0..target.1 {
511                let mut prod = 1.0f64;
512                for (arr, &shape) in arrays.iter().zip(shapes.iter()) {
513                    let (rr, cc) = project_index((r, c), shape);
514                    let lv = arr
515                        .get(rr)
516                        .and_then(|row| row.get(cc))
517                        .cloned()
518                        .unwrap_or(LiteralValue::Empty);
519                    match lv {
520                        LiteralValue::Error(e) => {
521                            return Ok(crate::traits::CalcValue::Scalar(LiteralValue::Error(e)));
522                        }
523                        _ => match super::super::utils::coerce_num(&lv) {
524                            Ok(n) => {
525                                prod *= n;
526                            }
527                            Err(_) => {
528                                // Non-numeric -> treated as 0 in SUMPRODUCT
529                                prod *= 0.0;
530                            }
531                        },
532                    }
533                }
534                total += prod;
535            }
536        }
537        Ok(crate::traits::CalcValue::Scalar(
538            super::super::utils::aggregate_result(total),
539        ))
540    }
541}
542
543#[cfg(test)]
544mod tests_sumproduct {
545    use super::*;
546    use crate::test_workbook::TestWorkbook;
547    use crate::traits::ArgumentHandle;
548    use formualizer_parse::LiteralValue;
549    use formualizer_parse::parser::{ASTNode, ASTNodeType};
550
551    fn interp(wb: &TestWorkbook) -> crate::interpreter::Interpreter<'_> {
552        wb.interpreter()
553    }
554
555    fn arr(vals: Vec<Vec<LiteralValue>>) -> ASTNode {
556        ASTNode::new(ASTNodeType::Literal(LiteralValue::Array(vals)), None)
557    }
558
559    fn num(n: f64) -> ASTNode {
560        ASTNode::new(ASTNodeType::Literal(LiteralValue::Number(n)), None)
561    }
562
563    #[test]
564    fn sumproduct_basic_pairwise() {
565        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(SumProductFn));
566        let ctx = interp(&wb);
567        // {1,2,3} * {4,5,6} = 1*4 + 2*5 + 3*6 = 32
568        let a = arr(vec![vec![
569            LiteralValue::Int(1),
570            LiteralValue::Int(2),
571            LiteralValue::Int(3),
572        ]]);
573        let b = arr(vec![vec![
574            LiteralValue::Int(4),
575            LiteralValue::Int(5),
576            LiteralValue::Int(6),
577        ]]);
578        let args = vec![ArgumentHandle::new(&a, &ctx), ArgumentHandle::new(&b, &ctx)];
579        let f = ctx.context.get_function("", "SUMPRODUCT").unwrap();
580        assert_eq!(
581            f.dispatch(&args, &ctx.function_context(None))
582                .unwrap()
583                .into_literal(),
584            LiteralValue::Number(32.0)
585        );
586    }
587
588    #[test]
589    fn sumproduct_variadic_three_arrays() {
590        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(SumProductFn));
591        let ctx = interp(&wb);
592        // {1,2} * {3,4} * {2,2} = (1*3*2) + (2*4*2) = 6 + 16 = 22
593        let a = arr(vec![vec![LiteralValue::Int(1), LiteralValue::Int(2)]]);
594        let b = arr(vec![vec![LiteralValue::Int(3), LiteralValue::Int(4)]]);
595        let c = arr(vec![vec![LiteralValue::Int(2), LiteralValue::Int(2)]]);
596        let args = vec![
597            ArgumentHandle::new(&a, &ctx),
598            ArgumentHandle::new(&b, &ctx),
599            ArgumentHandle::new(&c, &ctx),
600        ];
601        let f = ctx.context.get_function("", "SUMPRODUCT").unwrap();
602        assert_eq!(
603            f.dispatch(&args, &ctx.function_context(None))
604                .unwrap()
605                .into_literal(),
606            LiteralValue::Number(22.0)
607        );
608    }
609
610    #[test]
611    fn sumproduct_broadcast_scalar_over_array() {
612        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(SumProductFn));
613        let ctx = interp(&wb);
614        // {1,2,3} * 10 => (1*10 + 2*10 + 3*10) = 60
615        let a = arr(vec![vec![
616            LiteralValue::Int(1),
617            LiteralValue::Int(2),
618            LiteralValue::Int(3),
619        ]]);
620        let s = num(10.0);
621        let args = vec![ArgumentHandle::new(&a, &ctx), ArgumentHandle::new(&s, &ctx)];
622        let f = ctx.context.get_function("", "SUMPRODUCT").unwrap();
623        assert_eq!(
624            f.dispatch(&args, &ctx.function_context(None))
625                .unwrap()
626                .into_literal(),
627            LiteralValue::Number(60.0)
628        );
629    }
630
631    #[test]
632    fn sumproduct_2d_arrays_broadcast_rows_cols() {
633        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(SumProductFn));
634        let ctx = interp(&wb);
635        // A is 2x2, B is 1x2 -> broadcast B across rows
636        // A = [[1,2],[3,4]], B = [[10,20]]
637        // sum = 1*10 + 2*20 + 3*10 + 4*20 = 10 + 40 + 30 + 80 = 160
638        let a = arr(vec![
639            vec![LiteralValue::Int(1), LiteralValue::Int(2)],
640            vec![LiteralValue::Int(3), LiteralValue::Int(4)],
641        ]);
642        let b = arr(vec![vec![LiteralValue::Int(10), LiteralValue::Int(20)]]);
643        let args = vec![ArgumentHandle::new(&a, &ctx), ArgumentHandle::new(&b, &ctx)];
644        let f = ctx.context.get_function("", "SUMPRODUCT").unwrap();
645        assert_eq!(
646            f.dispatch(&args, &ctx.function_context(None))
647                .unwrap()
648                .into_literal(),
649            LiteralValue::Number(160.0)
650        );
651    }
652
653    #[test]
654    fn sumproduct_non_numeric_treated_as_zero() {
655        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(SumProductFn));
656        let ctx = interp(&wb);
657        // {1,"x",3} * {1,1,1} => 1*1 + 0*1 + 3*1 = 4
658        let a = arr(vec![vec![
659            LiteralValue::Int(1),
660            LiteralValue::Text("x".into()),
661            LiteralValue::Int(3),
662        ]]);
663        let b = arr(vec![vec![
664            LiteralValue::Int(1),
665            LiteralValue::Int(1),
666            LiteralValue::Int(1),
667        ]]);
668        let args = vec![ArgumentHandle::new(&a, &ctx), ArgumentHandle::new(&b, &ctx)];
669        let f = ctx.context.get_function("", "SUMPRODUCT").unwrap();
670        assert_eq!(
671            f.dispatch(&args, &ctx.function_context(None))
672                .unwrap()
673                .into_literal(),
674            LiteralValue::Number(4.0)
675        );
676    }
677
678    #[test]
679    fn sumproduct_error_in_input_propagates() {
680        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(SumProductFn));
681        let ctx = interp(&wb);
682        let a = arr(vec![vec![LiteralValue::Int(1), LiteralValue::Int(2)]]);
683        let e = ASTNode::new(
684            ASTNodeType::Literal(LiteralValue::Error(ExcelError::new_na())),
685            None,
686        );
687        let args = vec![ArgumentHandle::new(&a, &ctx), ArgumentHandle::new(&e, &ctx)];
688        let f = ctx.context.get_function("", "SUMPRODUCT").unwrap();
689        match f
690            .dispatch(&args, &ctx.function_context(None))
691            .unwrap()
692            .into_literal()
693        {
694            LiteralValue::Error(err) => assert_eq!(err, "#N/A"),
695            v => panic!("expected error, got {v:?}"),
696        }
697    }
698
699    #[test]
700    fn sumproduct_incompatible_shapes_value_error() {
701        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(SumProductFn));
702        let ctx = interp(&wb);
703        // 1x3 and 1x2 -> #VALUE!
704        let a = arr(vec![vec![
705            LiteralValue::Int(1),
706            LiteralValue::Int(2),
707            LiteralValue::Int(3),
708        ]]);
709        let b = arr(vec![vec![LiteralValue::Int(4), LiteralValue::Int(5)]]);
710        let args = vec![ArgumentHandle::new(&a, &ctx), ArgumentHandle::new(&b, &ctx)];
711        let f = ctx.context.get_function("", "SUMPRODUCT").unwrap();
712        match f
713            .dispatch(&args, &ctx.function_context(None))
714            .unwrap()
715            .into_literal()
716        {
717            LiteralValue::Error(e) => assert_eq!(e, "#VALUE!"),
718            v => panic!("expected value error, got {v:?}"),
719        }
720    }
721}
722
723#[cfg(test)]
724mod tests {
725    use super::*;
726    use crate::test_workbook::TestWorkbook;
727    use formualizer_parse::LiteralValue;
728
729    fn interp(wb: &TestWorkbook) -> crate::interpreter::Interpreter<'_> {
730        wb.interpreter()
731    }
732
733    #[test]
734    fn test_sum_caps() {
735        let sum_fn = SumFn;
736        let caps = sum_fn.caps();
737
738        // Check that the expected capabilities are set
739        assert!(caps.contains(crate::function::FnCaps::PURE));
740        assert!(caps.contains(crate::function::FnCaps::REDUCTION));
741        assert!(caps.contains(crate::function::FnCaps::NUMERIC_ONLY));
742        assert!(caps.contains(crate::function::FnCaps::STREAM_OK));
743
744        // Check that other caps are not set
745        assert!(!caps.contains(crate::function::FnCaps::VOLATILE));
746        assert!(!caps.contains(crate::function::FnCaps::ELEMENTWISE));
747    }
748
749    #[test]
750    fn test_sum_basic() {
751        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(SumFn));
752        let ctx = interp(&wb);
753        let fctx = ctx.function_context(None);
754
755        // Test basic SUM functionality by creating ArgumentHandles manually
756        let dummy_ast_1 = formualizer_parse::parser::ASTNode::new(
757            formualizer_parse::parser::ASTNodeType::Literal(LiteralValue::Number(1.0)),
758            None,
759        );
760        let dummy_ast_2 = formualizer_parse::parser::ASTNode::new(
761            formualizer_parse::parser::ASTNodeType::Literal(LiteralValue::Number(2.0)),
762            None,
763        );
764        let dummy_ast_3 = formualizer_parse::parser::ASTNode::new(
765            formualizer_parse::parser::ASTNodeType::Literal(LiteralValue::Number(3.0)),
766            None,
767        );
768
769        let args = vec![
770            ArgumentHandle::new(&dummy_ast_1, &ctx),
771            ArgumentHandle::new(&dummy_ast_2, &ctx),
772            ArgumentHandle::new(&dummy_ast_3, &ctx),
773        ];
774
775        let sum_fn = ctx.context.get_function("", "SUM").unwrap();
776        let result = sum_fn.dispatch(&args, &fctx).unwrap().into_literal();
777        assert_eq!(result, LiteralValue::Number(6.0));
778    }
779}
780
781#[cfg(test)]
782mod tests_count {
783    use super::*;
784    use crate::test_workbook::TestWorkbook;
785    use crate::traits::ArgumentHandle;
786    use formualizer_parse::LiteralValue;
787    use formualizer_parse::parser::ASTNode;
788    use formualizer_parse::parser::ASTNodeType;
789
790    fn interp(wb: &TestWorkbook) -> crate::interpreter::Interpreter<'_> {
791        wb.interpreter()
792    }
793
794    #[test]
795    fn count_numbers_ignores_text() {
796        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(CountFn));
797        let ctx = interp(&wb);
798        // COUNT({1,2,"x",3}) => 3
799        let arr = LiteralValue::Array(vec![vec![
800            LiteralValue::Int(1),
801            LiteralValue::Int(2),
802            LiteralValue::Text("x".into()),
803            LiteralValue::Int(3),
804        ]]);
805        let node = ASTNode::new(ASTNodeType::Literal(arr), None);
806        let args = vec![ArgumentHandle::new(&node, &ctx)];
807        let f = ctx.context.get_function("", "COUNT").unwrap();
808        let fctx = ctx.function_context(None);
809        assert_eq!(
810            f.dispatch(&args, &fctx).unwrap().into_literal(),
811            LiteralValue::Number(3.0)
812        );
813    }
814
815    #[test]
816    fn count_multiple_args_and_scalars() {
817        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(CountFn));
818        let ctx = interp(&wb);
819        let n1 = ASTNode::new(ASTNodeType::Literal(LiteralValue::Int(10)), None);
820        let n2 = ASTNode::new(ASTNodeType::Literal(LiteralValue::Text("n".into())), None);
821        let arr = LiteralValue::Array(vec![vec![LiteralValue::Int(1), LiteralValue::Int(2)]]);
822        let a = ASTNode::new(ASTNodeType::Literal(arr), None);
823        let args = vec![
824            ArgumentHandle::new(&a, &ctx),
825            ArgumentHandle::new(&n1, &ctx),
826            ArgumentHandle::new(&n2, &ctx),
827        ];
828        let f = ctx.context.get_function("", "COUNT").unwrap();
829        // Two from array + scalar 10 = 3
830        let fctx = ctx.function_context(None);
831        assert_eq!(
832            f.dispatch(&args, &fctx).unwrap().into_literal(),
833            LiteralValue::Number(3.0)
834        );
835    }
836
837    #[test]
838    fn count_direct_error_argument_propagates() {
839        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(CountFn));
840        let ctx = interp(&wb);
841        let err = ASTNode::new(
842            ASTNodeType::Literal(LiteralValue::Error(ExcelError::from_error_string(
843                "#DIV/0!",
844            ))),
845            None,
846        );
847        let args = vec![ArgumentHandle::new(&err, &ctx)];
848        let f = ctx.context.get_function("", "COUNT").unwrap();
849        let fctx = ctx.function_context(None);
850        match f.dispatch(&args, &fctx).unwrap().into_literal() {
851            LiteralValue::Error(e) => assert_eq!(e, "#DIV/0!"),
852            v => panic!("unexpected {v:?}"),
853        }
854    }
855}
856
857#[cfg(test)]
858mod tests_average {
859    use super::*;
860    use crate::test_workbook::TestWorkbook;
861    use crate::traits::ArgumentHandle;
862    use formualizer_parse::LiteralValue;
863    use formualizer_parse::parser::ASTNode;
864    use formualizer_parse::parser::ASTNodeType;
865
866    fn interp(wb: &TestWorkbook) -> crate::interpreter::Interpreter<'_> {
867        wb.interpreter()
868    }
869
870    #[test]
871    fn average_basic_numbers() {
872        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(AverageFn));
873        let ctx = interp(&wb);
874        let arr = LiteralValue::Array(vec![vec![
875            LiteralValue::Int(2),
876            LiteralValue::Int(4),
877            LiteralValue::Int(6),
878        ]]);
879        let node = ASTNode::new(ASTNodeType::Literal(arr), None);
880        let args = vec![ArgumentHandle::new(&node, &ctx)];
881        let f = ctx.context.get_function("", "AVERAGE").unwrap();
882        assert_eq!(
883            f.dispatch(&args, &ctx.function_context(None))
884                .unwrap()
885                .into_literal(),
886            LiteralValue::Number(4.0)
887        );
888    }
889
890    #[test]
891    fn average_mixed_with_text() {
892        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(AverageFn));
893        let ctx = interp(&wb);
894        let arr = LiteralValue::Array(vec![vec![
895            LiteralValue::Int(2),
896            LiteralValue::Text("x".into()),
897            LiteralValue::Int(6),
898        ]]);
899        let node = ASTNode::new(ASTNodeType::Literal(arr), None);
900        let args = vec![ArgumentHandle::new(&node, &ctx)];
901        let f = ctx.context.get_function("", "AVERAGE").unwrap();
902        // average of 2 and 6 = 4
903        assert_eq!(
904            f.dispatch(&args, &ctx.function_context(None))
905                .unwrap()
906                .into_literal(),
907            LiteralValue::Number(4.0)
908        );
909    }
910
911    #[test]
912    fn average_no_numeric_div0() {
913        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(AverageFn));
914        let ctx = interp(&wb);
915        let arr = LiteralValue::Array(vec![vec![
916            LiteralValue::Text("a".into()),
917            LiteralValue::Text("b".into()),
918        ]]);
919        let node = ASTNode::new(ASTNodeType::Literal(arr), None);
920        let args = vec![ArgumentHandle::new(&node, &ctx)];
921        let f = ctx.context.get_function("", "AVERAGE").unwrap();
922        let fctx = ctx.function_context(None);
923        match f.dispatch(&args, &fctx).unwrap().into_literal() {
924            LiteralValue::Error(e) => assert_eq!(e, "#DIV/0!"),
925            v => panic!("expected #DIV/0!, got {v:?}"),
926        }
927    }
928
929    #[test]
930    fn average_direct_error_argument_propagates() {
931        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(AverageFn));
932        let ctx = interp(&wb);
933        let err = ASTNode::new(
934            ASTNodeType::Literal(LiteralValue::Error(ExcelError::from_error_string(
935                "#DIV/0!",
936            ))),
937            None,
938        );
939        let args = vec![ArgumentHandle::new(&err, &ctx)];
940        let f = ctx.context.get_function("", "AVERAGE").unwrap();
941        let fctx = ctx.function_context(None);
942        match f.dispatch(&args, &fctx).unwrap().into_literal() {
943            LiteralValue::Error(e) => assert_eq!(e, "#DIV/0!"),
944            v => panic!("unexpected {v:?}"),
945        }
946    }
947}
948
949#[derive(Copy, Clone, Debug, Eq, PartialEq)]
950enum VisibilityPolicy {
951    IncludeAll,
952    ExcludeManualOrFilterHidden,
953}
954
955#[derive(Copy, Clone, Debug, Eq, PartialEq)]
956enum ErrorPolicy {
957    Propagate,
958    Ignore,
959}
960
961#[derive(Copy, Clone, Debug, Eq, PartialEq)]
962enum AggregateOp {
963    Average,
964    Count,
965    CountA,
966    Max,
967    Min,
968    Product,
969    StdevSample,
970    StdevPopulation,
971    Sum,
972    VarSample,
973    VarPopulation,
974}
975
976fn aggregate_op_from_function_num(function_num: i32) -> Option<AggregateOp> {
977    match function_num {
978        1 => Some(AggregateOp::Average),
979        2 => Some(AggregateOp::Count),
980        3 => Some(AggregateOp::CountA),
981        4 => Some(AggregateOp::Max),
982        5 => Some(AggregateOp::Min),
983        6 => Some(AggregateOp::Product),
984        7 => Some(AggregateOp::StdevSample),
985        8 => Some(AggregateOp::StdevPopulation),
986        9 => Some(AggregateOp::Sum),
987        10 => Some(AggregateOp::VarSample),
988        11 => Some(AggregateOp::VarPopulation),
989        _ => None,
990    }
991}
992
993fn parse_strict_int_arg(arg: &ArgumentHandle<'_, '_>) -> Result<i32, ExcelError> {
994    let raw = arg.value()?.into_literal();
995    if let LiteralValue::Error(e) = raw {
996        return Err(e);
997    }
998
999    let n = coerce_num(&raw)?;
1000    if !n.is_finite() {
1001        return Err(ExcelError::new_value());
1002    }
1003
1004    let rounded = n.round();
1005    if (n - rounded).abs() > 1e-9 {
1006        return Err(ExcelError::new_value());
1007    }
1008
1009    if rounded < i32::MIN as f64 || rounded > i32::MAX as f64 {
1010        return Err(ExcelError::new_value());
1011    }
1012
1013    Ok(rounded as i32)
1014}
1015
1016fn row_is_visible(mask: Option<&arrow_array::BooleanArray>, relative_row: usize) -> bool {
1017    let Some(mask) = mask else {
1018        return true;
1019    };
1020
1021    if relative_row >= mask.len() || mask.is_null(relative_row) {
1022        return true;
1023    }
1024
1025    mask.value(relative_row)
1026}
1027
1028fn numeric_from_range_value(value: &LiteralValue) -> Option<f64> {
1029    match value {
1030        LiteralValue::Number(n) => Some(*n),
1031        LiteralValue::Int(i) => Some(*i as f64),
1032        LiteralValue::Date(_)
1033        | LiteralValue::DateTime(_)
1034        | LiteralValue::Time(_)
1035        | LiteralValue::Duration(_) => coerce_num(value).ok(),
1036        _ => None,
1037    }
1038}
1039
1040#[derive(Debug, Default)]
1041struct AggregateCollector {
1042    numeric_values: Vec<f64>,
1043    counta: usize,
1044}
1045
1046impl AggregateCollector {
1047    fn collect_args<'a, 'b>(
1048        args: &[ArgumentHandle<'a, 'b>],
1049        start_idx: usize,
1050        ctx: &dyn FunctionContext<'b>,
1051        op: AggregateOp,
1052        visibility_policy: VisibilityPolicy,
1053        error_policy: ErrorPolicy,
1054    ) -> Result<Self, ExcelError> {
1055        let mut out = Self::default();
1056
1057        for arg in args.iter().skip(start_idx) {
1058            if let Ok(view) = arg.range_view() {
1059                out.collect_range_arg(&view, ctx, op, visibility_policy, error_policy)?;
1060            } else {
1061                out.consume_scalar_value(arg.value()?.into_literal(), op, error_policy)?;
1062            }
1063        }
1064
1065        Ok(out)
1066    }
1067
1068    fn collect_range_arg<'b>(
1069        &mut self,
1070        view: &crate::engine::range_view::RangeView<'_>,
1071        ctx: &dyn FunctionContext<'b>,
1072        op: AggregateOp,
1073        visibility_policy: VisibilityPolicy,
1074        error_policy: ErrorPolicy,
1075    ) -> Result<(), ExcelError> {
1076        let visibility_mask = match visibility_policy {
1077            VisibilityPolicy::IncludeAll => None,
1078            VisibilityPolicy::ExcludeManualOrFilterHidden => {
1079                ctx.get_row_visibility_mask(view, VisibilityMaskMode::ExcludeManualOrFilterHidden)
1080            }
1081        };
1082
1083        let (_, cols) = view.dims();
1084        if cols == 0 {
1085            return Ok(());
1086        }
1087
1088        for chunk in view.iter_row_chunks() {
1089            let chunk = chunk?;
1090            for row_offset in 0..chunk.row_len {
1091                let rel_row = chunk.row_start + row_offset;
1092                if !row_is_visible(visibility_mask.as_deref(), rel_row) {
1093                    continue;
1094                }
1095
1096                for col in 0..cols {
1097                    // Phase-1 contract: nested SUBTOTAL/AGGREGATE exclusion is deferred.
1098                    // Nested aggregate results are treated as ordinary scalar values.
1099                    self.consume_range_value(view.get_cell(rel_row, col), op, error_policy)?;
1100                }
1101            }
1102        }
1103
1104        Ok(())
1105    }
1106
1107    fn consume_range_value(
1108        &mut self,
1109        value: LiteralValue,
1110        op: AggregateOp,
1111        error_policy: ErrorPolicy,
1112    ) -> Result<(), ExcelError> {
1113        match value {
1114            LiteralValue::Error(e) => {
1115                if op == AggregateOp::CountA {
1116                    if error_policy == ErrorPolicy::Ignore {
1117                        return Ok(());
1118                    }
1119                    self.counta += 1;
1120                    return Ok(());
1121                }
1122                match error_policy {
1123                    ErrorPolicy::Propagate => Err(e),
1124                    ErrorPolicy::Ignore => Ok(()),
1125                }
1126            }
1127            LiteralValue::Empty => Ok(()),
1128            other => {
1129                self.counta += 1;
1130                if let Some(n) = numeric_from_range_value(&other) {
1131                    self.numeric_values.push(n);
1132                }
1133                Ok(())
1134            }
1135        }
1136    }
1137
1138    fn consume_scalar_value(
1139        &mut self,
1140        value: LiteralValue,
1141        op: AggregateOp,
1142        error_policy: ErrorPolicy,
1143    ) -> Result<(), ExcelError> {
1144        match value {
1145            LiteralValue::Error(e) => {
1146                if op == AggregateOp::CountA {
1147                    if error_policy == ErrorPolicy::Ignore {
1148                        return Ok(());
1149                    }
1150                    self.counta += 1;
1151                    return Ok(());
1152                }
1153                match error_policy {
1154                    ErrorPolicy::Propagate => Err(e),
1155                    ErrorPolicy::Ignore => Ok(()),
1156                }
1157            }
1158            LiteralValue::Array(rows) => {
1159                for row in rows {
1160                    for cell in row {
1161                        self.consume_range_value(cell, op, error_policy)?;
1162                    }
1163                }
1164                Ok(())
1165            }
1166            other => {
1167                match op {
1168                    AggregateOp::CountA => {
1169                        if !matches!(other, LiteralValue::Empty) {
1170                            self.counta += 1;
1171                        }
1172                    }
1173                    AggregateOp::Count => {
1174                        if !matches!(other, LiteralValue::Empty) && coerce_num(&other).is_ok() {
1175                            self.numeric_values.push(0.0);
1176                        }
1177                    }
1178                    _ => {
1179                        if let Ok(n) = coerce_num(&other) {
1180                            self.numeric_values.push(n);
1181                        }
1182                    }
1183                }
1184                Ok(())
1185            }
1186        }
1187    }
1188
1189    fn variance(values: &[f64], sample: bool) -> Result<f64, ExcelError> {
1190        let n = values.len();
1191        if sample {
1192            if n < 2 {
1193                return Err(ExcelError::new_div());
1194            }
1195        } else if n == 0 {
1196            return Err(ExcelError::new_div());
1197        }
1198
1199        let mean = values.iter().copied().sum::<f64>() / (n as f64);
1200        let mut ss = 0.0;
1201        for value in values {
1202            let d = *value - mean;
1203            ss += d * d;
1204        }
1205
1206        if sample {
1207            Ok(ss / ((n - 1) as f64))
1208        } else {
1209            Ok(ss / (n as f64))
1210        }
1211    }
1212
1213    fn finalize(self, op: AggregateOp) -> LiteralValue {
1214        use super::super::utils::aggregate_result;
1215        match op {
1216            AggregateOp::Average => {
1217                if self.numeric_values.is_empty() {
1218                    LiteralValue::Error(ExcelError::new_div())
1219                } else {
1220                    let sum = self.numeric_values.iter().copied().sum::<f64>();
1221                    aggregate_result(sum / (self.numeric_values.len() as f64))
1222                }
1223            }
1224            // Counts cannot overflow to non-finite; keep them branch-free.
1225            AggregateOp::Count => LiteralValue::Number(self.numeric_values.len() as f64),
1226            AggregateOp::CountA => LiteralValue::Number(self.counta as f64),
1227            AggregateOp::Max => aggregate_result(
1228                self.numeric_values
1229                    .iter()
1230                    .copied()
1231                    .reduce(f64::max)
1232                    .unwrap_or(0.0),
1233            ),
1234            AggregateOp::Min => aggregate_result(
1235                self.numeric_values
1236                    .iter()
1237                    .copied()
1238                    .reduce(f64::min)
1239                    .unwrap_or(0.0),
1240            ),
1241            AggregateOp::Product => {
1242                if self.numeric_values.is_empty() {
1243                    LiteralValue::Number(0.0)
1244                } else {
1245                    aggregate_result(self.numeric_values.iter().copied().product::<f64>())
1246                }
1247            }
1248            AggregateOp::StdevSample => match Self::variance(&self.numeric_values, true) {
1249                Ok(v) => aggregate_result(v.sqrt()),
1250                Err(e) => LiteralValue::Error(e),
1251            },
1252            AggregateOp::StdevPopulation => match Self::variance(&self.numeric_values, false) {
1253                Ok(v) => aggregate_result(v.sqrt()),
1254                Err(e) => LiteralValue::Error(e),
1255            },
1256            AggregateOp::Sum => aggregate_result(self.numeric_values.iter().copied().sum()),
1257            AggregateOp::VarSample => match Self::variance(&self.numeric_values, true) {
1258                Ok(v) => aggregate_result(v),
1259                Err(e) => LiteralValue::Error(e),
1260            },
1261            AggregateOp::VarPopulation => match Self::variance(&self.numeric_values, false) {
1262                Ok(v) => aggregate_result(v),
1263                Err(e) => LiteralValue::Error(e),
1264            },
1265        }
1266    }
1267}
1268
1269#[derive(Debug)]
1270pub struct SubtotalFn;
1271
1272/// [formualizer-docgen:schema:start]
1273/// Name: SUBTOTAL
1274/// Type: SubtotalFn
1275/// Min args: 2
1276/// Max args: variadic
1277/// Variadic: true
1278/// Signature: SUBTOTAL(arg1...: number@range)
1279/// Arg schema: arg1{kinds=number,required=true,shape=range,by_ref=false,coercion=NumberLenientText,max=None,repeating=None,default=false}
1280/// Caps: VOLATILE, REDUCTION, NUMERIC_ONLY, STREAM_OK
1281/// [formualizer-docgen:schema:end]
1282impl Function for SubtotalFn {
1283    func_caps!(VOLATILE, REDUCTION, NUMERIC_ONLY, STREAM_OK);
1284
1285    fn name(&self) -> &'static str {
1286        "SUBTOTAL"
1287    }
1288
1289    fn min_args(&self) -> usize {
1290        2
1291    }
1292
1293    fn variadic(&self) -> bool {
1294        true
1295    }
1296
1297    fn arg_schema(&self) -> &'static [ArgSchema] {
1298        &ARG_RANGE_NUM_LENIENT_ONE[..]
1299    }
1300
1301    fn eval<'a, 'b, 'c>(
1302        &self,
1303        args: &'c [ArgumentHandle<'a, 'b>],
1304        ctx: &dyn FunctionContext<'b>,
1305    ) -> Result<crate::traits::CalcValue<'b>, ExcelError> {
1306        if args.len() < 2 {
1307            return Ok(crate::traits::CalcValue::Scalar(LiteralValue::Error(
1308                ExcelError::new_value(),
1309            )));
1310        }
1311
1312        let function_num = match parse_strict_int_arg(&args[0]) {
1313            Ok(v) => v,
1314            Err(e) => return Ok(crate::traits::CalcValue::Scalar(LiteralValue::Error(e))),
1315        };
1316
1317        let (mapped_code, visibility) = if (1..=11).contains(&function_num) {
1318            (function_num, VisibilityPolicy::IncludeAll)
1319        } else if (101..=111).contains(&function_num) {
1320            (
1321                function_num - 100,
1322                VisibilityPolicy::ExcludeManualOrFilterHidden,
1323            )
1324        } else {
1325            return Ok(crate::traits::CalcValue::Scalar(LiteralValue::Error(
1326                ExcelError::new_value(),
1327            )));
1328        };
1329
1330        let Some(op) = aggregate_op_from_function_num(mapped_code) else {
1331            return Ok(crate::traits::CalcValue::Scalar(LiteralValue::Error(
1332                ExcelError::new_value(),
1333            )));
1334        };
1335
1336        let collected = match AggregateCollector::collect_args(
1337            args,
1338            1,
1339            ctx,
1340            op,
1341            visibility,
1342            ErrorPolicy::Propagate,
1343        ) {
1344            Ok(c) => c,
1345            Err(e) => return Ok(crate::traits::CalcValue::Scalar(LiteralValue::Error(e))),
1346        };
1347
1348        Ok(crate::traits::CalcValue::Scalar(collected.finalize(op)))
1349    }
1350}
1351
1352#[derive(Debug)]
1353pub struct AggregateFn;
1354
1355/// [formualizer-docgen:schema:start]
1356/// Name: AGGREGATE
1357/// Type: AggregateFn
1358/// Min args: 3
1359/// Max args: variadic
1360/// Variadic: true
1361/// Signature: AGGREGATE(arg1...: number@range)
1362/// Arg schema: arg1{kinds=number,required=true,shape=range,by_ref=false,coercion=NumberLenientText,max=None,repeating=None,default=false}
1363/// Caps: VOLATILE, REDUCTION, NUMERIC_ONLY, STREAM_OK
1364/// [formualizer-docgen:schema:end]
1365impl Function for AggregateFn {
1366    func_caps!(VOLATILE, REDUCTION, NUMERIC_ONLY, STREAM_OK);
1367
1368    fn name(&self) -> &'static str {
1369        "AGGREGATE"
1370    }
1371
1372    fn min_args(&self) -> usize {
1373        3
1374    }
1375
1376    fn variadic(&self) -> bool {
1377        true
1378    }
1379
1380    fn arg_schema(&self) -> &'static [ArgSchema] {
1381        &ARG_RANGE_NUM_LENIENT_ONE[..]
1382    }
1383
1384    fn eval<'a, 'b, 'c>(
1385        &self,
1386        args: &'c [ArgumentHandle<'a, 'b>],
1387        ctx: &dyn FunctionContext<'b>,
1388    ) -> Result<crate::traits::CalcValue<'b>, ExcelError> {
1389        if args.len() < 3 {
1390            return Ok(crate::traits::CalcValue::Scalar(LiteralValue::Error(
1391                ExcelError::new_value(),
1392            )));
1393        }
1394
1395        let function_num = match parse_strict_int_arg(&args[0]) {
1396            Ok(v) => v,
1397            Err(e) => return Ok(crate::traits::CalcValue::Scalar(LiteralValue::Error(e))),
1398        };
1399
1400        let op = if (1..=11).contains(&function_num) {
1401            aggregate_op_from_function_num(function_num)
1402                .expect("validated AGGREGATE function_num maps to operation")
1403        } else if (12..=19).contains(&function_num) {
1404            return Ok(crate::traits::CalcValue::Scalar(LiteralValue::Error(
1405                ExcelError::new(ExcelErrorKind::NImpl),
1406            )));
1407        } else {
1408            return Ok(crate::traits::CalcValue::Scalar(LiteralValue::Error(
1409                ExcelError::new_value(),
1410            )));
1411        };
1412
1413        let options = match parse_strict_int_arg(&args[1]) {
1414            Ok(v) => v,
1415            Err(e) => return Ok(crate::traits::CalcValue::Scalar(LiteralValue::Error(e))),
1416        };
1417
1418        let (visibility, error_policy) = match options {
1419            0 => (VisibilityPolicy::IncludeAll, ErrorPolicy::Propagate),
1420            1 => (
1421                VisibilityPolicy::ExcludeManualOrFilterHidden,
1422                ErrorPolicy::Propagate,
1423            ),
1424            2 => (VisibilityPolicy::IncludeAll, ErrorPolicy::Ignore),
1425            3 => (
1426                VisibilityPolicy::ExcludeManualOrFilterHidden,
1427                ErrorPolicy::Ignore,
1428            ),
1429            4..=7 => {
1430                return Ok(crate::traits::CalcValue::Scalar(LiteralValue::Error(
1431                    ExcelError::new(ExcelErrorKind::NImpl),
1432                )));
1433            }
1434            _ => {
1435                return Ok(crate::traits::CalcValue::Scalar(LiteralValue::Error(
1436                    ExcelError::new_value(),
1437                )));
1438            }
1439        };
1440
1441        let collected =
1442            match AggregateCollector::collect_args(args, 2, ctx, op, visibility, error_policy) {
1443                Ok(c) => c,
1444                Err(e) => return Ok(crate::traits::CalcValue::Scalar(LiteralValue::Error(e))),
1445            };
1446
1447        Ok(crate::traits::CalcValue::Scalar(collected.finalize(op)))
1448    }
1449}
1450
1451#[cfg(test)]
1452mod tests_subtotal_aggregate {
1453    use super::*;
1454    use crate::test_workbook::TestWorkbook;
1455    use crate::traits::ArgumentHandle;
1456    use formualizer_common::{ExcelErrorKind, LiteralValue};
1457    use formualizer_parse::parser::{ASTNode, ASTNodeType};
1458
1459    fn interp(wb: &TestWorkbook) -> crate::interpreter::Interpreter<'_> {
1460        wb.interpreter()
1461    }
1462
1463    fn lit(value: LiteralValue) -> ASTNode {
1464        ASTNode::new(ASTNodeType::Literal(value), None)
1465    }
1466
1467    fn dispatch(
1468        ctx: &crate::interpreter::Interpreter<'_>,
1469        fn_name: &str,
1470        nodes: &[ASTNode],
1471    ) -> LiteralValue {
1472        let args: Vec<_> = nodes.iter().map(|n| ArgumentHandle::new(n, ctx)).collect();
1473        let f = ctx.context.get_function("", fn_name).expect("function");
1474        f.dispatch(&args, &ctx.function_context(None))
1475            .expect("dispatch")
1476            .into_literal()
1477    }
1478
1479    fn assert_num_close(value: LiteralValue, expected: f64) {
1480        match value {
1481            LiteralValue::Number(n) => assert!((n - expected).abs() < 1e-9, "{n} != {expected}"),
1482            LiteralValue::Int(i) => assert!(((i as f64) - expected).abs() < 1e-9),
1483            other => panic!("expected numeric {expected}, got {other:?}"),
1484        }
1485    }
1486
1487    fn assert_error_kind(value: LiteralValue, expected: ExcelErrorKind) {
1488        match value {
1489            LiteralValue::Error(e) => assert_eq!(e.kind, expected),
1490            other => panic!("expected error {:?}, got {other:?}", expected),
1491        }
1492    }
1493
1494    #[test]
1495    fn subtotal_function_num_mapping_basics() {
1496        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(SubtotalFn));
1497        let ctx = interp(&wb);
1498        let values = LiteralValue::Array(vec![vec![
1499            LiteralValue::Int(1),
1500            LiteralValue::Int(2),
1501            LiteralValue::Int(3),
1502        ]]);
1503
1504        let cases: &[(i64, f64)] = &[
1505            (1, 2.0),
1506            (2, 3.0),
1507            (3, 3.0),
1508            (4, 3.0),
1509            (5, 1.0),
1510            (6, 6.0),
1511            (7, 1.0),
1512            (8, (2.0f64 / 3.0).sqrt()),
1513            (9, 6.0),
1514            (10, 1.0),
1515            (11, 2.0 / 3.0),
1516        ];
1517
1518        for (code, expected) in cases {
1519            let args = vec![lit(LiteralValue::Int(*code)), lit(values.clone())];
1520            let out = dispatch(&ctx, "SUBTOTAL", &args);
1521            assert_num_close(out, *expected);
1522        }
1523    }
1524
1525    #[test]
1526    fn subtotal_counta_counts_errors_as_non_empty() {
1527        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(SubtotalFn));
1528        let ctx = interp(&wb);
1529
1530        let args = vec![
1531            lit(LiteralValue::Int(3)),
1532            lit(LiteralValue::Array(vec![vec![
1533                LiteralValue::Int(1),
1534                LiteralValue::Error(ExcelError::new_div()),
1535                LiteralValue::Text("x".into()),
1536                LiteralValue::Text("".into()),
1537            ]])),
1538        ];
1539        let out = dispatch(&ctx, "SUBTOTAL", &args);
1540        assert_num_close(out, 4.0);
1541    }
1542
1543    #[test]
1544    fn subtotal_invalid_function_num_returns_value_error() {
1545        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(SubtotalFn));
1546        let ctx = interp(&wb);
1547
1548        let args = vec![
1549            lit(LiteralValue::Number(9.5)),
1550            lit(LiteralValue::Array(vec![vec![LiteralValue::Int(1)]])),
1551        ];
1552        let out = dispatch(&ctx, "SUBTOTAL", &args);
1553        assert_error_kind(out, ExcelErrorKind::Value);
1554    }
1555
1556    #[test]
1557    fn subtotal_requires_ref_argument() {
1558        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(SubtotalFn));
1559        let ctx = interp(&wb);
1560
1561        let out = dispatch(&ctx, "SUBTOTAL", &[lit(LiteralValue::Int(9))]);
1562        assert_error_kind(out, ExcelErrorKind::Value);
1563    }
1564
1565    #[test]
1566    fn aggregate_requires_options_and_ref_argument() {
1567        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(AggregateFn));
1568        let ctx = interp(&wb);
1569
1570        let out = dispatch(
1571            &ctx,
1572            "AGGREGATE",
1573            &[lit(LiteralValue::Int(9)), lit(LiteralValue::Int(0))],
1574        );
1575        assert_error_kind(out, ExcelErrorKind::Value);
1576    }
1577
1578    #[test]
1579    fn aggregate_options_zero_to_three_control_error_behavior() {
1580        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(AggregateFn));
1581        let ctx = interp(&wb);
1582        let values = LiteralValue::Array(vec![vec![
1583            LiteralValue::Int(10),
1584            LiteralValue::Error(ExcelError::new_div()),
1585            LiteralValue::Int(30),
1586        ]]);
1587
1588        let opt0 = dispatch(
1589            &ctx,
1590            "AGGREGATE",
1591            &[
1592                lit(LiteralValue::Int(9)),
1593                lit(LiteralValue::Int(0)),
1594                lit(values.clone()),
1595            ],
1596        );
1597        assert_error_kind(opt0, ExcelErrorKind::Div);
1598
1599        let opt1 = dispatch(
1600            &ctx,
1601            "AGGREGATE",
1602            &[
1603                lit(LiteralValue::Int(9)),
1604                lit(LiteralValue::Int(1)),
1605                lit(values.clone()),
1606            ],
1607        );
1608        assert_error_kind(opt1, ExcelErrorKind::Div);
1609
1610        let opt2 = dispatch(
1611            &ctx,
1612            "AGGREGATE",
1613            &[
1614                lit(LiteralValue::Int(9)),
1615                lit(LiteralValue::Int(2)),
1616                lit(values.clone()),
1617            ],
1618        );
1619        assert_num_close(opt2, 40.0);
1620
1621        let opt3 = dispatch(
1622            &ctx,
1623            "AGGREGATE",
1624            &[
1625                lit(LiteralValue::Int(9)),
1626                lit(LiteralValue::Int(3)),
1627                lit(values),
1628            ],
1629        );
1630        assert_num_close(opt3, 40.0);
1631    }
1632
1633    #[test]
1634    fn aggregate_counta_option_ignore_errors_skips_error_values() {
1635        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(AggregateFn));
1636        let ctx = interp(&wb);
1637
1638        let out = dispatch(
1639            &ctx,
1640            "AGGREGATE",
1641            &[
1642                lit(LiteralValue::Int(3)),
1643                lit(LiteralValue::Int(2)),
1644                lit(LiteralValue::Array(vec![vec![
1645                    LiteralValue::Int(1),
1646                    LiteralValue::Error(ExcelError::new_div()),
1647                    LiteralValue::Text("x".into()),
1648                ]])),
1649            ],
1650        );
1651        assert_num_close(out, 2.0);
1652    }
1653
1654    #[test]
1655    fn aggregate_unsupported_option_returns_nimpl() {
1656        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(AggregateFn));
1657        let ctx = interp(&wb);
1658
1659        let out = dispatch(
1660            &ctx,
1661            "AGGREGATE",
1662            &[
1663                lit(LiteralValue::Int(9)),
1664                lit(LiteralValue::Int(4)),
1665                lit(LiteralValue::Array(vec![vec![LiteralValue::Int(1)]])),
1666            ],
1667        );
1668        assert_error_kind(out, ExcelErrorKind::NImpl);
1669    }
1670
1671    #[test]
1672    fn aggregate_unsupported_function_num_returns_nimpl() {
1673        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(AggregateFn));
1674        let ctx = interp(&wb);
1675
1676        let out = dispatch(
1677            &ctx,
1678            "AGGREGATE",
1679            &[
1680                lit(LiteralValue::Int(12)),
1681                lit(LiteralValue::Int(0)),
1682                lit(LiteralValue::Array(vec![vec![LiteralValue::Int(1)]])),
1683            ],
1684        );
1685        assert_error_kind(out, ExcelErrorKind::NImpl);
1686    }
1687}
1688
1689pub fn register_builtins() {
1690    crate::function_registry::register_function(std::sync::Arc::new(SumProductFn));
1691    crate::function_registry::register_function(std::sync::Arc::new(SumFn));
1692    crate::function_registry::register_function(std::sync::Arc::new(CountFn));
1693    crate::function_registry::register_function(std::sync::Arc::new(AverageFn));
1694    crate::function_registry::register_function(std::sync::Arc::new(SubtotalFn));
1695    crate::function_registry::register_function(std::sync::Arc::new(AggregateFn));
1696}