flows_datafusion/
sum_column.rs

1// This is free and unencumbered software released into the public domain.
2
3use arrow_array::{ArrayRef, RecordBatch};
4use async_flow::{Inputs, Output, Port, Result};
5use datafusion_common::ScalarValue;
6
7/// A block that outputs the sum of the values in a given column.
8///
9/// Panics in case the specified column index is out of bounds.
10/// Outputs `ScalarValue::Null` in case the specified column has a non-numeric
11/// datatype.
12pub async fn sum_column(
13    column: usize,
14    mut inputs: Inputs<RecordBatch>,
15    output: Output<ScalarValue>,
16) -> Result {
17    let mut result: ScalarValue = ScalarValue::Null;
18
19    while let Some(input) = inputs.recv().await? {
20        if input.num_rows() == 0 {
21            continue; // skip empty batches
22        }
23
24        let column_array = input.column(column);
25        let Some(column_sum) = sum_array(column_array) else {
26            continue; // skip unsupported datatypes
27        };
28
29        result = if result.is_null() {
30            column_sum
31        } else {
32            result.add(column_sum).unwrap()
33        }
34    }
35
36    if !output.is_closed() {
37        output.send(result).await?;
38    }
39
40    Ok(())
41}
42
43pub fn sum_array(array: &ArrayRef) -> Option<ScalarValue> {
44    use arrow_arith::aggregate::sum;
45    use arrow_array::{cast::AsArray, types::*};
46    use arrow_schema::DataType::*;
47    Some(match array.data_type() {
48        Int8 => ScalarValue::from(sum(array.as_primitive::<Int8Type>())),
49        Int16 => ScalarValue::from(sum(array.as_primitive::<Int16Type>())),
50        Int32 => ScalarValue::from(sum(array.as_primitive::<Int32Type>())),
51        Int64 => ScalarValue::from(sum(array.as_primitive::<Int64Type>())),
52        UInt8 => ScalarValue::from(sum(array.as_primitive::<UInt8Type>())),
53        UInt16 => ScalarValue::from(sum(array.as_primitive::<UInt16Type>())),
54        UInt32 => ScalarValue::from(sum(array.as_primitive::<UInt32Type>())),
55        UInt64 => ScalarValue::from(sum(array.as_primitive::<UInt64Type>())),
56        Float16 => ScalarValue::from(sum(array.as_primitive::<Float16Type>())),
57        Float32 => ScalarValue::from(sum(array.as_primitive::<Float32Type>())),
58        Float64 => ScalarValue::from(sum(array.as_primitive::<Float64Type>())),
59        // TODO: Decimal32, Decimal64, Decimal128, Decimal256
60        _ => return None,
61    })
62}
63
64#[cfg(test)]
65mod tests {
66    use super::*;
67    use alloc::{boxed::Box, vec};
68    use arrow_array::record_batch;
69    use async_flow::{Channel, InputPort};
70    use core::error::Error;
71
72    #[tokio::test]
73    async fn test_sum_column_i32() -> Result<(), Box<dyn Error>> {
74        let mut in_ = Channel::bounded(10);
75        let mut out = Channel::oneshot();
76        let summer = tokio::spawn(sum_column(0, in_.rx, out.tx));
77
78        in_.tx.send(sample_data()).await?;
79        in_.tx.send(sample_data()).await?;
80        in_.tx.close();
81
82        let _ = tokio::join!(summer);
83
84        let outputs = out.rx.recv_all().await?;
85        assert_eq!(outputs.len(), 1);
86        assert_eq!(outputs[0], ScalarValue::from(30i32));
87
88        Ok(())
89    }
90
91    fn sample_data() -> RecordBatch {
92        record_batch!(
93            ("a", Int32, [1, 2, 3, 4, 5]),
94            ("b", Float64, [Some(4.0), None, Some(5.0), None, None]),
95            ("c", Utf8, ["alpha", "beta", "gamma", "", ""])
96        )
97        .unwrap()
98    }
99}