Skip to main content

gizmo_physics_rigid/joints/
solver.rs

1use super::data::*;
2use gizmo_physics_core::components::Transform;
3use crate::components::{RigidBody, Velocity};
4use gizmo_math::Vec3;
5
6pub struct JointSolver {
7    pub iterations: usize,
8    pub max_correction_speed: f32,
9    pub max_angular_speed: f32,
10    pub position_bias: f32,
11}
12
13impl Default for JointSolver {
14    fn default() -> Self {
15        Self {
16            iterations: 10,
17            max_correction_speed: 5.0,
18            max_angular_speed: 5.0,
19            position_bias: 0.3,
20        }
21    }
22}
23
24impl JointSolver {
25    pub fn new(iterations: usize) -> Self {
26        Self {
27            iterations,
28            ..Default::default()
29        }
30    }
31
32    pub fn solve_joints(
33        &self,
34        joints: &mut [Joint],
35        entity_index_map: &std::collections::HashMap<u32, usize>,
36        rigid_bodies: &[RigidBody],
37        transforms: &[Transform],
38        velocities: &mut [Velocity],
39        dt: f32,
40    ) {
41        for _ in 0..self.iterations {
42            for joint in joints.iter_mut() {
43                if joint.is_broken {
44                    continue;
45                }
46
47                let idx_a = entity_index_map.get(&joint.entity_a.id()).copied();
48                let idx_b = entity_index_map.get(&joint.entity_b.id()).copied();
49                let (Some(idx_a), Some(idx_b)) = (idx_a, idx_b) else {
50                    continue;
51                };
52                if idx_a == idx_b {
53                    continue;
54                }
55
56                match joint.joint_type() {
57                    "Fixed" => self.solve_fixed_joint(
58                        joint,
59                        rigid_bodies,
60                        transforms,
61                        velocities,
62                        idx_a,
63                        idx_b,
64                        dt,
65                    ),
66                    "Hinge" => self.solve_hinge_joint(
67                        joint,
68                        rigid_bodies,
69                        transforms,
70                        velocities,
71                        idx_a,
72                        idx_b,
73                        dt,
74                    ),
75                    "BallSocket" => self.solve_ball_socket_joint(
76                        joint,
77                        rigid_bodies,
78                        transforms,
79                        velocities,
80                        idx_a,
81                        idx_b,
82                        dt,
83                    ),
84                    "Slider" => self.solve_slider_joint(
85                        joint,
86                        rigid_bodies,
87                        transforms,
88                        velocities,
89                        idx_a,
90                        idx_b,
91                        dt,
92                    ),
93                    "Spring" => self.solve_spring_joint(
94                        joint,
95                        rigid_bodies,
96                        transforms,
97                        velocities,
98                        idx_a,
99                        idx_b,
100                        dt,
101                    ),
102                    _ => {}
103                }
104            }
105        }
106    }
107
108    // ── helpers ──────────────────────────────────────────────────────────────
109
110    /// Two unit vectors perpendicular to `v`.
111    fn perpendiculars(v: Vec3) -> (Vec3, Vec3) {
112        let p1 = if v.x.abs() < 0.9 {
113            v.cross(Vec3::X).normalize()
114        } else {
115            v.cross(Vec3::Y).normalize()
116        };
117        (p1, v.cross(p1))
118    }
119
120    /// Apply a 1-DOF angular velocity constraint along `direction`.
121    /// `error` is the positional error in radians (positive = bodies need to rotate apart).
122    fn apply_angular_constraint(
123        &self,
124        rigid_bodies: &[RigidBody],
125        transforms: &[Transform],
126        velocities: &mut [Velocity],
127        idx_a: usize,
128        idx_b: usize,
129        direction: Vec3,
130        error: f32,
131        dt: f32,
132        lambda_min: f32,
133        lambda_max: f32,
134    ) -> f32 {
135        if direction.length_squared() < 1e-10 {
136            return 0.0;
137        }
138
139        let inv_i_a = rigid_bodies[idx_a].inv_world_inertia_tensor(transforms[idx_a].rotation);
140        let inv_i_b = rigid_bodies[idx_b].inv_world_inertia_tensor(transforms[idx_b].rotation);
141        let w_a = velocities[idx_a].angular;
142        let w_b = velocities[idx_b].angular;
143        let dyn_a = rigid_bodies[idx_a].is_dynamic();
144        let dyn_b = rigid_bodies[idx_b].is_dynamic();
145
146        let k = direction.dot(inv_i_a.mul_vec3(direction)) + direction.dot(inv_i_b.mul_vec3(direction));
147        if k < 1e-10 {
148            return 0.0;
149        }
150
151        let vel_err = (w_b - w_a).dot(direction);
152        let position_bias = (self.position_bias * error / dt)
153            .clamp(-self.max_angular_speed, self.max_angular_speed);
154        let lambda = ((-vel_err + position_bias) / k).clamp(lambda_min, lambda_max);
155
156        let delta_a = inv_i_a.mul_vec3(direction) * lambda;
157        let delta_b = inv_i_b.mul_vec3(direction) * lambda;
158
159        if idx_a < idx_b {
160            let (l, r) = velocities.split_at_mut(idx_b);
161            if dyn_a {
162                l[idx_a].angular -= delta_a;
163            }
164            if dyn_b {
165                r[0].angular += delta_b;
166            }
167        } else {
168            let (l, r) = velocities.split_at_mut(idx_a);
169            if dyn_b {
170                l[idx_b].angular += delta_b;
171            }
172            if dyn_a {
173                r[0].angular -= delta_a;
174            }
175        }
176        lambda
177    }
178
179    /// Apply a 1-DOF linear velocity constraint along `direction` at the anchor points.
180    fn apply_linear_constraint(
181        &self,
182        rigid_bodies: &[RigidBody],
183        transforms: &[Transform],
184        velocities: &mut [Velocity],
185        idx_a: usize,
186        idx_b: usize,
187        direction: Vec3,
188        r_a: Vec3,
189        r_b: Vec3,
190        error: f32,
191        dt: f32,
192        lambda_min: f32,
193        lambda_max: f32,
194    ) -> f32 {
195        let inv_m_a = rigid_bodies[idx_a].inv_mass();
196        let inv_m_b = rigid_bodies[idx_b].inv_mass();
197        let inv_i_a = rigid_bodies[idx_a].inv_world_inertia_tensor(transforms[idx_a].rotation);
198        let inv_i_b = rigid_bodies[idx_b].inv_world_inertia_tensor(transforms[idx_b].rotation);
199        let v_a = velocities[idx_a].linear + velocities[idx_a].angular.cross(r_a);
200        let v_b = velocities[idx_b].linear + velocities[idx_b].angular.cross(r_b);
201        let dyn_a = rigid_bodies[idx_a].is_dynamic();
202        let dyn_b = rigid_bodies[idx_b].is_dynamic();
203
204        let ang_a = (inv_i_a.mul_vec3(r_a).cross(direction)).cross(r_a);
205        let ang_b = (inv_i_b.mul_vec3(r_b).cross(direction)).cross(r_b);
206        let k = inv_m_a + inv_m_b + ang_a.dot(direction) + ang_b.dot(direction);
207        if k < 1e-10 {
208            return 0.0;
209        }
210
211        let rel_vel = (v_b - v_a).dot(direction);
212        let position_bias = (self.position_bias * error / dt)
213            .clamp(-self.max_correction_speed, self.max_correction_speed);
214        let lambda = ((-rel_vel + position_bias) / k).clamp(lambda_min, lambda_max);
215
216        let impulse = direction * lambda;
217
218        if idx_a < idx_b {
219            let (l, r) = velocities.split_at_mut(idx_b);
220            if dyn_a {
221                l[idx_a].linear -= impulse * inv_m_a;
222                l[idx_a].angular -= inv_i_a.mul_vec3(r_a.cross(impulse));
223            }
224            if dyn_b {
225                r[0].linear += impulse * inv_m_b;
226                r[0].angular += inv_i_b.mul_vec3(r_b.cross(impulse));
227            }
228        } else {
229            let (l, r) = velocities.split_at_mut(idx_a);
230            if dyn_b {
231                l[idx_b].linear += impulse * inv_m_b;
232                l[idx_b].angular += inv_i_b.mul_vec3(r_b.cross(impulse));
233            }
234            if dyn_a {
235                r[0].linear -= impulse * inv_m_a;
236                r[0].angular -= inv_i_a.mul_vec3(r_a.cross(impulse));
237            }
238        }
239        lambda
240    }
241
242    // ── joint solvers ─────────────────────────────────────────────────────────
243
244    fn solve_fixed_joint(
245        &self,
246        joint: &mut Joint,
247        rigid_bodies: &[RigidBody],
248        transforms: &[Transform],
249        velocities: &mut [Velocity],
250        idx_a: usize,
251        idx_b: usize,
252        dt: f32,
253    ) {
254        let anchor_a =
255            transforms[idx_a].position + transforms[idx_a].rotation * joint.local_anchor_a;
256        let anchor_b =
257            transforms[idx_b].position + transforms[idx_b].rotation * joint.local_anchor_b;
258        let error = anchor_a - anchor_b; // target = a, current = b, so error = a - b
259        let err_len = error.length();
260
261        if err_len < 0.0001 {
262            return;
263        }
264
265        let r_a = anchor_a - transforms[idx_a].position;
266        let r_b = anchor_b - transforms[idx_b].position;
267
268        let max_impulse = f32::MAX;
269        let min_impulse = f32::MIN;
270
271        let mut impulse_sum = 0.0;
272        impulse_sum += self
273            .apply_linear_constraint(
274                rigid_bodies,
275                transforms,
276                velocities,
277                idx_a,
278                idx_b,
279                Vec3::new(1.0, 0.0, 0.0),
280                r_a,
281                r_b,
282                error.x,
283                dt,
284                min_impulse,
285                max_impulse,
286            )
287            .abs();
288        impulse_sum += self
289            .apply_linear_constraint(
290                rigid_bodies,
291                transforms,
292                velocities,
293                idx_a,
294                idx_b,
295                Vec3::new(0.0, 1.0, 0.0),
296                r_a,
297                r_b,
298                error.y,
299                dt,
300                min_impulse,
301                max_impulse,
302            )
303            .abs();
304        impulse_sum += self
305            .apply_linear_constraint(
306                rigid_bodies,
307                transforms,
308                velocities,
309                idx_a,
310                idx_b,
311                Vec3::new(0.0, 0.0, 1.0),
312                r_a,
313                r_b,
314                error.z,
315                dt,
316                min_impulse,
317                max_impulse,
318            )
319            .abs();
320
321        if impulse_sum / dt > joint.break_force {
322            joint.is_broken = true;
323        }
324    }
325
326    fn solve_hinge_joint(
327        &self,
328        joint: &mut Joint,
329        rigid_bodies: &[RigidBody],
330        transforms: &[Transform],
331        velocities: &mut [Velocity],
332        idx_a: usize,
333        idx_b: usize,
334        dt: f32,
335    ) {
336        // 1. Position constraint — keep anchor points together
337        self.solve_fixed_joint(
338            joint,
339            rigid_bodies,
340            transforms,
341            velocities,
342            idx_a,
343            idx_b,
344            dt,
345        );
346
347        let JointData::Hinge(ref mut data) = joint.data else {
348            return;
349        };
350
351        let rot_a = transforms[idx_a].rotation;
352        let rot_b = transforms[idx_b].rotation;
353        let axis_a = rot_a * data.axis;
354        let axis_b = rot_b * data.axis;
355
356        // 2. Angular constraint — keep hinge axes aligned (2 DOF)
357        let ang_err = axis_a.cross(axis_b);
358        let err_mag = ang_err.length();
359        let mut total_ang_impulse = 0.0;
360        if err_mag > 1e-6 {
361            let err_dir = ang_err / err_mag;
362            total_ang_impulse += self
363                .apply_angular_constraint(
364                    rigid_bodies,
365                    transforms,
366                    velocities,
367                    idx_a,
368                    idx_b,
369                    err_dir,
370                    -err_mag,
371                    dt,
372                    f32::NEG_INFINITY,
373                    f32::INFINITY,
374                )
375                .abs();
376        }
377
378        // 3. Track current angle
379        let ref_local = if data.axis.cross(Vec3::X).length() > 0.1 {
380            data.axis.cross(Vec3::X).normalize()
381        } else {
382            data.axis.cross(Vec3::Y).normalize()
383        };
384
385        let rot_a = transforms[idx_a].rotation;
386        let rot_b = transforms[idx_b].rotation;
387        let axis_w = rot_a * data.axis;
388        let ref_a_w = rot_a * ref_local;
389        let ref_b_w = rot_b * ref_local;
390
391        let proj_a = (ref_a_w - axis_w * ref_a_w.dot(axis_w)).normalize_or_zero();
392        let proj_b = (ref_b_w - axis_w * ref_b_w.dot(axis_w)).normalize_or_zero();
393
394        if proj_a.length_squared() > 0.01 && proj_b.length_squared() > 0.01 {
395            let cos_a = proj_a.dot(proj_b).clamp(-1.0, 1.0);
396            let sign = if proj_a.cross(proj_b).dot(axis_w) >= 0.0 {
397                1.0_f32
398            } else {
399                -1.0
400            };
401            data.current_angle = sign * cos_a.acos();
402
403            // 4. Angle limits
404            if data.use_limits {
405                if data.current_angle < data.lower_limit {
406                    let err = data.lower_limit - data.current_angle;
407                    // axis_w points from A to B; positive lambda increases angle
408                    total_ang_impulse += self
409                        .apply_angular_constraint(
410                            rigid_bodies,
411                            transforms,
412                            velocities,
413                            idx_a,
414                            idx_b,
415                            axis_w,
416                            err,
417                            dt,
418                            0.0,
419                            f32::INFINITY,
420                        )
421                        .abs();
422                } else if data.current_angle > data.upper_limit {
423                    let err = data.upper_limit - data.current_angle; // negative
424                    total_ang_impulse += self
425                        .apply_angular_constraint(
426                            rigid_bodies,
427                            transforms,
428                            velocities,
429                            idx_a,
430                            idx_b,
431                            axis_w,
432                            err,
433                            dt,
434                            f32::NEG_INFINITY,
435                            0.0,
436                        )
437                        .abs();
438                }
439            }
440        }
441
442        if total_ang_impulse / dt > joint.break_torque {
443            joint.is_broken = true;
444            return;
445        }
446
447        // 5. Motor — velocity constraint along hinge axis
448        if data.use_motor {
449            let axis_w = transforms[idx_a].rotation * data.axis;
450            let inv_i_a = rigid_bodies[idx_a].inv_world_inertia_tensor(transforms[idx_a].rotation);
451            let inv_i_b = rigid_bodies[idx_b].inv_world_inertia_tensor(transforms[idx_b].rotation);
452            let w_a = velocities[idx_a].angular;
453            let w_b = velocities[idx_b].angular;
454            let dyn_a = rigid_bodies[idx_a].is_dynamic();
455            let dyn_b = rigid_bodies[idx_b].is_dynamic();
456
457            let k = axis_w.dot(inv_i_a.mul_vec3(axis_w)) + axis_w.dot(inv_i_b.mul_vec3(axis_w));
458            if k > 1e-10 {
459                let rel_vel = (w_b - w_a).dot(axis_w);
460                let vel_err = data.motor_target_velocity - rel_vel;
461                let max_impulse = data.motor_max_force * dt;
462                let lambda = (vel_err / k).clamp(-max_impulse, max_impulse);
463
464                let delta_a = inv_i_a.mul_vec3(axis_w) * lambda;
465                let delta_b = inv_i_b.mul_vec3(axis_w) * lambda;
466
467                if idx_a < idx_b {
468                    let (l, r) = velocities.split_at_mut(idx_b);
469                    if dyn_a {
470                        l[idx_a].angular -= delta_a;
471                    }
472                    if dyn_b {
473                        r[0].angular += delta_b;
474                    }
475                } else {
476                    let (l, r) = velocities.split_at_mut(idx_a);
477                    if dyn_b {
478                        l[idx_b].angular += delta_b;
479                    }
480                    if dyn_a {
481                        r[0].angular -= delta_a;
482                    }
483                }
484            }
485        }
486    }
487
488    fn solve_ball_socket_joint(
489        &self,
490        joint: &mut Joint,
491        rigid_bodies: &[RigidBody],
492        transforms: &[Transform],
493        velocities: &mut [Velocity],
494        idx_a: usize,
495        idx_b: usize,
496        dt: f32,
497    ) {
498        // 1. Position constraint
499        self.solve_fixed_joint(
500            joint,
501            rigid_bodies,
502            transforms,
503            velocities,
504            idx_a,
505            idx_b,
506            dt,
507        );
508
509        let JointData::BallSocket(ref mut data) = joint.data else {
510            return;
511        };
512        if !data.use_cone_limit {
513            return;
514        }
515
516        // 2. Initialise reference rotation on first solve
517        let relative_rot = transforms[idx_a].rotation.inverse() * transforms[idx_b].rotation;
518        let initial_rot = match data.initial_relative_rotation {
519            None => {
520                data.initial_relative_rotation = Some(relative_rot);
521                return;
522            }
523            Some(rot) => rot,
524        };
525
526        // Compute the "swing" rotation of B away from its initial orientation (in A's frame)
527        let swing_quat = initial_rot.inverse() * relative_rot;
528
529        // Small-angle: angular error ≈ 2 * quat.xyz (when w ≥ 0)
530        let swing_err_local = if swing_quat.w >= 0.0 {
531            Vec3::new(swing_quat.x, swing_quat.y, swing_quat.z) * 2.0
532        } else {
533            -Vec3::new(swing_quat.x, swing_quat.y, swing_quat.z) * 2.0
534        };
535
536        let swing_angle = swing_err_local.length();
537        if swing_angle <= data.cone_limit_angle || swing_angle < 1e-6 {
538            return;
539        }
540
541        let excess = swing_angle - data.cone_limit_angle;
542        let swing_dir_local = swing_err_local / swing_angle;
543
544        // Convert error direction to world space
545        let swing_dir_world = transforms[idx_a].rotation * swing_dir_local;
546
547        let mut total_ang_impulse = 0.0;
548        total_ang_impulse += self
549            .apply_angular_constraint(
550                rigid_bodies,
551                transforms,
552                velocities,
553                idx_a,
554                idx_b,
555                swing_dir_world,
556                -excess,
557                dt,
558                f32::NEG_INFINITY,
559                0.0,
560            )
561            .abs();
562        if total_ang_impulse / dt > joint.break_torque {
563            joint.is_broken = true;
564        }
565    }
566
567    fn solve_slider_joint(
568        &self,
569        joint: &mut Joint,
570        rigid_bodies: &[RigidBody],
571        transforms: &[Transform],
572        velocities: &mut [Velocity],
573        idx_a: usize,
574        idx_b: usize,
575        dt: f32,
576    ) {
577        let JointData::Slider(ref mut data) = joint.data else {
578            return;
579        };
580
581        let anchor_a =
582            transforms[idx_a].position + transforms[idx_a].rotation * joint.local_anchor_a;
583        let anchor_b =
584            transforms[idx_b].position + transforms[idx_b].rotation * joint.local_anchor_b;
585        let axis_w = (transforms[idx_a].rotation * data.axis).normalize();
586
587        let delta = anchor_b - anchor_a;
588        let along = delta.dot(axis_w);
589        let off_axis = anchor_a - (anchor_b - axis_w * along); // error = target - current
590
591        data.current_position = along;
592
593        let r_a = anchor_a - transforms[idx_a].position;
594        let r_b = anchor_b - transforms[idx_b].position;
595
596        let mut total_lin_impulse = 0.0;
597        let mut total_ang_impulse = 0.0;
598
599        // 1. Off-axis constraint: project onto two perpendicular directions
600        let (perp1, perp2) = Self::perpendiculars(axis_w);
601
602        let err1 = off_axis.dot(perp1);
603        if err1.abs() > 1e-4 {
604            total_lin_impulse += self
605                .apply_linear_constraint(
606                    rigid_bodies,
607                    transforms,
608                    velocities,
609                    idx_a,
610                    idx_b,
611                    perp1,
612                    r_a,
613                    r_b,
614                    err1,
615                    dt,
616                    f32::NEG_INFINITY,
617                    f32::INFINITY,
618                )
619                .abs();
620        }
621
622        let err2 = off_axis.dot(perp2);
623        if err2.abs() > 1e-4 {
624            total_lin_impulse += self
625                .apply_linear_constraint(
626                    rigid_bodies,
627                    transforms,
628                    velocities,
629                    idx_a,
630                    idx_b,
631                    perp2,
632                    r_a,
633                    r_b,
634                    err2,
635                    dt,
636                    f32::NEG_INFINITY,
637                    f32::INFINITY,
638                )
639                .abs();
640        }
641
642        // 2. Angular lock — full 3-DOF rotation constraint using quaternion error
643        let relative_rot = transforms[idx_a].rotation.inverse() * transforms[idx_b].rotation;
644        if let Some(initial_rot) = data.initial_relative_rotation {
645            let err_quat = initial_rot.inverse() * relative_rot;
646            let ang_err_local = if err_quat.w >= 0.0 {
647                Vec3::new(err_quat.x, err_quat.y, err_quat.z) * 2.0
648            } else {
649                -Vec3::new(err_quat.x, err_quat.y, err_quat.z) * 2.0
650            };
651
652            let err_world = transforms[idx_a].rotation * ang_err_local;
653            let err_mag = err_world.length();
654            if err_mag > 1e-6 {
655                total_ang_impulse += self
656                    .apply_angular_constraint(
657                        rigid_bodies,
658                        transforms,
659                        velocities,
660                        idx_a,
661                        idx_b,
662                        err_world / err_mag,
663                        -err_mag,
664                        dt,
665                        f32::NEG_INFINITY,
666                        f32::INFINITY,
667                    )
668                    .abs();
669            }
670        } else {
671            data.initial_relative_rotation = Some(relative_rot);
672        }
673
674        // 3. Along-axis limits
675        if data.use_limits {
676            if along < data.lower_limit {
677                let err = data.lower_limit - along;
678                total_lin_impulse += self
679                    .apply_linear_constraint(
680                        rigid_bodies,
681                        transforms,
682                        velocities,
683                        idx_a,
684                        idx_b,
685                        axis_w,
686                        r_a,
687                        r_b,
688                        err,
689                        dt,
690                        f32::NEG_INFINITY,
691                        0.0,
692                    )
693                    .abs();
694            } else if along > data.upper_limit {
695                let err = data.upper_limit - along; // negative
696                total_lin_impulse += self
697                    .apply_linear_constraint(
698                        rigid_bodies,
699                        transforms,
700                        velocities,
701                        idx_a,
702                        idx_b,
703                        axis_w,
704                        r_a,
705                        r_b,
706                        err,
707                        dt,
708                        0.0,
709                        f32::INFINITY,
710                    )
711                    .abs();
712            }
713        }
714
715        if total_lin_impulse / dt > joint.break_force || total_ang_impulse / dt > joint.break_torque
716        {
717            joint.is_broken = true;
718            return;
719        }
720
721        // 4. Motor — velocity along axis
722        if data.use_motor {
723            let max_impulse = data.motor_max_force * dt;
724
725            let v_a = velocities[idx_a].linear + velocities[idx_a].angular.cross(r_a);
726            let v_b = velocities[idx_b].linear + velocities[idx_b].angular.cross(r_b);
727            let rel_vel = (v_b - v_a).dot(axis_w);
728            let vel_err = data.motor_target_velocity - rel_vel;
729
730            let inv_m_a = rigid_bodies[idx_a].inv_mass();
731            let inv_m_b = rigid_bodies[idx_b].inv_mass();
732            let inv_i_a = rigid_bodies[idx_a].inv_world_inertia_tensor(transforms[idx_a].rotation);
733            let inv_i_b = rigid_bodies[idx_b].inv_world_inertia_tensor(transforms[idx_b].rotation);
734            let dyn_a = rigid_bodies[idx_a].is_dynamic();
735            let dyn_b = rigid_bodies[idx_b].is_dynamic();
736
737            let ang_a = (inv_i_a.mul_vec3(r_a).cross(axis_w)).cross(r_a);
738            let ang_b = (inv_i_b.mul_vec3(r_b).cross(axis_w)).cross(r_b);
739            let k = inv_m_a + inv_m_b + ang_a.dot(axis_w) + ang_b.dot(axis_w);
740            if k > 1e-10 {
741                let lambda = (vel_err / k).clamp(-max_impulse, max_impulse);
742                let impulse = axis_w * lambda;
743
744                if idx_a < idx_b {
745                    let (l, r) = velocities.split_at_mut(idx_b);
746                    if dyn_a {
747                        l[idx_a].linear -= impulse * inv_m_a;
748                        l[idx_a].angular -= inv_i_a.mul_vec3(r_a.cross(impulse));
749                    }
750                    if dyn_b {
751                        r[0].linear += impulse * inv_m_b;
752                        r[0].angular += inv_i_b.mul_vec3(r_b.cross(impulse));
753                    }
754                } else {
755                    let (l, r) = velocities.split_at_mut(idx_a);
756                    if dyn_b {
757                        l[idx_b].linear += impulse * inv_m_b;
758                        l[idx_b].angular += inv_i_b.mul_vec3(r_b.cross(impulse));
759                    }
760                    if dyn_a {
761                        r[0].linear -= impulse * inv_m_a;
762                        r[0].angular -= inv_i_a.mul_vec3(r_a.cross(impulse));
763                    }
764                }
765            }
766        }
767    }
768
769    fn solve_spring_joint(
770        &self,
771        joint: &Joint,
772        rigid_bodies: &[RigidBody],
773        transforms: &[Transform],
774        velocities: &mut [Velocity],
775        idx_a: usize,
776        idx_b: usize,
777        dt: f32,
778    ) {
779        let JointData::Spring(data) = joint.data else {
780            return;
781        };
782
783        let anchor_a =
784            transforms[idx_a].position + transforms[idx_a].rotation * joint.local_anchor_a;
785        let anchor_b =
786            transforms[idx_b].position + transforms[idx_b].rotation * joint.local_anchor_b;
787
788        let diff = anchor_b - anchor_a;
789        let length = diff.length();
790        if length < 1e-6 {
791            return;
792        }
793
794        let direction = diff / length;
795        let r_a = anchor_a - transforms[idx_a].position;
796        let r_b = anchor_b - transforms[idx_b].position;
797
798        let v_a = velocities[idx_a].linear + velocities[idx_a].angular.cross(r_a);
799        let v_b = velocities[idx_b].linear + velocities[idx_b].angular.cross(r_b);
800
801        // Force calculation
802        // direction points from A to B
803        let spring_force = data.stiffness * (length - data.rest_length); // Positive if stretched (pulls together)
804        let relative_vel = (v_b - v_a).dot(direction); // Positive if B is moving away from A
805        let damping_force = data.damping * relative_vel;
806
807        // Total force pulling them together
808        let pull_force = spring_force + damping_force;
809        let pull_impulse = pull_force * dt;
810
811        // Hard limits (optional max_length)
812        let clamped_impulse = if length <= data.min_length && pull_impulse > 0.0 {
813            0.0 // already at min length, stop pulling
814        } else if let Some(max_len) = data.max_length {
815            if length >= max_len && pull_impulse < 0.0 {
816                0.0 // already at max length, stop pushing apart
817            } else {
818                pull_impulse
819            }
820        } else {
821            pull_impulse
822        };
823
824        if clamped_impulse.abs() < 1e-10 {
825            return;
826        }
827
828        // Apply impulse along direction (A to B)
829        // If clamped_impulse > 0, they are pulled together: A moves to B (+), B moves to A (-)
830        let impulse = direction * clamped_impulse;
831        let inv_m_a = rigid_bodies[idx_a].inv_mass();
832        let inv_m_b = rigid_bodies[idx_b].inv_mass();
833        let inv_i_a = rigid_bodies[idx_a].inv_world_inertia_tensor(transforms[idx_a].rotation);
834        let inv_i_b = rigid_bodies[idx_b].inv_world_inertia_tensor(transforms[idx_b].rotation);
835        let dyn_a = rigid_bodies[idx_a].is_dynamic();
836        let dyn_b = rigid_bodies[idx_b].is_dynamic();
837
838        if idx_a < idx_b {
839            let (l, r) = velocities.split_at_mut(idx_b);
840            if dyn_a {
841                l[idx_a].linear += impulse * inv_m_a;
842                l[idx_a].angular += inv_i_a.mul_vec3(r_a.cross(impulse));
843            }
844            if dyn_b {
845                r[0].linear -= impulse * inv_m_b;
846                r[0].angular -= inv_i_b.mul_vec3(r_b.cross(impulse));
847            }
848        } else {
849            let (l, r) = velocities.split_at_mut(idx_a);
850            if dyn_b {
851                l[idx_b].linear -= impulse * inv_m_b;
852                l[idx_b].angular -= inv_i_b.mul_vec3(r_b.cross(impulse));
853            }
854            if dyn_a {
855                r[0].linear += impulse * inv_m_a;
856                r[0].angular += inv_i_a.mul_vec3(r_a.cross(impulse));
857            }
858        }
859    }
860}
861
862#[cfg(test)]
863mod tests {
864    use super::*;
865    use gizmo_core::entity::Entity;
866
867    #[test]
868    fn test_joint_creation() {
869        let e1 = Entity::new(1, 0);
870        let e2 = Entity::new(2, 0);
871        let joint = Joint::fixed(e1, e2, Vec3::ZERO, Vec3::ZERO);
872        assert_eq!(joint.joint_type(), "Fixed");
873        assert!(!joint.is_broken);
874    }
875
876    #[test]
877    fn test_hinge_joint() {
878        let e1 = Entity::new(1, 0);
879        let e2 = Entity::new(2, 0);
880        let joint = Joint::hinge(e1, e2, Vec3::ZERO, Vec3::ZERO, Vec3::Y);
881        assert_eq!(joint.joint_type(), "Hinge");
882        if let JointData::Hinge(data) = joint.data {
883            assert_eq!(data.axis, Vec3::Y);
884        } else {
885            panic!("expected hinge data");
886        }
887    }
888
889    #[test]
890    fn test_spring_joint() {
891        let e1 = Entity::new(1, 0);
892        let e2 = Entity::new(2, 0);
893        let joint = Joint::spring(e1, e2, Vec3::ZERO, Vec3::ZERO, 1.0, 100.0, 10.0);
894        if let JointData::Spring(data) = joint.data {
895            assert_eq!(data.stiffness, 100.0);
896            assert_eq!(data.damping, 10.0);
897        } else {
898            panic!("expected spring data");
899        }
900    }
901
902    #[test]
903    fn test_perpendiculars_orthogonality() {
904        let v = Vec3::new(1.0, 0.0, 0.0);
905        let (p1, p2) = JointSolver::perpendiculars(v);
906        assert!(p1.dot(v).abs() < 1e-5);
907        assert!(p2.dot(v).abs() < 1e-5);
908        assert!(p1.dot(p2).abs() < 1e-5);
909    }
910}