Skip to main content

flows_datafusion/
min_column.rs

1// This is free and unencumbered software released into the public domain.
2
3use arrow_array::{ArrayRef, RecordBatch};
4use async_flow::{Inputs, Output, Port, Result};
5use datafusion_common::ScalarValue;
6
7/// A block that outputs the minimum of the values in a given column.
8///
9/// Panics in case the specified column index is out of bounds.
10/// Outputs `ScalarValue::Null` in case the specified column has a non-numeric
11/// datatype.
12pub async fn min_column(
13    column: usize,
14    mut inputs: Inputs<RecordBatch>,
15    output: Output<ScalarValue>,
16) -> Result {
17    let mut result: ScalarValue = ScalarValue::Null;
18
19    while let Some(input) = inputs.recv().await? {
20        if input.num_rows() == 0 {
21            continue; // skip empty batches
22        }
23
24        let column_array = input.column(column);
25        let Some(column_min) = min_array(column_array) else {
26            continue; // skip unsupported datatypes
27        };
28
29        if result.is_null() || column_min < result {
30            result = column_min;
31        }
32    }
33
34    if !output.is_closed() {
35        output.send(result).await?;
36    }
37
38    Ok(())
39}
40
41pub fn min_array(array: &ArrayRef) -> Option<ScalarValue> {
42    use arrow_arith::aggregate::min;
43    use arrow_array::{cast::AsArray, types::*};
44    use arrow_schema::DataType::*;
45    Some(match array.data_type() {
46        Int8 => ScalarValue::from(min(array.as_primitive::<Int8Type>())),
47        Int16 => ScalarValue::from(min(array.as_primitive::<Int16Type>())),
48        Int32 => ScalarValue::from(min(array.as_primitive::<Int32Type>())),
49        Int64 => ScalarValue::from(min(array.as_primitive::<Int64Type>())),
50        UInt8 => ScalarValue::from(min(array.as_primitive::<UInt8Type>())),
51        UInt16 => ScalarValue::from(min(array.as_primitive::<UInt16Type>())),
52        UInt32 => ScalarValue::from(min(array.as_primitive::<UInt32Type>())),
53        UInt64 => ScalarValue::from(min(array.as_primitive::<UInt64Type>())),
54        Float16 => ScalarValue::from(min(array.as_primitive::<Float16Type>())),
55        Float32 => ScalarValue::from(min(array.as_primitive::<Float32Type>())),
56        Float64 => ScalarValue::from(min(array.as_primitive::<Float64Type>())),
57        // TODO: Decimal32, Decimal64, Decimal128, Decimal256
58        _ => return None,
59    })
60}
61
62#[cfg(test)]
63mod tests {
64    use super::*;
65    use alloc::{boxed::Box, vec};
66    use arrow_array::record_batch;
67    use async_flow::{Channel, InputPort};
68    use core::error::Error;
69
70    #[tokio::test]
71    async fn test_min_column_i32() -> Result<(), Box<dyn Error>> {
72        let mut in_ = Channel::bounded(10);
73        let mut out = Channel::oneshot();
74        let minner = tokio::spawn(min_column(0, in_.rx, out.tx));
75
76        in_.tx.send(sample_data()).await?;
77        in_.tx.send(sample_data()).await?;
78        in_.tx.close();
79
80        let _ = tokio::join!(minner);
81
82        let outputs = out.rx.recv_all().await?;
83        assert_eq!(outputs.len(), 1);
84        assert_eq!(outputs[0], ScalarValue::from(1i32));
85
86        Ok(())
87    }
88
89    fn sample_data() -> RecordBatch {
90        record_batch!(
91            ("a", Int32, [1, 2, 3, 4, 5]),
92            ("b", Float64, [Some(4.0), None, Some(5.0), None, None]),
93            ("c", Utf8, ["alpha", "beta", "gamma", "", ""])
94        )
95        .unwrap()
96    }
97}