Skip to main content

scirs2_integrate/
bvp_collocation.rs

1//! Boundary-value problem solver using Lobatto IIIA collocation
2//!
3//! This module implements a collocation method for two-point boundary-value
4//! problems of the form:
5//!
6//! ```text
7//!   y'(x) = f(x, y),   a <= x <= b
8//!   bc(y(a), y(b)) = 0
9//! ```
10//!
11//! The solver places collocation points at the Lobatto nodes (including the
12//! mesh endpoints) on each sub-interval, constructs a global nonlinear system
13//! from the collocation residuals plus boundary conditions, and solves it
14//! with damped Newton iteration. Mesh refinement is driven by a defect-based
15//! error estimate.
16//!
17//! ## Algorithm outline
18//!
19//! 1. Start with an initial mesh and guess.
20//! 2. On each sub-interval `[x_i, x_{i+1}]` place the 3-point Lobatto IIIA
21//!    nodes (endpoints + midpoint). The collocation polynomial is degree 3
22//!    (4th-order accurate).
23//! 3. Assemble the global Newton system and solve.
24//! 4. Estimate the defect on each sub-interval; refine where needed.
25//! 5. Repeat until the tolerance is met or the budget is exhausted.
26//!
27//! ## References
28//!
29//! - U. Ascher, R. Mattheij, R. Russell (1995), "Numerical Solution of
30//!   Boundary Value Problems for Ordinary Differential Equations"
31//! - J. Kierzenka, L. Shampine (2001), "A BVP Solver Based on Residual
32//!   Control and the MATLAB PSE"
33
34use crate::error::{IntegrateError, IntegrateResult};
35use crate::IntegrateFloat;
36use scirs2_core::ndarray::{s, Array1, Array2, ArrayView1, ArrayView2};
37
38/// Helper to convert f64 to generic float
39#[inline(always)]
40fn to_f<F: IntegrateFloat>(v: f64) -> F {
41    F::from_f64(v).unwrap_or_else(|| F::zero())
42}
43
44// ---------------------------------------------------------------------------
45// Public types
46// ---------------------------------------------------------------------------
47
48/// Options for the collocation BVP solver
49#[derive(Debug, Clone)]
50pub struct CollocationBVPOptions<F: IntegrateFloat> {
51    /// Tolerance on the defect (default 1e-6)
52    pub tol: F,
53    /// Maximum Newton iterations per mesh (default 40)
54    pub max_newton_iter: usize,
55    /// Maximum number of mesh refinement cycles (default 10)
56    pub max_mesh_refinements: usize,
57    /// Maximum allowed mesh size (default 500 nodes)
58    pub max_mesh_size: usize,
59    /// Damping factor for Newton step (0 < factor <= 1, default 1.0)
60    pub damping: F,
61}
62
63impl<F: IntegrateFloat> Default for CollocationBVPOptions<F> {
64    fn default() -> Self {
65        Self {
66            tol: to_f(1e-6),
67            max_newton_iter: 40,
68            max_mesh_refinements: 10,
69            max_mesh_size: 500,
70            damping: F::one(),
71        }
72    }
73}
74
75/// Result from the collocation BVP solver
76#[derive(Debug, Clone)]
77pub struct CollocationBVPResult<F: IntegrateFloat> {
78    /// Mesh points
79    pub x: Vec<F>,
80    /// Solution values at mesh points, each of length `n_dim`
81    pub y: Vec<Array1<F>>,
82    /// Number of Newton iterations (total across all refinements)
83    pub n_newton_iter: usize,
84    /// Number of mesh refinement cycles
85    pub n_refinements: usize,
86    /// Final maximum defect
87    pub max_defect: F,
88    /// Whether the solver converged
89    pub converged: bool,
90}
91
92// ---------------------------------------------------------------------------
93// Core solve function
94// ---------------------------------------------------------------------------
95
96/// Solve a two-point BVP using Lobatto IIIA collocation.
97///
98/// # Arguments
99///
100/// * `ode`      - Right-hand side `y'(x) = ode(x, y)`
101/// * `bc`       - Boundary conditions: `bc(y(a), y(b))` returns residual vector
102///   (length must equal the system dimension)
103/// * `x_mesh`   - Initial mesh (strictly increasing, at least 2 points)
104/// * `y_guess`  - Initial guess at each mesh point
105/// * `options`  - Solver options (optional)
106///
107/// # Examples
108///
109/// ```
110/// use scirs2_core::ndarray::{array, Array1, ArrayView1};
111/// use scirs2_integrate::bvp_collocation::{solve_bvp_collocation, CollocationBVPOptions};
112///
113/// // Solve y'' = -y on [0, pi], y(0)=0, y(pi)=0
114/// // Rewrite as system: u1' = u2, u2' = -u1
115/// let ode = |_x: f64, y: ArrayView1<f64>| array![y[1], -y[0]];
116/// let bc = |ya: ArrayView1<f64>, yb: ArrayView1<f64>| array![ya[0], yb[0]];
117///
118/// let n = 11;
119/// let pi = std::f64::consts::PI;
120/// let x_mesh: Vec<f64> = (0..n).map(|i| i as f64 * pi / (n as f64 - 1.0)).collect();
121/// let y_guess: Vec<Array1<f64>> = x_mesh.iter()
122///     .map(|&x| array![x.sin(), x.cos()])
123///     .collect();
124///
125/// let result = solve_bvp_collocation(ode, bc, &x_mesh, &y_guess, None)
126///     .expect("collocation solve");
127/// assert!(result.converged);
128/// ```
129pub fn solve_bvp_collocation<F, OdeFn, BcFn>(
130    ode: OdeFn,
131    bc: BcFn,
132    x_mesh: &[F],
133    y_guess: &[Array1<F>],
134    options: Option<CollocationBVPOptions<F>>,
135) -> IntegrateResult<CollocationBVPResult<F>>
136where
137    F: IntegrateFloat,
138    OdeFn: Fn(F, ArrayView1<F>) -> Array1<F> + Copy,
139    BcFn: Fn(ArrayView1<F>, ArrayView1<F>) -> Array1<F>,
140{
141    let opts = options.unwrap_or_default();
142
143    // Validate inputs
144    if x_mesh.len() < 2 {
145        return Err(IntegrateError::ValueError(
146            "Mesh must have at least 2 points".into(),
147        ));
148    }
149    if x_mesh.len() != y_guess.len() {
150        return Err(IntegrateError::ValueError(
151            "Mesh and guess must have the same length".into(),
152        ));
153    }
154
155    let n_dim = y_guess[0].len();
156    for (i, yg) in y_guess.iter().enumerate() {
157        if yg.len() != n_dim {
158            return Err(IntegrateError::ValueError(format!(
159                "Guess at index {i} has wrong dimension: {} vs expected {n_dim}",
160                yg.len()
161            )));
162        }
163    }
164
165    // Check mesh is strictly increasing
166    for i in 1..x_mesh.len() {
167        if x_mesh[i] <= x_mesh[i - 1] {
168            return Err(IntegrateError::ValueError(
169                "Mesh must be strictly increasing".into(),
170            ));
171        }
172    }
173
174    let mut mesh = x_mesh.to_vec();
175    let mut y_sol: Vec<Array1<F>> = y_guess.to_vec();
176    let mut total_newton = 0_usize;
177    let mut n_refinements = 0_usize;
178
179    loop {
180        // Newton iteration on current mesh
181        let (new_y, newton_iter, converged) =
182            newton_collocation(&ode, &bc, &mesh, &y_sol, &opts, n_dim)?;
183        total_newton += newton_iter;
184        y_sol = new_y;
185
186        if !converged {
187            return Ok(CollocationBVPResult {
188                x: mesh,
189                y: y_sol,
190                n_newton_iter: total_newton,
191                n_refinements,
192                max_defect: F::infinity(),
193                converged: false,
194            });
195        }
196
197        // Compute defect on each sub-interval
198        let defects = compute_defects(&ode, &mesh, &y_sol, n_dim)?;
199        let max_defect = defects
200            .iter()
201            .copied()
202            .fold(F::zero(), |a, b| if b > a { b } else { a });
203
204        if max_defect <= opts.tol {
205            return Ok(CollocationBVPResult {
206                x: mesh,
207                y: y_sol,
208                n_newton_iter: total_newton,
209                n_refinements,
210                max_defect,
211                converged: true,
212            });
213        }
214
215        n_refinements += 1;
216        if n_refinements >= opts.max_mesh_refinements {
217            return Ok(CollocationBVPResult {
218                x: mesh,
219                y: y_sol,
220                n_newton_iter: total_newton,
221                n_refinements,
222                max_defect,
223                converged: false,
224            });
225        }
226
227        // Refine mesh: bisect intervals where defect exceeds tolerance
228        let (new_mesh, new_y_sol) = refine_mesh(
229            &ode,
230            &mesh,
231            &y_sol,
232            &defects,
233            opts.tol,
234            opts.max_mesh_size,
235            n_dim,
236        )?;
237
238        if new_mesh.len() >= opts.max_mesh_size {
239            return Ok(CollocationBVPResult {
240                x: new_mesh,
241                y: new_y_sol,
242                n_newton_iter: total_newton,
243                n_refinements,
244                max_defect,
245                converged: false,
246            });
247        }
248
249        mesh = new_mesh;
250        y_sol = new_y_sol;
251    }
252}
253
254// ---------------------------------------------------------------------------
255// Newton collocation solve on a fixed mesh
256// ---------------------------------------------------------------------------
257
258/// Perform Newton iteration on the collocation system for the given mesh.
259/// Returns `(solution, n_iterations, converged)`.
260fn newton_collocation<F, OdeFn, BcFn>(
261    ode: &OdeFn,
262    bc: &BcFn,
263    mesh: &[F],
264    y_init: &[Array1<F>],
265    opts: &CollocationBVPOptions<F>,
266    n_dim: usize,
267) -> IntegrateResult<(Vec<Array1<F>>, usize, bool)>
268where
269    F: IntegrateFloat,
270    OdeFn: Fn(F, ArrayView1<F>) -> Array1<F>,
271    BcFn: Fn(ArrayView1<F>, ArrayView1<F>) -> Array1<F>,
272{
273    let n_pts = mesh.len();
274    let n_intervals = n_pts - 1;
275    // Total unknowns: n_pts * n_dim
276    let n_vars = n_pts * n_dim;
277    // Equations: n_dim boundary conditions + n_intervals * n_dim collocation equations
278    let n_eqs = n_dim + n_intervals * n_dim;
279
280    if n_eqs != n_vars {
281        return Err(IntegrateError::DimensionMismatch(format!(
282            "Collocation system: {n_eqs} equations vs {n_vars} unknowns"
283        )));
284    }
285
286    let mut y_flat = flatten_solution(y_init, n_dim);
287    let eps: F = to_f(1e-8);
288
289    let mut converged = false;
290    let mut iter = 0_usize;
291
292    while iter < opts.max_newton_iter {
293        iter += 1;
294
295        // Evaluate residual
296        let residual = assemble_residual(ode, bc, mesh, &y_flat, n_dim, n_pts)?;
297
298        // Check convergence
299        let res_norm = residual
300            .iter()
301            .fold(F::zero(), |acc, &r| acc + r * r)
302            .sqrt()
303            / to_f::<F>(n_eqs as f64).max(F::one());
304
305        if res_norm < opts.tol {
306            converged = true;
307            break;
308        }
309
310        // Assemble Jacobian by finite differences
311        let jac = assemble_jacobian(ode, bc, mesh, &y_flat, &residual, n_dim, n_pts, eps)?;
312
313        // Solve J * delta = -residual
314        let neg_res = residual.mapv(|r| -r);
315        let delta = solve_dense_system(&jac, &neg_res)?;
316
317        // Update with damping
318        for i in 0..n_vars {
319            y_flat[i] += opts.damping * delta[i];
320        }
321    }
322
323    let y_sol = unflatten_solution(&y_flat, n_dim, n_pts);
324    Ok((y_sol, iter, converged))
325}
326
327// ---------------------------------------------------------------------------
328// Residual assembly
329// ---------------------------------------------------------------------------
330
331/// Assemble the global residual vector:
332///   R = [ bc(y(a), y(b)); collocation residuals ]
333fn assemble_residual<F, OdeFn, BcFn>(
334    ode: &OdeFn,
335    bc: &BcFn,
336    mesh: &[F],
337    y_flat: &Array1<F>,
338    n_dim: usize,
339    n_pts: usize,
340) -> IntegrateResult<Array1<F>>
341where
342    F: IntegrateFloat,
343    OdeFn: Fn(F, ArrayView1<F>) -> Array1<F>,
344    BcFn: Fn(ArrayView1<F>, ArrayView1<F>) -> Array1<F>,
345{
346    let n_intervals = n_pts - 1;
347    let n_eqs = n_dim + n_intervals * n_dim;
348    let mut res = Array1::zeros(n_eqs);
349
350    // Extract y_a and y_b
351    let y_a = y_flat.slice(s![0..n_dim]);
352    let y_b = y_flat.slice(s![(n_pts - 1) * n_dim..n_pts * n_dim]);
353
354    // Boundary conditions
355    let bc_res = bc(y_a, y_b);
356    for j in 0..n_dim {
357        res[j] = bc_res[j];
358    }
359
360    // Collocation equations: on each interval [x_i, x_{i+1}],
361    // we enforce the midpoint collocation condition:
362    //   y_{i+1} - y_i - h * f(x_mid, y_mid) = 0
363    // where y_mid = (y_i + y_{i+1}) / 2 + h/8 * (f_i - f_{i+1})
364    // (Lobatto IIIA 3-point formula, cubic accurate)
365    let half: F = to_f(0.5);
366    let eighth: F = to_f(0.125);
367
368    for i in 0..n_intervals {
369        let x_i = mesh[i];
370        let x_ip1 = mesh[i + 1];
371        let h = x_ip1 - x_i;
372        let x_mid = (x_i + x_ip1) * half;
373
374        let y_i = y_flat.slice(s![i * n_dim..(i + 1) * n_dim]);
375        let y_ip1 = y_flat.slice(s![(i + 1) * n_dim..(i + 2) * n_dim]);
376
377        let f_i = ode(x_i, y_i);
378        let f_ip1 = ode(x_ip1, y_ip1);
379
380        // Lobatto IIIA midpoint predictor
381        let mut y_mid = Array1::zeros(n_dim);
382        for j in 0..n_dim {
383            y_mid[j] = (y_i[j] + y_ip1[j]) * half + h * eighth * (f_i[j] - f_ip1[j]);
384        }
385
386        let f_mid = ode(x_mid, y_mid.view());
387
388        // Lobatto IIIA collocation residual:
389        //   y_{i+1} - y_i = h/6 * (f_i + 4*f_mid + f_ip1)  (Simpson-like)
390        let sixth: F = to_f(1.0 / 6.0);
391        let four: F = to_f(4.0);
392        let eq_offset = n_dim + i * n_dim;
393        for j in 0..n_dim {
394            res[eq_offset + j] =
395                y_ip1[j] - y_i[j] - h * sixth * (f_i[j] + four * f_mid[j] + f_ip1[j]);
396        }
397    }
398
399    Ok(res)
400}
401
402// ---------------------------------------------------------------------------
403// Jacobian assembly (finite differences)
404// ---------------------------------------------------------------------------
405
406fn assemble_jacobian<F, OdeFn, BcFn>(
407    ode: &OdeFn,
408    bc: &BcFn,
409    mesh: &[F],
410    y_flat: &Array1<F>,
411    res0: &Array1<F>,
412    n_dim: usize,
413    n_pts: usize,
414    eps: F,
415) -> IntegrateResult<Array2<F>>
416where
417    F: IntegrateFloat,
418    OdeFn: Fn(F, ArrayView1<F>) -> Array1<F>,
419    BcFn: Fn(ArrayView1<F>, ArrayView1<F>) -> Array1<F>,
420{
421    let n_vars = n_pts * n_dim;
422    let n_eqs = res0.len();
423    let mut jac = Array2::zeros((n_eqs, n_vars));
424
425    for col in 0..n_vars {
426        let mut y_pert = y_flat.clone();
427        let delta = eps * (F::one() + y_pert[col].abs());
428        y_pert[col] += delta;
429
430        let res_pert = assemble_residual(ode, bc, mesh, &y_pert, n_dim, n_pts)?;
431
432        for row in 0..n_eqs {
433            jac[[row, col]] = (res_pert[row] - res0[row]) / delta;
434        }
435    }
436
437    Ok(jac)
438}
439
440// ---------------------------------------------------------------------------
441// Dense linear solver (LU with partial pivoting)
442// ---------------------------------------------------------------------------
443
444fn solve_dense_system<F: IntegrateFloat>(
445    a: &Array2<F>,
446    b: &Array1<F>,
447) -> IntegrateResult<Array1<F>> {
448    let n = a.nrows();
449    if n != a.ncols() || n != b.len() {
450        return Err(IntegrateError::DimensionMismatch(
451            "solve_dense_system: dimension mismatch".into(),
452        ));
453    }
454
455    let mut lu = a.clone();
456    let mut piv: Vec<usize> = (0..n).collect();
457    let tiny = F::from_f64(1e-30).unwrap_or_else(|| F::epsilon());
458
459    for k in 0..n {
460        let mut max_val = lu[[piv[k], k]].abs();
461        let mut max_idx = k;
462        for i in (k + 1)..n {
463            let v = lu[[piv[i], k]].abs();
464            if v > max_val {
465                max_val = v;
466                max_idx = i;
467            }
468        }
469        if max_val < tiny {
470            return Err(IntegrateError::LinearSolveError(
471                "Singular matrix in collocation solver".into(),
472            ));
473        }
474        piv.swap(k, max_idx);
475
476        for i in (k + 1)..n {
477            let factor = lu[[piv[i], k]] / lu[[piv[k], k]];
478            lu[[piv[i], k]] = factor;
479            for j in (k + 1)..n {
480                let val = lu[[piv[k], j]];
481                lu[[piv[i], j]] -= factor * val;
482            }
483        }
484    }
485
486    let mut z = Array1::zeros(n);
487    for i in 0..n {
488        let mut s = b[piv[i]];
489        for j in 0..i {
490            s -= lu[[piv[i], j]] * z[j];
491        }
492        z[i] = s;
493    }
494
495    let mut x = Array1::zeros(n);
496    for i in (0..n).rev() {
497        let mut s = z[i];
498        for j in (i + 1)..n {
499            s -= lu[[piv[i], j]] * x[j];
500        }
501        if lu[[piv[i], i]].abs() < tiny {
502            return Err(IntegrateError::LinearSolveError(
503                "Zero diagonal in collocation LU".into(),
504            ));
505        }
506        x[i] = s / lu[[piv[i], i]];
507    }
508
509    Ok(x)
510}
511
512// ---------------------------------------------------------------------------
513// Defect estimation
514// ---------------------------------------------------------------------------
515
516/// Compute the defect (residual of the continuous ODE) at the midpoint of
517/// each sub-interval using the cubic Hermite interpolant.
518fn compute_defects<F, OdeFn>(
519    ode: &OdeFn,
520    mesh: &[F],
521    y_sol: &[Array1<F>],
522    n_dim: usize,
523) -> IntegrateResult<Vec<F>>
524where
525    F: IntegrateFloat,
526    OdeFn: Fn(F, ArrayView1<F>) -> Array1<F>,
527{
528    let n_intervals = mesh.len() - 1;
529    let mut defects = Vec::with_capacity(n_intervals);
530    let half: F = to_f(0.5);
531
532    for i in 0..n_intervals {
533        let h = mesh[i + 1] - mesh[i];
534        let x_mid = (mesh[i] + mesh[i + 1]) * half;
535
536        let f_i = ode(mesh[i], y_sol[i].view());
537        let f_ip1 = ode(mesh[i + 1], y_sol[i + 1].view());
538
539        // Cubic Hermite interpolation at midpoint
540        let mut y_mid = Array1::zeros(n_dim);
541        for j in 0..n_dim {
542            y_mid[j] =
543                (y_sol[i][j] + y_sol[i + 1][j]) * half + h * to_f::<F>(0.125) * (f_i[j] - f_ip1[j]);
544        }
545
546        // Derivative of Hermite interpolant at midpoint
547        let mut yp_mid = Array1::zeros(n_dim);
548        for j in 0..n_dim {
549            yp_mid[j] = (y_sol[i + 1][j] - y_sol[i][j]) / h - to_f::<F>(0.25) * (f_i[j] + f_ip1[j])
550                + half * to_f::<F>(1.0) * ((y_sol[i + 1][j] - y_sol[i][j]) / h);
551            // Simplified: the cubic Hermite slope at midpoint
552            yp_mid[j] = to_f::<F>(1.5) * (y_sol[i + 1][j] - y_sol[i][j]) / h
553                - to_f::<F>(0.25) * (f_i[j] + f_ip1[j]);
554        }
555
556        let f_mid = ode(x_mid, y_mid.view());
557
558        // Defect = ||yp_mid - f(x_mid, y_mid)||
559        let mut defect_sq = F::zero();
560        for j in 0..n_dim {
561            let d = yp_mid[j] - f_mid[j];
562            defect_sq += d * d;
563        }
564        defects.push(defect_sq.sqrt());
565    }
566
567    Ok(defects)
568}
569
570// ---------------------------------------------------------------------------
571// Mesh refinement
572// ---------------------------------------------------------------------------
573
574/// Refine the mesh by bisecting intervals with large defects.
575fn refine_mesh<F, OdeFn>(
576    ode: &OdeFn,
577    mesh: &[F],
578    y_sol: &[Array1<F>],
579    defects: &[F],
580    tol: F,
581    max_size: usize,
582    n_dim: usize,
583) -> IntegrateResult<(Vec<F>, Vec<Array1<F>>)>
584where
585    F: IntegrateFloat,
586    OdeFn: Fn(F, ArrayView1<F>) -> Array1<F>,
587{
588    let mut new_mesh = Vec::new();
589    let mut new_y = Vec::new();
590
591    new_mesh.push(mesh[0]);
592    new_y.push(y_sol[0].clone());
593
594    for i in 0..(mesh.len() - 1) {
595        if defects[i] > tol && new_mesh.len() + 2 <= max_size {
596            // Insert midpoint
597            let half: F = to_f(0.5);
598            let x_mid = (mesh[i] + mesh[i + 1]) * half;
599            let h = mesh[i + 1] - mesh[i];
600
601            let f_i = ode(mesh[i], y_sol[i].view());
602            let f_ip1 = ode(mesh[i + 1], y_sol[i + 1].view());
603
604            let mut y_mid = Array1::zeros(n_dim);
605            for j in 0..n_dim {
606                y_mid[j] = (y_sol[i][j] + y_sol[i + 1][j]) * half
607                    + h * to_f::<F>(0.125) * (f_i[j] - f_ip1[j]);
608            }
609
610            new_mesh.push(x_mid);
611            new_y.push(y_mid);
612        }
613
614        new_mesh.push(mesh[i + 1]);
615        new_y.push(y_sol[i + 1].clone());
616    }
617
618    Ok((new_mesh, new_y))
619}
620
621// ---------------------------------------------------------------------------
622// Utilities
623// ---------------------------------------------------------------------------
624
625fn flatten_solution<F: IntegrateFloat>(y: &[Array1<F>], n_dim: usize) -> Array1<F> {
626    let n_pts = y.len();
627    let mut flat = Array1::zeros(n_pts * n_dim);
628    for (i, yi) in y.iter().enumerate() {
629        for j in 0..n_dim {
630            flat[i * n_dim + j] = yi[j];
631        }
632    }
633    flat
634}
635
636fn unflatten_solution<F: IntegrateFloat>(
637    flat: &Array1<F>,
638    n_dim: usize,
639    n_pts: usize,
640) -> Vec<Array1<F>> {
641    let mut y = Vec::with_capacity(n_pts);
642    for i in 0..n_pts {
643        let start = i * n_dim;
644        let yi = Array1::from_vec(
645            flat.slice(s![start..start + n_dim])
646                .iter()
647                .copied()
648                .collect(),
649        );
650        y.push(yi);
651    }
652    y
653}
654
655// ---------------------------------------------------------------------------
656// Tests
657// ---------------------------------------------------------------------------
658
659#[cfg(test)]
660mod tests {
661    use super::*;
662    use scirs2_core::ndarray::array;
663
664    #[test]
665    fn test_linear_bvp() {
666        // y'' = 0, y(0)=0, y(1)=1 => y(x) = x
667        // System: u1' = u2, u2' = 0
668        let ode = |_x: f64, y: ArrayView1<f64>| array![y[1], 0.0];
669        let bc = |ya: ArrayView1<f64>, yb: ArrayView1<f64>| array![ya[0] - 0.0, yb[0] - 1.0];
670
671        let n = 5;
672        let mesh: Vec<f64> = (0..n).map(|i| i as f64 / (n as f64 - 1.0)).collect();
673        let guess: Vec<Array1<f64>> = mesh.iter().map(|&x| array![x, 1.0]).collect();
674
675        let result = solve_bvp_collocation(ode, bc, &mesh, &guess, None).expect("linear BVP solve");
676
677        assert!(result.converged, "linear BVP should converge");
678
679        // Check that y(x) ≈ x at all mesh points
680        for (i, xi) in result.x.iter().enumerate() {
681            assert!(
682                (result.y[i][0] - *xi).abs() < 1e-4,
683                "y({xi}) = {}, expected {xi}",
684                result.y[i][0]
685            );
686        }
687    }
688
689    #[test]
690    fn test_exponential_bvp() {
691        // y'' = y', y(0)=1, y(1)=e^(-1) => y(x) = e^(-x)
692        // System: u1' = u2, u2' = -u2  (so that u1 = e^(-x))
693        // Actually simpler: y' = -y, y(0)=1 is an IVP not BVP.
694        // Let's use a proper 2-point BVP: y'' + y' = 0, y(0)=1, y(1)=e^{-1}
695        // System: u1' = u2, u2' = -u2; BCs: u1(0)=1, u1(1)=e^{-1}
696        let ode = |_x: f64, y: ArrayView1<f64>| array![y[1], -y[1]];
697        let bc = |ya: ArrayView1<f64>, yb: ArrayView1<f64>| {
698            let exact_end = (-1.0_f64).exp();
699            array![ya[0] - 1.0, yb[0] - exact_end]
700        };
701
702        let n = 11;
703        let mesh: Vec<f64> = (0..n).map(|i| i as f64 / (n as f64 - 1.0)).collect();
704        let guess: Vec<Array1<f64>> = mesh
705            .iter()
706            .map(|&x| array![(-x).exp(), -(-x).exp()])
707            .collect();
708
709        let result = solve_bvp_collocation(
710            ode,
711            bc,
712            &mesh,
713            &guess,
714            Some(CollocationBVPOptions {
715                max_newton_iter: 100,
716                ..Default::default()
717            }),
718        )
719        .expect("exp BVP solve");
720
721        assert!(result.converged, "exp BVP should converge");
722
723        let y_final = result.y.last().expect("has solution")[0];
724        let exact = (-1.0_f64).exp();
725        assert!(
726            (y_final - exact).abs() < 1e-2,
727            "y(1) = {y_final}, expected {exact}"
728        );
729    }
730
731    #[test]
732    fn test_nonlinear_bvp() {
733        // Nonlinear BVP: y'' = -exp(y), y(0)=0, y(1)=0
734        // (Bratu problem, has solution for small lambda)
735        // System: u1' = u2, u2' = -exp(u1)
736        let ode = |_x: f64, y: ArrayView1<f64>| array![y[1], -y[0].exp()];
737        let bc = |ya: ArrayView1<f64>, yb: ArrayView1<f64>| array![ya[0], yb[0]];
738
739        let n = 21;
740        let mesh: Vec<f64> = (0..n).map(|i| i as f64 / (n as f64 - 1.0)).collect();
741        // Guess: parabola satisfying BCs, y ≈ 4*x*(1-x) * 0.1
742        let guess: Vec<Array1<f64>> = mesh
743            .iter()
744            .map(|&x| {
745                let y_val = 0.4 * x * (1.0 - x);
746                let yp_val = 0.4 * (1.0 - 2.0 * x);
747                array![y_val, yp_val]
748            })
749            .collect();
750
751        let result = solve_bvp_collocation(
752            ode,
753            bc,
754            &mesh,
755            &guess,
756            Some(CollocationBVPOptions {
757                max_newton_iter: 100,
758                ..Default::default()
759            }),
760        )
761        .expect("Bratu BVP solve");
762
763        assert!(result.converged, "Bratu BVP should converge");
764
765        // Check BCs
766        assert!(
767            result.y[0][0].abs() < 1e-4,
768            "y(0) should be 0, got {}",
769            result.y[0][0]
770        );
771        let y_end = result.y.last().expect("has solution")[0];
772        assert!(y_end.abs() < 1e-4, "y(1) should be 0, got {y_end}");
773
774        // Solution should be positive in interior
775        let mid_idx = n / 2;
776        assert!(
777            result.y[mid_idx][0] > 0.0,
778            "Interior should be positive, got {}",
779            result.y[mid_idx][0]
780        );
781    }
782
783    #[test]
784    fn test_stiff_bvp() {
785        // Stiff BVP: epsilon*y'' - y = 0, y(0)=1, y(1)=0
786        // For small epsilon, solution has boundary layer
787        // With epsilon = 0.1, the exact solution involves exp(x/sqrt(eps))
788        let epsilon = 0.1;
789        let ode = move |_x: f64, y: ArrayView1<f64>| array![y[1], y[0] / epsilon];
790        let bc = |ya: ArrayView1<f64>, yb: ArrayView1<f64>| array![ya[0] - 1.0, yb[0]];
791
792        let n = 31;
793        let mesh: Vec<f64> = (0..n).map(|i| i as f64 / (n as f64 - 1.0)).collect();
794        // Guess: linear decay from 1 to 0
795        let guess: Vec<Array1<f64>> = mesh.iter().map(|&x| array![1.0 - x, -1.0]).collect();
796
797        let result = solve_bvp_collocation(
798            ode,
799            bc,
800            &mesh,
801            &guess,
802            Some(CollocationBVPOptions {
803                tol: to_f(1e-4),
804                max_newton_iter: 60,
805                ..Default::default()
806            }),
807        )
808        .expect("stiff BVP solve");
809
810        assert!(result.converged, "stiff BVP should converge");
811        // Check BC
812        assert!(
813            (result.y[0][0] - 1.0).abs() < 1e-3,
814            "y(0) should be 1.0, got {}",
815            result.y[0][0]
816        );
817        let y_end = result.y.last().expect("has solution")[0];
818        assert!(y_end.abs() < 0.1, "y(1) should be ~0, got {y_end}");
819    }
820
821    #[test]
822    fn test_invalid_mesh() {
823        let ode = |_x: f64, _y: ArrayView1<f64>| array![0.0];
824        let bc = |ya: ArrayView1<f64>, _yb: ArrayView1<f64>| array![ya[0]];
825
826        // Mesh not increasing
827        let res = solve_bvp_collocation(ode, bc, &[1.0, 0.0], &[array![0.0], array![0.0]], None);
828        assert!(res.is_err());
829    }
830
831    #[test]
832    fn test_mesh_guess_mismatch() {
833        let ode = |_x: f64, _y: ArrayView1<f64>| array![0.0];
834        let bc = |ya: ArrayView1<f64>, _yb: ArrayView1<f64>| array![ya[0]];
835
836        let res = solve_bvp_collocation(
837            ode,
838            bc,
839            &[0.0, 0.5, 1.0],
840            &[array![0.0], array![0.0]], // wrong length
841            None,
842        );
843        assert!(res.is_err());
844    }
845}