flows_arrow/
count_rows.rs

1// This is free and unencumbered software released into the public domain.
2
3use arrow_array::RecordBatch;
4use async_flow::{Inputs, Outputs, Result};
5
6/// A block that outputs row counts of input record batches.
7pub async fn count_rows(
8    mut batches: Inputs<RecordBatch>,
9    counts: Outputs<usize>,
10    total: Outputs<usize>,
11) -> Result {
12    let mut total_rows = 0;
13
14    while let Some(batch) = batches.recv().await? {
15        let batch_rows = batch.num_rows();
16        total_rows += batch_rows;
17
18        if !counts.is_closed() {
19            counts.send(batch_rows).await?;
20        }
21    }
22
23    if !total.is_closed() {
24        total.send(total_rows).await?;
25    }
26    Ok(())
27}
28
29#[cfg(test)]
30mod tests {
31    use super::*;
32    use alloc::vec;
33    use arrow_array::{Float32Array, Int32Array};
34    use std::sync::Arc;
35
36    #[tokio::test]
37    async fn test_count_rows() {
38        use async_flow::bounded;
39
40        let col_1 = Arc::new(Int32Array::from_iter([1, 2, 3])) as _;
41        let col_2 = Arc::new(Float32Array::from_iter([1., 6.3, 4.])) as _;
42        let batch = RecordBatch::try_from_iter(vec![("col_1", col_1), ("col_2", col_2)]).unwrap();
43
44        let (mut batches_tx, batches_rx) = bounded(10);
45        let (counts_tx, mut counts_rx) = bounded(10);
46        let (total_tx, mut total_rx) = bounded(10);
47
48        let counter = tokio::spawn(count_rows(batches_rx, counts_tx, total_tx));
49
50        batches_tx.send(batch.clone()).await.unwrap();
51        batches_tx.send(batch.clone()).await.unwrap();
52        batches_tx.close();
53
54        let _ = tokio::join!(counter);
55
56        assert_eq!(counts_rx.recv().await.unwrap(), Some(3));
57        assert_eq!(counts_rx.recv().await.unwrap(), Some(3));
58
59        assert_eq!(total_rx.recv().await.unwrap(), Some(6));
60    }
61}