Skip to main content

oxiphysics_python/
constraints_api.rs

1#![allow(clippy::needless_range_loop)]
2// Copyright 2026 COOLJAPAN OU (Team KitaSan)
3// SPDX-License-Identifier: Apache-2.0
4
5//! Constraint system API for Python interop.
6//!
7//! Provides Python-friendly types for rigid body constraint solving,
8//! joints, contact constraints, motors, island management, CCD,
9//! PBD/XPBD solvers, control systems, friction models, and warm starting.
10
11#![allow(missing_docs)]
12#![allow(dead_code)]
13
14use serde::{Deserialize, Serialize};
15
16// ---------------------------------------------------------------------------
17// Helper functions
18// ---------------------------------------------------------------------------
19
20/// Perform one PGS (Projected Gauss-Seidel) iteration over a set of constraints.
21///
22/// `lambda` — accumulated impulse per constraint (modified in place).
23/// `rhs`    — constraint violation / desired velocity change per constraint.
24/// `diag`   — diagonal of the effective-mass matrix per constraint.
25/// `lo`     — lower bound per constraint (clamping).
26/// `hi`     — upper bound per constraint (clamping).
27///
28/// Returns the total residual after the iteration.
29pub fn solve_pgs_iteration(
30    lambda: &mut [f64],
31    rhs: &[f64],
32    diag: &[f64],
33    lo: &[f64],
34    hi: &[f64],
35) -> f64 {
36    let n = lambda.len();
37    let mut residual = 0.0;
38    for i in 0..n {
39        if diag[i].abs() < 1e-15 {
40            continue;
41        }
42        let delta = (rhs[i] - diag[i] * lambda[i]) / diag[i];
43        let new_lambda = (lambda[i] + delta).clamp(lo[i], hi[i]);
44        let actual_delta = new_lambda - lambda[i];
45        lambda[i] = new_lambda;
46        residual += actual_delta * actual_delta;
47    }
48    residual.sqrt()
49}
50
51/// Compute the Jacobian row for a distance constraint between two bodies.
52///
53/// `r_a` — vector from body A's center to the contact point.
54/// `r_b` — vector from body B's center to the contact point.
55/// `n`   — constraint normal direction (unit vector).
56///
57/// Returns \[J_lin_a (3), J_ang_a (3), J_lin_b (3), J_ang_b (3)\] = 12 values.
58pub fn compute_jacobian(r_a: [f64; 3], r_b: [f64; 3], n: [f64; 3]) -> [f64; 12] {
59    // J_lin_a = n, J_ang_a = r_a × n, J_lin_b = -n, J_ang_b = -r_b × n
60    let ang_a = cross(r_a, n);
61    let ang_b = cross(r_b, n);
62    [
63        n[0], n[1], n[2], ang_a[0], ang_a[1], ang_a[2], -n[0], -n[1], -n[2], -ang_b[0], -ang_b[1],
64        -ang_b[2],
65    ]
66}
67
68fn cross(a: [f64; 3], b: [f64; 3]) -> [f64; 3] {
69    [
70        a[1] * b[2] - a[2] * b[1],
71        a[2] * b[0] - a[0] * b[2],
72        a[0] * b[1] - a[1] * b[0],
73    ]
74}
75
76/// Compute the effective mass (inverse denominator) for a constraint.
77///
78/// `inv_mass_a`, `inv_mass_b` — inverse masses.
79/// `inv_inertia_a`, `inv_inertia_b` — diagonal inverse inertia tensors (3 values each).
80/// `j` — Jacobian row (12 values from `compute_jacobian`).
81///
82/// Returns the effective mass = 1 / (J M⁻¹ Jᵀ).
83pub fn compute_effective_mass(
84    inv_mass_a: f64,
85    inv_mass_b: f64,
86    inv_inertia_a: [f64; 3],
87    inv_inertia_b: [f64; 3],
88    j: [f64; 12],
89) -> f64 {
90    // linear part
91    let lin_a = j[0] * j[0] * inv_mass_a + j[1] * j[1] * inv_mass_a + j[2] * j[2] * inv_mass_a;
92    let ang_a = j[3] * j[3] * inv_inertia_a[0]
93        + j[4] * j[4] * inv_inertia_a[1]
94        + j[5] * j[5] * inv_inertia_a[2];
95    let lin_b = j[6] * j[6] * inv_mass_b + j[7] * j[7] * inv_mass_b + j[8] * j[8] * inv_mass_b;
96    let ang_b = j[9] * j[9] * inv_inertia_b[0]
97        + j[10] * j[10] * inv_inertia_b[1]
98        + j[11] * j[11] * inv_inertia_b[2];
99    let denom = lin_a + ang_a + lin_b + ang_b;
100    if denom.abs() < 1e-15 {
101        0.0
102    } else {
103        1.0 / denom
104    }
105}
106
107/// Clamp an impulse to the friction cone.
108///
109/// `lambda_n` — normal impulse (must be ≥ 0).
110/// `lambda_t` — current tangential impulse vector \[lt_x, lt_y\].
111/// `mu`       — friction coefficient.
112///
113/// Returns the clamped tangential impulse.
114pub fn clamp_impulse(lambda_n: f64, lambda_t: [f64; 2], mu: f64) -> [f64; 2] {
115    let max_t = mu * lambda_n.max(0.0);
116    let mag = (lambda_t[0] * lambda_t[0] + lambda_t[1] * lambda_t[1]).sqrt();
117    if mag > max_t && mag > 1e-15 {
118        [lambda_t[0] / mag * max_t, lambda_t[1] / mag * max_t]
119    } else {
120        lambda_t
121    }
122}
123
124// ---------------------------------------------------------------------------
125// Solver type
126// ---------------------------------------------------------------------------
127
128/// Constraint solver algorithm selection.
129#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
130pub enum SolverType {
131    /// Projected Gauss-Seidel.
132    Pgs,
133    /// Temporal Gauss-Seidel (velocity + position level).
134    Tgs,
135    /// Sequential impulse (SI).
136    Si,
137}
138
139/// Main constraint solver configuration.
140#[derive(Debug, Clone, Serialize, Deserialize)]
141pub struct PyConstraintSolver {
142    /// Solver algorithm.
143    pub solver_type: SolverType,
144    /// Number of velocity-level iterations.
145    pub velocity_iterations: u32,
146    /// Number of position-level iterations (for TGS/non-penetration correction).
147    pub position_iterations: u32,
148    /// Whether to use warm starting from the previous frame.
149    pub warm_start: bool,
150    /// Successive over-relaxation factor (1.0 = standard PGS).
151    pub sor_factor: f64,
152    /// Convergence tolerance.
153    pub tolerance: f64,
154    /// Maximum allowed penetration before applying a correction impulse.
155    pub slop: f64,
156}
157
158impl PyConstraintSolver {
159    /// Create a default PGS solver.
160    pub fn default_pgs() -> Self {
161        Self {
162            solver_type: SolverType::Pgs,
163            velocity_iterations: 10,
164            position_iterations: 5,
165            warm_start: true,
166            sor_factor: 1.0,
167            tolerance: 1e-4,
168            slop: 0.005,
169        }
170    }
171
172    /// Create a TGS solver (fewer iterations needed).
173    pub fn default_tgs() -> Self {
174        Self {
175            solver_type: SolverType::Tgs,
176            velocity_iterations: 4,
177            position_iterations: 2,
178            warm_start: true,
179            sor_factor: 1.3,
180            tolerance: 1e-4,
181            slop: 0.005,
182        }
183    }
184
185    /// Solve a simple set of constraints given accumulated lambdas, RHS and diagonal effective masses.
186    ///
187    /// Returns the final accumulated impulse array.
188    pub fn solve(&self, rhs: &[f64], diag: &[f64], lo: &[f64], hi: &[f64]) -> Vec<f64> {
189        let n = rhs.len();
190        let mut lambda = vec![0.0f64; n];
191        for _ in 0..self.velocity_iterations {
192            let res = solve_pgs_iteration(&mut lambda, rhs, diag, lo, hi);
193            if res < self.tolerance {
194                break;
195            }
196        }
197        lambda
198    }
199}
200
201// ---------------------------------------------------------------------------
202// PyJoint
203// ---------------------------------------------------------------------------
204
205/// Joint axis and limit specification.
206#[derive(Debug, Clone, Serialize, Deserialize)]
207pub struct AxisLimits {
208    /// Axis direction (unit vector).
209    pub axis: [f64; 3],
210    /// Lower limit (radians or meters).
211    pub lower: f64,
212    /// Upper limit.
213    pub upper: f64,
214    /// Whether limits are enabled.
215    pub enabled: bool,
216}
217
218impl AxisLimits {
219    /// Unlimited joint along an axis.
220    pub fn unlimited(axis: [f64; 3]) -> Self {
221        Self {
222            axis,
223            lower: -f64::INFINITY,
224            upper: f64::INFINITY,
225            enabled: false,
226        }
227    }
228
229    /// Limited joint.
230    pub fn limited(axis: [f64; 3], lower: f64, upper: f64) -> Self {
231        Self {
232            axis,
233            lower,
234            upper,
235            enabled: true,
236        }
237    }
238}
239
240/// The kind of joint.
241#[derive(Debug, Clone, Serialize, Deserialize)]
242pub enum JointKind {
243    /// Fully fixed — all 6 DOF locked.
244    Fixed,
245    /// Revolute (hinge) joint — 1 rotational DOF.
246    Revolute(AxisLimits),
247    /// Prismatic (sliding) joint — 1 translational DOF.
248    Prismatic(AxisLimits),
249    /// Ball-and-socket joint with swing and twist limits.
250    Ball { swing_limit: f64, twist_limit: f64 },
251    /// Spring joint with stiffness and damping.
252    Spring {
253        stiffness: f64,
254        damping: f64,
255        rest_length: f64,
256    },
257}
258
259/// A joint connecting two rigid bodies.
260#[derive(Debug, Clone, Serialize, Deserialize)]
261pub struct PyJoint {
262    /// Unique joint identifier.
263    pub id: u32,
264    /// Body A handle.
265    pub body_a: u32,
266    /// Body B handle.
267    pub body_b: u32,
268    /// Anchor point on body A in local space.
269    pub anchor_a: [f64; 3],
270    /// Anchor point on body B in local space.
271    pub anchor_b: [f64; 3],
272    /// Joint type.
273    pub kind: JointKind,
274    /// Whether this joint is enabled.
275    pub enabled: bool,
276    /// Joint breaking force (inf = unbreakable).
277    pub break_force: f64,
278}
279
280impl PyJoint {
281    /// Create a fixed joint between two bodies.
282    pub fn fixed(
283        id: u32,
284        body_a: u32,
285        body_b: u32,
286        anchor_a: [f64; 3],
287        anchor_b: [f64; 3],
288    ) -> Self {
289        Self {
290            id,
291            body_a,
292            body_b,
293            anchor_a,
294            anchor_b,
295            kind: JointKind::Fixed,
296            enabled: true,
297            break_force: f64::INFINITY,
298        }
299    }
300
301    /// Create a revolute joint.
302    pub fn revolute(id: u32, body_a: u32, body_b: u32, anchor: [f64; 3], axis: [f64; 3]) -> Self {
303        Self {
304            id,
305            body_a,
306            body_b,
307            anchor_a: anchor,
308            anchor_b: anchor,
309            kind: JointKind::Revolute(AxisLimits::unlimited(axis)),
310            enabled: true,
311            break_force: f64::INFINITY,
312        }
313    }
314
315    /// Create a spring joint.
316    #[allow(clippy::too_many_arguments)]
317    pub fn spring(
318        id: u32,
319        body_a: u32,
320        body_b: u32,
321        anchor_a: [f64; 3],
322        anchor_b: [f64; 3],
323        stiffness: f64,
324        damping: f64,
325        rest_length: f64,
326    ) -> Self {
327        Self {
328            id,
329            body_a,
330            body_b,
331            anchor_a,
332            anchor_b,
333            kind: JointKind::Spring {
334                stiffness,
335                damping,
336                rest_length,
337            },
338            enabled: true,
339            break_force: f64::INFINITY,
340        }
341    }
342
343    /// Check if the joint is breakable.
344    pub fn is_breakable(&self) -> bool {
345        self.break_force.is_finite()
346    }
347}
348
349// ---------------------------------------------------------------------------
350// PyContactConstraint
351// ---------------------------------------------------------------------------
352
353/// A single contact point constraint between two bodies.
354#[derive(Debug, Clone, Serialize, Deserialize)]
355pub struct PyContactConstraint {
356    /// Body A handle.
357    pub body_a: u32,
358    /// Body B handle.
359    pub body_b: u32,
360    /// Contact normal (from B to A, unit vector).
361    pub normal: [f64; 3],
362    /// Penetration depth (positive = overlap).
363    pub penetration: f64,
364    /// Contact position in world space.
365    pub position: [f64; 3],
366    /// Accumulated normal impulse (warm start).
367    pub lambda_n: f64,
368    /// Accumulated tangent impulse in primary direction.
369    pub lambda_t1: f64,
370    /// Accumulated tangent impulse in secondary direction.
371    pub lambda_t2: f64,
372    /// Coefficient of restitution.
373    pub restitution: f64,
374    /// Coulomb friction coefficient.
375    pub friction: f64,
376}
377
378impl PyContactConstraint {
379    /// Create a new contact constraint.
380    #[allow(clippy::too_many_arguments)]
381    pub fn new(
382        body_a: u32,
383        body_b: u32,
384        normal: [f64; 3],
385        penetration: f64,
386        position: [f64; 3],
387        restitution: f64,
388        friction: f64,
389    ) -> Self {
390        Self {
391            body_a,
392            body_b,
393            normal,
394            penetration,
395            position,
396            lambda_n: 0.0,
397            lambda_t1: 0.0,
398            lambda_t2: 0.0,
399            restitution,
400            friction,
401        }
402    }
403
404    /// Clamp the tangential impulses to the friction cone.
405    pub fn clamp_friction(&mut self) {
406        let clamped = clamp_impulse(
407            self.lambda_n,
408            [self.lambda_t1, self.lambda_t2],
409            self.friction,
410        );
411        self.lambda_t1 = clamped[0];
412        self.lambda_t2 = clamped[1];
413    }
414
415    /// Compute the primary and secondary tangent directions orthogonal to the normal.
416    pub fn tangent_basis(&self) -> ([f64; 3], [f64; 3]) {
417        let n = self.normal;
418        let t1 = if n[0].abs() < 0.9 {
419            let raw = [0.0 - n[1] * n[1], n[0] * n[1] - 0.0, 0.0];
420            // Just use a simple perpendicular
421            let t = [1.0 - n[0] * n[0], -n[0] * n[1], -n[0] * n[2]];
422            let len = (t[0] * t[0] + t[1] * t[1] + t[2] * t[2]).sqrt();
423            if len > 1e-12 {
424                [t[0] / len, t[1] / len, t[2] / len]
425            } else {
426                raw
427            }
428        } else {
429            let t = [-n[1] * n[0], 1.0 - n[1] * n[1], -n[1] * n[2]];
430            let len = (t[0] * t[0] + t[1] * t[1] + t[2] * t[2]).sqrt();
431            if len > 1e-12 {
432                [t[0] / len, t[1] / len, t[2] / len]
433            } else {
434                [0.0, 1.0, 0.0]
435            }
436        };
437        let t2 = cross(n, t1);
438        (t1, t2)
439    }
440}
441
442// ---------------------------------------------------------------------------
443// PyMotorConstraint
444// ---------------------------------------------------------------------------
445
446/// PID controller for motor target tracking.
447#[derive(Debug, Clone, Serialize, Deserialize)]
448pub struct PidGains {
449    /// Proportional gain.
450    pub kp: f64,
451    /// Integral gain.
452    pub ki: f64,
453    /// Derivative gain.
454    pub kd: f64,
455    /// Accumulated integral error.
456    pub integral: f64,
457    /// Previous error (for derivative).
458    pub prev_error: f64,
459}
460
461impl PidGains {
462    /// Create new PID gains.
463    pub fn new(kp: f64, ki: f64, kd: f64) -> Self {
464        Self {
465            kp,
466            ki,
467            kd,
468            integral: 0.0,
469            prev_error: 0.0,
470        }
471    }
472
473    /// Compute PID output given the current error and time step.
474    pub fn compute(&mut self, error: f64, dt: f64) -> f64 {
475        self.integral += error * dt;
476        let derivative = if dt > 1e-15 {
477            (error - self.prev_error) / dt
478        } else {
479            0.0
480        };
481        self.prev_error = error;
482        self.kp * error + self.ki * self.integral + self.kd * derivative
483    }
484
485    /// Reset the integral and derivative state.
486    pub fn reset(&mut self) {
487        self.integral = 0.0;
488        self.prev_error = 0.0;
489    }
490}
491
492/// Motor thermal model to cap torque under sustained load.
493#[derive(Debug, Clone, Serialize, Deserialize)]
494pub struct MotorThermal {
495    /// Current motor temperature (°C).
496    pub temperature: f64,
497    /// Ambient temperature (°C).
498    pub ambient: f64,
499    /// Thermal resistance (°C/W).
500    pub resistance: f64,
501    /// Thermal capacitance (J/°C).
502    pub capacitance: f64,
503    /// Motor resistance for heat generation (Ω).
504    pub motor_resistance: f64,
505    /// Temperature derating threshold (°C).
506    pub derate_threshold: f64,
507    /// Maximum allowed temperature (°C).
508    pub max_temperature: f64,
509}
510
511impl MotorThermal {
512    /// Create a default thermal model.
513    pub fn new() -> Self {
514        Self {
515            temperature: 25.0,
516            ambient: 25.0,
517            resistance: 0.5,
518            capacitance: 100.0,
519            motor_resistance: 0.1,
520            derate_threshold: 80.0,
521            max_temperature: 120.0,
522        }
523    }
524
525    /// Update temperature given current (A) and time step (s).
526    pub fn update(&mut self, current: f64, dt: f64) {
527        let heat_in = current * current * self.motor_resistance;
528        let heat_out = (self.temperature - self.ambient) / self.resistance;
529        self.temperature += (heat_in - heat_out) / self.capacitance * dt;
530    }
531
532    /// Compute the torque scale factor (1 = full, < 1 = derated).
533    pub fn torque_scale(&self) -> f64 {
534        if self.temperature <= self.derate_threshold {
535            1.0
536        } else {
537            let range = self.max_temperature - self.derate_threshold;
538            if range < 1e-9 {
539                0.0
540            } else {
541                (1.0 - (self.temperature - self.derate_threshold) / range).max(0.0)
542            }
543        }
544    }
545}
546
547impl Default for MotorThermal {
548    fn default() -> Self {
549        Self::new()
550    }
551}
552
553/// A motor/actuator constraint driving a joint.
554#[derive(Debug, Clone, Serialize, Deserialize)]
555pub struct PyMotorConstraint {
556    /// Target angular velocity (rad/s). Used in velocity mode.
557    pub target_velocity: f64,
558    /// Target position (rad or m). Used in position mode.
559    pub target_position: f64,
560    /// Maximum torque/force (N·m or N).
561    pub max_torque: f64,
562    /// PID controller.
563    pub pid: PidGains,
564    /// Whether in position mode (true) or velocity mode (false).
565    pub position_mode: bool,
566    /// Current applied torque.
567    pub applied_torque: f64,
568    /// Thermal model.
569    pub thermal: MotorThermal,
570}
571
572impl PyMotorConstraint {
573    /// Create a velocity-mode motor.
574    pub fn velocity_mode(target_velocity: f64, max_torque: f64, kp: f64) -> Self {
575        Self {
576            target_velocity,
577            target_position: 0.0,
578            max_torque,
579            pid: PidGains::new(kp, 0.0, 0.0),
580            position_mode: false,
581            applied_torque: 0.0,
582            thermal: MotorThermal::new(),
583        }
584    }
585
586    /// Create a position-mode motor with PID gains.
587    pub fn position_mode(target_position: f64, max_torque: f64, pid: PidGains) -> Self {
588        Self {
589            target_velocity: 0.0,
590            target_position,
591            max_torque,
592            pid,
593            position_mode: true,
594            applied_torque: 0.0,
595            thermal: MotorThermal::new(),
596        }
597    }
598
599    /// Step the motor given the current measured value and time step.
600    pub fn step(&mut self, current_value: f64, dt: f64) -> f64 {
601        let error = if self.position_mode {
602            self.target_position - current_value
603        } else {
604            self.target_velocity - current_value
605        };
606        let raw_torque = self.pid.compute(error, dt);
607        let scale = self.thermal.torque_scale();
608        let torque = raw_torque.clamp(-self.max_torque, self.max_torque) * scale;
609        let current = (torque / self.max_torque.max(1e-9)).abs() * 10.0;
610        self.thermal.update(current, dt);
611        self.applied_torque = torque;
612        torque
613    }
614}
615
616// ---------------------------------------------------------------------------
617// PyIsland
618// ---------------------------------------------------------------------------
619
620/// Sleeping state of an island.
621#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
622pub enum IslandSleepState {
623    /// Island is awake and being simulated.
624    Awake,
625    /// Island is dormant (below velocity threshold).
626    Sleeping,
627    /// Island is in the process of going to sleep.
628    Drowsy,
629}
630
631/// An island groups bodies and constraints that can interact.
632#[derive(Debug, Clone, Serialize, Deserialize)]
633pub struct PyIsland {
634    /// Unique island id.
635    pub id: u32,
636    /// Handles of rigid bodies in this island.
637    pub bodies: Vec<u32>,
638    /// Handles of constraints in this island.
639    pub constraints: Vec<u32>,
640    /// Sleep state.
641    pub sleep_state: IslandSleepState,
642    /// Frames below the velocity threshold (used for sleep triggering).
643    pub quiet_frames: u32,
644    /// Velocity threshold for sleeping (m/s and rad/s combined).
645    pub sleep_threshold: f64,
646    /// Number of frames needed to trigger sleep.
647    pub sleep_frames: u32,
648}
649
650impl PyIsland {
651    /// Create a new island.
652    pub fn new(id: u32) -> Self {
653        Self {
654            id,
655            bodies: Vec::new(),
656            constraints: Vec::new(),
657            sleep_state: IslandSleepState::Awake,
658            quiet_frames: 0,
659            sleep_threshold: 0.05,
660            sleep_frames: 60,
661        }
662    }
663
664    /// Add a body to the island.
665    pub fn add_body(&mut self, body: u32) {
666        if !self.bodies.contains(&body) {
667            self.bodies.push(body);
668        }
669    }
670
671    /// Add a constraint to the island.
672    pub fn add_constraint(&mut self, constraint: u32) {
673        if !self.constraints.contains(&constraint) {
674            self.constraints.push(constraint);
675        }
676    }
677
678    /// Update sleep state given the maximum body speed this frame.
679    pub fn update_sleep(&mut self, max_speed: f64) {
680        match self.sleep_state {
681            IslandSleepState::Sleeping => {
682                if max_speed > self.sleep_threshold * 2.0 {
683                    self.sleep_state = IslandSleepState::Awake;
684                    self.quiet_frames = 0;
685                }
686            }
687            _ => {
688                if max_speed < self.sleep_threshold {
689                    self.quiet_frames += 1;
690                    if self.quiet_frames >= self.sleep_frames {
691                        self.sleep_state = IslandSleepState::Sleeping;
692                    } else {
693                        self.sleep_state = IslandSleepState::Drowsy;
694                    }
695                } else {
696                    self.quiet_frames = 0;
697                    self.sleep_state = IslandSleepState::Awake;
698                }
699            }
700        }
701    }
702
703    /// Merge another island into this one.
704    pub fn merge(&mut self, other: &PyIsland) {
705        for &b in &other.bodies {
706            self.add_body(b);
707        }
708        for &c in &other.constraints {
709            self.add_constraint(c);
710        }
711        if other.sleep_state == IslandSleepState::Awake {
712            self.sleep_state = IslandSleepState::Awake;
713        }
714    }
715
716    /// Whether this island is currently active.
717    pub fn is_active(&self) -> bool {
718        self.sleep_state != IslandSleepState::Sleeping
719    }
720}
721
722// ---------------------------------------------------------------------------
723// PyCCDResult
724// ---------------------------------------------------------------------------
725
726/// Result of a continuous collision detection (CCD) query.
727#[derive(Debug, Clone, Serialize, Deserialize)]
728pub struct PyCCDResult {
729    /// Whether a collision was detected.
730    pub hit: bool,
731    /// Time of impact in \[0, 1\] (fraction of the time step).
732    pub toi: f64,
733    /// Contact normal at the time of impact.
734    pub normal: [f64; 3],
735    /// Witness point on body A.
736    pub witness_a: [f64; 3],
737    /// Witness point on body B.
738    pub witness_b: [f64; 3],
739    /// Body A handle.
740    pub body_a: u32,
741    /// Body B handle.
742    pub body_b: u32,
743}
744
745impl PyCCDResult {
746    /// No collision result.
747    pub fn miss(body_a: u32, body_b: u32) -> Self {
748        Self {
749            hit: false,
750            toi: 1.0,
751            normal: [0.0, 1.0, 0.0],
752            witness_a: [0.0; 3],
753            witness_b: [0.0; 3],
754            body_a,
755            body_b,
756        }
757    }
758
759    /// Collision result at a given time of impact.
760    #[allow(clippy::too_many_arguments)]
761    pub fn hit(
762        body_a: u32,
763        body_b: u32,
764        toi: f64,
765        normal: [f64; 3],
766        witness_a: [f64; 3],
767        witness_b: [f64; 3],
768    ) -> Self {
769        Self {
770            hit: true,
771            toi: toi.clamp(0.0, 1.0),
772            normal,
773            witness_a,
774            witness_b,
775            body_a,
776            body_b,
777        }
778    }
779
780    /// Separation distance between witness points.
781    pub fn witness_distance(&self) -> f64 {
782        let dx = self.witness_a[0] - self.witness_b[0];
783        let dy = self.witness_a[1] - self.witness_b[1];
784        let dz = self.witness_a[2] - self.witness_b[2];
785        (dx * dx + dy * dy + dz * dz).sqrt()
786    }
787}
788
789// ---------------------------------------------------------------------------
790// PyPbdSolver
791// ---------------------------------------------------------------------------
792
793/// PBD/XPBD constraint type.
794#[derive(Debug, Clone, Serialize, Deserialize)]
795pub enum PbdConstraintType {
796    /// Distance / stretch constraint.
797    Stretch { rest_length: f64, compliance: f64 },
798    /// Bending constraint between three particles.
799    Bend { rest_angle: f64, compliance: f64 },
800    /// Volume conservation constraint.
801    Volume { rest_volume: f64, compliance: f64 },
802    /// Collision contact constraint.
803    Contact { normal: [f64; 3], penetration: f64 },
804}
805
806/// A PBD/XPBD constraint connecting particle indices.
807#[derive(Debug, Clone, Serialize, Deserialize)]
808pub struct PbdConstraint {
809    /// Particle indices involved (1, 2 or 3).
810    pub particles: Vec<usize>,
811    /// Constraint type and parameters.
812    pub kind: PbdConstraintType,
813    /// Accumulated Lagrange multiplier (XPBD).
814    pub lambda: f64,
815}
816
817/// Position-based dynamics (XPBD) solver.
818#[derive(Debug, Clone, Serialize, Deserialize)]
819pub struct PyPbdSolver {
820    /// Current particle positions (flat \[x0, y0, z0, x1, ...\]).
821    pub positions: Vec<f64>,
822    /// Previous particle positions (for XPBD velocity estimation).
823    pub prev_positions: Vec<f64>,
824    /// Per-particle inverse mass (0 = kinematic/fixed).
825    pub inv_masses: Vec<f64>,
826    /// Constraints.
827    pub constraints: Vec<PbdConstraint>,
828    /// Number of substeps per time step.
829    pub substeps: u32,
830    /// Gravity vector.
831    pub gravity: [f64; 3],
832}
833
834impl PyPbdSolver {
835    /// Create a new PBD solver.
836    pub fn new(positions: Vec<f64>, inv_masses: Vec<f64>) -> Self {
837        let prev = positions.clone();
838        Self {
839            positions,
840            prev_positions: prev,
841            inv_masses,
842            constraints: Vec::new(),
843            substeps: 8,
844            gravity: [0.0, -9.81, 0.0],
845        }
846    }
847
848    /// Add a stretch constraint between two particles.
849    pub fn add_stretch(&mut self, i: usize, j: usize, compliance: f64) {
850        let dx = self.positions[3 * i] - self.positions[3 * j];
851        let dy = self.positions[3 * i + 1] - self.positions[3 * j + 1];
852        let dz = self.positions[3 * i + 2] - self.positions[3 * j + 2];
853        let rest = (dx * dx + dy * dy + dz * dz).sqrt();
854        self.constraints.push(PbdConstraint {
855            particles: vec![i, j],
856            kind: PbdConstraintType::Stretch {
857                rest_length: rest,
858                compliance,
859            },
860            lambda: 0.0,
861        });
862    }
863
864    /// Add a volume conservation constraint for a tetrahedron.
865    pub fn add_volume(&mut self, i: usize, j: usize, k: usize, l: usize, compliance: f64) {
866        self.constraints.push(PbdConstraint {
867            particles: vec![i, j, k, l],
868            kind: PbdConstraintType::Volume {
869                rest_volume: 1.0,
870                compliance,
871            },
872            lambda: 0.0,
873        });
874    }
875
876    /// Perform one substep of XPBD simulation.
877    pub fn substep(&mut self, dt: f64) {
878        let h = dt / self.substeps as f64;
879        let np = self.positions.len() / 3;
880        // Apply gravity and predict positions
881        for i in 0..np {
882            if self.inv_masses[i] > 0.0 {
883                let vx = self.positions[3 * i] - self.prev_positions[3 * i];
884                let vy = self.positions[3 * i + 1] - self.prev_positions[3 * i + 1];
885                let vz = self.positions[3 * i + 2] - self.prev_positions[3 * i + 2];
886                self.prev_positions[3 * i] = self.positions[3 * i];
887                self.prev_positions[3 * i + 1] = self.positions[3 * i + 1];
888                self.prev_positions[3 * i + 2] = self.positions[3 * i + 2];
889                self.positions[3 * i] += vx + self.gravity[0] * h * h;
890                self.positions[3 * i + 1] += vy + self.gravity[1] * h * h;
891                self.positions[3 * i + 2] += vz + self.gravity[2] * h * h;
892            }
893        }
894        // Reset lambdas
895        for c in &mut self.constraints {
896            c.lambda = 0.0;
897        }
898        // Solve stretch constraints (simplified)
899        let n_constraints = self.constraints.len();
900        for ci in 0..n_constraints {
901            if let PbdConstraintType::Stretch {
902                rest_length,
903                compliance,
904            } = self.constraints[ci].kind.clone()
905            {
906                let parts = self.constraints[ci].particles.clone();
907                if parts.len() < 2 {
908                    continue;
909                }
910                let (i, j) = (parts[0], parts[1]);
911                let wi = self.inv_masses[i];
912                let wj = self.inv_masses[j];
913                let w_sum = wi + wj;
914                if w_sum < 1e-15 {
915                    continue;
916                }
917                let dx = self.positions[3 * i] - self.positions[3 * j];
918                let dy = self.positions[3 * i + 1] - self.positions[3 * j + 1];
919                let dz = self.positions[3 * i + 2] - self.positions[3 * j + 2];
920                let len = (dx * dx + dy * dy + dz * dz).sqrt();
921                if len < 1e-12 {
922                    continue;
923                }
924                let alpha = compliance / (h * h);
925                let c_val = len - rest_length;
926                let d_lambda = (-c_val - alpha * self.constraints[ci].lambda) / (w_sum + alpha);
927                self.constraints[ci].lambda += d_lambda;
928                let nx = dx / len;
929                let ny = dy / len;
930                let nz = dz / len;
931                self.positions[3 * i] += wi * d_lambda * nx;
932                self.positions[3 * i + 1] += wi * d_lambda * ny;
933                self.positions[3 * i + 2] += wi * d_lambda * nz;
934                self.positions[3 * j] -= wj * d_lambda * nx;
935                self.positions[3 * j + 1] -= wj * d_lambda * ny;
936                self.positions[3 * j + 2] -= wj * d_lambda * nz;
937            }
938        }
939    }
940
941    /// Step the simulation for one full time step (using substeps).
942    pub fn step(&mut self, dt: f64) {
943        for _ in 0..self.substeps {
944            self.substep(dt);
945        }
946    }
947
948    /// Return particle count.
949    pub fn particle_count(&self) -> usize {
950        self.positions.len() / 3
951    }
952}
953
954// ---------------------------------------------------------------------------
955// PyControlSystem
956// ---------------------------------------------------------------------------
957
958/// A simple state-space system: dx/dt = A x + B u, y = C x + D u.
959#[derive(Debug, Clone, Serialize, Deserialize)]
960pub struct PyControlSystem {
961    /// System order (n).
962    pub order: usize,
963    /// State matrix A (n×n, row-major flat).
964    pub a_matrix: Vec<f64>,
965    /// Input matrix B (n×m, row-major flat).
966    pub b_matrix: Vec<f64>,
967    /// Output matrix C (p×n, row-major flat).
968    pub c_matrix: Vec<f64>,
969    /// Feedthrough matrix D (p×m, row-major flat).
970    pub d_matrix: Vec<f64>,
971    /// Number of inputs (m).
972    pub num_inputs: usize,
973    /// Number of outputs (p).
974    pub num_outputs: usize,
975    /// Current state vector (n values).
976    pub state: Vec<f64>,
977}
978
979impl PyControlSystem {
980    /// Create a first-order lag system: dx/dt = -x/tau + u/tau.
981    pub fn first_order_lag(tau: f64) -> Self {
982        Self {
983            order: 1,
984            a_matrix: vec![-1.0 / tau],
985            b_matrix: vec![1.0 / tau],
986            c_matrix: vec![1.0],
987            d_matrix: vec![0.0],
988            num_inputs: 1,
989            num_outputs: 1,
990            state: vec![0.0],
991        }
992    }
993
994    /// Step the system using Euler integration.
995    pub fn step_euler(&mut self, u: &[f64], dt: f64) -> Vec<f64> {
996        let n = self.order;
997        let m = self.num_inputs;
998        let p = self.num_outputs;
999        // dx = A x + B u
1000        let mut dx = vec![0.0f64; n];
1001        for i in 0..n {
1002            for j in 0..n {
1003                dx[i] += self.a_matrix[i * n + j] * self.state[j];
1004            }
1005            for j in 0..m {
1006                dx[i] += self.b_matrix[i * m + j] * u.get(j).copied().unwrap_or(0.0);
1007            }
1008        }
1009        for i in 0..n {
1010            self.state[i] += dx[i] * dt;
1011        }
1012        // y = C x + D u
1013        let mut y = vec![0.0f64; p];
1014        for i in 0..p {
1015            for j in 0..n {
1016                y[i] += self.c_matrix[i * n + j] * self.state[j];
1017            }
1018            for j in 0..m {
1019                y[i] += self.d_matrix[i * m + j] * u.get(j).copied().unwrap_or(0.0);
1020            }
1021        }
1022        y
1023    }
1024
1025    /// Compute step response over `n_steps` time steps.
1026    pub fn step_response(&mut self, dt: f64, n_steps: usize) -> Vec<f64> {
1027        self.state = vec![0.0; self.order];
1028        let mut out = Vec::with_capacity(n_steps);
1029        for _ in 0..n_steps {
1030            let y = self.step_euler(&[1.0], dt);
1031            out.push(y[0]);
1032        }
1033        out
1034    }
1035
1036    /// Evaluate the frequency response (gain) at frequency f (Hz).
1037    pub fn frequency_gain(&self, f: f64) -> f64 {
1038        // For first-order system: |G(jw)| = 1/sqrt(1 + (w*tau)^2)
1039        // Generic approximation: just use the DC gain for now
1040        let w = 2.0 * std::f64::consts::PI * f;
1041        // For simple first-order lag: tau = -1/a[0]
1042        if self.order == 1 && self.a_matrix[0] < 0.0 {
1043            let tau = -1.0 / self.a_matrix[0];
1044            let b = self.b_matrix[0] * tau;
1045            b / (1.0 + (w * tau) * (w * tau)).sqrt()
1046        } else {
1047            1.0
1048        }
1049    }
1050}
1051
1052// ---------------------------------------------------------------------------
1053// PyFrictionModel
1054// ---------------------------------------------------------------------------
1055
1056/// Friction model type.
1057#[derive(Debug, Clone, Serialize, Deserialize)]
1058pub enum FrictionModelKind {
1059    /// Standard Coulomb friction.
1060    Coulomb,
1061    /// Anisotropic friction with different coefficients along u and v directions.
1062    Anisotropic { mu_u: f64, mu_v: f64 },
1063    /// Velocity-dependent friction (Stribeck curve).
1064    VelocityDependent {
1065        mu_static: f64,
1066        mu_kinetic: f64,
1067        stribeck_velocity: f64,
1068    },
1069    /// Friction cone (3D).
1070    Cone { mu: f64 },
1071}
1072
1073/// Friction model for contact resolution.
1074#[derive(Debug, Clone, Serialize, Deserialize)]
1075pub struct PyFrictionModel {
1076    /// Friction model type.
1077    pub kind: FrictionModelKind,
1078    /// Global friction coefficient override (used by Coulomb and Cone).
1079    pub mu: f64,
1080    /// Regularization parameter to avoid singular Jacobians at zero slip.
1081    pub epsilon: f64,
1082}
1083
1084impl PyFrictionModel {
1085    /// Standard Coulomb friction.
1086    pub fn coulomb(mu: f64) -> Self {
1087        Self {
1088            kind: FrictionModelKind::Coulomb,
1089            mu,
1090            epsilon: 1e-4,
1091        }
1092    }
1093
1094    /// Anisotropic friction.
1095    pub fn anisotropic(mu_u: f64, mu_v: f64) -> Self {
1096        Self {
1097            kind: FrictionModelKind::Anisotropic { mu_u, mu_v },
1098            mu: mu_u.max(mu_v),
1099            epsilon: 1e-4,
1100        }
1101    }
1102
1103    /// Velocity-dependent (Stribeck) friction.
1104    pub fn stribeck(mu_static: f64, mu_kinetic: f64, stribeck_velocity: f64) -> Self {
1105        Self {
1106            kind: FrictionModelKind::VelocityDependent {
1107                mu_static,
1108                mu_kinetic,
1109                stribeck_velocity,
1110            },
1111            mu: mu_static,
1112            epsilon: 1e-4,
1113        }
1114    }
1115
1116    /// Compute the friction force limit given normal force and slip velocity.
1117    pub fn friction_limit(&self, normal_force: f64, slip_speed: f64) -> f64 {
1118        let mu = match &self.kind {
1119            FrictionModelKind::Coulomb => self.mu,
1120            FrictionModelKind::Cone { mu } => *mu,
1121            FrictionModelKind::Anisotropic { mu_u, mu_v } => mu_u.max(*mu_v),
1122            FrictionModelKind::VelocityDependent {
1123                mu_static,
1124                mu_kinetic,
1125                stribeck_velocity,
1126            } => {
1127                let alpha = (-slip_speed / stribeck_velocity).exp();
1128                mu_kinetic + (mu_static - mu_kinetic) * alpha
1129            }
1130        };
1131        mu * normal_force.max(0.0)
1132    }
1133}
1134
1135// ---------------------------------------------------------------------------
1136// PyWarmStart
1137// ---------------------------------------------------------------------------
1138
1139/// Cached impulse from the previous frame for warm starting.
1140#[derive(Debug, Clone, Serialize, Deserialize)]
1141pub struct CachedImpulse {
1142    /// Body pair key (body_a, body_b).
1143    pub key: (u32, u32),
1144    /// Cached normal impulse.
1145    pub lambda_n: f64,
1146    /// Cached tangential impulse \[t1, t2\].
1147    pub lambda_t: [f64; 2],
1148    /// Frame counter (for aging).
1149    pub age: u32,
1150}
1151
1152/// Warm start cache for constraint impulses.
1153#[derive(Debug, Clone, Serialize, Deserialize)]
1154pub struct PyWarmStart {
1155    /// Cached impulses from the previous solve.
1156    pub cache: Vec<CachedImpulse>,
1157    /// Fraction of cached impulse to apply at the start of the next frame.
1158    pub aging_factor: f64,
1159    /// Minimum quality threshold to keep a cached impulse.
1160    pub quality_threshold: f64,
1161    /// Maximum number of frames to keep a cached impulse.
1162    pub max_age: u32,
1163}
1164
1165impl PyWarmStart {
1166    /// Create a new warm-start cache.
1167    pub fn new() -> Self {
1168        Self {
1169            cache: Vec::new(),
1170            aging_factor: 0.85,
1171            quality_threshold: 0.5,
1172            max_age: 3,
1173        }
1174    }
1175
1176    /// Store an impulse for the next frame.
1177    pub fn store(&mut self, key: (u32, u32), lambda_n: f64, lambda_t: [f64; 2]) {
1178        if let Some(c) = self.cache.iter_mut().find(|c| c.key == key) {
1179            c.lambda_n = lambda_n;
1180            c.lambda_t = lambda_t;
1181            c.age = 0;
1182        } else {
1183            self.cache.push(CachedImpulse {
1184                key,
1185                lambda_n,
1186                lambda_t,
1187                age: 0,
1188            });
1189        }
1190    }
1191
1192    /// Retrieve the cached impulse for a body pair, scaled by aging factor.
1193    pub fn retrieve(&self, key: (u32, u32)) -> Option<(f64, [f64; 2])> {
1194        self.cache.iter().find(|c| c.key == key).map(|c| {
1195            (
1196                c.lambda_n * self.aging_factor,
1197                [
1198                    c.lambda_t[0] * self.aging_factor,
1199                    c.lambda_t[1] * self.aging_factor,
1200                ],
1201            )
1202        })
1203    }
1204
1205    /// Age all cached impulses and remove expired ones.
1206    pub fn age_and_prune(&mut self) {
1207        for c in &mut self.cache {
1208            c.age += 1;
1209        }
1210        self.cache.retain(|c| c.age <= self.max_age);
1211    }
1212
1213    /// Clear all cached impulses.
1214    pub fn clear(&mut self) {
1215        self.cache.clear();
1216    }
1217}
1218
1219impl Default for PyWarmStart {
1220    fn default() -> Self {
1221        Self::new()
1222    }
1223}
1224
1225// ---------------------------------------------------------------------------
1226// Tests
1227// ---------------------------------------------------------------------------
1228
1229#[cfg(test)]
1230mod tests {
1231    use super::*;
1232
1233    #[test]
1234    fn test_solve_pgs_basic() {
1235        let mut lam = vec![0.0f64];
1236        let rhs = [1.0];
1237        let diag = [2.0];
1238        let lo = [0.0];
1239        let hi = [10.0];
1240        let res = solve_pgs_iteration(&mut lam, &rhs, &diag, &lo, &hi);
1241        assert!(res >= 0.0);
1242        assert!(lam[0] > 0.0);
1243    }
1244
1245    #[test]
1246    fn test_solve_pgs_clamped() {
1247        let mut lam = vec![0.0f64];
1248        let rhs = [100.0];
1249        let diag = [1.0];
1250        let lo = [0.0];
1251        let hi = [5.0];
1252        solve_pgs_iteration(&mut lam, &rhs, &diag, &lo, &hi);
1253        assert!(lam[0] <= 5.0);
1254    }
1255
1256    #[test]
1257    fn test_compute_jacobian() {
1258        let j = compute_jacobian([0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [1.0, 0.0, 0.0]);
1259        assert_eq!(j.len(), 12);
1260        assert!((j[0] - 1.0).abs() < 1e-9);
1261        assert!((j[6] + 1.0).abs() < 1e-9);
1262    }
1263
1264    #[test]
1265    fn test_compute_effective_mass() {
1266        let j = compute_jacobian([0.0; 3], [0.0; 3], [1.0, 0.0, 0.0]);
1267        let em = compute_effective_mass(1.0, 1.0, [1.0; 3], [1.0; 3], j);
1268        assert!(em > 0.0);
1269    }
1270
1271    #[test]
1272    fn test_clamp_impulse_within_cone() {
1273        let out = clamp_impulse(10.0, [0.5, 0.5], 0.8);
1274        let mag = (out[0] * out[0] + out[1] * out[1]).sqrt();
1275        assert!(mag <= 10.0 * 0.8 + 1e-9);
1276    }
1277
1278    #[test]
1279    fn test_clamp_impulse_outside_cone() {
1280        let out = clamp_impulse(1.0, [5.0, 5.0], 0.5);
1281        let mag = (out[0] * out[0] + out[1] * out[1]).sqrt();
1282        assert!((mag - 0.5).abs() < 1e-9);
1283    }
1284
1285    #[test]
1286    fn test_constraint_solver_pgs() {
1287        let solver = PyConstraintSolver::default_pgs();
1288        assert_eq!(solver.solver_type, SolverType::Pgs);
1289        let lambda = solver.solve(&[1.0, 1.0], &[2.0, 2.0], &[0.0, 0.0], &[10.0, 10.0]);
1290        assert_eq!(lambda.len(), 2);
1291        assert!(lambda[0] >= 0.0);
1292    }
1293
1294    #[test]
1295    fn test_joint_fixed() {
1296        let j = PyJoint::fixed(0, 1, 2, [0.0; 3], [0.0; 3]);
1297        assert!(!j.is_breakable());
1298        assert!(j.enabled);
1299    }
1300
1301    #[test]
1302    fn test_joint_spring() {
1303        let j = PyJoint::spring(1, 0, 1, [0.0; 3], [1.0, 0.0, 0.0], 100.0, 5.0, 1.0);
1304        matches!(j.kind, JointKind::Spring { .. });
1305    }
1306
1307    #[test]
1308    fn test_contact_constraint_clamp_friction() {
1309        let mut c = PyContactConstraint::new(0, 1, [0.0, 1.0, 0.0], 0.01, [0.0; 3], 0.0, 0.5);
1310        c.lambda_n = 10.0;
1311        c.lambda_t1 = 8.0;
1312        c.lambda_t2 = 0.0;
1313        c.clamp_friction();
1314        assert!(c.lambda_t1 <= 5.0 + 1e-9);
1315    }
1316
1317    #[test]
1318    fn test_contact_constraint_tangent_basis() {
1319        let c = PyContactConstraint::new(0, 1, [0.0, 1.0, 0.0], 0.01, [0.0; 3], 0.0, 0.5);
1320        let (t1, t2) = c.tangent_basis();
1321        // t1 and t2 should be perpendicular to normal
1322        let dot1 = t1[0] * 0.0 + t1[1] * 1.0 + t1[2] * 0.0;
1323        assert!(dot1.abs() < 1e-9);
1324        let _ = t2;
1325    }
1326
1327    #[test]
1328    fn test_pid_gains() {
1329        let mut pid = PidGains::new(1.0, 0.0, 0.0);
1330        let out = pid.compute(2.0, 0.01);
1331        assert!((out - 2.0).abs() < 1e-9);
1332    }
1333
1334    #[test]
1335    fn test_pid_reset() {
1336        let mut pid = PidGains::new(1.0, 1.0, 0.0);
1337        pid.compute(1.0, 0.1);
1338        pid.reset();
1339        assert_eq!(pid.integral, 0.0);
1340    }
1341
1342    #[test]
1343    fn test_motor_thermal() {
1344        let mut th = MotorThermal::new();
1345        assert!((th.torque_scale() - 1.0).abs() < 1e-9);
1346        th.temperature = 100.0;
1347        assert!(th.torque_scale() < 1.0);
1348        th.temperature = 120.0;
1349        assert!(th.torque_scale() <= 0.0);
1350    }
1351
1352    #[test]
1353    fn test_motor_constraint_step() {
1354        let mut motor = PyMotorConstraint::velocity_mode(10.0, 100.0, 2.0);
1355        let torque = motor.step(0.0, 0.01);
1356        assert!(torque.abs() > 0.0);
1357    }
1358
1359    #[test]
1360    fn test_island_sleep() {
1361        let mut island = PyIsland::new(0);
1362        island.add_body(1);
1363        assert!(island.is_active());
1364        for _ in 0..60 {
1365            island.update_sleep(0.01); // below threshold
1366        }
1367        assert_eq!(island.sleep_state, IslandSleepState::Sleeping);
1368        island.update_sleep(1.0); // wake up
1369        assert_eq!(island.sleep_state, IslandSleepState::Awake);
1370    }
1371
1372    #[test]
1373    fn test_island_merge() {
1374        let mut a = PyIsland::new(0);
1375        a.add_body(1);
1376        let mut b = PyIsland::new(1);
1377        b.add_body(2);
1378        a.merge(&b);
1379        assert_eq!(a.bodies.len(), 2);
1380    }
1381
1382    #[test]
1383    fn test_ccd_result_miss() {
1384        let r = PyCCDResult::miss(0, 1);
1385        assert!(!r.hit);
1386        assert!((r.toi - 1.0).abs() < 1e-9);
1387    }
1388
1389    #[test]
1390    fn test_ccd_result_hit() {
1391        let r = PyCCDResult::hit(0, 1, 0.3, [0.0, 1.0, 0.0], [0.0; 3], [0.0, 0.01, 0.0]);
1392        assert!(r.hit);
1393        assert!((r.toi - 0.3).abs() < 1e-9);
1394    }
1395
1396    #[test]
1397    fn test_pbd_solver_stretch() {
1398        let positions = vec![0.0, 0.0, 0.0, 2.0, 0.0, 0.0];
1399        let inv_masses = vec![1.0, 1.0];
1400        let mut solver = PyPbdSolver::new(positions, inv_masses);
1401        solver.add_stretch(0, 1, 1e-4);
1402        solver.step(0.016);
1403        // Particles should have moved due to gravity
1404        assert!(solver.positions[1] < 0.0); // y should be negative
1405    }
1406
1407    #[test]
1408    fn test_pbd_particle_count() {
1409        let pos = vec![0.0; 9]; // 3 particles
1410        let inv = vec![1.0; 3];
1411        let s = PyPbdSolver::new(pos, inv);
1412        assert_eq!(s.particle_count(), 3);
1413    }
1414
1415    #[test]
1416    fn test_control_system_step_response() {
1417        let mut sys = PyControlSystem::first_order_lag(0.1);
1418        let resp = sys.step_response(0.01, 50);
1419        assert_eq!(resp.len(), 50);
1420        // Should converge toward 1.0
1421        assert!(*resp.last().unwrap() > 0.5);
1422    }
1423
1424    #[test]
1425    fn test_control_system_frequency() {
1426        let sys = PyControlSystem::first_order_lag(0.1);
1427        let dc = sys.frequency_gain(0.0);
1428        assert!(dc > 0.0);
1429        let hf = sys.frequency_gain(100.0);
1430        assert!(hf < dc);
1431    }
1432
1433    #[test]
1434    fn test_friction_model_coulomb() {
1435        let fm = PyFrictionModel::coulomb(0.5);
1436        assert!((fm.friction_limit(10.0, 0.0) - 5.0).abs() < 1e-9);
1437    }
1438
1439    #[test]
1440    fn test_friction_model_stribeck() {
1441        let fm = PyFrictionModel::stribeck(1.0, 0.5, 0.1);
1442        let f_static = fm.friction_limit(10.0, 0.0);
1443        let f_kinetic = fm.friction_limit(10.0, 1.0);
1444        assert!(f_static >= f_kinetic);
1445    }
1446
1447    #[test]
1448    fn test_warm_start_store_retrieve() {
1449        let mut ws = PyWarmStart::new();
1450        ws.store((0, 1), 5.0, [1.0, 2.0]);
1451        let r = ws.retrieve((0, 1));
1452        assert!(r.is_some());
1453        let (ln, _lt) = r.unwrap();
1454        assert!((ln - 5.0 * 0.85).abs() < 1e-9);
1455    }
1456
1457    #[test]
1458    fn test_warm_start_age_prune() {
1459        let mut ws = PyWarmStart::new();
1460        ws.store((0, 1), 5.0, [0.0, 0.0]);
1461        for _ in 0..=ws.max_age {
1462            ws.age_and_prune();
1463        }
1464        assert!(ws.retrieve((0, 1)).is_none());
1465    }
1466
1467    #[test]
1468    fn test_warm_start_clear() {
1469        let mut ws = PyWarmStart::new();
1470        ws.store((0, 1), 1.0, [0.0, 0.0]);
1471        ws.clear();
1472        assert!(ws.cache.is_empty());
1473    }
1474}