1use std::fs::File;
4
5use arrow::array::{
6 Array, AsArray, BooleanArray, Float32Array, Float64Array, Int8Array, Int16Array, Int32Array,
7 Int64Array,
8};
9use arrow::ipc::reader::FileReader;
10
11use kyu_common::{KyuError, KyuResult};
12use kyu_types::{LogicalType, TypedValue};
13
14use crate::DataReader;
15
16pub struct ArrowIpcReader {
18 schema: Vec<LogicalType>,
19 rows: std::vec::IntoIter<Vec<TypedValue>>,
20}
21
22impl ArrowIpcReader {
23 pub fn open(path: &str, schema: &[LogicalType]) -> KyuResult<Self> {
25 let file =
26 File::open(path).map_err(|e| KyuError::Copy(format!("cannot open '{path}': {e}")))?;
27
28 let reader = FileReader::try_new(file, None)
29 .map_err(|e| KyuError::Copy(format!("invalid Arrow IPC file '{path}': {e}")))?;
30
31 let mut all_rows = Vec::new();
32
33 for batch_result in reader {
34 let batch =
35 batch_result.map_err(|e| KyuError::Copy(format!("Arrow IPC batch error: {e}")))?;
36
37 let num_rows = batch.num_rows();
38 let num_cols = schema.len().min(batch.num_columns());
39
40 for row_idx in 0..num_rows {
41 let mut row = Vec::with_capacity(schema.len());
42 for (col_idx, col_type) in schema.iter().enumerate().take(num_cols) {
43 let col = batch.column(col_idx);
44 let value = extract_value(col.as_ref(), row_idx, col_type)?;
45 row.push(value);
46 }
47 for _ in num_cols..schema.len() {
48 row.push(TypedValue::Null);
49 }
50 all_rows.push(row);
51 }
52 }
53
54 Ok(Self {
55 schema: schema.to_vec(),
56 rows: all_rows.into_iter(),
57 })
58 }
59}
60
61impl DataReader for ArrowIpcReader {
62 fn schema(&self) -> &[LogicalType] {
63 &self.schema
64 }
65}
66
67impl Iterator for ArrowIpcReader {
68 type Item = KyuResult<Vec<TypedValue>>;
69
70 fn next(&mut self) -> Option<Self::Item> {
71 self.rows.next().map(Ok)
72 }
73}
74
75fn extract_value(
77 array: &dyn Array,
78 row: usize,
79 target_type: &LogicalType,
80) -> KyuResult<TypedValue> {
81 if array.is_null(row) {
82 return Ok(TypedValue::Null);
83 }
84
85 match target_type {
86 LogicalType::Int8 => {
87 let arr = array
88 .as_any()
89 .downcast_ref::<Int8Array>()
90 .ok_or_else(|| KyuError::Copy("expected Int8 column in Arrow IPC".into()))?;
91 Ok(TypedValue::Int8(arr.value(row)))
92 }
93 LogicalType::Int16 => {
94 let arr = array
95 .as_any()
96 .downcast_ref::<Int16Array>()
97 .ok_or_else(|| KyuError::Copy("expected Int16 column in Arrow IPC".into()))?;
98 Ok(TypedValue::Int16(arr.value(row)))
99 }
100 LogicalType::Int32 => {
101 let arr = array
102 .as_any()
103 .downcast_ref::<Int32Array>()
104 .ok_or_else(|| KyuError::Copy("expected Int32 column in Arrow IPC".into()))?;
105 Ok(TypedValue::Int32(arr.value(row)))
106 }
107 LogicalType::Int64 | LogicalType::Serial => {
108 let arr = array
109 .as_any()
110 .downcast_ref::<Int64Array>()
111 .ok_or_else(|| KyuError::Copy("expected Int64 column in Arrow IPC".into()))?;
112 Ok(TypedValue::Int64(arr.value(row)))
113 }
114 LogicalType::Float => {
115 let arr = array
116 .as_any()
117 .downcast_ref::<Float32Array>()
118 .ok_or_else(|| KyuError::Copy("expected Float32 column in Arrow IPC".into()))?;
119 Ok(TypedValue::Float(arr.value(row)))
120 }
121 LogicalType::Double => {
122 let arr = array
123 .as_any()
124 .downcast_ref::<Float64Array>()
125 .ok_or_else(|| KyuError::Copy("expected Float64 column in Arrow IPC".into()))?;
126 Ok(TypedValue::Double(arr.value(row)))
127 }
128 LogicalType::Bool => {
129 let arr = array
130 .as_any()
131 .downcast_ref::<BooleanArray>()
132 .ok_or_else(|| KyuError::Copy("expected Boolean column in Arrow IPC".into()))?;
133 Ok(TypedValue::Bool(arr.value(row)))
134 }
135 LogicalType::String => {
136 let arr = array.as_string::<i32>();
137 Ok(TypedValue::String(smol_str::SmolStr::new(arr.value(row))))
138 }
139 _ => Err(KyuError::Copy(format!(
140 "unsupported type {} for Arrow IPC import",
141 target_type.type_name()
142 ))),
143 }
144}
145
146#[cfg(test)]
147mod tests {
148 use super::*;
149 use arrow::array::{Int64Array, StringArray};
150 use arrow::datatypes::{DataType, Field, Schema};
151 use arrow::ipc::writer::FileWriter;
152 use arrow::record_batch::RecordBatch;
153 use std::sync::Arc;
154
155 fn write_test_arrow(dir: &std::path::Path, name: &str) -> String {
156 let path = dir.join(name);
157 let schema = Arc::new(Schema::new(vec![
158 Field::new("id", DataType::Int64, false),
159 Field::new("name", DataType::Utf8, false),
160 ]));
161
162 let ids = Int64Array::from(vec![10, 20, 30]);
163 let names = StringArray::from(vec!["X", "Y", "Z"]);
164 let batch = RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(ids), Arc::new(names)])
165 .unwrap();
166
167 let file = File::create(&path).unwrap();
168 let mut writer = FileWriter::try_new(file, &schema).unwrap();
169 writer.write(&batch).unwrap();
170 writer.finish().unwrap();
171
172 path.to_str().unwrap().to_string()
173 }
174
175 #[test]
176 fn read_arrow_ipc_basic() {
177 let dir = std::env::temp_dir().join("kyu_arrow_reader_test");
178 let _ = std::fs::create_dir_all(&dir);
179 let path = write_test_arrow(&dir, "test.arrow");
180
181 let schema = vec![LogicalType::Int64, LogicalType::String];
182 let reader = ArrowIpcReader::open(&path, &schema).unwrap();
183 let rows: Vec<_> = reader.collect::<Result<Vec<_>, _>>().unwrap();
184
185 assert_eq!(rows.len(), 3);
186 assert_eq!(rows[0][0], TypedValue::Int64(10));
187 assert_eq!(rows[0][1], TypedValue::String(smol_str::SmolStr::new("X")));
188 assert_eq!(rows[2][0], TypedValue::Int64(30));
189
190 let _ = std::fs::remove_dir_all(&dir);
191 }
192}