ferrite/autograd/tensor/
creation.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
use crate::tensor_storage::*;
use super::base::*;
use ndarray;
use num_traits;

impl TensorCreation for Tensor {
  fn zeros(shape: Vec<usize>, requires_grad: Option<bool>) -> Self {
    let tensor = TensorStorage::zeros(shape, None);
    let requires_grad = requires_grad.unwrap_or(false);
    Tensor::new(tensor, requires_grad)
  }

  fn ones(shape: Vec<usize>, requires_grad: Option<bool>) -> Self {
    let tensor = TensorStorage::ones(shape, None);
    let requires_grad = requires_grad.unwrap_or(false);
    Tensor::new(tensor, requires_grad)
  }

  fn from_ndarray<S, D, T>(data: &ndarray::ArrayBase<S, D>, requires_grad: Option<bool>) -> Self
  where 
    S: ndarray::Data<Elem = T>,
    T: num_traits::AsPrimitive<f32>,
    D: ndarray::Dimension 
  {
    let tensor = TensorStorage::from_ndarray(data, None);
    let requires_grad = requires_grad.unwrap_or(false);
    Tensor::new(tensor, requires_grad)
  }
}