1use crate::{
2 bvp::Boundary,
3 error::Error,
4 interpolate::Interpolation,
5 linalg::{Matrix, lin_solve, lu_decomp},
6 methods::{ToleranceConfig, bvp::BVPMethod},
7 ode::{ODE, OrdinaryNumericalMethod, solve_ode},
8 solout::{DefaultSolout, Solout, TEvalSolout},
9 solution::Solution,
10 stats::{Evals, Steps},
11 tolerance::Tolerance,
12 traits::{Real, State},
13};
14
15#[derive(Clone, Debug)]
21pub struct MultipleShooting<M> {
22 segments: usize,
23 max_iterations: usize,
24 tolerance: f64,
25 ode_solver: M,
26}
27
28impl<M> MultipleShooting<M> {
29 pub fn new(ode_solver: M) -> Self {
31 Self {
32 segments: 4,
33 max_iterations: 100,
34 tolerance: 1e-6,
35 ode_solver,
36 }
37 }
38
39 pub fn segments(mut self, segments: usize) -> Self {
41 self.segments = segments;
42 self
43 }
44
45 pub fn max_iterations(mut self, max_iterations: usize) -> Self {
47 self.max_iterations = max_iterations;
48 self
49 }
50
51 pub fn tolerance(mut self, tolerance: f64) -> Self {
53 self.tolerance = tolerance;
54 self
55 }
56}
57
58impl<M, T> ToleranceConfig<T> for MultipleShooting<M>
59where
60 T: Real,
61 M: ToleranceConfig<T>,
62{
63 fn rtol<V: Into<Tolerance<T>>>(mut self, rtol: V) -> Self {
64 self.ode_solver = self.ode_solver.rtol(rtol);
65 self
66 }
67
68 fn atol<V: Into<Tolerance<T>>>(mut self, atol: V) -> Self {
69 self.ode_solver = self.ode_solver.atol(atol);
70 self
71 }
72}
73
74struct BvpToOde<'a, EqType: ?Sized> {
75 problem: &'a EqType,
76}
77
78impl<EqType, T: Real, Y: State<T>> ODE<T, Y> for BvpToOde<'_, EqType>
79where
80 EqType: ODE<T, Y> + Boundary<T, Y> + ?Sized,
81{
82 #[inline]
83 fn diff(&self, t: T, y: &Y, dydt: &mut Y) {
84 self.problem.diff(t, y, dydt);
85 }
86}
87
88impl<M> MultipleShooting<M> {
89 fn mesh<T, Y>(&self, t0: T, tf: T) -> Result<Vec<T>, Error<T, Y>>
90 where
91 T: Real,
92 Y: State<T>,
93 {
94 if self.segments == 0 {
95 return Err(Error::BadInput {
96 msg: "Multiple shooting requires at least one segment.".to_string(),
97 });
98 }
99
100 let mut mesh = Vec::with_capacity(self.segments + 1);
101 let span = tf - t0;
102 let denominator = T::from_usize(self.segments).ok_or_else(|| Error::BadInput {
103 msg: "Could not represent multiple-shooting segment count as scalar type.".to_string(),
104 })?;
105
106 for i in 0..=self.segments {
107 let numerator = T::from_usize(i).ok_or_else(|| Error::BadInput {
108 msg: "Could not represent multiple-shooting mesh index as scalar type.".to_string(),
109 })?;
110 mesh.push(t0 + span * numerator / denominator);
111 }
112
113 Ok(mesh)
114 }
115
116 fn initial_nodes<EqType, T, Y>(
117 &self,
118 ode_system: &BvpToOde<'_, EqType>,
119 mesh: &[T],
120 t0: T,
121 tf: T,
122 y_guess: &Y,
123 ) -> Result<(Vec<Y>, Evals, Steps), Error<T, Y>>
124 where
125 T: Real,
126 Y: State<T>,
127 M: OrdinaryNumericalMethod<T, Y> + Interpolation<T, Y> + Clone,
128 EqType: ODE<T, Y> + Boundary<T, Y> + ?Sized,
129 {
130 let mut solver = self.ode_solver.clone();
131 let mut solout = TEvalSolout::new(mesh, t0, tf);
132 let solution = solve_ode(&mut solver, ode_system, t0, tf, y_guess, &mut solout)?;
133 if solution.y.len() != mesh.len() {
134 return Err(Error::BadInput {
135 msg: "Initial multiple-shooting IVP did not produce every mesh node.".to_string(),
136 });
137 }
138
139 Ok((solution.y, solution.evals, solution.steps))
140 }
141
142 fn residual<EqType, T, Y>(
143 &self,
144 problem: &EqType,
145 ode_system: &BvpToOde<'_, EqType>,
146 mesh: &[T],
147 nodes: &[Y],
148 dim: usize,
149 ) -> Result<(Vec<T>, Evals, Steps), Error<T, Y>>
150 where
151 T: Real,
152 Y: State<T>,
153 M: OrdinaryNumericalMethod<T, Y> + Interpolation<T, Y> + Clone,
154 EqType: ODE<T, Y> + Boundary<T, Y> + ?Sized,
155 {
156 let mut residual = Vec::with_capacity(nodes.len() * dim);
157 let mut boundary_residual = nodes[0].zeros_like();
158 let mut total_evals = Evals::new();
159 let mut total_steps = Steps::new();
160
161 problem.boundary(&nodes[0], &nodes[nodes.len() - 1], &mut boundary_residual);
162 for i in 0..dim {
163 residual.push(boundary_residual.get_component(i));
164 }
165
166 for segment_idx in 0..self.segments {
167 let mut solver = self.ode_solver.clone();
168 let mut solout = DefaultSolout::new();
169 let solution = solve_ode(
170 &mut solver,
171 ode_system,
172 mesh[segment_idx],
173 mesh[segment_idx + 1],
174 &nodes[segment_idx],
175 &mut solout,
176 )?;
177 total_evals += solution.evals;
178 total_steps += solution.steps;
179
180 let (_, y_end) = solution.last().map_err(|err| Error::BadInput {
181 msg: format!("Internal multiple-shooting IVP returned an empty solution: {err}"),
182 })?;
183
184 for i in 0..dim {
185 residual.push(y_end.get_component(i) - nodes[segment_idx + 1].get_component(i));
186 }
187 }
188
189 Ok((residual, total_evals, total_steps))
190 }
191
192 fn residual_norm<T: Real>(residual: &[T]) -> T {
193 residual
194 .iter()
195 .fold(T::zero(), |max_norm, value| max_norm.max(value.abs()))
196 }
197
198 fn apply_newton_step<T, Y>(nodes: &mut [Y], step: &[T], dim: usize)
199 where
200 T: Real,
201 Y: State<T>,
202 {
203 for (node_idx, node) in nodes.iter_mut().enumerate() {
204 for component_idx in 0..dim {
205 let index = node_idx * dim + component_idx;
206 node.set_component(
207 component_idx,
208 node.get_component(component_idx) + step[index],
209 );
210 }
211 }
212 }
213}
214
215impl<M, T, Y> BVPMethod<T, Y> for MultipleShooting<M>
216where
217 T: Real,
218 Y: State<T>,
219 M: OrdinaryNumericalMethod<T, Y> + Interpolation<T, Y> + Clone,
220{
221 fn solve<EqType, SoloutType>(
222 &mut self,
223 problem: &EqType,
224 t0: T,
225 tf: T,
226 y_guess: &Y,
227 solout: &mut SoloutType,
228 ) -> Result<Solution<T, Y>, Error<T, Y>>
229 where
230 EqType: ODE<T, Y> + Boundary<T, Y> + ?Sized,
231 SoloutType: Solout<T, Y>,
232 {
233 let dim = y_guess.len();
234 let unknowns = (self.segments + 1) * dim;
235 let mesh = self.mesh(t0, tf)?;
236 let ode_system = BvpToOde { problem };
237 let tolerance = T::from_f64(self.tolerance).ok_or_else(|| Error::BadInput {
238 msg: "BVP multiple-shooting tolerance cannot be represented by scalar type."
239 .to_string(),
240 })?;
241 let eps = T::default_epsilon().sqrt();
242
243 let (mut nodes, evals, steps) = self.initial_nodes(&ode_system, &mesh, t0, tf, y_guess)?;
244 let mut total_evals = evals;
245 let mut total_steps = steps;
246 let mut jacobian = Matrix::<T>::zeros(unknowns, unknowns);
247 let mut ip = vec![0; unknowns];
248
249 for _ in 0..self.max_iterations {
250 let (residual, evals, steps) =
251 self.residual(problem, &ode_system, &mesh, &nodes, dim)?;
252 total_evals += evals;
253 total_steps += steps;
254
255 if Self::residual_norm(&residual) <= tolerance {
256 let mut final_solver = self.ode_solver.clone();
257 let mut solution =
258 solve_ode(&mut final_solver, &ode_system, t0, tf, &nodes[0], solout)?;
259 solution.evals += total_evals;
260 solution.steps += total_steps;
261 return Ok(solution);
262 }
263
264 for j in 0..unknowns {
265 let node_idx = j / dim;
266 let component_idx = j % dim;
267 let mut perturbed_nodes = nodes.clone();
268 let y_j = perturbed_nodes[node_idx].get_component(component_idx);
269 let perturbation = eps * y_j.abs().max(T::one());
270 perturbed_nodes[node_idx].set_component(component_idx, y_j + perturbation);
271
272 let (perturbed_residual, evals, steps) =
273 self.residual(problem, &ode_system, &mesh, &perturbed_nodes, dim)?;
274 total_evals += evals;
275 total_steps += steps;
276 total_evals.jacobian += 1;
277
278 for i in 0..unknowns {
279 jacobian[(i, j)] = (perturbed_residual[i] - residual[i]) / perturbation;
280 }
281 }
282
283 let mut step = residual.iter().map(|value| -*value).collect::<Vec<_>>();
284 lu_decomp(&mut jacobian, &mut ip).map_err(|err| Error::LinearAlgebra {
285 t: t0,
286 y: nodes[0].clone(),
287 msg: err.to_string(),
288 })?;
289 lin_solve(&jacobian, &mut step, &ip);
290 total_evals.newton += 1;
291 total_evals.decompositions += 1;
292 total_evals.solves += 1;
293 Self::apply_newton_step(&mut nodes, &step, dim);
294 }
295
296 Err(Error::MaxSteps {
297 t: t0,
298 y: nodes.first().cloned().unwrap_or_else(|| y_guess.clone()),
299 })
300 }
301}