1use crate::BaseTensorOps;
2
3#[derive(Debug)]
4pub struct Float32Tensor {
5 pub shape: Vec<u64>,
6 pub data: Vec<f32>,
7}
8
9impl crate::TensorOps<Float32Tensor, f32> for Float32Tensor {
10 fn sum(&self, dim: i64) -> crate::Result<Float32Tensor> {
11 let dim = self.resolve_dim(dim)?;
12 if dim == 0 {
13 let sum = self.data.iter().sum();
14 return Ok(Float32Tensor {
15 shape: vec![1],
16 data: vec![sum],
17 });
18 }
19 let mut data = vec![0.0; self.data.len()];
20 for i in 0..self.data.len() {
21 data[i % self.shape[dim as usize] as usize] += self.data[i];
22 }
23 let mut shape = self.shape.clone();
24 shape[dim as usize] = 1;
25 Ok(Float32Tensor { shape, data })
26 }
27 fn item(&self) -> crate::Result<Vec<Self::Item>> {
28 Ok(self.data.clone())
29 }
30 fn size(&self) -> crate::Result<usize> {
31 Ok(self.data.len())
32 }
33 fn apply(&self, f: impl Fn(Self::Item) -> Self::Item) -> crate::Result<Float32Tensor> {
34 let mut result_data = Vec::with_capacity(self.data.len());
35 for item in self.data.iter() {
36 result_data.push(f(*item));
37 }
38 Ok(Float32Tensor {
39 shape: self.shape.clone(),
40 data: result_data,
41 })
42 }
43 fn apply_xy(
44 &self,
45 rhs: &Float32Tensor,
46 f: impl Fn(Self::Item, Self::Item) -> Self::Item,
47 ) -> crate::Result<Float32Tensor> {
48 let rhs = if rhs.data.len() == 1 {
49 &Float32Tensor {
50 shape: self.shape.clone(),
51 data: vec![rhs.data[0]; self.size()? as usize],
52 }
53 } else {
54 rhs
55 };
56 let mut result_data = Vec::with_capacity(self.data.len());
57 for (x, y) in self.data.iter().zip(rhs.data.iter()) {
58 result_data.push(f(*x, *y));
59 }
60 Ok(Float32Tensor {
61 shape: self.shape.clone(),
62 data: result_data,
63 })
64 }
65
66 fn matmul(&self, rhs: &Float32Tensor) -> crate::Result<Float32Tensor> {
67 let ar = self.shape[0] as usize;
68 let ac = self.shape[1] as usize;
69 let br = rhs.shape[0] as usize;
70 let bc = rhs.shape[1] as usize;
71 if ac != br {
72 return Err(crate::Error::InvalidShape(
73 self.shape.clone(),
74 rhs.shape.clone(),
75 ));
76 }
77 let mut data = vec![0.0; ar * bc];
78 for i in 0..ar {
79 for j in 0..bc {
80 for k in 0..ac {
81 data[i * bc + j] += self.data[i * ac + k] * rhs.data[k * bc + j];
82 }
83 }
84 }
85 let shape = vec![ar as u64, bc as u64];
86 Ok(Float32Tensor { shape, data })
87 }
88}
89
90impl crate::BaseTensorOps for Float32Tensor {
91 type Item = f32;
92 fn shape(&self) -> &Vec<u64> {
93 &self.shape
94 }
95 fn reshape(&self, shape: Vec<i64>) -> crate::Result<Self> {
96 let shape = self.resolve_shape(shape)?;
97 Ok(Float32Tensor {
98 shape,
99 data: self.data.clone(),
100 })
101 }
102 fn from_values(shape: Vec<u64>, values: Vec<Self::Item>) -> crate::Result<Self> {
103 let size: u64 = shape.iter().product();
104 if size != values.len() as u64 {
105 return Err(crate::Error::InvalidShape(
106 shape.clone(),
107 vec![values.len() as u64],
108 ));
109 }
110 Ok(Self {
111 shape,
112 data: values,
113 })
114 }
115}