mathhook_core/calculus/ode/
systems.rs

1//! Linear system of ODEs solver
2//!
3//! Solves systems of first-order linear ODEs with constant coefficients:
4//! dx/dt = Ax where A is a constant matrix
5//!
6//! Uses eigenvalue-eigenvector method for diagonalizable systems:
7//! x(t) = c₁e^(λ₁t)v₁ + c₂e^(λ₂t)v₂ + ... + cₙe^(λₙt)vₙ
8
9use crate::algebra::solvers::{linear::LinearSolver, EquationSolver, SolverResult};
10use crate::calculus::ode::first_order::ODEError;
11use crate::core::{Expression, Symbol};
12use crate::matrices::Matrix;
13use crate::simplify::Simplify;
14use std::collections::HashMap;
15
16/// Linear system of ODEs solver
17///
18/// Solves systems dx/dt = Ax where A is a constant coefficient matrix.
19pub struct LinearSystemSolver;
20
21impl LinearSystemSolver {
22    /// Solve linear system dx/dt = Ax
23    ///
24    /// Uses eigenvalue-eigenvector method. For an n×n system:
25    /// - Compute eigenvalues λ₁, λ₂, ..., λₙ and eigenvectors v₁, v₂, ..., vₙ
26    /// - General solution: x(t) = c₁e^(λ₁t)v₁ + c₂e^(λ₂t)v₂ + ... + cₙe^(λₙt)vₙ
27    ///
28    /// # Complexity
29    ///
30    /// * **Time:** O(n³) for eigenvalue decomposition of n×n matrix
31    /// * **Space:** O(n²) for storing eigenvectors and intermediate results
32    ///
33    /// # Arguments
34    ///
35    /// * `coefficient_matrix` - The constant coefficient matrix A
36    /// * `independent_var` - The independent variable (typically t)
37    /// * `initial_conditions` - Optional initial state vector x(t₀) = x₀
38    ///
39    /// # Returns
40    ///
41    /// Vector of expressions representing the solution [x₁(t), x₂(t), ..., xₙ(t)]
42    ///
43    /// # Examples
44    ///
45    /// ```rust
46    /// use mathhook_core::calculus::ode::systems::LinearSystemSolver;
47    /// use mathhook_core::matrices::Matrix;
48    /// use mathhook_core::{symbol, expr};
49    ///
50    /// let t = symbol!(t);
51    ///
52    /// // 2×2 system: dx/dt = [1 0; 0 2]x
53    /// let matrix = Matrix::diagonal(vec![expr!(1), expr!(2)]);
54    ///
55    /// let solver = LinearSystemSolver;
56    /// let solution = solver.solve(&matrix, &t, None);
57    /// ```
58    pub fn solve(
59        &self,
60        coefficient_matrix: &Matrix,
61        independent_var: &Symbol,
62        initial_conditions: Option<Vec<Expression>>,
63    ) -> Result<Vec<Expression>, ODEError> {
64        let (rows, cols) = coefficient_matrix.dimensions();
65
66        if rows != cols {
67            return Err(ODEError::NotLinearForm {
68                reason: format!("Coefficient matrix must be square, got {}×{}", rows, cols),
69            });
70        }
71
72        let n = rows;
73
74        if !coefficient_matrix.is_diagonalizable() {
75            return Err(ODEError::NotImplemented {
76                feature: "Non-diagonalizable systems (requires Jordan normal form)".to_owned(),
77            });
78        }
79
80        let eigen_decomp =
81            coefficient_matrix
82                .eigen_decomposition()
83                .ok_or_else(|| ODEError::NotImplemented {
84                    feature: "Eigendecomposition failed".to_owned(),
85                })?;
86
87        let eigenvalues = &eigen_decomp.eigenvalues;
88        let eigenvectors = &eigen_decomp.eigenvectors;
89
90        let solution_components: Vec<Vec<Expression>> = eigenvalues
91            .iter()
92            .enumerate()
93            .map(|(i, lambda)| {
94                let eigenvector_col: Vec<Expression> = (0..n)
95                    .map(|row_idx| eigenvectors.get_element(row_idx, i))
96                    .collect();
97
98                let exponent = Expression::mul(vec![
99                    lambda.clone(),
100                    Expression::symbol(independent_var.clone()),
101                ]);
102                let exp_term = Expression::function("exp", vec![exponent]);
103
104                let c_symbol = Symbol::new(format!("C{}", i + 1));
105                let c = Expression::symbol(c_symbol);
106
107                let scaled_exp = Expression::mul(vec![c, exp_term]);
108
109                eigenvector_col
110                    .into_iter()
111                    .map(|component| Expression::mul(vec![scaled_exp.clone(), component]))
112                    .collect()
113            })
114            .collect();
115
116        let final_solution: Vec<Expression> = (0..n)
117            .map(|i| {
118                let sum_terms: Vec<Expression> = solution_components
119                    .iter()
120                    .map(|comp| comp[i].clone())
121                    .collect();
122                Expression::add(sum_terms).simplify()
123            })
124            .collect();
125
126        if let Some(ic) = initial_conditions {
127            return self.apply_initial_conditions(&final_solution, &ic, n, eigenvectors);
128        }
129
130        Ok(final_solution)
131    }
132
133    /// Apply initial conditions to solve for integration constants
134    ///
135    /// Solves the linear system V*c = y₀ where:
136    /// - V is the eigenvector matrix
137    /// - c is the vector of constants [C1, C2, ..., Cn]
138    /// - y₀ is the initial condition vector
139    fn apply_initial_conditions(
140        &self,
141        general_solution: &[Expression],
142        initial_conditions: &[Expression],
143        n: usize,
144        eigenvectors: &Matrix,
145    ) -> Result<Vec<Expression>, ODEError> {
146        if initial_conditions.len() != n {
147            return Err(ODEError::NotLinearForm {
148                reason: format!(
149                    "Initial conditions length {} does not match system size {}",
150                    initial_conditions.len(),
151                    n
152                ),
153            });
154        }
155
156        let linear_solver = LinearSolver::new_fast();
157        let mut constant_values: HashMap<String, Expression> = HashMap::new();
158
159        for i in 0..n {
160            let constant_name = format!("C{}", i + 1);
161            let equation = self.build_constant_equation(i, n, eigenvectors, initial_conditions);
162
163            let substituted_equation = if i == 0 {
164                equation
165            } else {
166                equation.substitute(&constant_values).simplify()
167            };
168
169            let constant_symbol = Symbol::new(&constant_name);
170            let value = self.solve_for_constant(
171                &linear_solver,
172                &substituted_equation,
173                &constant_symbol,
174                &constant_name,
175            )?;
176
177            constant_values.insert(constant_name, value);
178        }
179
180        let particular_solution: Vec<Expression> = general_solution
181            .iter()
182            .map(|expr| expr.substitute(&constant_values).simplify())
183            .collect();
184
185        Ok(particular_solution)
186    }
187
188    /// Build equation for a single integration constant
189    ///
190    /// Constructs: Σⱼ vᵢⱼ*Cⱼ - y₀ᵢ = 0
191    fn build_constant_equation(
192        &self,
193        row_index: usize,
194        n: usize,
195        eigenvectors: &Matrix,
196        initial_conditions: &[Expression],
197    ) -> Expression {
198        let mut equation_terms = Vec::new();
199
200        for j in 0..n {
201            let eigenvector_component = eigenvectors.get_element(row_index, j);
202            let c_symbol = Symbol::new(format!("C{}", j + 1));
203            let term = Expression::mul(vec![eigenvector_component, Expression::symbol(c_symbol)]);
204            equation_terms.push(term);
205        }
206
207        equation_terms.push(Expression::mul(vec![
208            Expression::integer(-1),
209            initial_conditions[row_index].clone(),
210        ]));
211
212        Expression::add(equation_terms)
213    }
214
215    /// Solve for a single integration constant
216    ///
217    /// Handles all possible solver result cases with appropriate error messages
218    fn solve_for_constant(
219        &self,
220        solver: &LinearSolver,
221        equation: &Expression,
222        variable: &Symbol,
223        constant_name: &str,
224    ) -> Result<Expression, ODEError> {
225        match solver.solve(equation, variable) {
226            SolverResult::Single(value) => Ok(value),
227            SolverResult::NoSolution => Err(ODEError::NotLinearForm {
228                reason: format!(
229                    "No solution for integration constant {} (inconsistent system)",
230                    constant_name
231                ),
232            }),
233            SolverResult::InfiniteSolutions => Err(ODEError::NotLinearForm {
234                reason: format!(
235                    "Infinite solutions for integration constant {} (underdetermined)",
236                    constant_name
237                ),
238            }),
239            SolverResult::Multiple(_) => Err(ODEError::NotLinearForm {
240                reason: format!(
241                    "Multiple solutions for integration constant {}",
242                    constant_name
243                ),
244            }),
245            SolverResult::Parametric(_) => Err(ODEError::NotLinearForm {
246                reason: format!(
247                    "Parametric solutions not supported for integration constant {}",
248                    constant_name
249                ),
250            }),
251            SolverResult::Partial(_) => Err(ODEError::NotLinearForm {
252                reason: format!(
253                    "Partial solutions not supported for integration constant {}",
254                    constant_name
255                ),
256            }),
257        }
258    }
259
260    /// Solve 2×2 linear system dx/dt = Ax
261    ///
262    /// Specialized solver for 2×2 systems with explicit formulas.
263    ///
264    /// # Complexity
265    ///
266    /// * **Time:** O(1) for 2×2 eigenvalue computation (quadratic formula)
267    /// * **Space:** O(1) for storing solution components
268    ///
269    /// # Arguments
270    ///
271    /// * `a11`, `a12`, `a21`, `a22` - Matrix coefficients [a11 a12; a21 a22]
272    /// * `independent_var` - The independent variable (typically t)
273    ///
274    /// # Returns
275    ///
276    /// Vector [x₁(t), x₂(t)] representing the solution
277    ///
278    /// # Examples
279    ///
280    /// ```rust
281    /// use mathhook_core::calculus::ode::systems::LinearSystemSolver;
282    /// use mathhook_core::{symbol, expr};
283    ///
284    /// let t = symbol!(t);
285    ///
286    /// // dx/dt = [1 0; 0 2]x
287    /// let solver = LinearSystemSolver;
288    /// let solution = solver.solve_2x2(
289    ///     &expr!(1), &expr!(0),
290    ///     &expr!(0), &expr!(2),
291    ///     &t
292    /// );
293    /// ```
294    pub fn solve_2x2(
295        &self,
296        a11: &Expression,
297        a12: &Expression,
298        a21: &Expression,
299        a22: &Expression,
300        independent_var: &Symbol,
301    ) -> Result<Vec<Expression>, ODEError> {
302        let matrix = Matrix::dense(vec![
303            vec![a11.clone(), a12.clone()],
304            vec![a21.clone(), a22.clone()],
305        ]);
306
307        self.solve(&matrix, independent_var, None)
308    }
309
310    /// Solve 3×3 linear system dx/dt = Ax
311    ///
312    /// Specialized solver for 3×3 systems.
313    ///
314    /// # Complexity
315    ///
316    /// * **Time:** O(1) for 3×3 eigenvalue computation (cubic formula)
317    /// * **Space:** O(1) for storing solution components
318    ///
319    /// # Arguments
320    ///
321    /// * `matrix_entries` - Flattened 3×3 matrix entries [a11, a12, a13, a21, a22, a23, a31, a32, a33]
322    /// * `independent_var` - The independent variable (typically t)
323    ///
324    /// # Returns
325    ///
326    /// Vector [x₁(t), x₂(t), x₃(t)] representing the solution
327    ///
328    /// # Examples
329    ///
330    /// ```rust
331    /// use mathhook_core::calculus::ode::systems::LinearSystemSolver;
332    /// use mathhook_core::{symbol, expr};
333    ///
334    /// let t = symbol!(t);
335    ///
336    /// // dx/dt = [1 0 0; 0 2 0; 0 0 3]x (diagonal)
337    /// let solver = LinearSystemSolver;
338    /// let solution = solver.solve_3x3(
339    ///     &[expr!(1), expr!(0), expr!(0),
340    ///       expr!(0), expr!(2), expr!(0),
341    ///       expr!(0), expr!(0), expr!(3)],
342    ///     &t
343    /// );
344    /// ```
345    pub fn solve_3x3(
346        &self,
347        matrix_entries: &[Expression; 9],
348        independent_var: &Symbol,
349    ) -> Result<Vec<Expression>, ODEError> {
350        let matrix = Matrix::dense(vec![
351            vec![
352                matrix_entries[0].clone(),
353                matrix_entries[1].clone(),
354                matrix_entries[2].clone(),
355            ],
356            vec![
357                matrix_entries[3].clone(),
358                matrix_entries[4].clone(),
359                matrix_entries[5].clone(),
360            ],
361            vec![
362                matrix_entries[6].clone(),
363                matrix_entries[7].clone(),
364                matrix_entries[8].clone(),
365            ],
366        ]);
367
368        self.solve(&matrix, independent_var, None)
369    }
370}
371
372#[cfg(test)]
373mod tests {
374    use super::*;
375    use crate::{expr, symbol};
376
377    #[test]
378    fn test_diagonal_2x2_system() {
379        let t = symbol!(t);
380        let matrix = Matrix::diagonal(vec![expr!(1), expr!(2)]);
381        let solver = LinearSystemSolver;
382        let solution = solver.solve(&matrix, &t, None);
383
384        assert!(solution.is_ok(), "Should solve diagonal system");
385        let sol = solution.unwrap();
386        assert_eq!(sol.len(), 2, "Should have 2 solution components");
387    }
388
389    #[test]
390    fn test_non_square_matrix_error() {
391        let t = symbol!(t);
392        let matrix = Matrix::dense(vec![
393            vec![expr!(1), expr!(0)],
394            vec![expr!(0), expr!(2)],
395            vec![expr!(1), expr!(1)],
396        ]);
397
398        let solver = LinearSystemSolver;
399        let result = solver.solve(&matrix, &t, None);
400
401        assert!(result.is_err(), "Should reject non-square matrix");
402    }
403
404    #[test]
405    fn test_2x2_system_with_initial_conditions() {
406        let t = symbol!(t);
407        let matrix = Matrix::diagonal(vec![expr!(1), expr!(2)]);
408        let initial_conditions = vec![expr!(3), expr!(4)];
409
410        let solver = LinearSystemSolver;
411        let solution = solver.solve(&matrix, &t, Some(initial_conditions));
412
413        assert!(
414            solution.is_ok(),
415            "Should solve system with initial conditions: {:?}",
416            solution.err()
417        );
418
419        let sol = solution.unwrap();
420        assert_eq!(sol.len(), 2, "Should have 2 solution components");
421
422        let mut t_subs = HashMap::new();
423        t_subs.insert(t.name().to_string(), expr!(0));
424
425        let sol_at_zero: Vec<Expression> = sol
426            .iter()
427            .map(|expr| expr.substitute(&t_subs).simplify())
428            .collect();
429
430        assert_eq!(
431            sol_at_zero[0].simplify(),
432            expr!(3),
433            "First component at t=0 should be 3"
434        );
435        assert_eq!(
436            sol_at_zero[1].simplify(),
437            expr!(4),
438            "Second component at t=0 should be 4"
439        );
440    }
441
442    #[test]
443    fn test_2x2_system_zero_initial_conditions() {
444        let t = symbol!(t);
445        let matrix = Matrix::diagonal(vec![expr!(1), expr!(2)]);
446        let initial_conditions = vec![expr!(0), expr!(0)];
447
448        let solver = LinearSystemSolver;
449        let solution = solver.solve(&matrix, &t, Some(initial_conditions));
450
451        assert!(
452            solution.is_ok(),
453            "Should solve with zero initial conditions"
454        );
455
456        let sol = solution.unwrap();
457        let mut t_subs = HashMap::new();
458        t_subs.insert(t.name().to_string(), expr!(0));
459
460        let sol_at_zero: Vec<Expression> = sol
461            .iter()
462            .map(|expr| expr.substitute(&t_subs).simplify())
463            .collect();
464
465        assert_eq!(
466            sol_at_zero[0].simplify(),
467            expr!(0),
468            "First component at t=0 should be 0"
469        );
470        assert_eq!(
471            sol_at_zero[1].simplify(),
472            expr!(0),
473            "Second component at t=0 should be 0"
474        );
475    }
476
477    #[test]
478    fn test_wrong_size_initial_conditions() {
479        let t = symbol!(t);
480        let matrix = Matrix::diagonal(vec![expr!(1), expr!(2)]);
481        let initial_conditions = vec![expr!(1), expr!(2), expr!(3)];
482
483        let solver = LinearSystemSolver;
484        let result = solver.solve(&matrix, &t, Some(initial_conditions));
485
486        assert!(
487            result.is_err(),
488            "Should reject mismatched initial condition size"
489        );
490    }
491}