flows_datafusion/
avg_column.rs1use super::sum_array;
4use arrow_array::RecordBatch;
5use async_flow::{Inputs, Output, Port, Result};
6use datafusion_common::ScalarValue;
7
8pub async fn avg_column(
14 column: usize,
15 mut inputs: Inputs<RecordBatch>,
16 output: Output<ScalarValue>,
17) -> Result {
18 let mut tally: ScalarValue = ScalarValue::Null;
19 let mut count: usize = 0;
20
21 while let Some(input) = inputs.recv().await? {
22 if input.num_rows() == 0 {
23 continue; }
25
26 let column_array = input.column(column);
27 let column_len = column_array.len() - column_array.null_count();
28 if column_len == 0 {
29 continue; }
31
32 let Some(column_sum) = sum_array(column_array) else {
33 continue; };
35
36 let column_avg = avg(column_sum, column_len).unwrap();
37
38 tally = if tally.is_null() {
39 column_avg
40 } else {
41 tally.add(column_avg).unwrap()
42 };
43 count += 1;
44 }
45
46 let result = if count == 0 {
47 ScalarValue::Null
48 } else {
49 avg(tally, count).unwrap()
50 };
51
52 if !output.is_closed() {
53 output.send(result).await?;
54 }
55
56 Ok(())
57}
58
59pub fn avg(sum: ScalarValue, len: usize) -> Option<ScalarValue> {
60 assert!(len > 0);
61 use arrow_schema::DataType::*;
62 let sum = sum.cast_to(&Float64).unwrap();
63 let len = ScalarValue::Float64(Some(len as f64));
64 sum.div(len).ok()
65}
66
67#[cfg(test)]
68mod tests {
69 use super::*;
70 use alloc::{boxed::Box, vec};
71 use arrow_array::record_batch;
72 use async_flow::{Channel, InputPort};
73 use core::error::Error;
74
75 #[tokio::test]
76 async fn test_avg_column_i32() -> Result<(), Box<dyn Error>> {
77 let mut in_ = Channel::bounded(10);
78 let mut out = Channel::oneshot();
79 let averager = tokio::spawn(avg_column(0, in_.rx, out.tx));
80
81 in_.tx.send(sample_data()).await?;
82 in_.tx.send(sample_data()).await?;
83 in_.tx.close();
84
85 let _ = tokio::join!(averager);
86
87 let outputs = out.rx.recv_all().await?;
88 assert_eq!(outputs.len(), 1);
89 assert_eq!(outputs[0], ScalarValue::from(3.0));
90
91 Ok(())
92 }
93
94 fn sample_data() -> RecordBatch {
95 record_batch!(
96 ("a", Int32, [1, 2, 3, 4, 5]),
97 ("b", Float64, [Some(4.0), None, Some(5.0), None, None]),
98 ("c", Utf8, ["alpha", "beta", "gamma", "", ""])
99 )
100 .unwrap()
101 }
102}