Skip to main content

scirs2_integrate/shooting/
mod.rs

1//! Shooting methods for Boundary Value Problems (BVPs)
2//!
3//! This module provides shooting-based methods for solving two-point boundary
4//! value problems (BVPs) of the form:
5//!
6//!   y'(t) = f(t, y),   t ∈ [a, b]
7//!   g(y(a), y(b)) = 0
8//!
9//! ## Methods
10//!
11//! - **Single shooting**: Parameterize y(a) with free parameters, integrate to b,
12//!   and solve the boundary residual using Newton's method.
13//! - **Multiple shooting**: Divide \[a,b\] into subintervals, shoot over each, and
14//!   enforce continuity + boundary conditions simultaneously.
15//! - **Orthogonal collocation**: Collocate at Gaussian or Radau points within
16//!   subintervals for higher accuracy.
17//! - **Periodic orbit finder**: Find limit cycles by single shooting with period
18//!   as an additional unknown and a phase condition.
19//!
20//! ## References
21//!
22//! - Keller (1968), "Numerical Methods for Two-Point Boundary Value Problems"
23//! - Stoer & Bulirsch (1980), "Introduction to Numerical Analysis"
24//! - Ascher, Mattheij, Russell (1995), "Numerical Solution of Boundary Value ODEs"
25
26use crate::error::{IntegrateError, IntegrateResult};
27use scirs2_core::ndarray::{Array1, Array2};
28
29// ---------------------------------------------------------------------------
30// Helper
31// ---------------------------------------------------------------------------
32
33#[inline]
34fn to_f(v: f64) -> f64 {
35    v
36}
37
38/// Gaussian elimination with partial pivoting (modifies A and b in place)
39fn gauss_solve(a: &mut Array2<f64>, b: &mut Array1<f64>) -> IntegrateResult<Array1<f64>> {
40    let n = b.len();
41    for col in 0..n {
42        let mut max_row = col;
43        let mut max_val = a[[col, col]].abs();
44        for row in (col + 1)..n {
45            let v = a[[row, col]].abs();
46            if v > max_val {
47                max_val = v;
48                max_row = row;
49            }
50        }
51        if max_val < 1e-300 {
52            return Err(IntegrateError::LinearSolveError(
53                "Singular matrix in shooting solve".to_string(),
54            ));
55        }
56        if max_row != col {
57            for j in col..n {
58                let tmp = a[[col, j]];
59                a[[col, j]] = a[[max_row, j]];
60                a[[max_row, j]] = tmp;
61            }
62            b.swap(col, max_row);
63        }
64        let pivot = a[[col, col]];
65        for row in (col + 1)..n {
66            let factor = a[[row, col]] / pivot;
67            for j in col..n {
68                let u = factor * a[[col, j]];
69                a[[row, j]] -= u;
70            }
71            let bup = factor * b[col];
72            b[row] -= bup;
73        }
74    }
75    let mut x = Array1::<f64>::zeros(n);
76    for i in (0..n).rev() {
77        let mut s = b[i];
78        for j in (i + 1)..n {
79            s -= a[[i, j]] * x[j];
80        }
81        x[i] = s / a[[i, i]];
82    }
83    Ok(x)
84}
85
86/// Classical 4th-order Runge-Kutta step (fixed step)
87fn rk4_step<F>(f: &F, t: f64, y: &Array1<f64>, h: f64) -> Array1<f64>
88where
89    F: Fn(f64, &Array1<f64>) -> Array1<f64>,
90{
91    let n = y.len();
92    let k1 = f(t, y);
93    let mut y2 = Array1::<f64>::zeros(n);
94    for i in 0..n {
95        y2[i] = y[i] + 0.5 * h * k1[i];
96    }
97    let k2 = f(t + 0.5 * h, &y2);
98    let mut y3 = Array1::<f64>::zeros(n);
99    for i in 0..n {
100        y3[i] = y[i] + 0.5 * h * k2[i];
101    }
102    let k3 = f(t + 0.5 * h, &y3);
103    let mut y4 = Array1::<f64>::zeros(n);
104    for i in 0..n {
105        y4[i] = y[i] + h * k3[i];
106    }
107    let k4 = f(t + h, &y4);
108    let mut y_new = Array1::<f64>::zeros(n);
109    for i in 0..n {
110        y_new[i] = y[i] + h / 6.0 * (k1[i] + 2.0 * k2[i] + 2.0 * k3[i] + k4[i]);
111    }
112    y_new
113}
114
115/// Integrate ODE from t0 to t1 using fixed-step RK4, returns final state
116fn integrate_rk4<F>(f: &F, t0: f64, y0: &Array1<f64>, t1: f64, n_steps: usize) -> Array1<f64>
117where
118    F: Fn(f64, &Array1<f64>) -> Array1<f64>,
119{
120    let h = (t1 - t0) / n_steps as f64;
121    let mut t = t0;
122    let mut y = y0.clone();
123    for _ in 0..n_steps {
124        y = rk4_step(f, t, &y, h);
125        t += h;
126    }
127    y
128}
129
130/// Compute numerical Jacobian of residual g(s) w.r.t. s using central differences
131fn numerical_jacobian<G>(g: &G, s: &Array1<f64>, eps: f64) -> Array2<f64>
132where
133    G: Fn(&Array1<f64>) -> Array1<f64>,
134{
135    let n = s.len();
136    let m = g(s).len();
137    let mut jac = Array2::<f64>::zeros((m, n));
138    for j in 0..n {
139        let mut sp = s.clone();
140        sp[j] += eps;
141        let mut sm = s.clone();
142        sm[j] -= eps;
143        let fp = g(&sp);
144        let fm = g(&sm);
145        for i in 0..m {
146            jac[[i, j]] = (fp[i] - fm[i]) / (2.0 * eps);
147        }
148    }
149    jac
150}
151
152// ---------------------------------------------------------------------------
153// Result type
154// ---------------------------------------------------------------------------
155
156/// Result of a shooting-based BVP solve
157#[derive(Debug, Clone)]
158pub struct BVPResult {
159    /// Time points of the solution
160    pub t: Vec<f64>,
161    /// Solution state at each time point
162    pub y: Vec<Array1<f64>>,
163    /// Boundary residual at the solution
164    pub residual: f64,
165    /// Error estimate (based on boundary residual)
166    pub error: f64,
167    /// Number of Newton iterations
168    pub n_newton_iters: usize,
169    /// Whether the solver converged
170    pub success: bool,
171    /// Message describing termination
172    pub message: String,
173}
174
175// ---------------------------------------------------------------------------
176// Single Shooting
177// ---------------------------------------------------------------------------
178
179/// Configuration for single and multiple shooting methods
180#[derive(Debug, Clone)]
181pub struct ShootingConfig {
182    /// Number of RK4 integration steps per subinterval
183    pub n_steps: usize,
184    /// Newton tolerance for boundary residual
185    pub newton_tol: f64,
186    /// Maximum Newton iterations
187    pub max_newton_iter: usize,
188    /// Finite difference epsilon for Jacobians
189    pub fd_eps: f64,
190    /// Number of subintervals for multiple shooting
191    pub n_subintervals: usize,
192}
193
194impl Default for ShootingConfig {
195    fn default() -> Self {
196        Self {
197            n_steps: 100,
198            newton_tol: 1e-8,
199            max_newton_iter: 50,
200            fd_eps: 1e-7,
201            n_subintervals: 5,
202        }
203    }
204}
205
206/// Single-shooting method for BVPs.
207///
208/// The missing initial conditions y(a) are parameterized by a vector s ∈ R^k.
209/// We integrate to b and solve the boundary condition g(y(a), y(b)) = 0.
210///
211/// # Arguments
212///
213/// * `f` - ODE function: f(t, y) → y'
214/// * `g` - Boundary residual: g(s, y_b) → 0, where s = free parameters at a
215/// * `merge_initial` - Merge known+free initial conditions: (s) → y(a)
216/// * `t_span` - [a, b]
217/// * `s0` - Initial guess for free parameters
218/// * `cfg` - Solver configuration
219///
220/// # Returns
221///
222/// `BVPResult` with the solution trajectory or error.
223pub struct SingleShooting;
224
225impl SingleShooting {
226    /// Solve BVP by single shooting.
227    ///
228    /// The boundary condition is `g(s, y_b) = 0` where s are the free initial
229    /// conditions and y_b = y(b) is obtained by integrating forward.
230    pub fn solve<ODE, BC, IC>(
231        ode: &ODE,
232        bc: &BC,
233        initial_condition: &IC,
234        t_span: [f64; 2],
235        s0: Array1<f64>,
236        cfg: &ShootingConfig,
237    ) -> IntegrateResult<BVPResult>
238    where
239        ODE: Fn(f64, &Array1<f64>) -> Array1<f64>,
240        BC: Fn(&Array1<f64>, &Array1<f64>) -> Array1<f64>,
241        IC: Fn(&Array1<f64>) -> Array1<f64>,
242    {
243        let [ta, tb] = t_span;
244        let n_s = s0.len();
245
246        // Shooting function: shoot from s, return boundary residual g(ya, yb)
247        let shoot = |s: &Array1<f64>| -> Array1<f64> {
248            let ya = initial_condition(s);
249            let yb = integrate_rk4(ode, ta, &ya, tb, cfg.n_steps);
250            bc(&ya, &yb)
251        };
252
253        let mut s = s0.clone();
254        let mut n_iters = 0usize;
255        let mut converged = false;
256
257        for iter in 0..cfg.max_newton_iter {
258            let res = shoot(&s);
259            let res_norm: f64 = res.iter().map(|&v| v * v).sum::<f64>().sqrt();
260
261            if res_norm < cfg.newton_tol {
262                n_iters = iter + 1;
263                converged = true;
264                break;
265            }
266
267            // Jacobian of shooting function w.r.t. s
268            let mut jac = numerical_jacobian(&shoot, &s, cfg.fd_eps);
269            let mut neg_res = res.mapv(|v| -v);
270
271            match gauss_solve(&mut jac, &mut neg_res) {
272                Ok(delta) => {
273                    for i in 0..n_s {
274                        s[i] += delta[i];
275                    }
276                }
277                Err(_) => {
278                    // Fallback: gradient descent step
279                    let res_ref = shoot(&s);
280                    let grad_norm_sq: f64 = res_ref.iter().map(|&v| v * v).sum();
281                    if grad_norm_sq > 0.0 {
282                        for i in 0..n_s {
283                            s[i] -= cfg.fd_eps * res_ref[i.min(res_ref.len() - 1)];
284                        }
285                    }
286                }
287            }
288        }
289
290        if !converged {
291            n_iters = cfg.max_newton_iter;
292        }
293
294        // Reconstruct solution trajectory
295        let ya = initial_condition(&s);
296        let (t_traj, y_traj) = trajectory_rk4(ode, ta, &ya, tb, cfg.n_steps);
297
298        let yb = y_traj.last().cloned().unwrap_or_else(|| ya.clone());
299        let final_res = bc(&ya, &yb);
300        let residual: f64 = final_res.iter().map(|&v| v * v).sum::<f64>().sqrt();
301
302        Ok(BVPResult {
303            t: t_traj,
304            y: y_traj,
305            residual,
306            error: residual,
307            n_newton_iters: n_iters,
308            success: converged,
309            message: if converged {
310                "Single shooting converged".to_string()
311            } else {
312                format!(
313                    "Single shooting did not converge in {} iterations",
314                    cfg.max_newton_iter
315                )
316            },
317        })
318    }
319}
320
321/// Reconstruct trajectory using RK4, returning (times, states)
322fn trajectory_rk4<F>(
323    f: &F,
324    t0: f64,
325    y0: &Array1<f64>,
326    t1: f64,
327    n_steps: usize,
328) -> (Vec<f64>, Vec<Array1<f64>>)
329where
330    F: Fn(f64, &Array1<f64>) -> Array1<f64>,
331{
332    let h = (t1 - t0) / n_steps as f64;
333    let mut t = t0;
334    let mut y = y0.clone();
335    let mut ts = vec![t];
336    let mut ys = vec![y.clone()];
337
338    for _ in 0..n_steps {
339        y = rk4_step(f, t, &y, h);
340        t += h;
341        ts.push(t);
342        ys.push(y.clone());
343    }
344    (ts, ys)
345}
346
347// ---------------------------------------------------------------------------
348// Multiple Shooting
349// ---------------------------------------------------------------------------
350
351/// Multiple-shooting method for BVPs.
352///
353/// Divides [a, b] into M subintervals [t_0, t_1, ..., t_M] and introduces
354/// state variables s_i = y(t_i^-) at each interior (and initial) node.
355///
356/// The system to solve is:
357///   - Boundary conditions: g(s_0, s_M) = 0 (n_bc equations)
358///   - Continuity: y(t_i; s_i) = s_{i+1} for i = 1, ..., M-1 (n*(M-1) equations)
359///   - Total: n*M unknowns in [s_0, s_1, ..., s_{M-1}]
360///
361/// The Jacobian has a block-bidiagonal structure (solved here via dense Newton).
362pub struct MultipleShooting;
363
364impl MultipleShooting {
365    /// Solve BVP by multiple shooting.
366    ///
367    /// # Arguments
368    ///
369    /// * `ode` - ODE function f(t, y)
370    /// * `bc` - Boundary conditions: g(y(a), y(b)) → 0 (n_bc equations)
371    /// * `t_nodes` - Subinterval nodes [t_0, t_1, ..., t_M] (M+1 nodes, M intervals)
372    /// * `s0` - Initial guesses for state at each node: (M, n) as `Vec<Array1>`
373    /// * `cfg` - Solver configuration
374    pub fn solve<ODE, BC>(
375        ode: &ODE,
376        bc: &BC,
377        t_nodes: &[f64],
378        s0: Vec<Array1<f64>>,
379        cfg: &ShootingConfig,
380    ) -> IntegrateResult<BVPResult>
381    where
382        ODE: Fn(f64, &Array1<f64>) -> Array1<f64>,
383        BC: Fn(&Array1<f64>, &Array1<f64>) -> Array1<f64>,
384    {
385        if t_nodes.len() < 2 {
386            return Err(IntegrateError::InvalidInput(
387                "t_nodes must have at least 2 elements".to_string(),
388            ));
389        }
390        let m = t_nodes.len() - 1; // number of subintervals
391        if s0.len() != m {
392            return Err(IntegrateError::DimensionMismatch(format!(
393                "s0 length {} must equal number of subintervals {}",
394                s0.len(),
395                m
396            )));
397        }
398
399        let n = s0[0].len(); // state dimension
400        let total_unknowns = m * n;
401
402        // Flatten unknowns: [s_0 | s_1 | ... | s_{M-1}]
403        let mut s_flat = Array1::<f64>::zeros(total_unknowns);
404        for (i, si) in s0.iter().enumerate() {
405            for j in 0..n {
406                s_flat[i * n + j] = si[j];
407            }
408        }
409
410        let mut n_iters = 0usize;
411        let mut converged = false;
412
413        for iter in 0..cfg.max_newton_iter {
414            // Build residual: [BC | continuity...]
415            let n_bc_eqs = {
416                let ya = s_flat.slice(scirs2_core::ndarray::s![..n]).to_owned();
417                let yb_start = &s_flat
418                    .slice(scirs2_core::ndarray::s![(m - 1) * n..])
419                    .to_owned();
420                let t_last = t_nodes[m - 1];
421                let t_end = t_nodes[m];
422                let yb = integrate_rk4(ode, t_last, yb_start, t_end, cfg.n_steps);
423                bc(&ya, &yb).len()
424            };
425
426            let residual_len = n_bc_eqs + (m - 1) * n;
427            let mut res = Array1::<f64>::zeros(residual_len);
428
429            // BC residual
430            let ya = s_flat.slice(scirs2_core::ndarray::s![..n]).to_owned();
431            let yb_start = s_flat
432                .slice(scirs2_core::ndarray::s![(m - 1) * n..])
433                .to_owned();
434            let yb = integrate_rk4(ode, t_nodes[m - 1], &yb_start, t_nodes[m], cfg.n_steps);
435            let bc_res = bc(&ya, &yb);
436            for i in 0..n_bc_eqs {
437                res[i] = bc_res[i];
438            }
439
440            // Continuity residuals: y(t_i^+; s_i) - s_{i+1} = 0
441            for interval in 0..(m - 1) {
442                let si = s_flat
443                    .slice(scirs2_core::ndarray::s![interval * n..(interval + 1) * n])
444                    .to_owned();
445                let si_next = s_flat
446                    .slice(scirs2_core::ndarray::s![
447                        (interval + 1) * n..(interval + 2) * n
448                    ])
449                    .to_owned();
450                let y_shot = integrate_rk4(
451                    ode,
452                    t_nodes[interval],
453                    &si,
454                    t_nodes[interval + 1],
455                    cfg.n_steps,
456                );
457                for j in 0..n {
458                    res[n_bc_eqs + interval * n + j] = y_shot[j] - si_next[j];
459                }
460            }
461
462            let res_norm: f64 = res.iter().map(|&v| v * v).sum::<f64>().sqrt();
463            if res_norm < cfg.newton_tol {
464                n_iters = iter + 1;
465                converged = true;
466                break;
467            }
468
469            // Numerical Jacobian of full residual w.r.t. s_flat
470            let shoot_residual = |s: &Array1<f64>| {
471                let mut r = Array1::<f64>::zeros(residual_len);
472                let ya = s.slice(scirs2_core::ndarray::s![..n]).to_owned();
473                let yb_s = s.slice(scirs2_core::ndarray::s![(m - 1) * n..]).to_owned();
474                let yb = integrate_rk4(ode, t_nodes[m - 1], &yb_s, t_nodes[m], cfg.n_steps);
475                let bcr = bc(&ya, &yb);
476                for i in 0..n_bc_eqs {
477                    r[i] = bcr[i];
478                }
479                for interval in 0..(m - 1) {
480                    let si = s
481                        .slice(scirs2_core::ndarray::s![interval * n..(interval + 1) * n])
482                        .to_owned();
483                    let si_next = s
484                        .slice(scirs2_core::ndarray::s![
485                            (interval + 1) * n..(interval + 2) * n
486                        ])
487                        .to_owned();
488                    let y_shot = integrate_rk4(
489                        ode,
490                        t_nodes[interval],
491                        &si,
492                        t_nodes[interval + 1],
493                        cfg.n_steps,
494                    );
495                    for j in 0..n {
496                        r[n_bc_eqs + interval * n + j] = y_shot[j] - si_next[j];
497                    }
498                }
499                r
500            };
501
502            let mut jac = numerical_jacobian(&shoot_residual, &s_flat, cfg.fd_eps);
503            let mut neg_res = res.mapv(|v| -v);
504
505            match gauss_solve(&mut jac, &mut neg_res) {
506                Ok(delta) => {
507                    for i in 0..total_unknowns {
508                        s_flat[i] += delta[i];
509                    }
510                }
511                Err(_) => {
512                    return Err(IntegrateError::LinearSolveError(
513                        "Multiple shooting: singular Jacobian".to_string(),
514                    ));
515                }
516            }
517        }
518
519        if !converged {
520            n_iters = cfg.max_newton_iter;
521        }
522
523        // Reconstruct trajectory
524        let mut t_traj = Vec::new();
525        let mut y_traj = Vec::new();
526        for interval in 0..m {
527            let si = s_flat
528                .slice(scirs2_core::ndarray::s![interval * n..(interval + 1) * n])
529                .to_owned();
530            let (ts, ys) = trajectory_rk4(
531                ode,
532                t_nodes[interval],
533                &si,
534                t_nodes[interval + 1],
535                cfg.n_steps / m.max(1),
536            );
537            if interval == 0 {
538                t_traj.extend_from_slice(&ts);
539                y_traj.extend_from_slice(&ys);
540            } else {
541                t_traj.extend_from_slice(&ts[1..]);
542                y_traj.extend_from_slice(&ys[1..]);
543            }
544        }
545
546        let ya = y_traj
547            .first()
548            .cloned()
549            .unwrap_or_else(|| Array1::<f64>::zeros(n));
550        let yb = y_traj
551            .last()
552            .cloned()
553            .unwrap_or_else(|| Array1::<f64>::zeros(n));
554        let final_bc = bc(&ya, &yb);
555        let residual: f64 = final_bc.iter().map(|&v| v * v).sum::<f64>().sqrt();
556
557        Ok(BVPResult {
558            t: t_traj,
559            y: y_traj,
560            residual,
561            error: residual,
562            n_newton_iters: n_iters,
563            success: converged,
564            message: if converged {
565                "Multiple shooting converged".to_string()
566            } else {
567                format!(
568                    "Multiple shooting did not converge in {} iterations",
569                    cfg.max_newton_iter
570                )
571            },
572        })
573    }
574}
575
576// ---------------------------------------------------------------------------
577// Orthogonal Collocation
578// ---------------------------------------------------------------------------
579
580/// Gauss-Legendre collocation nodes on [-1, 1] for order 2 to 5
581fn gauss_legendre_nodes(order: usize) -> Vec<f64> {
582    match order {
583        1 => vec![0.0],
584        2 => vec![-1.0 / 3.0_f64.sqrt(), 1.0 / 3.0_f64.sqrt()],
585        3 => vec![-0.7745966692, 0.0, 0.7745966692],
586        4 => vec![-0.8611363116, -0.3399810436, 0.3399810436, 0.8611363116],
587        5 => vec![
588            -0.9061798459,
589            -0.5384693101,
590            0.0,
591            0.5384693101,
592            0.9061798459,
593        ],
594        _ => vec![-1.0 / 3.0_f64.sqrt(), 1.0 / 3.0_f64.sqrt()], // default to 2-point
595    }
596}
597
598/// Gauss-Legendre weights on [-1, 1]
599fn gauss_legendre_weights(order: usize) -> Vec<f64> {
600    match order {
601        1 => vec![2.0],
602        2 => vec![1.0, 1.0],
603        3 => vec![0.5555555556, 0.8888888889, 0.5555555556],
604        4 => vec![0.3478548451, 0.6521451549, 0.6521451549, 0.3478548451],
605        5 => vec![
606            0.2369268851,
607            0.4786286705,
608            0.5688888889,
609            0.4786286705,
610            0.2369268851,
611        ],
612        _ => vec![1.0, 1.0],
613    }
614}
615
616/// Configuration for orthogonal collocation
617#[derive(Debug, Clone)]
618pub struct CollocationConfig {
619    /// Number of subintervals (mesh intervals)
620    pub n_subintervals: usize,
621    /// Number of collocation points per interval (= polynomial order)
622    pub collocation_order: usize,
623    /// Newton solver tolerance
624    pub newton_tol: f64,
625    /// Maximum Newton iterations
626    pub max_newton_iter: usize,
627    /// Finite difference epsilon
628    pub fd_eps: f64,
629}
630
631impl Default for CollocationConfig {
632    fn default() -> Self {
633        Self {
634            n_subintervals: 10,
635            collocation_order: 3,
636            newton_tol: 1e-8,
637            max_newton_iter: 30,
638            fd_eps: 1e-7,
639        }
640    }
641}
642
643/// Orthogonal collocation at Gauss-Legendre points.
644///
645/// Approximates y(t) as a piecewise polynomial, enforcing the ODE at collocation
646/// points within each subinterval and matching state values at mesh nodes.
647///
648/// For each subinterval [t_i, t_{i+1}] with m collocation points τ_j:
649///   - The polynomial p(t) satisfies p(t_i) = y_i (left endpoint)
650///   - p'(τ_j) = f(τ_j, p(τ_j)) for j = 1..m (collocation conditions)
651///   - y_{i+1} = p(t_{i+1}) (right endpoint value)
652///
653/// This results in a large nonlinear system that is solved by Newton iterations.
654pub struct OrthogonalCollocation;
655
656impl OrthogonalCollocation {
657    /// Solve BVP by collocation at Gauss-Legendre points.
658    ///
659    /// # Arguments
660    ///
661    /// * `ode` - ODE function f(t, y)
662    /// * `bc` - Boundary condition: g(y(a), y(b)) = 0 (n_bc equations)
663    /// * `t_span` - [a, b]
664    /// * `y_init_guess` - Closure providing initial guess for y at time t
665    /// * `cfg` - Collocation configuration
666    pub fn solve<ODE, BC, Guess>(
667        ode: &ODE,
668        bc: &BC,
669        t_span: [f64; 2],
670        y_init_guess: &Guess,
671        n_state: usize,
672        cfg: &CollocationConfig,
673    ) -> IntegrateResult<BVPResult>
674    where
675        ODE: Fn(f64, &Array1<f64>) -> Array1<f64>,
676        BC: Fn(&Array1<f64>, &Array1<f64>) -> Array1<f64>,
677        Guess: Fn(f64) -> Array1<f64>,
678    {
679        let [ta, tb] = t_span;
680        let m = cfg.n_subintervals;
681        let k = cfg.collocation_order.min(5).max(1);
682        let h = (tb - ta) / m as f64;
683
684        // Mesh nodes
685        let nodes: Vec<f64> = (0..=m).map(|i| ta + i as f64 * h).collect();
686
687        // Collocation nodes (on reference element [-1,1])
688        let ref_nodes = gauss_legendre_nodes(k);
689
690        // Total unknowns: y at mesh nodes (m+1)*n  +  y at collocation pts m*k*n
691        // But we can simplify: unknowns = y at mesh nodes (m+1)*n only, and collocation
692        // pts determined via polynomial. For simplicity, we use the simpler formulation:
693        // unknowns = y at all mesh nodes (m+1)*n,
694        // residuals = ODE satisfied at collocation pts + BC.
695        //
696        // Each mesh interval [t_i, t_{i+1}] contributes k residuals from ODE collocation.
697        // We use linear interpolation to get y at collocation pts (order 1 approximation)
698        // and add Hermite-type update.
699
700        let n = n_state;
701        let total_unknowns = (m + 1) * n;
702
703        // Initial guess from provided function
704        let mut y_flat = Array1::<f64>::zeros(total_unknowns);
705        for i in 0..=m {
706            let guess = y_init_guess(nodes[i]);
707            for j in 0..n {
708                y_flat[i * n + j] = guess[j];
709            }
710        }
711
712        let n_bc_eqs = bc(
713            &y_flat.slice(scirs2_core::ndarray::s![..n]).to_owned(),
714            &y_flat.slice(scirs2_core::ndarray::s![m * n..]).to_owned(),
715        )
716        .len();
717
718        // Number of residuals: n_bc_eqs + m*k*n (collocation) - n*(m) (redundant continuity)
719        // Simpler: n_bc_eqs + m*k*n equations, (m+1)*n unknowns
720        // We use a simplified formulation: n*(m+1) equations with n*(m+1) unknowns
721        // Residuals:
722        //   - BC: n_bc_eqs equations
723        //   - For each interval: n equations from integral form (trapezoidal approx)
724
725        // Use a simpler consistent formulation:
726        // residual[0..n_bc] = bc(y0, yM)
727        // residual[n_bc + i*n .. n_bc + (i+1)*n] = integral residual for interval i
728        // via: y_{i+1} - y_i - h/2*(f(t_i, y_i) + f(t_{i+1}, y_{i+1})) = 0  (trapezoidal)
729        // Total: n_bc + m*n equations, (m+1)*n unknowns.
730        // Constraint: n_bc = n for well-posed system. (m+1)*n = n_bc + m*n ✓
731
732        let build_residual = |yf: &Array1<f64>| {
733            let rlen = n_bc_eqs + m * n;
734            let mut r = Array1::<f64>::zeros(rlen);
735            let ya = yf.slice(scirs2_core::ndarray::s![..n]).to_owned();
736            let ym = yf.slice(scirs2_core::ndarray::s![m * n..]).to_owned();
737            let bcr = bc(&ya, &ym);
738            for i in 0..n_bc_eqs {
739                r[i] = bcr[i];
740            }
741
742            for interval in 0..m {
743                let ti = nodes[interval];
744                let tip1 = nodes[interval + 1];
745                let yi = yf
746                    .slice(scirs2_core::ndarray::s![interval * n..(interval + 1) * n])
747                    .to_owned();
748                let yip1 = yf
749                    .slice(scirs2_core::ndarray::s![
750                        (interval + 1) * n..(interval + 2) * n
751                    ])
752                    .to_owned();
753
754                // Collocation at Gauss-Legendre points in [ti, tip1]
755                // For each collocation point, evaluate ODE and add contribution
756                // Using k-point Gauss rule to integrate the ODE residual
757                let hi = tip1 - ti;
758                let wts = gauss_legendre_weights(k);
759
760                // Trapezoidal/high-order integral: y_{i+1} - y_i - integral f dt = 0
761                let mut integral = Array1::<f64>::zeros(n);
762                for (q, &xi) in ref_nodes.iter().enumerate() {
763                    // Map from [-1,1] to [ti, tip1]
764                    let tc = ti + (xi + 1.0) * 0.5 * hi;
765                    // Linear interpolation for y at collocation point
766                    let alpha = (xi + 1.0) * 0.5;
767                    let mut yc = Array1::<f64>::zeros(n);
768                    for j in 0..n {
769                        yc[j] = (1.0 - alpha) * yi[j] + alpha * yip1[j];
770                    }
771                    let fc = ode(tc, &yc);
772                    let wt = wts[q] * hi * 0.5;
773                    for j in 0..n {
774                        integral[j] += wt * fc[j];
775                    }
776                }
777
778                let base = n_bc_eqs + interval * n;
779                for j in 0..n {
780                    r[base + j] = yip1[j] - yi[j] - integral[j];
781                }
782            }
783            r
784        };
785
786        let mut n_iters = 0usize;
787        let mut converged = false;
788
789        for iter in 0..cfg.max_newton_iter {
790            let res = build_residual(&y_flat);
791            let res_norm: f64 = res.iter().map(|&v| v * v).sum::<f64>().sqrt();
792            if res_norm < cfg.newton_tol {
793                n_iters = iter + 1;
794                converged = true;
795                break;
796            }
797
798            let mut jac = numerical_jacobian(&build_residual, &y_flat, cfg.fd_eps);
799            let mut neg_res = res.mapv(|v| -v);
800            match gauss_solve(&mut jac, &mut neg_res) {
801                Ok(delta) => {
802                    for i in 0..total_unknowns {
803                        y_flat[i] += delta[i];
804                    }
805                }
806                Err(e) => return Err(e),
807            }
808        }
809
810        if !converged {
811            n_iters = cfg.max_newton_iter;
812        }
813
814        // Extract trajectory
815        let t_traj: Vec<f64> = nodes.clone();
816        let y_traj: Vec<Array1<f64>> = (0..=m)
817            .map(|i| {
818                y_flat
819                    .slice(scirs2_core::ndarray::s![i * n..(i + 1) * n])
820                    .to_owned()
821            })
822            .collect();
823
824        let ya = y_traj
825            .first()
826            .cloned()
827            .unwrap_or_else(|| Array1::<f64>::zeros(n));
828        let yb = y_traj
829            .last()
830            .cloned()
831            .unwrap_or_else(|| Array1::<f64>::zeros(n));
832        let final_bc = bc(&ya, &yb);
833        let residual: f64 = final_bc.iter().map(|&v| v * v).sum::<f64>().sqrt();
834
835        Ok(BVPResult {
836            t: t_traj,
837            y: y_traj,
838            residual,
839            error: residual,
840            n_newton_iters: n_iters,
841            success: converged,
842            message: if converged {
843                "Orthogonal collocation converged".to_string()
844            } else {
845                format!(
846                    "Collocation did not converge in {} iterations",
847                    cfg.max_newton_iter
848                )
849            },
850        })
851    }
852}
853
854// ---------------------------------------------------------------------------
855// Periodic Orbit Finder
856// ---------------------------------------------------------------------------
857
858/// Configuration for the periodic orbit finder
859#[derive(Debug, Clone)]
860pub struct PeriodicOrbitConfig {
861    /// Number of RK4 steps for one period integration
862    pub n_steps: usize,
863    /// Newton solver tolerance
864    pub newton_tol: f64,
865    /// Maximum Newton iterations
866    pub max_newton_iter: usize,
867    /// Finite difference epsilon for Jacobians
868    pub fd_eps: f64,
869    /// Index of phase condition (which component is fixed)
870    pub phase_condition_idx: usize,
871}
872
873impl Default for PeriodicOrbitConfig {
874    fn default() -> Self {
875        Self {
876            n_steps: 500,
877            newton_tol: 1e-8,
878            max_newton_iter: 50,
879            fd_eps: 1e-7,
880            phase_condition_idx: 0,
881        }
882    }
883}
884
885/// Result of a periodic orbit computation
886#[derive(Debug, Clone)]
887pub struct PeriodicOrbitResult {
888    /// One full period of the orbit (time points)
889    pub t: Vec<f64>,
890    /// States along the orbit
891    pub y: Vec<Array1<f64>>,
892    /// Period T
893    pub period: f64,
894    /// Initial state of the orbit y(0)
895    pub y0: Array1<f64>,
896    /// Residual of the periodicity condition
897    pub residual: f64,
898    /// Number of Newton iterations
899    pub n_newton_iters: usize,
900    /// Whether the solver converged
901    pub success: bool,
902    /// Termination message
903    pub message: String,
904}
905
906/// Shooting-based periodic orbit finder.
907///
908/// Seeks a state y* and period T such that φ(T, y*) = y*, where φ is the flow map.
909///
910/// The system solved is:
911///   F(y*, T) = φ(T, y*) - y* = 0   (periodicity, n equations)
912///   g(y*, T) = 0                    (phase condition, 1 equation)
913///
914/// giving n+1 equations in n+1 unknowns (y* ∈ R^n, T ∈ R).
915///
916/// The phase condition fixes the phase along the orbit to remove translational
917/// invariance. Here we use: y*\[phase_condition_idx\] - y0_ref\[phase_condition_idx\] = 0.
918pub struct PeriodicOrbitFinder;
919
920impl PeriodicOrbitFinder {
921    /// Find a periodic orbit near (y_guess, t_guess).
922    ///
923    /// # Arguments
924    ///
925    /// * `ode` - Autonomous ODE function f(t, y)
926    /// * `y_guess` - Initial guess for initial state on the orbit
927    /// * `t_guess` - Initial guess for the period
928    /// * `cfg` - Configuration
929    pub fn find<ODE>(
930        ode: &ODE,
931        y_guess: &Array1<f64>,
932        t_guess: f64,
933        cfg: &PeriodicOrbitConfig,
934    ) -> IntegrateResult<PeriodicOrbitResult>
935    where
936        ODE: Fn(f64, &Array1<f64>) -> Array1<f64>,
937    {
938        let n = y_guess.len();
939        let phase_idx = cfg.phase_condition_idx.min(n - 1);
940        let y0_ref_phase = y_guess[phase_idx];
941
942        // Extended state: z = [y* (n), T (1)], total n+1
943        let mut z = Array1::<f64>::zeros(n + 1);
944        for i in 0..n {
945            z[i] = y_guess[i];
946        }
947        z[n] = t_guess;
948
949        // Residual function
950        let residual_fn = |zv: &Array1<f64>| {
951            let mut y_cur = Array1::<f64>::zeros(n);
952            for i in 0..n {
953                y_cur[i] = zv[i];
954            }
955            let period = zv[n].max(1e-10);
956            let y_end = integrate_rk4(ode, 0.0, &y_cur, period, cfg.n_steps);
957            let mut r = Array1::<f64>::zeros(n + 1);
958            for i in 0..n {
959                r[i] = y_end[i] - y_cur[i];
960            }
961            // Phase condition: y*[phase_idx] = y0_ref_phase
962            r[n] = y_cur[phase_idx] - y0_ref_phase;
963            r
964        };
965
966        let mut n_iters = 0usize;
967        let mut converged = false;
968
969        for iter in 0..cfg.max_newton_iter {
970            let res = residual_fn(&z);
971            let res_norm: f64 = res.iter().map(|&v| v * v).sum::<f64>().sqrt();
972            if res_norm < cfg.newton_tol {
973                n_iters = iter + 1;
974                converged = true;
975                break;
976            }
977
978            let mut jac = numerical_jacobian(&residual_fn, &z, cfg.fd_eps);
979            let mut neg_res = res.mapv(|v| -v);
980            match gauss_solve(&mut jac, &mut neg_res) {
981                Ok(delta) => {
982                    for i in 0..=n {
983                        z[i] += delta[i];
984                    }
985                    // Keep period positive
986                    if z[n] < 1e-10 {
987                        z[n] = 1e-10;
988                    }
989                }
990                Err(_) => {
991                    return Err(IntegrateError::LinearSolveError(
992                        "Periodic orbit: singular Jacobian".to_string(),
993                    ));
994                }
995            }
996        }
997
998        if !converged {
999            n_iters = cfg.max_newton_iter;
1000        }
1001
1002        let mut y_star = Array1::<f64>::zeros(n);
1003        for i in 0..n {
1004            y_star[i] = z[i];
1005        }
1006        let period = z[n];
1007
1008        let (t_traj, y_traj) = trajectory_rk4(ode, 0.0, &y_star, period, cfg.n_steps);
1009
1010        let final_res = residual_fn(&z);
1011        let residual: f64 = final_res.iter().map(|&v| v * v).sum::<f64>().sqrt();
1012
1013        Ok(PeriodicOrbitResult {
1014            t: t_traj,
1015            y: y_traj,
1016            period,
1017            y0: y_star,
1018            residual,
1019            n_newton_iters: n_iters,
1020            success: converged,
1021            message: if converged {
1022                format!("Periodic orbit found: T = {:.6}", period)
1023            } else {
1024                format!(
1025                    "Periodic orbit not found in {} iterations",
1026                    cfg.max_newton_iter
1027                )
1028            },
1029        })
1030    }
1031}
1032
1033// Suppress unused warning for helper
1034#[allow(dead_code)]
1035fn _use_to_f() {
1036    let _ = to_f(0.0);
1037}
1038
1039// ---------------------------------------------------------------------------
1040// Tests
1041// ---------------------------------------------------------------------------
1042
1043#[cfg(test)]
1044mod tests {
1045    use super::*;
1046    use scirs2_core::ndarray::array;
1047
1048    /// Linear BVP: y'' + y = 0, y(0) = 0, y(π/2) = 1
1049    /// Exact solution: y(t) = sin(t)
1050    /// First-order form: y' = [y1, -y0]
1051    fn linear_ode(t: f64, y: &Array1<f64>) -> Array1<f64> {
1052        let _ = t;
1053        array![y[1], -y[0]]
1054    }
1055
1056    fn linear_bc(ya: &Array1<f64>, yb: &Array1<f64>) -> Array1<f64> {
1057        // y(0) = 0, y(π/2) = 1
1058        array![ya[0], yb[0] - 1.0]
1059    }
1060
1061    #[test]
1062    fn test_single_shooting_linear_bvp() {
1063        // y' = [y1, -y0], y(0)=0, y(π/2)=1, exact: y(t)=sin(t)
1064        let t_span = [0.0, std::f64::consts::FRAC_PI_2];
1065
1066        // Free parameter is y'(0) = y[1](0)
1067        let initial_condition = |s: &Array1<f64>| array![0.0, s[0]];
1068        let bc = |ya: &Array1<f64>, yb: &Array1<f64>| array![ya[0], yb[0] - 1.0];
1069        let s0 = array![1.0]; // initial guess y'(0)=1
1070
1071        let cfg = ShootingConfig {
1072            n_steps: 200,
1073            newton_tol: 1e-8,
1074            ..Default::default()
1075        };
1076
1077        let result = SingleShooting::solve(&linear_ode, &bc, &initial_condition, t_span, s0, &cfg)
1078            .expect("Single shooting failed");
1079
1080        assert!(result.success, "Should converge: {}", result.message);
1081        assert!(
1082            result.residual < 1e-6,
1083            "Residual {} too large",
1084            result.residual
1085        );
1086
1087        // Check y(π/4) ≈ sin(π/4) ≈ 0.7071
1088        let t_quarter = std::f64::consts::FRAC_PI_4;
1089        let idx = result.t.iter().position(|&t| (t - t_quarter).abs() < 0.02);
1090        if let Some(i) = idx {
1091            let y_val = result.y[i][0];
1092            let exact = t_quarter.sin();
1093            assert!(
1094                (y_val - exact).abs() < 0.02,
1095                "y(π/4)={} != sin(π/4)={}",
1096                y_val,
1097                exact
1098            );
1099        }
1100    }
1101
1102    #[test]
1103    fn test_collocation_linear_bvp() {
1104        let cfg = CollocationConfig {
1105            n_subintervals: 8,
1106            collocation_order: 3,
1107            newton_tol: 1e-6,
1108            max_newton_iter: 30,
1109            fd_eps: 1e-6,
1110        };
1111
1112        let t_span = [0.0, std::f64::consts::FRAC_PI_2];
1113        let guess = |t: f64| array![t.sin(), t.cos()];
1114
1115        let result = OrthogonalCollocation::solve(&linear_ode, &linear_bc, t_span, &guess, 2, &cfg)
1116            .expect("Collocation failed");
1117
1118        assert!(
1119            result.residual < 1e-4,
1120            "Collocation residual {} too large",
1121            result.residual
1122        );
1123    }
1124
1125    #[test]
1126    fn test_periodic_orbit_harmonic_oscillator() {
1127        // Harmonic oscillator: y'' + y = 0, exact period T = 2π
1128        // First-order: y' = [y1, -y0]
1129        // Periodic orbit starting at [1, 0], T = 2π
1130        let cfg = PeriodicOrbitConfig {
1131            n_steps: 200,
1132            newton_tol: 1e-8,
1133            max_newton_iter: 20,
1134            fd_eps: 1e-6,
1135            phase_condition_idx: 0,
1136        };
1137
1138        let y_guess = array![1.0, 0.0];
1139        let t_guess = 2.0 * std::f64::consts::PI;
1140
1141        let result = PeriodicOrbitFinder::find(&linear_ode, &y_guess, t_guess, &cfg)
1142            .expect("Periodic orbit finder failed");
1143
1144        // Period should be close to 2π
1145        assert!(
1146            (result.period - 2.0 * std::f64::consts::PI).abs() < 0.1,
1147            "Period {} != 2π",
1148            result.period
1149        );
1150    }
1151}