Skip to main content

differential_equations/methods/bvp/shooting/
single.rs

1use crate::{
2    bvp::Boundary,
3    error::Error,
4    interpolate::Interpolation,
5    linalg::{Matrix, lin_solve, lu_decomp},
6    methods::{ToleranceConfig, bvp::BVPMethod},
7    ode::{ODE, OrdinaryNumericalMethod, solve_ode},
8    solout::{DefaultSolout, Solout},
9    solution::Solution,
10    stats::{Evals, Steps},
11    tolerance::Tolerance,
12    traits::{Real, State},
13};
14
15/// Single-shooting method for ODE boundary value problems.
16///
17/// This method reduces a BVP to a sequence of IVPs and applies Newton iteration
18/// to adjust the initial state until the endpoint boundary residual is small.
19#[derive(Clone, Debug)]
20pub struct SingleShooting<M> {
21    max_iterations: usize,
22    tolerance: f64,
23    ode_solver: M,
24}
25
26impl<M> SingleShooting<M> {
27    /// Create a single-shooting method from an ODE IVP solver.
28    pub fn new(ode_solver: M) -> Self {
29        Self {
30            max_iterations: 100,
31            tolerance: 1e-6,
32            ode_solver,
33        }
34    }
35
36    /// Set the maximum number of Newton iterations.
37    pub fn max_iterations(mut self, max_iterations: usize) -> Self {
38        self.max_iterations = max_iterations;
39        self
40    }
41
42    /// Set the infinity-norm tolerance for the boundary residual.
43    pub fn tolerance(mut self, tolerance: f64) -> Self {
44        self.tolerance = tolerance;
45        self
46    }
47}
48
49impl<M, T> ToleranceConfig<T> for SingleShooting<M>
50where
51    T: Real,
52    M: ToleranceConfig<T>,
53{
54    fn rtol<V: Into<Tolerance<T>>>(mut self, rtol: V) -> Self {
55        self.ode_solver = self.ode_solver.rtol(rtol);
56        self
57    }
58
59    fn atol<V: Into<Tolerance<T>>>(mut self, atol: V) -> Self {
60        self.ode_solver = self.ode_solver.atol(atol);
61        self
62    }
63}
64
65/// Wrapper to adapt a BVP definition to the ODE trait for internal IVP solves.
66struct BvpToOde<'a, EqType: ?Sized> {
67    problem: &'a EqType,
68}
69
70impl<EqType, T: Real, Y: State<T>> ODE<T, Y> for BvpToOde<'_, EqType>
71where
72    EqType: ODE<T, Y> + Boundary<T, Y> + ?Sized,
73{
74    #[inline]
75    fn diff(&self, t: T, y: &Y, dydt: &mut Y) {
76        self.problem.diff(t, y, dydt);
77    }
78}
79
80impl<M, T, Y> BVPMethod<T, Y> for SingleShooting<M>
81where
82    T: Real,
83    Y: State<T>,
84    M: OrdinaryNumericalMethod<T, Y> + Interpolation<T, Y> + Clone,
85{
86    fn solve<EqType, SoloutType>(
87        &mut self,
88        problem: &EqType,
89        t0: T,
90        tf: T,
91        y_guess: &Y,
92        solout: &mut SoloutType,
93    ) -> Result<Solution<T, Y>, Error<T, Y>>
94    where
95        EqType: ODE<T, Y> + Boundary<T, Y> + ?Sized,
96        SoloutType: Solout<T, Y>,
97    {
98        let dim = y_guess.len();
99        let mut y = y_guess.clone();
100        let mut residual = y_guess.zeros_like();
101        let mut jacobian = Matrix::<T>::zeros(dim, dim);
102        let mut ip = vec![0; dim];
103        let mut total_evals = Evals::new();
104        let mut total_steps = Steps::new();
105        let tolerance = T::from_f64(self.tolerance).ok_or_else(|| Error::BadInput {
106            msg: "BVP shooting tolerance cannot be represented by scalar type.".to_string(),
107        })?;
108
109        let ode_system = BvpToOde { problem };
110
111        for _ in 0..self.max_iterations {
112            let mut trial_solver = self.ode_solver.clone();
113            let mut trial_solout = DefaultSolout::new();
114            let sol = solve_ode(
115                &mut trial_solver,
116                &ode_system,
117                t0,
118                tf,
119                &y,
120                &mut trial_solout,
121            )?;
122            total_evals += sol.evals;
123            total_steps += sol.steps;
124
125            let (_, y_f) = sol.last().map_err(|err| Error::BadInput {
126                msg: format!("Internal IVP solve returned an empty solution: {err}"),
127            })?;
128
129            problem.boundary(&y, y_f, &mut residual);
130
131            if residual.max_norm() <= tolerance {
132                let mut final_solver = self.ode_solver.clone();
133                let mut solution = solve_ode(&mut final_solver, &ode_system, t0, tf, &y, solout)?;
134                solution.evals += total_evals;
135                solution.steps += total_steps;
136                return Ok(solution);
137            }
138
139            let eps = T::default_epsilon().sqrt();
140            for j in 0..dim {
141                let mut y_perturbed = y.clone();
142                let y_j = y.get_component(j);
143                let perturbation = eps * y_j.abs().max(T::one());
144                y_perturbed.set_component(j, y_j + perturbation);
145
146                let mut perturbed_solver = self.ode_solver.clone();
147                let mut perturbed_solout = DefaultSolout::new();
148                let sol_perturbed = solve_ode(
149                    &mut perturbed_solver,
150                    &ode_system,
151                    t0,
152                    tf,
153                    &y_perturbed,
154                    &mut perturbed_solout,
155                )?;
156                total_evals += sol_perturbed.evals;
157                total_steps += sol_perturbed.steps;
158                let (_, y_f_perturbed) = sol_perturbed.last().map_err(|err| Error::BadInput {
159                    msg: format!("Internal perturbed IVP solve returned an empty solution: {err}"),
160                })?;
161
162                let mut res_perturbed = residual.clone();
163                problem.boundary(&y_perturbed, y_f_perturbed, &mut res_perturbed);
164                total_evals.jacobian += 1;
165
166                for i in 0..dim {
167                    jacobian[(i, j)] =
168                        (res_perturbed.get_component(i) - residual.get_component(i)) / perturbation;
169                }
170            }
171
172            let mut step = y.zeros_like();
173            for i in 0..dim {
174                step.set_component(i, -residual.get_component(i));
175            }
176
177            lu_decomp(&mut jacobian, &mut ip).map_err(|err| Error::LinearAlgebra {
178                t: t0,
179                y: y.clone(),
180                msg: err.to_string(),
181            })?;
182            lin_solve(&jacobian, &mut step, &ip);
183            total_evals.newton += 1;
184            total_evals.decompositions += 1;
185            total_evals.solves += 1;
186
187            for i in 0..dim {
188                y.set_component(i, y.get_component(i) + step.get_component(i));
189            }
190        }
191
192        Err(Error::MaxSteps {
193            t: t0,
194            y: y_guess.clone(),
195        })
196    }
197}
198
199#[cfg(test)]
200mod tests {
201    use super::*;
202    use crate::{
203        bvp::BVP,
204        methods::{ExplicitRungeKutta, bvp::Shooting},
205    };
206
207    struct HarmonicOscillatorBvp {
208        target: f64,
209    }
210
211    impl ODE<f64, [f64; 2]> for HarmonicOscillatorBvp {
212        fn diff(&self, _t: f64, y: &[f64; 2], dydt: &mut [f64; 2]) {
213            dydt[0] = y[1];
214            dydt[1] = -y[0];
215        }
216    }
217
218    impl Boundary<f64, [f64; 2]> for HarmonicOscillatorBvp {
219        fn boundary(&self, y_a: &[f64; 2], y_b: &[f64; 2], res: &mut [f64; 2]) {
220            res[0] = y_a[0];
221            res[1] = y_b[0] - self.target;
222        }
223    }
224
225    #[test]
226    fn shooting_solves_harmonic_oscillator_with_trait_api() {
227        let problem = HarmonicOscillatorBvp { target: 1.0 };
228        let method = Shooting::single(ExplicitRungeKutta::dop853());
229
230        let result = BVP::ode(&problem, 0.0, std::f64::consts::FRAC_PI_2, [0.0, 0.5])
231            .method(method)
232            .solve()
233            .expect("BVP solve should converge");
234
235        let (_, y_initial) = result.iter().next().expect("solution has an initial point");
236        let (_, y_final) = result.last().expect("solution has a final point");
237
238        assert!(y_initial[0].abs() < 1e-5);
239        assert!((y_initial[1] - 1.0).abs() < 1e-5);
240        assert!((y_final[0] - 1.0).abs() < 1e-5);
241        assert!(y_final[1].abs() < 1e-5);
242    }
243
244    #[test]
245    fn shooting_solves_harmonic_oscillator_with_closure_api() {
246        let method = Shooting::single(ExplicitRungeKutta::dop853());
247
248        let result = BVP::ode_from_fn(
249            |_t, y: &[f64; 2], dydt: &mut [f64; 2]| {
250                dydt[0] = y[1];
251                dydt[1] = -y[0];
252            },
253            |y_a: &[f64; 2], y_b: &[f64; 2], res: &mut [f64; 2]| {
254                res[0] = y_a[0];
255                res[1] = y_b[0] - 1.0;
256            },
257            0.0,
258            std::f64::consts::FRAC_PI_2,
259            [0.0, 0.5],
260        )
261        .method(method)
262        .solve()
263        .expect("BVP solve should converge");
264
265        let (_, y_initial) = result.iter().next().expect("solution has an initial point");
266        let (_, y_final) = result.last().expect("solution has a final point");
267
268        assert!((y_initial[1] - 1.0).abs() < 1e-5);
269        assert!((y_final[0] - 1.0).abs() < 1e-5);
270    }
271
272    #[test]
273    fn shooting_supports_t_eval_output_for_final_trajectory() {
274        let problem = HarmonicOscillatorBvp { target: 1.0 };
275        let method = Shooting::single(ExplicitRungeKutta::dop853());
276        let points = [
277            0.0,
278            std::f64::consts::FRAC_PI_4,
279            std::f64::consts::FRAC_PI_2,
280        ];
281
282        let result = BVP::ode(&problem, 0.0, std::f64::consts::FRAC_PI_2, [0.0, 0.5])
283            .t_eval(points)
284            .method(method)
285            .solve()
286            .expect("BVP solve should converge with t_eval output");
287
288        assert_eq!(result.t, points);
289        assert_eq!(result.y.len(), points.len());
290        assert!((result.y[0][1] - 1.0).abs() < 1e-5);
291        assert!((result.y[2][0] - 1.0).abs() < 1e-5);
292    }
293
294    #[test]
295    fn shooting_reports_internal_ivp_and_newton_statistics() {
296        let problem = HarmonicOscillatorBvp { target: 1.0 };
297        let method = Shooting::single(ExplicitRungeKutta::dop853().rtol(1e-10).atol(1e-12));
298
299        let result = BVP::ode(&problem, 0.0, std::f64::consts::FRAC_PI_2, [0.0, 0.5])
300            .method(method)
301            .solve()
302            .expect("BVP solve should converge");
303
304        assert!(result.evals.function > 0);
305        assert!(result.evals.jacobian > 0);
306        assert!(result.evals.newton > 0);
307        assert_eq!(result.evals.decompositions, result.evals.newton);
308        assert_eq!(result.evals.solves, result.evals.newton);
309        assert!(result.steps.total() > 0);
310    }
311}