differential_equations/methods/bvp/shooting/
single.rs1use crate::{
2 bvp::Boundary,
3 error::Error,
4 interpolate::Interpolation,
5 linalg::{Matrix, lin_solve, lu_decomp},
6 methods::{ToleranceConfig, bvp::BVPMethod},
7 ode::{ODE, OrdinaryNumericalMethod, solve_ode},
8 solout::{DefaultSolout, Solout},
9 solution::Solution,
10 stats::{Evals, Steps},
11 tolerance::Tolerance,
12 traits::{Real, State},
13};
14
15#[derive(Clone, Debug)]
20pub struct SingleShooting<M> {
21 max_iterations: usize,
22 tolerance: f64,
23 ode_solver: M,
24}
25
26impl<M> SingleShooting<M> {
27 pub fn new(ode_solver: M) -> Self {
29 Self {
30 max_iterations: 100,
31 tolerance: 1e-6,
32 ode_solver,
33 }
34 }
35
36 pub fn max_iterations(mut self, max_iterations: usize) -> Self {
38 self.max_iterations = max_iterations;
39 self
40 }
41
42 pub fn tolerance(mut self, tolerance: f64) -> Self {
44 self.tolerance = tolerance;
45 self
46 }
47}
48
49impl<M, T> ToleranceConfig<T> for SingleShooting<M>
50where
51 T: Real,
52 M: ToleranceConfig<T>,
53{
54 fn rtol<V: Into<Tolerance<T>>>(mut self, rtol: V) -> Self {
55 self.ode_solver = self.ode_solver.rtol(rtol);
56 self
57 }
58
59 fn atol<V: Into<Tolerance<T>>>(mut self, atol: V) -> Self {
60 self.ode_solver = self.ode_solver.atol(atol);
61 self
62 }
63}
64
65struct BvpToOde<'a, EqType: ?Sized> {
67 problem: &'a EqType,
68}
69
70impl<EqType, T: Real, Y: State<T>> ODE<T, Y> for BvpToOde<'_, EqType>
71where
72 EqType: ODE<T, Y> + Boundary<T, Y> + ?Sized,
73{
74 #[inline]
75 fn diff(&self, t: T, y: &Y, dydt: &mut Y) {
76 self.problem.diff(t, y, dydt);
77 }
78}
79
80impl<M, T, Y> BVPMethod<T, Y> for SingleShooting<M>
81where
82 T: Real,
83 Y: State<T>,
84 M: OrdinaryNumericalMethod<T, Y> + Interpolation<T, Y> + Clone,
85{
86 fn solve<EqType, SoloutType>(
87 &mut self,
88 problem: &EqType,
89 t0: T,
90 tf: T,
91 y_guess: &Y,
92 solout: &mut SoloutType,
93 ) -> Result<Solution<T, Y>, Error<T, Y>>
94 where
95 EqType: ODE<T, Y> + Boundary<T, Y> + ?Sized,
96 SoloutType: Solout<T, Y>,
97 {
98 let dim = y_guess.len();
99 let mut y = y_guess.clone();
100 let mut residual = y_guess.zeros_like();
101 let mut jacobian = Matrix::<T>::zeros(dim, dim);
102 let mut ip = vec![0; dim];
103 let mut total_evals = Evals::new();
104 let mut total_steps = Steps::new();
105 let tolerance = T::from_f64(self.tolerance).ok_or_else(|| Error::BadInput {
106 msg: "BVP shooting tolerance cannot be represented by scalar type.".to_string(),
107 })?;
108
109 let ode_system = BvpToOde { problem };
110
111 for _ in 0..self.max_iterations {
112 let mut trial_solver = self.ode_solver.clone();
113 let mut trial_solout = DefaultSolout::new();
114 let sol = solve_ode(
115 &mut trial_solver,
116 &ode_system,
117 t0,
118 tf,
119 &y,
120 &mut trial_solout,
121 )?;
122 total_evals += sol.evals;
123 total_steps += sol.steps;
124
125 let (_, y_f) = sol.last().map_err(|err| Error::BadInput {
126 msg: format!("Internal IVP solve returned an empty solution: {err}"),
127 })?;
128
129 problem.boundary(&y, y_f, &mut residual);
130
131 if residual.max_norm() <= tolerance {
132 let mut final_solver = self.ode_solver.clone();
133 let mut solution = solve_ode(&mut final_solver, &ode_system, t0, tf, &y, solout)?;
134 solution.evals += total_evals;
135 solution.steps += total_steps;
136 return Ok(solution);
137 }
138
139 let eps = T::default_epsilon().sqrt();
140 for j in 0..dim {
141 let mut y_perturbed = y.clone();
142 let y_j = y.get_component(j);
143 let perturbation = eps * y_j.abs().max(T::one());
144 y_perturbed.set_component(j, y_j + perturbation);
145
146 let mut perturbed_solver = self.ode_solver.clone();
147 let mut perturbed_solout = DefaultSolout::new();
148 let sol_perturbed = solve_ode(
149 &mut perturbed_solver,
150 &ode_system,
151 t0,
152 tf,
153 &y_perturbed,
154 &mut perturbed_solout,
155 )?;
156 total_evals += sol_perturbed.evals;
157 total_steps += sol_perturbed.steps;
158 let (_, y_f_perturbed) = sol_perturbed.last().map_err(|err| Error::BadInput {
159 msg: format!("Internal perturbed IVP solve returned an empty solution: {err}"),
160 })?;
161
162 let mut res_perturbed = residual.clone();
163 problem.boundary(&y_perturbed, y_f_perturbed, &mut res_perturbed);
164 total_evals.jacobian += 1;
165
166 for i in 0..dim {
167 jacobian[(i, j)] =
168 (res_perturbed.get_component(i) - residual.get_component(i)) / perturbation;
169 }
170 }
171
172 let mut step = y.zeros_like();
173 for i in 0..dim {
174 step.set_component(i, -residual.get_component(i));
175 }
176
177 lu_decomp(&mut jacobian, &mut ip).map_err(|err| Error::LinearAlgebra {
178 t: t0,
179 y: y.clone(),
180 msg: err.to_string(),
181 })?;
182 lin_solve(&jacobian, &mut step, &ip);
183 total_evals.newton += 1;
184 total_evals.decompositions += 1;
185 total_evals.solves += 1;
186
187 for i in 0..dim {
188 y.set_component(i, y.get_component(i) + step.get_component(i));
189 }
190 }
191
192 Err(Error::MaxSteps {
193 t: t0,
194 y: y_guess.clone(),
195 })
196 }
197}
198
199#[cfg(test)]
200mod tests {
201 use super::*;
202 use crate::{
203 bvp::BVP,
204 methods::{ExplicitRungeKutta, bvp::Shooting},
205 };
206
207 struct HarmonicOscillatorBvp {
208 target: f64,
209 }
210
211 impl ODE<f64, [f64; 2]> for HarmonicOscillatorBvp {
212 fn diff(&self, _t: f64, y: &[f64; 2], dydt: &mut [f64; 2]) {
213 dydt[0] = y[1];
214 dydt[1] = -y[0];
215 }
216 }
217
218 impl Boundary<f64, [f64; 2]> for HarmonicOscillatorBvp {
219 fn boundary(&self, y_a: &[f64; 2], y_b: &[f64; 2], res: &mut [f64; 2]) {
220 res[0] = y_a[0];
221 res[1] = y_b[0] - self.target;
222 }
223 }
224
225 #[test]
226 fn shooting_solves_harmonic_oscillator_with_trait_api() {
227 let problem = HarmonicOscillatorBvp { target: 1.0 };
228 let method = Shooting::single(ExplicitRungeKutta::dop853());
229
230 let result = BVP::ode(&problem, 0.0, std::f64::consts::FRAC_PI_2, [0.0, 0.5])
231 .method(method)
232 .solve()
233 .expect("BVP solve should converge");
234
235 let (_, y_initial) = result.iter().next().expect("solution has an initial point");
236 let (_, y_final) = result.last().expect("solution has a final point");
237
238 assert!(y_initial[0].abs() < 1e-5);
239 assert!((y_initial[1] - 1.0).abs() < 1e-5);
240 assert!((y_final[0] - 1.0).abs() < 1e-5);
241 assert!(y_final[1].abs() < 1e-5);
242 }
243
244 #[test]
245 fn shooting_solves_harmonic_oscillator_with_closure_api() {
246 let method = Shooting::single(ExplicitRungeKutta::dop853());
247
248 let result = BVP::ode_from_fn(
249 |_t, y: &[f64; 2], dydt: &mut [f64; 2]| {
250 dydt[0] = y[1];
251 dydt[1] = -y[0];
252 },
253 |y_a: &[f64; 2], y_b: &[f64; 2], res: &mut [f64; 2]| {
254 res[0] = y_a[0];
255 res[1] = y_b[0] - 1.0;
256 },
257 0.0,
258 std::f64::consts::FRAC_PI_2,
259 [0.0, 0.5],
260 )
261 .method(method)
262 .solve()
263 .expect("BVP solve should converge");
264
265 let (_, y_initial) = result.iter().next().expect("solution has an initial point");
266 let (_, y_final) = result.last().expect("solution has a final point");
267
268 assert!((y_initial[1] - 1.0).abs() < 1e-5);
269 assert!((y_final[0] - 1.0).abs() < 1e-5);
270 }
271
272 #[test]
273 fn shooting_supports_t_eval_output_for_final_trajectory() {
274 let problem = HarmonicOscillatorBvp { target: 1.0 };
275 let method = Shooting::single(ExplicitRungeKutta::dop853());
276 let points = [
277 0.0,
278 std::f64::consts::FRAC_PI_4,
279 std::f64::consts::FRAC_PI_2,
280 ];
281
282 let result = BVP::ode(&problem, 0.0, std::f64::consts::FRAC_PI_2, [0.0, 0.5])
283 .t_eval(points)
284 .method(method)
285 .solve()
286 .expect("BVP solve should converge with t_eval output");
287
288 assert_eq!(result.t, points);
289 assert_eq!(result.y.len(), points.len());
290 assert!((result.y[0][1] - 1.0).abs() < 1e-5);
291 assert!((result.y[2][0] - 1.0).abs() < 1e-5);
292 }
293
294 #[test]
295 fn shooting_reports_internal_ivp_and_newton_statistics() {
296 let problem = HarmonicOscillatorBvp { target: 1.0 };
297 let method = Shooting::single(ExplicitRungeKutta::dop853().rtol(1e-10).atol(1e-12));
298
299 let result = BVP::ode(&problem, 0.0, std::f64::consts::FRAC_PI_2, [0.0, 0.5])
300 .method(method)
301 .solve()
302 .expect("BVP solve should converge");
303
304 assert!(result.evals.function > 0);
305 assert!(result.evals.jacobian > 0);
306 assert!(result.evals.newton > 0);
307 assert_eq!(result.evals.decompositions, result.evals.newton);
308 assert_eq!(result.evals.solves, result.evals.newton);
309 assert!(result.steps.total() > 0);
310 }
311}