flows_arrow/
count_rows.rs1use arrow_array::RecordBatch;
4use async_flow::{Inputs, Output, Outputs, Port, Result};
5
6pub async fn count_rows(
8 mut batches: Inputs<RecordBatch>,
9 counts: Outputs<usize>,
10 total: Output<usize>,
11) -> Result {
12 let mut total_rows = 0;
13
14 while let Some(batch) = batches.recv().await? {
15 let batch_rows = batch.num_rows();
16 total_rows += batch_rows;
17
18 if !counts.is_closed() {
19 counts.send(batch_rows).await?;
20 }
21 }
22
23 if !total.is_closed() {
24 total.send(total_rows).await?;
25 }
26 Ok(())
27}
28
29#[cfg(test)]
30mod tests {
31 use super::*;
32 use alloc::{boxed::Box, vec};
33 use arrow_array::record_batch;
34 use async_flow::{Channel, InputPort};
35 use core::error::Error;
36
37 #[tokio::test]
38 async fn test_count_rows() -> Result<(), Box<dyn Error>> {
39 let mut batches = Channel::bounded(10);
40 let mut counts = Channel::bounded(10);
41 let mut total = Channel::oneshot();
42 let counter = tokio::spawn(count_rows(batches.rx, counts.tx, total.tx));
43
44 let batch = record_batch!(
45 ("a", Int32, [1, 2, 3]),
46 ("b", Float64, [Some(4.0), None, Some(5.0)]),
47 ("c", Utf8, ["alpha", "beta", "gamma"])
48 )?;
49 batches.tx.send(batch.clone()).await?;
50 batches.tx.send(batch.clone()).await?;
51 batches.tx.close();
52
53 let _ = tokio::join!(counter);
54
55 let counts = counts.rx.recv_all().await?;
56 assert_eq!(counts.len(), 2);
57 for count in counts {
58 assert_eq!(count, 3);
59 }
60
61 assert_eq!(total.rx.recv().await?, Some(6));
62
63 Ok(())
64 }
65}