Skip to main content

formualizer_eval/builtins/math/
aggregate.rs

1use super::super::utils::{ARG_RANGE_NUM_LENIENT_ONE, coerce_num};
2use crate::args::ArgSchema;
3use crate::function::Function;
4use crate::traits::{ArgumentHandle, FunctionContext};
5use arrow_array::Array;
6use formualizer_common::{ExcelError, LiteralValue};
7use formualizer_macros::func_caps;
8
9/* ─────────────────────────── SUM() ──────────────────────────── */
10
11#[derive(Debug)]
12pub struct SumFn;
13
14impl Function for SumFn {
15    func_caps!(PURE, REDUCTION, NUMERIC_ONLY, STREAM_OK, PARALLEL_ARGS);
16
17    fn name(&self) -> &'static str {
18        "SUM"
19    }
20    fn min_args(&self) -> usize {
21        0
22    }
23    fn variadic(&self) -> bool {
24        true
25    }
26    fn arg_schema(&self) -> &'static [ArgSchema] {
27        &ARG_RANGE_NUM_LENIENT_ONE[..]
28    }
29
30    fn eval<'a, 'b, 'c>(
31        &self,
32        args: &'c [ArgumentHandle<'a, 'b>],
33        ctx: &dyn FunctionContext<'b>,
34    ) -> Result<crate::traits::CalcValue<'b>, ExcelError> {
35        let mut total = 0.0;
36        for arg in args {
37            if let Ok(view) = arg.range_view() {
38                // Propagate errors from range first
39                for res in view.errors_slices() {
40                    let (_, _, err_cols) = res?;
41                    for col in err_cols {
42                        if col.null_count() < col.len() {
43                            for i in 0..col.len() {
44                                if !col.is_null(i) {
45                                    return Ok(crate::traits::CalcValue::Scalar(
46                                        LiteralValue::Error(ExcelError::new(
47                                            crate::arrow_store::unmap_error_code(col.value(i)),
48                                        )),
49                                    ));
50                                }
51                            }
52                        }
53                    }
54                }
55
56                for res in view.numbers_slices() {
57                    let (_, _, num_cols) = res?;
58                    for col in num_cols {
59                        total +=
60                            arrow::compute::kernels::aggregate::sum(col.as_ref()).unwrap_or(0.0);
61                    }
62                }
63            } else {
64                let v = arg.value()?.into_literal();
65                match v {
66                    LiteralValue::Error(e) => {
67                        return Ok(crate::traits::CalcValue::Scalar(LiteralValue::Error(e)));
68                    }
69                    v => total += coerce_num(&v)?,
70                }
71            }
72        }
73        Ok(crate::traits::CalcValue::Scalar(LiteralValue::Number(
74            total,
75        )))
76    }
77}
78
79/* ─────────────────────────── COUNT() ──────────────────────────── */
80
81#[derive(Debug)]
82pub struct CountFn;
83
84impl Function for CountFn {
85    func_caps!(PURE, REDUCTION, NUMERIC_ONLY, STREAM_OK);
86
87    fn name(&self) -> &'static str {
88        "COUNT"
89    }
90    fn min_args(&self) -> usize {
91        0
92    }
93    fn variadic(&self) -> bool {
94        true
95    }
96    fn arg_schema(&self) -> &'static [ArgSchema] {
97        &ARG_RANGE_NUM_LENIENT_ONE[..]
98    }
99
100    fn eval<'a, 'b, 'c>(
101        &self,
102        args: &'c [ArgumentHandle<'a, 'b>],
103        _: &dyn FunctionContext<'b>,
104    ) -> Result<crate::traits::CalcValue<'b>, ExcelError> {
105        let mut count: i64 = 0;
106        for arg in args {
107            if let Ok(view) = arg.range_view() {
108                for res in view.numbers_slices() {
109                    let (_, _, num_cols) = res?;
110                    for col in num_cols {
111                        count += (col.len() - col.null_count()) as i64;
112                    }
113                }
114            } else {
115                let v = arg.value()?.into_literal();
116                if let LiteralValue::Error(e) = v {
117                    return Ok(crate::traits::CalcValue::Scalar(LiteralValue::Error(e)));
118                }
119                if !matches!(v, LiteralValue::Empty) && coerce_num(&v).is_ok() {
120                    count += 1;
121                }
122            }
123        }
124        Ok(crate::traits::CalcValue::Scalar(LiteralValue::Number(
125            count as f64,
126        )))
127    }
128}
129
130/* ─────────────────────────── AVERAGE() ──────────────────────────── */
131
132#[derive(Debug)]
133pub struct AverageFn;
134
135impl Function for AverageFn {
136    func_caps!(PURE, REDUCTION, NUMERIC_ONLY, STREAM_OK);
137
138    fn name(&self) -> &'static str {
139        "AVERAGE"
140    }
141    fn min_args(&self) -> usize {
142        1
143    }
144    fn variadic(&self) -> bool {
145        true
146    }
147    fn arg_schema(&self) -> &'static [ArgSchema] {
148        &ARG_RANGE_NUM_LENIENT_ONE[..]
149    }
150
151    fn eval<'a, 'b, 'c>(
152        &self,
153        args: &'c [ArgumentHandle<'a, 'b>],
154        ctx: &dyn FunctionContext<'b>,
155    ) -> Result<crate::traits::CalcValue<'b>, ExcelError> {
156        let mut sum = 0.0f64;
157        let mut cnt: i64 = 0;
158        for arg in args {
159            if let Ok(view) = arg.range_view() {
160                // Propagate errors from range first
161                for res in view.errors_slices() {
162                    let (_, _, err_cols) = res?;
163                    for col in err_cols {
164                        if col.null_count() < col.len() {
165                            for i in 0..col.len() {
166                                if !col.is_null(i) {
167                                    return Ok(crate::traits::CalcValue::Scalar(
168                                        LiteralValue::Error(ExcelError::new(
169                                            crate::arrow_store::unmap_error_code(col.value(i)),
170                                        )),
171                                    ));
172                                }
173                            }
174                        }
175                    }
176                }
177
178                for res in view.numbers_slices() {
179                    let (_, _, num_cols) = res?;
180                    for col in num_cols {
181                        sum += arrow::compute::kernels::aggregate::sum(col.as_ref()).unwrap_or(0.0);
182                        cnt += (col.len() - col.null_count()) as i64;
183                    }
184                }
185            } else {
186                let v = arg.value()?.into_literal();
187                if let LiteralValue::Error(e) = v {
188                    return Ok(crate::traits::CalcValue::Scalar(LiteralValue::Error(e)));
189                }
190                if let Ok(n) = crate::coercion::to_number_lenient_with_locale(&v, &ctx.locale()) {
191                    sum += n;
192                    cnt += 1;
193                }
194            }
195        }
196        if cnt == 0 {
197            return Ok(crate::traits::CalcValue::Scalar(LiteralValue::Error(
198                ExcelError::new_div(),
199            )));
200        }
201        Ok(crate::traits::CalcValue::Scalar(LiteralValue::Number(
202            sum / (cnt as f64),
203        )))
204    }
205}
206
207/* ──────────────────────── SUMPRODUCT() ───────────────────────── */
208
209#[derive(Debug)]
210pub struct SumProductFn;
211
212impl Function for SumProductFn {
213    // Pure reduction over arrays; uses broadcasting and lenient coercion
214    func_caps!(PURE, REDUCTION);
215
216    fn name(&self) -> &'static str {
217        "SUMPRODUCT"
218    }
219    fn min_args(&self) -> usize {
220        1
221    }
222    fn variadic(&self) -> bool {
223        true
224    }
225    fn arg_schema(&self) -> &'static [ArgSchema] {
226        // Accept ranges or scalars; numeric lenient coercion
227        &ARG_RANGE_NUM_LENIENT_ONE[..]
228    }
229
230    fn eval<'a, 'b, 'c>(
231        &self,
232        args: &'c [ArgumentHandle<'a, 'b>],
233        _: &dyn FunctionContext<'b>,
234    ) -> Result<crate::traits::CalcValue<'b>, ExcelError> {
235        use crate::broadcast::{broadcast_shape, project_index};
236
237        if args.is_empty() {
238            return Ok(crate::traits::CalcValue::Scalar(LiteralValue::Number(0.0)));
239        }
240
241        // Helper: materialize an argument to a 2D array of LiteralValue
242        let to_array = |ah: &ArgumentHandle| -> Result<Vec<Vec<LiteralValue>>, ExcelError> {
243            if let Ok(rv) = ah.range_view() {
244                let mut rows: Vec<Vec<LiteralValue>> = Vec::new();
245                rv.for_each_row(&mut |row| {
246                    rows.push(row.to_vec());
247                    Ok(())
248                })?;
249                Ok(rows)
250            } else {
251                let v = ah.value()?.into_literal();
252                Ok(match v {
253                    LiteralValue::Array(arr) => arr,
254                    other => vec![vec![other]],
255                })
256            }
257        };
258
259        // Collect arrays and shapes
260        let mut arrays: Vec<Vec<Vec<LiteralValue>>> = Vec::with_capacity(args.len());
261        let mut shapes: Vec<(usize, usize)> = Vec::with_capacity(args.len());
262        for a in args.iter() {
263            let arr = to_array(a)?;
264            let shape = (arr.len(), arr.first().map(|r| r.len()).unwrap_or(0));
265            arrays.push(arr);
266            shapes.push(shape);
267        }
268
269        // Compute broadcast target shape across all args
270        let target = match broadcast_shape(&shapes) {
271            Ok(s) => s,
272            Err(_) => {
273                return Ok(crate::traits::CalcValue::Scalar(LiteralValue::Error(
274                    ExcelError::new_value(),
275                )));
276            }
277        };
278
279        // Iterate target shape, multiply coerced values across args, sum total
280        let mut total = 0.0f64;
281        for r in 0..target.0 {
282            for c in 0..target.1 {
283                let mut prod = 1.0f64;
284                for (arr, &shape) in arrays.iter().zip(shapes.iter()) {
285                    let (rr, cc) = project_index((r, c), shape);
286                    let lv = arr
287                        .get(rr)
288                        .and_then(|row| row.get(cc))
289                        .cloned()
290                        .unwrap_or(LiteralValue::Empty);
291                    match lv {
292                        LiteralValue::Error(e) => {
293                            return Ok(crate::traits::CalcValue::Scalar(LiteralValue::Error(e)));
294                        }
295                        _ => match super::super::utils::coerce_num(&lv) {
296                            Ok(n) => {
297                                prod *= n;
298                            }
299                            Err(_) => {
300                                // Non-numeric -> treated as 0 in SUMPRODUCT
301                                prod *= 0.0;
302                            }
303                        },
304                    }
305                }
306                total += prod;
307            }
308        }
309        Ok(crate::traits::CalcValue::Scalar(LiteralValue::Number(
310            total,
311        )))
312    }
313}
314
315#[cfg(test)]
316mod tests_sumproduct {
317    use super::*;
318    use crate::test_workbook::TestWorkbook;
319    use crate::traits::ArgumentHandle;
320    use formualizer_parse::LiteralValue;
321    use formualizer_parse::parser::{ASTNode, ASTNodeType};
322
323    fn interp(wb: &TestWorkbook) -> crate::interpreter::Interpreter<'_> {
324        wb.interpreter()
325    }
326
327    fn arr(vals: Vec<Vec<LiteralValue>>) -> ASTNode {
328        ASTNode::new(ASTNodeType::Literal(LiteralValue::Array(vals)), None)
329    }
330
331    fn num(n: f64) -> ASTNode {
332        ASTNode::new(ASTNodeType::Literal(LiteralValue::Number(n)), None)
333    }
334
335    #[test]
336    fn sumproduct_basic_pairwise() {
337        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(SumProductFn));
338        let ctx = interp(&wb);
339        // {1,2,3} * {4,5,6} = 1*4 + 2*5 + 3*6 = 32
340        let a = arr(vec![vec![
341            LiteralValue::Int(1),
342            LiteralValue::Int(2),
343            LiteralValue::Int(3),
344        ]]);
345        let b = arr(vec![vec![
346            LiteralValue::Int(4),
347            LiteralValue::Int(5),
348            LiteralValue::Int(6),
349        ]]);
350        let args = vec![ArgumentHandle::new(&a, &ctx), ArgumentHandle::new(&b, &ctx)];
351        let f = ctx.context.get_function("", "SUMPRODUCT").unwrap();
352        assert_eq!(
353            f.dispatch(&args, &ctx.function_context(None))
354                .unwrap()
355                .into_literal(),
356            LiteralValue::Number(32.0)
357        );
358    }
359
360    #[test]
361    fn sumproduct_variadic_three_arrays() {
362        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(SumProductFn));
363        let ctx = interp(&wb);
364        // {1,2} * {3,4} * {2,2} = (1*3*2) + (2*4*2) = 6 + 16 = 22
365        let a = arr(vec![vec![LiteralValue::Int(1), LiteralValue::Int(2)]]);
366        let b = arr(vec![vec![LiteralValue::Int(3), LiteralValue::Int(4)]]);
367        let c = arr(vec![vec![LiteralValue::Int(2), LiteralValue::Int(2)]]);
368        let args = vec![
369            ArgumentHandle::new(&a, &ctx),
370            ArgumentHandle::new(&b, &ctx),
371            ArgumentHandle::new(&c, &ctx),
372        ];
373        let f = ctx.context.get_function("", "SUMPRODUCT").unwrap();
374        assert_eq!(
375            f.dispatch(&args, &ctx.function_context(None))
376                .unwrap()
377                .into_literal(),
378            LiteralValue::Number(22.0)
379        );
380    }
381
382    #[test]
383    fn sumproduct_broadcast_scalar_over_array() {
384        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(SumProductFn));
385        let ctx = interp(&wb);
386        // {1,2,3} * 10 => (1*10 + 2*10 + 3*10) = 60
387        let a = arr(vec![vec![
388            LiteralValue::Int(1),
389            LiteralValue::Int(2),
390            LiteralValue::Int(3),
391        ]]);
392        let s = num(10.0);
393        let args = vec![ArgumentHandle::new(&a, &ctx), ArgumentHandle::new(&s, &ctx)];
394        let f = ctx.context.get_function("", "SUMPRODUCT").unwrap();
395        assert_eq!(
396            f.dispatch(&args, &ctx.function_context(None))
397                .unwrap()
398                .into_literal(),
399            LiteralValue::Number(60.0)
400        );
401    }
402
403    #[test]
404    fn sumproduct_2d_arrays_broadcast_rows_cols() {
405        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(SumProductFn));
406        let ctx = interp(&wb);
407        // A is 2x2, B is 1x2 -> broadcast B across rows
408        // A = [[1,2],[3,4]], B = [[10,20]]
409        // sum = 1*10 + 2*20 + 3*10 + 4*20 = 10 + 40 + 30 + 80 = 160
410        let a = arr(vec![
411            vec![LiteralValue::Int(1), LiteralValue::Int(2)],
412            vec![LiteralValue::Int(3), LiteralValue::Int(4)],
413        ]);
414        let b = arr(vec![vec![LiteralValue::Int(10), LiteralValue::Int(20)]]);
415        let args = vec![ArgumentHandle::new(&a, &ctx), ArgumentHandle::new(&b, &ctx)];
416        let f = ctx.context.get_function("", "SUMPRODUCT").unwrap();
417        assert_eq!(
418            f.dispatch(&args, &ctx.function_context(None))
419                .unwrap()
420                .into_literal(),
421            LiteralValue::Number(160.0)
422        );
423    }
424
425    #[test]
426    fn sumproduct_non_numeric_treated_as_zero() {
427        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(SumProductFn));
428        let ctx = interp(&wb);
429        // {1,"x",3} * {1,1,1} => 1*1 + 0*1 + 3*1 = 4
430        let a = arr(vec![vec![
431            LiteralValue::Int(1),
432            LiteralValue::Text("x".into()),
433            LiteralValue::Int(3),
434        ]]);
435        let b = arr(vec![vec![
436            LiteralValue::Int(1),
437            LiteralValue::Int(1),
438            LiteralValue::Int(1),
439        ]]);
440        let args = vec![ArgumentHandle::new(&a, &ctx), ArgumentHandle::new(&b, &ctx)];
441        let f = ctx.context.get_function("", "SUMPRODUCT").unwrap();
442        assert_eq!(
443            f.dispatch(&args, &ctx.function_context(None))
444                .unwrap()
445                .into_literal(),
446            LiteralValue::Number(4.0)
447        );
448    }
449
450    #[test]
451    fn sumproduct_error_in_input_propagates() {
452        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(SumProductFn));
453        let ctx = interp(&wb);
454        let a = arr(vec![vec![LiteralValue::Int(1), LiteralValue::Int(2)]]);
455        let e = ASTNode::new(
456            ASTNodeType::Literal(LiteralValue::Error(ExcelError::new_na())),
457            None,
458        );
459        let args = vec![ArgumentHandle::new(&a, &ctx), ArgumentHandle::new(&e, &ctx)];
460        let f = ctx.context.get_function("", "SUMPRODUCT").unwrap();
461        match f
462            .dispatch(&args, &ctx.function_context(None))
463            .unwrap()
464            .into_literal()
465        {
466            LiteralValue::Error(err) => assert_eq!(err, "#N/A"),
467            v => panic!("expected error, got {v:?}"),
468        }
469    }
470
471    #[test]
472    fn sumproduct_incompatible_shapes_value_error() {
473        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(SumProductFn));
474        let ctx = interp(&wb);
475        // 1x3 and 1x2 -> #VALUE!
476        let a = arr(vec![vec![
477            LiteralValue::Int(1),
478            LiteralValue::Int(2),
479            LiteralValue::Int(3),
480        ]]);
481        let b = arr(vec![vec![LiteralValue::Int(4), LiteralValue::Int(5)]]);
482        let args = vec![ArgumentHandle::new(&a, &ctx), ArgumentHandle::new(&b, &ctx)];
483        let f = ctx.context.get_function("", "SUMPRODUCT").unwrap();
484        match f
485            .dispatch(&args, &ctx.function_context(None))
486            .unwrap()
487            .into_literal()
488        {
489            LiteralValue::Error(e) => assert_eq!(e, "#VALUE!"),
490            v => panic!("expected value error, got {v:?}"),
491        }
492    }
493}
494
495#[cfg(test)]
496mod tests {
497    use super::*;
498    use crate::test_workbook::TestWorkbook;
499    use formualizer_parse::LiteralValue;
500
501    fn interp(wb: &TestWorkbook) -> crate::interpreter::Interpreter<'_> {
502        wb.interpreter()
503    }
504
505    #[test]
506    fn test_sum_caps() {
507        let sum_fn = SumFn;
508        let caps = sum_fn.caps();
509
510        // Check that the expected capabilities are set
511        assert!(caps.contains(crate::function::FnCaps::PURE));
512        assert!(caps.contains(crate::function::FnCaps::REDUCTION));
513        assert!(caps.contains(crate::function::FnCaps::NUMERIC_ONLY));
514        assert!(caps.contains(crate::function::FnCaps::STREAM_OK));
515
516        // Check that other caps are not set
517        assert!(!caps.contains(crate::function::FnCaps::VOLATILE));
518        assert!(!caps.contains(crate::function::FnCaps::ELEMENTWISE));
519    }
520
521    #[test]
522    fn test_sum_basic() {
523        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(SumFn));
524        let ctx = interp(&wb);
525        let fctx = ctx.function_context(None);
526
527        // Test basic SUM functionality by creating ArgumentHandles manually
528        let dummy_ast_1 = formualizer_parse::parser::ASTNode::new(
529            formualizer_parse::parser::ASTNodeType::Literal(LiteralValue::Number(1.0)),
530            None,
531        );
532        let dummy_ast_2 = formualizer_parse::parser::ASTNode::new(
533            formualizer_parse::parser::ASTNodeType::Literal(LiteralValue::Number(2.0)),
534            None,
535        );
536        let dummy_ast_3 = formualizer_parse::parser::ASTNode::new(
537            formualizer_parse::parser::ASTNodeType::Literal(LiteralValue::Number(3.0)),
538            None,
539        );
540
541        let args = vec![
542            ArgumentHandle::new(&dummy_ast_1, &ctx),
543            ArgumentHandle::new(&dummy_ast_2, &ctx),
544            ArgumentHandle::new(&dummy_ast_3, &ctx),
545        ];
546
547        let sum_fn = ctx.context.get_function("", "SUM").unwrap();
548        let result = sum_fn.dispatch(&args, &fctx).unwrap().into_literal();
549        assert_eq!(result, LiteralValue::Number(6.0));
550    }
551}
552
553#[cfg(test)]
554mod tests_count {
555    use super::*;
556    use crate::test_workbook::TestWorkbook;
557    use crate::traits::ArgumentHandle;
558    use formualizer_parse::LiteralValue;
559    use formualizer_parse::parser::ASTNode;
560    use formualizer_parse::parser::ASTNodeType;
561
562    fn interp(wb: &TestWorkbook) -> crate::interpreter::Interpreter<'_> {
563        wb.interpreter()
564    }
565
566    #[test]
567    fn count_numbers_ignores_text() {
568        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(CountFn));
569        let ctx = interp(&wb);
570        // COUNT({1,2,"x",3}) => 3
571        let arr = LiteralValue::Array(vec![vec![
572            LiteralValue::Int(1),
573            LiteralValue::Int(2),
574            LiteralValue::Text("x".into()),
575            LiteralValue::Int(3),
576        ]]);
577        let node = ASTNode::new(ASTNodeType::Literal(arr), None);
578        let args = vec![ArgumentHandle::new(&node, &ctx)];
579        let f = ctx.context.get_function("", "COUNT").unwrap();
580        let fctx = ctx.function_context(None);
581        assert_eq!(
582            f.dispatch(&args, &fctx).unwrap().into_literal(),
583            LiteralValue::Number(3.0)
584        );
585    }
586
587    #[test]
588    fn count_multiple_args_and_scalars() {
589        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(CountFn));
590        let ctx = interp(&wb);
591        let n1 = ASTNode::new(ASTNodeType::Literal(LiteralValue::Int(10)), None);
592        let n2 = ASTNode::new(ASTNodeType::Literal(LiteralValue::Text("n".into())), None);
593        let arr = LiteralValue::Array(vec![vec![LiteralValue::Int(1), LiteralValue::Int(2)]]);
594        let a = ASTNode::new(ASTNodeType::Literal(arr), None);
595        let args = vec![
596            ArgumentHandle::new(&a, &ctx),
597            ArgumentHandle::new(&n1, &ctx),
598            ArgumentHandle::new(&n2, &ctx),
599        ];
600        let f = ctx.context.get_function("", "COUNT").unwrap();
601        // Two from array + scalar 10 = 3
602        let fctx = ctx.function_context(None);
603        assert_eq!(
604            f.dispatch(&args, &fctx).unwrap().into_literal(),
605            LiteralValue::Number(3.0)
606        );
607    }
608
609    #[test]
610    fn count_direct_error_argument_propagates() {
611        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(CountFn));
612        let ctx = interp(&wb);
613        let err = ASTNode::new(
614            ASTNodeType::Literal(LiteralValue::Error(ExcelError::from_error_string(
615                "#DIV/0!",
616            ))),
617            None,
618        );
619        let args = vec![ArgumentHandle::new(&err, &ctx)];
620        let f = ctx.context.get_function("", "COUNT").unwrap();
621        let fctx = ctx.function_context(None);
622        match f.dispatch(&args, &fctx).unwrap().into_literal() {
623            LiteralValue::Error(e) => assert_eq!(e, "#DIV/0!"),
624            v => panic!("unexpected {v:?}"),
625        }
626    }
627}
628
629#[cfg(test)]
630mod tests_average {
631    use super::*;
632    use crate::test_workbook::TestWorkbook;
633    use crate::traits::ArgumentHandle;
634    use formualizer_parse::LiteralValue;
635    use formualizer_parse::parser::ASTNode;
636    use formualizer_parse::parser::ASTNodeType;
637
638    fn interp(wb: &TestWorkbook) -> crate::interpreter::Interpreter<'_> {
639        wb.interpreter()
640    }
641
642    #[test]
643    fn average_basic_numbers() {
644        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(AverageFn));
645        let ctx = interp(&wb);
646        let arr = LiteralValue::Array(vec![vec![
647            LiteralValue::Int(2),
648            LiteralValue::Int(4),
649            LiteralValue::Int(6),
650        ]]);
651        let node = ASTNode::new(ASTNodeType::Literal(arr), None);
652        let args = vec![ArgumentHandle::new(&node, &ctx)];
653        let f = ctx.context.get_function("", "AVERAGE").unwrap();
654        assert_eq!(
655            f.dispatch(&args, &ctx.function_context(None))
656                .unwrap()
657                .into_literal(),
658            LiteralValue::Number(4.0)
659        );
660    }
661
662    #[test]
663    fn average_mixed_with_text() {
664        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(AverageFn));
665        let ctx = interp(&wb);
666        let arr = LiteralValue::Array(vec![vec![
667            LiteralValue::Int(2),
668            LiteralValue::Text("x".into()),
669            LiteralValue::Int(6),
670        ]]);
671        let node = ASTNode::new(ASTNodeType::Literal(arr), None);
672        let args = vec![ArgumentHandle::new(&node, &ctx)];
673        let f = ctx.context.get_function("", "AVERAGE").unwrap();
674        // average of 2 and 6 = 4
675        assert_eq!(
676            f.dispatch(&args, &ctx.function_context(None))
677                .unwrap()
678                .into_literal(),
679            LiteralValue::Number(4.0)
680        );
681    }
682
683    #[test]
684    fn average_no_numeric_div0() {
685        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(AverageFn));
686        let ctx = interp(&wb);
687        let arr = LiteralValue::Array(vec![vec![
688            LiteralValue::Text("a".into()),
689            LiteralValue::Text("b".into()),
690        ]]);
691        let node = ASTNode::new(ASTNodeType::Literal(arr), None);
692        let args = vec![ArgumentHandle::new(&node, &ctx)];
693        let f = ctx.context.get_function("", "AVERAGE").unwrap();
694        let fctx = ctx.function_context(None);
695        match f.dispatch(&args, &fctx).unwrap().into_literal() {
696            LiteralValue::Error(e) => assert_eq!(e, "#DIV/0!"),
697            v => panic!("expected #DIV/0!, got {v:?}"),
698        }
699    }
700
701    #[test]
702    fn average_direct_error_argument_propagates() {
703        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(AverageFn));
704        let ctx = interp(&wb);
705        let err = ASTNode::new(
706            ASTNodeType::Literal(LiteralValue::Error(ExcelError::from_error_string(
707                "#DIV/0!",
708            ))),
709            None,
710        );
711        let args = vec![ArgumentHandle::new(&err, &ctx)];
712        let f = ctx.context.get_function("", "AVERAGE").unwrap();
713        let fctx = ctx.function_context(None);
714        match f.dispatch(&args, &fctx).unwrap().into_literal() {
715            LiteralValue::Error(e) => assert_eq!(e, "#DIV/0!"),
716            v => panic!("unexpected {v:?}"),
717        }
718    }
719}
720
721pub fn register_builtins() {
722    crate::function_registry::register_function(std::sync::Arc::new(SumProductFn));
723    crate::function_registry::register_function(std::sync::Arc::new(SumFn));
724    crate::function_registry::register_function(std::sync::Arc::new(CountFn));
725    crate::function_registry::register_function(std::sync::Arc::new(AverageFn));
726}