flows_datafusion/
avg_column.rs

1// This is free and unencumbered software released into the public domain.
2
3use super::sum_array;
4use arrow_array::RecordBatch;
5use async_flow::{Inputs, Output, Port, Result};
6use datafusion_common::ScalarValue;
7
8/// A block that outputs the average of all values in a given column.
9///
10/// Panics in case the specified column index is out of bounds.
11/// Outputs `ScalarValue::Null` in case the specified column has a non-numeric
12/// datatype.
13pub async fn avg_column(
14    column: usize,
15    mut inputs: Inputs<RecordBatch>,
16    output: Output<ScalarValue>,
17) -> Result {
18    let mut tally: ScalarValue = ScalarValue::Null;
19    let mut count: usize = 0;
20
21    while let Some(input) = inputs.recv().await? {
22        if input.num_rows() == 0 {
23            continue; // skip empty batches
24        }
25
26        let column_array = input.column(column);
27        let column_len = column_array.len() - column_array.null_count();
28        if column_len == 0 {
29            continue; // skip null-only batches
30        }
31
32        let Some(column_sum) = sum_array(column_array) else {
33            continue; // skip unsupported datatypes
34        };
35
36        let column_avg = avg(column_sum, column_len).unwrap();
37
38        tally = if tally.is_null() {
39            column_avg
40        } else {
41            tally.add(column_avg).unwrap()
42        };
43        count += 1;
44    }
45
46    let result = if count == 0 {
47        ScalarValue::Null
48    } else {
49        avg(tally, count).unwrap()
50    };
51
52    if !output.is_closed() {
53        output.send(result).await?;
54    }
55
56    Ok(())
57}
58
59pub fn avg(sum: ScalarValue, len: usize) -> Option<ScalarValue> {
60    assert!(len > 0);
61    use arrow_schema::DataType::*;
62    let sum = sum.cast_to(&Float64).unwrap();
63    let len = ScalarValue::Float64(Some(len as f64));
64    sum.div(len).ok()
65}
66
67#[cfg(test)]
68mod tests {
69    use super::*;
70    use alloc::{boxed::Box, vec};
71    use arrow_array::record_batch;
72    use async_flow::{Channel, InputPort};
73    use core::error::Error;
74
75    #[tokio::test]
76    async fn test_avg_column_i32() -> Result<(), Box<dyn Error>> {
77        let mut in_ = Channel::bounded(10);
78        let mut out = Channel::oneshot();
79        let averager = tokio::spawn(avg_column(0, in_.rx, out.tx));
80
81        in_.tx.send(sample_data()).await?;
82        in_.tx.send(sample_data()).await?;
83        in_.tx.close();
84
85        let _ = tokio::join!(averager);
86
87        let outputs = out.rx.recv_all().await?;
88        assert_eq!(outputs.len(), 1);
89        assert_eq!(outputs[0], ScalarValue::from(3.0));
90
91        Ok(())
92    }
93
94    fn sample_data() -> RecordBatch {
95        record_batch!(
96            ("a", Int32, [1, 2, 3, 4, 5]),
97            ("b", Float64, [Some(4.0), None, Some(5.0), None, None]),
98            ("c", Utf8, ["alpha", "beta", "gamma", "", ""])
99        )
100        .unwrap()
101    }
102}