ferrite/tensor/
creation.rs

1use crate::*;
2use ndarray;
3use num_traits;
4
5
6
7impl Tensor {
8  pub fn zeros(shape: Vec<usize>, device: Device, requires_grad: Option<bool>) -> Self {
9    let tensor = Storage::zeros(shape, Some(device), None);
10    let requires_grad = requires_grad.unwrap_or(false);
11    Tensor::new(tensor, device, requires_grad)
12  }
13
14  pub fn ones(shape: Vec<usize>, device: Device, requires_grad: Option<bool>) -> Self {
15    let tensor = Storage::ones(shape, Some(device), None);
16    let requires_grad = requires_grad.unwrap_or(false);
17    Tensor::new(tensor, device, requires_grad)
18  }
19
20  pub fn from_ndarray<S, D, T>(data: &ndarray::ArrayBase<S, D>, device: Device, requires_grad: Option<bool>) -> Self
21  where 
22    S: ndarray::Data<Elem = T>,
23    T: num_traits::AsPrimitive<f32>,
24    D: ndarray::Dimension 
25  {
26    let tensor = Storage::from_ndarray(data, Some(device), None);
27    let requires_grad = requires_grad.unwrap_or(false);
28    Tensor::new(tensor, device, requires_grad)
29  }
30
31  pub fn uniform(l_bound: f32, r_bound: f32, shape: Vec<usize>, device: Device, requires_grad: Option<bool>) -> Self {
32    let tensor = Storage::uniform(l_bound, r_bound, shape, Some(device), None);
33    let requires_grad = requires_grad.unwrap_or(false);
34    Tensor::new(tensor, device, requires_grad)
35  }
36}