flows_arrow/
concat_batches.rs

1// This is free and unencumbered software released into the public domain.
2
3use alloc::vec::Vec;
4use arrow_array::RecordBatch;
5use async_flow::{Inputs, Output, Port, Result};
6
7/// A block that concatenates input batches into a single output batch.
8pub async fn concat_batches(
9    mut inputs: Inputs<RecordBatch>,
10    output: Output<RecordBatch>,
11) -> Result {
12    let mut batches: Vec<RecordBatch> = Vec::new();
13
14    while let Some(batch) = inputs.recv().await? {
15        if batch.num_rows() == 0 && !batches.is_empty() {
16            continue; // skip empty batches after the first one
17        }
18        batches.push(batch);
19    }
20
21    if !batches.is_empty() {
22        let schema = batches[0].schema();
23        let batch = arrow_select::concat::concat_batches(&schema, &batches).unwrap();
24
25        if !output.is_closed() {
26            output.send(batch).await?;
27        }
28    }
29
30    Ok(())
31}
32
33#[cfg(test)]
34mod tests {
35    use super::*;
36    use alloc::{boxed::Box, vec};
37    use arrow_array::record_batch;
38    use async_flow::{Channel, InputPort};
39    use core::error::Error;
40
41    #[tokio::test]
42    async fn test_concat_batches() -> Result<(), Box<dyn Error>> {
43        let mut in_ = Channel::bounded(1);
44        let mut out = Channel::oneshot();
45        let concatter = tokio::spawn(concat_batches(in_.rx, out.tx));
46
47        let batch = record_batch!(("n", Int32, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]))?;
48        in_.tx.send(batch.clone()).await?;
49        in_.tx.send(batch.clone()).await?;
50        in_.tx.close();
51
52        let _ = tokio::join!(concatter);
53
54        let outputs = out.rx.recv_all().await?;
55        assert_eq!(outputs.len(), 1);
56
57        for output in outputs {
58            assert_eq!(output.num_rows(), 20);
59            assert_eq!(output.num_columns(), 1);
60        }
61
62        Ok(())
63    }
64}