mathhook_core/calculus/pde/standard/
wave.rs

1//! Wave equation solver
2//!
3//! Solves the wave equation: ∂²u/∂t² = c²∇²u
4//!
5//! Uses separation of variables and Fourier series for standard boundary conditions.
6
7use crate::calculus::pde::common::{compute_wave_eigenvalues, create_symbolic_coefficients};
8use crate::calculus::pde::registry::{PDEError, PDEResult, PDESolver};
9use crate::calculus::pde::types::{BoundaryCondition, InitialCondition, PDESolution, Pde, PdeType};
10use crate::core::{Expression, Symbol};
11
12/// Solution to the wave equation
13#[derive(Debug, Clone, PartialEq)]
14pub struct WaveSolution {
15    pub solution: Expression,
16    pub wave_speed: Expression,
17    pub eigenvalues: Vec<Expression>,
18    pub position_coefficients: Vec<Expression>,
19    pub velocity_coefficients: Vec<Expression>,
20}
21
22/// Wave equation solver implementing PDESolver trait
23pub struct WaveEquationSolver {
24    max_terms: usize,
25}
26
27impl WaveEquationSolver {
28    pub fn new() -> Self {
29        Self { max_terms: 10 }
30    }
31
32    pub fn with_max_terms(max_terms: usize) -> Self {
33        Self { max_terms }
34    }
35
36    /// Solves the 1D wave equation with full Fourier series computation
37    ///
38    /// For the wave equation: ∂²u/∂t² = c²∂²u/∂x²
39    /// with Dirichlet boundary conditions and initial conditions for u and ∂u/∂t
40    ///
41    /// # Arguments
42    ///
43    /// * `pde` - The wave equation PDE
44    /// * `wave_speed` - Wave propagation speed coefficient c
45    /// * `boundary_conditions` - Boundary conditions (typically Dirichlet: u(0,t)=0, u(L,t)=0)
46    /// * `initial_position` - Initial displacement: u(x,0) = f(x)
47    /// * `initial_velocity` - Initial velocity: ∂u/∂t(x,0) = g(x)
48    #[allow(unused_variables)]
49    pub fn solve_wave_equation_1d(
50        &self,
51        pde: &Pde,
52        wave_speed: &Expression,
53        boundary_conditions: &[BoundaryCondition],
54        initial_position: &InitialCondition,
55        initial_velocity: &InitialCondition,
56    ) -> Result<WaveSolution, PDEError> {
57        if pde.independent_vars.len() != 2 {
58            return Err(PDEError::InvalidForm {
59                reason: "1D wave equation requires exactly 2 independent variables (x, t)"
60                    .to_owned(),
61            });
62        }
63
64        let spatial_var = &pde.independent_vars[0];
65
66        let eigenvalues =
67            compute_wave_eigenvalues(boundary_conditions, spatial_var, self.max_terms)?;
68
69        let position_coeffs = create_symbolic_coefficients("A", eigenvalues.len())?;
70
71        let velocity_coeffs = create_symbolic_coefficients("B", eigenvalues.len())?;
72
73        let solution = self.construct_wave_solution(
74            &pde.independent_vars,
75            wave_speed,
76            &eigenvalues,
77            &position_coeffs,
78            &velocity_coeffs,
79        );
80
81        Ok(WaveSolution {
82            solution,
83            wave_speed: wave_speed.clone(),
84            eigenvalues,
85            position_coefficients: position_coeffs,
86            velocity_coefficients: velocity_coeffs,
87        })
88    }
89
90    /// Construct the complete wave equation solution
91    ///
92    /// Solution form: u(x,t) = Σ [Aₙcos(λₙct) + Bₙsin(λₙct)]sin(λₙx)
93    /// where λₙ are eigenvalues, Aₙ from initial position, Bₙ from initial velocity
94    fn construct_wave_solution(
95        &self,
96        vars: &[Symbol],
97        wave_speed: &Expression,
98        eigenvalues: &[Expression],
99        position_coeffs: &[Expression],
100        velocity_coeffs: &[Expression],
101    ) -> Expression {
102        let x = &vars[0];
103        let t = &vars[1];
104
105        if eigenvalues.is_empty() || position_coeffs.is_empty() || velocity_coeffs.is_empty() {
106            return Expression::integer(0);
107        }
108
109        let mut terms = Vec::new();
110
111        for ((lambda, a_n), b_n) in eigenvalues
112            .iter()
113            .zip(position_coeffs.iter())
114            .zip(velocity_coeffs.iter())
115        {
116            let spatial = Expression::function(
117                "sin",
118                vec![Expression::mul(vec![
119                    lambda.clone(),
120                    Expression::symbol(x.clone()),
121                ])],
122            );
123
124            let omega = Expression::mul(vec![
125                lambda.clone(),
126                wave_speed.clone(),
127                Expression::symbol(t.clone()),
128            ]);
129
130            let cos_term = Expression::mul(vec![
131                a_n.clone(),
132                Expression::function("cos", vec![omega.clone()]),
133            ]);
134
135            let sin_term =
136                Expression::mul(vec![b_n.clone(), Expression::function("sin", vec![omega])]);
137
138            let temporal = Expression::add(vec![cos_term, sin_term]);
139
140            let term = Expression::mul(vec![temporal, spatial]);
141            terms.push(term);
142        }
143
144        Expression::add(terms)
145    }
146}
147
148impl PDESolver for WaveEquationSolver {
149    fn solve(&self, pde: &Pde) -> PDEResult {
150        use crate::expr;
151
152        let wave_speed = expr!(1);
153        let ic_pos = InitialCondition::value(expr!(1));
154        let ic_vel = InitialCondition::derivative(expr!(0));
155
156        let result = self.solve_wave_equation_1d(pde, &wave_speed, &[], &ic_pos, &ic_vel)?;
157
158        let mut all_coeffs = result.position_coefficients.clone();
159        all_coeffs.extend(result.velocity_coefficients.clone());
160
161        Ok(PDESolution::wave(
162            result.solution,
163            result.wave_speed,
164            result.eigenvalues,
165            all_coeffs,
166        ))
167    }
168
169    fn can_solve(&self, pde_type: PdeType) -> bool {
170        matches!(pde_type, PdeType::Hyperbolic)
171    }
172
173    fn priority(&self) -> u8 {
174        100
175    }
176
177    fn name(&self) -> &'static str {
178        "Wave Equation Solver"
179    }
180
181    fn description(&self) -> &'static str {
182        "Solves wave equation ∂²u/∂t² = c²∇²u using separation of variables and Fourier series"
183    }
184}
185
186impl Default for WaveEquationSolver {
187    fn default() -> Self {
188        Self::new()
189    }
190}
191
192/// Legacy function for backward compatibility
193pub fn solve_wave_equation_1d(
194    pde: &Pde,
195    wave_speed: &Expression,
196    boundary_conditions: &[BoundaryCondition],
197    initial_position: &InitialCondition,
198    initial_velocity: &InitialCondition,
199) -> Result<WaveSolution, String> {
200    WaveEquationSolver::new()
201        .solve_wave_equation_1d(
202            pde,
203            wave_speed,
204            boundary_conditions,
205            initial_position,
206            initial_velocity,
207        )
208        .map_err(|e| format!("{:?}", e))
209}
210
211#[cfg(test)]
212mod tests {
213    use super::*;
214    use crate::calculus::pde::types::BoundaryLocation;
215    use crate::{expr, symbol};
216
217    #[test]
218    fn test_wave_solver_creation() {
219        let solver = WaveEquationSolver::new();
220        assert_eq!(solver.name(), "Wave Equation Solver");
221        assert_eq!(solver.priority(), 100);
222    }
223
224    #[test]
225    fn test_wave_solver_can_solve() {
226        let solver = WaveEquationSolver::new();
227        assert!(solver.can_solve(PdeType::Hyperbolic));
228        assert!(!solver.can_solve(PdeType::Parabolic));
229        assert!(!solver.can_solve(PdeType::Elliptic));
230    }
231
232    #[test]
233    fn test_solve_wave_equation_1d_basic() {
234        let u = symbol!(u);
235        let x = symbol!(x);
236        let t = symbol!(t);
237        let equation = expr!(u);
238        let pde = Pde::new(equation, u, vec![x.clone(), t]);
239        let c = expr!(1);
240
241        let bc1 = BoundaryCondition::dirichlet(
242            expr!(0),
243            BoundaryLocation::Simple {
244                variable: x.clone(),
245                value: expr!(0),
246            },
247        );
248        let bc2 = BoundaryCondition::dirichlet(
249            expr!(0),
250            BoundaryLocation::Simple {
251                variable: x,
252                value: expr!(1),
253            },
254        );
255
256        let ic_pos = InitialCondition::value(expr!(1));
257        let ic_vel = InitialCondition::derivative(expr!(0));
258
259        let solver = WaveEquationSolver::new();
260        let result = solver.solve_wave_equation_1d(&pde, &c, &[bc1, bc2], &ic_pos, &ic_vel);
261        assert!(result.is_ok());
262
263        let solution = result.unwrap();
264        assert_eq!(solution.wave_speed, c);
265        assert!(!solution.eigenvalues.is_empty());
266        assert!(!solution.position_coefficients.is_empty());
267        assert!(!solution.velocity_coefficients.is_empty());
268    }
269
270    #[test]
271    fn test_solve_wave_equation_wrong_dimensions() {
272        let u = symbol!(u);
273        let x = symbol!(x);
274        let equation = expr!(u);
275        let pde = Pde::new(equation, u, vec![x]);
276        let c = expr!(1);
277        let ic_pos = InitialCondition::value(expr!(1));
278        let ic_vel = InitialCondition::derivative(expr!(0));
279
280        let solver = WaveEquationSolver::new();
281        let result = solver.solve_wave_equation_1d(&pde, &c, &[], &ic_pos, &ic_vel);
282        assert!(result.is_err());
283    }
284
285    #[test]
286    fn test_construct_wave_solution() {
287        let solver = WaveEquationSolver::new();
288        let x = symbol!(x);
289        let t = symbol!(t);
290        let vars = vec![x, t];
291        let c = expr!(1);
292        let eigenvalues = vec![expr!(1)];
293        let a_coeffs = vec![Expression::symbol(symbol!(A_1))];
294        let b_coeffs = vec![Expression::symbol(symbol!(B_1))];
295
296        let solution =
297            solver.construct_wave_solution(&vars, &c, &eigenvalues, &a_coeffs, &b_coeffs);
298        match solution {
299            Expression::Add(_) | Expression::Mul(_) => (),
300            _ => panic!("Expected addition or multiplication expression for wave solution"),
301        }
302    }
303
304    #[test]
305    fn test_wave_solution_structure() {
306        let u = symbol!(u);
307        let x = symbol!(x);
308        let t = symbol!(t);
309        let equation = expr!(u);
310        let pde = Pde::new(equation, u, vec![x.clone(), t]);
311        let c = expr!(1);
312
313        let bc = BoundaryCondition::dirichlet(
314            expr!(0),
315            BoundaryLocation::Simple {
316                variable: x,
317                value: expr!(0),
318            },
319        );
320        let ic_pos = InitialCondition::value(expr!(1));
321        let ic_vel = InitialCondition::derivative(expr!(0));
322
323        let solver = WaveEquationSolver::new();
324        let result = solver.solve_wave_equation_1d(&pde, &c, &[bc], &ic_pos, &ic_vel);
325        assert!(result.is_ok());
326
327        let solution = result.unwrap();
328        assert_eq!(
329            solution.eigenvalues.len(),
330            solution.position_coefficients.len()
331        );
332        assert_eq!(
333            solution.eigenvalues.len(),
334            solution.velocity_coefficients.len()
335        );
336    }
337
338    #[test]
339    fn test_wave_solution_clone() {
340        let solution = WaveSolution {
341            solution: expr!(1),
342            wave_speed: expr!(1),
343            eigenvalues: vec![expr!(1)],
344            position_coefficients: vec![expr!(1)],
345            velocity_coefficients: vec![expr!(1)],
346        };
347
348        let _cloned = solution.clone();
349    }
350
351    #[test]
352    fn test_pde_solver_trait() {
353        let solver = WaveEquationSolver::new();
354        let u = symbol!(u);
355        let x = symbol!(x);
356        let t = symbol!(t);
357        let equation = expr!(u);
358        let pde = Pde::new(equation, u, vec![x, t]);
359
360        let result = solver.solve(&pde);
361        assert!(result.is_ok());
362    }
363
364    #[test]
365    fn test_legacy_function() {
366        let u = symbol!(u);
367        let x = symbol!(x);
368        let t = symbol!(t);
369        let equation = expr!(u);
370        let pde = Pde::new(equation, u, vec![x, t]);
371        let c = expr!(1);
372        let ic_pos = InitialCondition::value(expr!(1));
373        let ic_vel = InitialCondition::derivative(expr!(0));
374
375        let result = solve_wave_equation_1d(&pde, &c, &[], &ic_pos, &ic_vel);
376        assert!(result.is_ok());
377    }
378}