flows_arrow/
count_rows.rs1use arrow_array::RecordBatch;
4use async_flow::{Inputs, Outputs, Result};
5
6pub async fn count_rows(
8 mut batches: Inputs<RecordBatch>,
9 counts: Outputs<usize>,
10 total: Outputs<usize>,
11) -> Result {
12 let mut total_rows = 0;
13
14 while let Some(batch) = batches.recv().await? {
15 let batch_rows = batch.num_rows();
16 total_rows += batch_rows;
17
18 if !counts.is_closed() {
19 counts.send(batch_rows).await?;
20 }
21 }
22
23 if !total.is_closed() {
24 total.send(total_rows).await?;
25 }
26 Ok(())
27}
28
29#[cfg(test)]
30mod tests {
31 use super::*;
32 use alloc::vec;
33 use arrow_array::{Float32Array, Int32Array};
34 use std::sync::Arc;
35
36 #[tokio::test]
37 async fn test_count_rows() {
38 use async_flow::bounded;
39
40 let col_1 = Arc::new(Int32Array::from_iter([1, 2, 3])) as _;
41 let col_2 = Arc::new(Float32Array::from_iter([1., 6.3, 4.])) as _;
42 let batch = RecordBatch::try_from_iter(vec![("col_1", col_1), ("col_2", col_2)]).unwrap();
43
44 let (mut batches_tx, batches_rx) = bounded(10);
45 let (counts_tx, mut counts_rx) = bounded(10);
46 let (total_tx, mut total_rx) = bounded(10);
47
48 let counter = tokio::spawn(count_rows(batches_rx, counts_tx, total_tx));
49
50 batches_tx.send(batch.clone()).await.unwrap();
51 batches_tx.send(batch.clone()).await.unwrap();
52 batches_tx.close();
53
54 let _ = tokio::join!(counter);
55
56 assert_eq!(counts_rx.recv().await.unwrap(), Some(3));
57 assert_eq!(counts_rx.recv().await.unwrap(), Some(3));
58
59 assert_eq!(total_rx.recv().await.unwrap(), Some(6));
60 }
61}