infa_core/
ops.rs

1use infa_impl::Dequantize;
2use infa_impl::Float32Tensor;
3use infa_impl::TensorOps;
4
5pub enum FloatTensor {
6    #[cfg(feature = "gguf")]
7    GGUFFloatTensor(infa_gguf::GGUFFloatTensor),
8    Float32Tensor(infa_impl::Float32Tensor),
9}
10
11impl infa_impl::TensorOps<FloatTensor, f32> for FloatTensor {
12    fn sum(&self, dim: i64) -> infa_impl::Result<FloatTensor> {
13        Ok(match self {
14            #[cfg(feature = "gguf")]
15            FloatTensor::GGUFFloatTensor(t) => FloatTensor::Float32Tensor(t.sum(dim)?),
16            FloatTensor::Float32Tensor(t) => FloatTensor::Float32Tensor(t.sum(dim)?),
17        })
18    }
19    fn item(&self) -> infa_impl::Result<Vec<Self::Item>> {
20        Ok(match self {
21            #[cfg(feature = "gguf")]
22            FloatTensor::GGUFFloatTensor(t) => t.item()?,
23            FloatTensor::Float32Tensor(t) => t.item()?,
24        })
25    }
26    fn size(&self) -> infa_impl::Result<usize> {
27        Ok(match self {
28            #[cfg(feature = "gguf")]
29            FloatTensor::GGUFFloatTensor(t) => t.size()?,
30            FloatTensor::Float32Tensor(t) => t.size()?,
31        })
32    }
33    fn apply(&self, f: impl Fn(Self::Item) -> Self::Item) -> infa_impl::Result<FloatTensor> {
34        Ok(match self {
35            #[cfg(feature = "gguf")]
36            FloatTensor::GGUFFloatTensor(t1) => FloatTensor::Float32Tensor(t1.apply(f)?),
37            FloatTensor::Float32Tensor(t1) => FloatTensor::Float32Tensor(t1.apply(f)?),
38        })
39    }
40    fn apply_xy(
41        &self,
42        rhs: &FloatTensor,
43        f: impl Fn(Self::Item, Self::Item) -> Self::Item,
44    ) -> infa_impl::Result<FloatTensor> {
45        Ok(match (self, rhs) {
46            #[cfg(feature = "gguf")]
47            (FloatTensor::GGUFFloatTensor(t1), FloatTensor::Float32Tensor(t2)) => {
48                FloatTensor::Float32Tensor(t1.apply_xy(t2, f)?)
49            }
50            #[cfg(feature = "gguf")]
51            (FloatTensor::Float32Tensor(t1), FloatTensor::GGUFFloatTensor(t2)) => {
52                FloatTensor::Float32Tensor(t2.apply_xy(t1, f)?)
53            }
54            #[cfg(feature = "gguf")]
55            (FloatTensor::GGUFFloatTensor(t1), FloatTensor::GGUFFloatTensor(t2)) => {
56                FloatTensor::Float32Tensor(t1.apply_xy(&t2.dequantize()?, f)?)
57            }
58            (FloatTensor::Float32Tensor(t1), FloatTensor::Float32Tensor(t2)) => {
59                FloatTensor::Float32Tensor(t1.apply_xy(t2, f)?)
60            }
61        })
62    }
63
64    fn matmul(&self, rhs: &FloatTensor) -> infa_impl::Result<FloatTensor> {
65        Ok(match (self, rhs) {
66            #[cfg(feature = "gguf")]
67            (FloatTensor::GGUFFloatTensor(t1), FloatTensor::Float32Tensor(t2)) => {
68                FloatTensor::Float32Tensor(t1.matmul(t2)?)
69            }
70            #[cfg(feature = "gguf")]
71            (FloatTensor::Float32Tensor(t1), FloatTensor::GGUFFloatTensor(t2)) => {
72                FloatTensor::Float32Tensor(t1.matmul(&t2.dequantize()?)?)
73            }
74            #[cfg(feature = "gguf")]
75            (FloatTensor::GGUFFloatTensor(t1), FloatTensor::GGUFFloatTensor(t2)) => {
76                FloatTensor::Float32Tensor(t1.matmul(&t2.dequantize()?)?)
77            }
78            (FloatTensor::Float32Tensor(t1), FloatTensor::Float32Tensor(t2)) => {
79                FloatTensor::Float32Tensor(t1.matmul(t2)?)
80            }
81        })
82    }
83}
84
85impl infa_impl::BaseTensorOps for FloatTensor {
86    type Item = f32;
87    fn shape(&self) -> &Vec<u64> {
88        match self {
89            #[cfg(feature = "gguf")]
90            FloatTensor::GGUFFloatTensor(t) => t.shape(),
91            FloatTensor::Float32Tensor(t) => t.shape(),
92        }
93    }
94
95    fn reshape(&self, shape: Vec<i64>) -> infa_impl::Result<Self> {
96        Ok(match self {
97            #[cfg(feature = "gguf")]
98            FloatTensor::GGUFFloatTensor(t) => FloatTensor::GGUFFloatTensor(t.reshape(shape)?),
99            FloatTensor::Float32Tensor(t) => FloatTensor::Float32Tensor(t.reshape(shape)?),
100        })
101    }
102    fn from_values(shape: Vec<u64>, values: Vec<Self::Item>) -> infa_impl::Result<Self> {
103        Float32Tensor::from_values(shape, values).map(FloatTensor::Float32Tensor)
104    }
105}
106
107impl<'a> std::ops::Add<&'a FloatTensor> for &'a FloatTensor {
108    type Output = infa_impl::Result<FloatTensor>;
109
110    fn add(self, rhs: Self) -> Self::Output {
111        TensorOps::add(self, rhs)
112    }
113}
114impl<'a> std::ops::Add<&'a f32> for &'a FloatTensor {
115    type Output = infa_impl::Result<FloatTensor>;
116
117    fn add(self, rhs: &f32) -> Self::Output {
118        TensorOps::add_item(self, rhs)
119    }
120}
121impl<'a> std::ops::Sub<&'a f32> for &'a FloatTensor {
122    type Output = infa_impl::Result<FloatTensor>;
123
124    fn sub(self, rhs: &f32) -> Self::Output {
125        TensorOps::sub_item(self, rhs)
126    }
127}
128impl<'a> std::ops::Sub<&'a FloatTensor> for &'a FloatTensor {
129    type Output = infa_impl::Result<FloatTensor>;
130
131    fn sub(self, rhs: &FloatTensor) -> Self::Output {
132        TensorOps::sub(self, rhs)
133    }
134}
135
136impl<'a> std::ops::Mul<&'a FloatTensor> for &'a FloatTensor {
137    type Output = infa_impl::Result<FloatTensor>;
138
139    fn mul(self, rhs: Self) -> Self::Output {
140        TensorOps::mul(self, rhs)
141    }
142}
143
144impl<'a> std::ops::Mul<&'a f32> for &'a FloatTensor {
145    type Output = infa_impl::Result<FloatTensor>;
146
147    fn mul(self, rhs: &f32) -> Self::Output {
148        TensorOps::mul_item(self, rhs)
149    }
150}
151impl<'a> std::ops::Div<&'a f32> for &'a FloatTensor {
152    type Output = infa_impl::Result<FloatTensor>;
153
154    fn div(self, rhs: &f32) -> Self::Output {
155        TensorOps::div_item(self, rhs)
156    }
157}
158impl<'a> std::ops::Div<&'a FloatTensor> for &'a FloatTensor {
159    type Output = infa_impl::Result<FloatTensor>;
160
161    fn div(self, rhs: Self) -> Self::Output {
162        TensorOps::div(self, rhs)
163    }
164}
165
166impl<'a> std::ops::Neg for &'a FloatTensor {
167    type Output = infa_impl::Result<FloatTensor>;
168
169    fn neg(self) -> Self::Output {
170        Ok(match self {
171            FloatTensor::Float32Tensor(s) => FloatTensor::Float32Tensor(s.neg()?),
172            #[cfg(feature = "gguf")]
173            FloatTensor::GGUFFloatTensor(s) => FloatTensor::Float32Tensor(s.neg()?),
174        })
175    }
176}