mathhook_core/calculus/pde/common/
eigenvalue_problem.rs

1//! Eigenvalue problem solver for Sturm-Liouville problems
2//!
3//! Solves standard eigenvalue problems of the form:
4//! X''(x) + λX(x) = 0
5//!
6//! With various boundary conditions:
7//! - Dirichlet: X(a) = 0, X(b) = 0
8//! - Neumann: X'(a) = 0, X'(b) = 0
9//! - Mixed: Dirichlet on one end, Neumann on the other
10//! - Robin: αX(a) + βX'(a) = 0
11//!
12//! Returns both eigenvalues and corresponding eigenfunctions.
13
14use crate::calculus::pde::types::{BoundaryCondition, BoundaryLocation};
15use crate::core::{Expression, Symbol};
16
17/// Result of solving an eigenvalue problem
18#[derive(Debug, Clone, PartialEq)]
19pub struct EigenvalueSolution {
20    /// The eigenvalues λₙ
21    pub eigenvalues: Vec<Expression>,
22    /// The eigenfunctions Xₙ(x)
23    pub eigenfunctions: Vec<Expression>,
24    /// The variable (e.g., x)
25    pub variable: Symbol,
26    /// The domain [a, b]
27    pub domain: (Expression, Expression),
28}
29
30/// Type of boundary condition pair
31#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32enum BoundaryType {
33    /// Dirichlet on both ends
34    DirichletDirichlet,
35    /// Neumann on both ends
36    NeumannNeumann,
37    /// Dirichlet at left, Neumann at right
38    DirichletNeumann,
39    /// Neumann at left, Dirichlet at right
40    NeumannDirichlet,
41}
42
43/// Solve Sturm-Liouville eigenvalue problem with boundary conditions
44///
45/// Solves: X''(x) + λX(x) = 0 on [a, b] with given BCs
46///
47/// # Arguments
48///
49/// * `bc_left` - Boundary condition at left endpoint
50/// * `bc_right` - Boundary condition at right endpoint
51/// * `num_modes` - Number of eigenvalue/eigenfunction pairs to compute
52///
53/// # Returns
54///
55/// Eigenvalues and eigenfunctions, or error if BCs are incompatible
56///
57/// # Examples
58///
59/// ```rust
60/// use mathhook_core::calculus::pde::common::eigenvalue_problem::solve_sturm_liouville;
61/// use mathhook_core::calculus::pde::types::BoundaryCondition;
62/// use mathhook_core::{symbol, expr};
63///
64/// let x = symbol!(x);
65/// let bc_left = BoundaryCondition::dirichlet_at(x.clone(), expr!(0), expr!(0));
66/// let bc_right = BoundaryCondition::dirichlet_at(x.clone(), expr!(pi), expr!(0));
67///
68/// let result = solve_sturm_liouville(&bc_left, &bc_right, 5);
69/// assert!(result.is_ok());
70/// let solution = result.unwrap();
71/// assert_eq!(solution.eigenvalues.len(), 5);
72/// ```
73pub fn solve_sturm_liouville(
74    bc_left: &BoundaryCondition,
75    bc_right: &BoundaryCondition,
76    num_modes: usize,
77) -> Result<EigenvalueSolution, String> {
78    let (var, domain) = extract_domain(bc_left, bc_right)?;
79    let bc_type = classify_boundary_conditions(bc_left, bc_right)?;
80
81    let (a, b) = domain.clone();
82    let length = compute_length(&a, &b);
83
84    let (eigenvalues, eigenfunctions) = match bc_type {
85        BoundaryType::DirichletDirichlet => solve_dirichlet_dirichlet(&var, &length, num_modes),
86        BoundaryType::NeumannNeumann => solve_neumann_neumann(&var, &length, num_modes),
87        BoundaryType::DirichletNeumann => solve_dirichlet_neumann(&var, &length, num_modes),
88        BoundaryType::NeumannDirichlet => solve_neumann_dirichlet(&var, &length, num_modes),
89    };
90
91    Ok(EigenvalueSolution {
92        eigenvalues,
93        eigenfunctions,
94        variable: var,
95        domain,
96    })
97}
98
99/// Extract spatial variable and domain from boundary conditions
100fn extract_domain(
101    bc_left: &BoundaryCondition,
102    bc_right: &BoundaryCondition,
103) -> Result<(Symbol, (Expression, Expression)), String> {
104    let (var_left, a) = extract_location(bc_left)?;
105    let (var_right, b) = extract_location(bc_right)?;
106
107    if var_left != var_right {
108        return Err(format!(
109            "Boundary conditions have different variables: {} and {}",
110            var_left.name(),
111            var_right.name()
112        ));
113    }
114
115    Ok((var_left, (a, b)))
116}
117
118/// Extract variable and location from a boundary condition
119fn extract_location(bc: &BoundaryCondition) -> Result<(Symbol, Expression), String> {
120    let location = match bc {
121        BoundaryCondition::Dirichlet { location, .. } => location,
122        BoundaryCondition::Neumann { location, .. } => location,
123        BoundaryCondition::Robin { location, .. } => location,
124    };
125
126    match location {
127        BoundaryLocation::Simple { variable, value } => Ok((variable.clone(), value.clone())),
128        _ => Err("Only simple boundary locations (var = value) are supported".to_owned()),
129    }
130}
131
132/// Classify boundary condition types
133fn classify_boundary_conditions(
134    bc_left: &BoundaryCondition,
135    bc_right: &BoundaryCondition,
136) -> Result<BoundaryType, String> {
137    let left_is_dirichlet = matches!(bc_left, BoundaryCondition::Dirichlet { .. });
138    let right_is_dirichlet = matches!(bc_right, BoundaryCondition::Dirichlet { .. });
139
140    let left_is_neumann = matches!(bc_left, BoundaryCondition::Neumann { .. });
141    let right_is_neumann = matches!(bc_right, BoundaryCondition::Neumann { .. });
142
143    if matches!(bc_left, BoundaryCondition::Robin { .. })
144        || matches!(bc_right, BoundaryCondition::Robin { .. })
145    {
146        return Err("Robin boundary conditions not yet implemented".to_owned());
147    }
148
149    match (left_is_dirichlet, right_is_dirichlet) {
150        (true, true) => Ok(BoundaryType::DirichletDirichlet),
151        (false, false) if left_is_neumann && right_is_neumann => Ok(BoundaryType::NeumannNeumann),
152        (true, false) if right_is_neumann => Ok(BoundaryType::DirichletNeumann),
153        (false, true) if left_is_neumann => Ok(BoundaryType::NeumannDirichlet),
154        _ => Err("Unsupported boundary condition combination".to_owned()),
155    }
156}
157
158/// Compute domain length L = b - a
159fn compute_length(a: &Expression, b: &Expression) -> Expression {
160    Expression::add(vec![
161        b.clone(),
162        Expression::mul(vec![Expression::integer(-1), a.clone()]),
163    ])
164}
165
166/// Solve Dirichlet-Dirichlet problem: X(0) = 0, X(L) = 0
167///
168/// Solution: λₙ = (nπ/L)², Xₙ(x) = sin(nπx/L)
169fn solve_dirichlet_dirichlet(
170    var: &Symbol,
171    length: &Expression,
172    num_modes: usize,
173) -> (Vec<Expression>, Vec<Expression>) {
174    let mut eigenvalues = Vec::new();
175    let mut eigenfunctions = Vec::new();
176
177    for n in 1..=num_modes {
178        let n_expr = Expression::integer(n as i64);
179
180        let n_pi = Expression::mul(vec![n_expr.clone(), Expression::pi()]);
181        let n_pi_squared = Expression::pow(n_pi.clone(), Expression::integer(2));
182        let length_squared = Expression::pow(length.clone(), Expression::integer(2));
183        let lambda_n = Expression::mul(vec![
184            n_pi_squared,
185            Expression::pow(length_squared, Expression::integer(-1)),
186        ]);
187        eigenvalues.push(lambda_n);
188
189        let arg = Expression::mul(vec![
190            n_pi,
191            Expression::symbol(var.clone()),
192            Expression::pow(length.clone(), Expression::integer(-1)),
193        ]);
194        let x_n = Expression::function("sin", vec![arg]);
195        eigenfunctions.push(x_n);
196    }
197
198    (eigenvalues, eigenfunctions)
199}
200
201/// Solve Neumann-Neumann problem: X'(0) = 0, X'(L) = 0
202///
203/// Solution: λ₀ = 0, X₀ = 1; λₙ = (nπ/L)², Xₙ(x) = cos(nπx/L) for n ≥ 1
204fn solve_neumann_neumann(
205    var: &Symbol,
206    length: &Expression,
207    num_modes: usize,
208) -> (Vec<Expression>, Vec<Expression>) {
209    let mut eigenvalues = Vec::new();
210    let mut eigenfunctions = Vec::new();
211
212    eigenvalues.push(Expression::integer(0));
213    eigenfunctions.push(Expression::integer(1));
214
215    for n in 1..num_modes {
216        let n_expr = Expression::integer(n as i64);
217
218        let n_pi = Expression::mul(vec![n_expr.clone(), Expression::pi()]);
219        let n_pi_squared = Expression::pow(n_pi.clone(), Expression::integer(2));
220        let length_squared = Expression::pow(length.clone(), Expression::integer(2));
221        let lambda_n = Expression::mul(vec![
222            n_pi_squared,
223            Expression::pow(length_squared, Expression::integer(-1)),
224        ]);
225        eigenvalues.push(lambda_n);
226
227        let arg = Expression::mul(vec![
228            n_pi,
229            Expression::symbol(var.clone()),
230            Expression::pow(length.clone(), Expression::integer(-1)),
231        ]);
232        let x_n = Expression::function("cos", vec![arg]);
233        eigenfunctions.push(x_n);
234    }
235
236    (eigenvalues, eigenfunctions)
237}
238
239/// Solve Dirichlet-Neumann problem: X(0) = 0, X'(L) = 0
240///
241/// Solution: λₙ = ((2n-1)π/2L)², Xₙ(x) = sin((2n-1)πx/2L)
242fn solve_dirichlet_neumann(
243    var: &Symbol,
244    length: &Expression,
245    num_modes: usize,
246) -> (Vec<Expression>, Vec<Expression>) {
247    let mut eigenvalues = Vec::new();
248    let mut eigenfunctions = Vec::new();
249
250    for n in 1..=num_modes {
251        let two_n_minus_1 = Expression::add(vec![
252            Expression::mul(vec![Expression::integer(2), Expression::integer(n as i64)]),
253            Expression::integer(-1),
254        ]);
255
256        let numerator = Expression::mul(vec![two_n_minus_1.clone(), Expression::pi()]);
257        let numerator_squared = Expression::pow(numerator.clone(), Expression::integer(2));
258
259        let denominator = Expression::mul(vec![
260            Expression::integer(4),
261            Expression::pow(length.clone(), Expression::integer(2)),
262        ]);
263
264        let lambda_n = Expression::mul(vec![
265            numerator_squared,
266            Expression::pow(denominator, Expression::integer(-1)),
267        ]);
268        eigenvalues.push(lambda_n);
269
270        let arg_numerator = Expression::mul(vec![
271            two_n_minus_1,
272            Expression::pi(),
273            Expression::symbol(var.clone()),
274        ]);
275        let arg_denominator = Expression::mul(vec![Expression::integer(2), length.clone()]);
276        let arg = Expression::mul(vec![
277            arg_numerator,
278            Expression::pow(arg_denominator, Expression::integer(-1)),
279        ]);
280        let x_n = Expression::function("sin", vec![arg]);
281        eigenfunctions.push(x_n);
282    }
283
284    (eigenvalues, eigenfunctions)
285}
286
287/// Solve Neumann-Dirichlet problem: X'(0) = 0, X(L) = 0
288///
289/// Solution: λₙ = ((2n-1)π/2L)², Xₙ(x) = cos((2n-1)πx/2L)
290fn solve_neumann_dirichlet(
291    var: &Symbol,
292    length: &Expression,
293    num_modes: usize,
294) -> (Vec<Expression>, Vec<Expression>) {
295    let mut eigenvalues = Vec::new();
296    let mut eigenfunctions = Vec::new();
297
298    for n in 1..=num_modes {
299        let two_n_minus_1 = Expression::add(vec![
300            Expression::mul(vec![Expression::integer(2), Expression::integer(n as i64)]),
301            Expression::integer(-1),
302        ]);
303
304        let numerator = Expression::mul(vec![two_n_minus_1.clone(), Expression::pi()]);
305        let numerator_squared = Expression::pow(numerator.clone(), Expression::integer(2));
306
307        let denominator = Expression::mul(vec![
308            Expression::integer(4),
309            Expression::pow(length.clone(), Expression::integer(2)),
310        ]);
311
312        let lambda_n = Expression::mul(vec![
313            numerator_squared,
314            Expression::pow(denominator, Expression::integer(-1)),
315        ]);
316        eigenvalues.push(lambda_n);
317
318        let arg_numerator = Expression::mul(vec![
319            two_n_minus_1,
320            Expression::pi(),
321            Expression::symbol(var.clone()),
322        ]);
323        let arg_denominator = Expression::mul(vec![Expression::integer(2), length.clone()]);
324        let arg = Expression::mul(vec![
325            arg_numerator,
326            Expression::pow(arg_denominator, Expression::integer(-1)),
327        ]);
328        let x_n = Expression::function("cos", vec![arg]);
329        eigenfunctions.push(x_n);
330    }
331
332    (eigenvalues, eigenfunctions)
333}
334
335#[cfg(test)]
336mod tests {
337    use super::*;
338    use crate::{expr, symbol};
339
340    #[test]
341    fn test_dirichlet_dirichlet_eigenvalues() {
342        let x = symbol!(x);
343        let bc_left = BoundaryCondition::dirichlet_at(x.clone(), expr!(0), expr!(0));
344        let bc_right = BoundaryCondition::dirichlet_at(x.clone(), expr!(pi), expr!(0));
345
346        let result = solve_sturm_liouville(&bc_left, &bc_right, 3);
347        assert!(result.is_ok());
348
349        let solution = result.unwrap();
350        assert_eq!(solution.eigenvalues.len(), 3);
351        assert_eq!(solution.eigenfunctions.len(), 3);
352    }
353
354    #[test]
355    fn test_neumann_neumann_eigenvalues() {
356        let x = symbol!(x);
357        let bc_left = BoundaryCondition::neumann_at(x.clone(), expr!(0), expr!(0));
358        let bc_right = BoundaryCondition::neumann_at(x.clone(), expr!(pi), expr!(0));
359
360        let result = solve_sturm_liouville(&bc_left, &bc_right, 3);
361        assert!(result.is_ok());
362
363        let solution = result.unwrap();
364        assert_eq!(solution.eigenvalues.len(), 3);
365        assert_eq!(solution.eigenfunctions.len(), 3);
366    }
367
368    #[test]
369    fn test_mixed_boundary_conditions() {
370        let x = symbol!(x);
371        let bc_left = BoundaryCondition::dirichlet_at(x.clone(), expr!(0), expr!(0));
372        let bc_right = BoundaryCondition::neumann_at(x.clone(), expr!(pi), expr!(0));
373
374        let result = solve_sturm_liouville(&bc_left, &bc_right, 3);
375        assert!(result.is_ok());
376
377        let solution = result.unwrap();
378        assert_eq!(solution.eigenvalues.len(), 3);
379        assert_eq!(solution.eigenfunctions.len(), 3);
380    }
381
382    #[test]
383    fn test_incompatible_variables() {
384        let x = symbol!(x);
385        let y = symbol!(y);
386        let bc_left = BoundaryCondition::dirichlet_at(x, expr!(0), expr!(0));
387        let bc_right = BoundaryCondition::dirichlet_at(y, expr!(pi), expr!(0));
388
389        let result = solve_sturm_liouville(&bc_left, &bc_right, 3);
390        assert!(result.is_err());
391    }
392
393    #[test]
394    fn test_extract_domain() {
395        let x = symbol!(x);
396        let bc_left = BoundaryCondition::dirichlet_at(x.clone(), expr!(0), expr!(0));
397        let bc_right = BoundaryCondition::dirichlet_at(x.clone(), expr!(1), expr!(0));
398
399        let result = extract_domain(&bc_left, &bc_right);
400        assert!(result.is_ok());
401
402        let (var, (a, b)) = result.unwrap();
403        assert_eq!(var, x);
404        assert_eq!(a, expr!(0));
405        assert_eq!(b, expr!(1));
406    }
407
408    #[test]
409    fn test_classify_boundary_conditions_dirichlet_dirichlet() {
410        let x = symbol!(x);
411        let bc_left = BoundaryCondition::dirichlet_at(x.clone(), expr!(0), expr!(0));
412        let bc_right = BoundaryCondition::dirichlet_at(x, expr!(1), expr!(0));
413
414        let result = classify_boundary_conditions(&bc_left, &bc_right);
415        assert_eq!(result.unwrap(), BoundaryType::DirichletDirichlet);
416    }
417
418    #[test]
419    fn test_classify_boundary_conditions_neumann_neumann() {
420        let x = symbol!(x);
421        let bc_left = BoundaryCondition::neumann_at(x.clone(), expr!(0), expr!(0));
422        let bc_right = BoundaryCondition::neumann_at(x, expr!(1), expr!(0));
423
424        let result = classify_boundary_conditions(&bc_left, &bc_right);
425        assert_eq!(result.unwrap(), BoundaryType::NeumannNeumann);
426    }
427
428    #[test]
429    fn test_dirichlet_neumann_mode_count() {
430        let x = symbol!(x);
431        let length = Expression::integer(1);
432
433        let (eigenvalues, eigenfunctions) = solve_dirichlet_neumann(&x, &length, 5);
434        assert_eq!(eigenvalues.len(), 5);
435        assert_eq!(eigenfunctions.len(), 5);
436    }
437
438    #[test]
439    fn test_neumann_dirichlet_mode_count() {
440        let x = symbol!(x);
441        let length = Expression::integer(1);
442
443        let (eigenvalues, eigenfunctions) = solve_neumann_dirichlet(&x, &length, 4);
444        assert_eq!(eigenvalues.len(), 4);
445        assert_eq!(eigenfunctions.len(), 4);
446    }
447}