Skip to main content

oxiphysics_softbody/
solver.rs

1// Copyright 2026 COOLJAPAN OU (Team KitaSan)
2// SPDX-License-Identifier: Apache-2.0
3
4//! XPBD solver for soft-body simulation.
5
6use oxiphysics_core::math::Real;
7
8use crate::constraint::SoftConstraint;
9use crate::particle::{SoftBody, SoftParticle};
10
11// ---------------------------------------------------------------------------
12// Constraint kind enum
13// ---------------------------------------------------------------------------
14
15/// The kind of compliance a constraint carries.
16///
17/// Compliance is the inverse of stiffness: higher compliance means softer.
18#[derive(Debug, Clone, Copy, PartialEq)]
19pub enum ConstraintKind {
20    /// Perfectly rigid – zero compliance, enforced exactly.
21    Rigid,
22    /// Linear elastic with explicit stiffness (N/m).
23    Elastic {
24        /// Spring stiffness in N/m.
25        stiffness: Real,
26    },
27    /// Bend/dihedral softness, measured in N·m/rad.
28    Bending {
29        /// Bending stiffness in N·m/rad.
30        stiffness: Real,
31    },
32    /// Volume preservation with given bulk modulus (Pa).
33    Volume {
34        /// Bulk modulus in Pa.
35        bulk_modulus: Real,
36    },
37    /// Collision response with a restitution coefficient.
38    Collision {
39        /// Coefficient of restitution (dimensionless, 0–1).
40        restitution: Real,
41    },
42    /// Arbitrary compliance value (inverse stiffness, m²/N).
43    Custom {
44        /// Compliance value (α = 1/k).
45        compliance: Real,
46    },
47}
48
49impl ConstraintKind {
50    /// Convert the kind to an XPBD compliance value (α).
51    ///
52    /// Returns 0 for `Rigid` and `Collision` (hard constraints).
53    pub fn compliance(&self) -> Real {
54        match self {
55            ConstraintKind::Rigid => 0.0,
56            ConstraintKind::Elastic { stiffness } => {
57                if *stiffness > 0.0 {
58                    1.0 / stiffness
59                } else {
60                    0.0
61                }
62            }
63            ConstraintKind::Bending { stiffness } => {
64                if *stiffness > 0.0 {
65                    1.0 / stiffness
66                } else {
67                    0.0
68                }
69            }
70            ConstraintKind::Volume { bulk_modulus } => {
71                if *bulk_modulus > 0.0 {
72                    1.0 / bulk_modulus
73                } else {
74                    0.0
75                }
76            }
77            ConstraintKind::Collision { .. } => 0.0,
78            ConstraintKind::Custom { compliance } => *compliance,
79        }
80    }
81}
82
83// ---------------------------------------------------------------------------
84// Sleep state
85// ---------------------------------------------------------------------------
86
87/// Whether the solver considers the body asleep.
88#[derive(Debug, Clone, Copy, PartialEq, Eq)]
89pub enum SleepState {
90    /// The body is actively being simulated.
91    Awake,
92    /// The body is below the sleep threshold and not being integrated.
93    Asleep,
94}
95
96// ---------------------------------------------------------------------------
97// XpbdSolver
98// ---------------------------------------------------------------------------
99
100/// Extended Position Based Dynamics (XPBD) solver.
101///
102/// The solver performs the following pipeline per call to [`XpbdSolver::solve`]:
103///
104/// 1. Predict positions using semi-implicit Euler integration.
105/// 2. For each sub-step, project all constraints.
106/// 3. Update velocities from corrected positions.
107/// 4. Apply velocity damping.
108#[derive(Debug, Clone)]
109pub struct XpbdSolver {
110    /// Number of sub-steps per solver call.
111    pub num_substeps: usize,
112    /// Number of constraint-projection iterations per sub-step.
113    pub num_iterations: usize,
114    /// Velocity magnitude threshold below which a body is considered asleep.
115    pub sleep_threshold: Real,
116    /// Number of consecutive solver calls all particles must be below the
117    /// sleep threshold before the body transitions to [`SleepState::Asleep`].
118    pub sleep_counter_max: usize,
119    /// Internal counter: how many calls have been below the sleep threshold.
120    sleep_counter: usize,
121    /// Current sleep state.
122    pub sleep_state: SleepState,
123}
124
125impl XpbdSolver {
126    /// Create a new XPBD solver with the given number of sub-steps.
127    pub fn new(num_substeps: usize) -> Self {
128        Self {
129            num_substeps,
130            num_iterations: 1,
131            sleep_threshold: 1e-4,
132            sleep_counter_max: 10,
133            sleep_counter: 0,
134            sleep_state: SleepState::Awake,
135        }
136    }
137
138    /// Create a solver with explicit sub-step and iteration counts.
139    pub fn with_iterations(num_substeps: usize, num_iterations: usize) -> Self {
140        Self {
141            num_substeps,
142            num_iterations,
143            ..Self::new(num_substeps)
144        }
145    }
146
147    /// Wake the body up (reset sleep counter and state).
148    pub fn wake(&mut self) {
149        self.sleep_counter = 0;
150        self.sleep_state = SleepState::Awake;
151    }
152
153    /// Compute an adaptive CFL time-step given the current body state.
154    ///
155    /// Returns the largest `dt` such that no particle moves more than
156    /// `max_displacement` in one sub-step.
157    ///
158    /// If all particles are static or have zero velocity, returns `dt_max`.
159    pub fn cfl_timestep(body: &SoftBody, dt_max: Real, max_displacement: Real) -> Real {
160        let v_max = body
161            .particles
162            .iter()
163            .filter(|p| !p.is_static())
164            .map(|p| p.velocity.norm())
165            .fold(0.0_f64, f64::max);
166
167        if v_max < 1e-12 {
168            return dt_max;
169        }
170
171        let dt_cfl = max_displacement / v_max;
172        dt_cfl.min(dt_max)
173    }
174
175    /// Check whether the body should go to sleep based on current velocities.
176    ///
177    /// Returns `true` if all dynamic particles are below [`Self::sleep_threshold`].
178    fn check_sleep(&mut self, body: &SoftBody) -> bool {
179        let all_slow = body
180            .particles
181            .iter()
182            .filter(|p| !p.is_static())
183            .all(|p| p.velocity.norm() < self.sleep_threshold);
184
185        if all_slow {
186            self.sleep_counter += 1;
187        } else {
188            self.sleep_counter = 0;
189        }
190
191        if self.sleep_counter >= self.sleep_counter_max {
192            self.sleep_state = SleepState::Asleep;
193            true
194        } else {
195            self.sleep_state = SleepState::Awake;
196            false
197        }
198    }
199
200    /// Run one full solve step over `body` with the given `constraints`.
201    pub fn solve(
202        &mut self,
203        body: &mut SoftBody,
204        constraints: &mut [Box<dyn SoftConstraint>],
205        dt: Real,
206    ) {
207        let n = body.particles.len();
208        if n == 0 || self.num_substeps == 0 {
209            return;
210        }
211
212        // Skip integration if asleep (but still check if we should stay asleep).
213        if self.sleep_state == SleepState::Asleep {
214            return;
215        }
216
217        let dt_sub = dt / self.num_substeps as Real;
218
219        for _sub in 0..self.num_substeps {
220            // 1. Predict positions.
221            for p in &mut body.particles {
222                if p.is_static() {
223                    continue;
224                }
225                p.velocity += p.external_force * (p.inverse_mass * dt_sub);
226                p.prev_position = p.position;
227                p.position += p.velocity * dt_sub;
228            }
229
230            // 2. Project constraints (multiple iterations per sub-step).
231            for _iter in 0..self.num_iterations {
232                for c in constraints.iter_mut() {
233                    c.project(&mut body.particles, dt_sub);
234                }
235            }
236
237            // 3. Update velocities from position corrections.
238            for p in &mut body.particles {
239                if p.is_static() {
240                    continue;
241                }
242                p.velocity = (p.position - p.prev_position) / dt_sub;
243            }
244
245            // 4. Apply damping.
246            let damp = 1.0 - body.damping;
247            for p in &mut body.particles {
248                p.velocity *= damp;
249            }
250        }
251
252        // 5. Update sleep state.
253        self.check_sleep(body);
254    }
255
256    /// Integrate particle positions forward by `dt` without constraint projection.
257    ///
258    /// Useful when you want to call the integration and projection phases
259    /// separately (e.g. from a higher-level PBD loop).
260    pub fn integrate_positions(&self, body: &mut SoftBody, dt: Real) {
261        for p in &mut body.particles {
262            if p.is_static() {
263                continue;
264            }
265            p.velocity += p.external_force * (p.inverse_mass * dt);
266            p.prev_position = p.position;
267            p.position += p.velocity * dt;
268        }
269    }
270
271    /// Update particle velocities from the displacement since the last
272    /// `integrate_positions` call (i.e. from `prev_position`).
273    ///
274    /// Call this **after** all constraint projections for a sub-step.
275    pub fn integrate_velocities(&self, body: &mut SoftBody, dt: Real) {
276        for p in &mut body.particles {
277            if p.is_static() {
278                continue;
279            }
280            p.velocity = (p.position - p.prev_position) / dt;
281        }
282    }
283
284    /// Apply velocity damping to all dynamic particles.
285    pub fn apply_damping(&self, body: &mut SoftBody) {
286        let damp = 1.0 - body.damping;
287        for p in &mut body.particles {
288            p.velocity *= damp;
289        }
290    }
291
292    /// Compute total kinetic energy of the body (½ Σ mᵢ |vᵢ|²).
293    pub fn kinetic_energy(body: &SoftBody) -> Real {
294        body.particles
295            .iter()
296            .filter(|p| !p.is_static())
297            .map(|p| {
298                let m = if p.inverse_mass > 0.0 {
299                    1.0 / p.inverse_mass
300                } else {
301                    0.0
302                };
303                0.5 * m * p.velocity.norm_squared()
304            })
305            .sum()
306    }
307
308    /// Compute the maximum particle displacement in the last sub-step.
309    ///
310    /// Useful for adaptive iteration count decisions.
311    pub fn max_displacement(body: &SoftBody) -> Real {
312        body.particles
313            .iter()
314            .filter(|p| !p.is_static())
315            .map(|p| (p.position - p.prev_position).norm())
316            .fold(0.0_f64, f64::max)
317    }
318}
319
320impl Default for XpbdSolver {
321    fn default() -> Self {
322        Self::new(10)
323    }
324}
325
326// ---------------------------------------------------------------------------
327// SolverConvergenceTracker
328// ---------------------------------------------------------------------------
329
330/// Tracks convergence statistics across solver iterations.
331#[derive(Debug, Clone)]
332pub struct SolverConvergenceTracker {
333    /// Per-iteration maximum constraint error.
334    pub error_history: Vec<Real>,
335    /// Total number of constraint projections performed.
336    pub total_projections: usize,
337    /// Whether the last solve converged below the threshold.
338    pub converged: bool,
339    /// Convergence threshold (maximum allowed constraint error).
340    pub threshold: Real,
341}
342
343impl SolverConvergenceTracker {
344    /// Create a new tracker.
345    pub fn new(threshold: Real) -> Self {
346        Self {
347            error_history: Vec::new(),
348            total_projections: 0,
349            converged: false,
350            threshold,
351        }
352    }
353
354    /// Record an error value for the current iteration.
355    pub fn record(&mut self, error: Real) {
356        self.error_history.push(error);
357        self.converged = error < self.threshold;
358    }
359
360    /// Reset the tracker for a new solve.
361    pub fn reset(&mut self) {
362        self.error_history.clear();
363        self.total_projections = 0;
364        self.converged = false;
365    }
366
367    /// Convergence rate: ratio of last two errors (< 1 means converging).
368    pub fn convergence_rate(&self) -> Option<Real> {
369        let n = self.error_history.len();
370        if n < 2 {
371            return None;
372        }
373        let prev = self.error_history[n - 2];
374        let curr = self.error_history[n - 1];
375        if prev.abs() < 1e-14 {
376            return None;
377        }
378        Some(curr / prev)
379    }
380
381    /// Suggest an iteration count based on convergence rate.
382    ///
383    /// If convergence is fast (rate < 0.5), fewer iterations are needed.
384    /// If convergence is slow (rate > 0.9), more iterations help.
385    pub fn suggest_iterations(&self, current: usize, min: usize, max: usize) -> usize {
386        match self.convergence_rate() {
387            Some(rate) if rate < 0.3 => (current / 2).max(min),
388            Some(rate) if rate > 0.8 => (current * 2).min(max),
389            _ => current,
390        }
391    }
392}
393
394// ---------------------------------------------------------------------------
395// ConstraintOrderingStrategy
396// ---------------------------------------------------------------------------
397
398/// Strategy for ordering constraint projections within an iteration.
399#[derive(Debug, Clone, Copy, PartialEq, Eq)]
400pub enum ConstraintOrdering {
401    /// Process constraints in the order they were added (default).
402    Sequential,
403    /// Reverse the order each iteration (improves convergence for chains).
404    Alternating,
405    /// Randomly shuffle each iteration (helps avoid systematic bias).
406    Shuffled,
407}
408
409/// Solve with a specified constraint ordering strategy.
410pub fn solve_with_ordering(
411    solver: &mut XpbdSolver,
412    body: &mut SoftBody,
413    constraints: &mut [Box<dyn SoftConstraint>],
414    dt: Real,
415    ordering: ConstraintOrdering,
416) {
417    let n = body.particles.len();
418    if n == 0 || solver.num_substeps == 0 {
419        return;
420    }
421    if solver.sleep_state == SleepState::Asleep {
422        return;
423    }
424
425    let dt_sub = dt / solver.num_substeps as Real;
426
427    for sub in 0..solver.num_substeps {
428        // 1. Predict positions
429        for i in 0..n {
430            let p = &mut body.particles[i];
431            if p.is_static() {
432                continue;
433            }
434            p.velocity += p.external_force * (p.inverse_mass * dt_sub);
435            p.prev_position = p.position;
436            p.position += p.velocity * dt_sub;
437        }
438
439        // 2. Project constraints with ordering
440        for iter in 0..solver.num_iterations {
441            match ordering {
442                ConstraintOrdering::Sequential => {
443                    for c in constraints.iter_mut() {
444                        c.project(&mut body.particles, dt_sub);
445                    }
446                }
447                ConstraintOrdering::Alternating => {
448                    if (sub + iter) % 2 == 0 {
449                        for c in constraints.iter_mut() {
450                            c.project(&mut body.particles, dt_sub);
451                        }
452                    } else {
453                        for c in constraints.iter_mut().rev() {
454                            c.project(&mut body.particles, dt_sub);
455                        }
456                    }
457                }
458                ConstraintOrdering::Shuffled => {
459                    // Deterministic pseudo-shuffle based on iteration index
460                    let offset = (iter * 7 + sub * 13) % constraints.len().max(1);
461                    for k in 0..constraints.len() {
462                        let idx = (k + offset) % constraints.len();
463                        constraints[idx].project(&mut body.particles, dt_sub);
464                    }
465                }
466            }
467        }
468
469        // 3. Update velocities
470        for i in 0..n {
471            let p = &mut body.particles[i];
472            if p.is_static() {
473                continue;
474            }
475            p.velocity = (p.position - p.prev_position) / dt_sub;
476        }
477
478        // 4. Damping
479        let damp = 1.0 - body.damping;
480        for i in 0..n {
481            body.particles[i].velocity *= damp;
482        }
483    }
484}
485
486// ---------------------------------------------------------------------------
487// Solver warmstarting
488// ---------------------------------------------------------------------------
489
490/// Stores per-constraint Lagrange multipliers from the previous solve
491/// for warm-starting the next solve.
492#[derive(Debug, Clone, Default)]
493pub struct WarmstartCache {
494    /// Previous Lagrange multipliers indexed by constraint.
495    pub lambdas: Vec<Real>,
496}
497
498impl WarmstartCache {
499    /// Create a new cache.
500    pub fn new() -> Self {
501        Self::default()
502    }
503
504    /// Resize the cache to match the number of constraints.
505    pub fn resize(&mut self, n: usize) {
506        self.lambdas.resize(n, 0.0);
507    }
508
509    /// Apply warm-start displacements to particles based on cached lambdas.
510    ///
511    /// This is a simplified version: it scales all particle velocities by
512    /// a factor derived from the previous solve's total lambda.
513    pub fn apply_warmstart(&self, body: &mut SoftBody, factor: Real) {
514        if self.lambdas.is_empty() {
515            return;
516        }
517        let avg_lambda: Real =
518            self.lambdas.iter().map(|l| l.abs()).sum::<Real>() / self.lambdas.len() as Real;
519        // Nudge velocities in their current direction
520        for p in &mut body.particles {
521            if !p.is_static() {
522                let v_mag = p.velocity.norm();
523                if v_mag > 1e-14 {
524                    let scale = 1.0 + factor * avg_lambda / (v_mag + 1e-10);
525                    p.velocity *= scale.clamp(0.5, 2.0);
526                }
527            }
528        }
529    }
530
531    /// Reset all cached lambdas to zero.
532    pub fn clear(&mut self) {
533        for l in &mut self.lambdas {
534            *l = 0.0;
535        }
536    }
537}
538
539// ---------------------------------------------------------------------------
540// Gauss-Seidel solver
541// ---------------------------------------------------------------------------
542
543/// A simple Gauss-Seidel constraint solver that processes constraints
544/// one at a time, immediately applying corrections.
545///
546/// This is essentially what the standard XPBD loop does, but packaged
547/// as a separate utility for clarity and testability.
548#[derive(Debug, Clone)]
549pub struct GaussSeidelSolver {
550    /// Number of iterations.
551    pub iterations: usize,
552    /// Successive over-relaxation factor (1.0 = standard GS, >1 = SOR).
553    pub omega: Real,
554}
555
556impl GaussSeidelSolver {
557    /// Create a new Gauss-Seidel solver.
558    pub fn new(iterations: usize) -> Self {
559        Self {
560            iterations,
561            omega: 1.0,
562        }
563    }
564
565    /// Create with SOR factor.
566    pub fn with_sor(iterations: usize, omega: Real) -> Self {
567        Self { iterations, omega }
568    }
569
570    /// Project all constraints `iterations` times.
571    pub fn solve(
572        &self,
573        particles: &mut [SoftParticle],
574        constraints: &mut [Box<dyn SoftConstraint>],
575        dt_sub: Real,
576    ) {
577        for _ in 0..self.iterations {
578            for c in constraints.iter_mut() {
579                c.project(particles, dt_sub);
580            }
581            // Apply SOR if omega != 1
582            if (self.omega - 1.0).abs() > 1e-10 {
583                for p in particles.iter_mut() {
584                    if !p.is_static() {
585                        let displacement = p.position - p.prev_position;
586                        p.position = p.prev_position + displacement * self.omega;
587                    }
588                }
589            }
590        }
591    }
592}
593
594// ---------------------------------------------------------------------------
595// Tests
596// ---------------------------------------------------------------------------
597
598#[cfg(test)]
599mod tests {
600    use super::*;
601    use crate::constraint::DistanceConstraint;
602
603    use oxiphysics_core::math::Vec3;
604
605    // T1. Substep decomposition: running 1 step with N substeps should move
606    //     a freely-falling particle by roughly the same amount as N individual
607    //     1-substep solves over the same total dt.
608    #[test]
609    fn test_substep_decomposition() {
610        let make_body = || {
611            let mut body =
612                SoftBody::from_particles(vec![SoftParticle::new(Vec3::new(0.0, 10.0, 0.0), 1.0)]);
613            body.apply_force(&Vec3::new(0.0, -9.81, 0.0));
614            body
615        };
616
617        let dt = 1.0 / 60.0;
618        let n_sub = 5;
619
620        // Solver A: single call with n_sub substeps.
621        let mut body_a = make_body();
622        let mut solver_a = XpbdSolver::new(n_sub);
623        solver_a.solve(&mut body_a, &mut [], dt);
624
625        // Solver B: n_sub calls each with 1 substep over dt/n_sub.
626        let mut body_b = make_body();
627        let mut solver_b = XpbdSolver::new(1);
628        for _ in 0..n_sub {
629            solver_b.solve(&mut body_b, &mut [], dt / n_sub as Real);
630        }
631
632        let dy_a = (body_a.particles[0].position.y - 10.0).abs();
633        let dy_b = (body_b.particles[0].position.y - 10.0).abs();
634
635        // Results should be very close (within 1e-10).
636        assert!(
637            (dy_a - dy_b).abs() < 1e-10,
638            "Substep decomposition mismatch: dy_a={dy_a}, dy_b={dy_b}"
639        );
640    }
641
642    // T2. Sleep detection: a stationary body should go to sleep after enough
643    //     calls with no external force.
644    #[test]
645    fn test_sleep_detection() {
646        let mut body =
647            SoftBody::from_particles(vec![SoftParticle::new(Vec3::new(0.0, 0.0, 0.0), 1.0)]);
648        let mut solver = XpbdSolver::new(1);
649        solver.sleep_counter_max = 3;
650        solver.sleep_threshold = 1e-3;
651
652        // No external force → particle stays still → should fall asleep.
653        for _ in 0..5 {
654            solver.solve(&mut body, &mut [], 1.0 / 60.0);
655        }
656        assert_eq!(
657            solver.sleep_state,
658            SleepState::Asleep,
659            "Body should be asleep when velocity is zero"
660        );
661    }
662
663    // T3. Sleeping body does not move.
664    #[test]
665    fn test_sleeping_body_not_integrated() {
666        let mut body =
667            SoftBody::from_particles(vec![SoftParticle::new(Vec3::new(0.0, 5.0, 0.0), 1.0)]);
668        body.apply_force(&Vec3::new(0.0, -9.81, 0.0));
669        let mut solver = XpbdSolver::new(1);
670        solver.sleep_state = SleepState::Asleep;
671
672        let y_before = body.particles[0].position.y;
673        solver.solve(&mut body, &mut [], 1.0 / 60.0);
674        let y_after = body.particles[0].position.y;
675
676        assert!(
677            (y_before - y_after).abs() < 1e-15,
678            "Sleeping body must not move"
679        );
680    }
681
682    // T4. Wake resets sleep state and counter.
683    #[test]
684    fn test_wake_resets_sleep() {
685        let mut solver = XpbdSolver::new(5);
686        solver.sleep_state = SleepState::Asleep;
687        solver.sleep_counter = 99;
688        solver.wake();
689        assert_eq!(solver.sleep_state, SleepState::Awake);
690        assert_eq!(solver.sleep_counter, 0);
691    }
692
693    // T5. Constraint iteration: a chain of particles with competing constraints
694    //     converges further with more iterations per sub-step.
695    //
696    // A chain: p0 (pinned) - p1 - p2 - p3, rest length 1.0 between each pair,
697    // all initially displaced to positions 0, 3, 6, 9 (3× rest length).
698    // With only 1 iteration per sub-step the correction cannot fully propagate
699    // along the chain; with 30 iterations it gets much closer to rest lengths.
700    #[test]
701    fn test_constraint_iterations_converge() {
702        let make_setup = |iters: usize| {
703            let mut particles = vec![
704                SoftParticle::new(Vec3::new(0.0, 0.0, 0.0), 1.0),
705                SoftParticle::new(Vec3::new(3.0, 0.0, 0.0), 1.0),
706                SoftParticle::new(Vec3::new(6.0, 0.0, 0.0), 1.0),
707                SoftParticle::new(Vec3::new(9.0, 0.0, 0.0), 1.0),
708            ];
709            particles[0].inverse_mass = 0.0; // pin first particle
710            let rest = 1.0;
711            let constraints: Vec<Box<dyn SoftConstraint>> = vec![
712                Box::new(DistanceConstraint::new(0, 1, rest, 0.0)),
713                Box::new(DistanceConstraint::new(1, 2, rest, 0.0)),
714                Box::new(DistanceConstraint::new(2, 3, rest, 0.0)),
715            ];
716            let mut body = SoftBody::from_particles(particles);
717            let mut constraints = constraints;
718            let mut solver = XpbdSolver::with_iterations(1, iters);
719            solver.solve(&mut body, &mut constraints, 1.0 / 60.0);
720            // Sum of errors across all three springs.
721            let d01 = (body.particles[0].position - body.particles[1].position).norm();
722            let d12 = (body.particles[1].position - body.particles[2].position).norm();
723            let d23 = (body.particles[2].position - body.particles[3].position).norm();
724            (d01 - rest).abs() + (d12 - rest).abs() + (d23 - rest).abs()
725        };
726
727        let err_1 = make_setup(1);
728        let err_30 = make_setup(30);
729
730        assert!(
731            err_30 < err_1,
732            "More iterations should give smaller total error: err_1={err_1:.4}, err_30={err_30:.4}"
733        );
734    }
735
736    // T6. CFL timestep clamps correctly.
737    #[test]
738    fn test_cfl_timestep() {
739        let mut body = SoftBody::from_particles(vec![
740            SoftParticle::new(Vec3::new(0.0, 0.0, 0.0), 1.0),
741            SoftParticle::new(Vec3::new(1.0, 0.0, 0.0), 1.0),
742        ]);
743        // Give one particle a large velocity.
744        body.particles[0].velocity = Vec3::new(100.0, 0.0, 0.0);
745
746        let dt_max = 0.1;
747        let max_disp = 0.5;
748        let dt_cfl = XpbdSolver::cfl_timestep(&body, dt_max, max_disp);
749
750        // dt_cfl should be max_disp / 100 = 0.005 < dt_max.
751        assert!(
752            dt_cfl < dt_max,
753            "CFL dt should be smaller than dt_max: {dt_cfl}"
754        );
755        assert!(
756            (dt_cfl - max_disp / 100.0).abs() < 1e-10,
757            "CFL dt mismatch: {dt_cfl}"
758        );
759    }
760
761    // T7. ConstraintKind::compliance() returns correct values.
762    #[test]
763    fn test_constraint_kind_compliance() {
764        assert_eq!(ConstraintKind::Rigid.compliance(), 0.0);
765        assert_eq!(
766            ConstraintKind::Collision { restitution: 0.5 }.compliance(),
767            0.0
768        );
769        let k = ConstraintKind::Elastic { stiffness: 1000.0 };
770        assert!((k.compliance() - 1e-3).abs() < 1e-12);
771        let cv = ConstraintKind::Custom { compliance: 0.007 };
772        assert!((cv.compliance() - 0.007).abs() < 1e-12);
773    }
774
775    // T8. kinetic_energy returns zero for static-only body.
776    #[test]
777    fn test_kinetic_energy_static() {
778        let body = SoftBody::from_particles(vec![SoftParticle::new_static(Vec3::zeros())]);
779        assert_eq!(XpbdSolver::kinetic_energy(&body), 0.0);
780    }
781
782    // T9. integrate_positions + integrate_velocities round-trip.
783    #[test]
784    fn test_integrate_round_trip() {
785        let mut body =
786            SoftBody::from_particles(vec![SoftParticle::new(Vec3::new(0.0, 0.0, 0.0), 1.0)]);
787        body.particles[0].velocity = Vec3::new(1.0, 0.0, 0.0);
788        let solver = XpbdSolver::new(1);
789        let dt = 0.1;
790        solver.integrate_positions(&mut body, dt);
791        // pos should now be 0.1 in x
792        assert!((body.particles[0].position.x - 0.1).abs() < 1e-10);
793        // velocity re-derived from displacement should match.
794        solver.integrate_velocities(&mut body, dt);
795        assert!((body.particles[0].velocity.x - 1.0).abs() < 1e-10);
796    }
797
798    // T10. SolverConvergenceTracker records and reports.
799    #[test]
800    fn test_convergence_tracker() {
801        let mut tracker = SolverConvergenceTracker::new(0.01);
802        tracker.record(1.0);
803        tracker.record(0.5);
804        tracker.record(0.25);
805
806        assert_eq!(tracker.error_history.len(), 3);
807        assert!(!tracker.converged, "0.25 > 0.01, should not be converged");
808
809        let rate = tracker.convergence_rate().unwrap();
810        assert!((rate - 0.5).abs() < 1e-10, "Expected rate 0.5, got {rate}");
811
812        tracker.record(0.005);
813        assert!(tracker.converged, "0.005 < 0.01, should be converged");
814
815        tracker.reset();
816        assert!(tracker.error_history.is_empty());
817        assert!(!tracker.converged);
818    }
819
820    // T11. Convergence tracker suggest_iterations.
821    #[test]
822    fn test_suggest_iterations() {
823        let mut tracker = SolverConvergenceTracker::new(0.01);
824        // Fast convergence
825        tracker.record(1.0);
826        tracker.record(0.1); // rate = 0.1
827        let suggested = tracker.suggest_iterations(10, 2, 50);
828        assert!(
829            suggested <= 10,
830            "Fast convergence should suggest fewer iters"
831        );
832
833        tracker.reset();
834        // Slow convergence
835        tracker.record(1.0);
836        tracker.record(0.95); // rate = 0.95
837        let suggested_slow = tracker.suggest_iterations(10, 2, 50);
838        assert!(
839            suggested_slow >= 10,
840            "Slow convergence should suggest more iters"
841        );
842    }
843
844    // T12. ConstraintOrdering::Alternating vs Sequential.
845    #[test]
846    fn test_alternating_ordering() {
847        let mut particles = vec![
848            SoftParticle::new(Vec3::new(0.0, 0.0, 0.0), 1.0),
849            SoftParticle::new(Vec3::new(3.0, 0.0, 0.0), 1.0),
850            SoftParticle::new(Vec3::new(6.0, 0.0, 0.0), 1.0),
851        ];
852        particles[0].inverse_mass = 0.0; // pin first
853        let rest = 1.0;
854        let mut constraints: Vec<Box<dyn SoftConstraint>> = vec![
855            Box::new(DistanceConstraint::new(0, 1, rest, 0.0)),
856            Box::new(DistanceConstraint::new(1, 2, rest, 0.0)),
857        ];
858        let mut body = SoftBody::from_particles(particles);
859        let mut solver = XpbdSolver::with_iterations(1, 5);
860        solve_with_ordering(
861            &mut solver,
862            &mut body,
863            &mut constraints,
864            1.0 / 60.0,
865            ConstraintOrdering::Alternating,
866        );
867
868        // Verify positions are finite and constraints partially satisfied
869        for p in &body.particles {
870            assert!(p.position.x.is_finite(), "position should be finite");
871        }
872    }
873
874    // T13. WarmstartCache basic operations.
875    #[test]
876    fn test_warmstart_cache() {
877        let mut cache = WarmstartCache::new();
878        cache.resize(5);
879        assert_eq!(cache.lambdas.len(), 5);
880        for l in &cache.lambdas {
881            assert!(l.abs() < 1e-14);
882        }
883
884        cache.lambdas[0] = 1.0;
885        cache.lambdas[1] = -0.5;
886        cache.clear();
887        for l in &cache.lambdas {
888            assert!(l.abs() < 1e-14, "clear should zero all lambdas");
889        }
890    }
891
892    // T14. GaussSeidelSolver basic operation.
893    #[test]
894    fn test_gauss_seidel_solver() {
895        let mut particles = vec![
896            SoftParticle::new(Vec3::new(0.0, 0.0, 0.0), 1.0),
897            SoftParticle::new(Vec3::new(3.0, 0.0, 0.0), 1.0),
898        ];
899        particles[0].inverse_mass = 0.0;
900        // Save prev_position for the solver
901        for p in &mut particles {
902            p.prev_position = p.position;
903        }
904        let mut constraints: Vec<Box<dyn SoftConstraint>> =
905            vec![Box::new(DistanceConstraint::new(0, 1, 1.0, 0.0))];
906        let gs = GaussSeidelSolver::new(20);
907        gs.solve(&mut particles, &mut constraints, 1.0 / 60.0);
908
909        let d = (particles[0].position - particles[1].position).norm();
910        assert!(
911            (d - 1.0).abs() < 0.5,
912            "GS should bring particles closer to rest length: d={d}"
913        );
914    }
915
916    // T15. max_displacement for stationary body.
917    #[test]
918    fn test_max_displacement_zero() {
919        let body = SoftBody::from_particles(vec![SoftParticle::new(Vec3::zeros(), 1.0)]);
920        let d = XpbdSolver::max_displacement(&body);
921        assert!(
922            d.abs() < 1e-14,
923            "Stationary body should have zero displacement"
924        );
925    }
926
927    // T16. SOR factor in GaussSeidelSolver.
928    #[test]
929    fn test_gauss_seidel_sor() {
930        let gs = GaussSeidelSolver::with_sor(10, 1.5);
931        assert!((gs.omega - 1.5).abs() < 1e-14);
932        assert_eq!(gs.iterations, 10);
933    }
934
935    // T17. ConstraintKind::Volume compliance.
936    #[test]
937    fn test_volume_constraint_kind() {
938        let k = ConstraintKind::Volume { bulk_modulus: 1e6 };
939        assert!((k.compliance() - 1e-6).abs() < 1e-12);
940    }
941
942    // T18. Bending ConstraintKind compliance.
943    #[test]
944    fn test_bending_constraint_kind() {
945        let k = ConstraintKind::Bending { stiffness: 500.0 };
946        assert!((k.compliance() - 1.0 / 500.0).abs() < 1e-12);
947    }
948}
949
950// ---------------------------------------------------------------------------
951// Constraint batching for XPBD
952// ---------------------------------------------------------------------------
953
954/// A batch of constraints that can be solved in parallel (no shared particles).
955///
956/// Two constraints are compatible for batching if they do not share any
957/// particle indices.
958#[derive(Debug, Clone)]
959pub struct ConstraintBatch {
960    /// Indices into the original constraint list that belong to this batch.
961    pub indices: Vec<usize>,
962}
963
964impl ConstraintBatch {
965    /// Create a new empty batch.
966    pub fn new() -> Self {
967        Self {
968            indices: Vec::new(),
969        }
970    }
971
972    /// Number of constraints in this batch.
973    pub fn len(&self) -> usize {
974        self.indices.len()
975    }
976
977    /// Check if the batch is empty.
978    pub fn is_empty(&self) -> bool {
979        self.indices.is_empty()
980    }
981}
982
983impl Default for ConstraintBatch {
984    fn default() -> Self {
985        Self::new()
986    }
987}
988
989/// Partition constraints into independent batches (graph coloring).
990///
991/// Two constraints conflict if they share a particle index.
992/// Uses greedy graph coloring to assign batches.
993///
994/// # Arguments
995/// * `constraint_particles` – for each constraint, the list of particle indices it touches
996///
997/// Returns a list of batches, each batch containing constraint indices that
998/// can be solved in parallel.
999pub fn partition_constraints_into_batches(
1000    constraint_particles: &[Vec<usize>],
1001) -> Vec<ConstraintBatch> {
1002    let n = constraint_particles.len();
1003    let mut batch_id = vec![usize::MAX; n];
1004    let mut batches: Vec<ConstraintBatch> = Vec::new();
1005
1006    for i in 0..n {
1007        // Find the set of batch IDs used by conflicting constraints
1008        let mut forbidden = std::collections::BTreeSet::new();
1009        for j in 0..i {
1010            if batch_id[j] == usize::MAX {
1011                continue;
1012            }
1013            // Check if constraints i and j share any particle
1014            let share = constraint_particles[i]
1015                .iter()
1016                .any(|p| constraint_particles[j].contains(p));
1017            if share {
1018                forbidden.insert(batch_id[j]);
1019            }
1020        }
1021
1022        // Find the smallest non-forbidden batch id
1023        let mut b = 0;
1024        while forbidden.contains(&b) {
1025            b += 1;
1026        }
1027        batch_id[i] = b;
1028
1029        if b >= batches.len() {
1030            batches.push(ConstraintBatch::new());
1031        }
1032        batches[b].indices.push(i);
1033    }
1034
1035    batches
1036}
1037
1038// ---------------------------------------------------------------------------
1039// Compliance matrix computation
1040// ---------------------------------------------------------------------------
1041
1042/// Compute the effective compliance matrix for a set of XPBD constraints.
1043///
1044/// For XPBD, each constraint has an effective compliance:
1045/// α_tilde = α / (dt² * sum_w_i |∇C_i|²)
1046///
1047/// where α is the physical compliance (1/stiffness) and the denominator
1048/// accounts for the weighted gradient contributions.
1049///
1050/// This function computes the diagonal of the compliance matrix (one entry
1051/// per constraint).
1052pub fn xpbd_compliance_diagonal(
1053    compliances: &[Real],
1054    gradient_norms_sq: &[Real],
1055    dt: Real,
1056) -> Vec<Real> {
1057    let dt_sq = dt * dt;
1058    compliances
1059        .iter()
1060        .zip(gradient_norms_sq.iter())
1061        .map(|(&alpha, &grad_sq)| {
1062            let denom = dt_sq * grad_sq;
1063            if denom.abs() < 1e-60 {
1064                alpha / 1e-60
1065            } else {
1066                alpha / denom
1067            }
1068        })
1069        .collect()
1070}
1071
1072/// Compute the XPBD constraint residual for a distance constraint.
1073///
1074/// C = |x_b - x_a| - rest_length
1075///
1076/// Returns (constraint_value, gradient_norm_squared).
1077pub fn distance_constraint_residual(
1078    pos_a: [Real; 3],
1079    pos_b: [Real; 3],
1080    rest_length: Real,
1081) -> (Real, Real) {
1082    let dx = pos_b[0] - pos_a[0];
1083    let dy = pos_b[1] - pos_a[1];
1084    let dz = pos_b[2] - pos_a[2];
1085    let len = (dx * dx + dy * dy + dz * dz).sqrt();
1086    let c = len - rest_length;
1087    // |∇C|² = 2 (unit vector dotted each particle: sum = 2)
1088    let grad_sq = 2.0; // for unit inverse mass
1089    (c, grad_sq)
1090}
1091
1092// ---------------------------------------------------------------------------
1093// XPBD global step with position update
1094// ---------------------------------------------------------------------------
1095
1096/// Perform one global XPBD update step.
1097///
1098/// This is the "global" form of XPBD where all constraint corrections
1099/// are accumulated and applied simultaneously (Jacobi-style), suitable
1100/// for GPU-ready parallel execution.
1101///
1102/// Returns the sum of |Δx| (total displacement applied this step).
1103pub fn xpbd_global_step(
1104    positions: &mut [[Real; 3]],
1105    inv_masses: &[Real],
1106    constraints: &[(usize, usize, Real, Real)], // (a, b, rest_len, compliance)
1107    dt: Real,
1108) -> Real {
1109    let n = positions.len();
1110    let mut deltas = vec![[0.0_f64; 3]; n];
1111    let mut counts = vec![0_usize; n];
1112    let dt_sq = dt * dt;
1113
1114    for &(a, b, rest, alpha) in constraints {
1115        if a >= n || b >= n {
1116            continue;
1117        }
1118        let dx = positions[b][0] - positions[a][0];
1119        let dy = positions[b][1] - positions[a][1];
1120        let dz = positions[b][2] - positions[a][2];
1121        let len = (dx * dx + dy * dy + dz * dz).sqrt();
1122        if len < 1e-15 {
1123            continue;
1124        }
1125
1126        let c = len - rest;
1127        let wa = inv_masses[a];
1128        let wb = inv_masses[b];
1129        let w_sum = wa + wb;
1130        if w_sum < 1e-30 {
1131            continue;
1132        }
1133
1134        // XPBD lambda: Δλ = (-C - α̃ λ) / (w_sum + α̃)
1135        // Simplified (λ=0 start): Δλ = -C / (w_sum + α / dt²)
1136        let alpha_tilde = alpha / dt_sq;
1137        let d_lambda = -c / (w_sum + alpha_tilde);
1138
1139        let nx = dx / len;
1140        let ny = dy / len;
1141        let nz = dz / len;
1142
1143        deltas[a][0] -= wa * d_lambda * nx;
1144        deltas[a][1] -= wa * d_lambda * ny;
1145        deltas[a][2] -= wa * d_lambda * nz;
1146        deltas[b][0] += wb * d_lambda * nx;
1147        deltas[b][1] += wb * d_lambda * ny;
1148        deltas[b][2] += wb * d_lambda * nz;
1149        counts[a] += 1;
1150        counts[b] += 1;
1151    }
1152
1153    // Apply averaged corrections
1154    let mut total_disp = 0.0;
1155    for i in 0..n {
1156        if counts[i] > 0 {
1157            let s = 1.0 / counts[i] as Real;
1158            positions[i][0] += deltas[i][0] * s;
1159            positions[i][1] += deltas[i][1] * s;
1160            positions[i][2] += deltas[i][2] * s;
1161            let d = (deltas[i][0] * s).hypot(deltas[i][1] * s);
1162            total_disp += (d * d + (deltas[i][2] * s).powi(2)).sqrt();
1163        }
1164    }
1165    total_disp
1166}
1167
1168// ---------------------------------------------------------------------------
1169// Parallel Gauss-Seidel via graph coloring
1170// ---------------------------------------------------------------------------
1171
1172/// Parallel Gauss-Seidel solver using pre-computed constraint batches.
1173///
1174/// Each batch is processed sequentially, but within a batch all constraints
1175/// can be solved in parallel (no shared particles).
1176pub struct ParallelGaussSeidelSolver {
1177    /// Number of iterations.
1178    pub iterations: usize,
1179    /// Pre-computed constraint batches.
1180    pub batches: Vec<ConstraintBatch>,
1181}
1182
1183impl ParallelGaussSeidelSolver {
1184    /// Create a new parallel GS solver from constraint topology.
1185    pub fn new(iterations: usize, constraint_particles: &[Vec<usize>]) -> Self {
1186        let batches = partition_constraints_into_batches(constraint_particles);
1187        Self {
1188            iterations,
1189            batches,
1190        }
1191    }
1192
1193    /// Number of batches (colors).
1194    pub fn n_batches(&self) -> usize {
1195        self.batches.len()
1196    }
1197}
1198
1199// ---------------------------------------------------------------------------
1200// GPU-ready position update utility
1201// ---------------------------------------------------------------------------
1202
1203/// GPU-ready position update: given current and previous positions,
1204/// compute new velocity for each particle.
1205///
1206/// `v_i = (x_i^new - x_i^prev) / dt`
1207///
1208/// Returns the kinetic energy.
1209pub fn compute_velocities_from_positions(
1210    positions: &[[Real; 3]],
1211    prev_positions: &[[Real; 3]],
1212    inv_masses: &[Real],
1213    dt: Real,
1214) -> (Vec<[Real; 3]>, Real) {
1215    assert_eq!(positions.len(), prev_positions.len());
1216    assert_eq!(positions.len(), inv_masses.len());
1217
1218    let mut velocities = Vec::with_capacity(positions.len());
1219    let mut ke = 0.0;
1220
1221    for i in 0..positions.len() {
1222        let vx = (positions[i][0] - prev_positions[i][0]) / dt;
1223        let vy = (positions[i][1] - prev_positions[i][1]) / dt;
1224        let vz = (positions[i][2] - prev_positions[i][2]) / dt;
1225        velocities.push([vx, vy, vz]);
1226
1227        if inv_masses[i] > 0.0 {
1228            let mass = 1.0 / inv_masses[i];
1229            ke += 0.5 * mass * (vx * vx + vy * vy + vz * vz);
1230        }
1231    }
1232
1233    (velocities, ke)
1234}
1235
1236// ---------------------------------------------------------------------------
1237// Additional tests for XPBD solver extensions
1238// ---------------------------------------------------------------------------
1239
1240#[cfg(test)]
1241mod tests_extended {
1242
1243    use crate::solver::ConstraintBatch;
1244    use crate::solver::ParallelGaussSeidelSolver;
1245    use crate::solver::compute_velocities_from_positions;
1246    use crate::solver::distance_constraint_residual;
1247    use crate::solver::partition_constraints_into_batches;
1248    use crate::solver::xpbd_compliance_diagonal;
1249    use crate::solver::xpbd_global_step;
1250
1251    #[test]
1252    fn test_partition_no_conflicts() {
1253        // Two constraints on different particles → same batch possible
1254        let cp = vec![vec![0, 1], vec![2, 3]];
1255        let batches = partition_constraints_into_batches(&cp);
1256        assert_eq!(
1257            batches.len(),
1258            1,
1259            "non-conflicting constraints should be in 1 batch"
1260        );
1261        assert_eq!(batches[0].len(), 2);
1262    }
1263
1264    #[test]
1265    fn test_partition_all_conflicts() {
1266        // Three constraints all sharing particle 0
1267        let cp = vec![vec![0, 1], vec![0, 2], vec![0, 3]];
1268        let batches = partition_constraints_into_batches(&cp);
1269        assert_eq!(
1270            batches.len(),
1271            3,
1272            "all-conflicting constraints need 3 batches"
1273        );
1274    }
1275
1276    #[test]
1277    fn test_partition_chain_of_constraints() {
1278        // Chain: 0-1, 1-2, 2-3 → 2-colorable
1279        let cp = vec![vec![0, 1], vec![1, 2], vec![2, 3]];
1280        let batches = partition_constraints_into_batches(&cp);
1281        // Should need exactly 2 colors for a path graph
1282        assert!(batches.len() >= 2 && batches.len() <= 3);
1283    }
1284
1285    #[test]
1286    fn test_xpbd_compliance_diagonal() {
1287        let compliances = vec![1e-3, 0.0];
1288        let grad_sq = vec![2.0, 2.0];
1289        let dt = 0.01;
1290        let diag = xpbd_compliance_diagonal(&compliances, &grad_sq, dt);
1291        assert_eq!(diag.len(), 2);
1292        // alpha_tilde = alpha / (dt^2 * grad_sq)
1293        let expected = 1e-3 / (0.01 * 0.01 * 2.0);
1294        assert!(
1295            (diag[0] - expected).abs() / expected < 1e-10,
1296            "diag[0] = {}",
1297            diag[0]
1298        );
1299    }
1300
1301    #[test]
1302    fn test_distance_constraint_residual() {
1303        let a = [0.0, 0.0, 0.0];
1304        let b = [3.0, 0.0, 0.0]; // actual distance = 3.0
1305        let rest = 1.0;
1306        let (c, grad_sq) = distance_constraint_residual(a, b, rest);
1307        assert!(
1308            (c - 2.0).abs() < 1e-12,
1309            "constraint value should be 2.0: {c}"
1310        );
1311        assert!(
1312            (grad_sq - 2.0).abs() < 1e-12,
1313            "gradient norm sq = {grad_sq}"
1314        );
1315    }
1316
1317    #[test]
1318    fn test_distance_constraint_residual_at_rest() {
1319        let a = [0.0, 0.0, 0.0];
1320        let b = [1.0, 0.0, 0.0];
1321        let (c, _) = distance_constraint_residual(a, b, 1.0);
1322        assert!(
1323            c.abs() < 1e-12,
1324            "constraint at rest length should be zero: {c}"
1325        );
1326    }
1327
1328    #[test]
1329    fn test_xpbd_global_step_reduces_violation() {
1330        let mut positions = [[0.0, 0.0, 0.0_f64], [3.0, 0.0, 0.0]];
1331        let inv_masses = [1.0, 1.0];
1332        let constraints = vec![(0, 1, 1.0, 0.0)];
1333        let dt = 0.01;
1334
1335        xpbd_global_step(&mut positions, &inv_masses, &constraints, dt);
1336
1337        let dx = positions[1][0] - positions[0][0];
1338        let dy = positions[1][1] - positions[0][1];
1339        let dz = positions[1][2] - positions[0][2];
1340        let dist = (dx * dx + dy * dy + dz * dz).sqrt();
1341        // Should be closer to rest length 1.0 than original 3.0
1342        assert!(dist < 3.0, "distance should decrease: {dist}");
1343    }
1344
1345    #[test]
1346    fn test_xpbd_global_step_static_particle() {
1347        let mut positions = [[0.0, 0.0, 0.0_f64], [3.0, 0.0, 0.0]];
1348        let inv_masses = [0.0, 1.0]; // first particle is static
1349        let constraints = vec![(0, 1, 1.0, 0.0)];
1350        let dt = 0.01;
1351
1352        xpbd_global_step(&mut positions, &inv_masses, &constraints, dt);
1353
1354        // Static particle should not move
1355        assert!(
1356            (positions[0][0]).abs() < 1e-14,
1357            "static particle should not move"
1358        );
1359    }
1360
1361    #[test]
1362    fn test_compute_velocities_from_positions() {
1363        let pos = [[1.0, 0.0, 0.0_f64]];
1364        let prev = [[0.0, 0.0, 0.0_f64]];
1365        let inv_m = [1.0];
1366        let dt = 0.1;
1367
1368        let (vels, ke) = compute_velocities_from_positions(&pos, &prev, &inv_m, dt);
1369        assert_eq!(vels.len(), 1);
1370        assert!((vels[0][0] - 10.0).abs() < 1e-10, "vx = {}", vels[0][0]);
1371        // KE = 0.5 * 1.0 * 100 = 50
1372        assert!((ke - 50.0).abs() < 1e-10, "KE = {ke}");
1373    }
1374
1375    #[test]
1376    fn test_compute_velocities_static_particle() {
1377        let pos = [[1.0, 0.0, 0.0_f64]];
1378        let prev = [[0.0, 0.0, 0.0_f64]];
1379        let inv_m = [0.0]; // static
1380        let dt = 0.1;
1381
1382        let (_vels, ke) = compute_velocities_from_positions(&pos, &prev, &inv_m, dt);
1383        assert_eq!(ke, 0.0, "static particle has zero KE");
1384    }
1385
1386    #[test]
1387    fn test_constraint_batch_default() {
1388        let batch = ConstraintBatch::default();
1389        assert!(batch.is_empty());
1390        assert_eq!(batch.len(), 0);
1391    }
1392
1393    #[test]
1394    fn test_parallel_gauss_seidel_n_batches() {
1395        let cp = vec![vec![0, 1], vec![2, 3], vec![1, 2]];
1396        let pgs = ParallelGaussSeidelSolver::new(10, &cp);
1397        assert!(pgs.n_batches() >= 2, "chain needs at least 2 colors");
1398        assert_eq!(pgs.iterations, 10);
1399    }
1400}