1use arrow::array::{
9 ArrayRef, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array,
10 UInt16Array, UInt32Array, UInt64Array, UInt8Array,
11};
12use arrow::buffer::Buffer;
13use arrow::datatypes::{DataType, Field, Schema};
14use arrow::ipc::reader::FileReader;
15use arrow::ipc::writer::FileWriter;
16use arrow::record_batch::RecordBatch;
17use bytes::Bytes;
18use serde::{Deserialize, Serialize};
19use std::collections::HashMap;
20use std::io::{Read, Seek, Write};
21use std::sync::Arc;
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
25pub enum TensorDtype {
26 Float32,
27 Float64,
28 Int8,
29 Int16,
30 Int32,
31 Int64,
32 UInt8,
33 UInt16,
34 UInt32,
35 UInt64,
36 BFloat16,
37 Float16,
38}
39
40impl TensorDtype {
41 #[inline]
43 pub fn element_size(&self) -> usize {
44 match self {
45 TensorDtype::Float32 => 4,
46 TensorDtype::Float64 => 8,
47 TensorDtype::Int8 | TensorDtype::UInt8 => 1,
48 TensorDtype::Int16
49 | TensorDtype::UInt16
50 | TensorDtype::Float16
51 | TensorDtype::BFloat16 => 2,
52 TensorDtype::Int32 | TensorDtype::UInt32 => 4,
53 TensorDtype::Int64 | TensorDtype::UInt64 => 8,
54 }
55 }
56
57 #[inline]
59 pub fn to_arrow_type(&self) -> DataType {
60 match self {
61 TensorDtype::Float32 => DataType::Float32,
62 TensorDtype::Float64 => DataType::Float64,
63 TensorDtype::Int8 => DataType::Int8,
64 TensorDtype::Int16 => DataType::Int16,
65 TensorDtype::Int32 => DataType::Int32,
66 TensorDtype::Int64 => DataType::Int64,
67 TensorDtype::UInt8 => DataType::UInt8,
68 TensorDtype::UInt16 => DataType::UInt16,
69 TensorDtype::UInt32 => DataType::UInt32,
70 TensorDtype::UInt64 => DataType::UInt64,
71 TensorDtype::BFloat16 | TensorDtype::Float16 => DataType::UInt16,
73 }
74 }
75
76 pub fn parse(s: &str) -> Option<Self> {
78 match s.to_lowercase().as_str() {
79 "f32" | "float32" => Some(TensorDtype::Float32),
80 "f64" | "float64" => Some(TensorDtype::Float64),
81 "i8" | "int8" => Some(TensorDtype::Int8),
82 "i16" | "int16" => Some(TensorDtype::Int16),
83 "i32" | "int32" => Some(TensorDtype::Int32),
84 "i64" | "int64" => Some(TensorDtype::Int64),
85 "u8" | "uint8" => Some(TensorDtype::UInt8),
86 "u16" | "uint16" => Some(TensorDtype::UInt16),
87 "u32" | "uint32" => Some(TensorDtype::UInt32),
88 "u64" | "uint64" => Some(TensorDtype::UInt64),
89 "bf16" | "bfloat16" => Some(TensorDtype::BFloat16),
90 "f16" | "float16" => Some(TensorDtype::Float16),
91 _ => None,
92 }
93 }
94}
95
96impl std::fmt::Display for TensorDtype {
97 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
98 match self {
99 TensorDtype::Float32 => write!(f, "float32"),
100 TensorDtype::Float64 => write!(f, "float64"),
101 TensorDtype::Int8 => write!(f, "int8"),
102 TensorDtype::Int16 => write!(f, "int16"),
103 TensorDtype::Int32 => write!(f, "int32"),
104 TensorDtype::Int64 => write!(f, "int64"),
105 TensorDtype::UInt8 => write!(f, "uint8"),
106 TensorDtype::UInt16 => write!(f, "uint16"),
107 TensorDtype::UInt32 => write!(f, "uint32"),
108 TensorDtype::UInt64 => write!(f, "uint64"),
109 TensorDtype::BFloat16 => write!(f, "bfloat16"),
110 TensorDtype::Float16 => write!(f, "float16"),
111 }
112 }
113}
114
115#[derive(Debug, Clone, Serialize, Deserialize)]
117pub struct TensorMetadata {
118 pub name: String,
120 pub shape: Vec<usize>,
122 pub dtype: TensorDtype,
124 pub strides: Option<Vec<usize>>,
126 pub custom: HashMap<String, String>,
128}
129
130impl TensorMetadata {
131 pub fn new(name: String, shape: Vec<usize>, dtype: TensorDtype) -> Self {
133 Self {
134 name,
135 shape,
136 dtype,
137 strides: None,
138 custom: HashMap::new(),
139 }
140 }
141
142 pub fn with_strides(mut self, strides: Vec<usize>) -> Self {
144 self.strides = Some(strides);
145 self
146 }
147
148 pub fn with_custom(mut self, key: String, value: String) -> Self {
150 self.custom.insert(key, value);
151 self
152 }
153
154 #[inline]
156 pub fn numel(&self) -> usize {
157 self.shape.iter().product()
158 }
159
160 #[inline]
162 pub fn size_bytes(&self) -> usize {
163 self.numel() * self.dtype.element_size()
164 }
165
166 pub fn compute_strides(&self) -> Vec<usize> {
168 if self.shape.is_empty() {
169 return vec![];
170 }
171 let mut strides = vec![1; self.shape.len()];
172 for i in (0..self.shape.len() - 1).rev() {
173 strides[i] = strides[i + 1] * self.shape[i + 1];
174 }
175 strides
176 }
177
178 pub fn get_strides(&self) -> Vec<usize> {
180 self.strides
181 .clone()
182 .unwrap_or_else(|| self.compute_strides())
183 }
184}
185
186pub struct ArrowTensor {
188 pub metadata: TensorMetadata,
190 array: ArrayRef,
192}
193
194impl ArrowTensor {
195 pub fn from_slice_f32(name: &str, shape: Vec<usize>, data: &[f32]) -> Self {
197 let metadata = TensorMetadata::new(name.to_string(), shape, TensorDtype::Float32);
198 let array: ArrayRef = Arc::new(Float32Array::from(data.to_vec()));
199 Self { metadata, array }
200 }
201
202 pub fn from_slice_f64(name: &str, shape: Vec<usize>, data: &[f64]) -> Self {
204 let metadata = TensorMetadata::new(name.to_string(), shape, TensorDtype::Float64);
205 let array: ArrayRef = Arc::new(Float64Array::from(data.to_vec()));
206 Self { metadata, array }
207 }
208
209 pub fn from_slice_i32(name: &str, shape: Vec<usize>, data: &[i32]) -> Self {
211 let metadata = TensorMetadata::new(name.to_string(), shape, TensorDtype::Int32);
212 let array: ArrayRef = Arc::new(Int32Array::from(data.to_vec()));
213 Self { metadata, array }
214 }
215
216 pub fn from_slice_i64(name: &str, shape: Vec<usize>, data: &[i64]) -> Self {
218 let metadata = TensorMetadata::new(name.to_string(), shape, TensorDtype::Int64);
219 let array: ArrayRef = Arc::new(Int64Array::from(data.to_vec()));
220 Self { metadata, array }
221 }
222
223 #[inline]
225 pub fn as_slice_f32(&self) -> Option<&[f32]> {
226 self.array
227 .as_any()
228 .downcast_ref::<Float32Array>()
229 .map(|arr| arr.values().as_ref())
230 }
231
232 #[inline]
234 pub fn as_slice_f64(&self) -> Option<&[f64]> {
235 self.array
236 .as_any()
237 .downcast_ref::<Float64Array>()
238 .map(|arr| arr.values().as_ref())
239 }
240
241 #[inline]
243 pub fn as_slice_i32(&self) -> Option<&[i32]> {
244 self.array
245 .as_any()
246 .downcast_ref::<Int32Array>()
247 .map(|arr| arr.values().as_ref())
248 }
249
250 #[inline]
252 pub fn as_slice_i64(&self) -> Option<&[i64]> {
253 self.array
254 .as_any()
255 .downcast_ref::<Int64Array>()
256 .map(|arr| arr.values().as_ref())
257 }
258
259 pub fn as_bytes(&self) -> Vec<u8> {
261 let data = self.array.to_data();
262 if data.buffers().is_empty() {
263 Vec::new()
264 } else {
265 data.buffers()[0].as_slice().to_vec()
266 }
267 }
268
269 #[inline]
271 pub fn array(&self) -> &ArrayRef {
272 &self.array
273 }
274
275 #[inline]
277 pub fn len(&self) -> usize {
278 self.array.len()
279 }
280
281 #[inline]
283 pub fn is_empty(&self) -> bool {
284 self.array.is_empty()
285 }
286}
287
288pub struct ArrowTensorStore {
290 tensors: HashMap<String, ArrowTensor>,
292 schema: Option<Arc<Schema>>,
294}
295
296impl ArrowTensorStore {
297 pub fn new() -> Self {
299 Self {
300 tensors: HashMap::new(),
301 schema: None,
302 }
303 }
304
305 pub fn insert(&mut self, tensor: ArrowTensor) {
307 self.schema = None; self.tensors.insert(tensor.metadata.name.clone(), tensor);
309 }
310
311 #[inline]
313 pub fn get(&self, name: &str) -> Option<&ArrowTensor> {
314 self.tensors.get(name)
315 }
316
317 pub fn names(&self) -> Vec<&str> {
319 self.tensors.keys().map(|s| s.as_str()).collect()
320 }
321
322 #[inline]
324 pub fn len(&self) -> usize {
325 self.tensors.len()
326 }
327
328 #[inline]
330 pub fn is_empty(&self) -> bool {
331 self.tensors.is_empty()
332 }
333
334 pub fn build_schema(&mut self) -> Arc<Schema> {
336 if let Some(ref schema) = self.schema {
337 return schema.clone();
338 }
339
340 let fields: Vec<Field> = self
341 .tensors
342 .values()
343 .map(|t| {
344 let mut metadata = HashMap::new();
345 metadata.insert("shape".to_string(), format!("{:?}", t.metadata.shape));
346 metadata.insert("dtype".to_string(), t.metadata.dtype.to_string());
347 if let Some(ref strides) = t.metadata.strides {
348 metadata.insert("strides".to_string(), format!("{:?}", strides));
349 }
350 for (k, v) in &t.metadata.custom {
351 metadata.insert(k.clone(), v.clone());
352 }
353 Field::new(&t.metadata.name, t.metadata.dtype.to_arrow_type(), false)
354 .with_metadata(metadata)
355 })
356 .collect();
357
358 let schema = Arc::new(Schema::new(fields));
359 self.schema = Some(schema.clone());
360 schema
361 }
362
363 pub fn to_record_batch(&mut self) -> Result<RecordBatch, arrow::error::ArrowError> {
365 let schema = self.build_schema();
366 let columns: Vec<ArrayRef> = self.tensors.values().map(|t| t.array.clone()).collect();
367 RecordBatch::try_new(schema, columns)
368 }
369
370 pub fn write_ipc<W: Write>(&mut self, writer: W) -> Result<(), arrow::error::ArrowError> {
372 let batch = self.to_record_batch()?;
373 let schema = batch.schema();
374 let mut ipc_writer = FileWriter::try_new(writer, &schema)?;
375 ipc_writer.write(&batch)?;
376 ipc_writer.finish()?;
377 Ok(())
378 }
379
380 pub fn read_ipc<R: Read + Seek>(reader: R) -> Result<Self, arrow::error::ArrowError> {
382 let ipc_reader = FileReader::try_new(reader, None)?;
383 let schema = ipc_reader.schema();
384 let mut store = Self::new();
385
386 for batch_result in ipc_reader {
387 let batch = batch_result?;
388 for (i, field) in schema.fields().iter().enumerate() {
389 let array = batch.column(i).clone();
390 let shape = parse_shape_from_metadata(field.metadata());
391 let dtype = dtype_from_arrow(field.data_type());
392
393 let metadata = TensorMetadata::new(field.name().clone(), shape, dtype);
394 store
395 .tensors
396 .insert(field.name().clone(), ArrowTensor { metadata, array });
397 }
398 }
399
400 store.schema = Some(schema);
401 Ok(store)
402 }
403
404 pub fn to_bytes(&mut self) -> Result<Bytes, arrow::error::ArrowError> {
406 let mut buffer = Vec::new();
407 self.write_ipc(&mut buffer)?;
408 Ok(Bytes::from(buffer))
409 }
410
411 pub fn from_bytes(bytes: &[u8]) -> Result<Self, arrow::error::ArrowError> {
413 let cursor = std::io::Cursor::new(bytes);
414 Self::read_ipc(cursor)
415 }
416}
417
418impl Default for ArrowTensorStore {
419 fn default() -> Self {
420 Self::new()
421 }
422}
423
424fn parse_shape_from_metadata(metadata: &HashMap<String, String>) -> Vec<usize> {
426 metadata
427 .get("shape")
428 .and_then(|s| {
429 let trimmed = s.trim_start_matches('[').trim_end_matches(']');
431 let parts: Result<Vec<usize>, _> =
432 trimmed.split(',').map(|p| p.trim().parse()).collect();
433 parts.ok()
434 })
435 .unwrap_or_default()
436}
437
438fn dtype_from_arrow(dt: &DataType) -> TensorDtype {
440 match dt {
441 DataType::Float32 => TensorDtype::Float32,
442 DataType::Float64 => TensorDtype::Float64,
443 DataType::Int8 => TensorDtype::Int8,
444 DataType::Int16 => TensorDtype::Int16,
445 DataType::Int32 => TensorDtype::Int32,
446 DataType::Int64 => TensorDtype::Int64,
447 DataType::UInt8 => TensorDtype::UInt8,
448 DataType::UInt16 => TensorDtype::UInt16,
449 DataType::UInt32 => TensorDtype::UInt32,
450 DataType::UInt64 => TensorDtype::UInt64,
451 _ => TensorDtype::Float32, }
453}
454
455pub trait ZeroCopyAccessor {
457 fn get_bytes(&self) -> Vec<u8>;
459
460 fn len_bytes(&self) -> usize {
462 self.get_bytes().len()
463 }
464}
465
466impl ZeroCopyAccessor for ArrowTensor {
467 fn get_bytes(&self) -> Vec<u8> {
468 ArrowTensor::as_bytes(self)
469 }
470}
471
472#[allow(deprecated)]
474pub fn buffer_from_bytes(bytes: Bytes) -> Buffer {
475 Buffer::from(bytes)
476}
477
478#[allow(dead_code)]
480fn create_array_from_buffer(buffer: Buffer, dtype: TensorDtype, _len: usize) -> ArrayRef {
481 match dtype {
482 TensorDtype::Float32 => Arc::new(Float32Array::new(buffer.into(), None)) as ArrayRef,
483 TensorDtype::Float64 => Arc::new(Float64Array::new(buffer.into(), None)) as ArrayRef,
484 TensorDtype::Int8 => Arc::new(Int8Array::new(buffer.into(), None)) as ArrayRef,
485 TensorDtype::Int16 => Arc::new(Int16Array::new(buffer.into(), None)) as ArrayRef,
486 TensorDtype::Int32 => Arc::new(Int32Array::new(buffer.into(), None)) as ArrayRef,
487 TensorDtype::Int64 => Arc::new(Int64Array::new(buffer.into(), None)) as ArrayRef,
488 TensorDtype::UInt8 => Arc::new(UInt8Array::new(buffer.into(), None)) as ArrayRef,
489 TensorDtype::UInt16 => Arc::new(UInt16Array::new(buffer.into(), None)) as ArrayRef,
490 TensorDtype::UInt32 => Arc::new(UInt32Array::new(buffer.into(), None)) as ArrayRef,
491 TensorDtype::UInt64 => Arc::new(UInt64Array::new(buffer.into(), None)) as ArrayRef,
492 TensorDtype::Float16 | TensorDtype::BFloat16 => {
494 Arc::new(UInt16Array::new(buffer.into(), None)) as ArrayRef
495 }
496 }
497}
498
499#[cfg(test)]
500mod tests {
501 use super::*;
502
503 #[test]
504 fn test_tensor_metadata() {
505 let meta = TensorMetadata::new("test".to_string(), vec![2, 3, 4], TensorDtype::Float32);
506 assert_eq!(meta.numel(), 24);
507 assert_eq!(meta.size_bytes(), 96);
508 assert_eq!(meta.compute_strides(), vec![12, 4, 1]);
509 }
510
511 #[test]
512 fn test_arrow_tensor_f32() {
513 let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
514 let tensor = ArrowTensor::from_slice_f32("weights", vec![2, 3], &data);
515
516 assert_eq!(tensor.metadata.name, "weights");
517 assert_eq!(tensor.metadata.shape, vec![2, 3]);
518 assert_eq!(tensor.len(), 6);
519
520 let slice = tensor.as_slice_f32().unwrap();
521 assert_eq!(slice, &data);
522 }
523
524 #[test]
525 fn test_arrow_tensor_store() {
526 let mut store = ArrowTensorStore::new();
527
528 let w1 = ArrowTensor::from_slice_f32("layer1.weight", vec![4, 3], &[0.0; 12]);
529 let w2 = ArrowTensor::from_slice_f32("layer2.weight", vec![2, 4], &[0.0; 8]);
530
531 store.insert(w1);
532 store.insert(w2);
533
534 assert_eq!(store.len(), 2);
535 assert!(store.get("layer1.weight").is_some());
536 assert!(store.get("layer2.weight").is_some());
537 }
538
539 #[test]
540 fn test_ipc_roundtrip() {
541 let mut store = ArrowTensorStore::new();
542 let data: Vec<f32> = (0..12).map(|i| i as f32).collect();
543 store.insert(ArrowTensor::from_slice_f32("test", vec![3, 4], &data));
544
545 let bytes = store.to_bytes().unwrap();
546 let loaded = ArrowTensorStore::from_bytes(&bytes).unwrap();
547
548 assert_eq!(loaded.len(), 1);
549 let tensor = loaded.get("test").unwrap();
550 assert_eq!(tensor.as_slice_f32().unwrap(), &data);
551 }
552
553 #[test]
554 fn test_dtype_conversion() {
555 assert_eq!(TensorDtype::Float32.to_arrow_type(), DataType::Float32);
556 assert_eq!(TensorDtype::Int64.to_arrow_type(), DataType::Int64);
557 assert_eq!(TensorDtype::Float32.element_size(), 4);
558 assert_eq!(TensorDtype::Float64.element_size(), 8);
559 }
560}