Skip to main content

differential_equations/methods/bvp/shooting/
multiple.rs

1use crate::{
2    bvp::Boundary,
3    error::Error,
4    interpolate::Interpolation,
5    linalg::{Matrix, lin_solve, lu_decomp},
6    methods::{ToleranceConfig, bvp::BVPMethod},
7    ode::{ODE, OrdinaryNumericalMethod, solve_ode},
8    solout::{DefaultSolout, Solout, TEvalSolout},
9    solution::Solution,
10    stats::{Evals, Steps},
11    tolerance::Tolerance,
12    traits::{Real, State},
13};
14
15/// Multiple-shooting method for ODE boundary value problems.
16///
17/// This method partitions the interval, solves an IVP on each subinterval, and
18/// applies Newton iteration to enforce both endpoint boundary conditions and
19/// continuity between neighboring shooting segments.
20#[derive(Clone, Debug)]
21pub struct MultipleShooting<M> {
22    segments: usize,
23    max_iterations: usize,
24    tolerance: f64,
25    ode_solver: M,
26}
27
28impl<M> MultipleShooting<M> {
29    /// Create a multiple-shooting method from an ODE IVP solver.
30    pub fn new(ode_solver: M) -> Self {
31        Self {
32            segments: 4,
33            max_iterations: 100,
34            tolerance: 1e-6,
35            ode_solver,
36        }
37    }
38
39    /// Set the number of shooting segments.
40    pub fn segments(mut self, segments: usize) -> Self {
41        self.segments = segments;
42        self
43    }
44
45    /// Set the maximum number of Newton iterations.
46    pub fn max_iterations(mut self, max_iterations: usize) -> Self {
47        self.max_iterations = max_iterations;
48        self
49    }
50
51    /// Set the infinity-norm tolerance for the full multiple-shooting residual.
52    pub fn tolerance(mut self, tolerance: f64) -> Self {
53        self.tolerance = tolerance;
54        self
55    }
56}
57
58impl<M, T> ToleranceConfig<T> for MultipleShooting<M>
59where
60    T: Real,
61    M: ToleranceConfig<T>,
62{
63    fn rtol<V: Into<Tolerance<T>>>(mut self, rtol: V) -> Self {
64        self.ode_solver = self.ode_solver.rtol(rtol);
65        self
66    }
67
68    fn atol<V: Into<Tolerance<T>>>(mut self, atol: V) -> Self {
69        self.ode_solver = self.ode_solver.atol(atol);
70        self
71    }
72}
73
74struct BvpToOde<'a, EqType: ?Sized> {
75    problem: &'a EqType,
76}
77
78impl<EqType, T: Real, Y: State<T>> ODE<T, Y> for BvpToOde<'_, EqType>
79where
80    EqType: ODE<T, Y> + Boundary<T, Y> + ?Sized,
81{
82    #[inline]
83    fn diff(&self, t: T, y: &Y, dydt: &mut Y) {
84        self.problem.diff(t, y, dydt);
85    }
86}
87
88impl<M> MultipleShooting<M> {
89    fn mesh<T, Y>(&self, t0: T, tf: T) -> Result<Vec<T>, Error<T, Y>>
90    where
91        T: Real,
92        Y: State<T>,
93    {
94        if self.segments == 0 {
95            return Err(Error::BadInput {
96                msg: "Multiple shooting requires at least one segment.".to_string(),
97            });
98        }
99
100        let mut mesh = Vec::with_capacity(self.segments + 1);
101        let span = tf - t0;
102        let denominator = T::from_usize(self.segments).ok_or_else(|| Error::BadInput {
103            msg: "Could not represent multiple-shooting segment count as scalar type.".to_string(),
104        })?;
105
106        for i in 0..=self.segments {
107            let numerator = T::from_usize(i).ok_or_else(|| Error::BadInput {
108                msg: "Could not represent multiple-shooting mesh index as scalar type.".to_string(),
109            })?;
110            mesh.push(t0 + span * numerator / denominator);
111        }
112
113        Ok(mesh)
114    }
115
116    fn initial_nodes<EqType, T, Y>(
117        &self,
118        ode_system: &BvpToOde<'_, EqType>,
119        mesh: &[T],
120        t0: T,
121        tf: T,
122        y_guess: &Y,
123    ) -> Result<(Vec<Y>, Evals, Steps), Error<T, Y>>
124    where
125        T: Real,
126        Y: State<T>,
127        M: OrdinaryNumericalMethod<T, Y> + Interpolation<T, Y> + Clone,
128        EqType: ODE<T, Y> + Boundary<T, Y> + ?Sized,
129    {
130        let mut solver = self.ode_solver.clone();
131        let mut solout = TEvalSolout::new(mesh, t0, tf);
132        let solution = solve_ode(&mut solver, ode_system, t0, tf, y_guess, &mut solout)?;
133        if solution.y.len() != mesh.len() {
134            return Err(Error::BadInput {
135                msg: "Initial multiple-shooting IVP did not produce every mesh node.".to_string(),
136            });
137        }
138
139        Ok((solution.y, solution.evals, solution.steps))
140    }
141
142    fn residual<EqType, T, Y>(
143        &self,
144        problem: &EqType,
145        ode_system: &BvpToOde<'_, EqType>,
146        mesh: &[T],
147        nodes: &[Y],
148        dim: usize,
149    ) -> Result<(Vec<T>, Evals, Steps), Error<T, Y>>
150    where
151        T: Real,
152        Y: State<T>,
153        M: OrdinaryNumericalMethod<T, Y> + Interpolation<T, Y> + Clone,
154        EqType: ODE<T, Y> + Boundary<T, Y> + ?Sized,
155    {
156        let mut residual = Vec::with_capacity(nodes.len() * dim);
157        let mut boundary_residual = nodes[0].zeros_like();
158        let mut total_evals = Evals::new();
159        let mut total_steps = Steps::new();
160
161        problem.boundary(&nodes[0], &nodes[nodes.len() - 1], &mut boundary_residual);
162        for i in 0..dim {
163            residual.push(boundary_residual.get_component(i));
164        }
165
166        for segment_idx in 0..self.segments {
167            let mut solver = self.ode_solver.clone();
168            let mut solout = DefaultSolout::new();
169            let solution = solve_ode(
170                &mut solver,
171                ode_system,
172                mesh[segment_idx],
173                mesh[segment_idx + 1],
174                &nodes[segment_idx],
175                &mut solout,
176            )?;
177            total_evals += solution.evals;
178            total_steps += solution.steps;
179
180            let (_, y_end) = solution.last().map_err(|err| Error::BadInput {
181                msg: format!("Internal multiple-shooting IVP returned an empty solution: {err}"),
182            })?;
183
184            for i in 0..dim {
185                residual.push(y_end.get_component(i) - nodes[segment_idx + 1].get_component(i));
186            }
187        }
188
189        Ok((residual, total_evals, total_steps))
190    }
191
192    fn residual_norm<T: Real>(residual: &[T]) -> T {
193        residual
194            .iter()
195            .fold(T::zero(), |max_norm, value| max_norm.max(value.abs()))
196    }
197
198    fn apply_newton_step<T, Y>(nodes: &mut [Y], step: &[T], dim: usize)
199    where
200        T: Real,
201        Y: State<T>,
202    {
203        for (node_idx, node) in nodes.iter_mut().enumerate() {
204            for component_idx in 0..dim {
205                let index = node_idx * dim + component_idx;
206                node.set_component(
207                    component_idx,
208                    node.get_component(component_idx) + step[index],
209                );
210            }
211        }
212    }
213}
214
215impl<M, T, Y> BVPMethod<T, Y> for MultipleShooting<M>
216where
217    T: Real,
218    Y: State<T>,
219    M: OrdinaryNumericalMethod<T, Y> + Interpolation<T, Y> + Clone,
220{
221    fn solve<EqType, SoloutType>(
222        &mut self,
223        problem: &EqType,
224        t0: T,
225        tf: T,
226        y_guess: &Y,
227        solout: &mut SoloutType,
228    ) -> Result<Solution<T, Y>, Error<T, Y>>
229    where
230        EqType: ODE<T, Y> + Boundary<T, Y> + ?Sized,
231        SoloutType: Solout<T, Y>,
232    {
233        let dim = y_guess.len();
234        let unknowns = (self.segments + 1) * dim;
235        let mesh = self.mesh(t0, tf)?;
236        let ode_system = BvpToOde { problem };
237        let tolerance = T::from_f64(self.tolerance).ok_or_else(|| Error::BadInput {
238            msg: "BVP multiple-shooting tolerance cannot be represented by scalar type."
239                .to_string(),
240        })?;
241        let eps = T::default_epsilon().sqrt();
242
243        let (mut nodes, evals, steps) = self.initial_nodes(&ode_system, &mesh, t0, tf, y_guess)?;
244        let mut total_evals = evals;
245        let mut total_steps = steps;
246        let mut jacobian = Matrix::<T>::zeros(unknowns, unknowns);
247        let mut ip = vec![0; unknowns];
248
249        for _ in 0..self.max_iterations {
250            let (residual, evals, steps) =
251                self.residual(problem, &ode_system, &mesh, &nodes, dim)?;
252            total_evals += evals;
253            total_steps += steps;
254
255            if Self::residual_norm(&residual) <= tolerance {
256                let mut final_solver = self.ode_solver.clone();
257                let mut solution =
258                    solve_ode(&mut final_solver, &ode_system, t0, tf, &nodes[0], solout)?;
259                solution.evals += total_evals;
260                solution.steps += total_steps;
261                return Ok(solution);
262            }
263
264            for j in 0..unknowns {
265                let node_idx = j / dim;
266                let component_idx = j % dim;
267                let mut perturbed_nodes = nodes.clone();
268                let y_j = perturbed_nodes[node_idx].get_component(component_idx);
269                let perturbation = eps * y_j.abs().max(T::one());
270                perturbed_nodes[node_idx].set_component(component_idx, y_j + perturbation);
271
272                let (perturbed_residual, evals, steps) =
273                    self.residual(problem, &ode_system, &mesh, &perturbed_nodes, dim)?;
274                total_evals += evals;
275                total_steps += steps;
276                total_evals.jacobian += 1;
277
278                for i in 0..unknowns {
279                    jacobian[(i, j)] = (perturbed_residual[i] - residual[i]) / perturbation;
280                }
281            }
282
283            let mut step = residual.iter().map(|value| -*value).collect::<Vec<_>>();
284            lu_decomp(&mut jacobian, &mut ip).map_err(|err| Error::LinearAlgebra {
285                t: t0,
286                y: nodes[0].clone(),
287                msg: err.to_string(),
288            })?;
289            lin_solve(&jacobian, &mut step, &ip);
290            total_evals.newton += 1;
291            total_evals.decompositions += 1;
292            total_evals.solves += 1;
293            Self::apply_newton_step(&mut nodes, &step, dim);
294        }
295
296        Err(Error::MaxSteps {
297            t: t0,
298            y: nodes.first().cloned().unwrap_or_else(|| y_guess.clone()),
299        })
300    }
301}