hpt_dataloader/utils.rs
1use hpt_common::{error::base::TensorError, shape::shape::Shape};
2
3pub(crate) fn create_file(path: std::path::PathBuf, ext: &str) -> std::io::Result<std::fs::File> {
4 if let Some(extension) = path.extension() {
5 if extension == ext {
6 std::fs::File::create(path)
7 } else {
8 std::fs::File::create(format!("{}.{ext}", path.to_str().unwrap()))
9 }
10 } else {
11 std::fs::File::create(format!("{}.{ext}", path.to_str().unwrap()))
12 }
13}
14
15/// A trait defines empty function for Tensor that will allocate memory on CPU.
16pub trait CPUTensorCreator<T> {
17 /// the output type of the creator
18 type Output;
19
20 /// Creates a tensor with uninitialized elements of the specified shape.
21 ///
22 /// This function allocates memory for a tensor in CPU of the given shape, but the values are uninitialized, meaning they may contain random data.
23 ///
24 /// # Arguments
25 ///
26 /// * `shape` - The desired shape of the tensor. The type `S` must implement `Into<Shape>`.
27 ///
28 /// # Returns
29 ///
30 /// * A tensor with the specified shape, but with uninitialized data.
31 ///
32 /// # Panics
33 ///
34 /// * This function may panic if the requested shape is invalid or too large for available memory.
35 #[track_caller]
36 fn empty<S: Into<Shape>>(shape: S) -> Result<Self::Output, TensorError>;
37}