1use crate::calculus::pde::types::{BoundaryCondition, BoundaryLocation};
15use crate::core::{Expression, Symbol};
16
17#[derive(Debug, Clone, PartialEq)]
19pub struct EigenvalueSolution {
20 pub eigenvalues: Vec<Expression>,
22 pub eigenfunctions: Vec<Expression>,
24 pub variable: Symbol,
26 pub domain: (Expression, Expression),
28}
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32enum BoundaryType {
33 DirichletDirichlet,
35 NeumannNeumann,
37 DirichletNeumann,
39 NeumannDirichlet,
41}
42
43pub fn solve_sturm_liouville(
74 bc_left: &BoundaryCondition,
75 bc_right: &BoundaryCondition,
76 num_modes: usize,
77) -> Result<EigenvalueSolution, String> {
78 let (var, domain) = extract_domain(bc_left, bc_right)?;
79 let bc_type = classify_boundary_conditions(bc_left, bc_right)?;
80
81 let (a, b) = domain.clone();
82 let length = compute_length(&a, &b);
83
84 let (eigenvalues, eigenfunctions) = match bc_type {
85 BoundaryType::DirichletDirichlet => solve_dirichlet_dirichlet(&var, &length, num_modes),
86 BoundaryType::NeumannNeumann => solve_neumann_neumann(&var, &length, num_modes),
87 BoundaryType::DirichletNeumann => solve_dirichlet_neumann(&var, &length, num_modes),
88 BoundaryType::NeumannDirichlet => solve_neumann_dirichlet(&var, &length, num_modes),
89 };
90
91 Ok(EigenvalueSolution {
92 eigenvalues,
93 eigenfunctions,
94 variable: var,
95 domain,
96 })
97}
98
99fn extract_domain(
101 bc_left: &BoundaryCondition,
102 bc_right: &BoundaryCondition,
103) -> Result<(Symbol, (Expression, Expression)), String> {
104 let (var_left, a) = extract_location(bc_left)?;
105 let (var_right, b) = extract_location(bc_right)?;
106
107 if var_left != var_right {
108 return Err(format!(
109 "Boundary conditions have different variables: {} and {}",
110 var_left.name(),
111 var_right.name()
112 ));
113 }
114
115 Ok((var_left, (a, b)))
116}
117
118fn extract_location(bc: &BoundaryCondition) -> Result<(Symbol, Expression), String> {
120 let location = match bc {
121 BoundaryCondition::Dirichlet { location, .. } => location,
122 BoundaryCondition::Neumann { location, .. } => location,
123 BoundaryCondition::Robin { location, .. } => location,
124 };
125
126 match location {
127 BoundaryLocation::Simple { variable, value } => Ok((variable.clone(), value.clone())),
128 _ => Err("Only simple boundary locations (var = value) are supported".to_owned()),
129 }
130}
131
132fn classify_boundary_conditions(
134 bc_left: &BoundaryCondition,
135 bc_right: &BoundaryCondition,
136) -> Result<BoundaryType, String> {
137 let left_is_dirichlet = matches!(bc_left, BoundaryCondition::Dirichlet { .. });
138 let right_is_dirichlet = matches!(bc_right, BoundaryCondition::Dirichlet { .. });
139
140 let left_is_neumann = matches!(bc_left, BoundaryCondition::Neumann { .. });
141 let right_is_neumann = matches!(bc_right, BoundaryCondition::Neumann { .. });
142
143 if matches!(bc_left, BoundaryCondition::Robin { .. })
144 || matches!(bc_right, BoundaryCondition::Robin { .. })
145 {
146 return Err("Robin boundary conditions not yet implemented".to_owned());
147 }
148
149 match (left_is_dirichlet, right_is_dirichlet) {
150 (true, true) => Ok(BoundaryType::DirichletDirichlet),
151 (false, false) if left_is_neumann && right_is_neumann => Ok(BoundaryType::NeumannNeumann),
152 (true, false) if right_is_neumann => Ok(BoundaryType::DirichletNeumann),
153 (false, true) if left_is_neumann => Ok(BoundaryType::NeumannDirichlet),
154 _ => Err("Unsupported boundary condition combination".to_owned()),
155 }
156}
157
158fn compute_length(a: &Expression, b: &Expression) -> Expression {
160 Expression::add(vec![
161 b.clone(),
162 Expression::mul(vec![Expression::integer(-1), a.clone()]),
163 ])
164}
165
166fn solve_dirichlet_dirichlet(
170 var: &Symbol,
171 length: &Expression,
172 num_modes: usize,
173) -> (Vec<Expression>, Vec<Expression>) {
174 let mut eigenvalues = Vec::new();
175 let mut eigenfunctions = Vec::new();
176
177 for n in 1..=num_modes {
178 let n_expr = Expression::integer(n as i64);
179
180 let n_pi = Expression::mul(vec![n_expr.clone(), Expression::pi()]);
181 let n_pi_squared = Expression::pow(n_pi.clone(), Expression::integer(2));
182 let length_squared = Expression::pow(length.clone(), Expression::integer(2));
183 let lambda_n = Expression::mul(vec![
184 n_pi_squared,
185 Expression::pow(length_squared, Expression::integer(-1)),
186 ]);
187 eigenvalues.push(lambda_n);
188
189 let arg = Expression::mul(vec![
190 n_pi,
191 Expression::symbol(var.clone()),
192 Expression::pow(length.clone(), Expression::integer(-1)),
193 ]);
194 let x_n = Expression::function("sin", vec![arg]);
195 eigenfunctions.push(x_n);
196 }
197
198 (eigenvalues, eigenfunctions)
199}
200
201fn solve_neumann_neumann(
205 var: &Symbol,
206 length: &Expression,
207 num_modes: usize,
208) -> (Vec<Expression>, Vec<Expression>) {
209 let mut eigenvalues = Vec::new();
210 let mut eigenfunctions = Vec::new();
211
212 eigenvalues.push(Expression::integer(0));
213 eigenfunctions.push(Expression::integer(1));
214
215 for n in 1..num_modes {
216 let n_expr = Expression::integer(n as i64);
217
218 let n_pi = Expression::mul(vec![n_expr.clone(), Expression::pi()]);
219 let n_pi_squared = Expression::pow(n_pi.clone(), Expression::integer(2));
220 let length_squared = Expression::pow(length.clone(), Expression::integer(2));
221 let lambda_n = Expression::mul(vec![
222 n_pi_squared,
223 Expression::pow(length_squared, Expression::integer(-1)),
224 ]);
225 eigenvalues.push(lambda_n);
226
227 let arg = Expression::mul(vec![
228 n_pi,
229 Expression::symbol(var.clone()),
230 Expression::pow(length.clone(), Expression::integer(-1)),
231 ]);
232 let x_n = Expression::function("cos", vec![arg]);
233 eigenfunctions.push(x_n);
234 }
235
236 (eigenvalues, eigenfunctions)
237}
238
239fn solve_dirichlet_neumann(
243 var: &Symbol,
244 length: &Expression,
245 num_modes: usize,
246) -> (Vec<Expression>, Vec<Expression>) {
247 let mut eigenvalues = Vec::new();
248 let mut eigenfunctions = Vec::new();
249
250 for n in 1..=num_modes {
251 let two_n_minus_1 = Expression::add(vec![
252 Expression::mul(vec![Expression::integer(2), Expression::integer(n as i64)]),
253 Expression::integer(-1),
254 ]);
255
256 let numerator = Expression::mul(vec![two_n_minus_1.clone(), Expression::pi()]);
257 let numerator_squared = Expression::pow(numerator.clone(), Expression::integer(2));
258
259 let denominator = Expression::mul(vec![
260 Expression::integer(4),
261 Expression::pow(length.clone(), Expression::integer(2)),
262 ]);
263
264 let lambda_n = Expression::mul(vec![
265 numerator_squared,
266 Expression::pow(denominator, Expression::integer(-1)),
267 ]);
268 eigenvalues.push(lambda_n);
269
270 let arg_numerator = Expression::mul(vec![
271 two_n_minus_1,
272 Expression::pi(),
273 Expression::symbol(var.clone()),
274 ]);
275 let arg_denominator = Expression::mul(vec![Expression::integer(2), length.clone()]);
276 let arg = Expression::mul(vec![
277 arg_numerator,
278 Expression::pow(arg_denominator, Expression::integer(-1)),
279 ]);
280 let x_n = Expression::function("sin", vec![arg]);
281 eigenfunctions.push(x_n);
282 }
283
284 (eigenvalues, eigenfunctions)
285}
286
287fn solve_neumann_dirichlet(
291 var: &Symbol,
292 length: &Expression,
293 num_modes: usize,
294) -> (Vec<Expression>, Vec<Expression>) {
295 let mut eigenvalues = Vec::new();
296 let mut eigenfunctions = Vec::new();
297
298 for n in 1..=num_modes {
299 let two_n_minus_1 = Expression::add(vec![
300 Expression::mul(vec![Expression::integer(2), Expression::integer(n as i64)]),
301 Expression::integer(-1),
302 ]);
303
304 let numerator = Expression::mul(vec![two_n_minus_1.clone(), Expression::pi()]);
305 let numerator_squared = Expression::pow(numerator.clone(), Expression::integer(2));
306
307 let denominator = Expression::mul(vec![
308 Expression::integer(4),
309 Expression::pow(length.clone(), Expression::integer(2)),
310 ]);
311
312 let lambda_n = Expression::mul(vec![
313 numerator_squared,
314 Expression::pow(denominator, Expression::integer(-1)),
315 ]);
316 eigenvalues.push(lambda_n);
317
318 let arg_numerator = Expression::mul(vec![
319 two_n_minus_1,
320 Expression::pi(),
321 Expression::symbol(var.clone()),
322 ]);
323 let arg_denominator = Expression::mul(vec![Expression::integer(2), length.clone()]);
324 let arg = Expression::mul(vec![
325 arg_numerator,
326 Expression::pow(arg_denominator, Expression::integer(-1)),
327 ]);
328 let x_n = Expression::function("cos", vec![arg]);
329 eigenfunctions.push(x_n);
330 }
331
332 (eigenvalues, eigenfunctions)
333}
334
335#[cfg(test)]
336mod tests {
337 use super::*;
338 use crate::{expr, symbol};
339
340 #[test]
341 fn test_dirichlet_dirichlet_eigenvalues() {
342 let x = symbol!(x);
343 let bc_left = BoundaryCondition::dirichlet_at(x.clone(), expr!(0), expr!(0));
344 let bc_right = BoundaryCondition::dirichlet_at(x.clone(), expr!(pi), expr!(0));
345
346 let result = solve_sturm_liouville(&bc_left, &bc_right, 3);
347 assert!(result.is_ok());
348
349 let solution = result.unwrap();
350 assert_eq!(solution.eigenvalues.len(), 3);
351 assert_eq!(solution.eigenfunctions.len(), 3);
352 }
353
354 #[test]
355 fn test_neumann_neumann_eigenvalues() {
356 let x = symbol!(x);
357 let bc_left = BoundaryCondition::neumann_at(x.clone(), expr!(0), expr!(0));
358 let bc_right = BoundaryCondition::neumann_at(x.clone(), expr!(pi), expr!(0));
359
360 let result = solve_sturm_liouville(&bc_left, &bc_right, 3);
361 assert!(result.is_ok());
362
363 let solution = result.unwrap();
364 assert_eq!(solution.eigenvalues.len(), 3);
365 assert_eq!(solution.eigenfunctions.len(), 3);
366 }
367
368 #[test]
369 fn test_mixed_boundary_conditions() {
370 let x = symbol!(x);
371 let bc_left = BoundaryCondition::dirichlet_at(x.clone(), expr!(0), expr!(0));
372 let bc_right = BoundaryCondition::neumann_at(x.clone(), expr!(pi), expr!(0));
373
374 let result = solve_sturm_liouville(&bc_left, &bc_right, 3);
375 assert!(result.is_ok());
376
377 let solution = result.unwrap();
378 assert_eq!(solution.eigenvalues.len(), 3);
379 assert_eq!(solution.eigenfunctions.len(), 3);
380 }
381
382 #[test]
383 fn test_incompatible_variables() {
384 let x = symbol!(x);
385 let y = symbol!(y);
386 let bc_left = BoundaryCondition::dirichlet_at(x, expr!(0), expr!(0));
387 let bc_right = BoundaryCondition::dirichlet_at(y, expr!(pi), expr!(0));
388
389 let result = solve_sturm_liouville(&bc_left, &bc_right, 3);
390 assert!(result.is_err());
391 }
392
393 #[test]
394 fn test_extract_domain() {
395 let x = symbol!(x);
396 let bc_left = BoundaryCondition::dirichlet_at(x.clone(), expr!(0), expr!(0));
397 let bc_right = BoundaryCondition::dirichlet_at(x.clone(), expr!(1), expr!(0));
398
399 let result = extract_domain(&bc_left, &bc_right);
400 assert!(result.is_ok());
401
402 let (var, (a, b)) = result.unwrap();
403 assert_eq!(var, x);
404 assert_eq!(a, expr!(0));
405 assert_eq!(b, expr!(1));
406 }
407
408 #[test]
409 fn test_classify_boundary_conditions_dirichlet_dirichlet() {
410 let x = symbol!(x);
411 let bc_left = BoundaryCondition::dirichlet_at(x.clone(), expr!(0), expr!(0));
412 let bc_right = BoundaryCondition::dirichlet_at(x, expr!(1), expr!(0));
413
414 let result = classify_boundary_conditions(&bc_left, &bc_right);
415 assert_eq!(result.unwrap(), BoundaryType::DirichletDirichlet);
416 }
417
418 #[test]
419 fn test_classify_boundary_conditions_neumann_neumann() {
420 let x = symbol!(x);
421 let bc_left = BoundaryCondition::neumann_at(x.clone(), expr!(0), expr!(0));
422 let bc_right = BoundaryCondition::neumann_at(x, expr!(1), expr!(0));
423
424 let result = classify_boundary_conditions(&bc_left, &bc_right);
425 assert_eq!(result.unwrap(), BoundaryType::NeumannNeumann);
426 }
427
428 #[test]
429 fn test_dirichlet_neumann_mode_count() {
430 let x = symbol!(x);
431 let length = Expression::integer(1);
432
433 let (eigenvalues, eigenfunctions) = solve_dirichlet_neumann(&x, &length, 5);
434 assert_eq!(eigenvalues.len(), 5);
435 assert_eq!(eigenfunctions.len(), 5);
436 }
437
438 #[test]
439 fn test_neumann_dirichlet_mode_count() {
440 let x = symbol!(x);
441 let length = Expression::integer(1);
442
443 let (eigenvalues, eigenfunctions) = solve_neumann_dirichlet(&x, &length, 4);
444 assert_eq!(eigenvalues.len(), 4);
445 assert_eq!(eigenfunctions.len(), 4);
446 }
447}