1#![allow(dead_code)]
2use std::sync::Arc;
3
4use crate::{op::Op, storage::Storage, DType, Device};
5
6pub struct Tensor_ {
7 storage: Storage,
8 shape: Vec<usize>,
9 stride: Vec<usize>,
10 op: Option<Op>,
11}
12
13pub struct Tensor(Arc<Tensor_>);
14
15impl Tensor {
16 pub fn zeros(shape: &[usize], dtype: DType, device: Device) -> Self {
17 let storage = device.zeros(shape, dtype);
18 let tensor_ = Tensor_ {
19 storage,
20 shape: shape.to_vec(),
21 stride: vec![1; shape.len()],
22 op: None,
23 };
24 Tensor(Arc::new(tensor_))
25 }
26
27 pub fn dtype(&self) -> DType {
30 self.0.storage.dtype()
31 }
32
33 pub fn device(&self) -> Device {
34 self.0.storage.device()
35 }
36
37 pub fn shape(&self) -> &[usize] {
38 &self.0.shape
39 }
40
41 pub fn stride(&self) -> &[usize] {
42 &self.0.stride
43 }
44
45 pub fn rank(&self) -> usize {
47 self.0.shape.len()
48 }
49
50 pub fn elem_count(&self) -> usize {
52 self.0.shape.iter().product()
53 }
54}