1use super::super::utils::{ARG_ANY_ONE, coerce_num, criteria_match};
2use crate::args::ArgSchema;
3use crate::compute_prelude::{boolean, cmp, filter_array};
4use crate::function::Function;
5use crate::traits::{ArgumentHandle, FunctionContext};
6use arrow::compute::kernels::aggregate::sum_array;
7use arrow_array::types::Float64Type;
8use arrow_array::{Array as _, BooleanArray, Float64Array};
9use formualizer_common::{ExcelError, LiteralValue};
10use formualizer_macros::func_caps;
11
12#[cfg(test)]
13pub(crate) mod test_hooks {
14 use std::cell::Cell;
15
16 thread_local! {
17 static CACHED_MASK_SLICE_FAST: Cell<usize> = const { Cell::new(0) };
18 static CACHED_MASK_PAD_PARTIAL: Cell<usize> = const { Cell::new(0) };
19 static CACHED_MASK_PAD_ALL_FILL: Cell<usize> = const { Cell::new(0) };
20 }
21
22 pub fn reset_cached_mask_counters() {
23 CACHED_MASK_SLICE_FAST.with(|c| c.set(0));
24 CACHED_MASK_PAD_PARTIAL.with(|c| c.set(0));
25 CACHED_MASK_PAD_ALL_FILL.with(|c| c.set(0));
26 }
27
28 pub fn cached_mask_counters() -> (usize, usize, usize) {
29 let a = CACHED_MASK_SLICE_FAST.with(|c| c.get());
30 let b = CACHED_MASK_PAD_PARTIAL.with(|c| c.get());
31 let d = CACHED_MASK_PAD_ALL_FILL.with(|c| c.get());
32 (a, b, d)
33 }
34
35 pub(crate) fn inc_slice_fast() {
36 CACHED_MASK_SLICE_FAST.with(|c| c.set(c.get() + 1));
37 }
38 pub(crate) fn inc_pad_partial() {
39 CACHED_MASK_PAD_PARTIAL.with(|c| c.set(c.get() + 1));
40 }
41 pub(crate) fn inc_pad_all_fill() {
42 CACHED_MASK_PAD_ALL_FILL.with(|c| c.set(c.get() + 1));
43 }
44}
45
46#[derive(Debug, Clone, Copy, PartialEq, Eq)]
63enum AggregationType {
64 Sum,
65 Count,
66 Average,
67}
68
69fn eval_if_family<'a, 'b>(
70 args: &[ArgumentHandle<'a, 'b>],
71 ctx: &dyn FunctionContext<'b>,
72 agg_type: AggregationType,
73 multi: bool,
74) -> Result<crate::traits::CalcValue<'b>, ExcelError> {
75 let mut sum_view: Option<crate::engine::range_view::RangeView<'_>> = None;
76 let mut sum_scalar: Option<LiteralValue> = None;
77 let mut crit_specs = Vec::new();
78
79 if !multi {
80 if args.len() < 2 || args.len() > 3 {
82 return Ok(crate::traits::CalcValue::Scalar(LiteralValue::Error(
83 ExcelError::new_value().with_message(format!(
84 "Function expects 2 or 3 arguments, got {}",
85 args.len()
86 )),
87 )));
88 }
89 let pred = crate::args::parse_criteria(&args[1].value()?.into_literal())?;
90 let crit_rv = args[0].range_view().ok();
91 let crit_val = if crit_rv.is_none() {
92 Some(args[0].value()?.into_literal())
93 } else {
94 None
95 };
96 crit_specs.push((crit_rv, pred, crit_val));
97
98 if agg_type != AggregationType::Count {
99 if args.len() == 3 {
100 if let Ok(v) = args[2].range_view() {
101 let crit_dims = crit_specs[0].0.as_ref().map(|v| v.dims()).unwrap_or((1, 1));
102 sum_view = Some(v.expand_to(crit_dims.0, crit_dims.1));
103 } else {
104 sum_scalar = Some(args[2].value()?.into_literal());
105 }
106 } else {
107 if let Ok(v) = args[0].range_view() {
109 sum_view = Some(v);
110 } else {
111 sum_scalar = Some(args[0].value()?.into_literal());
112 }
113 }
114 }
115 } else {
116 if agg_type == AggregationType::Count {
118 if args.len() < 2 || !args.len().is_multiple_of(2) {
119 return Ok(crate::traits::CalcValue::Scalar(LiteralValue::Error(
120 ExcelError::new_value().with_message(format!(
121 "COUNTIFS expects N pairs (criteria_range, criteria); got {} args",
122 args.len()
123 )),
124 )));
125 }
126 for i in (0..args.len()).step_by(2) {
127 let mut rv = args[i].range_view().ok();
128 let mut val: Option<LiteralValue> = None;
129
130 if let Some(ref view) = rv {
132 let (r, c) = view.dims();
133 if r == 1 && c == 1 {
134 val = Some(view.as_1x1().unwrap_or(LiteralValue::Empty));
135 rv = None;
136 }
137 }
138
139 if val.is_none() && rv.is_none() {
140 val = Some(args[i].value()?.into_literal());
141 }
142
143 let pred = crate::args::parse_criteria(&args[i + 1].value()?.into_literal())?;
144 crit_specs.push((rv, pred, val));
145 }
146 } else {
147 if args.len() < 3 || !(args.len() - 1).is_multiple_of(2) {
148 return Ok(crate::traits::CalcValue::Scalar(LiteralValue::Error(
149 ExcelError::new_value().with_message(format!(
150 "Function expects 1 target_range followed by N pairs (criteria_range, criteria); got {} args",
151 args.len()
152 )),
153 )));
154 }
155 if let Ok(v) = args[0].range_view() {
156 sum_view = Some(v);
157 } else {
158 sum_scalar = Some(args[0].value()?.into_literal());
159 }
160 for i in (1..args.len()).step_by(2) {
161 let mut rv = args[i].range_view().ok();
162 let mut val: Option<LiteralValue> = None;
163
164 if let Some(ref view) = rv {
166 let (r, c) = view.dims();
167 if r == 1 && c == 1 {
168 val = Some(view.as_1x1().unwrap_or(LiteralValue::Empty));
169 rv = None;
170 }
171 }
172
173 if val.is_none() && rv.is_none() {
174 val = Some(args[i].value()?.into_literal());
175 }
176
177 let pred = crate::args::parse_criteria(&args[i + 1].value()?.into_literal())?;
178 crit_specs.push((rv, pred, val));
179 }
180 }
181 }
182
183 let mut dims = (1usize, 1usize);
185 if let Some(ref sv) = sum_view {
186 dims = sv.dims();
187 }
188 for (rv, _, _) in &crit_specs {
189 if let Some(v) = rv {
190 let vd = v.dims();
191 dims.0 = dims.0.max(vd.0);
192 dims.1 = dims.1.max(vd.1);
193 }
194 }
195
196 let mut total_sum = 0.0f64;
201 let mut total_count = 0i64;
202
203 let driver = sum_view
205 .as_ref()
206 .or_else(|| crit_specs.iter().find_map(|(rv, _, _)| rv.as_ref()));
207
208 if let Some(drv) = driver {
209 let driver = if !multi && crit_specs[0].0.is_some() {
214 crit_specs[0].0.as_ref().unwrap()
215 } else {
216 drv
217 };
218
219 for res in driver.iter_row_chunks() {
220 let cs = res?;
221 let row_start = cs.row_start;
222 let row_len = cs.row_len;
223 if row_len == 0 {
224 continue;
225 }
226
227 let mut crit_num_slices = Vec::with_capacity(crit_specs.len());
229 let mut crit_text_slices = Vec::with_capacity(crit_specs.len());
230 for (rv, _, _) in &crit_specs {
231 if let Some(v) = rv {
232 crit_num_slices.push(Some(v.slice_numbers(row_start, row_len)));
233 crit_text_slices.push(Some(v.slice_lowered_text(row_start, row_len)));
234 } else {
235 crit_num_slices.push(None);
236 crit_text_slices.push(None);
237 }
238 }
239
240 let sum_slices = sum_view
241 .as_ref()
242 .map(|v| v.slice_numbers(row_start, row_len));
243
244 for c in 0..dims.1 {
245 let mut mask_opt: Option<BooleanArray> = None;
246 let mut impossible = false;
247
248 for (j, (_, pred, scalar_val)) in crit_specs.iter().enumerate() {
249 if crit_specs[j].0.is_none() {
250 if let Some(sv) = scalar_val {
251 if !criteria_match(pred, sv) {
252 impossible = true;
253 break;
254 }
255 continue;
256 }
257 if !criteria_match(pred, &LiteralValue::Empty) {
258 impossible = true;
259 break;
260 }
261 continue;
262 }
263
264 let cur_cached = if let Some(ref view) = crit_specs[j].0 {
266 ctx.get_criteria_mask(view, c, pred).map(|m| {
267 let fill = criteria_match(pred, &LiteralValue::Empty);
268 let m_len = m.len();
269
270 if row_start + row_len <= m_len {
274 #[cfg(test)]
275 test_hooks::inc_slice_fast();
276 let sl = m.slice(row_start, row_len);
277 return sl
278 .as_any()
279 .downcast_ref::<arrow_array::BooleanArray>()
280 .expect("cached criteria mask slice downcast")
281 .clone();
282 }
283
284 let mut bb =
285 arrow_array::builder::BooleanBuilder::with_capacity(row_len);
286 if row_start < m_len {
287 #[cfg(test)]
288 test_hooks::inc_pad_partial();
289 let take_len = row_len.min(m_len - row_start);
290 let sl = m.slice(row_start, take_len);
291 let ba = sl
292 .as_any()
293 .downcast_ref::<arrow_array::BooleanArray>()
294 .expect("cached criteria mask slice downcast");
295 bb.append_array(ba);
296 bb.append_n(row_len - take_len, fill);
297 } else {
298 #[cfg(test)]
299 test_hooks::inc_pad_all_fill();
300 bb.append_n(row_len, fill);
301 }
302
303 bb.finish()
304 })
305 } else {
306 None
307 };
308
309 if let Some(cm) = cur_cached {
310 mask_opt = Some(match mask_opt {
311 None => cm,
312 Some(prev) => boolean::and_kleene(&prev, &cm).unwrap(),
313 });
314 continue;
315 }
316
317 let num_col = crit_num_slices[j]
319 .as_ref()
320 .and_then(|cols| cols.get(c).and_then(|a| a.as_ref()));
321 let text_col = crit_text_slices[j]
322 .as_ref()
323 .and_then(|cols| cols.get(c).and_then(|a| a.as_ref()));
324
325 let m = match (pred, num_col, text_col) {
326 (crate::args::CriteriaPredicate::Gt(n), Some(nc), _) => {
327 cmp::gt(nc.as_ref(), &Float64Array::new_scalar(*n)).unwrap()
328 }
329 (crate::args::CriteriaPredicate::Ge(n), Some(nc), _) => {
330 cmp::gt_eq(nc.as_ref(), &Float64Array::new_scalar(*n)).unwrap()
331 }
332 (crate::args::CriteriaPredicate::Lt(n), Some(nc), _) => {
333 cmp::lt(nc.as_ref(), &Float64Array::new_scalar(*n)).unwrap()
334 }
335 (crate::args::CriteriaPredicate::Le(n), Some(nc), _) => {
336 cmp::lt_eq(nc.as_ref(), &Float64Array::new_scalar(*n)).unwrap()
337 }
338 (crate::args::CriteriaPredicate::Eq(v), nc, tc) => {
339 match v {
340 LiteralValue::Number(x) => {
341 let nx = *x;
342 if let Some(nc) = nc {
343 cmp::eq(nc.as_ref(), &Float64Array::new_scalar(nx)).unwrap()
344 } else {
345 BooleanArray::new_null(row_len)
346 }
347 }
348 LiteralValue::Int(x) => {
349 let nx = *x as f64;
350 if let Some(nc) = nc {
351 cmp::eq(nc.as_ref(), &Float64Array::new_scalar(nx)).unwrap()
352 } else {
353 BooleanArray::new_null(row_len)
354 }
355 }
356 _ => {
357 let mut bb =
359 arrow_array::builder::BooleanBuilder::with_capacity(
360 row_len,
361 );
362 let view = crit_specs[j].0.as_ref().unwrap();
363 for i in 0..row_len {
364 bb.append_value(criteria_match(
365 pred,
366 &view.get_cell(row_start + i, c),
367 ));
368 }
369 bb.finish()
370 }
371 }
372 }
373 (crate::args::CriteriaPredicate::Ne(v), nc, tc) => match v {
374 LiteralValue::Number(x) => {
375 let nx = *x;
376 if let Some(nc) = nc {
377 cmp::neq(nc.as_ref(), &Float64Array::new_scalar(nx)).unwrap()
378 } else {
379 BooleanArray::from(vec![true; row_len])
380 }
381 }
382 LiteralValue::Int(x) => {
383 let nx = *x as f64;
384 if let Some(nc) = nc {
385 cmp::neq(nc.as_ref(), &Float64Array::new_scalar(nx)).unwrap()
386 } else {
387 BooleanArray::from(vec![true; row_len])
388 }
389 }
390 _ => {
391 let mut bb =
392 arrow_array::builder::BooleanBuilder::with_capacity(row_len);
393 let view = crit_specs[j].0.as_ref().unwrap();
394 for i in 0..row_len {
395 bb.append_value(criteria_match(
396 pred,
397 &view.get_cell(row_start + i, c),
398 ));
399 }
400 bb.finish()
401 }
402 },
403 (crate::args::CriteriaPredicate::TextLike { .. }, _, _) => {
404 let mut bb =
405 arrow_array::builder::BooleanBuilder::with_capacity(row_len);
406 let view = crit_specs[j].0.as_ref().unwrap();
407 for i in 0..row_len {
408 bb.append_value(criteria_match(
409 pred,
410 &view.get_cell(row_start + i, c),
411 ));
412 }
413 bb.finish()
414 }
415 _ => {
416 let mut bb =
418 arrow_array::builder::BooleanBuilder::with_capacity(row_len);
419 if let Some(ref view) = crit_specs[j].0 {
420 for i in 0..row_len {
421 bb.append_value(criteria_match(
422 pred,
423 &view.get_cell(row_start + i, c),
424 ));
425 }
426 } else {
427 let val = scalar_val.as_ref().unwrap_or(&LiteralValue::Empty);
428 let matches = criteria_match(pred, val);
429 for _ in 0..row_len {
430 bb.append_value(matches);
431 }
432 }
433 bb.finish()
434 }
435 };
436
437 mask_opt = Some(match mask_opt {
438 None => m,
439 Some(prev) => boolean::and_kleene(&prev, &m).unwrap(),
440 });
441 }
442
443 if impossible {
444 continue;
445 }
446
447 match mask_opt {
448 Some(mask) => {
449 if agg_type == AggregationType::Count {
450 total_count += (0..mask.len())
451 .filter(|&i| mask.is_valid(i) && mask.value(i))
452 .count() as i64;
453 } else {
454 let target_col = sum_slices
455 .as_ref()
456 .and_then(|cols| cols.get(c).and_then(|a| a.as_ref()));
457 if let Some(tc) = target_col {
458 let filtered = filter_array(tc.as_ref(), &mask).unwrap();
459 let f64_arr =
460 filtered.as_any().downcast_ref::<Float64Array>().unwrap();
461 if let Some(s) = sum_array::<Float64Type, _>(f64_arr) {
462 total_sum += s;
463 }
464 total_count += f64_arr.len() as i64 - f64_arr.null_count() as i64;
465 } else if let Some(ref s) = sum_scalar
466 && let Ok(n) = coerce_num(s)
467 {
468 let count = (0..mask.len())
469 .filter(|&i| mask.is_valid(i) && mask.value(i))
470 .count() as i64;
471 total_sum += n * count as f64;
472 total_count += count;
473 }
474 }
475 }
476 None => {
477 if agg_type == AggregationType::Count {
479 total_count += row_len as i64;
480 } else {
481 let target_col = sum_slices
482 .as_ref()
483 .and_then(|cols| cols.get(c).and_then(|a| a.as_ref()));
484 if let Some(tc) = target_col {
485 if let Some(s) = sum_array::<Float64Type, _>(tc.as_ref()) {
486 total_sum += s;
487 }
488 total_count += tc.len() as i64 - tc.null_count() as i64;
489 } else if let Some(ref s) = sum_scalar
490 && let Ok(n) = coerce_num(s)
491 {
492 total_sum += n * row_len as f64;
493 total_count += row_len as i64;
494 }
495 }
496 }
497 }
498 }
499 }
500 } else {
501 let mut all_match = true;
503 for (_, pred, scalar_val) in &crit_specs {
504 let val = scalar_val.as_ref().unwrap_or(&LiteralValue::Empty);
505 if !criteria_match(pred, val) {
506 all_match = false;
507 break;
508 }
509 }
510 if all_match {
511 if agg_type == AggregationType::Count {
512 total_count = (dims.0 * dims.1) as i64;
513 } else if let Some(ref s) = sum_scalar
514 && let Ok(n) = coerce_num(s)
515 {
516 total_sum = n * (dims.0 * dims.1) as f64;
517 total_count = (dims.0 * dims.1) as i64;
518 }
519 }
520 }
521
522 match agg_type {
523 AggregationType::Sum => Ok(crate::traits::CalcValue::Scalar(LiteralValue::Number(
524 total_sum,
525 ))),
526 AggregationType::Count => Ok(crate::traits::CalcValue::Scalar(LiteralValue::Number(
527 total_count as f64,
528 ))),
529 AggregationType::Average => {
530 if total_count == 0 {
531 Ok(crate::traits::CalcValue::Scalar(LiteralValue::Error(
532 ExcelError::new_div(),
533 )))
534 } else {
535 Ok(crate::traits::CalcValue::Scalar(LiteralValue::Number(
536 total_sum / total_count as f64,
537 )))
538 }
539 }
540 }
541}
542
543#[derive(Debug)]
545pub struct AverageIfFn;
546impl Function for AverageIfFn {
547 func_caps!(
548 PURE,
549 REDUCTION,
550 WINDOWED,
551 STREAM_OK,
552 PARALLEL_ARGS,
553 PARALLEL_CHUNKS
554 );
555 fn name(&self) -> &'static str {
556 "AVERAGEIF"
557 }
558 fn min_args(&self) -> usize {
559 2
560 }
561 fn variadic(&self) -> bool {
562 true
563 }
564 fn arg_schema(&self) -> &'static [ArgSchema] {
565 &ARG_ANY_ONE[..]
566 }
567 fn eval<'a, 'b, 'c>(
568 &self,
569 args: &'c [ArgumentHandle<'a, 'b>],
570 ctx: &dyn FunctionContext<'b>,
571 ) -> Result<crate::traits::CalcValue<'b>, ExcelError> {
572 eval_if_family(args, ctx, AggregationType::Average, false)
573 }
574}
575
576#[derive(Debug)]
578pub struct SumIfFn;
579impl Function for SumIfFn {
580 func_caps!(
581 PURE,
582 REDUCTION,
583 WINDOWED,
584 STREAM_OK,
585 PARALLEL_ARGS,
586 PARALLEL_CHUNKS
587 );
588 fn name(&self) -> &'static str {
589 "SUMIF"
590 }
591 fn min_args(&self) -> usize {
592 2
593 }
594 fn variadic(&self) -> bool {
595 true
596 }
597 fn arg_schema(&self) -> &'static [ArgSchema] {
598 &ARG_ANY_ONE[..]
599 }
600 fn eval<'a, 'b, 'c>(
601 &self,
602 args: &'c [ArgumentHandle<'a, 'b>],
603 ctx: &dyn FunctionContext<'b>,
604 ) -> Result<crate::traits::CalcValue<'b>, ExcelError> {
605 eval_if_family(args, ctx, AggregationType::Sum, false)
606 }
607}
608
609#[derive(Debug)]
611pub struct CountIfFn;
612impl Function for CountIfFn {
613 func_caps!(
614 PURE,
615 REDUCTION,
616 WINDOWED,
617 STREAM_OK,
618 PARALLEL_ARGS,
619 PARALLEL_CHUNKS
620 );
621 fn name(&self) -> &'static str {
622 "COUNTIF"
623 }
624 fn min_args(&self) -> usize {
625 2
626 }
627 fn variadic(&self) -> bool {
628 false
629 }
630 fn arg_schema(&self) -> &'static [ArgSchema] {
631 &ARG_ANY_ONE[..]
632 }
633 fn eval<'a, 'b, 'c>(
634 &self,
635 args: &'c [ArgumentHandle<'a, 'b>],
636 ctx: &dyn FunctionContext<'b>,
637 ) -> Result<crate::traits::CalcValue<'b>, ExcelError> {
638 eval_if_family(args, ctx, AggregationType::Count, false)
639 }
640}
641
642#[derive(Debug)]
644pub struct SumIfsFn; impl Function for SumIfsFn {
646 func_caps!(
647 PURE,
648 REDUCTION,
649 WINDOWED,
650 STREAM_OK,
651 PARALLEL_ARGS,
652 PARALLEL_CHUNKS
653 );
654 fn name(&self) -> &'static str {
655 "SUMIFS"
656 }
657 fn min_args(&self) -> usize {
658 3
659 }
660 fn variadic(&self) -> bool {
661 true
662 }
663 fn arg_schema(&self) -> &'static [ArgSchema] {
664 &ARG_ANY_ONE[..]
665 }
666 fn eval<'a, 'b, 'c>(
667 &self,
668 args: &'c [ArgumentHandle<'a, 'b>],
669 ctx: &dyn FunctionContext<'b>,
670 ) -> Result<crate::traits::CalcValue<'b>, ExcelError> {
671 eval_if_family(args, ctx, AggregationType::Sum, true)
672 }
673}
674
675#[derive(Debug)]
677pub struct CountIfsFn; impl Function for CountIfsFn {
679 func_caps!(
680 PURE,
681 REDUCTION,
682 WINDOWED,
683 STREAM_OK,
684 PARALLEL_ARGS,
685 PARALLEL_CHUNKS
686 );
687 fn name(&self) -> &'static str {
688 "COUNTIFS"
689 }
690 fn min_args(&self) -> usize {
691 2
692 }
693 fn variadic(&self) -> bool {
694 true
695 }
696 fn arg_schema(&self) -> &'static [ArgSchema] {
697 &ARG_ANY_ONE[..]
698 }
699 fn eval<'a, 'b, 'c>(
700 &self,
701 args: &'c [ArgumentHandle<'a, 'b>],
702 ctx: &dyn FunctionContext<'b>,
703 ) -> Result<crate::traits::CalcValue<'b>, ExcelError> {
704 eval_if_family(args, ctx, AggregationType::Count, true)
705 }
706}
707
708#[derive(Debug)]
710pub struct AverageIfsFn;
711impl Function for AverageIfsFn {
712 func_caps!(
713 PURE,
714 REDUCTION,
715 WINDOWED,
716 STREAM_OK,
717 PARALLEL_ARGS,
718 PARALLEL_CHUNKS
719 );
720 fn name(&self) -> &'static str {
721 "AVERAGEIFS"
722 }
723 fn min_args(&self) -> usize {
724 3
725 }
726 fn variadic(&self) -> bool {
727 true
728 }
729 fn arg_schema(&self) -> &'static [ArgSchema] {
730 &ARG_ANY_ONE[..]
731 }
732 fn eval<'a, 'b, 'c>(
733 &self,
734 args: &'c [ArgumentHandle<'a, 'b>],
735 ctx: &dyn FunctionContext<'b>,
736 ) -> Result<crate::traits::CalcValue<'b>, ExcelError> {
737 eval_if_family(args, ctx, AggregationType::Average, true)
738 }
739}
740
741#[derive(Debug)]
743pub struct CountAFn; impl Function for CountAFn {
745 func_caps!(PURE, REDUCTION);
746 fn name(&self) -> &'static str {
747 "COUNTA"
748 }
749 fn min_args(&self) -> usize {
750 1
751 }
752 fn variadic(&self) -> bool {
753 true
754 }
755 fn arg_schema(&self) -> &'static [ArgSchema] {
756 &ARG_ANY_ONE[..]
757 }
758 fn eval<'a, 'b, 'c>(
759 &self,
760 args: &'c [ArgumentHandle<'a, 'b>],
761 _ctx: &dyn FunctionContext<'b>,
762 ) -> Result<crate::traits::CalcValue<'b>, ExcelError> {
763 let mut cnt = 0i64;
764 for a in args {
765 if let Ok(view) = a.range_view() {
766 for res in view.type_tags_slices() {
767 let (_, _, tag_cols) = res?;
768 for col in tag_cols {
769 for i in 0..col.len() {
770 if col.value(i) != crate::arrow_store::TypeTag::Empty as u8 {
771 cnt += 1;
772 }
773 }
774 }
775 }
776 } else {
777 let v = a.value()?.into_literal();
778 if !matches!(v, LiteralValue::Empty) {
779 cnt += 1;
780 }
781 }
782 }
783 Ok(crate::traits::CalcValue::Scalar(LiteralValue::Number(
784 cnt as f64,
785 )))
786 }
787}
788
789#[derive(Debug)]
791pub struct CountBlankFn; impl Function for CountBlankFn {
793 func_caps!(PURE, REDUCTION);
794 fn name(&self) -> &'static str {
795 "COUNTBLANK"
796 }
797 fn min_args(&self) -> usize {
798 1
799 }
800 fn variadic(&self) -> bool {
801 true
802 }
803 fn arg_schema(&self) -> &'static [ArgSchema] {
804 &ARG_ANY_ONE[..]
805 }
806 fn eval<'a, 'b, 'c>(
807 &self,
808 args: &'c [ArgumentHandle<'a, 'b>],
809 _ctx: &dyn FunctionContext<'b>,
810 ) -> Result<crate::traits::CalcValue<'b>, ExcelError> {
811 let mut cnt = 0i64;
812 for a in args {
813 if let Ok(view) = a.range_view() {
814 let mut tag_it = view.type_tags_slices();
815 let mut text_it = view.text_slices();
816
817 while let (Some(tag_res), Some(text_res)) = (tag_it.next(), text_it.next()) {
818 let (_, _, tag_cols) = tag_res?;
819 let (_, _, text_cols) = text_res?;
820
821 for (tc, xc) in tag_cols.into_iter().zip(text_cols.into_iter()) {
822 let text_arr = xc
823 .as_any()
824 .downcast_ref::<arrow_array::StringArray>()
825 .unwrap();
826 for i in 0..tc.len() {
827 let is_blank = tc.value(i) == crate::arrow_store::TypeTag::Empty as u8
828 || (tc.value(i) == crate::arrow_store::TypeTag::Text as u8
829 && !text_arr.is_null(i)
830 && text_arr.value(i).is_empty());
831 if is_blank {
832 cnt += 1;
833 }
834 }
835 }
836 }
837 } else {
838 let v = a.value()?.into_literal();
839 match v {
840 LiteralValue::Empty => cnt += 1,
841 LiteralValue::Text(s) if s.is_empty() => cnt += 1,
842 _ => {}
843 }
844 }
845 }
846 Ok(crate::traits::CalcValue::Scalar(LiteralValue::Number(
847 cnt as f64,
848 )))
849 }
850}
851
852pub fn register_builtins() {
853 use std::sync::Arc;
854 crate::function_registry::register_function(Arc::new(SumIfFn));
855 crate::function_registry::register_function(Arc::new(CountIfFn));
856 crate::function_registry::register_function(Arc::new(AverageIfFn));
857 crate::function_registry::register_function(Arc::new(SumIfsFn));
858 crate::function_registry::register_function(Arc::new(CountIfsFn));
859 crate::function_registry::register_function(Arc::new(AverageIfsFn));
860 crate::function_registry::register_function(Arc::new(CountAFn));
861 crate::function_registry::register_function(Arc::new(CountBlankFn));
862}
863
864#[cfg(test)]
865mod tests {
866 use super::*;
867 use crate::test_workbook::TestWorkbook;
868 use crate::traits::ArgumentHandle;
869 use formualizer_common::LiteralValue;
870 use formualizer_parse::parser::{ASTNode, ASTNodeType};
871 fn interp(wb: &TestWorkbook) -> crate::interpreter::Interpreter<'_> {
872 wb.interpreter()
873 }
874 fn lit(v: LiteralValue) -> ASTNode {
875 ASTNode::new(ASTNodeType::Literal(v), None)
876 }
877
878 #[test]
879 fn sumif_basic() {
880 let wb = TestWorkbook::new().with_function(std::sync::Arc::new(SumIfFn));
881 let ctx = interp(&wb);
882 let range = lit(LiteralValue::Array(vec![vec![
883 LiteralValue::Int(1),
884 LiteralValue::Int(2),
885 LiteralValue::Int(3),
886 ]]));
887 let crit = lit(LiteralValue::Text(">1".into()));
888 let args = vec![
889 ArgumentHandle::new(&range, &ctx),
890 ArgumentHandle::new(&crit, &ctx),
891 ];
892 let f = ctx.context.get_function("", "SUMIF").unwrap();
893 assert_eq!(
894 f.dispatch(&args, &ctx.function_context(None))
895 .unwrap()
896 .into_literal(),
897 LiteralValue::Number(5.0)
898 );
899 }
900
901 #[test]
902 fn sumif_with_sum_range() {
903 let wb = TestWorkbook::new().with_function(std::sync::Arc::new(SumIfFn));
904 let ctx = interp(&wb);
905 let range = lit(LiteralValue::Array(vec![vec![
906 LiteralValue::Int(1),
907 LiteralValue::Int(0),
908 LiteralValue::Int(1),
909 ]]));
910 let sum_range = lit(LiteralValue::Array(vec![vec![
911 LiteralValue::Int(10),
912 LiteralValue::Int(20),
913 LiteralValue::Int(30),
914 ]]));
915 let crit = lit(LiteralValue::Text("=1".into()));
916 let args = vec![
917 ArgumentHandle::new(&range, &ctx),
918 ArgumentHandle::new(&crit, &ctx),
919 ArgumentHandle::new(&sum_range, &ctx),
920 ];
921 let f = ctx.context.get_function("", "SUMIF").unwrap();
922 assert_eq!(
923 f.dispatch(&args, &ctx.function_context(None))
924 .unwrap()
925 .into_literal(),
926 LiteralValue::Number(40.0)
927 );
928 }
929
930 #[test]
931 fn sumif_mismatched_ranges_now_pad_with_empty() {
932 let wb = TestWorkbook::new().with_function(std::sync::Arc::new(SumIfFn));
933 let ctx = interp(&wb);
934 let sum = lit(LiteralValue::Array(vec![
936 vec![LiteralValue::Int(1), LiteralValue::Int(2)],
937 vec![LiteralValue::Int(3), LiteralValue::Int(4)],
938 ]));
939 let crit_range = lit(LiteralValue::Array(vec![
941 vec![LiteralValue::Int(1), LiteralValue::Int(1)],
942 vec![LiteralValue::Int(1), LiteralValue::Int(1)],
943 vec![LiteralValue::Int(1), LiteralValue::Int(1)],
944 ]));
945 let crit = lit(LiteralValue::Text("=1".into()));
946 let args = vec![
947 ArgumentHandle::new(&crit_range, &ctx),
948 ArgumentHandle::new(&crit, &ctx),
949 ArgumentHandle::new(&sum, &ctx),
950 ];
951 let f = ctx.context.get_function("", "SUMIF").unwrap();
952 assert_eq!(
953 f.dispatch(&args, &ctx.function_context(None))
954 .unwrap()
955 .into_literal(),
956 LiteralValue::Number(10.0)
957 );
958 }
959
960 #[test]
961 fn countif_text_wildcard() {
962 let wb = TestWorkbook::new().with_function(std::sync::Arc::new(CountIfFn));
963 let ctx = interp(&wb);
964 let rng = lit(LiteralValue::Array(vec![vec![
965 LiteralValue::Text("alpha".into()),
966 LiteralValue::Text("beta".into()),
967 LiteralValue::Text("alphabet".into()),
968 ]]));
969 let crit = lit(LiteralValue::Text("al*".into()));
970 let args = vec![
971 ArgumentHandle::new(&rng, &ctx),
972 ArgumentHandle::new(&crit, &ctx),
973 ];
974 let f = ctx.context.get_function("", "COUNTIF").unwrap();
975 assert_eq!(
976 f.dispatch(&args, &ctx.function_context(None))
977 .unwrap()
978 .into_literal(),
979 LiteralValue::Number(2.0)
980 );
981 }
982
983 #[test]
984 fn sumifs_multiple_criteria() {
985 let wb = TestWorkbook::new().with_function(std::sync::Arc::new(SumIfsFn));
986 let ctx = interp(&wb);
987 let sum = lit(LiteralValue::Array(vec![vec![
988 LiteralValue::Int(10),
989 LiteralValue::Int(20),
990 LiteralValue::Int(30),
991 LiteralValue::Int(40),
992 ]]));
993 let city = lit(LiteralValue::Array(vec![vec![
994 LiteralValue::Text("Bellevue".into()),
995 LiteralValue::Text("Issaquah".into()),
996 LiteralValue::Text("Bellevue".into()),
997 LiteralValue::Text("Issaquah".into()),
998 ]]));
999 let beds = lit(LiteralValue::Array(vec![vec![
1000 LiteralValue::Int(2),
1001 LiteralValue::Int(3),
1002 LiteralValue::Int(4),
1003 LiteralValue::Int(5),
1004 ]]));
1005 let c_city = lit(LiteralValue::Text("Bellevue".into()));
1006 let c_beds = lit(LiteralValue::Text(">=4".into()));
1007 let args = vec![
1008 ArgumentHandle::new(&sum, &ctx),
1009 ArgumentHandle::new(&city, &ctx),
1010 ArgumentHandle::new(&c_city, &ctx),
1011 ArgumentHandle::new(&beds, &ctx),
1012 ArgumentHandle::new(&c_beds, &ctx),
1013 ];
1014 let f = ctx.context.get_function("", "SUMIFS").unwrap();
1015 assert_eq!(
1016 f.dispatch(&args, &ctx.function_context(None))
1017 .unwrap()
1018 .into_literal(),
1019 LiteralValue::Number(30.0)
1020 );
1021 }
1022
1023 #[test]
1024 fn countifs_basic() {
1025 let wb = TestWorkbook::new().with_function(std::sync::Arc::new(CountIfsFn));
1026 let ctx = interp(&wb);
1027 let city = lit(LiteralValue::Array(vec![vec![
1028 LiteralValue::Text("a".into()),
1029 LiteralValue::Text("b".into()),
1030 LiteralValue::Text("a".into()),
1031 ]]));
1032 let beds = lit(LiteralValue::Array(vec![vec![
1033 LiteralValue::Int(1),
1034 LiteralValue::Int(2),
1035 LiteralValue::Int(3),
1036 ]]));
1037 let c_city = lit(LiteralValue::Text("a".into()));
1038 let c_beds = lit(LiteralValue::Text(">1".into()));
1039 let args = vec![
1040 ArgumentHandle::new(&city, &ctx),
1041 ArgumentHandle::new(&c_city, &ctx),
1042 ArgumentHandle::new(&beds, &ctx),
1043 ArgumentHandle::new(&c_beds, &ctx),
1044 ];
1045 let f = ctx.context.get_function("", "COUNTIFS").unwrap();
1046 assert_eq!(
1047 f.dispatch(&args, &ctx.function_context(None))
1048 .unwrap()
1049 .into_literal(),
1050 LiteralValue::Number(1.0)
1051 );
1052 }
1053
1054 #[test]
1055 fn averageifs_div0() {
1056 let wb = TestWorkbook::new().with_function(std::sync::Arc::new(AverageIfsFn));
1057 let ctx = interp(&wb);
1058 let avg = lit(LiteralValue::Array(vec![vec![
1059 LiteralValue::Int(1),
1060 LiteralValue::Int(2),
1061 ]]));
1062 let crit_rng = lit(LiteralValue::Array(vec![vec![
1063 LiteralValue::Int(0),
1064 LiteralValue::Int(0),
1065 ]]));
1066 let crit = lit(LiteralValue::Text(">0".into()));
1067 let args = vec![
1068 ArgumentHandle::new(&avg, &ctx),
1069 ArgumentHandle::new(&crit_rng, &ctx),
1070 ArgumentHandle::new(&crit, &ctx),
1071 ];
1072 let f = ctx.context.get_function("", "AVERAGEIFS").unwrap();
1073 match f
1074 .dispatch(&args, &ctx.function_context(None))
1075 .unwrap()
1076 .into_literal()
1077 {
1078 LiteralValue::Error(e) => assert_eq!(e, "#DIV/0!"),
1079 _ => panic!("expected div0"),
1080 }
1081 }
1082
1083 #[test]
1084 fn counta_and_countblank() {
1085 let wb = TestWorkbook::new()
1086 .with_function(std::sync::Arc::new(CountAFn))
1087 .with_function(std::sync::Arc::new(CountBlankFn));
1088 let ctx = interp(&wb);
1089 let arr = lit(LiteralValue::Array(vec![vec![
1090 LiteralValue::Empty,
1091 LiteralValue::Text("".into()),
1092 LiteralValue::Int(5),
1093 ]]));
1094 let args = vec![ArgumentHandle::new(&arr, &ctx)];
1095 let counta = ctx.context.get_function("", "COUNTA").unwrap();
1096 let countblank = ctx.context.get_function("", "COUNTBLANK").unwrap();
1097 assert_eq!(
1098 counta
1099 .dispatch(&args, &ctx.function_context(None))
1100 .unwrap()
1101 .into_literal(),
1102 LiteralValue::Number(2.0)
1103 );
1104 assert_eq!(
1105 countblank
1106 .dispatch(&args, &ctx.function_context(None))
1107 .unwrap()
1108 .into_literal(),
1109 LiteralValue::Number(2.0)
1110 );
1111 }
1112
1113 #[test]
1115 fn sumifs_broadcasts_1x1_criteria_over_range() {
1116 let wb = TestWorkbook::new().with_function(std::sync::Arc::new(SumIfsFn));
1117 let ctx = interp(&wb);
1118 let sum = lit(LiteralValue::Array(vec![
1120 vec![LiteralValue::Int(10)],
1121 vec![LiteralValue::Int(20)],
1122 ]));
1123 let tags = lit(LiteralValue::Array(vec![
1125 vec![LiteralValue::Text("A".into())],
1126 vec![LiteralValue::Text("B".into())],
1127 ]));
1128 let c_tag = lit(LiteralValue::Array(vec![vec![LiteralValue::Text(
1130 "A".into(),
1131 )]]));
1132 let args = vec![
1133 ArgumentHandle::new(&sum, &ctx),
1134 ArgumentHandle::new(&tags, &ctx),
1135 ArgumentHandle::new(&c_tag, &ctx),
1136 ];
1137 let f = ctx.context.get_function("", "SUMIFS").unwrap();
1138 assert_eq!(
1139 f.dispatch(&args, &ctx.function_context(None))
1140 .unwrap()
1141 .into_literal(),
1142 LiteralValue::Number(10.0)
1143 );
1144 }
1145
1146 #[test]
1147 fn countifs_broadcasts_1x1_criteria_over_row() {
1148 let wb = TestWorkbook::new().with_function(std::sync::Arc::new(CountIfsFn));
1149 let ctx = interp(&wb);
1150 let nums = lit(LiteralValue::Array(vec![vec![
1152 LiteralValue::Int(1),
1153 LiteralValue::Int(2),
1154 LiteralValue::Int(3),
1155 LiteralValue::Int(4),
1156 ]]));
1157 let crit = lit(LiteralValue::Array(vec![vec![LiteralValue::Text(
1159 ">=3".into(),
1160 )]]));
1161 let args = vec![
1162 ArgumentHandle::new(&nums, &ctx),
1163 ArgumentHandle::new(&crit, &ctx),
1164 ];
1165 let f = ctx.context.get_function("", "COUNTIFS").unwrap();
1166 assert_eq!(
1167 f.dispatch(&args, &ctx.function_context(None))
1168 .unwrap()
1169 .into_literal(),
1170 LiteralValue::Number(2.0)
1171 );
1172 }
1173
1174 #[test]
1175 fn sumifs_empty_ranges_with_1x1_criteria_produce_zero() {
1176 let wb = TestWorkbook::new().with_function(std::sync::Arc::new(SumIfsFn));
1177 let ctx = interp(&wb);
1178 let empty = lit(LiteralValue::Array(Vec::new()));
1180 let crit = lit(LiteralValue::Array(vec![vec![LiteralValue::Text(
1182 "X".into(),
1183 )]]));
1184 let args = vec![
1185 ArgumentHandle::new(&empty, &ctx),
1186 ArgumentHandle::new(&empty, &ctx),
1187 ArgumentHandle::new(&crit, &ctx),
1188 ];
1189 let f = ctx.context.get_function("", "SUMIFS").unwrap();
1190 assert_eq!(
1191 f.dispatch(&args, &ctx.function_context(None)).unwrap(),
1192 LiteralValue::Number(0.0)
1193 );
1194 }
1195
1196 #[test]
1197 fn sumifs_mismatched_ranges_now_pad_with_empty() {
1198 let wb = TestWorkbook::new().with_function(std::sync::Arc::new(SumIfsFn));
1199 let ctx = interp(&wb);
1200 let sum = lit(LiteralValue::Array(vec![
1202 vec![LiteralValue::Int(1), LiteralValue::Int(2)],
1203 vec![LiteralValue::Int(3), LiteralValue::Int(4)],
1204 ]));
1205 let crit_range = lit(LiteralValue::Array(vec![
1207 vec![LiteralValue::Int(1), LiteralValue::Int(1)],
1208 vec![LiteralValue::Int(1), LiteralValue::Int(1)],
1209 vec![LiteralValue::Int(1), LiteralValue::Int(1)],
1210 ]));
1211 let crit = lit(LiteralValue::Text("=1".into()));
1213 let args = vec![
1214 ArgumentHandle::new(&sum, &ctx),
1215 ArgumentHandle::new(&crit_range, &ctx),
1216 ArgumentHandle::new(&crit, &ctx),
1217 ];
1218 let f = ctx.context.get_function("", "SUMIFS").unwrap();
1219 assert_eq!(
1223 f.dispatch(&args, &ctx.function_context(None)).unwrap(),
1224 LiteralValue::Number(10.0)
1225 );
1226 }
1227
1228 #[test]
1229 fn countifs_mismatched_ranges_pad_and_broadcast() {
1230 let wb = TestWorkbook::new().with_function(std::sync::Arc::new(CountIfsFn));
1231 let ctx = interp(&wb);
1232 let r1 = lit(LiteralValue::Array(vec![
1234 vec![LiteralValue::Int(1)],
1235 vec![LiteralValue::Int(1)],
1236 ]));
1237 let c1 = lit(LiteralValue::Text("=1".into()));
1239 let r2 = lit(LiteralValue::Array(vec![
1241 vec![LiteralValue::Int(1)],
1242 vec![LiteralValue::Int(1)],
1243 vec![LiteralValue::Int(1)],
1244 ]));
1245 let c2 = lit(LiteralValue::Text("=1".into()));
1247 let args = vec![
1248 ArgumentHandle::new(&r1, &ctx),
1249 ArgumentHandle::new(&c1, &ctx),
1250 ArgumentHandle::new(&r2, &ctx),
1251 ArgumentHandle::new(&c2, &ctx),
1252 ];
1253 let f = ctx.context.get_function("", "COUNTIFS").unwrap();
1254 assert_eq!(
1256 f.dispatch(&args, &ctx.function_context(None)).unwrap(),
1257 LiteralValue::Number(2.0)
1258 );
1259 }
1260
1261 #[test]
1262 fn averageifs_mismatched_ranges_pad() {
1263 let wb = TestWorkbook::new().with_function(std::sync::Arc::new(AverageIfsFn));
1264 let ctx = interp(&wb);
1265 let avg = lit(LiteralValue::Array(vec![
1267 vec![LiteralValue::Int(10)],
1268 vec![LiteralValue::Int(20)],
1269 ]));
1270 let r1 = lit(LiteralValue::Array(vec![
1272 vec![LiteralValue::Int(1)],
1273 vec![LiteralValue::Int(1)],
1274 vec![LiteralValue::Int(2)],
1275 ]));
1276 let c1 = lit(LiteralValue::Text("=1".into()));
1277 let args = vec![
1278 ArgumentHandle::new(&avg, &ctx),
1279 ArgumentHandle::new(&r1, &ctx),
1280 ArgumentHandle::new(&c1, &ctx),
1281 ];
1282 let f = ctx.context.get_function("", "AVERAGEIFS").unwrap();
1283 assert_eq!(
1285 f.dispatch(&args, &ctx.function_context(None)).unwrap(),
1286 LiteralValue::Number(15.0)
1287 );
1288 }
1289
1290 #[test]
1291 fn criteria_scientific_notation() {
1292 let wb = TestWorkbook::new().with_function(std::sync::Arc::new(SumIfFn));
1293 let ctx = interp(&wb);
1294 let nums = lit(LiteralValue::Array(vec![vec![
1295 LiteralValue::Number(1000.0),
1296 LiteralValue::Number(1500.0),
1297 LiteralValue::Number(999.0),
1298 ]]));
1299 let crit = lit(LiteralValue::Text(">1e3".into())); let args = vec![
1301 ArgumentHandle::new(&nums, &ctx),
1302 ArgumentHandle::new(&crit, &ctx),
1303 ];
1304 let f = ctx.context.get_function("", "SUMIF").unwrap();
1305 assert_eq!(
1307 f.dispatch(&args, &ctx.function_context(None)).unwrap(),
1308 LiteralValue::Number(1500.0)
1309 );
1310 }
1311}