ipfrs_tensorlogic/
arrow.rs

1//! Apache Arrow integration for zero-copy tensor transport
2//!
3//! Provides Arrow memory layout for tensors, enabling:
4//! - Zero-copy data access
5//! - Efficient columnar data formats
6//! - Interoperability with Arrow ecosystem
7
8use arrow::array::{
9    ArrayRef, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array,
10    UInt16Array, UInt32Array, UInt64Array, UInt8Array,
11};
12use arrow::buffer::Buffer;
13use arrow::datatypes::{DataType, Field, Schema};
14use arrow::ipc::reader::FileReader;
15use arrow::ipc::writer::FileWriter;
16use arrow::record_batch::RecordBatch;
17use bytes::Bytes;
18use serde::{Deserialize, Serialize};
19use std::collections::HashMap;
20use std::io::{Read, Seek, Write};
21use std::sync::Arc;
22
23/// Tensor data type
24#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
25pub enum TensorDtype {
26    Float32,
27    Float64,
28    Int8,
29    Int16,
30    Int32,
31    Int64,
32    UInt8,
33    UInt16,
34    UInt32,
35    UInt64,
36    BFloat16,
37    Float16,
38}
39
40impl TensorDtype {
41    /// Get the size in bytes of a single element
42    #[inline]
43    pub fn element_size(&self) -> usize {
44        match self {
45            TensorDtype::Float32 => 4,
46            TensorDtype::Float64 => 8,
47            TensorDtype::Int8 | TensorDtype::UInt8 => 1,
48            TensorDtype::Int16
49            | TensorDtype::UInt16
50            | TensorDtype::Float16
51            | TensorDtype::BFloat16 => 2,
52            TensorDtype::Int32 | TensorDtype::UInt32 => 4,
53            TensorDtype::Int64 | TensorDtype::UInt64 => 8,
54        }
55    }
56
57    /// Convert to Arrow DataType
58    #[inline]
59    pub fn to_arrow_type(&self) -> DataType {
60        match self {
61            TensorDtype::Float32 => DataType::Float32,
62            TensorDtype::Float64 => DataType::Float64,
63            TensorDtype::Int8 => DataType::Int8,
64            TensorDtype::Int16 => DataType::Int16,
65            TensorDtype::Int32 => DataType::Int32,
66            TensorDtype::Int64 => DataType::Int64,
67            TensorDtype::UInt8 => DataType::UInt8,
68            TensorDtype::UInt16 => DataType::UInt16,
69            TensorDtype::UInt32 => DataType::UInt32,
70            TensorDtype::UInt64 => DataType::UInt64,
71            // BFloat16 and Float16 are stored as UInt16 in Arrow
72            TensorDtype::BFloat16 | TensorDtype::Float16 => DataType::UInt16,
73        }
74    }
75
76    /// Get dtype from string representation
77    pub fn parse(s: &str) -> Option<Self> {
78        match s.to_lowercase().as_str() {
79            "f32" | "float32" => Some(TensorDtype::Float32),
80            "f64" | "float64" => Some(TensorDtype::Float64),
81            "i8" | "int8" => Some(TensorDtype::Int8),
82            "i16" | "int16" => Some(TensorDtype::Int16),
83            "i32" | "int32" => Some(TensorDtype::Int32),
84            "i64" | "int64" => Some(TensorDtype::Int64),
85            "u8" | "uint8" => Some(TensorDtype::UInt8),
86            "u16" | "uint16" => Some(TensorDtype::UInt16),
87            "u32" | "uint32" => Some(TensorDtype::UInt32),
88            "u64" | "uint64" => Some(TensorDtype::UInt64),
89            "bf16" | "bfloat16" => Some(TensorDtype::BFloat16),
90            "f16" | "float16" => Some(TensorDtype::Float16),
91            _ => None,
92        }
93    }
94}
95
96impl std::fmt::Display for TensorDtype {
97    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
98        match self {
99            TensorDtype::Float32 => write!(f, "float32"),
100            TensorDtype::Float64 => write!(f, "float64"),
101            TensorDtype::Int8 => write!(f, "int8"),
102            TensorDtype::Int16 => write!(f, "int16"),
103            TensorDtype::Int32 => write!(f, "int32"),
104            TensorDtype::Int64 => write!(f, "int64"),
105            TensorDtype::UInt8 => write!(f, "uint8"),
106            TensorDtype::UInt16 => write!(f, "uint16"),
107            TensorDtype::UInt32 => write!(f, "uint32"),
108            TensorDtype::UInt64 => write!(f, "uint64"),
109            TensorDtype::BFloat16 => write!(f, "bfloat16"),
110            TensorDtype::Float16 => write!(f, "float16"),
111        }
112    }
113}
114
115/// Tensor metadata for self-describing tensors
116#[derive(Debug, Clone, Serialize, Deserialize)]
117pub struct TensorMetadata {
118    /// Tensor name
119    pub name: String,
120    /// Shape dimensions
121    pub shape: Vec<usize>,
122    /// Data type
123    pub dtype: TensorDtype,
124    /// Strides (in elements, not bytes)
125    pub strides: Option<Vec<usize>>,
126    /// Custom metadata fields
127    pub custom: HashMap<String, String>,
128}
129
130impl TensorMetadata {
131    /// Create new tensor metadata
132    pub fn new(name: String, shape: Vec<usize>, dtype: TensorDtype) -> Self {
133        Self {
134            name,
135            shape,
136            dtype,
137            strides: None,
138            custom: HashMap::new(),
139        }
140    }
141
142    /// Set strides
143    pub fn with_strides(mut self, strides: Vec<usize>) -> Self {
144        self.strides = Some(strides);
145        self
146    }
147
148    /// Add custom metadata
149    pub fn with_custom(mut self, key: String, value: String) -> Self {
150        self.custom.insert(key, value);
151        self
152    }
153
154    /// Get the number of elements
155    #[inline]
156    pub fn numel(&self) -> usize {
157        self.shape.iter().product()
158    }
159
160    /// Get the size in bytes
161    #[inline]
162    pub fn size_bytes(&self) -> usize {
163        self.numel() * self.dtype.element_size()
164    }
165
166    /// Compute default strides (row-major order)
167    pub fn compute_strides(&self) -> Vec<usize> {
168        if self.shape.is_empty() {
169            return vec![];
170        }
171        let mut strides = vec![1; self.shape.len()];
172        for i in (0..self.shape.len() - 1).rev() {
173            strides[i] = strides[i + 1] * self.shape[i + 1];
174        }
175        strides
176    }
177
178    /// Get strides (computed if not specified)
179    pub fn get_strides(&self) -> Vec<usize> {
180        self.strides
181            .clone()
182            .unwrap_or_else(|| self.compute_strides())
183    }
184}
185
186/// Arrow-backed tensor for zero-copy access
187pub struct ArrowTensor {
188    /// Tensor metadata
189    pub metadata: TensorMetadata,
190    /// Arrow array containing the data
191    array: ArrayRef,
192}
193
194impl ArrowTensor {
195    /// Create a new Arrow tensor from raw data
196    pub fn from_slice_f32(name: &str, shape: Vec<usize>, data: &[f32]) -> Self {
197        let metadata = TensorMetadata::new(name.to_string(), shape, TensorDtype::Float32);
198        let array: ArrayRef = Arc::new(Float32Array::from(data.to_vec()));
199        Self { metadata, array }
200    }
201
202    /// Create a new Arrow tensor from raw f64 data
203    pub fn from_slice_f64(name: &str, shape: Vec<usize>, data: &[f64]) -> Self {
204        let metadata = TensorMetadata::new(name.to_string(), shape, TensorDtype::Float64);
205        let array: ArrayRef = Arc::new(Float64Array::from(data.to_vec()));
206        Self { metadata, array }
207    }
208
209    /// Create from i32 data
210    pub fn from_slice_i32(name: &str, shape: Vec<usize>, data: &[i32]) -> Self {
211        let metadata = TensorMetadata::new(name.to_string(), shape, TensorDtype::Int32);
212        let array: ArrayRef = Arc::new(Int32Array::from(data.to_vec()));
213        Self { metadata, array }
214    }
215
216    /// Create from i64 data
217    pub fn from_slice_i64(name: &str, shape: Vec<usize>, data: &[i64]) -> Self {
218        let metadata = TensorMetadata::new(name.to_string(), shape, TensorDtype::Int64);
219        let array: ArrayRef = Arc::new(Int64Array::from(data.to_vec()));
220        Self { metadata, array }
221    }
222
223    /// Get zero-copy view of f32 data
224    #[inline]
225    pub fn as_slice_f32(&self) -> Option<&[f32]> {
226        self.array
227            .as_any()
228            .downcast_ref::<Float32Array>()
229            .map(|arr| arr.values().as_ref())
230    }
231
232    /// Get zero-copy view of f64 data
233    #[inline]
234    pub fn as_slice_f64(&self) -> Option<&[f64]> {
235        self.array
236            .as_any()
237            .downcast_ref::<Float64Array>()
238            .map(|arr| arr.values().as_ref())
239    }
240
241    /// Get zero-copy view of i32 data
242    #[inline]
243    pub fn as_slice_i32(&self) -> Option<&[i32]> {
244        self.array
245            .as_any()
246            .downcast_ref::<Int32Array>()
247            .map(|arr| arr.values().as_ref())
248    }
249
250    /// Get zero-copy view of i64 data
251    #[inline]
252    pub fn as_slice_i64(&self) -> Option<&[i64]> {
253        self.array
254            .as_any()
255            .downcast_ref::<Int64Array>()
256            .map(|arr| arr.values().as_ref())
257    }
258
259    /// Get raw bytes (copies data)
260    pub fn as_bytes(&self) -> Vec<u8> {
261        let data = self.array.to_data();
262        if data.buffers().is_empty() {
263            Vec::new()
264        } else {
265            data.buffers()[0].as_slice().to_vec()
266        }
267    }
268
269    /// Get the underlying Arrow array
270    #[inline]
271    pub fn array(&self) -> &ArrayRef {
272        &self.array
273    }
274
275    /// Get the number of elements
276    #[inline]
277    pub fn len(&self) -> usize {
278        self.array.len()
279    }
280
281    /// Check if empty
282    #[inline]
283    pub fn is_empty(&self) -> bool {
284        self.array.is_empty()
285    }
286}
287
288/// Collection of tensors stored in Arrow format
289pub struct ArrowTensorStore {
290    /// Tensors by name
291    tensors: HashMap<String, ArrowTensor>,
292    /// Schema for the tensor collection
293    schema: Option<Arc<Schema>>,
294}
295
296impl ArrowTensorStore {
297    /// Create a new empty store
298    pub fn new() -> Self {
299        Self {
300            tensors: HashMap::new(),
301            schema: None,
302        }
303    }
304
305    /// Add a tensor to the store
306    pub fn insert(&mut self, tensor: ArrowTensor) {
307        self.schema = None; // Invalidate schema
308        self.tensors.insert(tensor.metadata.name.clone(), tensor);
309    }
310
311    /// Get a tensor by name
312    #[inline]
313    pub fn get(&self, name: &str) -> Option<&ArrowTensor> {
314        self.tensors.get(name)
315    }
316
317    /// List all tensor names
318    pub fn names(&self) -> Vec<&str> {
319        self.tensors.keys().map(|s| s.as_str()).collect()
320    }
321
322    /// Get the number of tensors
323    #[inline]
324    pub fn len(&self) -> usize {
325        self.tensors.len()
326    }
327
328    /// Check if empty
329    #[inline]
330    pub fn is_empty(&self) -> bool {
331        self.tensors.is_empty()
332    }
333
334    /// Build Arrow schema for all tensors
335    pub fn build_schema(&mut self) -> Arc<Schema> {
336        if let Some(ref schema) = self.schema {
337            return schema.clone();
338        }
339
340        let fields: Vec<Field> = self
341            .tensors
342            .values()
343            .map(|t| {
344                let mut metadata = HashMap::new();
345                metadata.insert("shape".to_string(), format!("{:?}", t.metadata.shape));
346                metadata.insert("dtype".to_string(), t.metadata.dtype.to_string());
347                if let Some(ref strides) = t.metadata.strides {
348                    metadata.insert("strides".to_string(), format!("{:?}", strides));
349                }
350                for (k, v) in &t.metadata.custom {
351                    metadata.insert(k.clone(), v.clone());
352                }
353                Field::new(&t.metadata.name, t.metadata.dtype.to_arrow_type(), false)
354                    .with_metadata(metadata)
355            })
356            .collect();
357
358        let schema = Arc::new(Schema::new(fields));
359        self.schema = Some(schema.clone());
360        schema
361    }
362
363    /// Convert to RecordBatch for IPC
364    pub fn to_record_batch(&mut self) -> Result<RecordBatch, arrow::error::ArrowError> {
365        let schema = self.build_schema();
366        let columns: Vec<ArrayRef> = self.tensors.values().map(|t| t.array.clone()).collect();
367        RecordBatch::try_new(schema, columns)
368    }
369
370    /// Write to Arrow IPC format
371    pub fn write_ipc<W: Write>(&mut self, writer: W) -> Result<(), arrow::error::ArrowError> {
372        let batch = self.to_record_batch()?;
373        let schema = batch.schema();
374        let mut ipc_writer = FileWriter::try_new(writer, &schema)?;
375        ipc_writer.write(&batch)?;
376        ipc_writer.finish()?;
377        Ok(())
378    }
379
380    /// Read from Arrow IPC format
381    pub fn read_ipc<R: Read + Seek>(reader: R) -> Result<Self, arrow::error::ArrowError> {
382        let ipc_reader = FileReader::try_new(reader, None)?;
383        let schema = ipc_reader.schema();
384        let mut store = Self::new();
385
386        for batch_result in ipc_reader {
387            let batch = batch_result?;
388            for (i, field) in schema.fields().iter().enumerate() {
389                let array = batch.column(i).clone();
390                let shape = parse_shape_from_metadata(field.metadata());
391                let dtype = dtype_from_arrow(field.data_type());
392
393                let metadata = TensorMetadata::new(field.name().clone(), shape, dtype);
394                store
395                    .tensors
396                    .insert(field.name().clone(), ArrowTensor { metadata, array });
397            }
398        }
399
400        store.schema = Some(schema);
401        Ok(store)
402    }
403
404    /// Serialize to bytes (Arrow IPC format)
405    pub fn to_bytes(&mut self) -> Result<Bytes, arrow::error::ArrowError> {
406        let mut buffer = Vec::new();
407        self.write_ipc(&mut buffer)?;
408        Ok(Bytes::from(buffer))
409    }
410
411    /// Deserialize from bytes
412    pub fn from_bytes(bytes: &[u8]) -> Result<Self, arrow::error::ArrowError> {
413        let cursor = std::io::Cursor::new(bytes);
414        Self::read_ipc(cursor)
415    }
416}
417
418impl Default for ArrowTensorStore {
419    fn default() -> Self {
420        Self::new()
421    }
422}
423
424/// Parse shape from field metadata
425fn parse_shape_from_metadata(metadata: &HashMap<String, String>) -> Vec<usize> {
426    metadata
427        .get("shape")
428        .and_then(|s| {
429            // Parse "[1, 2, 3]" format
430            let trimmed = s.trim_start_matches('[').trim_end_matches(']');
431            let parts: Result<Vec<usize>, _> =
432                trimmed.split(',').map(|p| p.trim().parse()).collect();
433            parts.ok()
434        })
435        .unwrap_or_default()
436}
437
438/// Convert Arrow DataType to TensorDtype
439fn dtype_from_arrow(dt: &DataType) -> TensorDtype {
440    match dt {
441        DataType::Float32 => TensorDtype::Float32,
442        DataType::Float64 => TensorDtype::Float64,
443        DataType::Int8 => TensorDtype::Int8,
444        DataType::Int16 => TensorDtype::Int16,
445        DataType::Int32 => TensorDtype::Int32,
446        DataType::Int64 => TensorDtype::Int64,
447        DataType::UInt8 => TensorDtype::UInt8,
448        DataType::UInt16 => TensorDtype::UInt16,
449        DataType::UInt32 => TensorDtype::UInt32,
450        DataType::UInt64 => TensorDtype::UInt64,
451        _ => TensorDtype::Float32, // Default
452    }
453}
454
455/// Zero-copy tensor accessor trait
456pub trait ZeroCopyAccessor {
457    /// Get raw byte vector
458    fn get_bytes(&self) -> Vec<u8>;
459
460    /// Get length in bytes
461    fn len_bytes(&self) -> usize {
462        self.get_bytes().len()
463    }
464}
465
466impl ZeroCopyAccessor for ArrowTensor {
467    fn get_bytes(&self) -> Vec<u8> {
468        ArrowTensor::as_bytes(self)
469    }
470}
471
472/// Create Arrow buffer from raw bytes (zero-copy when possible)
473#[allow(deprecated)]
474pub fn buffer_from_bytes(bytes: Bytes) -> Buffer {
475    Buffer::from(bytes)
476}
477
478/// Create typed array from buffer
479#[allow(dead_code)]
480fn create_array_from_buffer(buffer: Buffer, dtype: TensorDtype, _len: usize) -> ArrayRef {
481    match dtype {
482        TensorDtype::Float32 => Arc::new(Float32Array::new(buffer.into(), None)) as ArrayRef,
483        TensorDtype::Float64 => Arc::new(Float64Array::new(buffer.into(), None)) as ArrayRef,
484        TensorDtype::Int8 => Arc::new(Int8Array::new(buffer.into(), None)) as ArrayRef,
485        TensorDtype::Int16 => Arc::new(Int16Array::new(buffer.into(), None)) as ArrayRef,
486        TensorDtype::Int32 => Arc::new(Int32Array::new(buffer.into(), None)) as ArrayRef,
487        TensorDtype::Int64 => Arc::new(Int64Array::new(buffer.into(), None)) as ArrayRef,
488        TensorDtype::UInt8 => Arc::new(UInt8Array::new(buffer.into(), None)) as ArrayRef,
489        TensorDtype::UInt16 => Arc::new(UInt16Array::new(buffer.into(), None)) as ArrayRef,
490        TensorDtype::UInt32 => Arc::new(UInt32Array::new(buffer.into(), None)) as ArrayRef,
491        TensorDtype::UInt64 => Arc::new(UInt64Array::new(buffer.into(), None)) as ArrayRef,
492        // Float16/BFloat16 stored as UInt16
493        TensorDtype::Float16 | TensorDtype::BFloat16 => {
494            Arc::new(UInt16Array::new(buffer.into(), None)) as ArrayRef
495        }
496    }
497}
498
499#[cfg(test)]
500mod tests {
501    use super::*;
502
503    #[test]
504    fn test_tensor_metadata() {
505        let meta = TensorMetadata::new("test".to_string(), vec![2, 3, 4], TensorDtype::Float32);
506        assert_eq!(meta.numel(), 24);
507        assert_eq!(meta.size_bytes(), 96);
508        assert_eq!(meta.compute_strides(), vec![12, 4, 1]);
509    }
510
511    #[test]
512    fn test_arrow_tensor_f32() {
513        let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
514        let tensor = ArrowTensor::from_slice_f32("weights", vec![2, 3], &data);
515
516        assert_eq!(tensor.metadata.name, "weights");
517        assert_eq!(tensor.metadata.shape, vec![2, 3]);
518        assert_eq!(tensor.len(), 6);
519
520        let slice = tensor.as_slice_f32().unwrap();
521        assert_eq!(slice, &data);
522    }
523
524    #[test]
525    fn test_arrow_tensor_store() {
526        let mut store = ArrowTensorStore::new();
527
528        let w1 = ArrowTensor::from_slice_f32("layer1.weight", vec![4, 3], &[0.0; 12]);
529        let w2 = ArrowTensor::from_slice_f32("layer2.weight", vec![2, 4], &[0.0; 8]);
530
531        store.insert(w1);
532        store.insert(w2);
533
534        assert_eq!(store.len(), 2);
535        assert!(store.get("layer1.weight").is_some());
536        assert!(store.get("layer2.weight").is_some());
537    }
538
539    #[test]
540    fn test_ipc_roundtrip() {
541        let mut store = ArrowTensorStore::new();
542        let data: Vec<f32> = (0..12).map(|i| i as f32).collect();
543        store.insert(ArrowTensor::from_slice_f32("test", vec![3, 4], &data));
544
545        let bytes = store.to_bytes().unwrap();
546        let loaded = ArrowTensorStore::from_bytes(&bytes).unwrap();
547
548        assert_eq!(loaded.len(), 1);
549        let tensor = loaded.get("test").unwrap();
550        assert_eq!(tensor.as_slice_f32().unwrap(), &data);
551    }
552
553    #[test]
554    fn test_dtype_conversion() {
555        assert_eq!(TensorDtype::Float32.to_arrow_type(), DataType::Float32);
556        assert_eq!(TensorDtype::Int64.to_arrow_type(), DataType::Int64);
557        assert_eq!(TensorDtype::Float32.element_size(), 4);
558        assert_eq!(TensorDtype::Float64.element_size(), 8);
559    }
560}