Skip to main content

oar_ocr_core/utils/
tensor.rs

1//! Tensor utility functions for converting between vectors and tensors.
2//!
3//! This module provides functions to convert between Rust vectors and
4//! multi-dimensional tensors (1D, 2D, 3D, and 4D). It also includes
5//! utility functions for tensor operations like slicing and stacking.
6
7use crate::core::OCRError;
8use crate::core::Tensor1D;
9use crate::core::Tensor2D;
10use crate::core::Tensor3D;
11use crate::core::Tensor4D;
12
13use ndarray::{Array2, Array3, Array4, ArrayD, Axis};
14
15/// Converts a 2D vector of f32 values into a 2D tensor.
16///
17/// # Arguments
18///
19/// * `data` - A slice of vectors containing f32 values.
20///
21/// # Returns
22///
23/// * `Ok(Tensor2D)` - A 2D tensor created from the input data.
24/// * `Err(OCRError)` - An error if the input data is invalid (e.g., empty, inconsistent row lengths).
25///
26/// # Examples
27///
28/// ```
29/// use oar_ocr_core::utils::tensor::vec_to_tensor2d;
30/// let data = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
31/// let tensor = vec_to_tensor2d(&data);
32/// ```
33pub fn vec_to_tensor2d(data: &[Vec<f32>]) -> Result<Tensor2D, OCRError> {
34    if data.is_empty() {
35        return Err(OCRError::InvalidInput {
36            message: "Empty data".to_string(),
37        });
38    }
39
40    let rows = data.len();
41    let cols = data[0].len();
42
43    if cols == 0 {
44        return Err(OCRError::InvalidInput {
45            message: "Cannot create tensor with zero-width columns".to_string(),
46        });
47    }
48
49    let total_size = rows
50        .checked_mul(cols)
51        .ok_or_else(|| OCRError::InvalidInput {
52            message: format!("Tensor dimensions ({rows}, {cols}) would cause integer overflow"),
53        })?;
54
55    const MAX_TENSOR_ELEMENTS: usize = 1_000_000_000;
56    if total_size > MAX_TENSOR_ELEMENTS {
57        return Err(OCRError::InvalidInput {
58            message: format!(
59                "Tensor size {total_size} exceeds maximum allowed size {MAX_TENSOR_ELEMENTS}"
60            ),
61        });
62    }
63
64    for (i, row) in data.iter().enumerate() {
65        if row.len() != cols {
66            return Err(OCRError::InvalidInput {
67                message: format!(
68                    "Inconsistent row lengths at row {}: expected {}, got {}",
69                    i,
70                    cols,
71                    row.len()
72                ),
73            });
74        }
75    }
76
77    let flat_data: Vec<f32> = data.iter().flat_map(|row| row.iter().cloned()).collect();
78    let flat_data_len = flat_data.len();
79    Array2::from_shape_vec((rows, cols), flat_data).map_err(|e| {
80        OCRError::tensor_operation_error(
81            "vec_to_tensor2d",
82            &[rows, cols],
83            &[flat_data_len],
84            &format!(
85                "Failed to create 2D tensor from {} rows x {} cols data",
86                rows, cols
87            ),
88            e,
89        )
90    })
91}
92
93/// Converts a 2D tensor into a 2D vector of f32 values.
94///
95/// # Arguments
96///
97/// * `tensor` - A reference to a 2D tensor.
98///
99/// # Returns
100///
101/// * `Vec<Vec<f32>>` - A 2D vector created from the tensor.
102///
103/// # Examples
104///
105/// ```
106/// use oar_ocr_core::utils::tensor::{tensor2d_to_vec, vec_to_tensor2d};
107/// // Create a tensor first
108/// # fn main() -> Result<(), oar_ocr_core::core::OCRError> {
109/// let data = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
110/// let tensor = vec_to_tensor2d(&data)?;
111/// let vec_data = tensor2d_to_vec(&tensor);
112/// # let _ = vec_data;
113/// # Ok(())
114/// # }
115/// ```
116pub fn tensor2d_to_vec(tensor: &Tensor2D) -> Vec<Vec<f32>> {
117    tensor.outer_iter().map(|row| row.to_vec()).collect()
118}
119
120/// Converts a 3D vector of f32 values into a 3D tensor.
121///
122/// # Arguments
123///
124/// * `data` - A slice of 3D vectors containing f32 values.
125///
126/// # Returns
127///
128/// * `Ok(Tensor3D)` - A 3D tensor created from the input data.
129/// * `Err(OCRError)` - An error if the input data is invalid (e.g., empty, inconsistent dimensions).
130///
131/// # Examples
132///
133/// ```
134/// use oar_ocr_core::utils::tensor::vec_to_tensor3d;
135/// let data = vec![vec![vec![1.0, 2.0], vec![3.0, 4.0]]];
136/// let tensor = vec_to_tensor3d(&data);
137/// ```
138pub fn vec_to_tensor3d(data: &[Vec<Vec<f32>>]) -> Result<Tensor3D, OCRError> {
139    if data.is_empty() {
140        return Err(OCRError::InvalidInput {
141            message: "Empty data".to_string(),
142        });
143    }
144
145    let dim0 = data.len();
146    let dim1 = data[0].len();
147
148    let dim2 = if dim1 > 0 && !data[0].is_empty() {
149        data[0][0].len()
150    } else {
151        0
152    };
153
154    for (i, outer) in data.iter().enumerate() {
155        if outer.len() != dim1 {
156            return Err(OCRError::InvalidInput {
157                message: format!(
158                    "Inconsistent dimension 1 at index {}: expected {}, got {}",
159                    i,
160                    dim1,
161                    outer.len()
162                ),
163            });
164        }
165        for (j, inner) in outer.iter().enumerate() {
166            if inner.len() != dim2 {
167                return Err(OCRError::InvalidInput {
168                    message: format!(
169                        "Inconsistent dimension 2 at index [{}, {}]: expected {}, got {}",
170                        i,
171                        j,
172                        dim2,
173                        inner.len()
174                    ),
175                });
176            }
177        }
178    }
179
180    let total_size = dim0
181        .checked_mul(dim1)
182        .and_then(|size| size.checked_mul(dim2))
183        .ok_or_else(|| OCRError::InvalidInput {
184            message: format!(
185                "Tensor dimensions ({dim0}, {dim1}, {dim2}) would cause integer overflow"
186            ),
187        })?;
188
189    const MAX_TENSOR_ELEMENTS: usize = 1_000_000_000;
190    if total_size > MAX_TENSOR_ELEMENTS {
191        return Err(OCRError::InvalidInput {
192            message: format!(
193                "Tensor size {total_size} exceeds maximum allowed size {MAX_TENSOR_ELEMENTS}"
194            ),
195        });
196    }
197
198    let flat_data: Vec<f32> = data
199        .iter()
200        .flat_map(|slice| slice.iter().flat_map(|row| row.iter().cloned()))
201        .collect();
202    let flat_data_len = flat_data.len();
203
204    Array3::from_shape_vec((dim0, dim1, dim2), flat_data).map_err(|e| {
205        OCRError::tensor_operation_error(
206            "vec_to_tensor3d",
207            &[dim0, dim1, dim2],
208            &[flat_data_len],
209            &format!(
210                "Failed to create 3D tensor from {}x{}x{} data",
211                dim0, dim1, dim2
212            ),
213            e,
214        )
215    })
216}
217
218/// Converts a 3D tensor into a 3D vector of f32 values.
219///
220/// # Arguments
221///
222/// * `tensor` - A reference to a 3D tensor.
223///
224/// # Returns
225///
226/// * `Vec<Vec<Vec<f32>>>` - A 3D vector created from the tensor.
227///
228/// # Examples
229///
230/// ```
231/// use oar_ocr_core::utils::tensor::{tensor3d_to_vec, vec_to_tensor3d};
232/// // Create a tensor first
233/// # fn main() -> Result<(), oar_ocr_core::core::OCRError> {
234/// let data = vec![vec![vec![1.0, 2.0], vec![3.0, 4.0]]];
235/// let tensor = vec_to_tensor3d(&data)?;
236/// let vec_data = tensor3d_to_vec(&tensor);
237/// # let _ = vec_data;
238/// # Ok(())
239/// # }
240/// ```
241pub fn tensor3d_to_vec(tensor: &Tensor3D) -> Vec<Vec<Vec<f32>>> {
242    tensor
243        .outer_iter()
244        .map(|slice| slice.outer_iter().map(|row| row.to_vec()).collect())
245        .collect()
246}
247
248/// Converts a 4D vector of f32 values into a 4D tensor.
249///
250/// # Arguments
251///
252/// * `data` - A slice of 4D vectors containing f32 values.
253///
254/// # Returns
255///
256/// * `Ok(Tensor4D)` - A 4D tensor created from the input data.
257/// * `Err(OCRError)` - An error if the input data is invalid (e.g., empty, inconsistent dimensions).
258///
259/// # Examples
260///
261/// ```
262/// use oar_ocr_core::utils::tensor::vec_to_tensor4d;
263/// let data = vec![vec![vec![vec![1.0, 2.0], vec![3.0, 4.0]]]];
264/// let tensor = vec_to_tensor4d(&data);
265/// ```
266pub fn vec_to_tensor4d(data: &[Vec<Vec<Vec<f32>>>]) -> Result<Tensor4D, OCRError> {
267    if data.is_empty() {
268        return Err(OCRError::InvalidInput {
269            message: "Empty data".to_string(),
270        });
271    }
272
273    let dim0 = data.len();
274    let dim1 = data[0].len();
275
276    let dim2 = if dim1 > 0 && !data[0].is_empty() {
277        data[0][0].len()
278    } else {
279        0
280    };
281
282    let dim3 = if dim2 > 0 && !data[0].is_empty() && !data[0][0].is_empty() {
283        data[0][0][0].len()
284    } else {
285        0
286    };
287
288    for (i, outer) in data.iter().enumerate() {
289        if outer.len() != dim1 {
290            return Err(OCRError::InvalidInput {
291                message: format!(
292                    "Inconsistent dimension 1 at index {}: expected {}, got {}",
293                    i,
294                    dim1,
295                    outer.len()
296                ),
297            });
298        }
299        for (j, middle) in outer.iter().enumerate() {
300            if middle.len() != dim2 {
301                return Err(OCRError::InvalidInput {
302                    message: format!(
303                        "Inconsistent dimension 2 at index [{}, {}]: expected {}, got {}",
304                        i,
305                        j,
306                        dim2,
307                        middle.len()
308                    ),
309                });
310            }
311            for (k, inner) in middle.iter().enumerate() {
312                if inner.len() != dim3 {
313                    return Err(OCRError::InvalidInput {
314                        message: format!(
315                            "Inconsistent dimension 3 at index [{}, {}, {}]: expected {}, got {}",
316                            i,
317                            j,
318                            k,
319                            dim3,
320                            inner.len()
321                        ),
322                    });
323                }
324            }
325        }
326    }
327
328    let total_size = dim0
329        .checked_mul(dim1)
330        .and_then(|size| size.checked_mul(dim2))
331        .and_then(|size| size.checked_mul(dim3))
332        .ok_or_else(|| OCRError::InvalidInput {
333            message: format!(
334                "Tensor dimensions ({dim0}, {dim1}, {dim2}, {dim3}) would cause integer overflow"
335            ),
336        })?;
337
338    const MAX_TENSOR_ELEMENTS: usize = 1_000_000_000;
339    if total_size > MAX_TENSOR_ELEMENTS {
340        return Err(OCRError::InvalidInput {
341            message: format!(
342                "Tensor size {total_size} exceeds maximum allowed size {MAX_TENSOR_ELEMENTS}"
343            ),
344        });
345    }
346
347    Ok(Array4::from_shape_fn(
348        (dim0, dim1, dim2, dim3),
349        |(i, j, k, l)| data[i][j][k][l],
350    ))
351}
352
353/// Converts a 4D tensor into a 4D vector of f32 values.
354///
355/// # Arguments
356///
357/// * `tensor` - A reference to a 4D tensor.
358///
359/// # Returns
360///
361/// * `Vec<Vec<Vec<Vec<f32>>>>` - A 4D vector created from the tensor.
362///
363/// # Examples
364///
365/// ```
366/// use oar_ocr_core::utils::tensor::{tensor4d_to_vec, vec_to_tensor4d};
367/// // Create a tensor first
368/// # fn main() -> Result<(), oar_ocr_core::core::OCRError> {
369/// let data = vec![vec![vec![vec![1.0, 2.0], vec![3.0, 4.0]]]];
370/// let tensor = vec_to_tensor4d(&data)?;
371/// let vec_data = tensor4d_to_vec(&tensor);
372/// # let _ = vec_data;
373/// # Ok(())
374/// # }
375/// ```
376pub fn tensor4d_to_vec(tensor: &Tensor4D) -> Vec<Vec<Vec<Vec<f32>>>> {
377    tensor
378        .outer_iter()
379        .map(|batch| {
380            batch
381                .outer_iter()
382                .map(|slice| slice.outer_iter().map(|row| row.to_vec()).collect())
383                .collect()
384        })
385        .collect()
386}
387
388/// Converts a 1D vector of f32 values into a 1D tensor with the specified shape.
389///
390/// # Arguments
391///
392/// * `data` - A vector of f32 values.
393/// * `shape` - A slice of usize values representing the shape of the tensor.
394///
395/// # Returns
396///
397/// * `Ok(Tensor1D)` - A 1D tensor created from the input data.
398/// * `Err(OCRError)` - An error if the input data is invalid.
399///
400/// # Examples
401///
402/// ```
403/// use oar_ocr_core::utils::tensor::vec_to_tensor1d;
404/// let data = vec![1.0, 2.0, 3.0, 4.0];
405/// let shape = &[2, 2];
406/// let tensor = vec_to_tensor1d(data, shape);
407/// ```
408pub fn vec_to_tensor1d(data: Vec<f32>, shape: &[usize]) -> Result<Tensor1D, OCRError> {
409    let data_len = data.len();
410    ArrayD::from_shape_vec(shape, data).map_err(|e| {
411        OCRError::tensor_operation_error(
412            "vec_to_tensor1d",
413            shape,
414            &[data_len],
415            &format!(
416                "Failed to create 1D tensor with shape {:?} from {} elements",
417                shape, data_len
418            ),
419            e,
420        )
421    })
422}
423
424/// Converts a 1D tensor into a 1D vector of f32 values.
425///
426/// # Arguments
427///
428/// * `tensor` - A reference to a 1D tensor.
429///
430/// # Returns
431///
432/// * `Vec<f32>` - A 1D vector created from the tensor.
433///
434/// # Examples
435///
436/// ```
437/// use oar_ocr_core::utils::tensor::{tensor1d_to_vec, vec_to_tensor1d};
438/// // Create a tensor first
439/// # fn main() -> Result<(), oar_ocr_core::core::OCRError> {
440/// let data = vec![1.0, 2.0, 3.0, 4.0];
441/// let shape = &[4];
442/// let tensor = vec_to_tensor1d(data, shape)?;
443/// let vec_data = tensor1d_to_vec(&tensor);
444/// # let _ = vec_data;
445/// # Ok(())
446/// # }
447/// ```
448pub fn tensor1d_to_vec(tensor: &Tensor1D) -> Vec<f32> {
449    tensor.as_slice().unwrap_or(&[]).to_vec()
450}
451
452/// Extracts a 3D slice from a 4D tensor at the specified index.
453///
454/// # Arguments
455///
456/// * `tensor` - A reference to a 4D tensor.
457/// * `index` - The index of the slice to extract.
458///
459/// # Returns
460///
461/// * `Ok(Tensor3D)` - A 3D tensor slice extracted from the input tensor.
462/// * `Err(OCRError)` - An error if the index is out of bounds.
463///
464/// # Examples
465///
466/// ```
467/// use oar_ocr_core::utils::tensor::{tensor4d_slice, vec_to_tensor4d};
468/// // Create a tensor first
469/// # fn main() -> Result<(), oar_ocr_core::core::OCRError> {
470/// let data = vec![vec![vec![vec![1.0, 2.0], vec![3.0, 4.0]]]];
471/// let tensor = vec_to_tensor4d(&data)?;
472/// let slice = tensor4d_slice(&tensor, 0)?;
473/// # let _ = slice;
474/// # Ok(())
475/// # }
476/// ```
477pub fn tensor4d_slice(tensor: &Tensor4D, index: usize) -> Result<Tensor3D, OCRError> {
478    if index >= tensor.shape()[0] {
479        return Err(OCRError::InvalidInput {
480            message: format!(
481                "Index {} out of bounds for tensor with shape {:?}",
482                index,
483                tensor.shape()
484            ),
485        });
486    }
487    Ok(tensor.index_axis(Axis(0), index).to_owned())
488}
489
490/// Extracts a 2D slice from a 3D tensor at the specified index.
491///
492/// # Arguments
493///
494/// * `tensor` - A reference to a 3D tensor.
495/// * `index` - The index of the slice to extract.
496///
497/// # Returns
498///
499/// * `Ok(Tensor2D)` - A 2D tensor slice extracted from the input tensor.
500/// * `Err(OCRError)` - An error if the index is out of bounds.
501///
502/// # Examples
503///
504/// ```
505/// use oar_ocr_core::utils::tensor::{tensor3d_slice, vec_to_tensor3d};
506/// // Create a tensor first
507/// # fn main() -> Result<(), oar_ocr_core::core::OCRError> {
508/// let data = vec![vec![vec![1.0, 2.0], vec![3.0, 4.0]]];
509/// let tensor = vec_to_tensor3d(&data)?;
510/// let slice = tensor3d_slice(&tensor, 0)?;
511/// # let _ = slice;
512/// # Ok(())
513/// # }
514/// ```
515pub fn tensor3d_slice(tensor: &Tensor3D, index: usize) -> Result<Tensor2D, OCRError> {
516    if index >= tensor.shape()[0] {
517        return Err(OCRError::InvalidInput {
518            message: format!(
519                "Index {} out of bounds for tensor with shape {:?}",
520                index,
521                tensor.shape()
522            ),
523        });
524    }
525    Ok(tensor.index_axis(Axis(0), index).to_owned())
526}
527
528/// Stacks a slice of 3D tensors into a single 4D tensor.
529///
530/// # Arguments
531///
532/// * `tensors` - A slice of 3D tensors to stack.
533///
534/// # Returns
535///
536/// * `Ok(Tensor4D)` - A 4D tensor created by stacking the input tensors.
537/// * `Err(OCRError)` - An error if the input tensors are invalid (e.g., empty, inconsistent shapes).
538///
539/// # Examples
540///
541/// ```
542/// use oar_ocr_core::utils::tensor::{stack_tensor3d, vec_to_tensor3d};
543/// // Create tensors first
544/// # fn main() -> Result<(), oar_ocr_core::core::OCRError> {
545/// let data1 = vec![vec![vec![1.0, 2.0], vec![3.0, 4.0]]];
546/// let data2 = vec![vec![vec![5.0, 6.0], vec![7.0, 8.0]]];
547/// let tensor1 = vec_to_tensor3d(&data1)?;
548/// let tensor2 = vec_to_tensor3d(&data2)?;
549/// let tensors = vec![tensor1, tensor2];
550/// let stacked_tensor = stack_tensor3d(&tensors)?;
551/// # let _ = stacked_tensor;
552/// # Ok(())
553/// # }
554/// ```
555pub fn stack_tensor3d(tensors: &[Tensor3D]) -> Result<Tensor4D, OCRError> {
556    if tensors.is_empty() {
557        return Err(OCRError::InvalidInput {
558            message: "No tensors to stack".to_string(),
559        });
560    }
561
562    let first_shape = tensors[0].shape();
563
564    if first_shape.contains(&0) {
565        return Err(OCRError::InvalidInput {
566            message: format!("Cannot stack tensors with zero dimensions: shape {first_shape:?}"),
567        });
568    }
569
570    for (i, tensor) in tensors.iter().enumerate() {
571        if tensor.is_empty() {
572            return Err(OCRError::InvalidInput {
573                message: format!("Tensor {i} is empty and cannot be stacked"),
574            });
575        }
576
577        if i > 0 && tensor.shape() != first_shape {
578            return Err(OCRError::InvalidInput {
579                message: format!(
580                    "Tensor shape mismatch during 3D tensor stacking at index {}: expected {:?}, got {:?}",
581                    i,
582                    first_shape,
583                    tensor.shape()
584                ),
585            });
586        }
587    }
588
589    let result_size = tensors
590        .len()
591        .checked_mul(first_shape.iter().product::<usize>())
592        .ok_or_else(|| OCRError::InvalidInput {
593            message: format!(
594                "Stacking {} tensors of shape {:?} would cause integer overflow",
595                tensors.len(),
596                first_shape
597            ),
598        })?;
599
600    const MAX_TENSOR_ELEMENTS: usize = 1_000_000_000;
601    if result_size > MAX_TENSOR_ELEMENTS {
602        return Err(OCRError::InvalidInput {
603            message: format!(
604                "Stacked tensor size {result_size} exceeds maximum allowed size {MAX_TENSOR_ELEMENTS}"
605            ),
606        });
607    }
608
609    let views: Vec<_> = tensors.iter().map(|t| t.view()).collect();
610
611    ndarray::stack(Axis(0), &views).map_err(|e| {
612        OCRError::tensor_operation_error(
613            "stack_tensor3d",
614            &[
615                tensors.len(),
616                first_shape[0],
617                first_shape[1],
618                first_shape[2],
619            ],
620            &[result_size],
621            &format!(
622                "Failed to stack {} 3D tensors of shape {:?}",
623                tensors.len(),
624                first_shape
625            ),
626            e,
627        )
628    })
629}
630
631/// Stacks a slice of 2D tensors into a single 3D tensor.
632///
633/// # Arguments
634///
635/// * `tensors` - A slice of 2D tensors to stack.
636///
637/// # Returns
638///
639/// * `Ok(Tensor3D)` - A 3D tensor created by stacking the input tensors.
640/// * `Err(OCRError)` - An error if the input tensors are invalid (e.g., empty, inconsistent shapes).
641///
642/// # Examples
643///
644/// ```
645/// use oar_ocr_core::utils::tensor::{stack_tensor2d, vec_to_tensor2d};
646/// // Create tensors first
647/// # fn main() -> Result<(), oar_ocr_core::core::OCRError> {
648/// let data1 = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
649/// let data2 = vec![vec![5.0, 6.0], vec![7.0, 8.0]];
650/// let tensor1 = vec_to_tensor2d(&data1)?;
651/// let tensor2 = vec_to_tensor2d(&data2)?;
652/// let tensors = vec![tensor1, tensor2];
653/// let stacked_tensor = stack_tensor2d(&tensors)?;
654/// # let _ = stacked_tensor;
655/// # Ok(())
656/// # }
657/// ```
658pub fn stack_tensor2d(tensors: &[Tensor2D]) -> Result<Tensor3D, OCRError> {
659    if tensors.is_empty() {
660        return Err(OCRError::InvalidInput {
661            message: "No tensors to stack".to_string(),
662        });
663    }
664
665    let first_shape = tensors[0].shape();
666
667    if first_shape.contains(&0) {
668        return Err(OCRError::InvalidInput {
669            message: format!("Cannot stack tensors with zero dimensions: shape {first_shape:?}"),
670        });
671    }
672
673    for (i, tensor) in tensors.iter().enumerate() {
674        if tensor.is_empty() {
675            return Err(OCRError::InvalidInput {
676                message: format!("Tensor {i} is empty and cannot be stacked"),
677            });
678        }
679
680        if i > 0 && tensor.shape() != first_shape {
681            return Err(OCRError::InvalidInput {
682                message: format!(
683                    "All tensors must have the same shape for stacking. Tensor 0 has shape {:?}, tensor {} has shape {:?}",
684                    first_shape,
685                    i,
686                    tensor.shape()
687                ),
688            });
689        }
690    }
691
692    let result_size = tensors
693        .len()
694        .checked_mul(first_shape.iter().product::<usize>())
695        .ok_or_else(|| OCRError::InvalidInput {
696            message: format!(
697                "Stacking {} tensors of shape {:?} would cause integer overflow",
698                tensors.len(),
699                first_shape
700            ),
701        })?;
702
703    const MAX_TENSOR_ELEMENTS: usize = 1_000_000_000;
704    if result_size > MAX_TENSOR_ELEMENTS {
705        return Err(OCRError::InvalidInput {
706            message: format!(
707                "Stacked tensor size {result_size} exceeds maximum allowed size {MAX_TENSOR_ELEMENTS}"
708            ),
709        });
710    }
711
712    let views: Vec<_> = tensors.iter().map(|t| t.view()).collect();
713
714    ndarray::stack(Axis(0), &views).map_err(OCRError::Tensor)
715}