use arrayvec::ArrayVec;
use nalgebra::SVector;
use symtropy_math::Shape;
const MAX_ITERATIONS: usize = 64;
const MAX_SIMPLEX: usize = 5;
type Simplex<const D: usize> = ArrayVec<SVector<f64, D>, MAX_SIMPLEX>;
#[derive(Clone, Debug)]
pub struct GjkResult<const D: usize> {
pub intersecting: bool,
pub simplex: Simplex<D>,
pub iterations: usize,
}
pub fn intersects<const D: usize>(
shape_a: &dyn Shape<D>,
pos_a: &SVector<f64, D>,
shape_b: &dyn Shape<D>,
pos_b: &SVector<f64, D>,
) -> GjkResult<D> {
let mut direction = pos_b - pos_a;
if direction.norm_squared() < 1e-20 {
direction = SVector::zeros();
direction[0] = 1.0;
}
let first = minkowski_support(shape_a, pos_a, shape_b, pos_b, &direction);
let mut simplex = Simplex::new();
simplex.push(first);
direction = -first;
for iteration in 0..MAX_ITERATIONS {
if direction.norm_squared() < 1e-20 {
return GjkResult {
intersecting: true,
simplex,
iterations: iteration,
};
}
let new_point = minkowski_support(shape_a, pos_a, shape_b, pos_b, &direction);
if new_point.dot(&direction) < -1e-10 {
return GjkResult {
intersecting: false,
simplex,
iterations: iteration,
};
}
simplex.push(new_point);
if do_simplex(&mut simplex, &mut direction) {
return GjkResult {
intersecting: true,
simplex,
iterations: iteration,
};
}
}
GjkResult {
intersecting: false,
simplex,
iterations: MAX_ITERATIONS,
}
}
fn minkowski_support<const D: usize>(
shape_a: &dyn Shape<D>,
pos_a: &SVector<f64, D>,
shape_b: &dyn Shape<D>,
pos_b: &SVector<f64, D>,
direction: &SVector<f64, D>,
) -> SVector<f64, D> {
let sa = shape_a.support(direction) + pos_a;
let sb = shape_b.support(&-direction) + pos_b;
sa - sb
}
fn do_simplex<const D: usize>(
simplex: &mut Simplex<D>,
direction: &mut SVector<f64, D>,
) -> bool {
match simplex.len() {
2 => do_simplex_line(simplex, direction),
3 => do_simplex_triangle(simplex, direction),
4 if D >= 3 => do_simplex_tetrahedron(simplex, direction),
n if n == D + 1 => {
true
}
_ => {
do_simplex_general(simplex, direction)
}
}
}
fn do_simplex_line<const D: usize>(
simplex: &mut Simplex<D>,
direction: &mut SVector<f64, D>,
) -> bool {
let a = simplex[1]; let b = simplex[0];
let ab = b - a;
let ao = -a;
if ab.dot(&ao) > 0.0 {
*direction = triple_cross_product(&ab, &ao, &ab);
if direction.norm_squared() < 1e-20 {
return true;
}
} else {
simplex.clear();
simplex.push(a);
*direction = ao;
}
false
}
fn do_simplex_triangle<const D: usize>(
simplex: &mut Simplex<D>,
direction: &mut SVector<f64, D>,
) -> bool {
let a = simplex[2]; let b = simplex[1];
let c = simplex[0];
let ab = b - a;
let ac = c - a;
let ao = -a;
let ab_sq = ab.dot(&ab);
let ab_perp = if ab_sq > 1e-20 {
let proj = &ab * (ac.dot(&ab) / ab_sq);
-(ac - proj)
} else {
SVector::zeros()
};
if ab_perp.dot(&ao) > 0.0 {
simplex.clear();
simplex.push(b);
simplex.push(a);
return do_simplex_line(simplex, direction);
}
let ac_sq = ac.dot(&ac);
let ac_perp = if ac_sq > 1e-20 {
let proj = &ac * (ab.dot(&ac) / ac_sq);
-(ab - proj)
} else {
SVector::zeros()
};
if ac_perp.dot(&ao) > 0.0 {
simplex.clear();
simplex.push(c);
simplex.push(a);
return do_simplex_line(simplex, direction);
}
if D == 2 {
return true; }
let e1e1 = ab.norm_squared();
let e2e2 = ac.norm_squared();
let e1e2 = ab.dot(&ac);
let det = e1e1 * e2e2 - e1e2 * e1e2;
if det.abs() < 1e-20 {
simplex.clear();
simplex.push(b);
simplex.push(a);
return do_simplex_line(simplex, direction);
}
let ao_e1 = ao.dot(&ab);
let ao_e2 = ao.dot(&ac);
let proj = &ab * ((ao_e1 * e2e2 - ao_e2 * e1e2) / det)
+ &ac * ((ao_e2 * e1e1 - ao_e1 * e1e2) / det);
let face_normal = &ao - &proj;
if face_normal.norm_squared() < 1e-20 {
return true; }
*direction = face_normal;
false
}
fn do_simplex_tetrahedron<const D: usize>(
simplex: &mut Simplex<D>,
direction: &mut SVector<f64, D>,
) -> bool {
let a = simplex[3]; let b = simplex[2];
let c = simplex[1];
let d = simplex[0];
let ao = -a;
let ab = b - a;
let ac = c - a;
let ad = d - a;
let abc = face_normal(&ab, &ac, &ad);
let acd = face_normal(&ac, &ad, &ab);
let adb = face_normal(&ad, &ab, &ac);
if abc.dot(&ao) > 0.0 {
simplex.clear();
simplex.push(c);
simplex.push(b);
simplex.push(a);
return do_simplex_triangle(simplex, direction);
}
if acd.dot(&ao) > 0.0 {
simplex.clear();
simplex.push(d);
simplex.push(c);
simplex.push(a);
return do_simplex_triangle(simplex, direction);
}
if adb.dot(&ao) > 0.0 {
simplex.clear();
simplex.push(b);
simplex.push(d);
simplex.push(a);
return do_simplex_triangle(simplex, direction);
}
true
}
fn face_normal<const D: usize>(
edge1: &SVector<f64, D>,
edge2: &SVector<f64, D>,
opposite: &SVector<f64, D>,
) -> SVector<f64, D> {
let e1_norm = edge1.norm_squared();
let e2_norm = edge2.norm_squared();
let e1_e2 = edge1.dot(edge2);
let denom = e1_norm * e2_norm - e1_e2 * e1_e2;
if denom.abs() < 1e-20 {
return SVector::zeros(); }
let opp_e1 = opposite.dot(edge1);
let opp_e2 = opposite.dot(edge2);
let proj = edge1 * ((opp_e1 * e2_norm - opp_e2 * e1_e2) / denom)
+ edge2 * ((opp_e2 * e1_norm - opp_e1 * e1_e2) / denom);
let normal_toward_opp = opposite - proj;
-normal_toward_opp
}
fn do_simplex_general<const D: usize>(
simplex: &mut Simplex<D>,
direction: &mut SVector<f64, D>,
) -> bool {
let n = simplex.len();
let a = simplex[n - 1]; let ao = -a;
for i in 0..(n - 1) {
let mut face_vecs: ArrayVec<SVector<f64, D>, MAX_SIMPLEX> = ArrayVec::new();
for j in 0..(n - 1) {
if j != i {
face_vecs.push(simplex[j] - a);
}
}
if face_vecs.len() >= 1 {
let to_removed = simplex[i] - a;
let mut normal = to_removed;
for edge in &face_vecs {
let edge_nsq = edge.norm_squared();
if edge_nsq > 1e-20 {
let proj = edge * (normal.dot(edge) / edge_nsq);
normal -= proj;
}
}
let outward = -normal;
if outward.dot(&ao) > 0.0 {
simplex.remove(i);
*direction = outward;
return false;
}
}
}
true
}
#[inline]
fn triple_cross_product<const D: usize>(
a: &SVector<f64, D>,
b: &SVector<f64, D>,
c: &SVector<f64, D>,
) -> SVector<f64, D> {
b * a.dot(c) - a * b.dot(c)
}
#[cfg(test)]
mod tests {
use super::*;
use symtropy_math::{ConvexHull, Point, Sphere};
#[test]
fn spheres_overlapping_3d() {
let a = Sphere::<3>::new(Point::new([0.0, 0.0, 0.0]), 1.0);
let b = Sphere::<3>::new(Point::origin(), 1.0);
let pa = SVector::from([0.0, 0.0, 0.0]);
let pb = SVector::from([1.0, 0.0, 0.0]);
let result = intersects(&a, &pa, &b, &pb);
assert!(result.intersecting);
}
#[test]
fn spheres_separated_3d() {
let a = Sphere::<3>::new(Point::origin(), 1.0);
let b = Sphere::<3>::new(Point::origin(), 1.0);
let pa = SVector::from([0.0, 0.0, 0.0]);
let pb = SVector::from([3.0, 0.0, 0.0]);
let result = intersects(&a, &pa, &b, &pb);
assert!(!result.intersecting);
}
#[test]
fn spheres_touching_3d() {
let a = Sphere::<3>::new(Point::origin(), 1.0);
let b = Sphere::<3>::new(Point::origin(), 1.0);
let pa = SVector::from([0.0, 0.0, 0.0]);
let pb = SVector::from([2.0, 0.0, 0.0]);
let result = intersects(&a, &pa, &b, &pb);
assert!(result.iterations < MAX_ITERATIONS);
}
#[test]
fn boxes_overlapping_2d() {
let a = ConvexHull::<2>::unit_cube();
let b = ConvexHull::<2>::unit_cube();
let pa = SVector::from([0.0, 0.0]);
let pb = SVector::from([1.0, 0.0]);
let result = intersects(&a, &pa, &b, &pb);
assert!(result.intersecting);
}
#[test]
fn boxes_separated_2d() {
let a = ConvexHull::<2>::unit_cube();
let b = ConvexHull::<2>::unit_cube();
let pa = SVector::from([0.0, 0.0]);
let pb = SVector::from([5.0, 0.0]);
let result = intersects(&a, &pa, &b, &pb);
assert!(!result.intersecting);
}
#[test]
fn boxes_overlapping_4d() {
let a = ConvexHull::<4>::unit_cube(); let b = ConvexHull::<4>::unit_cube();
let pa = SVector::from([0.0, 0.0, 0.0, 0.0]);
let pb = SVector::from([1.0, 0.0, 0.0, 0.0]);
let result = intersects(&a, &pa, &b, &pb);
assert!(result.intersecting);
}
#[test]
fn boxes_separated_4d() {
let a = ConvexHull::<4>::unit_cube();
let b = ConvexHull::<4>::unit_cube();
let pa = SVector::from([0.0, 0.0, 0.0, 0.0]);
let pb = SVector::from([5.0, 0.0, 0.0, 0.0]);
let result = intersects(&a, &pa, &b, &pb);
assert!(!result.intersecting);
}
#[test]
fn sphere_vs_box_3d() {
let sphere = Sphere::<3>::new(Point::origin(), 1.0);
let cube = ConvexHull::<3>::unit_cube();
let ps = SVector::from([0.0, 0.0, 0.0]);
let pc = SVector::from([1.5, 0.0, 0.0]);
let result = intersects(&sphere, &ps, &cube, &pc);
assert!(result.intersecting, "sphere+box should overlap at dist 1.5");
}
#[test]
fn sphere_vs_box_separated_3d() {
let sphere = Sphere::<3>::new(Point::origin(), 1.0);
let cube = ConvexHull::<3>::unit_cube();
let ps = SVector::from([0.0, 0.0, 0.0]);
let pc = SVector::from([3.0, 0.0, 0.0]);
let result = intersects(&sphere, &ps, &cube, &pc);
assert!(!result.intersecting, "sphere+box should be separated at dist 3");
}
#[test]
fn identical_position_intersects() {
let a = Sphere::<3>::unit();
let pos = SVector::from([5.0, 5.0, 5.0]);
let result = intersects(&a, &pos, &a, &pos);
assert!(result.intersecting);
}
#[test]
fn convergence_sphere_box() {
let a = Sphere::<3>::unit();
let b = ConvexHull::<3>::unit_cube();
let pa = SVector::zeros();
let pb = SVector::from([0.5, 0.3, 0.1]);
let result = intersects(&a, &pa, &b, &pb);
assert!(result.intersecting, "sphere+cube should intersect, iters={}", result.iterations);
}
}
#[cfg(test)]
mod proptests {
use super::*;
use proptest::prelude::*;
use symtropy_math::{Point, Sphere};
proptest! {
#[test]
fn coincident_spheres_always_intersect(
x in -100.0f64..100.0,
y in -100.0f64..100.0,
z in -100.0f64..100.0,
r in 0.1f64..10.0,
) {
let a = Sphere::<3>::new(Point::origin(), r);
let b = Sphere::<3>::new(Point::origin(), r);
let pos = SVector::from([x, y, z]);
let result = intersects(&a, &pos, &b, &pos);
prop_assert!(result.intersecting, "coincident spheres must intersect");
}
#[test]
fn distant_spheres_never_intersect(
x in 0.0f64..10.0,
y in 0.0f64..10.0,
z in 0.0f64..10.0,
r in 0.1f64..1.0,
) {
let a = Sphere::<3>::new(Point::origin(), r);
let b = Sphere::<3>::new(Point::origin(), r);
let pa = SVector::from([0.0, 0.0, 0.0]);
let pb = SVector::from([x + 2.0 * r + 10.0, y + 2.0 * r + 10.0, z + 2.0 * r + 10.0]);
let result = intersects(&a, &pa, &b, &pb);
prop_assert!(!result.intersecting, "distant spheres must not intersect");
}
#[test]
fn gjk_is_symmetric(
ax in -10.0f64..10.0, ay in -10.0f64..10.0,
bx in -10.0f64..10.0, by in -10.0f64..10.0,
) {
let a = Sphere::<2>::unit();
let b = Sphere::<2>::unit();
let pa = SVector::from([ax, ay]);
let pb = SVector::from([bx, by]);
let ab = intersects(&a, &pa, &b, &pb);
let ba = intersects(&b, &pb, &a, &pa);
prop_assert!(ab.intersecting == ba.intersecting, "GJK must be symmetric: ab={} ba={}", ab.intersecting, ba.intersecting);
}
#[test]
fn gjk_always_terminates(
ax in -50.0f64..50.0, ay in -50.0f64..50.0, az in -50.0f64..50.0,
bx in -50.0f64..50.0, by in -50.0f64..50.0, bz in -50.0f64..50.0,
r in 0.01f64..5.0,
) {
let a = Sphere::<3>::new(Point::origin(), r);
let b = Sphere::<3>::new(Point::origin(), r);
let pa = SVector::from([ax, ay, az]);
let pb = SVector::from([bx, by, bz]);
let result = intersects(&a, &pa, &b, &pb);
prop_assert!(result.iterations <= MAX_ITERATIONS, "GJK must terminate");
}
#[test]
fn gjk_matches_analytical_sphere_check(
dist in 0.0f64..5.0,
r in 0.1f64..2.0,
) {
let a = Sphere::<3>::new(Point::origin(), r);
let b = Sphere::<3>::new(Point::origin(), r);
let pa = SVector::from([0.0, 0.0, 0.0]);
let pb = SVector::from([dist, 0.0, 0.0]);
let result = intersects(&a, &pa, &b, &pb);
let analytical = dist < 2.0 * r;
if (dist - 2.0 * r).abs() > 0.1 {
prop_assert!(
result.intersecting == analytical,
"GJK disagrees with analytical at dist={dist}, r={r}: gjk={}, analytical={analytical}",
result.intersecting,
);
}
}
}
}