1use crate::error::{Error, Result};
23use crate::tensor::{TensorBlock, TensorDtype, TensorShape};
24use arrow_array::{
25 Array, ArrayRef, BooleanArray, Float32Array, Float64Array, Int32Array, Int64Array, Int8Array,
26 UInt32Array, UInt8Array,
27};
28use arrow_buffer::Buffer;
29use arrow_schema::{DataType, Field, Schema};
30use bytes::Bytes;
31use std::sync::Arc;
32
33pub trait TensorBlockArrowExt {
35 fn to_arrow_array(&self) -> Result<ArrayRef>;
37
38 fn to_arrow_field(&self, name: &str) -> Field;
40
41 fn to_arrow_schema(&self, field_name: &str) -> Schema;
43}
44
45impl TensorBlockArrowExt for TensorBlock {
46 fn to_arrow_array(&self) -> Result<ArrayRef> {
47 let metadata = self.metadata();
48 let data = self.data();
49
50 match metadata.dtype {
51 TensorDtype::F32 => {
52 let buffer = Buffer::from(data.clone());
53 let array = Float32Array::new(buffer.into(), None);
54 Ok(Arc::new(array) as ArrayRef)
55 }
56 TensorDtype::F64 => {
57 let buffer = Buffer::from(data.clone());
58 let array = Float64Array::new(buffer.into(), None);
59 Ok(Arc::new(array) as ArrayRef)
60 }
61 TensorDtype::I8 => {
62 let buffer = Buffer::from(data.clone());
63 let array = Int8Array::new(buffer.into(), None);
64 Ok(Arc::new(array) as ArrayRef)
65 }
66 TensorDtype::I32 => {
67 let buffer = Buffer::from(data.clone());
68 let array = Int32Array::new(buffer.into(), None);
69 Ok(Arc::new(array) as ArrayRef)
70 }
71 TensorDtype::I64 => {
72 let buffer = Buffer::from(data.clone());
73 let array = Int64Array::new(buffer.into(), None);
74 Ok(Arc::new(array) as ArrayRef)
75 }
76 TensorDtype::U8 => {
77 let buffer = Buffer::from(data.clone());
78 let array = UInt8Array::new(buffer.into(), None);
79 Ok(Arc::new(array) as ArrayRef)
80 }
81 TensorDtype::U32 => {
82 let buffer = Buffer::from(data.clone());
83 let array = UInt32Array::new(buffer.into(), None);
84 Ok(Arc::new(array) as ArrayRef)
85 }
86 TensorDtype::Bool => {
87 let bytes: Vec<u8> = data.to_vec();
89 let array = BooleanArray::from(bytes.iter().map(|&b| b != 0).collect::<Vec<_>>());
90 Ok(Arc::new(array) as ArrayRef)
91 }
92 TensorDtype::F16 => {
93 Err(Error::InvalidInput(
95 "F16 not directly supported by Arrow, use F32 instead".to_string(),
96 ))
97 }
98 }
99 }
100
101 fn to_arrow_field(&self, name: &str) -> Field {
102 let metadata = self.metadata();
103 let arrow_dtype = tensor_dtype_to_arrow(&metadata.dtype);
104 Field::new(name, arrow_dtype, false)
105 }
106
107 fn to_arrow_schema(&self, field_name: &str) -> Schema {
108 Schema::new(vec![self.to_arrow_field(field_name)])
109 }
110}
111
112pub fn arrow_dtype_to_tensor(dtype: &DataType) -> Result<TensorDtype> {
114 match dtype {
115 DataType::Float32 => Ok(TensorDtype::F32),
116 DataType::Float64 => Ok(TensorDtype::F64),
117 DataType::Int8 => Ok(TensorDtype::I8),
118 DataType::Int32 => Ok(TensorDtype::I32),
119 DataType::Int64 => Ok(TensorDtype::I64),
120 DataType::UInt8 => Ok(TensorDtype::U8),
121 DataType::UInt32 => Ok(TensorDtype::U32),
122 DataType::Boolean => Ok(TensorDtype::Bool),
123 _ => Err(Error::InvalidInput(format!(
124 "Unsupported Arrow dtype: {:?}",
125 dtype
126 ))),
127 }
128}
129
130pub fn tensor_dtype_to_arrow(dtype: &TensorDtype) -> DataType {
132 match dtype {
133 TensorDtype::F32 => DataType::Float32,
134 TensorDtype::F64 => DataType::Float64,
135 TensorDtype::I8 => DataType::Int8,
136 TensorDtype::I32 => DataType::Int32,
137 TensorDtype::I64 => DataType::Int64,
138 TensorDtype::U8 => DataType::UInt8,
139 TensorDtype::U32 => DataType::UInt32,
140 TensorDtype::Bool => DataType::Boolean,
141 TensorDtype::F16 => DataType::Float32, }
143}
144
145pub fn arrow_to_tensor_block(array: &dyn Array, shape: TensorShape) -> Result<TensorBlock> {
147 let dtype = arrow_dtype_to_tensor(array.data_type())?;
148
149 let data = match array.data_type() {
151 DataType::Float32 => {
152 let arr = array.as_any().downcast_ref::<Float32Array>().unwrap();
153 let buffer = arr.values();
154 let byte_slice = unsafe {
156 std::slice::from_raw_parts(
157 buffer.as_ptr() as *const u8,
158 buffer.len() * std::mem::size_of::<f32>(),
159 )
160 };
161 Bytes::copy_from_slice(byte_slice)
162 }
163 DataType::Float64 => {
164 let arr = array.as_any().downcast_ref::<Float64Array>().unwrap();
165 let buffer = arr.values();
166 let byte_slice = unsafe {
167 std::slice::from_raw_parts(
168 buffer.as_ptr() as *const u8,
169 buffer.len() * std::mem::size_of::<f64>(),
170 )
171 };
172 Bytes::copy_from_slice(byte_slice)
173 }
174 DataType::Int8 => {
175 let arr = array.as_any().downcast_ref::<Int8Array>().unwrap();
176 let buffer = arr.values();
177 let byte_slice =
178 unsafe { std::slice::from_raw_parts(buffer.as_ptr() as *const u8, buffer.len()) };
179 Bytes::copy_from_slice(byte_slice)
180 }
181 DataType::Int32 => {
182 let arr = array.as_any().downcast_ref::<Int32Array>().unwrap();
183 let buffer = arr.values();
184 let byte_slice = unsafe {
185 std::slice::from_raw_parts(
186 buffer.as_ptr() as *const u8,
187 buffer.len() * std::mem::size_of::<i32>(),
188 )
189 };
190 Bytes::copy_from_slice(byte_slice)
191 }
192 DataType::Int64 => {
193 let arr = array.as_any().downcast_ref::<Int64Array>().unwrap();
194 let buffer = arr.values();
195 let byte_slice = unsafe {
196 std::slice::from_raw_parts(
197 buffer.as_ptr() as *const u8,
198 buffer.len() * std::mem::size_of::<i64>(),
199 )
200 };
201 Bytes::copy_from_slice(byte_slice)
202 }
203 DataType::UInt8 => {
204 let arr = array.as_any().downcast_ref::<UInt8Array>().unwrap();
205 let buffer = arr.values();
206 Bytes::copy_from_slice(buffer.as_ref())
207 }
208 DataType::UInt32 => {
209 let arr = array.as_any().downcast_ref::<UInt32Array>().unwrap();
210 let buffer = arr.values();
211 let byte_slice = unsafe {
212 std::slice::from_raw_parts(
213 buffer.as_ptr() as *const u8,
214 buffer.len() * std::mem::size_of::<u32>(),
215 )
216 };
217 Bytes::copy_from_slice(byte_slice)
218 }
219 DataType::Boolean => {
220 let arr = array.as_any().downcast_ref::<BooleanArray>().unwrap();
221 let bytes: Vec<u8> = (0..arr.len()).map(|i| arr.value(i) as u8).collect();
222 Bytes::from(bytes)
223 }
224 _ => {
225 return Err(Error::InvalidInput(format!(
226 "Unsupported Arrow dtype: {:?}",
227 array.data_type()
228 )))
229 }
230 };
231
232 TensorBlock::new(data, shape, dtype)
233}
234
235#[allow(dead_code)]
237pub fn tensors_to_record_batch(
238 tensors: Vec<(&str, &TensorBlock)>,
239) -> Result<arrow_array::RecordBatch> {
240 let mut fields = Vec::new();
241 let mut arrays: Vec<ArrayRef> = Vec::new();
242
243 for (name, tensor) in tensors {
244 fields.push(tensor.to_arrow_field(name));
245 arrays.push(tensor.to_arrow_array()?);
246 }
247
248 let schema = Arc::new(Schema::new(fields));
249 arrow_array::RecordBatch::try_new(schema, arrays)
250 .map_err(|e| Error::InvalidInput(format!("Failed to create RecordBatch: {}", e)))
251}
252
253#[cfg(test)]
254mod tests {
255 use super::*;
256
257 #[test]
258 fn test_tensor_to_arrow_f32() {
259 let data = [1.0f32, 2.0, 3.0, 4.0];
260 let bytes = Bytes::from(
261 data.iter()
262 .flat_map(|&f| f.to_le_bytes())
263 .collect::<Vec<u8>>(),
264 );
265
266 let tensor =
267 TensorBlock::new(bytes, TensorShape::new(vec![2, 2]), TensorDtype::F32).unwrap();
268
269 let arrow_array = tensor.to_arrow_array().unwrap();
270 let f32_array = arrow_array.as_any().downcast_ref::<Float32Array>().unwrap();
271
272 assert_eq!(f32_array.len(), 4);
273 assert_eq!(f32_array.value(0), 1.0);
274 assert_eq!(f32_array.value(1), 2.0);
275 assert_eq!(f32_array.value(2), 3.0);
276 assert_eq!(f32_array.value(3), 4.0);
277 }
278
279 #[test]
280 fn test_arrow_to_tensor_f32() {
281 let arrow_array = Float32Array::from(vec![1.0f32, 2.0, 3.0, 4.0]);
282 let tensor = arrow_to_tensor_block(&arrow_array, TensorShape::new(vec![2, 2])).unwrap();
283
284 assert_eq!(tensor.element_count(), 4);
285 assert_eq!(tensor.metadata().dtype, TensorDtype::F32);
286 }
287
288 #[test]
289 fn test_tensor_to_arrow_i32() {
290 let data = [1i32, 2, 3, 4];
291 let bytes = Bytes::from(
292 data.iter()
293 .flat_map(|&i| i.to_le_bytes())
294 .collect::<Vec<u8>>(),
295 );
296
297 let tensor = TensorBlock::new(bytes, TensorShape::new(vec![4]), TensorDtype::I32).unwrap();
298
299 let arrow_array = tensor.to_arrow_array().unwrap();
300 let i32_array = arrow_array.as_any().downcast_ref::<Int32Array>().unwrap();
301
302 assert_eq!(i32_array.len(), 4);
303 assert_eq!(i32_array.value(0), 1);
304 assert_eq!(i32_array.value(3), 4);
305 }
306
307 #[test]
308 fn test_dtype_conversions() {
309 assert_eq!(tensor_dtype_to_arrow(&TensorDtype::F32), DataType::Float32);
311 assert_eq!(tensor_dtype_to_arrow(&TensorDtype::I64), DataType::Int64);
312 assert_eq!(tensor_dtype_to_arrow(&TensorDtype::Bool), DataType::Boolean);
313
314 assert_eq!(
316 arrow_dtype_to_tensor(&DataType::Float32).unwrap(),
317 TensorDtype::F32
318 );
319 assert_eq!(
320 arrow_dtype_to_tensor(&DataType::Int64).unwrap(),
321 TensorDtype::I64
322 );
323 }
324
325 #[test]
326 fn test_arrow_schema_generation() {
327 let data = Bytes::from(vec![0u8; 16]);
328 let tensor = TensorBlock::new(data, TensorShape::new(vec![4]), TensorDtype::F32).unwrap();
329
330 let schema = tensor.to_arrow_schema("tensor_data");
331 assert_eq!(schema.fields().len(), 1);
332 assert_eq!(schema.field(0).name(), "tensor_data");
333 assert_eq!(schema.field(0).data_type(), &DataType::Float32);
334 }
335
336 #[test]
337 fn test_zero_copy_roundtrip() {
338 let original_data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
340 let arrow_array = Float32Array::from(original_data.clone());
341
342 let tensor = arrow_to_tensor_block(&arrow_array, TensorShape::new(vec![2, 3])).unwrap();
344
345 let arrow_back = tensor.to_arrow_array().unwrap();
347 let f32_back = arrow_back.as_any().downcast_ref::<Float32Array>().unwrap();
348
349 assert_eq!(f32_back.len(), original_data.len());
351 for (i, &expected) in original_data.iter().enumerate() {
352 assert_eq!(f32_back.value(i), expected);
353 }
354 }
355
356 #[test]
357 fn test_tensor_to_arrow_field() {
358 let data = Bytes::from(vec![0u8; 64]); let tensor = TensorBlock::new(data, TensorShape::new(vec![8]), TensorDtype::I64).unwrap();
360
361 let field = tensor.to_arrow_field("my_tensor");
362 assert_eq!(field.name(), "my_tensor");
363 assert_eq!(field.data_type(), &DataType::Int64);
364 assert!(!field.is_nullable());
365 }
366}