1use arrow::array::{
7 ArrayRef, Float32Array, Float64Array, Int32Array, Int64Array, UInt16Array, UInt32Array,
8 UInt64Array, UInt8Array,
9};
10use arrow::buffer::Buffer;
11use arrow::datatypes::{DataType, Field, Schema};
12use arrow::ipc::writer::StreamWriter;
13use arrow::record_batch::RecordBatch;
14use bytes::Bytes;
15use ipfrs_core::error::{Error, Result};
16use std::sync::Arc;
17
18use crate::tensor::TensorMetadata;
19
20pub fn tensor_to_record_batch(metadata: &TensorMetadata, data: &[u8]) -> Result<RecordBatch> {
25 let arrow_dtype = match metadata.dtype.as_str() {
27 "F32" | "f32" => DataType::Float32,
28 "F64" | "f64" => DataType::Float64,
29 "I32" | "i32" => DataType::Int32,
30 "I64" | "i64" => DataType::Int64,
31 "U8" | "u8" => DataType::UInt8,
32 "U16" | "u16" => DataType::UInt16,
33 "U32" | "u32" => DataType::UInt32,
34 "U64" | "u64" => DataType::UInt64,
35 _ => {
36 return Err(Error::Internal(format!(
37 "Unsupported dtype: {}",
38 metadata.dtype
39 )))
40 }
41 };
42
43 let schema = Schema::new(vec![Field::new("data", arrow_dtype.clone(), false)]);
45
46 let array: ArrayRef = match metadata.dtype.as_str() {
48 "F32" | "f32" => {
49 let buffer = Buffer::from(data);
50 Arc::new(Float32Array::new(buffer.into(), None))
51 }
52 "F64" | "f64" => {
53 let buffer = Buffer::from(data);
54 Arc::new(Float64Array::new(buffer.into(), None))
55 }
56 "I32" | "i32" => {
57 let buffer = Buffer::from(data);
58 Arc::new(Int32Array::new(buffer.into(), None))
59 }
60 "I64" | "i64" => {
61 let buffer = Buffer::from(data);
62 Arc::new(Int64Array::new(buffer.into(), None))
63 }
64 "U8" | "u8" => {
65 let buffer = Buffer::from(data);
66 Arc::new(UInt8Array::new(buffer.into(), None))
67 }
68 "U16" | "u16" => {
69 let buffer = Buffer::from(data);
70 Arc::new(UInt16Array::new(buffer.into(), None))
71 }
72 "U32" | "u32" => {
73 let buffer = Buffer::from(data);
74 Arc::new(UInt32Array::new(buffer.into(), None))
75 }
76 "U64" | "u64" => {
77 let buffer = Buffer::from(data);
78 Arc::new(UInt64Array::new(buffer.into(), None))
79 }
80 _ => {
81 return Err(Error::Internal(format!(
82 "Unsupported dtype: {}",
83 metadata.dtype
84 )))
85 }
86 };
87
88 RecordBatch::try_new(Arc::new(schema), vec![array])
90 .map_err(|e| Error::Internal(format!("Failed to create Arrow RecordBatch: {}", e)))
91}
92
93pub fn record_batch_to_ipc_bytes(batch: &RecordBatch) -> Result<Bytes> {
97 let mut buffer = Vec::new();
98 {
99 let mut writer = StreamWriter::try_new(&mut buffer, &batch.schema())
100 .map_err(|e| Error::Internal(format!("Failed to create Arrow StreamWriter: {}", e)))?;
101
102 writer
103 .write(batch)
104 .map_err(|e| Error::Internal(format!("Failed to write Arrow batch: {}", e)))?;
105
106 writer
107 .finish()
108 .map_err(|e| Error::Internal(format!("Failed to finish Arrow stream: {}", e)))?;
109 }
110
111 Ok(Bytes::from(buffer))
112}
113
114pub fn create_tensor_schema(metadata: &TensorMetadata) -> Result<Schema> {
118 let arrow_dtype = match metadata.dtype.as_str() {
119 "F32" | "f32" => DataType::Float32,
120 "F64" | "f64" => DataType::Float64,
121 "I32" | "i32" => DataType::Int32,
122 "I64" | "i64" => DataType::Int64,
123 "U8" | "u8" => DataType::UInt8,
124 "U16" | "u16" => DataType::UInt16,
125 "U32" | "u32" => DataType::UInt32,
126 "U64" | "u64" => DataType::UInt64,
127 _ => {
128 return Err(Error::Internal(format!(
129 "Unsupported dtype: {}",
130 metadata.dtype
131 )))
132 }
133 };
134
135 let mut field = Field::new("data", arrow_dtype, false);
137
138 let shape_str = metadata
140 .shape
141 .iter()
142 .map(|s| s.to_string())
143 .collect::<Vec<_>>()
144 .join(",");
145
146 field = field.with_metadata(
147 [
148 ("tensor_shape".to_string(), shape_str),
149 ("tensor_dtype".to_string(), metadata.dtype.clone()),
150 (
151 "tensor_layout".to_string(),
152 format!("{:?}", metadata.layout),
153 ),
154 ]
155 .into_iter()
156 .collect(),
157 );
158
159 Ok(Schema::new(vec![field]))
160}
161
162#[cfg(test)]
163mod tests {
164 use super::*;
165 use crate::tensor::TensorLayout;
166
167 #[test]
168 fn test_tensor_to_record_batch_f32() {
169 let metadata = TensorMetadata {
170 shape: vec![2, 3],
171 dtype: "F32".to_string(),
172 num_elements: 6,
173 size_bytes: 24,
174 layout: TensorLayout::RowMajor,
175 };
176
177 let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
178 let bytes = data
179 .iter()
180 .flat_map(|f| f.to_le_bytes())
181 .collect::<Vec<u8>>();
182
183 let batch = tensor_to_record_batch(&metadata, &bytes).unwrap();
184 assert_eq!(batch.num_columns(), 1);
185 assert_eq!(batch.num_rows(), 6);
186
187 let array = batch
188 .column(0)
189 .as_any()
190 .downcast_ref::<Float32Array>()
191 .unwrap();
192 assert_eq!(array.value(0), 1.0);
193 assert_eq!(array.value(5), 6.0);
194 }
195
196 #[test]
197 fn test_tensor_to_record_batch_i32() {
198 let metadata = TensorMetadata {
199 shape: vec![4],
200 dtype: "I32".to_string(),
201 num_elements: 4,
202 size_bytes: 16,
203 layout: TensorLayout::RowMajor,
204 };
205
206 let data: Vec<i32> = vec![10, 20, 30, 40];
207 let bytes = data
208 .iter()
209 .flat_map(|i| i.to_le_bytes())
210 .collect::<Vec<u8>>();
211
212 let batch = tensor_to_record_batch(&metadata, &bytes).unwrap();
213 assert_eq!(batch.num_rows(), 4);
214
215 let array = batch
216 .column(0)
217 .as_any()
218 .downcast_ref::<Int32Array>()
219 .unwrap();
220 assert_eq!(array.value(0), 10);
221 assert_eq!(array.value(3), 40);
222 }
223
224 #[test]
225 fn test_record_batch_to_ipc_bytes() {
226 let metadata = TensorMetadata {
227 shape: vec![3],
228 dtype: "F32".to_string(),
229 num_elements: 3,
230 size_bytes: 12,
231 layout: TensorLayout::RowMajor,
232 };
233
234 let data: Vec<f32> = vec![1.0, 2.0, 3.0];
235 let bytes = data
236 .iter()
237 .flat_map(|f| f.to_le_bytes())
238 .collect::<Vec<u8>>();
239
240 let batch = tensor_to_record_batch(&metadata, &bytes).unwrap();
241 let ipc_bytes = record_batch_to_ipc_bytes(&batch).unwrap();
242
243 assert!(ipc_bytes.len() > 50);
245 }
246
247 #[test]
248 fn test_create_tensor_schema() {
249 let metadata = TensorMetadata {
250 shape: vec![10, 20, 30],
251 dtype: "F64".to_string(),
252 num_elements: 6000,
253 size_bytes: 48000,
254 layout: TensorLayout::RowMajor,
255 };
256
257 let schema = create_tensor_schema(&metadata).unwrap();
258 assert_eq!(schema.fields().len(), 1);
259
260 let field = &schema.fields()[0];
261 assert_eq!(field.name(), "data");
262 assert_eq!(field.data_type(), &DataType::Float64);
263
264 let meta = field.metadata();
265 assert!(meta.contains_key("tensor_shape"));
266 assert_eq!(meta.get("tensor_shape").unwrap(), "10,20,30");
267 assert_eq!(meta.get("tensor_dtype").unwrap(), "F64");
268 }
269
270 #[test]
271 fn test_all_dtypes() {
272 let dtypes = vec!["F32", "F64", "I32", "I64", "U8", "U16", "U32", "U64"];
273
274 for dtype in dtypes {
275 let element_size = match dtype {
276 "F32" | "I32" | "U32" => 4,
277 "F64" | "I64" | "U64" => 8,
278 "U8" => 1,
279 "U16" => 2,
280 _ => 4,
281 };
282
283 let metadata = TensorMetadata {
284 shape: vec![4],
285 dtype: dtype.to_string(),
286 num_elements: 4,
287 size_bytes: 4 * element_size,
288 layout: TensorLayout::RowMajor,
289 };
290
291 let data = vec![0u8; metadata.size_bytes];
292 let result = tensor_to_record_batch(&metadata, &data);
293 assert!(result.is_ok(), "Failed for dtype: {}", dtype);
294 }
295 }
296}