Skip to main content

kyu_copy/
arrow_reader.rs

1//! Arrow IPC file reader.
2
3use std::fs::File;
4
5use arrow::array::{
6    Array, AsArray, BooleanArray, Float32Array, Float64Array, Int8Array, Int16Array, Int32Array,
7    Int64Array,
8};
9use arrow::ipc::reader::FileReader;
10
11use kyu_common::{KyuError, KyuResult};
12use kyu_types::{LogicalType, TypedValue};
13
14use crate::DataReader;
15
16/// Reads rows from an Arrow IPC file (.arrow / .ipc).
17pub struct ArrowIpcReader {
18    schema: Vec<LogicalType>,
19    rows: std::vec::IntoIter<Vec<TypedValue>>,
20}
21
22impl ArrowIpcReader {
23    /// Open an Arrow IPC file and read all record batches into memory.
24    pub fn open(path: &str, schema: &[LogicalType]) -> KyuResult<Self> {
25        let file =
26            File::open(path).map_err(|e| KyuError::Copy(format!("cannot open '{path}': {e}")))?;
27
28        let reader = FileReader::try_new(file, None)
29            .map_err(|e| KyuError::Copy(format!("invalid Arrow IPC file '{path}': {e}")))?;
30
31        let mut all_rows = Vec::new();
32
33        for batch_result in reader {
34            let batch =
35                batch_result.map_err(|e| KyuError::Copy(format!("Arrow IPC batch error: {e}")))?;
36
37            let num_rows = batch.num_rows();
38            let num_cols = schema.len().min(batch.num_columns());
39
40            for row_idx in 0..num_rows {
41                let mut row = Vec::with_capacity(schema.len());
42                for (col_idx, col_type) in schema.iter().enumerate().take(num_cols) {
43                    let col = batch.column(col_idx);
44                    let value = extract_value(col.as_ref(), row_idx, col_type)?;
45                    row.push(value);
46                }
47                for _ in num_cols..schema.len() {
48                    row.push(TypedValue::Null);
49                }
50                all_rows.push(row);
51            }
52        }
53
54        Ok(Self {
55            schema: schema.to_vec(),
56            rows: all_rows.into_iter(),
57        })
58    }
59}
60
61impl DataReader for ArrowIpcReader {
62    fn schema(&self) -> &[LogicalType] {
63        &self.schema
64    }
65}
66
67impl Iterator for ArrowIpcReader {
68    type Item = KyuResult<Vec<TypedValue>>;
69
70    fn next(&mut self) -> Option<Self::Item> {
71        self.rows.next().map(Ok)
72    }
73}
74
75/// Extract a TypedValue from an Arrow array at the given row index.
76fn extract_value(
77    array: &dyn Array,
78    row: usize,
79    target_type: &LogicalType,
80) -> KyuResult<TypedValue> {
81    if array.is_null(row) {
82        return Ok(TypedValue::Null);
83    }
84
85    match target_type {
86        LogicalType::Int8 => {
87            let arr = array
88                .as_any()
89                .downcast_ref::<Int8Array>()
90                .ok_or_else(|| KyuError::Copy("expected Int8 column in Arrow IPC".into()))?;
91            Ok(TypedValue::Int8(arr.value(row)))
92        }
93        LogicalType::Int16 => {
94            let arr = array
95                .as_any()
96                .downcast_ref::<Int16Array>()
97                .ok_or_else(|| KyuError::Copy("expected Int16 column in Arrow IPC".into()))?;
98            Ok(TypedValue::Int16(arr.value(row)))
99        }
100        LogicalType::Int32 => {
101            let arr = array
102                .as_any()
103                .downcast_ref::<Int32Array>()
104                .ok_or_else(|| KyuError::Copy("expected Int32 column in Arrow IPC".into()))?;
105            Ok(TypedValue::Int32(arr.value(row)))
106        }
107        LogicalType::Int64 | LogicalType::Serial => {
108            let arr = array
109                .as_any()
110                .downcast_ref::<Int64Array>()
111                .ok_or_else(|| KyuError::Copy("expected Int64 column in Arrow IPC".into()))?;
112            Ok(TypedValue::Int64(arr.value(row)))
113        }
114        LogicalType::Float => {
115            let arr = array
116                .as_any()
117                .downcast_ref::<Float32Array>()
118                .ok_or_else(|| KyuError::Copy("expected Float32 column in Arrow IPC".into()))?;
119            Ok(TypedValue::Float(arr.value(row)))
120        }
121        LogicalType::Double => {
122            let arr = array
123                .as_any()
124                .downcast_ref::<Float64Array>()
125                .ok_or_else(|| KyuError::Copy("expected Float64 column in Arrow IPC".into()))?;
126            Ok(TypedValue::Double(arr.value(row)))
127        }
128        LogicalType::Bool => {
129            let arr = array
130                .as_any()
131                .downcast_ref::<BooleanArray>()
132                .ok_or_else(|| KyuError::Copy("expected Boolean column in Arrow IPC".into()))?;
133            Ok(TypedValue::Bool(arr.value(row)))
134        }
135        LogicalType::String => {
136            let arr = array.as_string::<i32>();
137            Ok(TypedValue::String(smol_str::SmolStr::new(arr.value(row))))
138        }
139        _ => Err(KyuError::Copy(format!(
140            "unsupported type {} for Arrow IPC import",
141            target_type.type_name()
142        ))),
143    }
144}
145
146#[cfg(test)]
147mod tests {
148    use super::*;
149    use arrow::array::{Int64Array, StringArray};
150    use arrow::datatypes::{DataType, Field, Schema};
151    use arrow::ipc::writer::FileWriter;
152    use arrow::record_batch::RecordBatch;
153    use std::sync::Arc;
154
155    fn write_test_arrow(dir: &std::path::Path, name: &str) -> String {
156        let path = dir.join(name);
157        let schema = Arc::new(Schema::new(vec![
158            Field::new("id", DataType::Int64, false),
159            Field::new("name", DataType::Utf8, false),
160        ]));
161
162        let ids = Int64Array::from(vec![10, 20, 30]);
163        let names = StringArray::from(vec!["X", "Y", "Z"]);
164        let batch = RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(ids), Arc::new(names)])
165            .unwrap();
166
167        let file = File::create(&path).unwrap();
168        let mut writer = FileWriter::try_new(file, &schema).unwrap();
169        writer.write(&batch).unwrap();
170        writer.finish().unwrap();
171
172        path.to_str().unwrap().to_string()
173    }
174
175    #[test]
176    fn read_arrow_ipc_basic() {
177        let dir = std::env::temp_dir().join("kyu_arrow_reader_test");
178        let _ = std::fs::create_dir_all(&dir);
179        let path = write_test_arrow(&dir, "test.arrow");
180
181        let schema = vec![LogicalType::Int64, LogicalType::String];
182        let reader = ArrowIpcReader::open(&path, &schema).unwrap();
183        let rows: Vec<_> = reader.collect::<Result<Vec<_>, _>>().unwrap();
184
185        assert_eq!(rows.len(), 3);
186        assert_eq!(rows[0][0], TypedValue::Int64(10));
187        assert_eq!(rows[0][1], TypedValue::String(smol_str::SmolStr::new("X")));
188        assert_eq!(rows[2][0], TypedValue::Int64(30));
189
190        let _ = std::fs::remove_dir_all(&dir);
191    }
192}