formualizer_eval/builtins/math/
aggregate.rs

1use super::super::utils::{ARG_RANGE_NUM_LENIENT_ONE, coerce_num};
2use crate::args::ArgSchema;
3use crate::function::{FnFoldCtx, Function};
4use crate::traits::{ArgumentHandle, FunctionContext};
5use formualizer_common::{ExcelError, LiteralValue};
6use formualizer_macros::func_caps;
7
8/* ─────────────────────────── SUM() ──────────────────────────── */
9
10#[derive(Debug)]
11pub struct SumFn;
12
13impl Function for SumFn {
14    func_caps!(PURE, REDUCTION, NUMERIC_ONLY, STREAM_OK, PARALLEL_ARGS);
15
16    fn name(&self) -> &'static str {
17        "SUM"
18    }
19    fn min_args(&self) -> usize {
20        0
21    }
22    fn variadic(&self) -> bool {
23        true
24    }
25    fn arg_schema(&self) -> &'static [ArgSchema] {
26        &ARG_RANGE_NUM_LENIENT_ONE[..]
27    }
28
29    fn eval_scalar<'a, 'b>(
30        &self,
31        args: &'a [ArgumentHandle<'a, 'b>],
32        _ctx: &dyn FunctionContext,
33    ) -> Result<LiteralValue, ExcelError> {
34        let mut total = 0.0;
35        for arg in args {
36            // Try to get a range/view first. If that fails, fall back to a single value.
37            if let Ok(view) = arg.range_view() {
38                view.for_each_cell(&mut |v| {
39                    match v {
40                        LiteralValue::Error(e) => return Err(e.clone()),
41                        _ => {
42                            if let Ok(n) = coerce_num(v) {
43                                total += n;
44                            }
45                        }
46                    }
47                    Ok(())
48                })?;
49            } else {
50                // Fallback for arguments that are not ranges but might be single values or errors.
51                match arg.value()?.as_ref() {
52                    LiteralValue::Error(e) => return Ok(LiteralValue::Error(e.clone())),
53                    v => total += coerce_num(v)?,
54                }
55            }
56        }
57        Ok(LiteralValue::Number(total))
58    }
59
60    fn eval_fold(&self, f: &mut dyn FnFoldCtx) -> Option<Result<LiteralValue, ExcelError>> {
61        let mut acc = 0.0f64;
62        // Stream numeric chunks using the fold context. Use a moderate default chunk size.
63        let mut cb = |chunk: crate::stripes::NumericChunk| -> Result<(), ExcelError> {
64            for &n in chunk.data {
65                acc += n;
66            }
67            Ok(())
68        };
69        if let Err(e) = f.for_each_numeric_chunk(4096, &mut cb) {
70            return Some(Ok(LiteralValue::Error(e)));
71        }
72        let out = LiteralValue::Number(acc);
73        f.write_result(out.clone());
74        Some(Ok(out))
75    }
76}
77
78/* ─────────────────────────── COUNT() ──────────────────────────── */
79
80#[derive(Debug)]
81pub struct CountFn;
82
83impl Function for CountFn {
84    func_caps!(PURE, REDUCTION, NUMERIC_ONLY, STREAM_OK);
85
86    fn name(&self) -> &'static str {
87        "COUNT"
88    }
89    fn min_args(&self) -> usize {
90        0
91    }
92    fn variadic(&self) -> bool {
93        true
94    }
95    fn arg_schema(&self) -> &'static [ArgSchema] {
96        &ARG_RANGE_NUM_LENIENT_ONE[..]
97    }
98
99    fn eval_scalar<'a, 'b>(
100        &self,
101        args: &'a [ArgumentHandle<'a, 'b>],
102        _ctx: &dyn FunctionContext,
103    ) -> Result<LiteralValue, ExcelError> {
104        let mut count: i64 = 0;
105        for arg in args {
106            if let Ok(view) = arg.range_view() {
107                view.for_each_cell(&mut |v| {
108                    if !matches!(v, LiteralValue::Empty) && coerce_num(v).is_ok() {
109                        count += 1;
110                    }
111                    Ok(())
112                })?;
113            } else {
114                match arg.value()?.as_ref() {
115                    LiteralValue::Error(e) => return Ok(LiteralValue::Error(e.clone())),
116                    v => {
117                        if !matches!(v, LiteralValue::Empty) && coerce_num(v).is_ok() {
118                            count += 1;
119                        }
120                    }
121                }
122            }
123        }
124        Ok(LiteralValue::Number(count as f64))
125    }
126
127    fn eval_fold(&self, f: &mut dyn FnFoldCtx) -> Option<Result<LiteralValue, ExcelError>> {
128        let mut cnt: i64 = 0;
129        let mut cb = |chunk: crate::stripes::NumericChunk| -> Result<(), ExcelError> {
130            // Empty cells are excluded at packing time; all values here are numerics
131            cnt += chunk.data.len() as i64;
132            Ok(())
133        };
134        if let Err(e) = f.for_each_numeric_chunk(4096, &mut cb) {
135            return Some(Ok(LiteralValue::Error(e)));
136        }
137        let out = LiteralValue::Number(cnt as f64);
138        f.write_result(out.clone());
139        Some(Ok(out))
140    }
141}
142
143/* ─────────────────────────── AVERAGE() ──────────────────────────── */
144
145#[derive(Debug)]
146pub struct AverageFn;
147
148impl Function for AverageFn {
149    func_caps!(PURE, REDUCTION, NUMERIC_ONLY, STREAM_OK);
150
151    fn name(&self) -> &'static str {
152        "AVERAGE"
153    }
154    fn min_args(&self) -> usize {
155        1
156    }
157    fn variadic(&self) -> bool {
158        true
159    }
160    fn arg_schema(&self) -> &'static [ArgSchema] {
161        &ARG_RANGE_NUM_LENIENT_ONE[..]
162    }
163
164    fn eval_scalar<'a, 'b>(
165        &self,
166        args: &'a [ArgumentHandle<'a, 'b>],
167        _ctx: &dyn FunctionContext,
168    ) -> Result<LiteralValue, ExcelError> {
169        let mut sum = 0.0f64;
170        let mut cnt: i64 = 0;
171        for arg in args {
172            if let Ok(view) = arg.range_view() {
173                view.for_each_cell(&mut |v| {
174                    if let Ok(n) = coerce_num(v) {
175                        sum += n;
176                        cnt += 1;
177                    }
178                    Ok(())
179                })?;
180            } else {
181                match arg.value()?.as_ref() {
182                    LiteralValue::Error(e) => return Ok(LiteralValue::Error(e.clone())),
183                    v => {
184                        if let Ok(n) = coerce_num(v) {
185                            sum += n;
186                            cnt += 1;
187                        }
188                    }
189                }
190            }
191        }
192        if cnt == 0 {
193            return Ok(LiteralValue::Error(ExcelError::from_error_string(
194                "#DIV/0!",
195            )));
196        }
197        Ok(LiteralValue::Number(sum / (cnt as f64)))
198    }
199
200    fn eval_fold(&self, f: &mut dyn FnFoldCtx) -> Option<Result<LiteralValue, ExcelError>> {
201        let mut sum = 0.0f64;
202        let mut cnt: i64 = 0;
203        let mut cb = |chunk: crate::stripes::NumericChunk| -> Result<(), ExcelError> {
204            for &n in chunk.data {
205                sum += n;
206                cnt += 1;
207            }
208            Ok(())
209        };
210        if let Err(e) = f.for_each_numeric_chunk(4096, &mut cb) {
211            return Some(Ok(LiteralValue::Error(e)));
212        }
213        if cnt == 0 {
214            let e = ExcelError::new_div();
215            f.write_result(LiteralValue::Error(e.clone()));
216            return Some(Ok(LiteralValue::Error(e)));
217        }
218        let out = LiteralValue::Number(sum / (cnt as f64));
219        f.write_result(out.clone());
220        Some(Ok(out))
221    }
222}
223
224/* ──────────────────────── SUMPRODUCT() ───────────────────────── */
225
226#[derive(Debug)]
227pub struct SumProductFn;
228
229impl Function for SumProductFn {
230    // Pure reduction over arrays; uses broadcasting and lenient coercion
231    func_caps!(PURE, REDUCTION);
232
233    fn name(&self) -> &'static str {
234        "SUMPRODUCT"
235    }
236    fn min_args(&self) -> usize {
237        1
238    }
239    fn variadic(&self) -> bool {
240        true
241    }
242    fn arg_schema(&self) -> &'static [ArgSchema] {
243        // Accept ranges or scalars; numeric lenient coercion
244        &ARG_RANGE_NUM_LENIENT_ONE[..]
245    }
246
247    fn eval_scalar<'a, 'b>(
248        &self,
249        args: &'a [ArgumentHandle<'a, 'b>],
250        _ctx: &dyn FunctionContext,
251    ) -> Result<LiteralValue, ExcelError> {
252        use crate::broadcast::{broadcast_shape, project_index};
253
254        if args.is_empty() {
255            return Ok(LiteralValue::Number(0.0));
256        }
257
258        // Helper: materialize an argument to a 2D array of LiteralValue
259        let to_array = |ah: &ArgumentHandle| -> Result<Vec<Vec<LiteralValue>>, ExcelError> {
260            if let Ok(rv) = ah.range_view() {
261                let mut rows: Vec<Vec<LiteralValue>> = Vec::new();
262                rv.for_each_row(&mut |row| {
263                    rows.push(row.to_vec());
264                    Ok(())
265                })?;
266                Ok(rows)
267            } else {
268                let v = ah.value()?;
269                Ok(match v.as_ref() {
270                    LiteralValue::Array(arr) => arr.clone(),
271                    other => vec![vec![other.clone()]],
272                })
273            }
274        };
275
276        // Collect arrays and shapes
277        let mut arrays: Vec<Vec<Vec<LiteralValue>>> = Vec::with_capacity(args.len());
278        let mut shapes: Vec<(usize, usize)> = Vec::with_capacity(args.len());
279        for a in args.iter() {
280            let arr = to_array(a)?;
281            let shape = (arr.len(), arr.first().map(|r| r.len()).unwrap_or(0));
282            arrays.push(arr);
283            shapes.push(shape);
284        }
285
286        // Compute broadcast target shape across all args
287        let target = match broadcast_shape(&shapes) {
288            Ok(s) => s,
289            Err(_) => {
290                return Ok(LiteralValue::Error(ExcelError::new_value()));
291            }
292        };
293
294        // Iterate target shape, multiply coerced values across args, sum total
295        let mut total = 0.0f64;
296        for r in 0..target.0 {
297            for c in 0..target.1 {
298                let mut prod = 1.0f64;
299                for (arr, &shape) in arrays.iter().zip(shapes.iter()) {
300                    let (rr, cc) = project_index((r, c), shape);
301                    let lv = arr
302                        .get(rr)
303                        .and_then(|row| row.get(cc))
304                        .cloned()
305                        .unwrap_or(LiteralValue::Empty);
306                    match lv {
307                        LiteralValue::Error(e) => return Ok(LiteralValue::Error(e)),
308                        _ => match super::super::utils::coerce_num(&lv) {
309                            Ok(n) => {
310                                prod *= n;
311                            }
312                            Err(_) => {
313                                // Non-numeric -> treated as 0 in SUMPRODUCT
314                                prod *= 0.0;
315                            }
316                        },
317                    }
318                }
319                total += prod;
320            }
321        }
322        Ok(LiteralValue::Number(total))
323    }
324}
325
326#[cfg(test)]
327mod tests_sumproduct {
328    use super::*;
329    use crate::test_workbook::TestWorkbook;
330    use crate::traits::ArgumentHandle;
331    use formualizer_parse::LiteralValue;
332    use formualizer_parse::parser::{ASTNode, ASTNodeType};
333
334    fn interp(wb: &TestWorkbook) -> crate::interpreter::Interpreter<'_> {
335        wb.interpreter()
336    }
337
338    fn arr(vals: Vec<Vec<LiteralValue>>) -> ASTNode {
339        ASTNode::new(ASTNodeType::Literal(LiteralValue::Array(vals)), None)
340    }
341
342    fn num(n: f64) -> ASTNode {
343        ASTNode::new(ASTNodeType::Literal(LiteralValue::Number(n)), None)
344    }
345
346    #[test]
347    fn sumproduct_basic_pairwise() {
348        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(SumProductFn));
349        let ctx = interp(&wb);
350        // {1,2,3} * {4,5,6} = 1*4 + 2*5 + 3*6 = 32
351        let a = arr(vec![vec![
352            LiteralValue::Int(1),
353            LiteralValue::Int(2),
354            LiteralValue::Int(3),
355        ]]);
356        let b = arr(vec![vec![
357            LiteralValue::Int(4),
358            LiteralValue::Int(5),
359            LiteralValue::Int(6),
360        ]]);
361        let args = vec![ArgumentHandle::new(&a, &ctx), ArgumentHandle::new(&b, &ctx)];
362        let f = ctx.context.get_function("", "SUMPRODUCT").unwrap();
363        assert_eq!(
364            f.dispatch(&args, &ctx.function_context(None)).unwrap(),
365            LiteralValue::Number(32.0)
366        );
367    }
368
369    #[test]
370    fn sumproduct_variadic_three_arrays() {
371        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(SumProductFn));
372        let ctx = interp(&wb);
373        // {1,2} * {3,4} * {2,2} = (1*3*2) + (2*4*2) = 6 + 16 = 22
374        let a = arr(vec![vec![LiteralValue::Int(1), LiteralValue::Int(2)]]);
375        let b = arr(vec![vec![LiteralValue::Int(3), LiteralValue::Int(4)]]);
376        let c = arr(vec![vec![LiteralValue::Int(2), LiteralValue::Int(2)]]);
377        let args = vec![
378            ArgumentHandle::new(&a, &ctx),
379            ArgumentHandle::new(&b, &ctx),
380            ArgumentHandle::new(&c, &ctx),
381        ];
382        let f = ctx.context.get_function("", "SUMPRODUCT").unwrap();
383        assert_eq!(
384            f.dispatch(&args, &ctx.function_context(None)).unwrap(),
385            LiteralValue::Number(22.0)
386        );
387    }
388
389    #[test]
390    fn sumproduct_broadcast_scalar_over_array() {
391        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(SumProductFn));
392        let ctx = interp(&wb);
393        // {1,2,3} * 10 => (1*10 + 2*10 + 3*10) = 60
394        let a = arr(vec![vec![
395            LiteralValue::Int(1),
396            LiteralValue::Int(2),
397            LiteralValue::Int(3),
398        ]]);
399        let s = num(10.0);
400        let args = vec![ArgumentHandle::new(&a, &ctx), ArgumentHandle::new(&s, &ctx)];
401        let f = ctx.context.get_function("", "SUMPRODUCT").unwrap();
402        assert_eq!(
403            f.dispatch(&args, &ctx.function_context(None)).unwrap(),
404            LiteralValue::Number(60.0)
405        );
406    }
407
408    #[test]
409    fn sumproduct_2d_arrays_broadcast_rows_cols() {
410        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(SumProductFn));
411        let ctx = interp(&wb);
412        // A is 2x2, B is 1x2 -> broadcast B across rows
413        // A = [[1,2],[3,4]], B = [[10,20]]
414        // sum = 1*10 + 2*20 + 3*10 + 4*20 = 10 + 40 + 30 + 80 = 160
415        let a = arr(vec![
416            vec![LiteralValue::Int(1), LiteralValue::Int(2)],
417            vec![LiteralValue::Int(3), LiteralValue::Int(4)],
418        ]);
419        let b = arr(vec![vec![LiteralValue::Int(10), LiteralValue::Int(20)]]);
420        let args = vec![ArgumentHandle::new(&a, &ctx), ArgumentHandle::new(&b, &ctx)];
421        let f = ctx.context.get_function("", "SUMPRODUCT").unwrap();
422        assert_eq!(
423            f.dispatch(&args, &ctx.function_context(None)).unwrap(),
424            LiteralValue::Number(160.0)
425        );
426    }
427
428    #[test]
429    fn sumproduct_non_numeric_treated_as_zero() {
430        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(SumProductFn));
431        let ctx = interp(&wb);
432        // {1,"x",3} * {1,1,1} => 1*1 + 0*1 + 3*1 = 4
433        let a = arr(vec![vec![
434            LiteralValue::Int(1),
435            LiteralValue::Text("x".into()),
436            LiteralValue::Int(3),
437        ]]);
438        let b = arr(vec![vec![
439            LiteralValue::Int(1),
440            LiteralValue::Int(1),
441            LiteralValue::Int(1),
442        ]]);
443        let args = vec![ArgumentHandle::new(&a, &ctx), ArgumentHandle::new(&b, &ctx)];
444        let f = ctx.context.get_function("", "SUMPRODUCT").unwrap();
445        assert_eq!(
446            f.dispatch(&args, &ctx.function_context(None)).unwrap(),
447            LiteralValue::Number(4.0)
448        );
449    }
450
451    #[test]
452    fn sumproduct_error_in_input_propagates() {
453        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(SumProductFn));
454        let ctx = interp(&wb);
455        let a = arr(vec![vec![LiteralValue::Int(1), LiteralValue::Int(2)]]);
456        let e = ASTNode::new(
457            ASTNodeType::Literal(LiteralValue::Error(ExcelError::new_na())),
458            None,
459        );
460        let args = vec![ArgumentHandle::new(&a, &ctx), ArgumentHandle::new(&e, &ctx)];
461        let f = ctx.context.get_function("", "SUMPRODUCT").unwrap();
462        match f.dispatch(&args, &ctx.function_context(None)).unwrap() {
463            LiteralValue::Error(err) => assert_eq!(err, "#N/A"),
464            v => panic!("expected error, got {v:?}"),
465        }
466    }
467
468    #[test]
469    fn sumproduct_incompatible_shapes_value_error() {
470        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(SumProductFn));
471        let ctx = interp(&wb);
472        // 1x3 and 1x2 -> #VALUE!
473        let a = arr(vec![vec![
474            LiteralValue::Int(1),
475            LiteralValue::Int(2),
476            LiteralValue::Int(3),
477        ]]);
478        let b = arr(vec![vec![LiteralValue::Int(4), LiteralValue::Int(5)]]);
479        let args = vec![ArgumentHandle::new(&a, &ctx), ArgumentHandle::new(&b, &ctx)];
480        let f = ctx.context.get_function("", "SUMPRODUCT").unwrap();
481        match f.dispatch(&args, &ctx.function_context(None)).unwrap() {
482            LiteralValue::Error(e) => assert_eq!(e, "#VALUE!"),
483            v => panic!("expected value error, got {v:?}"),
484        }
485    }
486}
487
488#[cfg(test)]
489mod tests {
490    use super::*;
491    use crate::test_workbook::TestWorkbook;
492    use formualizer_parse::LiteralValue;
493
494    fn interp(wb: &TestWorkbook) -> crate::interpreter::Interpreter<'_> {
495        wb.interpreter()
496    }
497
498    #[test]
499    fn test_sum_caps() {
500        let sum_fn = SumFn;
501        let caps = sum_fn.caps();
502
503        // Check that the expected capabilities are set
504        assert!(caps.contains(crate::function::FnCaps::PURE));
505        assert!(caps.contains(crate::function::FnCaps::REDUCTION));
506        assert!(caps.contains(crate::function::FnCaps::NUMERIC_ONLY));
507        assert!(caps.contains(crate::function::FnCaps::STREAM_OK));
508
509        // Check that other caps are not set
510        assert!(!caps.contains(crate::function::FnCaps::VOLATILE));
511        assert!(!caps.contains(crate::function::FnCaps::ELEMENTWISE));
512    }
513
514    #[test]
515    fn test_sum_basic() {
516        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(SumFn));
517        let ctx = interp(&wb);
518        let fctx = crate::traits::DefaultFunctionContext::new(ctx.context, None);
519
520        // Test basic SUM functionality by creating ArgumentHandles manually
521        let dummy_ast_1 = formualizer_parse::parser::ASTNode::new(
522            formualizer_parse::parser::ASTNodeType::Literal(LiteralValue::Number(1.0)),
523            None,
524        );
525        let dummy_ast_2 = formualizer_parse::parser::ASTNode::new(
526            formualizer_parse::parser::ASTNodeType::Literal(LiteralValue::Number(2.0)),
527            None,
528        );
529        let dummy_ast_3 = formualizer_parse::parser::ASTNode::new(
530            formualizer_parse::parser::ASTNodeType::Literal(LiteralValue::Number(3.0)),
531            None,
532        );
533
534        let args = vec![
535            ArgumentHandle::new(&dummy_ast_1, &ctx),
536            ArgumentHandle::new(&dummy_ast_2, &ctx),
537            ArgumentHandle::new(&dummy_ast_3, &ctx),
538        ];
539
540        let sum_fn = ctx.context.get_function("", "SUM").unwrap();
541        let result = sum_fn.dispatch(&args, &fctx).unwrap();
542        assert_eq!(result, LiteralValue::Number(6.0));
543    }
544}
545
546#[cfg(test)]
547mod tests_count {
548    use super::*;
549    use crate::test_workbook::TestWorkbook;
550    use crate::traits::ArgumentHandle;
551    use formualizer_parse::LiteralValue;
552    use formualizer_parse::parser::ASTNode;
553    use formualizer_parse::parser::ASTNodeType;
554
555    fn interp(wb: &TestWorkbook) -> crate::interpreter::Interpreter<'_> {
556        wb.interpreter()
557    }
558
559    #[test]
560    fn count_numbers_ignores_text() {
561        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(CountFn));
562        let ctx = interp(&wb);
563        // COUNT({1,2,"x",3}) => 3
564        let arr = LiteralValue::Array(vec![vec![
565            LiteralValue::Int(1),
566            LiteralValue::Int(2),
567            LiteralValue::Text("x".into()),
568            LiteralValue::Int(3),
569        ]]);
570        let node = ASTNode::new(ASTNodeType::Literal(arr), None);
571        let args = vec![ArgumentHandle::new(&node, &ctx)];
572        let f = ctx.context.get_function("", "COUNT").unwrap();
573        let fctx = crate::traits::DefaultFunctionContext::new(ctx.context, None);
574        assert_eq!(f.dispatch(&args, &fctx).unwrap(), LiteralValue::Number(3.0));
575    }
576
577    #[test]
578    fn count_multiple_args_and_scalars() {
579        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(CountFn));
580        let ctx = interp(&wb);
581        let n1 = ASTNode::new(ASTNodeType::Literal(LiteralValue::Int(10)), None);
582        let n2 = ASTNode::new(ASTNodeType::Literal(LiteralValue::Text("n".into())), None);
583        let arr = LiteralValue::Array(vec![vec![LiteralValue::Int(1), LiteralValue::Int(2)]]);
584        let a = ASTNode::new(ASTNodeType::Literal(arr), None);
585        let args = vec![
586            ArgumentHandle::new(&a, &ctx),
587            ArgumentHandle::new(&n1, &ctx),
588            ArgumentHandle::new(&n2, &ctx),
589        ];
590        let f = ctx.context.get_function("", "COUNT").unwrap();
591        // Two from array + scalar 10 = 3
592        let fctx = crate::traits::DefaultFunctionContext::new(ctx.context, None);
593        assert_eq!(f.dispatch(&args, &fctx).unwrap(), LiteralValue::Number(3.0));
594    }
595
596    #[test]
597    fn count_direct_error_argument_propagates() {
598        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(CountFn));
599        let ctx = interp(&wb);
600        let err = ASTNode::new(
601            ASTNodeType::Literal(LiteralValue::Error(ExcelError::from_error_string(
602                "#DIV/0!",
603            ))),
604            None,
605        );
606        let args = vec![ArgumentHandle::new(&err, &ctx)];
607        let f = ctx.context.get_function("", "COUNT").unwrap();
608        let fctx = crate::traits::DefaultFunctionContext::new(ctx.context, None);
609        match f.dispatch(&args, &fctx).unwrap() {
610            LiteralValue::Error(e) => assert_eq!(e, "#DIV/0!"),
611            v => panic!("unexpected {v:?}"),
612        }
613    }
614}
615
616#[cfg(test)]
617mod tests_average {
618    use super::*;
619    use crate::test_workbook::TestWorkbook;
620    use crate::traits::ArgumentHandle;
621    use formualizer_parse::LiteralValue;
622    use formualizer_parse::parser::ASTNode;
623    use formualizer_parse::parser::ASTNodeType;
624
625    fn interp(wb: &TestWorkbook) -> crate::interpreter::Interpreter<'_> {
626        wb.interpreter()
627    }
628
629    #[test]
630    fn average_basic_numbers() {
631        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(AverageFn));
632        let ctx = interp(&wb);
633        let arr = LiteralValue::Array(vec![vec![
634            LiteralValue::Int(2),
635            LiteralValue::Int(4),
636            LiteralValue::Int(6),
637        ]]);
638        let node = ASTNode::new(ASTNodeType::Literal(arr), None);
639        let args = vec![ArgumentHandle::new(&node, &ctx)];
640        let f = ctx.context.get_function("", "AVERAGE").unwrap();
641        assert_eq!(
642            f.dispatch(&args, &ctx.function_context(None)).unwrap(),
643            LiteralValue::Number(4.0)
644        );
645    }
646
647    #[test]
648    fn average_mixed_with_text() {
649        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(AverageFn));
650        let ctx = interp(&wb);
651        let arr = LiteralValue::Array(vec![vec![
652            LiteralValue::Int(2),
653            LiteralValue::Text("x".into()),
654            LiteralValue::Int(6),
655        ]]);
656        let node = ASTNode::new(ASTNodeType::Literal(arr), None);
657        let args = vec![ArgumentHandle::new(&node, &ctx)];
658        let f = ctx.context.get_function("", "AVERAGE").unwrap();
659        // average of 2 and 6 = 4
660        assert_eq!(
661            f.dispatch(&args, &ctx.function_context(None)).unwrap(),
662            LiteralValue::Number(4.0)
663        );
664    }
665
666    #[test]
667    fn average_no_numeric_div0() {
668        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(AverageFn));
669        let ctx = interp(&wb);
670        let arr = LiteralValue::Array(vec![vec![
671            LiteralValue::Text("a".into()),
672            LiteralValue::Text("b".into()),
673        ]]);
674        let node = ASTNode::new(ASTNodeType::Literal(arr), None);
675        let args = vec![ArgumentHandle::new(&node, &ctx)];
676        let f = ctx.context.get_function("", "AVERAGE").unwrap();
677        let fctx = crate::traits::DefaultFunctionContext::new(ctx.context, None);
678        match f.dispatch(&args, &fctx).unwrap() {
679            LiteralValue::Error(e) => assert_eq!(e, "#DIV/0!"),
680            v => panic!("expected #DIV/0!, got {v:?}"),
681        }
682    }
683
684    #[test]
685    fn average_direct_error_argument_propagates() {
686        let wb = TestWorkbook::new().with_function(std::sync::Arc::new(AverageFn));
687        let ctx = interp(&wb);
688        let err = ASTNode::new(
689            ASTNodeType::Literal(LiteralValue::Error(ExcelError::from_error_string(
690                "#DIV/0!",
691            ))),
692            None,
693        );
694        let args = vec![ArgumentHandle::new(&err, &ctx)];
695        let f = ctx.context.get_function("", "AVERAGE").unwrap();
696        let fctx = crate::traits::DefaultFunctionContext::new(ctx.context, None);
697        match f.dispatch(&args, &fctx).unwrap() {
698            LiteralValue::Error(e) => assert_eq!(e, "#DIV/0!"),
699            v => panic!("unexpected {v:?}"),
700        }
701    }
702}
703
704pub fn register_builtins() {
705    crate::function_registry::register_function(std::sync::Arc::new(SumProductFn));
706    crate::function_registry::register_function(std::sync::Arc::new(SumFn));
707    crate::function_registry::register_function(std::sync::Arc::new(CountFn));
708    crate::function_registry::register_function(std::sync::Arc::new(AverageFn));
709}