Skip to main content

flows_arrow/
slice_rows.rs

1// This is free and unencumbered software released into the public domain.
2
3use arrow_array::RecordBatch;
4use async_flow::{Inputs, Outputs, Port, Result};
5
6/// A block that applies offsets/limits to batches of rows.
7#[allow(unused)]
8pub async fn slice_rows(
9    mut offset: usize,
10    mut limit: Option<usize>,
11    mut inputs: Inputs<RecordBatch>,
12    outputs: Outputs<RecordBatch>,
13) -> Result {
14    let mut total_rows = 0;
15
16    while let Some(input) = inputs.recv().await? {
17        if input.num_rows() == 0 {
18            continue; // skip empty batches
19        }
20
21        let batch_len = input.num_rows();
22        total_rows += batch_len;
23
24        let output = match (offset, limit) {
25            (0, Some(0)) => RecordBatch::new_empty(input.schema()),
26            (o, Some(0)) => {
27                offset -= batch_len.min(o);
28                RecordBatch::new_empty(input.schema())
29            },
30
31            (0, None) => input,
32            (0, Some(n)) if n <= batch_len => {
33                limit = Some(0);
34                if n == batch_len {
35                    input
36                } else {
37                    input.slice(0, n)
38                }
39            },
40            (0, Some(n)) if n > batch_len => {
41                limit = Some(n - batch_len);
42                input
43            },
44
45            (o, None) if o <= batch_len => {
46                offset -= batch_len.min(o);
47                input.slice(o, batch_len - o)
48            },
49            (o, None) if o > batch_len => {
50                offset -= batch_len;
51                RecordBatch::new_empty(input.schema())
52            },
53
54            (o, Some(n)) if o + n <= batch_len => {
55                offset = 0;
56                limit = Some(0);
57                input.slice(o, n)
58            },
59            (o, Some(n)) if o >= batch_len => {
60                offset -= batch_len;
61                RecordBatch::new_empty(input.schema())
62            },
63            (o, Some(n)) if o < batch_len => {
64                let output_len = batch_len - o;
65                offset -= o;
66                limit = Some(n - output_len);
67                input.slice(o, output_len)
68            },
69
70            (_, _) => unreachable!(),
71        };
72
73        if !outputs.is_closed() {
74            outputs.send(output).await?;
75        }
76    }
77
78    Ok(())
79}
80
81#[cfg(test)]
82mod tests {
83    use super::*;
84    use alloc::{boxed::Box, vec, vec::Vec};
85    use arrow_array::record_batch;
86    use async_flow::{Channel, InputPort};
87    use core::error::Error;
88
89    #[tokio::test]
90    async fn test_slice_rows() -> Result<(), Box<dyn Error>> {
91        let output = exec_slice_rows(0, Some(0)).await?;
92        assert_eq!(output, Vec::<i32>::new());
93
94        let output = exec_slice_rows(10, Some(0)).await?;
95        assert_eq!(output, Vec::<i32>::new());
96
97        let output = exec_slice_rows(0, None).await?;
98        assert_eq!(output, (0..=29).collect::<Vec<i32>>());
99
100        let output = exec_slice_rows(0, Some(1)).await?;
101        assert_eq!(output, vec![0]);
102
103        let output = exec_slice_rows(0, Some(10)).await?;
104        assert_eq!(output, (0..=9).collect::<Vec<i32>>());
105
106        let output = exec_slice_rows(0, Some(11)).await?;
107        assert_eq!(output, (0..=10).collect::<Vec<i32>>());
108
109        let output = exec_slice_rows(0, Some(21)).await?;
110        assert_eq!(output, (0..=20).collect::<Vec<i32>>());
111
112        let output = exec_slice_rows(1, Some(1)).await?;
113        assert_eq!(output, vec![1]);
114
115        let output = exec_slice_rows(9, Some(1)).await?;
116        assert_eq!(output, vec![9]);
117
118        let output = exec_slice_rows(9, Some(2)).await?;
119        assert_eq!(output, vec![9, 10]);
120
121        let output = exec_slice_rows(9, Some(3)).await?;
122        assert_eq!(output, vec![9, 10, 11]);
123
124        let output = exec_slice_rows(9, Some(12)).await?;
125        assert_eq!(output, (9..=20).collect::<Vec<i32>>());
126
127        let output = exec_slice_rows(10, Some(1)).await?;
128        assert_eq!(output, vec![10]);
129
130        let output = exec_slice_rows(19, Some(2)).await?;
131        assert_eq!(output, vec![19, 20]);
132
133        let output = exec_slice_rows(19, Some(3)).await?;
134        assert_eq!(output, vec![19, 20, 21]);
135
136        let output = exec_slice_rows(29, Some(1)).await?;
137        assert_eq!(output, vec![29]);
138
139        let output = exec_slice_rows(29, Some(2)).await?;
140        assert_eq!(output, vec![29]);
141
142        Ok(())
143    }
144
145    async fn exec_slice_rows(
146        offset: usize,
147        limit: Option<usize>,
148    ) -> Result<Vec<i32>, Box<dyn Error>> {
149        let mut in_ = Channel::bounded(10);
150        let mut out = Channel::bounded(10);
151        let slicer = tokio::spawn(slice_rows(offset, limit, in_.rx, out.tx));
152
153        let batch = record_batch!(("n", Int32, [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]))?;
154        in_.tx.send(batch).await?;
155
156        let batch = record_batch!(("n", Int32, [10, 11, 12, 13, 14, 15, 16, 17, 18, 19]))?;
157        in_.tx.send(batch).await?;
158
159        let batch = record_batch!(("n", Int32, [20, 21, 22, 23, 24, 25, 26, 27, 28, 29]))?;
160        in_.tx.send(batch).await?;
161
162        in_.tx.close();
163
164        let _ = tokio::join!(slicer);
165
166        let outputs = out.rx.recv_all().await?;
167        assert_eq!(outputs.len(), 3);
168
169        let schema = outputs[0].schema();
170        let batch = arrow_select::concat::concat_batches(&schema, &outputs).unwrap();
171        let output = batch
172            .column(0)
173            .as_any()
174            .downcast_ref::<arrow_array::Int32Array>()
175            .unwrap()
176            .values()
177            .to_vec();
178
179        Ok(output)
180    }
181}