oar_ocr_core/utils/tensor.rs
1//! Tensor utility functions for converting between vectors and tensors.
2//!
3//! This module provides functions to convert between Rust vectors and
4//! multi-dimensional tensors (1D, 2D, 3D, and 4D). It also includes
5//! utility functions for tensor operations like slicing and stacking.
6
7use crate::core::OCRError;
8use crate::core::Tensor1D;
9use crate::core::Tensor2D;
10use crate::core::Tensor3D;
11use crate::core::Tensor4D;
12
13use ndarray::{Array2, Array3, Array4, ArrayD, Axis};
14
15/// Converts a 2D vector of f32 values into a 2D tensor.
16///
17/// # Arguments
18///
19/// * `data` - A slice of vectors containing f32 values.
20///
21/// # Returns
22///
23/// * `Ok(Tensor2D)` - A 2D tensor created from the input data.
24/// * `Err(OCRError)` - An error if the input data is invalid (e.g., empty, inconsistent row lengths).
25///
26/// # Examples
27///
28/// ```
29/// use oar_ocr_core::utils::tensor::vec_to_tensor2d;
30/// let data = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
31/// let tensor = vec_to_tensor2d(&data);
32/// ```
33pub fn vec_to_tensor2d(data: &[Vec<f32>]) -> Result<Tensor2D, OCRError> {
34 if data.is_empty() {
35 return Err(OCRError::InvalidInput {
36 message: "Empty data".to_string(),
37 });
38 }
39
40 let rows = data.len();
41 let cols = data[0].len();
42
43 if cols == 0 {
44 return Err(OCRError::InvalidInput {
45 message: "Cannot create tensor with zero-width columns".to_string(),
46 });
47 }
48
49 let total_size = rows
50 .checked_mul(cols)
51 .ok_or_else(|| OCRError::InvalidInput {
52 message: format!("Tensor dimensions ({rows}, {cols}) would cause integer overflow"),
53 })?;
54
55 const MAX_TENSOR_ELEMENTS: usize = 1_000_000_000;
56 if total_size > MAX_TENSOR_ELEMENTS {
57 return Err(OCRError::InvalidInput {
58 message: format!(
59 "Tensor size {total_size} exceeds maximum allowed size {MAX_TENSOR_ELEMENTS}"
60 ),
61 });
62 }
63
64 for (i, row) in data.iter().enumerate() {
65 if row.len() != cols {
66 return Err(OCRError::InvalidInput {
67 message: format!(
68 "Inconsistent row lengths at row {}: expected {}, got {}",
69 i,
70 cols,
71 row.len()
72 ),
73 });
74 }
75 }
76
77 let flat_data: Vec<f32> = data.iter().flat_map(|row| row.iter().cloned()).collect();
78 let flat_data_len = flat_data.len();
79 Array2::from_shape_vec((rows, cols), flat_data).map_err(|e| {
80 OCRError::tensor_operation_error(
81 "vec_to_tensor2d",
82 &[rows, cols],
83 &[flat_data_len],
84 &format!(
85 "Failed to create 2D tensor from {} rows x {} cols data",
86 rows, cols
87 ),
88 e,
89 )
90 })
91}
92
93/// Converts a 2D tensor into a 2D vector of f32 values.
94///
95/// # Arguments
96///
97/// * `tensor` - A reference to a 2D tensor.
98///
99/// # Returns
100///
101/// * `Vec<Vec<f32>>` - A 2D vector created from the tensor.
102///
103/// # Examples
104///
105/// ```
106/// use oar_ocr_core::utils::tensor::{tensor2d_to_vec, vec_to_tensor2d};
107/// // Create a tensor first
108/// # fn main() -> Result<(), oar_ocr_core::core::OCRError> {
109/// let data = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
110/// let tensor = vec_to_tensor2d(&data)?;
111/// let vec_data = tensor2d_to_vec(&tensor);
112/// # let _ = vec_data;
113/// # Ok(())
114/// # }
115/// ```
116pub fn tensor2d_to_vec(tensor: &Tensor2D) -> Vec<Vec<f32>> {
117 tensor.outer_iter().map(|row| row.to_vec()).collect()
118}
119
120/// Converts a 3D vector of f32 values into a 3D tensor.
121///
122/// # Arguments
123///
124/// * `data` - A slice of 3D vectors containing f32 values.
125///
126/// # Returns
127///
128/// * `Ok(Tensor3D)` - A 3D tensor created from the input data.
129/// * `Err(OCRError)` - An error if the input data is invalid (e.g., empty, inconsistent dimensions).
130///
131/// # Examples
132///
133/// ```
134/// use oar_ocr_core::utils::tensor::vec_to_tensor3d;
135/// let data = vec![vec![vec![1.0, 2.0], vec![3.0, 4.0]]];
136/// let tensor = vec_to_tensor3d(&data);
137/// ```
138pub fn vec_to_tensor3d(data: &[Vec<Vec<f32>>]) -> Result<Tensor3D, OCRError> {
139 if data.is_empty() {
140 return Err(OCRError::InvalidInput {
141 message: "Empty data".to_string(),
142 });
143 }
144
145 let dim0 = data.len();
146 let dim1 = data[0].len();
147
148 let dim2 = if dim1 > 0 && !data[0].is_empty() {
149 data[0][0].len()
150 } else {
151 0
152 };
153
154 for (i, outer) in data.iter().enumerate() {
155 if outer.len() != dim1 {
156 return Err(OCRError::InvalidInput {
157 message: format!(
158 "Inconsistent dimension 1 at index {}: expected {}, got {}",
159 i,
160 dim1,
161 outer.len()
162 ),
163 });
164 }
165 for (j, inner) in outer.iter().enumerate() {
166 if inner.len() != dim2 {
167 return Err(OCRError::InvalidInput {
168 message: format!(
169 "Inconsistent dimension 2 at index [{}, {}]: expected {}, got {}",
170 i,
171 j,
172 dim2,
173 inner.len()
174 ),
175 });
176 }
177 }
178 }
179
180 let total_size = dim0
181 .checked_mul(dim1)
182 .and_then(|size| size.checked_mul(dim2))
183 .ok_or_else(|| OCRError::InvalidInput {
184 message: format!(
185 "Tensor dimensions ({dim0}, {dim1}, {dim2}) would cause integer overflow"
186 ),
187 })?;
188
189 const MAX_TENSOR_ELEMENTS: usize = 1_000_000_000;
190 if total_size > MAX_TENSOR_ELEMENTS {
191 return Err(OCRError::InvalidInput {
192 message: format!(
193 "Tensor size {total_size} exceeds maximum allowed size {MAX_TENSOR_ELEMENTS}"
194 ),
195 });
196 }
197
198 let flat_data: Vec<f32> = data
199 .iter()
200 .flat_map(|slice| slice.iter().flat_map(|row| row.iter().cloned()))
201 .collect();
202 let flat_data_len = flat_data.len();
203
204 Array3::from_shape_vec((dim0, dim1, dim2), flat_data).map_err(|e| {
205 OCRError::tensor_operation_error(
206 "vec_to_tensor3d",
207 &[dim0, dim1, dim2],
208 &[flat_data_len],
209 &format!(
210 "Failed to create 3D tensor from {}x{}x{} data",
211 dim0, dim1, dim2
212 ),
213 e,
214 )
215 })
216}
217
218/// Converts a 3D tensor into a 3D vector of f32 values.
219///
220/// # Arguments
221///
222/// * `tensor` - A reference to a 3D tensor.
223///
224/// # Returns
225///
226/// * `Vec<Vec<Vec<f32>>>` - A 3D vector created from the tensor.
227///
228/// # Examples
229///
230/// ```
231/// use oar_ocr_core::utils::tensor::{tensor3d_to_vec, vec_to_tensor3d};
232/// // Create a tensor first
233/// # fn main() -> Result<(), oar_ocr_core::core::OCRError> {
234/// let data = vec![vec![vec![1.0, 2.0], vec![3.0, 4.0]]];
235/// let tensor = vec_to_tensor3d(&data)?;
236/// let vec_data = tensor3d_to_vec(&tensor);
237/// # let _ = vec_data;
238/// # Ok(())
239/// # }
240/// ```
241pub fn tensor3d_to_vec(tensor: &Tensor3D) -> Vec<Vec<Vec<f32>>> {
242 tensor
243 .outer_iter()
244 .map(|slice| slice.outer_iter().map(|row| row.to_vec()).collect())
245 .collect()
246}
247
248/// Converts a 4D vector of f32 values into a 4D tensor.
249///
250/// # Arguments
251///
252/// * `data` - A slice of 4D vectors containing f32 values.
253///
254/// # Returns
255///
256/// * `Ok(Tensor4D)` - A 4D tensor created from the input data.
257/// * `Err(OCRError)` - An error if the input data is invalid (e.g., empty, inconsistent dimensions).
258///
259/// # Examples
260///
261/// ```
262/// use oar_ocr_core::utils::tensor::vec_to_tensor4d;
263/// let data = vec![vec![vec![vec![1.0, 2.0], vec![3.0, 4.0]]]];
264/// let tensor = vec_to_tensor4d(&data);
265/// ```
266pub fn vec_to_tensor4d(data: &[Vec<Vec<Vec<f32>>>]) -> Result<Tensor4D, OCRError> {
267 if data.is_empty() {
268 return Err(OCRError::InvalidInput {
269 message: "Empty data".to_string(),
270 });
271 }
272
273 let dim0 = data.len();
274 let dim1 = data[0].len();
275
276 let dim2 = if dim1 > 0 && !data[0].is_empty() {
277 data[0][0].len()
278 } else {
279 0
280 };
281
282 let dim3 = if dim2 > 0 && !data[0].is_empty() && !data[0][0].is_empty() {
283 data[0][0][0].len()
284 } else {
285 0
286 };
287
288 for (i, outer) in data.iter().enumerate() {
289 if outer.len() != dim1 {
290 return Err(OCRError::InvalidInput {
291 message: format!(
292 "Inconsistent dimension 1 at index {}: expected {}, got {}",
293 i,
294 dim1,
295 outer.len()
296 ),
297 });
298 }
299 for (j, middle) in outer.iter().enumerate() {
300 if middle.len() != dim2 {
301 return Err(OCRError::InvalidInput {
302 message: format!(
303 "Inconsistent dimension 2 at index [{}, {}]: expected {}, got {}",
304 i,
305 j,
306 dim2,
307 middle.len()
308 ),
309 });
310 }
311 for (k, inner) in middle.iter().enumerate() {
312 if inner.len() != dim3 {
313 return Err(OCRError::InvalidInput {
314 message: format!(
315 "Inconsistent dimension 3 at index [{}, {}, {}]: expected {}, got {}",
316 i,
317 j,
318 k,
319 dim3,
320 inner.len()
321 ),
322 });
323 }
324 }
325 }
326 }
327
328 let total_size = dim0
329 .checked_mul(dim1)
330 .and_then(|size| size.checked_mul(dim2))
331 .and_then(|size| size.checked_mul(dim3))
332 .ok_or_else(|| OCRError::InvalidInput {
333 message: format!(
334 "Tensor dimensions ({dim0}, {dim1}, {dim2}, {dim3}) would cause integer overflow"
335 ),
336 })?;
337
338 const MAX_TENSOR_ELEMENTS: usize = 1_000_000_000;
339 if total_size > MAX_TENSOR_ELEMENTS {
340 return Err(OCRError::InvalidInput {
341 message: format!(
342 "Tensor size {total_size} exceeds maximum allowed size {MAX_TENSOR_ELEMENTS}"
343 ),
344 });
345 }
346
347 Ok(Array4::from_shape_fn(
348 (dim0, dim1, dim2, dim3),
349 |(i, j, k, l)| data[i][j][k][l],
350 ))
351}
352
353/// Converts a 4D tensor into a 4D vector of f32 values.
354///
355/// # Arguments
356///
357/// * `tensor` - A reference to a 4D tensor.
358///
359/// # Returns
360///
361/// * `Vec<Vec<Vec<Vec<f32>>>>` - A 4D vector created from the tensor.
362///
363/// # Examples
364///
365/// ```
366/// use oar_ocr_core::utils::tensor::{tensor4d_to_vec, vec_to_tensor4d};
367/// // Create a tensor first
368/// # fn main() -> Result<(), oar_ocr_core::core::OCRError> {
369/// let data = vec![vec![vec![vec![1.0, 2.0], vec![3.0, 4.0]]]];
370/// let tensor = vec_to_tensor4d(&data)?;
371/// let vec_data = tensor4d_to_vec(&tensor);
372/// # let _ = vec_data;
373/// # Ok(())
374/// # }
375/// ```
376pub fn tensor4d_to_vec(tensor: &Tensor4D) -> Vec<Vec<Vec<Vec<f32>>>> {
377 tensor
378 .outer_iter()
379 .map(|batch| {
380 batch
381 .outer_iter()
382 .map(|slice| slice.outer_iter().map(|row| row.to_vec()).collect())
383 .collect()
384 })
385 .collect()
386}
387
388/// Converts a 1D vector of f32 values into a 1D tensor with the specified shape.
389///
390/// # Arguments
391///
392/// * `data` - A vector of f32 values.
393/// * `shape` - A slice of usize values representing the shape of the tensor.
394///
395/// # Returns
396///
397/// * `Ok(Tensor1D)` - A 1D tensor created from the input data.
398/// * `Err(OCRError)` - An error if the input data is invalid.
399///
400/// # Examples
401///
402/// ```
403/// use oar_ocr_core::utils::tensor::vec_to_tensor1d;
404/// let data = vec![1.0, 2.0, 3.0, 4.0];
405/// let shape = &[2, 2];
406/// let tensor = vec_to_tensor1d(data, shape);
407/// ```
408pub fn vec_to_tensor1d(data: Vec<f32>, shape: &[usize]) -> Result<Tensor1D, OCRError> {
409 let data_len = data.len();
410 ArrayD::from_shape_vec(shape, data).map_err(|e| {
411 OCRError::tensor_operation_error(
412 "vec_to_tensor1d",
413 shape,
414 &[data_len],
415 &format!(
416 "Failed to create 1D tensor with shape {:?} from {} elements",
417 shape, data_len
418 ),
419 e,
420 )
421 })
422}
423
424/// Converts a 1D tensor into a 1D vector of f32 values.
425///
426/// # Arguments
427///
428/// * `tensor` - A reference to a 1D tensor.
429///
430/// # Returns
431///
432/// * `Vec<f32>` - A 1D vector created from the tensor.
433///
434/// # Examples
435///
436/// ```
437/// use oar_ocr_core::utils::tensor::{tensor1d_to_vec, vec_to_tensor1d};
438/// // Create a tensor first
439/// # fn main() -> Result<(), oar_ocr_core::core::OCRError> {
440/// let data = vec![1.0, 2.0, 3.0, 4.0];
441/// let shape = &[4];
442/// let tensor = vec_to_tensor1d(data, shape)?;
443/// let vec_data = tensor1d_to_vec(&tensor);
444/// # let _ = vec_data;
445/// # Ok(())
446/// # }
447/// ```
448pub fn tensor1d_to_vec(tensor: &Tensor1D) -> Vec<f32> {
449 tensor.as_slice().unwrap_or(&[]).to_vec()
450}
451
452/// Extracts a 3D slice from a 4D tensor at the specified index.
453///
454/// # Arguments
455///
456/// * `tensor` - A reference to a 4D tensor.
457/// * `index` - The index of the slice to extract.
458///
459/// # Returns
460///
461/// * `Ok(Tensor3D)` - A 3D tensor slice extracted from the input tensor.
462/// * `Err(OCRError)` - An error if the index is out of bounds.
463///
464/// # Examples
465///
466/// ```
467/// use oar_ocr_core::utils::tensor::{tensor4d_slice, vec_to_tensor4d};
468/// // Create a tensor first
469/// # fn main() -> Result<(), oar_ocr_core::core::OCRError> {
470/// let data = vec![vec![vec![vec![1.0, 2.0], vec![3.0, 4.0]]]];
471/// let tensor = vec_to_tensor4d(&data)?;
472/// let slice = tensor4d_slice(&tensor, 0)?;
473/// # let _ = slice;
474/// # Ok(())
475/// # }
476/// ```
477pub fn tensor4d_slice(tensor: &Tensor4D, index: usize) -> Result<Tensor3D, OCRError> {
478 if index >= tensor.shape()[0] {
479 return Err(OCRError::InvalidInput {
480 message: format!(
481 "Index {} out of bounds for tensor with shape {:?}",
482 index,
483 tensor.shape()
484 ),
485 });
486 }
487 Ok(tensor.index_axis(Axis(0), index).to_owned())
488}
489
490/// Extracts a 2D slice from a 3D tensor at the specified index.
491///
492/// # Arguments
493///
494/// * `tensor` - A reference to a 3D tensor.
495/// * `index` - The index of the slice to extract.
496///
497/// # Returns
498///
499/// * `Ok(Tensor2D)` - A 2D tensor slice extracted from the input tensor.
500/// * `Err(OCRError)` - An error if the index is out of bounds.
501///
502/// # Examples
503///
504/// ```
505/// use oar_ocr_core::utils::tensor::{tensor3d_slice, vec_to_tensor3d};
506/// // Create a tensor first
507/// # fn main() -> Result<(), oar_ocr_core::core::OCRError> {
508/// let data = vec![vec![vec![1.0, 2.0], vec![3.0, 4.0]]];
509/// let tensor = vec_to_tensor3d(&data)?;
510/// let slice = tensor3d_slice(&tensor, 0)?;
511/// # let _ = slice;
512/// # Ok(())
513/// # }
514/// ```
515pub fn tensor3d_slice(tensor: &Tensor3D, index: usize) -> Result<Tensor2D, OCRError> {
516 if index >= tensor.shape()[0] {
517 return Err(OCRError::InvalidInput {
518 message: format!(
519 "Index {} out of bounds for tensor with shape {:?}",
520 index,
521 tensor.shape()
522 ),
523 });
524 }
525 Ok(tensor.index_axis(Axis(0), index).to_owned())
526}
527
528/// Stacks a slice of 3D tensors into a single 4D tensor.
529///
530/// # Arguments
531///
532/// * `tensors` - A slice of 3D tensors to stack.
533///
534/// # Returns
535///
536/// * `Ok(Tensor4D)` - A 4D tensor created by stacking the input tensors.
537/// * `Err(OCRError)` - An error if the input tensors are invalid (e.g., empty, inconsistent shapes).
538///
539/// # Examples
540///
541/// ```
542/// use oar_ocr_core::utils::tensor::{stack_tensor3d, vec_to_tensor3d};
543/// // Create tensors first
544/// # fn main() -> Result<(), oar_ocr_core::core::OCRError> {
545/// let data1 = vec![vec![vec![1.0, 2.0], vec![3.0, 4.0]]];
546/// let data2 = vec![vec![vec![5.0, 6.0], vec![7.0, 8.0]]];
547/// let tensor1 = vec_to_tensor3d(&data1)?;
548/// let tensor2 = vec_to_tensor3d(&data2)?;
549/// let tensors = vec![tensor1, tensor2];
550/// let stacked_tensor = stack_tensor3d(&tensors)?;
551/// # let _ = stacked_tensor;
552/// # Ok(())
553/// # }
554/// ```
555pub fn stack_tensor3d(tensors: &[Tensor3D]) -> Result<Tensor4D, OCRError> {
556 if tensors.is_empty() {
557 return Err(OCRError::InvalidInput {
558 message: "No tensors to stack".to_string(),
559 });
560 }
561
562 let first_shape = tensors[0].shape();
563
564 if first_shape.contains(&0) {
565 return Err(OCRError::InvalidInput {
566 message: format!("Cannot stack tensors with zero dimensions: shape {first_shape:?}"),
567 });
568 }
569
570 for (i, tensor) in tensors.iter().enumerate() {
571 if tensor.is_empty() {
572 return Err(OCRError::InvalidInput {
573 message: format!("Tensor {i} is empty and cannot be stacked"),
574 });
575 }
576
577 if i > 0 && tensor.shape() != first_shape {
578 return Err(OCRError::InvalidInput {
579 message: format!(
580 "Tensor shape mismatch during 3D tensor stacking at index {}: expected {:?}, got {:?}",
581 i,
582 first_shape,
583 tensor.shape()
584 ),
585 });
586 }
587 }
588
589 let result_size = tensors
590 .len()
591 .checked_mul(first_shape.iter().product::<usize>())
592 .ok_or_else(|| OCRError::InvalidInput {
593 message: format!(
594 "Stacking {} tensors of shape {:?} would cause integer overflow",
595 tensors.len(),
596 first_shape
597 ),
598 })?;
599
600 const MAX_TENSOR_ELEMENTS: usize = 1_000_000_000;
601 if result_size > MAX_TENSOR_ELEMENTS {
602 return Err(OCRError::InvalidInput {
603 message: format!(
604 "Stacked tensor size {result_size} exceeds maximum allowed size {MAX_TENSOR_ELEMENTS}"
605 ),
606 });
607 }
608
609 let views: Vec<_> = tensors.iter().map(|t| t.view()).collect();
610
611 ndarray::stack(Axis(0), &views).map_err(|e| {
612 OCRError::tensor_operation_error(
613 "stack_tensor3d",
614 &[
615 tensors.len(),
616 first_shape[0],
617 first_shape[1],
618 first_shape[2],
619 ],
620 &[result_size],
621 &format!(
622 "Failed to stack {} 3D tensors of shape {:?}",
623 tensors.len(),
624 first_shape
625 ),
626 e,
627 )
628 })
629}
630
631/// Stacks a slice of 2D tensors into a single 3D tensor.
632///
633/// # Arguments
634///
635/// * `tensors` - A slice of 2D tensors to stack.
636///
637/// # Returns
638///
639/// * `Ok(Tensor3D)` - A 3D tensor created by stacking the input tensors.
640/// * `Err(OCRError)` - An error if the input tensors are invalid (e.g., empty, inconsistent shapes).
641///
642/// # Examples
643///
644/// ```
645/// use oar_ocr_core::utils::tensor::{stack_tensor2d, vec_to_tensor2d};
646/// // Create tensors first
647/// # fn main() -> Result<(), oar_ocr_core::core::OCRError> {
648/// let data1 = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
649/// let data2 = vec![vec![5.0, 6.0], vec![7.0, 8.0]];
650/// let tensor1 = vec_to_tensor2d(&data1)?;
651/// let tensor2 = vec_to_tensor2d(&data2)?;
652/// let tensors = vec![tensor1, tensor2];
653/// let stacked_tensor = stack_tensor2d(&tensors)?;
654/// # let _ = stacked_tensor;
655/// # Ok(())
656/// # }
657/// ```
658pub fn stack_tensor2d(tensors: &[Tensor2D]) -> Result<Tensor3D, OCRError> {
659 if tensors.is_empty() {
660 return Err(OCRError::InvalidInput {
661 message: "No tensors to stack".to_string(),
662 });
663 }
664
665 let first_shape = tensors[0].shape();
666
667 if first_shape.contains(&0) {
668 return Err(OCRError::InvalidInput {
669 message: format!("Cannot stack tensors with zero dimensions: shape {first_shape:?}"),
670 });
671 }
672
673 for (i, tensor) in tensors.iter().enumerate() {
674 if tensor.is_empty() {
675 return Err(OCRError::InvalidInput {
676 message: format!("Tensor {i} is empty and cannot be stacked"),
677 });
678 }
679
680 if i > 0 && tensor.shape() != first_shape {
681 return Err(OCRError::InvalidInput {
682 message: format!(
683 "All tensors must have the same shape for stacking. Tensor 0 has shape {:?}, tensor {} has shape {:?}",
684 first_shape,
685 i,
686 tensor.shape()
687 ),
688 });
689 }
690 }
691
692 let result_size = tensors
693 .len()
694 .checked_mul(first_shape.iter().product::<usize>())
695 .ok_or_else(|| OCRError::InvalidInput {
696 message: format!(
697 "Stacking {} tensors of shape {:?} would cause integer overflow",
698 tensors.len(),
699 first_shape
700 ),
701 })?;
702
703 const MAX_TENSOR_ELEMENTS: usize = 1_000_000_000;
704 if result_size > MAX_TENSOR_ELEMENTS {
705 return Err(OCRError::InvalidInput {
706 message: format!(
707 "Stacked tensor size {result_size} exceeds maximum allowed size {MAX_TENSOR_ELEMENTS}"
708 ),
709 });
710 }
711
712 let views: Vec<_> = tensors.iter().map(|t| t.view()).collect();
713
714 ndarray::stack(Axis(0), &views).map_err(OCRError::Tensor)
715}