infa_impl/
float32.rs

1use crate::BaseTensorOps;
2
3#[derive(Debug)]
4pub struct Float32Tensor {
5    pub shape: Vec<u64>,
6    pub data: Vec<f32>,
7}
8
9impl crate::TensorOps<Float32Tensor, f32> for Float32Tensor {
10    fn sum(&self, dim: i64) -> crate::Result<Float32Tensor> {
11        let dim = self.resolve_dim(dim)?;
12        if dim == 0 {
13            let sum = self.data.iter().sum();
14            return Ok(Float32Tensor {
15                shape: vec![1],
16                data: vec![sum],
17            });
18        }
19        let mut data = vec![0.0; self.data.len()];
20        for i in 0..self.data.len() {
21            data[i % self.shape[dim as usize] as usize] += self.data[i];
22        }
23        let mut shape = self.shape.clone();
24        shape[dim as usize] = 1;
25        Ok(Float32Tensor { shape, data })
26    }
27    fn item(&self) -> crate::Result<Vec<Self::Item>> {
28        Ok(self.data.clone())
29    }
30    fn size(&self) -> crate::Result<usize> {
31        Ok(self.data.len())
32    }
33    fn apply(&self, f: impl Fn(Self::Item) -> Self::Item) -> crate::Result<Float32Tensor> {
34        let mut result_data = Vec::with_capacity(self.data.len());
35        for item in self.data.iter() {
36            result_data.push(f(*item));
37        }
38        Ok(Float32Tensor {
39            shape: self.shape.clone(),
40            data: result_data,
41        })
42    }
43    fn apply_xy(
44        &self,
45        rhs: &Float32Tensor,
46        f: impl Fn(Self::Item, Self::Item) -> Self::Item,
47    ) -> crate::Result<Float32Tensor> {
48        let rhs = if rhs.data.len() == 1 {
49            &Float32Tensor {
50                shape: self.shape.clone(),
51                data: vec![rhs.data[0]; self.size()? as usize],
52            }
53        } else {
54            rhs
55        };
56        let mut result_data = Vec::with_capacity(self.data.len());
57        for (x, y) in self.data.iter().zip(rhs.data.iter()) {
58            result_data.push(f(*x, *y));
59        }
60        Ok(Float32Tensor {
61            shape: self.shape.clone(),
62            data: result_data,
63        })
64    }
65
66    fn matmul(&self, rhs: &Float32Tensor) -> crate::Result<Float32Tensor> {
67        let ar = self.shape[0] as usize;
68        let ac = self.shape[1] as usize;
69        let br = rhs.shape[0] as usize;
70        let bc = rhs.shape[1] as usize;
71        if ac != br {
72            return Err(crate::Error::InvalidShape(
73                self.shape.clone(),
74                rhs.shape.clone(),
75            ));
76        }
77        let mut data = vec![0.0; ar * bc];
78        for i in 0..ar {
79            for j in 0..bc {
80                for k in 0..ac {
81                    data[i * bc + j] += self.data[i * ac + k] * rhs.data[k * bc + j];
82                }
83            }
84        }
85        let shape = vec![ar as u64, bc as u64];
86        Ok(Float32Tensor { shape, data })
87    }
88}
89
90impl crate::BaseTensorOps for Float32Tensor {
91    type Item = f32;
92    fn shape(&self) -> &Vec<u64> {
93        &self.shape
94    }
95    fn reshape(&self, shape: Vec<i64>) -> crate::Result<Self> {
96        let shape = self.resolve_shape(shape)?;
97        Ok(Float32Tensor {
98            shape,
99            data: self.data.clone(),
100        })
101    }
102    fn from_values(shape: Vec<u64>, values: Vec<Self::Item>) -> crate::Result<Self> {
103        let size: u64 = shape.iter().product();
104        if size != values.len() as u64 {
105            return Err(crate::Error::InvalidShape(
106                shape.clone(),
107                vec![values.len() as u64],
108            ));
109        }
110        Ok(Self {
111            shape,
112            data: values,
113        })
114    }
115}