Skip to main content

flows_arrow/
count_rows.rs

1// This is free and unencumbered software released into the public domain.
2
3use arrow_array::RecordBatch;
4use async_flow::{Inputs, Output, Outputs, Port, Result};
5
6/// A block that outputs row counts of input record batches.
7pub async fn count_rows(
8    mut batches: Inputs<RecordBatch>,
9    counts: Outputs<usize>,
10    total: Output<usize>,
11) -> Result {
12    let mut total_rows = 0;
13
14    while let Some(batch) = batches.recv().await? {
15        let batch_rows = batch.num_rows();
16        total_rows += batch_rows;
17
18        if !counts.is_closed() {
19            counts.send(batch_rows).await?;
20        }
21    }
22
23    if !total.is_closed() {
24        total.send(total_rows).await?;
25    }
26    Ok(())
27}
28
29#[cfg(test)]
30mod tests {
31    use super::*;
32    use alloc::{boxed::Box, vec};
33    use arrow_array::record_batch;
34    use async_flow::{Channel, InputPort};
35    use core::error::Error;
36
37    #[tokio::test]
38    async fn test_count_rows() -> Result<(), Box<dyn Error>> {
39        let mut batches = Channel::bounded(10);
40        let mut counts = Channel::bounded(10);
41        let mut total = Channel::oneshot();
42        let counter = tokio::spawn(count_rows(batches.rx, counts.tx, total.tx));
43
44        let batch = record_batch!(
45            ("a", Int32, [1, 2, 3]),
46            ("b", Float64, [Some(4.0), None, Some(5.0)]),
47            ("c", Utf8, ["alpha", "beta", "gamma"])
48        )?;
49        batches.tx.send(batch.clone()).await?;
50        batches.tx.send(batch.clone()).await?;
51        batches.tx.close();
52
53        let _ = tokio::join!(counter);
54
55        let counts = counts.rx.recv_all().await?;
56        assert_eq!(counts.len(), 2);
57        for count in counts {
58            assert_eq!(count, 3);
59        }
60
61        assert_eq!(total.rx.recv().await?, Some(6));
62
63        Ok(())
64    }
65}