1use infa_impl::Dequantize;
2use infa_impl::Float32Tensor;
3use infa_impl::TensorOps;
4
5pub enum FloatTensor {
6 #[cfg(feature = "gguf")]
7 GGUFFloatTensor(infa_gguf::GGUFFloatTensor),
8 Float32Tensor(infa_impl::Float32Tensor),
9}
10
11impl infa_impl::TensorOps<FloatTensor, f32> for FloatTensor {
12 fn sum(&self, dim: i64) -> infa_impl::Result<FloatTensor> {
13 Ok(match self {
14 #[cfg(feature = "gguf")]
15 FloatTensor::GGUFFloatTensor(t) => FloatTensor::Float32Tensor(t.sum(dim)?),
16 FloatTensor::Float32Tensor(t) => FloatTensor::Float32Tensor(t.sum(dim)?),
17 })
18 }
19 fn item(&self) -> infa_impl::Result<Vec<Self::Item>> {
20 Ok(match self {
21 #[cfg(feature = "gguf")]
22 FloatTensor::GGUFFloatTensor(t) => t.item()?,
23 FloatTensor::Float32Tensor(t) => t.item()?,
24 })
25 }
26 fn size(&self) -> infa_impl::Result<usize> {
27 Ok(match self {
28 #[cfg(feature = "gguf")]
29 FloatTensor::GGUFFloatTensor(t) => t.size()?,
30 FloatTensor::Float32Tensor(t) => t.size()?,
31 })
32 }
33 fn apply(&self, f: impl Fn(Self::Item) -> Self::Item) -> infa_impl::Result<FloatTensor> {
34 Ok(match self {
35 #[cfg(feature = "gguf")]
36 FloatTensor::GGUFFloatTensor(t1) => FloatTensor::Float32Tensor(t1.apply(f)?),
37 FloatTensor::Float32Tensor(t1) => FloatTensor::Float32Tensor(t1.apply(f)?),
38 })
39 }
40 fn apply_xy(
41 &self,
42 rhs: &FloatTensor,
43 f: impl Fn(Self::Item, Self::Item) -> Self::Item,
44 ) -> infa_impl::Result<FloatTensor> {
45 Ok(match (self, rhs) {
46 #[cfg(feature = "gguf")]
47 (FloatTensor::GGUFFloatTensor(t1), FloatTensor::Float32Tensor(t2)) => {
48 FloatTensor::Float32Tensor(t1.apply_xy(t2, f)?)
49 }
50 #[cfg(feature = "gguf")]
51 (FloatTensor::Float32Tensor(t1), FloatTensor::GGUFFloatTensor(t2)) => {
52 FloatTensor::Float32Tensor(t2.apply_xy(t1, f)?)
53 }
54 #[cfg(feature = "gguf")]
55 (FloatTensor::GGUFFloatTensor(t1), FloatTensor::GGUFFloatTensor(t2)) => {
56 FloatTensor::Float32Tensor(t1.apply_xy(&t2.dequantize()?, f)?)
57 }
58 (FloatTensor::Float32Tensor(t1), FloatTensor::Float32Tensor(t2)) => {
59 FloatTensor::Float32Tensor(t1.apply_xy(t2, f)?)
60 }
61 })
62 }
63
64 fn matmul(&self, rhs: &FloatTensor) -> infa_impl::Result<FloatTensor> {
65 Ok(match (self, rhs) {
66 #[cfg(feature = "gguf")]
67 (FloatTensor::GGUFFloatTensor(t1), FloatTensor::Float32Tensor(t2)) => {
68 FloatTensor::Float32Tensor(t1.matmul(t2)?)
69 }
70 #[cfg(feature = "gguf")]
71 (FloatTensor::Float32Tensor(t1), FloatTensor::GGUFFloatTensor(t2)) => {
72 FloatTensor::Float32Tensor(t1.matmul(&t2.dequantize()?)?)
73 }
74 #[cfg(feature = "gguf")]
75 (FloatTensor::GGUFFloatTensor(t1), FloatTensor::GGUFFloatTensor(t2)) => {
76 FloatTensor::Float32Tensor(t1.matmul(&t2.dequantize()?)?)
77 }
78 (FloatTensor::Float32Tensor(t1), FloatTensor::Float32Tensor(t2)) => {
79 FloatTensor::Float32Tensor(t1.matmul(t2)?)
80 }
81 })
82 }
83}
84
85impl infa_impl::BaseTensorOps for FloatTensor {
86 type Item = f32;
87 fn shape(&self) -> &Vec<u64> {
88 match self {
89 #[cfg(feature = "gguf")]
90 FloatTensor::GGUFFloatTensor(t) => t.shape(),
91 FloatTensor::Float32Tensor(t) => t.shape(),
92 }
93 }
94
95 fn reshape(&self, shape: Vec<i64>) -> infa_impl::Result<Self> {
96 Ok(match self {
97 #[cfg(feature = "gguf")]
98 FloatTensor::GGUFFloatTensor(t) => FloatTensor::GGUFFloatTensor(t.reshape(shape)?),
99 FloatTensor::Float32Tensor(t) => FloatTensor::Float32Tensor(t.reshape(shape)?),
100 })
101 }
102 fn from_values(shape: Vec<u64>, values: Vec<Self::Item>) -> infa_impl::Result<Self> {
103 Float32Tensor::from_values(shape, values).map(FloatTensor::Float32Tensor)
104 }
105}
106
107impl<'a> std::ops::Add<&'a FloatTensor> for &'a FloatTensor {
108 type Output = infa_impl::Result<FloatTensor>;
109
110 fn add(self, rhs: Self) -> Self::Output {
111 TensorOps::add(self, rhs)
112 }
113}
114impl<'a> std::ops::Add<&'a f32> for &'a FloatTensor {
115 type Output = infa_impl::Result<FloatTensor>;
116
117 fn add(self, rhs: &f32) -> Self::Output {
118 TensorOps::add_item(self, rhs)
119 }
120}
121impl<'a> std::ops::Sub<&'a f32> for &'a FloatTensor {
122 type Output = infa_impl::Result<FloatTensor>;
123
124 fn sub(self, rhs: &f32) -> Self::Output {
125 TensorOps::sub_item(self, rhs)
126 }
127}
128impl<'a> std::ops::Sub<&'a FloatTensor> for &'a FloatTensor {
129 type Output = infa_impl::Result<FloatTensor>;
130
131 fn sub(self, rhs: &FloatTensor) -> Self::Output {
132 TensorOps::sub(self, rhs)
133 }
134}
135
136impl<'a> std::ops::Mul<&'a FloatTensor> for &'a FloatTensor {
137 type Output = infa_impl::Result<FloatTensor>;
138
139 fn mul(self, rhs: Self) -> Self::Output {
140 TensorOps::mul(self, rhs)
141 }
142}
143
144impl<'a> std::ops::Mul<&'a f32> for &'a FloatTensor {
145 type Output = infa_impl::Result<FloatTensor>;
146
147 fn mul(self, rhs: &f32) -> Self::Output {
148 TensorOps::mul_item(self, rhs)
149 }
150}
151impl<'a> std::ops::Div<&'a f32> for &'a FloatTensor {
152 type Output = infa_impl::Result<FloatTensor>;
153
154 fn div(self, rhs: &f32) -> Self::Output {
155 TensorOps::div_item(self, rhs)
156 }
157}
158impl<'a> std::ops::Div<&'a FloatTensor> for &'a FloatTensor {
159 type Output = infa_impl::Result<FloatTensor>;
160
161 fn div(self, rhs: Self) -> Self::Output {
162 TensorOps::div(self, rhs)
163 }
164}
165
166impl<'a> std::ops::Neg for &'a FloatTensor {
167 type Output = infa_impl::Result<FloatTensor>;
168
169 fn neg(self) -> Self::Output {
170 Ok(match self {
171 FloatTensor::Float32Tensor(s) => FloatTensor::Float32Tensor(s.neg()?),
172 #[cfg(feature = "gguf")]
173 FloatTensor::GGUFFloatTensor(s) => FloatTensor::Float32Tensor(s.neg()?),
174 })
175 }
176}