flows_datafusion/
sum_column.rs1use arrow_array::{ArrayRef, RecordBatch};
4use async_flow::{Inputs, Output, Port, Result};
5use datafusion_common::ScalarValue;
6
7pub async fn sum_column(
13 column: usize,
14 mut inputs: Inputs<RecordBatch>,
15 output: Output<ScalarValue>,
16) -> Result {
17 let mut result: ScalarValue = ScalarValue::Null;
18
19 while let Some(input) = inputs.recv().await? {
20 if input.num_rows() == 0 {
21 continue; }
23
24 let column_array = input.column(column);
25 let Some(column_sum) = sum_array(column_array) else {
26 continue; };
28
29 result = if result.is_null() {
30 column_sum
31 } else {
32 result.add(column_sum).unwrap()
33 }
34 }
35
36 if !output.is_closed() {
37 output.send(result).await?;
38 }
39
40 Ok(())
41}
42
43pub fn sum_array(array: &ArrayRef) -> Option<ScalarValue> {
44 use arrow_arith::aggregate::sum;
45 use arrow_array::{cast::AsArray, types::*};
46 use arrow_schema::DataType::*;
47 Some(match array.data_type() {
48 Int8 => ScalarValue::from(sum(array.as_primitive::<Int8Type>())),
49 Int16 => ScalarValue::from(sum(array.as_primitive::<Int16Type>())),
50 Int32 => ScalarValue::from(sum(array.as_primitive::<Int32Type>())),
51 Int64 => ScalarValue::from(sum(array.as_primitive::<Int64Type>())),
52 UInt8 => ScalarValue::from(sum(array.as_primitive::<UInt8Type>())),
53 UInt16 => ScalarValue::from(sum(array.as_primitive::<UInt16Type>())),
54 UInt32 => ScalarValue::from(sum(array.as_primitive::<UInt32Type>())),
55 UInt64 => ScalarValue::from(sum(array.as_primitive::<UInt64Type>())),
56 Float16 => ScalarValue::from(sum(array.as_primitive::<Float16Type>())),
57 Float32 => ScalarValue::from(sum(array.as_primitive::<Float32Type>())),
58 Float64 => ScalarValue::from(sum(array.as_primitive::<Float64Type>())),
59 _ => return None,
61 })
62}
63
64#[cfg(test)]
65mod tests {
66 use super::*;
67 use alloc::{boxed::Box, vec};
68 use arrow_array::record_batch;
69 use async_flow::{Channel, InputPort};
70 use core::error::Error;
71
72 #[tokio::test]
73 async fn test_sum_column_i32() -> Result<(), Box<dyn Error>> {
74 let mut in_ = Channel::bounded(10);
75 let mut out = Channel::oneshot();
76 let summer = tokio::spawn(sum_column(0, in_.rx, out.tx));
77
78 in_.tx.send(sample_data()).await?;
79 in_.tx.send(sample_data()).await?;
80 in_.tx.close();
81
82 let _ = tokio::join!(summer);
83
84 let outputs = out.rx.recv_all().await?;
85 assert_eq!(outputs.len(), 1);
86 assert_eq!(outputs[0], ScalarValue::from(30i32));
87
88 Ok(())
89 }
90
91 fn sample_data() -> RecordBatch {
92 record_batch!(
93 ("a", Int32, [1, 2, 3, 4, 5]),
94 ("b", Float64, [Some(4.0), None, Some(5.0), None, None]),
95 ("c", Utf8, ["alpha", "beta", "gamma", "", ""])
96 )
97 .unwrap()
98 }
99}