1#![allow(clippy::needless_range_loop, clippy::type_complexity)]
2#![allow(dead_code)]
13#![allow(clippy::too_many_arguments)]
14
15#[derive(Debug, Clone)]
25pub struct ContinuationState {
26 pub lambda: f64,
28 pub u: Vec<f64>,
30 pub tangent: Vec<f64>,
32 pub arc_length: f64,
34 pub ds: f64,
36}
37
38impl ContinuationState {
39 pub fn new(lambda: f64, u: Vec<f64>, ds: f64) -> Self {
45 let n = u.len();
46 let mut tangent = vec![0.0; n + 1];
48 tangent[n] = 1.0;
49 Self {
50 lambda,
51 u,
52 tangent,
53 arc_length: 0.0,
54 ds,
55 }
56 }
57
58 pub fn dim(&self) -> usize {
60 self.u.len()
61 }
62
63 pub fn normalised_tangent(&self) -> Vec<f64> {
65 let norm: f64 = self.tangent.iter().map(|x| x * x).sum::<f64>().sqrt();
66 if norm > 1e-14 {
67 self.tangent.iter().map(|x| x / norm).collect()
68 } else {
69 self.tangent.clone()
70 }
71 }
72}
73
74pub fn pseudo_arclength_step(state: &ContinuationState) -> ContinuationState {
83 let n = state.u.len();
84 let ds = state.ds;
85 let u_pred: Vec<f64> = state
86 .u
87 .iter()
88 .enumerate()
89 .map(|(i, &ui)| ui + ds * state.tangent[i])
90 .collect();
91 let lambda_pred = state.lambda + ds * state.tangent[n];
92 ContinuationState {
93 lambda: lambda_pred,
94 u: u_pred,
95 tangent: state.tangent.clone(),
96 arc_length: state.arc_length + ds,
97 ds,
98 }
99}
100
101#[derive(Debug, Clone)]
107pub struct CorrectorResult {
108 pub converged: bool,
110 pub iterations: usize,
112 pub residual: f64,
114 pub u: Vec<f64>,
116 pub lambda: f64,
118}
119
120pub fn corrector_newton(
126 predicted: &ContinuationState,
127 prev: &ContinuationState,
128 f: &dyn Fn(&[f64], f64) -> Vec<f64>,
129 jac: &dyn Fn(&[f64], f64) -> Vec<Vec<f64>>,
130 tol: f64,
131 max_iter: usize,
132) -> CorrectorResult {
133 let n = predicted.u.len();
134 let mut u = predicted.u.clone();
135 let mut lam = predicted.lambda;
136 let tangent = &prev.tangent;
137
138 for iter in 0..max_iter {
139 let res = f(&u, lam);
140 let arc_val: f64 = u
141 .iter()
142 .zip(prev.u.iter())
143 .enumerate()
144 .map(|(i, (&ui, &upi))| (ui - upi) * tangent[i])
145 .sum::<f64>()
146 + (lam - prev.lambda) * tangent[n]
147 - prev.ds;
148
149 let res_norm: f64 = (res.iter().map(|r| r * r).sum::<f64>() + arc_val * arc_val).sqrt();
150
151 if res_norm < tol {
152 return CorrectorResult {
153 converged: true,
154 iterations: iter,
155 residual: res_norm,
156 u,
157 lambda: lam,
158 };
159 }
160
161 let eps = 1e-7;
162 let res_lam_plus = f(&u, lam + eps);
163 let f_lam: Vec<f64> = res
164 .iter()
165 .zip(res_lam_plus.iter())
166 .map(|(r, rp)| (rp - r) / eps)
167 .collect();
168
169 let j = jac(&u, lam);
170 let m = n + 1;
171 let mut mat: Vec<Vec<f64>> = Vec::with_capacity(m);
172 for i in 0..n {
173 let mut row = j[i].clone();
174 row.push(f_lam[i]);
175 row.push(-res[i]);
176 mat.push(row);
177 }
178 {
179 let mut row: Vec<f64> = tangent[..n].to_vec();
180 row.push(tangent[n]);
181 row.push(-arc_val);
182 mat.push(row);
183 }
184
185 for col in 0..m {
187 let mut max_row = col;
188 let mut max_val = mat[col][col].abs();
189 for row in (col + 1)..m {
190 if mat[row][col].abs() > max_val {
191 max_val = mat[row][col].abs();
192 max_row = row;
193 }
194 }
195 mat.swap(col, max_row);
196 let pivot = mat[col][col];
197 if pivot.abs() < 1e-14 {
198 return CorrectorResult {
199 converged: false,
200 iterations: iter,
201 residual: res_norm,
202 u,
203 lambda: lam,
204 };
205 }
206 for row in (col + 1)..m {
207 let factor = mat[row][col] / pivot;
208 for k in col..=m {
209 let val = mat[col][k];
210 mat[row][k] -= factor * val;
211 }
212 }
213 }
214 let mut delta = vec![0.0_f64; m];
216 for i in (0..m).rev() {
217 let mut s = mat[i][m];
218 for jj in (i + 1)..m {
219 s -= mat[i][jj] * delta[jj];
220 }
221 delta[i] = s / mat[i][i];
222 }
223
224 for i in 0..n {
225 u[i] += delta[i];
226 }
227 lam += delta[n];
228 }
229
230 let res = f(&u, lam);
231 let res_norm: f64 = res.iter().map(|r| r * r).sum::<f64>().sqrt();
232 CorrectorResult {
233 converged: false,
234 iterations: max_iter,
235 residual: res_norm,
236 u,
237 lambda: lam,
238 }
239}
240
241pub fn matrix_determinant(mat: &[Vec<f64>]) -> f64 {
249 let n = mat.len();
250 if n == 0 {
251 return 1.0;
252 }
253 if n == 1 {
254 return mat[0][0];
255 }
256 if n == 2 {
257 return mat[0][0] * mat[1][1] - mat[0][1] * mat[1][0];
258 }
259 let mut a: Vec<Vec<f64>> = mat.to_vec();
260 let mut sign = 1.0_f64;
261 for col in 0..n {
262 let mut max_row = col;
263 let mut max_val = a[col][col].abs();
264 for row in (col + 1)..n {
265 if a[row][col].abs() > max_val {
266 max_val = a[row][col].abs();
267 max_row = row;
268 }
269 }
270 if max_row != col {
271 a.swap(col, max_row);
272 sign = -sign;
273 }
274 let pivot = a[col][col];
275 if pivot.abs() < 1e-14 {
276 return 0.0;
277 }
278 for row in (col + 1)..n {
279 let factor = a[row][col] / pivot;
280 for k in col..n {
281 let val = a[col][k];
282 a[row][k] -= factor * val;
283 }
284 }
285 }
286 let diag_product: f64 = (0..n).map(|i| a[i][i]).product();
287 sign * diag_product
288}
289
290pub fn detect_fold_point(
300 state_a: &ContinuationState,
301 state_b: &ContinuationState,
302 jac: &dyn Fn(&[f64], f64) -> Vec<Vec<f64>>,
303) -> bool {
304 let det_a = matrix_determinant(&jac(&state_a.u, state_a.lambda));
305 let det_b = matrix_determinant(&jac(&state_b.u, state_b.lambda));
306 det_a * det_b <= 0.0
307}
308
309pub fn stability_index(mat: &[Vec<f64>]) -> usize {
319 let n = mat.len();
320 if n == 0 {
321 return 0;
322 }
323 if n == 1 {
324 return if mat[0][0] > 0.0 { 1 } else { 0 };
325 }
326 if n == 2 {
327 let tr = mat[0][0] + mat[1][1];
328 let det = mat[0][0] * mat[1][1] - mat[0][1] * mat[1][0];
329 let disc = tr * tr - 4.0 * det;
330 if disc < 0.0 {
331 if tr > 0.0 { 2 } else { 0 }
332 } else {
333 let sqrt_d = disc.sqrt();
334 let e1 = (tr + sqrt_d) / 2.0;
335 let e2 = (tr - sqrt_d) / 2.0;
336 let mut count = 0;
337 if e1 > 0.0 {
338 count += 1;
339 }
340 if e2 > 0.0 {
341 count += 1;
342 }
343 count
344 }
345 } else {
346 let mut count = 0;
347 for i in 0..n {
348 let center = mat[i][i];
349 let radius: f64 = (0..n)
350 .filter(|&jj| jj != i)
351 .map(|jj| mat[i][jj].abs())
352 .sum();
353 if center - radius > 0.0 {
354 count += 1;
355 }
356 }
357 count
358 }
359}
360
361pub fn detect_bifurcation(
370 state_a: &ContinuationState,
371 state_b: &ContinuationState,
372 jac: &dyn Fn(&[f64], f64) -> Vec<Vec<f64>>,
373) -> bool {
374 let ja = jac(&state_a.u, state_a.lambda);
375 let jb = jac(&state_b.u, state_b.lambda);
376 let idx_a = stability_index(&ja);
377 let idx_b = stability_index(&jb);
378 idx_a != idx_b
379}
380
381#[derive(Debug, Clone, PartialEq)]
387pub enum BifurcationType {
388 Fold,
390 Pitchfork,
392 Hopf,
394 Unknown,
396}
397
398#[derive(Debug, Clone)]
400pub struct BifurcationPoint {
401 pub state: ContinuationState,
403 pub bif_type: BifurcationType,
405 pub det_j: f64,
407}
408
409pub fn classify_bifurcation(
415 state_a: &ContinuationState,
416 state_b: &ContinuationState,
417 jac: &dyn Fn(&[f64], f64) -> Vec<Vec<f64>>,
418) -> Option<BifurcationPoint> {
419 let det_a = matrix_determinant(&jac(&state_a.u, state_a.lambda));
420 let det_b = matrix_determinant(&jac(&state_b.u, state_b.lambda));
421 if det_a * det_b > 0.0 {
422 let idx_a = stability_index(&jac(&state_a.u, state_a.lambda));
424 let idx_b = stability_index(&jac(&state_b.u, state_b.lambda));
425 if idx_a == idx_b {
426 return None;
427 }
428 let u_mid: Vec<f64> = state_a
430 .u
431 .iter()
432 .zip(&state_b.u)
433 .map(|(a, b)| 0.5 * (a + b))
434 .collect();
435 let lam_mid = 0.5 * (state_a.lambda + state_b.lambda);
436 let mid_state = ContinuationState::new(lam_mid, u_mid, state_a.ds);
437 let j_mid = jac(&mid_state.u, lam_mid);
438 let det_mid = matrix_determinant(&j_mid);
439 return Some(BifurcationPoint {
440 state: mid_state,
441 bif_type: BifurcationType::Hopf,
442 det_j: det_mid,
443 });
444 }
445
446 let u_mid: Vec<f64> = state_a
448 .u
449 .iter()
450 .zip(&state_b.u)
451 .map(|(a, b)| 0.5 * (a + b))
452 .collect();
453 let lam_mid = 0.5 * (state_a.lambda + state_b.lambda);
454 let mid_state = ContinuationState::new(lam_mid, u_mid, state_a.ds);
455 let j_mid = jac(&mid_state.u, lam_mid);
456 let det_mid = matrix_determinant(&j_mid);
457
458 let n2 = j_mid.len();
460 let tr_mid: f64 = (0..n2).map(|i| j_mid[i][i]).sum();
461 let bif_type = if tr_mid.abs() < 1e-6 {
462 BifurcationType::Pitchfork
463 } else {
464 BifurcationType::Fold
465 };
466
467 Some(BifurcationPoint {
468 state: mid_state,
469 bif_type,
470 det_j: det_mid,
471 })
472}
473
474pub fn branch_switching(
484 state: &ContinuationState,
485 jac: &dyn Fn(&[f64], f64) -> Vec<Vec<f64>>,
486 epsilon: f64,
487) -> ContinuationState {
488 let n = state.u.len();
489 let j = jac(&state.u, state.lambda);
490 let mut null_dir = vec![0.0_f64; n];
491 let mut min_col_norm = f64::MAX;
492 for col in 0..n {
493 let col_norm: f64 = (0..n)
494 .map(|row| j[row][col] * j[row][col])
495 .sum::<f64>()
496 .sqrt();
497 if col_norm < min_col_norm {
498 min_col_norm = col_norm;
499 for row in 0..n {
500 null_dir[row] = j[row][col];
501 }
502 }
503 }
504 let norm: f64 = null_dir.iter().map(|x| x * x).sum::<f64>().sqrt();
505 if norm > 1e-14 {
506 for x in &mut null_dir {
507 *x /= norm;
508 }
509 } else {
510 null_dir[0] = 1.0;
511 }
512 let u_new: Vec<f64> = state
513 .u
514 .iter()
515 .zip(&null_dir)
516 .map(|(&ui, &di)| ui + epsilon * di)
517 .collect();
518 let mut tangent_new = null_dir.clone();
519 tangent_new.push(0.0);
520 ContinuationState {
521 lambda: state.lambda,
522 u: u_new,
523 tangent: tangent_new,
524 arc_length: state.arc_length,
525 ds: state.ds,
526 }
527}
528
529pub struct BranchSwitching {
539 pub epsilon: f64,
541 pub max_iter: usize,
543 pub tol: f64,
545}
546
547impl BranchSwitching {
548 pub fn new(epsilon: f64, max_iter: usize, tol: f64) -> Self {
550 Self {
551 epsilon,
552 max_iter,
553 tol,
554 }
555 }
556
557 pub fn switch(
561 &self,
562 state: &ContinuationState,
563 jac: &dyn Fn(&[f64], f64) -> Vec<Vec<f64>>,
564 ) -> ContinuationState {
565 branch_switching(state, jac, self.epsilon)
566 }
567}
568
569pub struct TurningPointLocator {
578 pub tol: f64,
580 pub max_iter: usize,
582}
583
584impl TurningPointLocator {
585 pub fn new(tol: f64, max_iter: usize) -> Self {
587 Self { tol, max_iter }
588 }
589
590 pub fn locate(
595 &self,
596 state_a: &ContinuationState,
597 state_b: &ContinuationState,
598 jac: &dyn Fn(&[f64], f64) -> Vec<Vec<f64>>,
599 ) -> (ContinuationState, f64) {
600 let mut alpha_lo = 0.0_f64;
601 let mut alpha_hi = 1.0_f64;
602
603 let interp = |alpha: f64| {
604 let u_i: Vec<f64> = state_a
605 .u
606 .iter()
607 .zip(&state_b.u)
608 .map(|(a, b)| a + alpha * (b - a))
609 .collect();
610 let lam_i = state_a.lambda + alpha * (state_b.lambda - state_a.lambda);
611 ContinuationState::new(lam_i, u_i, state_a.ds)
612 };
613
614 let det_a = matrix_determinant(&jac(&state_a.u, state_a.lambda));
615 let mut det_lo = det_a;
616 let mut mid_state = interp(0.5_f64);
617 let mut det_mid = matrix_determinant(&jac(&mid_state.u, mid_state.lambda));
618
619 for _ in 0..self.max_iter {
620 if det_mid.abs() < self.tol {
621 break;
622 }
623 let alpha_mid = (alpha_lo + alpha_hi) / 2.0;
624 mid_state = interp(alpha_mid);
625 det_mid = matrix_determinant(&jac(&mid_state.u, mid_state.lambda));
626 if det_lo * det_mid <= 0.0 {
627 alpha_hi = alpha_mid;
628 } else {
629 alpha_lo = alpha_mid;
630 det_lo = det_mid;
631 }
632 }
633 (mid_state, det_mid)
634 }
635}
636
637#[derive(Debug, Clone, PartialEq)]
643pub enum StabilityLabel {
644 Stable,
646 Unstable,
648 Marginal,
650}
651
652pub struct StabilityAnalysis {
657 pub marginal_tol: f64,
659}
660
661impl StabilityAnalysis {
662 pub fn new(marginal_tol: f64) -> Self {
664 Self { marginal_tol }
665 }
666
667 pub fn label(
671 &self,
672 state: &ContinuationState,
673 jac: &dyn Fn(&[f64], f64) -> Vec<Vec<f64>>,
674 ) -> StabilityLabel {
675 let j = jac(&state.u, state.lambda);
676 let n = j.len();
677 if n == 0 {
678 return StabilityLabel::Stable;
679 }
680 if n == 1 {
681 let e = j[0][0];
682 if e.abs() < self.marginal_tol {
683 return StabilityLabel::Marginal;
684 }
685 return if e < 0.0 {
686 StabilityLabel::Stable
687 } else {
688 StabilityLabel::Unstable
689 };
690 }
691 if n == 2 {
692 let tr = j[0][0] + j[1][1];
693 let det = j[0][0] * j[1][1] - j[0][1] * j[1][0];
694 let disc = tr * tr - 4.0 * det;
695 if disc < 0.0 {
696 if tr.abs() < self.marginal_tol {
698 return StabilityLabel::Marginal;
699 }
700 return if tr < 0.0 {
701 StabilityLabel::Stable
702 } else {
703 StabilityLabel::Unstable
704 };
705 }
706 let sqrt_d = disc.sqrt();
707 let e1 = (tr + sqrt_d) / 2.0;
708 let e2 = (tr - sqrt_d) / 2.0;
709 if e1.abs() < self.marginal_tol || e2.abs() < self.marginal_tol {
710 return StabilityLabel::Marginal;
711 }
712 if e1 > 0.0 || e2 > 0.0 {
713 return StabilityLabel::Unstable;
714 }
715 return StabilityLabel::Stable;
716 }
717 let mut any_unstable = false;
719 let mut any_marginal = false;
720 for i in 0..n {
721 let center = j[i][i];
722 let radius: f64 = (0..n).filter(|&jj| jj != i).map(|jj| j[i][jj].abs()).sum();
723 if center - radius > 0.0 {
724 any_unstable = true;
725 }
726 if center.abs() <= radius {
727 any_marginal = true;
728 }
729 }
730 if any_unstable {
731 StabilityLabel::Unstable
732 } else if any_marginal {
733 StabilityLabel::Marginal
734 } else {
735 StabilityLabel::Stable
736 }
737 }
738
739 pub fn label_branch(
741 &self,
742 states: &[ContinuationState],
743 jac: &dyn Fn(&[f64], f64) -> Vec<Vec<f64>>,
744 ) -> Vec<StabilityLabel> {
745 states.iter().map(|s| self.label(s, jac)).collect()
746 }
747}
748
749pub struct PseudoArcLengthContinuation {
760 pub tol: f64,
762 pub max_iter: usize,
764 pub max_iter_fast: usize,
766 pub ds_min: f64,
768 pub ds_max: f64,
770}
771
772impl PseudoArcLengthContinuation {
773 pub fn new(tol: f64, max_iter: usize, max_iter_fast: usize, ds_min: f64, ds_max: f64) -> Self {
775 Self {
776 tol,
777 max_iter,
778 max_iter_fast,
779 ds_min,
780 ds_max,
781 }
782 }
783
784 pub fn step(
788 &self,
789 state: &mut ContinuationState,
790 f: &dyn Fn(&[f64], f64) -> Vec<f64>,
791 jac: &dyn Fn(&[f64], f64) -> Vec<Vec<f64>>,
792 ) -> Option<ContinuationState> {
793 for attempt in 0..5usize {
795 let _ = attempt;
796 let predicted = pseudo_arclength_step(state);
797 let result = corrector_newton(&predicted, state, f, jac, self.tol, self.max_iter);
798 if result.converged {
799 let n = state.u.len();
801 let mut new_tangent = vec![0.0_f64; n + 1];
802 for i in 0..n {
803 new_tangent[i] = result.u[i] - state.u[i];
804 }
805 new_tangent[n] = result.lambda - state.lambda;
806 let t_norm: f64 = new_tangent.iter().map(|x| x * x).sum::<f64>().sqrt();
807 if t_norm > 1e-14 {
808 for x in &mut new_tangent {
809 *x /= t_norm;
810 }
811 }
812
813 let mut accepted = ContinuationState {
814 lambda: result.lambda,
815 u: result.u,
816 tangent: new_tangent,
817 arc_length: state.arc_length + state.ds,
818 ds: state.ds,
819 };
820
821 if result.iterations <= self.max_iter_fast {
823 accepted.ds = (accepted.ds * 2.0).min(self.ds_max);
824 }
825 return Some(accepted);
826 }
827 state.ds = (state.ds / 2.0).max(self.ds_min);
829 if state.ds <= self.ds_min {
830 break;
831 }
832 }
833 None
834 }
835}
836
837#[derive(Debug, Clone)]
843pub struct PathStep {
844 pub state: ContinuationState,
846 pub fold_detected: bool,
848 pub bifurcation_detected: bool,
850}
851
852pub struct PathFollowing {
857 pub max_steps: usize,
859 pub continuation: PseudoArcLengthContinuation,
861}
862
863impl PathFollowing {
864 pub fn new(max_steps: usize, continuation: PseudoArcLengthContinuation) -> Self {
866 Self {
867 max_steps,
868 continuation,
869 }
870 }
871
872 pub fn follow(
876 &self,
877 initial_state: ContinuationState,
878 f: &dyn Fn(&[f64], f64) -> Vec<f64>,
879 jac: &dyn Fn(&[f64], f64) -> Vec<Vec<f64>>,
880 ) -> Vec<PathStep> {
881 let mut steps: Vec<PathStep> = Vec::with_capacity(self.max_steps);
882 let mut current = initial_state;
883
884 for _ in 0..self.max_steps {
885 let prev = current.clone();
886 match self.continuation.step(&mut current, f, jac) {
887 None => break,
888 Some(accepted) => {
889 let fold = detect_fold_point(&prev, &accepted, jac);
890 let bif = detect_bifurcation(&prev, &accepted, jac);
891 steps.push(PathStep {
892 state: accepted.clone(),
893 fold_detected: fold,
894 bifurcation_detected: bif,
895 });
896 current = accepted;
897 }
898 }
899 }
900 steps
901 }
902}
903
904#[cfg(test)]
909mod tests {
910 use super::*;
911
912 fn f1d(u: &[f64], lam: f64) -> Vec<f64> {
914 vec![u[0] * u[0] - lam]
915 }
916 fn jac1d(u: &[f64], _lam: f64) -> Vec<Vec<f64>> {
917 vec![vec![2.0 * u[0]]]
918 }
919
920 fn f2d_pitch(u: &[f64], lam: f64) -> Vec<f64> {
922 vec![u[0].powi(3) - lam * u[0], u[1] + u[0]]
923 }
924 fn jac2d_pitch(u: &[f64], lam: f64) -> Vec<Vec<f64>> {
925 vec![vec![3.0 * u[0] * u[0] - lam, 0.0], vec![1.0, 1.0]]
926 }
927
928 #[test]
929 fn test_continuation_state_new() {
930 let s = ContinuationState::new(0.0, vec![1.0, 2.0], 0.1);
931 assert_eq!(s.dim(), 2);
932 assert_eq!(s.lambda, 0.0);
933 assert_eq!(s.ds, 0.1);
934 assert!((s.tangent[2] - 1.0).abs() < 1e-12);
935 }
936
937 #[test]
938 fn test_pseudo_arclength_step_increments_arc_length() {
939 let s = ContinuationState::new(1.0, vec![1.0], 0.1);
940 let s2 = pseudo_arclength_step(&s);
941 assert!((s2.arc_length - 0.1).abs() < 1e-12);
942 }
943
944 #[test]
945 fn test_pseudo_arclength_step_lambda_moves() {
946 let s = ContinuationState::new(1.0, vec![1.0], 0.5);
947 let s2 = pseudo_arclength_step(&s);
948 assert!((s2.lambda - 1.5).abs() < 1e-12);
949 }
950
951 #[test]
952 fn test_pseudo_arclength_step_u_unchanged_when_tangent_zero() {
953 let mut s = ContinuationState::new(0.0, vec![3.0, 4.0], 0.2);
954 s.tangent = vec![0.0, 0.0, 1.0];
955 let s2 = pseudo_arclength_step(&s);
956 assert!((s2.u[0] - 3.0).abs() < 1e-12);
957 assert!((s2.u[1] - 4.0).abs() < 1e-12);
958 }
959
960 #[test]
961 fn test_corrector_newton_converges_1d() {
962 let prev = ContinuationState::new(1.0, vec![1.0], 0.1);
963 let mut predicted = ContinuationState::new(1.0, vec![1.1], 0.1);
964 predicted.tangent = prev.tangent.clone();
965 let result = corrector_newton(&predicted, &prev, &f1d, &jac1d, 1e-10, 50);
966 assert!(result.converged, "Newton should converge");
967 assert!(result.residual < 1e-8);
968 }
969
970 #[test]
971 fn test_corrector_newton_result_satisfies_equation() {
972 let prev = ContinuationState::new(4.0, vec![2.0], 0.1);
973 let mut predicted = ContinuationState::new(4.0, vec![2.05], 0.1);
974 predicted.tangent = prev.tangent.clone();
975 let result = corrector_newton(&predicted, &prev, &f1d, &jac1d, 1e-10, 50);
976 if result.converged {
977 let res = f1d(&result.u, result.lambda);
978 assert!(res[0].abs() < 1e-6);
979 }
980 }
981
982 #[test]
983 fn test_matrix_determinant_2x2() {
984 let m = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
985 let det = matrix_determinant(&m);
986 assert!((det - (1.0 * 4.0 - 2.0 * 3.0)).abs() < 1e-10);
987 }
988
989 #[test]
990 fn test_matrix_determinant_identity_3x3() {
991 let m = vec![
992 vec![1.0, 0.0, 0.0],
993 vec![0.0, 1.0, 0.0],
994 vec![0.0, 0.0, 1.0],
995 ];
996 assert!((matrix_determinant(&m) - 1.0).abs() < 1e-10);
997 }
998
999 #[test]
1000 fn test_matrix_determinant_singular() {
1001 let m = vec![
1002 vec![1.0, 2.0, 3.0],
1003 vec![1.0, 2.0, 3.0],
1004 vec![4.0, 5.0, 6.0],
1005 ];
1006 assert!(matrix_determinant(&m).abs() < 1e-10);
1007 }
1008
1009 #[test]
1010 fn test_matrix_determinant_1x1() {
1011 let m = vec![vec![7.0]];
1012 assert!((matrix_determinant(&m) - 7.0).abs() < 1e-10);
1013 }
1014
1015 #[test]
1016 fn test_matrix_determinant_empty() {
1017 let m: Vec<Vec<f64>> = vec![];
1018 assert!((matrix_determinant(&m) - 1.0).abs() < 1e-10);
1019 }
1020
1021 #[test]
1022 fn test_detect_fold_point_true() {
1023 let sa = ContinuationState::new(0.5, vec![0.5], 0.1);
1024 let sb = ContinuationState::new(-0.5, vec![-0.5], 0.1);
1025 let fold = detect_fold_point(&sa, &sb, &jac1d);
1026 assert!(
1027 fold,
1028 "fold should be detected across a sign change in det(J)"
1029 );
1030 }
1031
1032 #[test]
1033 fn test_detect_fold_point_false() {
1034 let sa = ContinuationState::new(1.0, vec![1.0], 0.1);
1035 let sb = ContinuationState::new(2.0, vec![2.0], 0.1);
1036 let fold = detect_fold_point(&sa, &sb, &jac1d);
1037 assert!(!fold);
1038 }
1039
1040 #[test]
1041 fn test_stability_index_1x1_positive() {
1042 let m = vec![vec![3.0]];
1043 assert_eq!(stability_index(&m), 1);
1044 }
1045
1046 #[test]
1047 fn test_stability_index_1x1_negative() {
1048 let m = vec![vec![-2.0]];
1049 assert_eq!(stability_index(&m), 0);
1050 }
1051
1052 #[test]
1053 fn test_stability_index_2x2_stable() {
1054 let m = vec![vec![-2.0, 0.0], vec![0.0, -3.0]];
1055 assert_eq!(stability_index(&m), 0);
1056 }
1057
1058 #[test]
1059 fn test_stability_index_2x2_unstable() {
1060 let m = vec![vec![2.0, 0.0], vec![0.0, 3.0]];
1061 assert_eq!(stability_index(&m), 2);
1062 }
1063
1064 #[test]
1065 fn test_stability_index_2x2_one_unstable() {
1066 let m = vec![vec![1.0, 0.0], vec![0.0, -2.0]];
1067 assert_eq!(stability_index(&m), 1);
1068 }
1069
1070 #[test]
1071 fn test_stability_index_empty() {
1072 let m: Vec<Vec<f64>> = vec![];
1073 assert_eq!(stability_index(&m), 0);
1074 }
1075
1076 #[test]
1077 fn test_detect_bifurcation_detects_change() {
1078 let sa = ContinuationState::new(0.0, vec![0.0, 0.0], 0.1);
1079 let sb = ContinuationState::new(1.0, vec![1.0, 1.0], 0.1);
1080 let jac_stable = |_u: &[f64], _lam: f64| vec![vec![-1.0, 0.0], vec![0.0, -1.0]];
1081 assert!(!detect_bifurcation(&sa, &sb, &jac_stable));
1082 let jac_crossing = |u: &[f64], _lam: f64| {
1083 let v = u[0];
1084 vec![vec![v, 0.0], vec![0.0, v]]
1085 };
1086 let sa2 = ContinuationState::new(0.0, vec![-1.0, 0.0], 0.1);
1087 let sb2 = ContinuationState::new(1.0, vec![1.0, 0.0], 0.1);
1088 assert!(detect_bifurcation(&sa2, &sb2, &jac_crossing));
1089 }
1090
1091 #[test]
1092 fn test_branch_switching_perturbs_solution() {
1093 let s = ContinuationState::new(1.0, vec![1.0, 0.0], 0.1);
1094 let jac2d = |_u: &[f64], _lam: f64| vec![vec![1e-15, 0.0], vec![0.0, 1.0]];
1095 let s_branch = branch_switching(&s, &jac2d, 0.01);
1096 let diff: f64 =
1097 s.u.iter()
1098 .zip(&s_branch.u)
1099 .map(|(a, b)| (a - b).abs())
1100 .sum();
1101 assert!(diff > 0.0, "branch switching must perturb the solution");
1102 }
1103
1104 #[test]
1105 fn test_branch_switching_preserves_lambda() {
1106 let s = ContinuationState::new(2.5, vec![1.0], 0.05);
1107 let jac_id = |_u: &[f64], _lam: f64| vec![vec![1.0]];
1108 let s2 = branch_switching(&s, &jac_id, 0.1);
1109 assert!((s2.lambda - 2.5).abs() < 1e-12);
1110 }
1111
1112 #[test]
1113 fn test_corrector_newton_diverges_on_singular() {
1114 let f_zero = |_u: &[f64], _lam: f64| vec![0.0];
1115 let jac_zero = |_u: &[f64], _lam: f64| vec![vec![0.0]];
1116 let prev = ContinuationState::new(0.0, vec![0.0], 0.1);
1117 let mut predicted = ContinuationState::new(0.0, vec![0.1], 0.1);
1118 predicted.tangent = prev.tangent.clone();
1119 let _result = corrector_newton(&predicted, &prev, &f_zero, &jac_zero, 1e-10, 5);
1120 }
1121
1122 #[test]
1123 fn test_arc_length_accumulates_over_steps() {
1124 let s0 = ContinuationState::new(0.0, vec![0.0], 0.2);
1125 let s1 = pseudo_arclength_step(&s0);
1126 let s2 = pseudo_arclength_step(&s1);
1127 assert!((s2.arc_length - 0.4).abs() < 1e-12);
1128 }
1129
1130 #[test]
1131 fn test_matrix_determinant_4x4_diagonal() {
1132 let m = vec![
1133 vec![2.0, 0.0, 0.0, 0.0],
1134 vec![0.0, 3.0, 0.0, 0.0],
1135 vec![0.0, 0.0, 5.0, 0.0],
1136 vec![0.0, 0.0, 0.0, 7.0],
1137 ];
1138 let det = matrix_determinant(&m);
1139 assert!((det - 210.0).abs() < 1e-8);
1140 }
1141
1142 #[test]
1143 fn test_continuation_state_dim() {
1144 let s = ContinuationState::new(0.0, vec![1.0, 2.0, 3.0], 0.1);
1145 assert_eq!(s.dim(), 3);
1146 }
1147
1148 #[test]
1149 fn test_tangent_length_matches_n_plus_1() {
1150 let s = ContinuationState::new(0.0, vec![1.0, 2.0, 3.0], 0.1);
1151 assert_eq!(s.tangent.len(), 4);
1152 }
1153
1154 #[test]
1155 fn test_classify_bifurcation_detects_fold() {
1156 let sa = ContinuationState::new(1.0, vec![1.0], 0.1);
1162 let sb = ContinuationState::new(1.0, vec![-1.0], 0.1);
1163 let bif = classify_bifurcation(&sa, &sb, &jac1d);
1164 assert!(bif.is_some(), "a bifurcation should be detected");
1165 let bp = bif.unwrap();
1166 assert!(
1167 bp.bif_type == BifurcationType::Fold || bp.bif_type == BifurcationType::Pitchfork,
1168 "expected Fold or Pitchfork, got {:?}",
1169 bp.bif_type
1170 );
1171 }
1172
1173 #[test]
1174 fn test_classify_bifurcation_none_when_same_stability() {
1175 let sa = ContinuationState::new(1.0, vec![1.0], 0.1);
1176 let sb = ContinuationState::new(2.0, vec![2.0], 0.1);
1177 let bif = classify_bifurcation(&sa, &sb, &jac1d);
1179 assert!(bif.is_none());
1180 }
1181
1182 #[test]
1183 fn test_stability_analysis_stable_label() {
1184 let sa = StabilityAnalysis::new(1e-6);
1185 let state = ContinuationState::new(0.0, vec![0.0], 0.1);
1186 let jac_neg = |_u: &[f64], _lam: f64| vec![vec![-1.0]];
1187 assert_eq!(sa.label(&state, &jac_neg), StabilityLabel::Stable);
1188 }
1189
1190 #[test]
1191 fn test_stability_analysis_unstable_label() {
1192 let sa = StabilityAnalysis::new(1e-6);
1193 let state = ContinuationState::new(0.0, vec![0.0], 0.1);
1194 let jac_pos = |_u: &[f64], _lam: f64| vec![vec![1.0]];
1195 assert_eq!(sa.label(&state, &jac_pos), StabilityLabel::Unstable);
1196 }
1197
1198 #[test]
1199 fn test_stability_analysis_marginal_label() {
1200 let sa = StabilityAnalysis::new(1e-6);
1201 let state = ContinuationState::new(0.0, vec![0.0], 0.1);
1202 let jac_zero = |_u: &[f64], _lam: f64| vec![vec![0.0]];
1203 assert_eq!(sa.label(&state, &jac_zero), StabilityLabel::Marginal);
1204 }
1205
1206 #[test]
1207 fn test_stability_analysis_branch_labels() {
1208 let sa = StabilityAnalysis::new(1e-6);
1209 let states = vec![
1210 ContinuationState::new(0.0, vec![-1.0], 0.1),
1211 ContinuationState::new(0.0, vec![0.0], 0.1),
1212 ContinuationState::new(0.0, vec![1.0], 0.1),
1213 ];
1214 let jac_sign = |u: &[f64], _lam: f64| vec![vec![u[0]]];
1215 let labels = sa.label_branch(&states, &jac_sign);
1216 assert_eq!(labels.len(), 3);
1217 assert_eq!(labels[0], StabilityLabel::Stable);
1218 assert_eq!(labels[2], StabilityLabel::Unstable);
1219 }
1220
1221 #[test]
1222 fn test_turning_point_locator_basic() {
1223 let sa = ContinuationState::new(0.25, vec![0.5], 0.1);
1225 let sb = ContinuationState::new(0.25, vec![-0.5], 0.1);
1226 let locator = TurningPointLocator::new(1e-6, 50);
1227 let (mid, det) = locator.locate(&sa, &sb, &jac1d);
1228 assert!(
1230 det.abs() < 0.1,
1231 "det at turning point should be near 0, got {}",
1232 det
1233 );
1234 let _ = mid;
1235 }
1236
1237 #[test]
1238 fn test_normalised_tangent_length() {
1239 let s = ContinuationState::new(0.0, vec![3.0, 4.0], 0.1);
1240 let nt = s.normalised_tangent();
1241 let norm: f64 = nt.iter().map(|x| x * x).sum::<f64>().sqrt();
1242 assert!(
1243 (norm - 1.0).abs() < 1e-10,
1244 "normalised tangent should have unit norm"
1245 );
1246 }
1247
1248 #[test]
1249 fn test_branch_switching_struct() {
1250 let bs = BranchSwitching::new(0.01, 20, 1e-8);
1251 let s = ContinuationState::new(1.0, vec![1.0, 0.0], 0.1);
1252 let jac2d = |_u: &[f64], _lam: f64| vec![vec![1e-15, 0.0], vec![0.0, 1.0]];
1253 let s2 = bs.switch(&s, &jac2d);
1254 let diff: f64 = s.u.iter().zip(&s2.u).map(|(a, b)| (a - b).abs()).sum();
1255 assert!(diff > 0.0);
1256 }
1257
1258 #[test]
1259 fn test_pseudo_arc_length_continuation_step() {
1260 let mut state = ContinuationState::new(1.0, vec![1.0], 0.05);
1262 state.tangent = vec![1.0 / 2.0_f64.sqrt(), 1.0 / 2.0_f64.sqrt()];
1264 let cont = PseudoArcLengthContinuation::new(1e-8, 30, 5, 1e-4, 0.5);
1265 let result = cont.step(&mut state, &f1d, &jac1d);
1266 assert!(result.is_some(), "continuation step should succeed");
1267 let accepted = result.unwrap();
1268 let residual = (accepted.u[0] * accepted.u[0] - accepted.lambda).abs();
1270 assert!(residual < 1e-6, "residual = {}", residual);
1271 }
1272
1273 #[test]
1274 fn test_path_following_runs_multiple_steps() {
1275 let initial = ContinuationState::new(1.0, vec![1.0], 0.05);
1276 let cont = PseudoArcLengthContinuation::new(1e-8, 30, 5, 1e-4, 0.5);
1277 let pf = PathFollowing::new(10, cont);
1278 let steps = pf.follow(initial, &f1d, &jac1d);
1279 assert!(!steps.is_empty(), "path following should produce steps");
1281 }
1282
1283 #[test]
1284 fn test_path_following_arc_length_increasing() {
1285 let initial = ContinuationState::new(1.0, vec![1.0], 0.05);
1286 let cont = PseudoArcLengthContinuation::new(1e-8, 30, 5, 1e-4, 0.5);
1287 let pf = PathFollowing::new(5, cont);
1288 let steps = pf.follow(initial, &f1d, &jac1d);
1289 for w in steps.windows(2) {
1290 assert!(
1291 w[1].state.arc_length >= w[0].state.arc_length,
1292 "arc length should be non-decreasing"
1293 );
1294 }
1295 }
1296
1297 #[test]
1298 fn test_matrix_determinant_3x3_known() {
1299 let m = vec![
1301 vec![1.0, 2.0, 3.0],
1302 vec![0.0, 4.0, 5.0],
1303 vec![1.0, 0.0, 6.0],
1304 ];
1305 let det = matrix_determinant(&m);
1306 assert!((det - 22.0).abs() < 1e-8, "det = {}", det);
1307 }
1308
1309 #[test]
1310 fn test_jac2d_pitch_at_zero() {
1311 let j = jac2d_pitch(&[0.0, 0.0], 0.0);
1313 assert!((j[0][0] - 0.0).abs() < 1e-12);
1314 assert!((j[1][0] - 1.0).abs() < 1e-12);
1315 }
1316
1317 #[test]
1318 fn test_f2d_pitch_at_trivial() {
1319 let res = f2d_pitch(&[0.0, 0.0], 1.0);
1321 assert!(res[0].abs() < 1e-12);
1322 assert!(res[1].abs() < 1e-12);
1323 }
1324}