1use crate::error::{IntegrateError, IntegrateResult};
35use crate::IntegrateFloat;
36use scirs2_core::ndarray::{s, Array1, Array2, ArrayView1, ArrayView2};
37
38#[inline(always)]
40fn to_f<F: IntegrateFloat>(v: f64) -> F {
41 F::from_f64(v).unwrap_or_else(|| F::zero())
42}
43
44#[derive(Debug, Clone)]
50pub struct CollocationBVPOptions<F: IntegrateFloat> {
51 pub tol: F,
53 pub max_newton_iter: usize,
55 pub max_mesh_refinements: usize,
57 pub max_mesh_size: usize,
59 pub damping: F,
61}
62
63impl<F: IntegrateFloat> Default for CollocationBVPOptions<F> {
64 fn default() -> Self {
65 Self {
66 tol: to_f(1e-6),
67 max_newton_iter: 40,
68 max_mesh_refinements: 10,
69 max_mesh_size: 500,
70 damping: F::one(),
71 }
72 }
73}
74
75#[derive(Debug, Clone)]
77pub struct CollocationBVPResult<F: IntegrateFloat> {
78 pub x: Vec<F>,
80 pub y: Vec<Array1<F>>,
82 pub n_newton_iter: usize,
84 pub n_refinements: usize,
86 pub max_defect: F,
88 pub converged: bool,
90}
91
92pub fn solve_bvp_collocation<F, OdeFn, BcFn>(
130 ode: OdeFn,
131 bc: BcFn,
132 x_mesh: &[F],
133 y_guess: &[Array1<F>],
134 options: Option<CollocationBVPOptions<F>>,
135) -> IntegrateResult<CollocationBVPResult<F>>
136where
137 F: IntegrateFloat,
138 OdeFn: Fn(F, ArrayView1<F>) -> Array1<F> + Copy,
139 BcFn: Fn(ArrayView1<F>, ArrayView1<F>) -> Array1<F>,
140{
141 let opts = options.unwrap_or_default();
142
143 if x_mesh.len() < 2 {
145 return Err(IntegrateError::ValueError(
146 "Mesh must have at least 2 points".into(),
147 ));
148 }
149 if x_mesh.len() != y_guess.len() {
150 return Err(IntegrateError::ValueError(
151 "Mesh and guess must have the same length".into(),
152 ));
153 }
154
155 let n_dim = y_guess[0].len();
156 for (i, yg) in y_guess.iter().enumerate() {
157 if yg.len() != n_dim {
158 return Err(IntegrateError::ValueError(format!(
159 "Guess at index {i} has wrong dimension: {} vs expected {n_dim}",
160 yg.len()
161 )));
162 }
163 }
164
165 for i in 1..x_mesh.len() {
167 if x_mesh[i] <= x_mesh[i - 1] {
168 return Err(IntegrateError::ValueError(
169 "Mesh must be strictly increasing".into(),
170 ));
171 }
172 }
173
174 let mut mesh = x_mesh.to_vec();
175 let mut y_sol: Vec<Array1<F>> = y_guess.to_vec();
176 let mut total_newton = 0_usize;
177 let mut n_refinements = 0_usize;
178
179 loop {
180 let (new_y, newton_iter, converged) =
182 newton_collocation(&ode, &bc, &mesh, &y_sol, &opts, n_dim)?;
183 total_newton += newton_iter;
184 y_sol = new_y;
185
186 if !converged {
187 return Ok(CollocationBVPResult {
188 x: mesh,
189 y: y_sol,
190 n_newton_iter: total_newton,
191 n_refinements,
192 max_defect: F::infinity(),
193 converged: false,
194 });
195 }
196
197 let defects = compute_defects(&ode, &mesh, &y_sol, n_dim)?;
199 let max_defect = defects
200 .iter()
201 .copied()
202 .fold(F::zero(), |a, b| if b > a { b } else { a });
203
204 if max_defect <= opts.tol {
205 return Ok(CollocationBVPResult {
206 x: mesh,
207 y: y_sol,
208 n_newton_iter: total_newton,
209 n_refinements,
210 max_defect,
211 converged: true,
212 });
213 }
214
215 n_refinements += 1;
216 if n_refinements >= opts.max_mesh_refinements {
217 return Ok(CollocationBVPResult {
218 x: mesh,
219 y: y_sol,
220 n_newton_iter: total_newton,
221 n_refinements,
222 max_defect,
223 converged: false,
224 });
225 }
226
227 let (new_mesh, new_y_sol) = refine_mesh(
229 &ode,
230 &mesh,
231 &y_sol,
232 &defects,
233 opts.tol,
234 opts.max_mesh_size,
235 n_dim,
236 )?;
237
238 if new_mesh.len() >= opts.max_mesh_size {
239 return Ok(CollocationBVPResult {
240 x: new_mesh,
241 y: new_y_sol,
242 n_newton_iter: total_newton,
243 n_refinements,
244 max_defect,
245 converged: false,
246 });
247 }
248
249 mesh = new_mesh;
250 y_sol = new_y_sol;
251 }
252}
253
254fn newton_collocation<F, OdeFn, BcFn>(
261 ode: &OdeFn,
262 bc: &BcFn,
263 mesh: &[F],
264 y_init: &[Array1<F>],
265 opts: &CollocationBVPOptions<F>,
266 n_dim: usize,
267) -> IntegrateResult<(Vec<Array1<F>>, usize, bool)>
268where
269 F: IntegrateFloat,
270 OdeFn: Fn(F, ArrayView1<F>) -> Array1<F>,
271 BcFn: Fn(ArrayView1<F>, ArrayView1<F>) -> Array1<F>,
272{
273 let n_pts = mesh.len();
274 let n_intervals = n_pts - 1;
275 let n_vars = n_pts * n_dim;
277 let n_eqs = n_dim + n_intervals * n_dim;
279
280 if n_eqs != n_vars {
281 return Err(IntegrateError::DimensionMismatch(format!(
282 "Collocation system: {n_eqs} equations vs {n_vars} unknowns"
283 )));
284 }
285
286 let mut y_flat = flatten_solution(y_init, n_dim);
287 let eps: F = to_f(1e-8);
288
289 let mut converged = false;
290 let mut iter = 0_usize;
291
292 while iter < opts.max_newton_iter {
293 iter += 1;
294
295 let residual = assemble_residual(ode, bc, mesh, &y_flat, n_dim, n_pts)?;
297
298 let res_norm = residual
300 .iter()
301 .fold(F::zero(), |acc, &r| acc + r * r)
302 .sqrt()
303 / to_f::<F>(n_eqs as f64).max(F::one());
304
305 if res_norm < opts.tol {
306 converged = true;
307 break;
308 }
309
310 let jac = assemble_jacobian(ode, bc, mesh, &y_flat, &residual, n_dim, n_pts, eps)?;
312
313 let neg_res = residual.mapv(|r| -r);
315 let delta = solve_dense_system(&jac, &neg_res)?;
316
317 for i in 0..n_vars {
319 y_flat[i] += opts.damping * delta[i];
320 }
321 }
322
323 let y_sol = unflatten_solution(&y_flat, n_dim, n_pts);
324 Ok((y_sol, iter, converged))
325}
326
327fn assemble_residual<F, OdeFn, BcFn>(
334 ode: &OdeFn,
335 bc: &BcFn,
336 mesh: &[F],
337 y_flat: &Array1<F>,
338 n_dim: usize,
339 n_pts: usize,
340) -> IntegrateResult<Array1<F>>
341where
342 F: IntegrateFloat,
343 OdeFn: Fn(F, ArrayView1<F>) -> Array1<F>,
344 BcFn: Fn(ArrayView1<F>, ArrayView1<F>) -> Array1<F>,
345{
346 let n_intervals = n_pts - 1;
347 let n_eqs = n_dim + n_intervals * n_dim;
348 let mut res = Array1::zeros(n_eqs);
349
350 let y_a = y_flat.slice(s![0..n_dim]);
352 let y_b = y_flat.slice(s![(n_pts - 1) * n_dim..n_pts * n_dim]);
353
354 let bc_res = bc(y_a, y_b);
356 for j in 0..n_dim {
357 res[j] = bc_res[j];
358 }
359
360 let half: F = to_f(0.5);
366 let eighth: F = to_f(0.125);
367
368 for i in 0..n_intervals {
369 let x_i = mesh[i];
370 let x_ip1 = mesh[i + 1];
371 let h = x_ip1 - x_i;
372 let x_mid = (x_i + x_ip1) * half;
373
374 let y_i = y_flat.slice(s![i * n_dim..(i + 1) * n_dim]);
375 let y_ip1 = y_flat.slice(s![(i + 1) * n_dim..(i + 2) * n_dim]);
376
377 let f_i = ode(x_i, y_i);
378 let f_ip1 = ode(x_ip1, y_ip1);
379
380 let mut y_mid = Array1::zeros(n_dim);
382 for j in 0..n_dim {
383 y_mid[j] = (y_i[j] + y_ip1[j]) * half + h * eighth * (f_i[j] - f_ip1[j]);
384 }
385
386 let f_mid = ode(x_mid, y_mid.view());
387
388 let sixth: F = to_f(1.0 / 6.0);
391 let four: F = to_f(4.0);
392 let eq_offset = n_dim + i * n_dim;
393 for j in 0..n_dim {
394 res[eq_offset + j] =
395 y_ip1[j] - y_i[j] - h * sixth * (f_i[j] + four * f_mid[j] + f_ip1[j]);
396 }
397 }
398
399 Ok(res)
400}
401
402fn assemble_jacobian<F, OdeFn, BcFn>(
407 ode: &OdeFn,
408 bc: &BcFn,
409 mesh: &[F],
410 y_flat: &Array1<F>,
411 res0: &Array1<F>,
412 n_dim: usize,
413 n_pts: usize,
414 eps: F,
415) -> IntegrateResult<Array2<F>>
416where
417 F: IntegrateFloat,
418 OdeFn: Fn(F, ArrayView1<F>) -> Array1<F>,
419 BcFn: Fn(ArrayView1<F>, ArrayView1<F>) -> Array1<F>,
420{
421 let n_vars = n_pts * n_dim;
422 let n_eqs = res0.len();
423 let mut jac = Array2::zeros((n_eqs, n_vars));
424
425 for col in 0..n_vars {
426 let mut y_pert = y_flat.clone();
427 let delta = eps * (F::one() + y_pert[col].abs());
428 y_pert[col] += delta;
429
430 let res_pert = assemble_residual(ode, bc, mesh, &y_pert, n_dim, n_pts)?;
431
432 for row in 0..n_eqs {
433 jac[[row, col]] = (res_pert[row] - res0[row]) / delta;
434 }
435 }
436
437 Ok(jac)
438}
439
440fn solve_dense_system<F: IntegrateFloat>(
445 a: &Array2<F>,
446 b: &Array1<F>,
447) -> IntegrateResult<Array1<F>> {
448 let n = a.nrows();
449 if n != a.ncols() || n != b.len() {
450 return Err(IntegrateError::DimensionMismatch(
451 "solve_dense_system: dimension mismatch".into(),
452 ));
453 }
454
455 let mut lu = a.clone();
456 let mut piv: Vec<usize> = (0..n).collect();
457 let tiny = F::from_f64(1e-30).unwrap_or_else(|| F::epsilon());
458
459 for k in 0..n {
460 let mut max_val = lu[[piv[k], k]].abs();
461 let mut max_idx = k;
462 for i in (k + 1)..n {
463 let v = lu[[piv[i], k]].abs();
464 if v > max_val {
465 max_val = v;
466 max_idx = i;
467 }
468 }
469 if max_val < tiny {
470 return Err(IntegrateError::LinearSolveError(
471 "Singular matrix in collocation solver".into(),
472 ));
473 }
474 piv.swap(k, max_idx);
475
476 for i in (k + 1)..n {
477 let factor = lu[[piv[i], k]] / lu[[piv[k], k]];
478 lu[[piv[i], k]] = factor;
479 for j in (k + 1)..n {
480 let val = lu[[piv[k], j]];
481 lu[[piv[i], j]] -= factor * val;
482 }
483 }
484 }
485
486 let mut z = Array1::zeros(n);
487 for i in 0..n {
488 let mut s = b[piv[i]];
489 for j in 0..i {
490 s -= lu[[piv[i], j]] * z[j];
491 }
492 z[i] = s;
493 }
494
495 let mut x = Array1::zeros(n);
496 for i in (0..n).rev() {
497 let mut s = z[i];
498 for j in (i + 1)..n {
499 s -= lu[[piv[i], j]] * x[j];
500 }
501 if lu[[piv[i], i]].abs() < tiny {
502 return Err(IntegrateError::LinearSolveError(
503 "Zero diagonal in collocation LU".into(),
504 ));
505 }
506 x[i] = s / lu[[piv[i], i]];
507 }
508
509 Ok(x)
510}
511
512fn compute_defects<F, OdeFn>(
519 ode: &OdeFn,
520 mesh: &[F],
521 y_sol: &[Array1<F>],
522 n_dim: usize,
523) -> IntegrateResult<Vec<F>>
524where
525 F: IntegrateFloat,
526 OdeFn: Fn(F, ArrayView1<F>) -> Array1<F>,
527{
528 let n_intervals = mesh.len() - 1;
529 let mut defects = Vec::with_capacity(n_intervals);
530 let half: F = to_f(0.5);
531
532 for i in 0..n_intervals {
533 let h = mesh[i + 1] - mesh[i];
534 let x_mid = (mesh[i] + mesh[i + 1]) * half;
535
536 let f_i = ode(mesh[i], y_sol[i].view());
537 let f_ip1 = ode(mesh[i + 1], y_sol[i + 1].view());
538
539 let mut y_mid = Array1::zeros(n_dim);
541 for j in 0..n_dim {
542 y_mid[j] =
543 (y_sol[i][j] + y_sol[i + 1][j]) * half + h * to_f::<F>(0.125) * (f_i[j] - f_ip1[j]);
544 }
545
546 let mut yp_mid = Array1::zeros(n_dim);
548 for j in 0..n_dim {
549 yp_mid[j] = (y_sol[i + 1][j] - y_sol[i][j]) / h - to_f::<F>(0.25) * (f_i[j] + f_ip1[j])
550 + half * to_f::<F>(1.0) * ((y_sol[i + 1][j] - y_sol[i][j]) / h);
551 yp_mid[j] = to_f::<F>(1.5) * (y_sol[i + 1][j] - y_sol[i][j]) / h
553 - to_f::<F>(0.25) * (f_i[j] + f_ip1[j]);
554 }
555
556 let f_mid = ode(x_mid, y_mid.view());
557
558 let mut defect_sq = F::zero();
560 for j in 0..n_dim {
561 let d = yp_mid[j] - f_mid[j];
562 defect_sq += d * d;
563 }
564 defects.push(defect_sq.sqrt());
565 }
566
567 Ok(defects)
568}
569
570fn refine_mesh<F, OdeFn>(
576 ode: &OdeFn,
577 mesh: &[F],
578 y_sol: &[Array1<F>],
579 defects: &[F],
580 tol: F,
581 max_size: usize,
582 n_dim: usize,
583) -> IntegrateResult<(Vec<F>, Vec<Array1<F>>)>
584where
585 F: IntegrateFloat,
586 OdeFn: Fn(F, ArrayView1<F>) -> Array1<F>,
587{
588 let mut new_mesh = Vec::new();
589 let mut new_y = Vec::new();
590
591 new_mesh.push(mesh[0]);
592 new_y.push(y_sol[0].clone());
593
594 for i in 0..(mesh.len() - 1) {
595 if defects[i] > tol && new_mesh.len() + 2 <= max_size {
596 let half: F = to_f(0.5);
598 let x_mid = (mesh[i] + mesh[i + 1]) * half;
599 let h = mesh[i + 1] - mesh[i];
600
601 let f_i = ode(mesh[i], y_sol[i].view());
602 let f_ip1 = ode(mesh[i + 1], y_sol[i + 1].view());
603
604 let mut y_mid = Array1::zeros(n_dim);
605 for j in 0..n_dim {
606 y_mid[j] = (y_sol[i][j] + y_sol[i + 1][j]) * half
607 + h * to_f::<F>(0.125) * (f_i[j] - f_ip1[j]);
608 }
609
610 new_mesh.push(x_mid);
611 new_y.push(y_mid);
612 }
613
614 new_mesh.push(mesh[i + 1]);
615 new_y.push(y_sol[i + 1].clone());
616 }
617
618 Ok((new_mesh, new_y))
619}
620
621fn flatten_solution<F: IntegrateFloat>(y: &[Array1<F>], n_dim: usize) -> Array1<F> {
626 let n_pts = y.len();
627 let mut flat = Array1::zeros(n_pts * n_dim);
628 for (i, yi) in y.iter().enumerate() {
629 for j in 0..n_dim {
630 flat[i * n_dim + j] = yi[j];
631 }
632 }
633 flat
634}
635
636fn unflatten_solution<F: IntegrateFloat>(
637 flat: &Array1<F>,
638 n_dim: usize,
639 n_pts: usize,
640) -> Vec<Array1<F>> {
641 let mut y = Vec::with_capacity(n_pts);
642 for i in 0..n_pts {
643 let start = i * n_dim;
644 let yi = Array1::from_vec(
645 flat.slice(s![start..start + n_dim])
646 .iter()
647 .copied()
648 .collect(),
649 );
650 y.push(yi);
651 }
652 y
653}
654
655#[cfg(test)]
660mod tests {
661 use super::*;
662 use scirs2_core::ndarray::array;
663
664 #[test]
665 fn test_linear_bvp() {
666 let ode = |_x: f64, y: ArrayView1<f64>| array![y[1], 0.0];
669 let bc = |ya: ArrayView1<f64>, yb: ArrayView1<f64>| array![ya[0] - 0.0, yb[0] - 1.0];
670
671 let n = 5;
672 let mesh: Vec<f64> = (0..n).map(|i| i as f64 / (n as f64 - 1.0)).collect();
673 let guess: Vec<Array1<f64>> = mesh.iter().map(|&x| array![x, 1.0]).collect();
674
675 let result = solve_bvp_collocation(ode, bc, &mesh, &guess, None).expect("linear BVP solve");
676
677 assert!(result.converged, "linear BVP should converge");
678
679 for (i, xi) in result.x.iter().enumerate() {
681 assert!(
682 (result.y[i][0] - *xi).abs() < 1e-4,
683 "y({xi}) = {}, expected {xi}",
684 result.y[i][0]
685 );
686 }
687 }
688
689 #[test]
690 fn test_exponential_bvp() {
691 let ode = |_x: f64, y: ArrayView1<f64>| array![y[1], -y[1]];
697 let bc = |ya: ArrayView1<f64>, yb: ArrayView1<f64>| {
698 let exact_end = (-1.0_f64).exp();
699 array![ya[0] - 1.0, yb[0] - exact_end]
700 };
701
702 let n = 11;
703 let mesh: Vec<f64> = (0..n).map(|i| i as f64 / (n as f64 - 1.0)).collect();
704 let guess: Vec<Array1<f64>> = mesh
705 .iter()
706 .map(|&x| array![(-x).exp(), -(-x).exp()])
707 .collect();
708
709 let result = solve_bvp_collocation(
710 ode,
711 bc,
712 &mesh,
713 &guess,
714 Some(CollocationBVPOptions {
715 max_newton_iter: 100,
716 ..Default::default()
717 }),
718 )
719 .expect("exp BVP solve");
720
721 assert!(result.converged, "exp BVP should converge");
722
723 let y_final = result.y.last().expect("has solution")[0];
724 let exact = (-1.0_f64).exp();
725 assert!(
726 (y_final - exact).abs() < 1e-2,
727 "y(1) = {y_final}, expected {exact}"
728 );
729 }
730
731 #[test]
732 fn test_nonlinear_bvp() {
733 let ode = |_x: f64, y: ArrayView1<f64>| array![y[1], -y[0].exp()];
737 let bc = |ya: ArrayView1<f64>, yb: ArrayView1<f64>| array![ya[0], yb[0]];
738
739 let n = 21;
740 let mesh: Vec<f64> = (0..n).map(|i| i as f64 / (n as f64 - 1.0)).collect();
741 let guess: Vec<Array1<f64>> = mesh
743 .iter()
744 .map(|&x| {
745 let y_val = 0.4 * x * (1.0 - x);
746 let yp_val = 0.4 * (1.0 - 2.0 * x);
747 array![y_val, yp_val]
748 })
749 .collect();
750
751 let result = solve_bvp_collocation(
752 ode,
753 bc,
754 &mesh,
755 &guess,
756 Some(CollocationBVPOptions {
757 max_newton_iter: 100,
758 ..Default::default()
759 }),
760 )
761 .expect("Bratu BVP solve");
762
763 assert!(result.converged, "Bratu BVP should converge");
764
765 assert!(
767 result.y[0][0].abs() < 1e-4,
768 "y(0) should be 0, got {}",
769 result.y[0][0]
770 );
771 let y_end = result.y.last().expect("has solution")[0];
772 assert!(y_end.abs() < 1e-4, "y(1) should be 0, got {y_end}");
773
774 let mid_idx = n / 2;
776 assert!(
777 result.y[mid_idx][0] > 0.0,
778 "Interior should be positive, got {}",
779 result.y[mid_idx][0]
780 );
781 }
782
783 #[test]
784 fn test_stiff_bvp() {
785 let epsilon = 0.1;
789 let ode = move |_x: f64, y: ArrayView1<f64>| array![y[1], y[0] / epsilon];
790 let bc = |ya: ArrayView1<f64>, yb: ArrayView1<f64>| array![ya[0] - 1.0, yb[0]];
791
792 let n = 31;
793 let mesh: Vec<f64> = (0..n).map(|i| i as f64 / (n as f64 - 1.0)).collect();
794 let guess: Vec<Array1<f64>> = mesh.iter().map(|&x| array![1.0 - x, -1.0]).collect();
796
797 let result = solve_bvp_collocation(
798 ode,
799 bc,
800 &mesh,
801 &guess,
802 Some(CollocationBVPOptions {
803 tol: to_f(1e-4),
804 max_newton_iter: 60,
805 ..Default::default()
806 }),
807 )
808 .expect("stiff BVP solve");
809
810 assert!(result.converged, "stiff BVP should converge");
811 assert!(
813 (result.y[0][0] - 1.0).abs() < 1e-3,
814 "y(0) should be 1.0, got {}",
815 result.y[0][0]
816 );
817 let y_end = result.y.last().expect("has solution")[0];
818 assert!(y_end.abs() < 0.1, "y(1) should be ~0, got {y_end}");
819 }
820
821 #[test]
822 fn test_invalid_mesh() {
823 let ode = |_x: f64, _y: ArrayView1<f64>| array![0.0];
824 let bc = |ya: ArrayView1<f64>, _yb: ArrayView1<f64>| array![ya[0]];
825
826 let res = solve_bvp_collocation(ode, bc, &[1.0, 0.0], &[array![0.0], array![0.0]], None);
828 assert!(res.is_err());
829 }
830
831 #[test]
832 fn test_mesh_guess_mismatch() {
833 let ode = |_x: f64, _y: ArrayView1<f64>| array![0.0];
834 let bc = |ya: ArrayView1<f64>, _yb: ArrayView1<f64>| array![ya[0]];
835
836 let res = solve_bvp_collocation(
837 ode,
838 bc,
839 &[0.0, 0.5, 1.0],
840 &[array![0.0], array![0.0]], None,
842 );
843 assert!(res.is_err());
844 }
845}