1use crate::core::constants::EPSILON;
4use crate::core::{Expression, Number};
5use crate::simplify::Simplify;
6
7pub struct PartialUtils;
9
10impl PartialUtils {
11 pub fn expressions_equal(expr1: &Expression, expr2: &Expression) -> bool {
28 if std::ptr::eq(expr1, expr2) {
29 return true;
30 }
31
32 match (expr1, expr2) {
33 (Expression::Number(n1), Expression::Number(n2)) => n1 == n2,
34 (Expression::Symbol(s1), Expression::Symbol(s2)) => s1 == s2,
35 _ => format!("{:?}", expr1.simplify()) == format!("{:?}", expr2.simplify()),
36 }
37 }
38
39 pub fn is_zero(expr: &Expression) -> bool {
51 match expr {
52 Expression::Number(Number::Integer(0)) => true,
53 Expression::Number(Number::Float(f)) if f.abs() < EPSILON => true,
54 _ => matches!(expr.simplify(), Expression::Number(Number::Integer(0))),
55 }
56 }
57
58 pub fn validate_dimensions(name: &str, expected: usize, actual: usize) -> Result<(), String> {
68 if expected != actual {
69 Err(format!(
70 "{}: dimension mismatch - expected {}, got {}",
71 name, expected, actual
72 ))
73 } else {
74 Ok(())
75 }
76 }
77}
78
79pub struct MatrixUtils;
81
82impl MatrixUtils {
83 pub fn determinant(matrix: &[Vec<Expression>]) -> Expression {
97 let n = matrix.len();
98 if n == 0 {
99 panic!("Matrix must be square and non-empty");
100 }
101
102 let expected_cols = matrix[0].len();
104 if expected_cols != n {
105 panic!("Matrix must be square and non-empty");
106 }
107
108 for row in matrix.iter() {
109 if row.len() != expected_cols {
110 panic!("Matrix must be square and non-empty");
111 }
112 }
113
114 match n {
115 1 => matrix[0][0].clone(),
116 2 => Self::det_2x2(matrix),
117 3 => Self::det_3x3(matrix),
118 _ => Self::det_symbolic(matrix),
119 }
120 }
121
122 fn det_2x2(matrix: &[Vec<Expression>]) -> Expression {
125 let ad = Expression::mul(vec![matrix[0][0].clone(), matrix[1][1].clone()]).simplify();
126 let bc = Expression::mul(vec![matrix[0][1].clone(), matrix[1][0].clone()]).simplify();
127 let neg_bc = Expression::mul(vec![Expression::integer(-1), bc]).simplify();
128
129 Expression::add(vec![ad, neg_bc]).simplify()
130 }
131
132 fn det_3x3(matrix: &[Vec<Expression>]) -> Expression {
134 let mut terms = Vec::with_capacity(3);
135
136 for i in 0..3 {
137 let sign = if i % 2 == 0 { 1 } else { -1 };
138 let cofactor = Self::cofactor_2x2(matrix, 0, i);
139 terms.push(Expression::mul(vec![
140 Expression::integer(sign),
141 matrix[0][i].clone(),
142 cofactor,
143 ]));
144 }
145
146 Expression::add(terms).simplify()
147 }
148
149 fn cofactor_2x2(matrix: &[Vec<Expression>], skip_row: usize, skip_col: usize) -> Expression {
151 let elements: Vec<Expression> = (0..3)
152 .filter(|&i| i != skip_row)
153 .flat_map(|i| {
154 (0..3)
155 .filter(|&j| j != skip_col)
156 .map(move |j| matrix[i][j].clone())
157 })
158 .collect();
159
160 let ad = Expression::mul(vec![elements[0].clone(), elements[3].clone()]).simplify();
162 let bc = Expression::mul(vec![elements[1].clone(), elements[2].clone()]).simplify();
163 let neg_bc = Expression::mul(vec![Expression::integer(-1), bc]).simplify();
164
165 Expression::add(vec![ad, neg_bc]).simplify()
166 }
167
168 fn det_symbolic(matrix: &[Vec<Expression>]) -> Expression {
170 Expression::function(
171 "det",
172 vec![Expression::function(
173 "matrix",
174 matrix.iter().flat_map(|row| row.iter().cloned()).collect(),
175 )],
176 )
177 }
178}
179
180#[cfg(test)]
181mod tests {
182 use super::*;
183 use crate::symbol;
184 use crate::Symbol;
185 use std::f64::consts::PI;
186
187 fn test_symbols() -> (Symbol, Symbol, Symbol) {
188 (symbol!(x), symbol!(y), symbol!(z))
189 }
190
191 #[test]
192 fn test_expression_equality() {
193 let (x, y, _) = test_symbols();
194
195 let expr1 = Expression::symbol(x.clone());
197 let expr2 = Expression::symbol(x.clone());
198 assert!(PartialUtils::expressions_equal(&expr1, &expr2));
199
200 let expr3 = Expression::symbol(y);
202 assert!(!PartialUtils::expressions_equal(&expr1, &expr3));
203
204 let num1 = Expression::integer(42);
206 let num2 = Expression::integer(42);
207 assert!(PartialUtils::expressions_equal(&num1, &num2));
208
209 let num3 = Expression::integer(24);
211 assert!(!PartialUtils::expressions_equal(&num1, &num3));
212
213 let float1 = Expression::float(PI);
215 let float2 = Expression::float(PI);
216 assert!(PartialUtils::expressions_equal(&float1, &float2));
217 }
218
219 #[test]
220 fn test_complex_expression_equality() {
221 let (x, _, _) = test_symbols();
222
223 let expr1 = Expression::add(vec![Expression::symbol(x.clone()), Expression::integer(1)]);
225 let expr2 = Expression::add(vec![Expression::symbol(x.clone()), Expression::integer(1)]);
226 assert!(PartialUtils::expressions_equal(&expr1, &expr2));
227
228 let poly1 = Expression::pow(Expression::symbol(x.clone()), Expression::integer(2));
230 let poly2 = Expression::pow(Expression::symbol(x.clone()), Expression::integer(2));
231 assert!(PartialUtils::expressions_equal(&poly1, &poly2));
232
233 let mult1 = Expression::mul(vec![Expression::integer(2), Expression::symbol(x.clone())]);
235 let mult2 = Expression::mul(vec![Expression::integer(2), Expression::symbol(x)]);
236 assert!(PartialUtils::expressions_equal(&mult1, &mult2));
237 }
238
239 #[test]
240 fn test_zero_detection() {
241 assert!(PartialUtils::is_zero(&Expression::integer(0)));
243
244 assert!(PartialUtils::is_zero(&Expression::float(0.0)));
246
247 assert!(!PartialUtils::is_zero(&Expression::integer(1)));
249 assert!(!PartialUtils::is_zero(&Expression::integer(-5)));
250
251 assert!(!PartialUtils::is_zero(&Expression::float(PI)));
253 assert!(!PartialUtils::is_zero(&Expression::float(-2.71)));
254
255 let x = symbol!(x);
257 assert!(!PartialUtils::is_zero(&Expression::symbol(x)));
258 }
259
260 #[test]
261 fn test_zero_expressions() {
262 let (x, _, _) = test_symbols();
263
264 let zero_sum = Expression::add(vec![Expression::integer(0), Expression::integer(0)]);
266 assert!(PartialUtils::is_zero(&zero_sum));
267
268 let zero_mult =
270 Expression::mul(vec![Expression::integer(0), Expression::symbol(x.clone())]);
271 assert!(PartialUtils::is_zero(&zero_mult));
272
273 let diff = Expression::add(vec![
275 Expression::symbol(x.clone()),
276 Expression::mul(vec![Expression::integer(-1), Expression::symbol(x)]),
277 ]);
278 assert!(PartialUtils::is_zero(&diff));
279 }
280
281 #[test]
282 fn test_dimension_validation() {
283 assert!(PartialUtils::validate_dimensions("test", 3, 3).is_ok());
285 assert!(PartialUtils::validate_dimensions("gradient", 2, 2).is_ok());
286 assert!(PartialUtils::validate_dimensions("hessian", 4, 4).is_ok());
287
288 let result = PartialUtils::validate_dimensions("jacobian", 3, 2);
290 let error_message = result.unwrap_err();
291
292 assert!(error_message.contains("dimension mismatch"));
293 assert!(error_message.contains("expected 3"));
294 assert!(error_message.contains("got 2"));
295
296 assert!(PartialUtils::validate_dimensions("empty", 0, 0).is_ok());
298 let zero_error = PartialUtils::validate_dimensions("non-empty", 1, 0);
299 assert!(zero_error.is_err());
300 }
301
302 #[test]
303 fn test_1x1_determinant() {
304 let matrix = vec![vec![Expression::integer(5)]];
306 let det = MatrixUtils::determinant(&matrix);
307 assert_eq!(det, Expression::integer(5));
308
309 let x = symbol!(x);
311 let matrix_x = vec![vec![Expression::symbol(x.clone())]];
312 let det_x = MatrixUtils::determinant(&matrix_x);
313 assert_eq!(det_x, Expression::symbol(x));
314 }
315
316 #[test]
317 fn test_2x2_determinant() {
318 let matrix = vec![
321 vec![Expression::integer(1), Expression::integer(2)],
322 vec![Expression::integer(3), Expression::integer(4)],
323 ];
324 let det = MatrixUtils::determinant(&matrix);
325 assert_eq!(det.simplify(), Expression::integer(-2));
326
327 let (a, b, c) = (symbol!(a), symbol!(b), symbol!(c));
330 let d = symbol!(d);
331 let symbolic_matrix = vec![
332 vec![Expression::symbol(a.clone()), Expression::symbol(b.clone())],
333 vec![Expression::symbol(c.clone()), Expression::symbol(d.clone())],
334 ];
335 let symbolic_det = MatrixUtils::determinant(&symbolic_matrix);
336
337 let expected = Expression::add(vec![
338 Expression::mul(vec![Expression::symbol(a), Expression::symbol(d)]), Expression::mul(vec![
340 Expression::integer(-1),
341 Expression::mul(vec![Expression::symbol(b), Expression::symbol(c)]), ]),
343 ]);
344 assert_eq!(symbolic_det.simplify(), expected.simplify());
345 }
346
347 #[test]
348 fn test_3x3_determinant() {
349 let identity = vec![
353 vec![
354 Expression::integer(1),
355 Expression::integer(0),
356 Expression::integer(0),
357 ],
358 vec![
359 Expression::integer(0),
360 Expression::integer(1),
361 Expression::integer(0),
362 ],
363 vec![
364 Expression::integer(0),
365 Expression::integer(0),
366 Expression::integer(1),
367 ],
368 ];
369 let det = MatrixUtils::determinant(&identity);
370 assert_eq!(det.simplify(), Expression::integer(1));
371
372 let singular = vec![
376 vec![
377 Expression::integer(1),
378 Expression::integer(2),
379 Expression::integer(3),
380 ],
381 vec![
382 Expression::integer(4),
383 Expression::integer(5),
384 Expression::integer(6),
385 ],
386 vec![
387 Expression::integer(7),
388 Expression::integer(8),
389 Expression::integer(9),
390 ],
391 ];
392 let det_singular = MatrixUtils::determinant(&singular);
393 assert_eq!(det_singular.simplify(), Expression::integer(0));
394 }
395
396 #[test]
397 fn test_3x3_symbolic_determinant() {
398 let (x, y, z) = test_symbols();
399
400 let diagonal = vec![
404 vec![
405 Expression::symbol(x.clone()),
406 Expression::integer(0),
407 Expression::integer(0),
408 ],
409 vec![
410 Expression::integer(0),
411 Expression::symbol(y.clone()),
412 Expression::integer(0),
413 ],
414 vec![
415 Expression::integer(0),
416 Expression::integer(0),
417 Expression::symbol(z.clone()),
418 ],
419 ];
420 let det = MatrixUtils::determinant(&diagonal);
421
422 let expected = Expression::mul(vec![
423 Expression::symbol(x),
424 Expression::symbol(y),
425 Expression::symbol(z),
426 ]);
427 assert_eq!(det.simplify(), expected.simplify());
428 }
429
430 #[test]
431 fn test_large_matrix_symbolic() {
432 let matrix = vec![
434 vec![
435 Expression::integer(1),
436 Expression::integer(2),
437 Expression::integer(3),
438 Expression::integer(4),
439 ],
440 vec![
441 Expression::integer(5),
442 Expression::integer(6),
443 Expression::integer(7),
444 Expression::integer(8),
445 ],
446 vec![
447 Expression::integer(9),
448 Expression::integer(10),
449 Expression::integer(11),
450 Expression::integer(12),
451 ],
452 vec![
453 Expression::integer(13),
454 Expression::integer(14),
455 Expression::integer(15),
456 Expression::integer(16),
457 ],
458 ];
459
460 let det = MatrixUtils::determinant(&matrix);
461
462 match det {
464 Expression::Function { name, .. } => {
465 assert_eq!(name.as_ref(), "det");
466 }
467 _ => panic!("Expected function call for large matrix determinant"),
468 }
469 }
470
471 #[test]
472 fn test_special_matrices() {
473 let zero_matrix = vec![
476 vec![Expression::integer(0), Expression::integer(0)],
477 vec![Expression::integer(0), Expression::integer(0)],
478 ];
479 let det_zero = MatrixUtils::determinant(&zero_matrix);
480 assert_eq!(det_zero.simplify(), Expression::integer(0));
481
482 let upper_tri = vec![
485 vec![Expression::integer(1), Expression::integer(2)],
486 vec![Expression::integer(0), Expression::integer(3)],
487 ];
488 let det_tri = MatrixUtils::determinant(&upper_tri);
489 assert_eq!(det_tri.simplify(), Expression::integer(3));
490 }
491
492 #[test]
493 fn test_rational_determinant() {
494 let rational_matrix = vec![
497 vec![Expression::rational(1, 2), Expression::rational(1, 3)],
498 vec![Expression::rational(1, 4), Expression::rational(1, 5)],
499 ];
500 let det = MatrixUtils::determinant(&rational_matrix);
501
502 let expected = Expression::rational(1, 60);
504 assert_eq!(det.simplify(), expected.simplify());
505 }
506
507 #[test]
508 #[should_panic(expected = "Matrix must be square and non-empty")]
509 fn test_non_square_matrix_panic() {
510 let non_square = vec![
511 vec![Expression::integer(1), Expression::integer(2)],
512 vec![
513 Expression::integer(3),
514 Expression::integer(4),
515 Expression::integer(5),
516 ],
517 ];
518 MatrixUtils::determinant(&non_square);
519 }
520
521 #[test]
522 #[should_panic(expected = "Matrix must be square and non-empty")]
523 fn test_empty_matrix_panic() {
524 let empty: Vec<Vec<Expression>> = vec![];
525 MatrixUtils::determinant(&empty);
526 }
527}