1use std::fmt::Display;
9use std::ops::Add;
10use std::ops::AddAssign;
11use std::ops::Div;
12use std::ops::Mul;
13use std::ops::MulAssign;
14use std::ops::Neg;
15use std::ops::Sub;
16use std::ops::SubAssign;
17use std::str::FromStr;
18
19use super::PolynomialRingElement;
20
21#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
22#[derive(Debug, Clone, PartialEq, Hash, Eq)]
23pub struct Matrix<T: PolynomialRingElement> {
24 pub dimensions: Vec<usize>,
26 pub values: Vec<T>,
27}
28
29impl<T: PolynomialRingElement> Matrix<T> {
30 pub fn len(&self) -> usize {
31 self.values.len()
32 }
33
34 #[allow(dead_code)]
35 pub fn is_empty(&self) -> bool {
36 self.values.is_empty()
37 }
38
39 pub fn mul_scalar(&self, v: T) -> Self {
40 let values = self.values.iter().map(|x| x.clone() * v.clone()).collect();
41 Matrix {
42 dimensions: self.dimensions.clone(),
43 values,
44 }
45 }
46
47 pub fn invert(&self) -> Self {
48 let values = self.values.iter().map(|x| T::one() / x.clone()).collect();
49 Matrix {
50 dimensions: self.dimensions.clone(),
51 values,
52 }
53 }
54
55 pub fn retrieve_indices(&self, indices: &[usize]) -> (Self, usize) {
58 let mul_sum = |vec: &Vec<usize>, start: usize| -> usize {
59 let mut out = 1;
60 for v in &vec[start..] {
61 out *= v;
62 }
63 out
64 };
65 let mut offset = 0;
66 for x in 0..indices.len() {
67 if x == indices.len() - 1 && indices.len() == self.dimensions.len() {
70 offset += indices[x];
71 } else {
72 offset += indices[x] * mul_sum(&self.dimensions, x + 1);
73 }
74 }
75
76 let mut new_dimensions = vec![];
77 for x in indices.len()..self.dimensions.len() {
78 new_dimensions.push(self.dimensions[x]);
79 }
80 if new_dimensions.is_empty() {
82 new_dimensions.push(1);
83 }
84 let offset_end = if indices.len() == self.dimensions.len() {
85 offset + 1
86 } else {
87 offset + mul_sum(&self.dimensions, indices.len())
88 };
89 (
90 Self {
91 dimensions: new_dimensions,
92 values: self.values[offset..offset_end].to_vec(),
93 },
94 offset,
95 )
96 }
97
98 pub fn _assert_internal_consistency(&self) {
99 assert_eq!(self.values.len(), self.dimensions.iter().product::<usize>());
100 }
101
102 pub fn assert_eq_shape(&self, m: &Matrix<T>) {
103 if self.dimensions.len() != m.dimensions.len() {
104 panic!("lhs and rhs dimensions are not equal: {:?} {:?}", self, m);
105 }
106 for x in 0..m.dimensions.len() {
107 if self.dimensions[x] != m.dimensions[x] {
108 panic!(
109 "lhs and rhs inner dimensions are not equal: {:?} {:?}",
110 self, m
111 );
112 }
113 }
114 }
115}
116
117impl<T: PolynomialRingElement> Add for Matrix<T> {
118 type Output = Self;
119
120 fn add(self, other: Self) -> Self {
121 self.assert_eq_shape(&other);
122 let values = self
123 .values
124 .iter()
125 .zip(other.values.iter())
126 .map(|(a, b)| a.clone() + b.clone())
127 .collect();
128 Matrix {
129 dimensions: self.dimensions,
130 values,
131 }
132 }
133}
134
135impl<T: PolynomialRingElement> AddAssign for Matrix<T> {
136 fn add_assign(&mut self, other: Self) {
137 self.assert_eq_shape(&other);
138 for i in 0..self.values.len() {
139 self.values[i] += other.values[i].clone();
140 }
141 }
142}
143
144impl<T: PolynomialRingElement> Sub for Matrix<T> {
145 type Output = Self;
146
147 fn sub(self, other: Self) -> Self {
148 self.assert_eq_shape(&other);
149 let values = self
150 .values
151 .iter()
152 .zip(other.values.iter())
153 .map(|(a, b)| a.clone() - b.clone())
154 .collect();
155 Matrix {
156 dimensions: self.dimensions,
157 values,
158 }
159 }
160}
161
162impl<T: PolynomialRingElement> SubAssign for Matrix<T> {
163 fn sub_assign(&mut self, other: Self) {
164 self.assert_eq_shape(&other);
165 for i in 0..self.values.len() {
166 self.values[i] -= other.values[i].clone();
167 }
168 }
169}
170
171impl<T: PolynomialRingElement> Mul for Matrix<T> {
172 type Output = Self;
173
174 fn mul(self, other: Self) -> Self {
175 self.assert_eq_shape(&other);
176 let values = self
177 .values
178 .iter()
179 .zip(other.values.iter())
180 .map(|(a, b)| a.clone() * b.clone())
181 .collect();
182 Matrix {
183 dimensions: self.dimensions,
184 values,
185 }
186 }
187}
188
189impl<T: PolynomialRingElement> MulAssign for Matrix<T> {
190 fn mul_assign(&mut self, other: Self) {
191 self.assert_eq_shape(&other);
192 for i in 0..self.values.len() {
193 self.values[i] *= other.values[i].clone();
194 }
195 }
196}
197
198impl<T: PolynomialRingElement> Div for Matrix<T> {
199 type Output = Self;
200
201 fn div(self, other: Self) -> Self {
202 self.assert_eq_shape(&other);
203 let values = self
204 .values
205 .iter()
206 .zip(other.values.iter())
207 .map(|(a, b)| a.clone() / b.clone())
208 .collect();
209 Matrix {
210 dimensions: self.dimensions,
211 values,
212 }
213 }
214}
215
216impl<T: PolynomialRingElement> Neg for Matrix<T> {
217 type Output = Self;
218
219 fn neg(self) -> Self {
220 let values = self.values.iter().map(|x| -x.clone()).collect();
221 Matrix {
222 dimensions: self.dimensions,
223 values,
224 }
225 }
226}
227
228impl<T: PolynomialRingElement> From<T> for Matrix<T> {
229 fn from(v: T) -> Self {
230 Matrix {
231 dimensions: vec![1],
232 values: vec![v],
233 }
234 }
235}
236
237impl<T: PolynomialRingElement> From<u64> for Matrix<T> {
238 fn from(v: u64) -> Self {
239 Matrix::from(T::from(v))
240 }
241}
242
243impl<T: PolynomialRingElement> FromStr for Matrix<T> {
244 type Err = T::Err;
245
246 fn from_str(s: &str) -> Result<Self, Self::Err> {
247 Ok(Matrix::from(T::from_str(s)?))
248 }
249}
250
251impl<T: PolynomialRingElement> Display for Matrix<T> {
252 fn fmt(&self, f: &mut std::fmt::Formatter) -> Result<(), std::fmt::Error> {
254 let mut s = String::new();
255 s.push_str(&format!(
256 "dimensions: {}\n",
257 self.dimensions
258 .clone()
259 .into_iter()
260 .map(|x| x.to_string())
261 .collect::<Vec<_>>()
262 .join("x")
263 ));
264 for i in 0..self.values.len() {
265 s.push_str(&format!("{}, ", self.values[i]));
266 }
267 write!(f, "{}", s)
268 }
269}