mathhook_core/calculus/pde/standard/
wave.rs1use crate::calculus::pde::common::{compute_wave_eigenvalues, create_symbolic_coefficients};
8use crate::calculus::pde::registry::{PDEError, PDEResult, PDESolver};
9use crate::calculus::pde::types::{BoundaryCondition, InitialCondition, PDESolution, Pde, PdeType};
10use crate::core::{Expression, Symbol};
11
12#[derive(Debug, Clone, PartialEq)]
14pub struct WaveSolution {
15 pub solution: Expression,
16 pub wave_speed: Expression,
17 pub eigenvalues: Vec<Expression>,
18 pub position_coefficients: Vec<Expression>,
19 pub velocity_coefficients: Vec<Expression>,
20}
21
22pub struct WaveEquationSolver {
24 max_terms: usize,
25}
26
27impl WaveEquationSolver {
28 pub fn new() -> Self {
29 Self { max_terms: 10 }
30 }
31
32 pub fn with_max_terms(max_terms: usize) -> Self {
33 Self { max_terms }
34 }
35
36 #[allow(unused_variables)]
49 pub fn solve_wave_equation_1d(
50 &self,
51 pde: &Pde,
52 wave_speed: &Expression,
53 boundary_conditions: &[BoundaryCondition],
54 initial_position: &InitialCondition,
55 initial_velocity: &InitialCondition,
56 ) -> Result<WaveSolution, PDEError> {
57 if pde.independent_vars.len() != 2 {
58 return Err(PDEError::InvalidForm {
59 reason: "1D wave equation requires exactly 2 independent variables (x, t)"
60 .to_owned(),
61 });
62 }
63
64 let spatial_var = &pde.independent_vars[0];
65
66 let eigenvalues =
67 compute_wave_eigenvalues(boundary_conditions, spatial_var, self.max_terms)?;
68
69 let position_coeffs = create_symbolic_coefficients("A", eigenvalues.len())?;
70
71 let velocity_coeffs = create_symbolic_coefficients("B", eigenvalues.len())?;
72
73 let solution = self.construct_wave_solution(
74 &pde.independent_vars,
75 wave_speed,
76 &eigenvalues,
77 &position_coeffs,
78 &velocity_coeffs,
79 );
80
81 Ok(WaveSolution {
82 solution,
83 wave_speed: wave_speed.clone(),
84 eigenvalues,
85 position_coefficients: position_coeffs,
86 velocity_coefficients: velocity_coeffs,
87 })
88 }
89
90 fn construct_wave_solution(
95 &self,
96 vars: &[Symbol],
97 wave_speed: &Expression,
98 eigenvalues: &[Expression],
99 position_coeffs: &[Expression],
100 velocity_coeffs: &[Expression],
101 ) -> Expression {
102 let x = &vars[0];
103 let t = &vars[1];
104
105 if eigenvalues.is_empty() || position_coeffs.is_empty() || velocity_coeffs.is_empty() {
106 return Expression::integer(0);
107 }
108
109 let mut terms = Vec::new();
110
111 for ((lambda, a_n), b_n) in eigenvalues
112 .iter()
113 .zip(position_coeffs.iter())
114 .zip(velocity_coeffs.iter())
115 {
116 let spatial = Expression::function(
117 "sin",
118 vec![Expression::mul(vec![
119 lambda.clone(),
120 Expression::symbol(x.clone()),
121 ])],
122 );
123
124 let omega = Expression::mul(vec![
125 lambda.clone(),
126 wave_speed.clone(),
127 Expression::symbol(t.clone()),
128 ]);
129
130 let cos_term = Expression::mul(vec![
131 a_n.clone(),
132 Expression::function("cos", vec![omega.clone()]),
133 ]);
134
135 let sin_term =
136 Expression::mul(vec![b_n.clone(), Expression::function("sin", vec![omega])]);
137
138 let temporal = Expression::add(vec![cos_term, sin_term]);
139
140 let term = Expression::mul(vec![temporal, spatial]);
141 terms.push(term);
142 }
143
144 Expression::add(terms)
145 }
146}
147
148impl PDESolver for WaveEquationSolver {
149 fn solve(&self, pde: &Pde) -> PDEResult {
150 use crate::expr;
151
152 let wave_speed = expr!(1);
153 let ic_pos = InitialCondition::value(expr!(1));
154 let ic_vel = InitialCondition::derivative(expr!(0));
155
156 let result = self.solve_wave_equation_1d(pde, &wave_speed, &[], &ic_pos, &ic_vel)?;
157
158 let mut all_coeffs = result.position_coefficients.clone();
159 all_coeffs.extend(result.velocity_coefficients.clone());
160
161 Ok(PDESolution::wave(
162 result.solution,
163 result.wave_speed,
164 result.eigenvalues,
165 all_coeffs,
166 ))
167 }
168
169 fn can_solve(&self, pde_type: PdeType) -> bool {
170 matches!(pde_type, PdeType::Hyperbolic)
171 }
172
173 fn priority(&self) -> u8 {
174 100
175 }
176
177 fn name(&self) -> &'static str {
178 "Wave Equation Solver"
179 }
180
181 fn description(&self) -> &'static str {
182 "Solves wave equation ∂²u/∂t² = c²∇²u using separation of variables and Fourier series"
183 }
184}
185
186impl Default for WaveEquationSolver {
187 fn default() -> Self {
188 Self::new()
189 }
190}
191
192pub fn solve_wave_equation_1d(
194 pde: &Pde,
195 wave_speed: &Expression,
196 boundary_conditions: &[BoundaryCondition],
197 initial_position: &InitialCondition,
198 initial_velocity: &InitialCondition,
199) -> Result<WaveSolution, String> {
200 WaveEquationSolver::new()
201 .solve_wave_equation_1d(
202 pde,
203 wave_speed,
204 boundary_conditions,
205 initial_position,
206 initial_velocity,
207 )
208 .map_err(|e| format!("{:?}", e))
209}
210
211#[cfg(test)]
212mod tests {
213 use super::*;
214 use crate::calculus::pde::types::BoundaryLocation;
215 use crate::{expr, symbol};
216
217 #[test]
218 fn test_wave_solver_creation() {
219 let solver = WaveEquationSolver::new();
220 assert_eq!(solver.name(), "Wave Equation Solver");
221 assert_eq!(solver.priority(), 100);
222 }
223
224 #[test]
225 fn test_wave_solver_can_solve() {
226 let solver = WaveEquationSolver::new();
227 assert!(solver.can_solve(PdeType::Hyperbolic));
228 assert!(!solver.can_solve(PdeType::Parabolic));
229 assert!(!solver.can_solve(PdeType::Elliptic));
230 }
231
232 #[test]
233 fn test_solve_wave_equation_1d_basic() {
234 let u = symbol!(u);
235 let x = symbol!(x);
236 let t = symbol!(t);
237 let equation = expr!(u);
238 let pde = Pde::new(equation, u, vec![x.clone(), t]);
239 let c = expr!(1);
240
241 let bc1 = BoundaryCondition::dirichlet(
242 expr!(0),
243 BoundaryLocation::Simple {
244 variable: x.clone(),
245 value: expr!(0),
246 },
247 );
248 let bc2 = BoundaryCondition::dirichlet(
249 expr!(0),
250 BoundaryLocation::Simple {
251 variable: x,
252 value: expr!(1),
253 },
254 );
255
256 let ic_pos = InitialCondition::value(expr!(1));
257 let ic_vel = InitialCondition::derivative(expr!(0));
258
259 let solver = WaveEquationSolver::new();
260 let result = solver.solve_wave_equation_1d(&pde, &c, &[bc1, bc2], &ic_pos, &ic_vel);
261 assert!(result.is_ok());
262
263 let solution = result.unwrap();
264 assert_eq!(solution.wave_speed, c);
265 assert!(!solution.eigenvalues.is_empty());
266 assert!(!solution.position_coefficients.is_empty());
267 assert!(!solution.velocity_coefficients.is_empty());
268 }
269
270 #[test]
271 fn test_solve_wave_equation_wrong_dimensions() {
272 let u = symbol!(u);
273 let x = symbol!(x);
274 let equation = expr!(u);
275 let pde = Pde::new(equation, u, vec![x]);
276 let c = expr!(1);
277 let ic_pos = InitialCondition::value(expr!(1));
278 let ic_vel = InitialCondition::derivative(expr!(0));
279
280 let solver = WaveEquationSolver::new();
281 let result = solver.solve_wave_equation_1d(&pde, &c, &[], &ic_pos, &ic_vel);
282 assert!(result.is_err());
283 }
284
285 #[test]
286 fn test_construct_wave_solution() {
287 let solver = WaveEquationSolver::new();
288 let x = symbol!(x);
289 let t = symbol!(t);
290 let vars = vec![x, t];
291 let c = expr!(1);
292 let eigenvalues = vec![expr!(1)];
293 let a_coeffs = vec![Expression::symbol(symbol!(A_1))];
294 let b_coeffs = vec![Expression::symbol(symbol!(B_1))];
295
296 let solution =
297 solver.construct_wave_solution(&vars, &c, &eigenvalues, &a_coeffs, &b_coeffs);
298 match solution {
299 Expression::Add(_) | Expression::Mul(_) => (),
300 _ => panic!("Expected addition or multiplication expression for wave solution"),
301 }
302 }
303
304 #[test]
305 fn test_wave_solution_structure() {
306 let u = symbol!(u);
307 let x = symbol!(x);
308 let t = symbol!(t);
309 let equation = expr!(u);
310 let pde = Pde::new(equation, u, vec![x.clone(), t]);
311 let c = expr!(1);
312
313 let bc = BoundaryCondition::dirichlet(
314 expr!(0),
315 BoundaryLocation::Simple {
316 variable: x,
317 value: expr!(0),
318 },
319 );
320 let ic_pos = InitialCondition::value(expr!(1));
321 let ic_vel = InitialCondition::derivative(expr!(0));
322
323 let solver = WaveEquationSolver::new();
324 let result = solver.solve_wave_equation_1d(&pde, &c, &[bc], &ic_pos, &ic_vel);
325 assert!(result.is_ok());
326
327 let solution = result.unwrap();
328 assert_eq!(
329 solution.eigenvalues.len(),
330 solution.position_coefficients.len()
331 );
332 assert_eq!(
333 solution.eigenvalues.len(),
334 solution.velocity_coefficients.len()
335 );
336 }
337
338 #[test]
339 fn test_wave_solution_clone() {
340 let solution = WaveSolution {
341 solution: expr!(1),
342 wave_speed: expr!(1),
343 eigenvalues: vec![expr!(1)],
344 position_coefficients: vec![expr!(1)],
345 velocity_coefficients: vec![expr!(1)],
346 };
347
348 let _cloned = solution.clone();
349 }
350
351 #[test]
352 fn test_pde_solver_trait() {
353 let solver = WaveEquationSolver::new();
354 let u = symbol!(u);
355 let x = symbol!(x);
356 let t = symbol!(t);
357 let equation = expr!(u);
358 let pde = Pde::new(equation, u, vec![x, t]);
359
360 let result = solver.solve(&pde);
361 assert!(result.is_ok());
362 }
363
364 #[test]
365 fn test_legacy_function() {
366 let u = symbol!(u);
367 let x = symbol!(x);
368 let t = symbol!(t);
369 let equation = expr!(u);
370 let pde = Pde::new(equation, u, vec![x, t]);
371 let c = expr!(1);
372 let ic_pos = InitialCondition::value(expr!(1));
373 let ic_vel = InitialCondition::derivative(expr!(0));
374
375 let result = solve_wave_equation_1d(&pde, &c, &[], &ic_pos, &ic_vel);
376 assert!(result.is_ok());
377 }
378}