flowtorch_core/
tensor.rs

1#![allow(dead_code)]
2use std::sync::Arc;
3
4use crate::{op::Op, storage::Storage, DType, Device};
5
6pub struct Tensor_ {
7    storage: Storage,
8    shape: Vec<usize>,
9    stride: Vec<usize>,
10    op: Option<Op>,
11}
12
13pub struct Tensor(Arc<Tensor_>);
14
15impl Tensor {
16    pub fn zeros(shape: &[usize], dtype: DType, device: Device) -> Self {
17        let storage = device.zeros(shape, dtype);
18        let tensor_ = Tensor_ {
19            storage,
20            shape: shape.to_vec(),
21            stride: vec![1; shape.len()],
22            op: None,
23        };
24        Tensor(Arc::new(tensor_))
25    }
26
27    //The reason for self.0 is Tensor is a tuple struct wapper around Tensor_ with Arc
28    //https://doc.rust-lang.org/std/keyword.self.html
29    pub fn dtype(&self) -> DType {
30        self.0.storage.dtype()
31    }
32
33    pub fn device(&self) -> Device {
34        self.0.storage.device()
35    }
36
37    pub fn shape(&self) -> &[usize] {
38        &self.0.shape
39    }
40
41    pub fn stride(&self) -> &[usize] {
42        &self.0.stride
43    }
44
45    //The rank of a tensor is the number of dimensions or axes it has. In other words, it is the length of the shape of the tensor.
46    pub fn rank(&self) -> usize {
47        self.0.shape.len()
48    }
49
50    //Max number of elements in the Tensor
51    pub fn elem_count(&self) -> usize {
52        self.0.shape.iter().product()
53    }
54}