ipfrs_interface/
arrow.rs

1//! Apache Arrow integration for zero-copy data exchange
2//!
3//! This module provides Apache Arrow IPC format support for tensor data,
4//! enabling efficient zero-copy data transfer for ML/data science workflows.
5
6use arrow::array::{
7    ArrayRef, Float32Array, Float64Array, Int32Array, Int64Array, UInt16Array, UInt32Array,
8    UInt64Array, UInt8Array,
9};
10use arrow::buffer::Buffer;
11use arrow::datatypes::{DataType, Field, Schema};
12use arrow::ipc::writer::StreamWriter;
13use arrow::record_batch::RecordBatch;
14use bytes::Bytes;
15use ipfrs_core::error::{Error, Result};
16use std::sync::Arc;
17
18use crate::tensor::TensorMetadata;
19
20/// Convert tensor data to Apache Arrow RecordBatch
21///
22/// The tensor is represented as a single column in the RecordBatch,
23/// with the column name "data" and appropriate Arrow data type.
24pub fn tensor_to_record_batch(metadata: &TensorMetadata, data: &[u8]) -> Result<RecordBatch> {
25    // Determine Arrow data type from tensor dtype string
26    let arrow_dtype = match metadata.dtype.as_str() {
27        "F32" | "f32" => DataType::Float32,
28        "F64" | "f64" => DataType::Float64,
29        "I32" | "i32" => DataType::Int32,
30        "I64" | "i64" => DataType::Int64,
31        "U8" | "u8" => DataType::UInt8,
32        "U16" | "u16" => DataType::UInt16,
33        "U32" | "u32" => DataType::UInt32,
34        "U64" | "u64" => DataType::UInt64,
35        _ => {
36            return Err(Error::Internal(format!(
37                "Unsupported dtype: {}",
38                metadata.dtype
39            )))
40        }
41    };
42
43    // Create schema
44    let schema = Schema::new(vec![Field::new("data", arrow_dtype.clone(), false)]);
45
46    // Create array from raw data
47    let array: ArrayRef = match metadata.dtype.as_str() {
48        "F32" | "f32" => {
49            let buffer = Buffer::from(data);
50            Arc::new(Float32Array::new(buffer.into(), None))
51        }
52        "F64" | "f64" => {
53            let buffer = Buffer::from(data);
54            Arc::new(Float64Array::new(buffer.into(), None))
55        }
56        "I32" | "i32" => {
57            let buffer = Buffer::from(data);
58            Arc::new(Int32Array::new(buffer.into(), None))
59        }
60        "I64" | "i64" => {
61            let buffer = Buffer::from(data);
62            Arc::new(Int64Array::new(buffer.into(), None))
63        }
64        "U8" | "u8" => {
65            let buffer = Buffer::from(data);
66            Arc::new(UInt8Array::new(buffer.into(), None))
67        }
68        "U16" | "u16" => {
69            let buffer = Buffer::from(data);
70            Arc::new(UInt16Array::new(buffer.into(), None))
71        }
72        "U32" | "u32" => {
73            let buffer = Buffer::from(data);
74            Arc::new(UInt32Array::new(buffer.into(), None))
75        }
76        "U64" | "u64" => {
77            let buffer = Buffer::from(data);
78            Arc::new(UInt64Array::new(buffer.into(), None))
79        }
80        _ => {
81            return Err(Error::Internal(format!(
82                "Unsupported dtype: {}",
83                metadata.dtype
84            )))
85        }
86    };
87
88    // Create record batch
89    RecordBatch::try_new(Arc::new(schema), vec![array])
90        .map_err(|e| Error::Internal(format!("Failed to create Arrow RecordBatch: {}", e)))
91}
92
93/// Serialize RecordBatch to Arrow IPC Stream format
94///
95/// Returns the serialized bytes that can be sent over HTTP
96pub fn record_batch_to_ipc_bytes(batch: &RecordBatch) -> Result<Bytes> {
97    let mut buffer = Vec::new();
98    {
99        let mut writer = StreamWriter::try_new(&mut buffer, &batch.schema())
100            .map_err(|e| Error::Internal(format!("Failed to create Arrow StreamWriter: {}", e)))?;
101
102        writer
103            .write(batch)
104            .map_err(|e| Error::Internal(format!("Failed to write Arrow batch: {}", e)))?;
105
106        writer
107            .finish()
108            .map_err(|e| Error::Internal(format!("Failed to finish Arrow stream: {}", e)))?;
109    }
110
111    Ok(Bytes::from(buffer))
112}
113
114/// Create Arrow schema with metadata for tensor shape and dtype
115///
116/// This enriches the Arrow schema with custom metadata about the tensor dimensions
117pub fn create_tensor_schema(metadata: &TensorMetadata) -> Result<Schema> {
118    let arrow_dtype = match metadata.dtype.as_str() {
119        "F32" | "f32" => DataType::Float32,
120        "F64" | "f64" => DataType::Float64,
121        "I32" | "i32" => DataType::Int32,
122        "I64" | "i64" => DataType::Int64,
123        "U8" | "u8" => DataType::UInt8,
124        "U16" | "u16" => DataType::UInt16,
125        "U32" | "u32" => DataType::UInt32,
126        "U64" | "u64" => DataType::UInt64,
127        _ => {
128            return Err(Error::Internal(format!(
129                "Unsupported dtype: {}",
130                metadata.dtype
131            )))
132        }
133    };
134
135    // Create field with metadata
136    let mut field = Field::new("data", arrow_dtype, false);
137
138    // Add tensor shape as metadata
139    let shape_str = metadata
140        .shape
141        .iter()
142        .map(|s| s.to_string())
143        .collect::<Vec<_>>()
144        .join(",");
145
146    field = field.with_metadata(
147        [
148            ("tensor_shape".to_string(), shape_str),
149            ("tensor_dtype".to_string(), metadata.dtype.clone()),
150            (
151                "tensor_layout".to_string(),
152                format!("{:?}", metadata.layout),
153            ),
154        ]
155        .into_iter()
156        .collect(),
157    );
158
159    Ok(Schema::new(vec![field]))
160}
161
162#[cfg(test)]
163mod tests {
164    use super::*;
165    use crate::tensor::TensorLayout;
166
167    #[test]
168    fn test_tensor_to_record_batch_f32() {
169        let metadata = TensorMetadata {
170            shape: vec![2, 3],
171            dtype: "F32".to_string(),
172            num_elements: 6,
173            size_bytes: 24,
174            layout: TensorLayout::RowMajor,
175        };
176
177        let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
178        let bytes = data
179            .iter()
180            .flat_map(|f| f.to_le_bytes())
181            .collect::<Vec<u8>>();
182
183        let batch = tensor_to_record_batch(&metadata, &bytes).unwrap();
184        assert_eq!(batch.num_columns(), 1);
185        assert_eq!(batch.num_rows(), 6);
186
187        let array = batch
188            .column(0)
189            .as_any()
190            .downcast_ref::<Float32Array>()
191            .unwrap();
192        assert_eq!(array.value(0), 1.0);
193        assert_eq!(array.value(5), 6.0);
194    }
195
196    #[test]
197    fn test_tensor_to_record_batch_i32() {
198        let metadata = TensorMetadata {
199            shape: vec![4],
200            dtype: "I32".to_string(),
201            num_elements: 4,
202            size_bytes: 16,
203            layout: TensorLayout::RowMajor,
204        };
205
206        let data: Vec<i32> = vec![10, 20, 30, 40];
207        let bytes = data
208            .iter()
209            .flat_map(|i| i.to_le_bytes())
210            .collect::<Vec<u8>>();
211
212        let batch = tensor_to_record_batch(&metadata, &bytes).unwrap();
213        assert_eq!(batch.num_rows(), 4);
214
215        let array = batch
216            .column(0)
217            .as_any()
218            .downcast_ref::<Int32Array>()
219            .unwrap();
220        assert_eq!(array.value(0), 10);
221        assert_eq!(array.value(3), 40);
222    }
223
224    #[test]
225    fn test_record_batch_to_ipc_bytes() {
226        let metadata = TensorMetadata {
227            shape: vec![3],
228            dtype: "F32".to_string(),
229            num_elements: 3,
230            size_bytes: 12,
231            layout: TensorLayout::RowMajor,
232        };
233
234        let data: Vec<f32> = vec![1.0, 2.0, 3.0];
235        let bytes = data
236            .iter()
237            .flat_map(|f| f.to_le_bytes())
238            .collect::<Vec<u8>>();
239
240        let batch = tensor_to_record_batch(&metadata, &bytes).unwrap();
241        let ipc_bytes = record_batch_to_ipc_bytes(&batch).unwrap();
242
243        // IPC format should have non-trivial size (header + data)
244        assert!(ipc_bytes.len() > 50);
245    }
246
247    #[test]
248    fn test_create_tensor_schema() {
249        let metadata = TensorMetadata {
250            shape: vec![10, 20, 30],
251            dtype: "F64".to_string(),
252            num_elements: 6000,
253            size_bytes: 48000,
254            layout: TensorLayout::RowMajor,
255        };
256
257        let schema = create_tensor_schema(&metadata).unwrap();
258        assert_eq!(schema.fields().len(), 1);
259
260        let field = &schema.fields()[0];
261        assert_eq!(field.name(), "data");
262        assert_eq!(field.data_type(), &DataType::Float64);
263
264        let meta = field.metadata();
265        assert!(meta.contains_key("tensor_shape"));
266        assert_eq!(meta.get("tensor_shape").unwrap(), "10,20,30");
267        assert_eq!(meta.get("tensor_dtype").unwrap(), "F64");
268    }
269
270    #[test]
271    fn test_all_dtypes() {
272        let dtypes = vec!["F32", "F64", "I32", "I64", "U8", "U16", "U32", "U64"];
273
274        for dtype in dtypes {
275            let element_size = match dtype {
276                "F32" | "I32" | "U32" => 4,
277                "F64" | "I64" | "U64" => 8,
278                "U8" => 1,
279                "U16" => 2,
280                _ => 4,
281            };
282
283            let metadata = TensorMetadata {
284                shape: vec![4],
285                dtype: dtype.to_string(),
286                num_elements: 4,
287                size_bytes: 4 * element_size,
288                layout: TensorLayout::RowMajor,
289            };
290
291            let data = vec![0u8; metadata.size_bytes];
292            let result = tensor_to_record_batch(&metadata, &data);
293            assert!(result.is_ok(), "Failed for dtype: {}", dtype);
294        }
295    }
296}