mathhook_core/calculus/ode/
systems.rs1use crate::algebra::solvers::{linear::LinearSolver, EquationSolver, SolverResult};
10use crate::calculus::ode::first_order::ODEError;
11use crate::core::{Expression, Symbol};
12use crate::matrices::Matrix;
13use crate::simplify::Simplify;
14use std::collections::HashMap;
15
16pub struct LinearSystemSolver;
20
21impl LinearSystemSolver {
22 pub fn solve(
59 &self,
60 coefficient_matrix: &Matrix,
61 independent_var: &Symbol,
62 initial_conditions: Option<Vec<Expression>>,
63 ) -> Result<Vec<Expression>, ODEError> {
64 let (rows, cols) = coefficient_matrix.dimensions();
65
66 if rows != cols {
67 return Err(ODEError::NotLinearForm {
68 reason: format!("Coefficient matrix must be square, got {}×{}", rows, cols),
69 });
70 }
71
72 let n = rows;
73
74 if !coefficient_matrix.is_diagonalizable() {
75 return Err(ODEError::NotImplemented {
76 feature: "Non-diagonalizable systems (requires Jordan normal form)".to_owned(),
77 });
78 }
79
80 let eigen_decomp =
81 coefficient_matrix
82 .eigen_decomposition()
83 .ok_or_else(|| ODEError::NotImplemented {
84 feature: "Eigendecomposition failed".to_owned(),
85 })?;
86
87 let eigenvalues = &eigen_decomp.eigenvalues;
88 let eigenvectors = &eigen_decomp.eigenvectors;
89
90 let solution_components: Vec<Vec<Expression>> = eigenvalues
91 .iter()
92 .enumerate()
93 .map(|(i, lambda)| {
94 let eigenvector_col: Vec<Expression> = (0..n)
95 .map(|row_idx| eigenvectors.get_element(row_idx, i))
96 .collect();
97
98 let exponent = Expression::mul(vec![
99 lambda.clone(),
100 Expression::symbol(independent_var.clone()),
101 ]);
102 let exp_term = Expression::function("exp", vec![exponent]);
103
104 let c_symbol = Symbol::new(format!("C{}", i + 1));
105 let c = Expression::symbol(c_symbol);
106
107 let scaled_exp = Expression::mul(vec![c, exp_term]);
108
109 eigenvector_col
110 .into_iter()
111 .map(|component| Expression::mul(vec![scaled_exp.clone(), component]))
112 .collect()
113 })
114 .collect();
115
116 let final_solution: Vec<Expression> = (0..n)
117 .map(|i| {
118 let sum_terms: Vec<Expression> = solution_components
119 .iter()
120 .map(|comp| comp[i].clone())
121 .collect();
122 Expression::add(sum_terms).simplify()
123 })
124 .collect();
125
126 if let Some(ic) = initial_conditions {
127 return self.apply_initial_conditions(&final_solution, &ic, n, eigenvectors);
128 }
129
130 Ok(final_solution)
131 }
132
133 fn apply_initial_conditions(
140 &self,
141 general_solution: &[Expression],
142 initial_conditions: &[Expression],
143 n: usize,
144 eigenvectors: &Matrix,
145 ) -> Result<Vec<Expression>, ODEError> {
146 if initial_conditions.len() != n {
147 return Err(ODEError::NotLinearForm {
148 reason: format!(
149 "Initial conditions length {} does not match system size {}",
150 initial_conditions.len(),
151 n
152 ),
153 });
154 }
155
156 let linear_solver = LinearSolver::new_fast();
157 let mut constant_values: HashMap<String, Expression> = HashMap::new();
158
159 for i in 0..n {
160 let constant_name = format!("C{}", i + 1);
161 let equation = self.build_constant_equation(i, n, eigenvectors, initial_conditions);
162
163 let substituted_equation = if i == 0 {
164 equation
165 } else {
166 equation.substitute(&constant_values).simplify()
167 };
168
169 let constant_symbol = Symbol::new(&constant_name);
170 let value = self.solve_for_constant(
171 &linear_solver,
172 &substituted_equation,
173 &constant_symbol,
174 &constant_name,
175 )?;
176
177 constant_values.insert(constant_name, value);
178 }
179
180 let particular_solution: Vec<Expression> = general_solution
181 .iter()
182 .map(|expr| expr.substitute(&constant_values).simplify())
183 .collect();
184
185 Ok(particular_solution)
186 }
187
188 fn build_constant_equation(
192 &self,
193 row_index: usize,
194 n: usize,
195 eigenvectors: &Matrix,
196 initial_conditions: &[Expression],
197 ) -> Expression {
198 let mut equation_terms = Vec::new();
199
200 for j in 0..n {
201 let eigenvector_component = eigenvectors.get_element(row_index, j);
202 let c_symbol = Symbol::new(format!("C{}", j + 1));
203 let term = Expression::mul(vec![eigenvector_component, Expression::symbol(c_symbol)]);
204 equation_terms.push(term);
205 }
206
207 equation_terms.push(Expression::mul(vec![
208 Expression::integer(-1),
209 initial_conditions[row_index].clone(),
210 ]));
211
212 Expression::add(equation_terms)
213 }
214
215 fn solve_for_constant(
219 &self,
220 solver: &LinearSolver,
221 equation: &Expression,
222 variable: &Symbol,
223 constant_name: &str,
224 ) -> Result<Expression, ODEError> {
225 match solver.solve(equation, variable) {
226 SolverResult::Single(value) => Ok(value),
227 SolverResult::NoSolution => Err(ODEError::NotLinearForm {
228 reason: format!(
229 "No solution for integration constant {} (inconsistent system)",
230 constant_name
231 ),
232 }),
233 SolverResult::InfiniteSolutions => Err(ODEError::NotLinearForm {
234 reason: format!(
235 "Infinite solutions for integration constant {} (underdetermined)",
236 constant_name
237 ),
238 }),
239 SolverResult::Multiple(_) => Err(ODEError::NotLinearForm {
240 reason: format!(
241 "Multiple solutions for integration constant {}",
242 constant_name
243 ),
244 }),
245 SolverResult::Parametric(_) => Err(ODEError::NotLinearForm {
246 reason: format!(
247 "Parametric solutions not supported for integration constant {}",
248 constant_name
249 ),
250 }),
251 SolverResult::Partial(_) => Err(ODEError::NotLinearForm {
252 reason: format!(
253 "Partial solutions not supported for integration constant {}",
254 constant_name
255 ),
256 }),
257 }
258 }
259
260 pub fn solve_2x2(
295 &self,
296 a11: &Expression,
297 a12: &Expression,
298 a21: &Expression,
299 a22: &Expression,
300 independent_var: &Symbol,
301 ) -> Result<Vec<Expression>, ODEError> {
302 let matrix = Matrix::dense(vec![
303 vec![a11.clone(), a12.clone()],
304 vec![a21.clone(), a22.clone()],
305 ]);
306
307 self.solve(&matrix, independent_var, None)
308 }
309
310 pub fn solve_3x3(
346 &self,
347 matrix_entries: &[Expression; 9],
348 independent_var: &Symbol,
349 ) -> Result<Vec<Expression>, ODEError> {
350 let matrix = Matrix::dense(vec![
351 vec![
352 matrix_entries[0].clone(),
353 matrix_entries[1].clone(),
354 matrix_entries[2].clone(),
355 ],
356 vec![
357 matrix_entries[3].clone(),
358 matrix_entries[4].clone(),
359 matrix_entries[5].clone(),
360 ],
361 vec![
362 matrix_entries[6].clone(),
363 matrix_entries[7].clone(),
364 matrix_entries[8].clone(),
365 ],
366 ]);
367
368 self.solve(&matrix, independent_var, None)
369 }
370}
371
372#[cfg(test)]
373mod tests {
374 use super::*;
375 use crate::{expr, symbol};
376
377 #[test]
378 fn test_diagonal_2x2_system() {
379 let t = symbol!(t);
380 let matrix = Matrix::diagonal(vec![expr!(1), expr!(2)]);
381 let solver = LinearSystemSolver;
382 let solution = solver.solve(&matrix, &t, None);
383
384 assert!(solution.is_ok(), "Should solve diagonal system");
385 let sol = solution.unwrap();
386 assert_eq!(sol.len(), 2, "Should have 2 solution components");
387 }
388
389 #[test]
390 fn test_non_square_matrix_error() {
391 let t = symbol!(t);
392 let matrix = Matrix::dense(vec![
393 vec![expr!(1), expr!(0)],
394 vec![expr!(0), expr!(2)],
395 vec![expr!(1), expr!(1)],
396 ]);
397
398 let solver = LinearSystemSolver;
399 let result = solver.solve(&matrix, &t, None);
400
401 assert!(result.is_err(), "Should reject non-square matrix");
402 }
403
404 #[test]
405 fn test_2x2_system_with_initial_conditions() {
406 let t = symbol!(t);
407 let matrix = Matrix::diagonal(vec![expr!(1), expr!(2)]);
408 let initial_conditions = vec![expr!(3), expr!(4)];
409
410 let solver = LinearSystemSolver;
411 let solution = solver.solve(&matrix, &t, Some(initial_conditions));
412
413 assert!(
414 solution.is_ok(),
415 "Should solve system with initial conditions: {:?}",
416 solution.err()
417 );
418
419 let sol = solution.unwrap();
420 assert_eq!(sol.len(), 2, "Should have 2 solution components");
421
422 let mut t_subs = HashMap::new();
423 t_subs.insert(t.name().to_string(), expr!(0));
424
425 let sol_at_zero: Vec<Expression> = sol
426 .iter()
427 .map(|expr| expr.substitute(&t_subs).simplify())
428 .collect();
429
430 assert_eq!(
431 sol_at_zero[0].simplify(),
432 expr!(3),
433 "First component at t=0 should be 3"
434 );
435 assert_eq!(
436 sol_at_zero[1].simplify(),
437 expr!(4),
438 "Second component at t=0 should be 4"
439 );
440 }
441
442 #[test]
443 fn test_2x2_system_zero_initial_conditions() {
444 let t = symbol!(t);
445 let matrix = Matrix::diagonal(vec![expr!(1), expr!(2)]);
446 let initial_conditions = vec![expr!(0), expr!(0)];
447
448 let solver = LinearSystemSolver;
449 let solution = solver.solve(&matrix, &t, Some(initial_conditions));
450
451 assert!(
452 solution.is_ok(),
453 "Should solve with zero initial conditions"
454 );
455
456 let sol = solution.unwrap();
457 let mut t_subs = HashMap::new();
458 t_subs.insert(t.name().to_string(), expr!(0));
459
460 let sol_at_zero: Vec<Expression> = sol
461 .iter()
462 .map(|expr| expr.substitute(&t_subs).simplify())
463 .collect();
464
465 assert_eq!(
466 sol_at_zero[0].simplify(),
467 expr!(0),
468 "First component at t=0 should be 0"
469 );
470 assert_eq!(
471 sol_at_zero[1].simplify(),
472 expr!(0),
473 "Second component at t=0 should be 0"
474 );
475 }
476
477 #[test]
478 fn test_wrong_size_initial_conditions() {
479 let t = symbol!(t);
480 let matrix = Matrix::diagonal(vec![expr!(1), expr!(2)]);
481 let initial_conditions = vec![expr!(1), expr!(2), expr!(3)];
482
483 let solver = LinearSystemSolver;
484 let result = solver.solve(&matrix, &t, Some(initial_conditions));
485
486 assert!(
487 result.is_err(),
488 "Should reject mismatched initial condition size"
489 );
490 }
491}