hpt_dataloader/
utils.rs

1use hpt_common::{error::base::TensorError, shape::shape::Shape};
2
3pub(crate) fn create_file(path: std::path::PathBuf, ext: &str) -> std::io::Result<std::fs::File> {
4    if let Some(extension) = path.extension() {
5        if extension == ext {
6            std::fs::File::create(path)
7        } else {
8            std::fs::File::create(format!("{}.{ext}", path.to_str().unwrap()))
9        }
10    } else {
11        std::fs::File::create(format!("{}.{ext}", path.to_str().unwrap()))
12    }
13}
14
15/// A trait defines empty function for Tensor that will allocate memory on CPU.
16pub trait CPUTensorCreator<T> {
17    /// the output type of the creator
18    type Output;
19
20    /// Creates a tensor with uninitialized elements of the specified shape.
21    ///
22    /// This function allocates memory for a tensor in CPU of the given shape, but the values are uninitialized, meaning they may contain random data.
23    ///
24    /// # Arguments
25    ///
26    /// * `shape` - The desired shape of the tensor. The type `S` must implement `Into<Shape>`.
27    ///
28    /// # Returns
29    ///
30    /// * A tensor with the specified shape, but with uninitialized data.
31    ///
32    /// # Panics
33    ///
34    /// * This function may panic if the requested shape is invalid or too large for available memory.
35    #[track_caller]
36    fn empty<S: Into<Shape>>(shape: S) -> Result<Self::Output, TensorError>;
37}