ferrite/tensor_storage/
creation.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
use super::base::TensorStorage;  // Import from parent module's base.rs

use ndarray::{ArrayBase, Dimension};
use num_traits::cast::AsPrimitive;

pub trait TensorCreation {
  fn zeros(shape: Vec<usize>, requires_grad: Option<bool>) -> Self;
  fn ones(shape: Vec<usize>, requires_grad: Option<bool>) -> Self;
  fn from_ndarray<S, D, T>(data: &ArrayBase<S, D>, requires_grad: Option<bool>) -> Self
  where 
    S: ndarray::Data<Elem = T>,
    T: AsPrimitive<f32>,
    D: Dimension;
}

impl TensorCreation for TensorStorage {
  fn zeros(shape: Vec<usize>, _requires_grad: Option<bool>) -> Self {
    let size = shape.iter().product();
    let data = vec![0.0; size];
    TensorStorage::new(data, shape)
  }

  fn ones(shape: Vec<usize>, _requires_grad: Option<bool>) -> Self {
    let size = shape.iter().product();
    let data = vec![1.0; size];
    TensorStorage::new(data, shape)
  }

  fn from_ndarray<S, D, T>(data: &ArrayBase<S, D>, _requires_grad: Option<bool>) -> Self
  where 
    S: ndarray::Data<Elem = T>,
    T: AsPrimitive<f32>,
    D: Dimension,
  {
    let shape = data.shape().to_vec();
    let arr = data.mapv(|x| x.as_());
    let data = arr.iter().cloned().collect();
    TensorStorage::new(data, shape)
  }
}