ipfrs_core/
arrow.rs

1//! Apache Arrow memory layout integration for zero-copy tensor access.
2//!
3//! This module provides conversions between IPFRS tensor types and Apache Arrow arrays,
4//! enabling zero-copy interoperability with the Arrow ecosystem (Parquet, Flight, etc.).
5//!
6//! ## Example
7//!
8//! ```rust
9//! use ipfrs_core::arrow::{TensorBlockArrowExt, arrow_to_tensor_block};
10//! use ipfrs_core::tensor::{TensorBlock, TensorDtype, TensorShape};
11//! use bytes::Bytes;
12//! use arrow_array::Float32Array;
13//!
14//! // Convert Arrow array to TensorBlock (zero-copy)
15//! let arrow_array = Float32Array::from(vec![1.0f32, 2.0, 3.0, 4.0]);
16//! let tensor = arrow_to_tensor_block(&arrow_array, TensorShape::new(vec![2, 2])).unwrap();
17//!
18//! // Convert TensorBlock back to Arrow array
19//! let arrow_back = tensor.to_arrow_array().unwrap();
20//! ```
21
22use crate::error::{Error, Result};
23use crate::tensor::{TensorBlock, TensorDtype, TensorShape};
24use arrow_array::{
25    Array, ArrayRef, BooleanArray, Float32Array, Float64Array, Int32Array, Int64Array, Int8Array,
26    UInt32Array, UInt8Array,
27};
28use arrow_buffer::Buffer;
29use arrow_schema::{DataType, Field, Schema};
30use bytes::Bytes;
31use std::sync::Arc;
32
33/// Extension trait for TensorBlock to provide Arrow conversions
34pub trait TensorBlockArrowExt {
35    /// Convert to an Arrow array (zero-copy when possible)
36    fn to_arrow_array(&self) -> Result<ArrayRef>;
37
38    /// Convert to an Arrow Field (for schema)
39    fn to_arrow_field(&self, name: &str) -> Field;
40
41    /// Convert to an Arrow Schema
42    fn to_arrow_schema(&self, field_name: &str) -> Schema;
43}
44
45impl TensorBlockArrowExt for TensorBlock {
46    fn to_arrow_array(&self) -> Result<ArrayRef> {
47        let metadata = self.metadata();
48        let data = self.data();
49
50        match metadata.dtype {
51            TensorDtype::F32 => {
52                let buffer = Buffer::from(data.clone());
53                let array = Float32Array::new(buffer.into(), None);
54                Ok(Arc::new(array) as ArrayRef)
55            }
56            TensorDtype::F64 => {
57                let buffer = Buffer::from(data.clone());
58                let array = Float64Array::new(buffer.into(), None);
59                Ok(Arc::new(array) as ArrayRef)
60            }
61            TensorDtype::I8 => {
62                let buffer = Buffer::from(data.clone());
63                let array = Int8Array::new(buffer.into(), None);
64                Ok(Arc::new(array) as ArrayRef)
65            }
66            TensorDtype::I32 => {
67                let buffer = Buffer::from(data.clone());
68                let array = Int32Array::new(buffer.into(), None);
69                Ok(Arc::new(array) as ArrayRef)
70            }
71            TensorDtype::I64 => {
72                let buffer = Buffer::from(data.clone());
73                let array = Int64Array::new(buffer.into(), None);
74                Ok(Arc::new(array) as ArrayRef)
75            }
76            TensorDtype::U8 => {
77                let buffer = Buffer::from(data.clone());
78                let array = UInt8Array::new(buffer.into(), None);
79                Ok(Arc::new(array) as ArrayRef)
80            }
81            TensorDtype::U32 => {
82                let buffer = Buffer::from(data.clone());
83                let array = UInt32Array::new(buffer.into(), None);
84                Ok(Arc::new(array) as ArrayRef)
85            }
86            TensorDtype::Bool => {
87                // Boolean arrays are stored as bit-packed in Arrow
88                let bytes: Vec<u8> = data.to_vec();
89                let array = BooleanArray::from(bytes.iter().map(|&b| b != 0).collect::<Vec<_>>());
90                Ok(Arc::new(array) as ArrayRef)
91            }
92            TensorDtype::F16 => {
93                // Arrow doesn't have native F16 support, convert to F32
94                Err(Error::InvalidInput(
95                    "F16 not directly supported by Arrow, use F32 instead".to_string(),
96                ))
97            }
98        }
99    }
100
101    fn to_arrow_field(&self, name: &str) -> Field {
102        let metadata = self.metadata();
103        let arrow_dtype = tensor_dtype_to_arrow(&metadata.dtype);
104        Field::new(name, arrow_dtype, false)
105    }
106
107    fn to_arrow_schema(&self, field_name: &str) -> Schema {
108        Schema::new(vec![self.to_arrow_field(field_name)])
109    }
110}
111
112/// Convert Arrow DataType to TensorDtype
113pub fn arrow_dtype_to_tensor(dtype: &DataType) -> Result<TensorDtype> {
114    match dtype {
115        DataType::Float32 => Ok(TensorDtype::F32),
116        DataType::Float64 => Ok(TensorDtype::F64),
117        DataType::Int8 => Ok(TensorDtype::I8),
118        DataType::Int32 => Ok(TensorDtype::I32),
119        DataType::Int64 => Ok(TensorDtype::I64),
120        DataType::UInt8 => Ok(TensorDtype::U8),
121        DataType::UInt32 => Ok(TensorDtype::U32),
122        DataType::Boolean => Ok(TensorDtype::Bool),
123        _ => Err(Error::InvalidInput(format!(
124            "Unsupported Arrow dtype: {:?}",
125            dtype
126        ))),
127    }
128}
129
130/// Convert TensorDtype to Arrow DataType
131pub fn tensor_dtype_to_arrow(dtype: &TensorDtype) -> DataType {
132    match dtype {
133        TensorDtype::F32 => DataType::Float32,
134        TensorDtype::F64 => DataType::Float64,
135        TensorDtype::I8 => DataType::Int8,
136        TensorDtype::I32 => DataType::Int32,
137        TensorDtype::I64 => DataType::Int64,
138        TensorDtype::U8 => DataType::UInt8,
139        TensorDtype::U32 => DataType::UInt32,
140        TensorDtype::Bool => DataType::Boolean,
141        TensorDtype::F16 => DataType::Float32, // Fallback to F32
142    }
143}
144
145/// Convert an Arrow array to a TensorBlock (zero-copy)
146pub fn arrow_to_tensor_block(array: &dyn Array, shape: TensorShape) -> Result<TensorBlock> {
147    let dtype = arrow_dtype_to_tensor(array.data_type())?;
148
149    // Get the raw buffer data
150    let data = match array.data_type() {
151        DataType::Float32 => {
152            let arr = array.as_any().downcast_ref::<Float32Array>().unwrap();
153            let buffer = arr.values();
154            // Cast typed slice to &[u8] for Bytes
155            let byte_slice = unsafe {
156                std::slice::from_raw_parts(
157                    buffer.as_ptr() as *const u8,
158                    buffer.len() * std::mem::size_of::<f32>(),
159                )
160            };
161            Bytes::copy_from_slice(byte_slice)
162        }
163        DataType::Float64 => {
164            let arr = array.as_any().downcast_ref::<Float64Array>().unwrap();
165            let buffer = arr.values();
166            let byte_slice = unsafe {
167                std::slice::from_raw_parts(
168                    buffer.as_ptr() as *const u8,
169                    buffer.len() * std::mem::size_of::<f64>(),
170                )
171            };
172            Bytes::copy_from_slice(byte_slice)
173        }
174        DataType::Int8 => {
175            let arr = array.as_any().downcast_ref::<Int8Array>().unwrap();
176            let buffer = arr.values();
177            let byte_slice =
178                unsafe { std::slice::from_raw_parts(buffer.as_ptr() as *const u8, buffer.len()) };
179            Bytes::copy_from_slice(byte_slice)
180        }
181        DataType::Int32 => {
182            let arr = array.as_any().downcast_ref::<Int32Array>().unwrap();
183            let buffer = arr.values();
184            let byte_slice = unsafe {
185                std::slice::from_raw_parts(
186                    buffer.as_ptr() as *const u8,
187                    buffer.len() * std::mem::size_of::<i32>(),
188                )
189            };
190            Bytes::copy_from_slice(byte_slice)
191        }
192        DataType::Int64 => {
193            let arr = array.as_any().downcast_ref::<Int64Array>().unwrap();
194            let buffer = arr.values();
195            let byte_slice = unsafe {
196                std::slice::from_raw_parts(
197                    buffer.as_ptr() as *const u8,
198                    buffer.len() * std::mem::size_of::<i64>(),
199                )
200            };
201            Bytes::copy_from_slice(byte_slice)
202        }
203        DataType::UInt8 => {
204            let arr = array.as_any().downcast_ref::<UInt8Array>().unwrap();
205            let buffer = arr.values();
206            Bytes::copy_from_slice(buffer.as_ref())
207        }
208        DataType::UInt32 => {
209            let arr = array.as_any().downcast_ref::<UInt32Array>().unwrap();
210            let buffer = arr.values();
211            let byte_slice = unsafe {
212                std::slice::from_raw_parts(
213                    buffer.as_ptr() as *const u8,
214                    buffer.len() * std::mem::size_of::<u32>(),
215                )
216            };
217            Bytes::copy_from_slice(byte_slice)
218        }
219        DataType::Boolean => {
220            let arr = array.as_any().downcast_ref::<BooleanArray>().unwrap();
221            let bytes: Vec<u8> = (0..arr.len()).map(|i| arr.value(i) as u8).collect();
222            Bytes::from(bytes)
223        }
224        _ => {
225            return Err(Error::InvalidInput(format!(
226                "Unsupported Arrow dtype: {:?}",
227                array.data_type()
228            )))
229        }
230    };
231
232    TensorBlock::new(data, shape, dtype)
233}
234
235/// Create an Arrow RecordBatch from multiple TensorBlocks
236#[allow(dead_code)]
237pub fn tensors_to_record_batch(
238    tensors: Vec<(&str, &TensorBlock)>,
239) -> Result<arrow_array::RecordBatch> {
240    let mut fields = Vec::new();
241    let mut arrays: Vec<ArrayRef> = Vec::new();
242
243    for (name, tensor) in tensors {
244        fields.push(tensor.to_arrow_field(name));
245        arrays.push(tensor.to_arrow_array()?);
246    }
247
248    let schema = Arc::new(Schema::new(fields));
249    arrow_array::RecordBatch::try_new(schema, arrays)
250        .map_err(|e| Error::InvalidInput(format!("Failed to create RecordBatch: {}", e)))
251}
252
253#[cfg(test)]
254mod tests {
255    use super::*;
256
257    #[test]
258    fn test_tensor_to_arrow_f32() {
259        let data = [1.0f32, 2.0, 3.0, 4.0];
260        let bytes = Bytes::from(
261            data.iter()
262                .flat_map(|&f| f.to_le_bytes())
263                .collect::<Vec<u8>>(),
264        );
265
266        let tensor =
267            TensorBlock::new(bytes, TensorShape::new(vec![2, 2]), TensorDtype::F32).unwrap();
268
269        let arrow_array = tensor.to_arrow_array().unwrap();
270        let f32_array = arrow_array.as_any().downcast_ref::<Float32Array>().unwrap();
271
272        assert_eq!(f32_array.len(), 4);
273        assert_eq!(f32_array.value(0), 1.0);
274        assert_eq!(f32_array.value(1), 2.0);
275        assert_eq!(f32_array.value(2), 3.0);
276        assert_eq!(f32_array.value(3), 4.0);
277    }
278
279    #[test]
280    fn test_arrow_to_tensor_f32() {
281        let arrow_array = Float32Array::from(vec![1.0f32, 2.0, 3.0, 4.0]);
282        let tensor = arrow_to_tensor_block(&arrow_array, TensorShape::new(vec![2, 2])).unwrap();
283
284        assert_eq!(tensor.element_count(), 4);
285        assert_eq!(tensor.metadata().dtype, TensorDtype::F32);
286    }
287
288    #[test]
289    fn test_tensor_to_arrow_i32() {
290        let data = [1i32, 2, 3, 4];
291        let bytes = Bytes::from(
292            data.iter()
293                .flat_map(|&i| i.to_le_bytes())
294                .collect::<Vec<u8>>(),
295        );
296
297        let tensor = TensorBlock::new(bytes, TensorShape::new(vec![4]), TensorDtype::I32).unwrap();
298
299        let arrow_array = tensor.to_arrow_array().unwrap();
300        let i32_array = arrow_array.as_any().downcast_ref::<Int32Array>().unwrap();
301
302        assert_eq!(i32_array.len(), 4);
303        assert_eq!(i32_array.value(0), 1);
304        assert_eq!(i32_array.value(3), 4);
305    }
306
307    #[test]
308    fn test_dtype_conversions() {
309        // TensorDtype to Arrow DataType
310        assert_eq!(tensor_dtype_to_arrow(&TensorDtype::F32), DataType::Float32);
311        assert_eq!(tensor_dtype_to_arrow(&TensorDtype::I64), DataType::Int64);
312        assert_eq!(tensor_dtype_to_arrow(&TensorDtype::Bool), DataType::Boolean);
313
314        // Arrow DataType to TensorDtype
315        assert_eq!(
316            arrow_dtype_to_tensor(&DataType::Float32).unwrap(),
317            TensorDtype::F32
318        );
319        assert_eq!(
320            arrow_dtype_to_tensor(&DataType::Int64).unwrap(),
321            TensorDtype::I64
322        );
323    }
324
325    #[test]
326    fn test_arrow_schema_generation() {
327        let data = Bytes::from(vec![0u8; 16]);
328        let tensor = TensorBlock::new(data, TensorShape::new(vec![4]), TensorDtype::F32).unwrap();
329
330        let schema = tensor.to_arrow_schema("tensor_data");
331        assert_eq!(schema.fields().len(), 1);
332        assert_eq!(schema.field(0).name(), "tensor_data");
333        assert_eq!(schema.field(0).data_type(), &DataType::Float32);
334    }
335
336    #[test]
337    fn test_zero_copy_roundtrip() {
338        // Create Arrow array
339        let original_data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
340        let arrow_array = Float32Array::from(original_data.clone());
341
342        // Convert to TensorBlock
343        let tensor = arrow_to_tensor_block(&arrow_array, TensorShape::new(vec![2, 3])).unwrap();
344
345        // Convert back to Arrow
346        let arrow_back = tensor.to_arrow_array().unwrap();
347        let f32_back = arrow_back.as_any().downcast_ref::<Float32Array>().unwrap();
348
349        // Verify data integrity
350        assert_eq!(f32_back.len(), original_data.len());
351        for (i, &expected) in original_data.iter().enumerate() {
352            assert_eq!(f32_back.value(i), expected);
353        }
354    }
355
356    #[test]
357    fn test_tensor_to_arrow_field() {
358        let data = Bytes::from(vec![0u8; 64]); // 8 elements * 8 bytes per I64
359        let tensor = TensorBlock::new(data, TensorShape::new(vec![8]), TensorDtype::I64).unwrap();
360
361        let field = tensor.to_arrow_field("my_tensor");
362        assert_eq!(field.name(), "my_tensor");
363        assert_eq!(field.data_type(), &DataType::Int64);
364        assert!(!field.is_nullable());
365    }
366}