mathhook_core/matrices/unified/
operations.rs1use crate::core::Expression;
4use crate::error::MathError;
5use crate::matrices::types::*;
6use crate::matrices::unified::Matrix;
7use crate::simplify::Simplify;
8
9impl Matrix {
10 #[inline]
12 pub fn trace(&self) -> Expression {
13 match self {
14 Matrix::Identity(data) => Expression::integer(data.size as i64),
15 Matrix::Zero(_) => Expression::integer(0),
16 Matrix::Scalar(data) => Expression::mul(vec![
17 Expression::integer(data.size as i64),
18 data.scalar_value.clone(),
19 ])
20 .simplify(),
21 Matrix::Diagonal(data) => Expression::add(data.diagonal_elements.clone()).simplify(),
22 _ => {
23 let (rows, _) = self.dimensions();
24 let diagonal_elements: Vec<Expression> =
25 (0..rows).map(|i| self.get_element(i, i)).collect();
26 Expression::add(diagonal_elements).simplify()
27 }
28 }
29 }
30
31 pub fn determinant(&self) -> Result<Expression, MathError> {
45 let (rows, cols) = self.dimensions();
46
47 if rows != cols {
48 return Err(MathError::DomainError {
49 operation: "determinant".to_string(),
50 value: Expression::function("matrix", vec![]),
51 reason: format!("Determinant requires square matrix, got {}x{}", rows, cols),
52 });
53 }
54
55 match self {
56 Matrix::Identity(_) => return Ok(Expression::integer(1)),
57 Matrix::Zero(_) => return Ok(Expression::integer(0)),
58 Matrix::Scalar(data) => {
59 return Ok(Expression::pow(
60 data.scalar_value.clone(),
61 Expression::integer(data.size as i64),
62 )
63 .simplify())
64 }
65 Matrix::Diagonal(data) => {
66 return Ok(Expression::mul(data.diagonal_elements.clone()).simplify())
67 }
68 Matrix::Permutation(_data) => return Ok(Expression::integer(1)),
69 _ => {}
70 }
71
72 if let Some(numeric) = self.as_numeric() {
73 let det = numeric.determinant()?;
74 return Ok(Expression::float(det));
75 }
76
77 if rows == 0 {
78 return Ok(Expression::integer(1));
79 }
80
81 if rows == 1 {
82 return Ok(self.get_element(0, 0));
83 }
84
85 if rows == 2 {
86 let a = self.get_element(0, 0);
87 let b = self.get_element(0, 1);
88 let c = self.get_element(1, 0);
89 let d = self.get_element(1, 1);
90
91 let ad = Expression::mul(vec![a, d]);
92 let bc = Expression::mul(vec![b, c]);
93 return Ok(Expression::add(vec![
94 ad,
95 Expression::mul(vec![Expression::integer(-1), bc]),
96 ])
97 .simplify());
98 }
99
100 Ok(self.determinant_lu())
101 }
102
103 fn determinant_lu(&self) -> Expression {
109 let (n, _) = self.dimensions();
110
111 let mut a: Vec<Vec<Expression>> = (0..n)
112 .map(|i| (0..n).map(|j| self.get_element(i, j)).collect())
113 .collect();
114
115 let mut sign = 1i64;
116
117 for k in 0..n {
118 let pivot = ((k + 1)..n).find(|&i| !a[i][k].is_zero_fast()).unwrap_or(k);
119
120 if pivot != k {
121 a.swap(k, pivot);
122 sign = -sign;
123 }
124
125 if a[k][k].is_zero_fast() {
126 return Expression::integer(0);
127 }
128
129 for i in (k + 1)..n {
130 let factor = Expression::mul(vec![
131 a[i][k].clone(),
132 Expression::pow(a[k][k].clone(), Expression::integer(-1)),
133 ]);
134
135 let pivot_row: Vec<Expression> = a[k][k..n].to_vec();
136 for (j_offset, pivot_val) in pivot_row.into_iter().enumerate() {
137 let j = k + j_offset;
138 let subtraction = Expression::mul(vec![factor.clone(), pivot_val]);
139 a[i][j] = a[i][j].clone() - subtraction;
140 }
141 }
142 }
143
144 let det_u: Vec<Expression> = (0..n).map(|i| a[i][i].clone()).collect();
145
146 Expression::mul(vec![Expression::integer(sign), Expression::mul(det_u)])
147 }
148
149 pub fn scalar_multiply(&self, scalar: &Expression) -> Matrix {
151 match self {
152 Matrix::Zero(data) => Matrix::Zero(data.clone()),
153 Matrix::Identity(data) => Matrix::Scalar(ScalarMatrixData {
154 size: data.size,
155 scalar_value: scalar.clone(),
156 }),
157 Matrix::Scalar(data) => Matrix::Scalar(ScalarMatrixData {
158 size: data.size,
159 scalar_value: Expression::mul(vec![scalar.clone(), data.scalar_value.clone()])
160 .simplify(),
161 }),
162 Matrix::Diagonal(data) => {
163 let scaled_elements: Vec<Expression> = data
164 .diagonal_elements
165 .iter()
166 .map(|elem| Expression::mul(vec![scalar.clone(), elem.clone()]).simplify())
167 .collect();
168 Matrix::Diagonal(DiagonalMatrixData {
169 diagonal_elements: scaled_elements,
170 })
171 }
172 _ => {
173 let (rows, cols) = self.dimensions();
174 let mut result_rows = Vec::with_capacity(rows);
175 for i in 0..rows {
176 let mut row = Vec::with_capacity(cols);
177 for j in 0..cols {
178 let elem = self.get_element(i, j);
179 let product = Expression::mul(vec![scalar.clone(), elem]).simplify();
180 row.push(product);
181 }
182 result_rows.push(row);
183 }
184 Matrix::Dense(MatrixData { rows: result_rows }).optimize()
185 }
186 }
187 }
188}
189
190pub trait CoreMatrixOps {
192 fn add(&self, other: &Matrix) -> Result<Matrix, MathError>;
193 fn multiply(&self, other: &Matrix) -> Result<Matrix, MathError>;
194 fn transpose(&self) -> Matrix;
195 fn inverse(&self) -> Matrix;
196}
197
198impl CoreMatrixOps for Matrix {
199 fn add(&self, other: &Matrix) -> Result<Matrix, MathError> {
200 let (rows1, cols1) = self.dimensions();
201 let (rows2, cols2) = other.dimensions();
202
203 if rows1 != rows2 || cols1 != cols2 {
204 return Err(MathError::DomainError {
205 operation: "matrix_addition".to_string(),
206 value: Expression::function("incompatible_matrices", vec![]),
207 reason: format!(
208 "Cannot add {}x{} matrix to {}x{} matrix",
209 rows1, cols1, rows2, cols2
210 ),
211 });
212 }
213
214 let result = match (self, other) {
215 (Matrix::Zero(_), other) => other.clone(),
216 (this, Matrix::Zero(_)) => this.clone(),
217
218 (Matrix::Identity(id), Matrix::Dense(dense))
219 | (Matrix::Dense(dense), Matrix::Identity(id)) => {
220 let mut result_rows = dense.rows.clone();
221 for i in 0..id.size.min(result_rows.len()) {
222 if let Some(row) = result_rows.get_mut(i) {
223 if let Some(elem) = row.get_mut(i) {
224 *elem = Expression::add(vec![elem.clone(), Expression::integer(1)]);
225 }
226 }
227 }
228 Matrix::Dense(MatrixData { rows: result_rows })
229 }
230
231 (Matrix::Diagonal(d1), Matrix::Diagonal(d2))
232 if d1.diagonal_elements.len() == d2.diagonal_elements.len() =>
233 {
234 let result_elements: Vec<Expression> = d1
235 .diagonal_elements
236 .iter()
237 .zip(d2.diagonal_elements.iter())
238 .map(|(a, b)| Expression::add(vec![a.clone(), b.clone()]).simplify())
239 .collect();
240 Matrix::Diagonal(DiagonalMatrixData {
241 diagonal_elements: result_elements,
242 })
243 }
244
245 (Matrix::Identity(id), Matrix::Diagonal(diag))
246 | (Matrix::Diagonal(diag), Matrix::Identity(id))
247 if diag.diagonal_elements.len() == id.size =>
248 {
249 let result_elements: Vec<Expression> = diag
250 .diagonal_elements
251 .iter()
252 .map(|elem| {
253 Expression::add(vec![elem.clone(), Expression::integer(1)]).simplify()
254 })
255 .collect();
256 Matrix::Diagonal(DiagonalMatrixData {
257 diagonal_elements: result_elements,
258 })
259 }
260
261 (Matrix::Scalar(s1), Matrix::Scalar(s2)) if s1.size == s2.size => {
262 Matrix::Scalar(ScalarMatrixData {
263 size: s1.size,
264 scalar_value: Expression::add(vec![
265 s1.scalar_value.clone(),
266 s2.scalar_value.clone(),
267 ])
268 .simplify(),
269 })
270 }
271
272 _ => {
273 let mut result_rows = Vec::with_capacity(rows1);
274 for i in 0..rows1 {
275 let mut row = Vec::with_capacity(cols1);
276 for j in 0..cols1 {
277 let elem1 = self.get_element(i, j);
278 let elem2 = other.get_element(i, j);
279 let sum = Expression::add(vec![elem1, elem2]).simplify();
280 row.push(sum);
281 }
282 result_rows.push(row);
283 }
284
285 Matrix::Dense(MatrixData { rows: result_rows }).optimize()
286 }
287 };
288
289 Ok(result)
290 }
291
292 fn multiply(&self, other: &Matrix) -> Result<Matrix, MathError> {
293 let (rows1, cols1) = self.dimensions();
294 let (rows2, cols2) = other.dimensions();
295
296 if cols1 != rows2 {
297 return Err(MathError::DomainError {
298 operation: "matrix_multiplication".to_string(),
299 value: Expression::function("incompatible_matrices", vec![]),
300 reason: format!(
301 "Cannot multiply {}x{} matrix by {}x{} matrix (inner dimensions {} != {})",
302 rows1, cols1, rows2, cols2, cols1, rows2
303 ),
304 });
305 }
306
307 let result = match (self, other) {
308 (Matrix::Zero(_), _) => Matrix::Zero(ZeroMatrixData {
309 rows: rows1,
310 cols: cols2,
311 }),
312 (_, Matrix::Zero(_)) => Matrix::Zero(ZeroMatrixData {
313 rows: rows1,
314 cols: cols2,
315 }),
316
317 (Matrix::Identity(_), other) => other.clone(),
318 (this, Matrix::Identity(_)) => this.clone(),
319
320 (Matrix::Diagonal(d1), Matrix::Diagonal(d2))
321 if d1.diagonal_elements.len() == d2.diagonal_elements.len() =>
322 {
323 let result_elements: Vec<Expression> = d1
324 .diagonal_elements
325 .iter()
326 .zip(d2.diagonal_elements.iter())
327 .map(|(a, b)| Expression::mul(vec![a.clone(), b.clone()]))
328 .collect();
329 Matrix::Diagonal(DiagonalMatrixData {
330 diagonal_elements: result_elements,
331 })
332 }
333
334 (Matrix::Scalar(s1), Matrix::Scalar(s2)) if s1.size == s2.size => {
335 let product_scalar =
336 Expression::mul(vec![s1.scalar_value.clone(), s2.scalar_value.clone()]);
337 Matrix::Scalar(ScalarMatrixData {
338 size: s1.size,
339 scalar_value: product_scalar,
340 })
341 }
342 (Matrix::Scalar(s), other) => other.scalar_multiply(&s.scalar_value),
343 (this, Matrix::Scalar(s)) => this.scalar_multiply(&s.scalar_value),
344
345 _ => {
346 if let (Some(num1), Some(num2)) = (self.as_numeric(), other.as_numeric()) {
347 let result = num1.multiply(&num2)?;
348 return Ok(result.to_matrix().optimize());
349 }
350
351 let mut result_rows = Vec::with_capacity(rows1);
352 for i in 0..rows1 {
353 let mut row = Vec::with_capacity(cols2);
354 for j in 0..cols2 {
355 let mut sum_terms = Vec::with_capacity(cols1);
356 for k in 0..cols1 {
357 let elem1 = self.get_element(i, k);
358 let elem2 = other.get_element(k, j);
359 sum_terms.push(Expression::mul(vec![elem1, elem2]));
360 }
361 let sum = Expression::add(sum_terms);
362 row.push(sum);
363 }
364 result_rows.push(row);
365 }
366
367 Matrix::Dense(MatrixData { rows: result_rows })
368 }
369 };
370
371 Ok(result)
372 }
373
374 fn transpose(&self) -> Matrix {
375 match self {
376 Matrix::Identity(data) => Matrix::Identity(data.clone()),
377 Matrix::Zero(data) => Matrix::Zero(ZeroMatrixData {
378 rows: data.cols,
379 cols: data.rows,
380 }),
381 Matrix::Scalar(data) => Matrix::Scalar(data.clone()),
382 Matrix::Diagonal(data) => Matrix::Diagonal(data.clone()),
383 Matrix::Symmetric(data) => Matrix::Symmetric(data.clone()),
384 Matrix::UpperTriangular(data) => Matrix::LowerTriangular(LowerTriangularMatrixData {
385 size: data.size,
386 elements: data.elements.clone(),
387 }),
388 Matrix::LowerTriangular(data) => Matrix::UpperTriangular(UpperTriangularMatrixData {
389 size: data.size,
390 elements: data.elements.clone(),
391 }),
392 _ => {
393 let (rows, cols) = self.dimensions();
394 let mut result_rows = Vec::with_capacity(cols);
395 for j in 0..cols {
396 let mut row = Vec::with_capacity(rows);
397 for i in 0..rows {
398 row.push(self.get_element(i, j));
399 }
400 result_rows.push(row);
401 }
402 Matrix::Dense(MatrixData { rows: result_rows }).optimize()
403 }
404 }
405 }
406
407 fn inverse(&self) -> Matrix {
408 match self {
409 Matrix::Identity(data) => Matrix::Identity(data.clone()),
410 Matrix::Scalar(data) => {
411 let inverse_scalar =
412 Expression::pow(data.scalar_value.clone(), Expression::integer(-1)).simplify();
413 Matrix::Scalar(ScalarMatrixData {
414 size: data.size,
415 scalar_value: inverse_scalar,
416 })
417 }
418 Matrix::Diagonal(data) => {
419 let inverse_elements: Vec<Expression> = data
420 .diagonal_elements
421 .iter()
422 .map(|elem| Expression::pow(elem.clone(), Expression::integer(-1)).simplify())
423 .collect();
424 Matrix::Diagonal(DiagonalMatrixData {
425 diagonal_elements: inverse_elements,
426 })
427 }
428 _ => {
429 if let Some(numeric) = self.as_numeric() {
430 if let Ok(inv) = numeric.inverse() {
431 return inv.to_matrix();
432 }
433 }
434
435 if let Some(inv) = self.inverse_via_lu() {
436 inv
437 } else {
438 self.gauss_jordan_inverse()
439 }
440 }
441 }
442 }
443}