phys-collision 2.0.1-beta.0

Provides collision detection ability
// Copyright (C) 2020-2025 phys-collision authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use std::any::Any;

use glam::Point3;
use phys_geom::Aabb3;
use rustc_hash::FxHashMap;
use vslab::{new_type_id, Slab};

use super::traits::{ComplexShapeTrait, ShapeSetTrait};
use crate::ray::RaycastHitResult;
use crate::{ContainsResult, Ray};

new_type_id!(
    /// Height Field id
    #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
    ComplexShapeId
);

#[derive(Clone)]
struct ShapeSet<T: 'static + ComplexShapeTrait>(Slab<ComplexShapeId, T>);
impl<T: 'static + ComplexShapeTrait> ShapeSet<T> {
    pub fn new() -> Self {
        ShapeSet::<T>(Slab::<ComplexShapeId, T>::default())
    }

    pub fn add(&mut self, value: T) -> ComplexShapeId {
        self.0.insert(value)
    }

    pub fn get(&self, id: ComplexShapeId) -> Option<&T> {
        self.0.get(id)
    }

    pub fn remove(&mut self, id: ComplexShapeId) -> Option<T> {
        self.0.remove(id)
    }
}

impl<T: 'static + Send + Sync + ComplexShapeTrait> ShapeSetTrait for ShapeSet<T> {
    fn as_any(&self) -> &dyn Any {
        self
    }

    fn as_any_mut(&mut self) -> &mut dyn Any {
        self
    }

    fn compute_aabb(&self, id: ComplexShapeId) -> Aabb3 {
        self.get(id).unwrap().compute_aabb()
    }

    fn max_radius_and_max_angular_expansion(&self, id: ComplexShapeId) -> (f32, f32) {
        self.get(id).unwrap().max_radius_and_max_angular_expansion()
    }

    fn raycast(
        &self,
        id: ComplexShapeId,
        local_ray: Ray,
        max_distance: f32,
        discard_inside_hit: bool,
    ) -> Option<RaycastHitResult> {
        self.get(id)
            .unwrap()
            .raycast(local_ray, max_distance, discard_inside_hit)
    }

    fn contains_point_with_threshold(
        &self,
        id: ComplexShapeId,
        local_point: Point3,
        threshold: f32,
    ) -> ContainsResult {
        self.get(id)
            .unwrap()
            .contains_point_with_threshold(local_point, threshold)
    }

    fn signed_distance_to_point(&self, id: ComplexShapeId, local_point: Point3) -> f32 {
        self.get(id).unwrap().signed_distance_to_point(local_point)
    }

    fn compute_volume(&self, id: ComplexShapeId) -> f32 {
        self.get(id).unwrap().compute_volume()
    }
}

#[derive(Default, Clone)]
pub struct ShapeContainer {
    type_containers: FxHashMap<u64, Box<dyn ShapeSetTrait>>,
}

impl ShapeContainer {
    #[inline]
    #[must_use]
    pub fn get_shape_set_by_id(&self, plugin_id: u64) -> &dyn ShapeSetTrait {
        self.type_containers
            .get(&plugin_id)
            .expect("plugin not found")
            .as_ref()
    }

    #[inline]
    #[must_use]
    pub fn get<T: ComplexShapeTrait + 'static>(&self, id: ComplexShapeId) -> Option<&T> {
        if let Some(shape_set) = self.get_shape_set::<T>() {
            shape_set.get(id)
        } else {
            None
        }
    }

    #[inline]
    #[must_use]
    pub fn add<T: ComplexShapeTrait + 'static + Send + Sync>(
        &mut self,
        value: T,
    ) -> ComplexShapeId {
        let plugin_id = T::PLUGIN_ID;
        let id = if let Some(container) = self.type_containers.get_mut(&plugin_id) {
            let shape_set = container
                .as_any_mut()
                .downcast_mut::<ShapeSet<T>>()
                .unwrap();
            shape_set.add(value)
        } else {
            let mut shape_set = ShapeSet::<T>::new();
            let id = shape_set.add(value);
            self.type_containers.insert(plugin_id, Box::new(shape_set));
            id
        };
        id
    }

    #[inline]
    pub fn remove<T: ComplexShapeTrait + 'static>(&mut self, id: ComplexShapeId) -> Option<T> {
        let plugin_id = T::PLUGIN_ID;
        if let Some(container) = self.type_containers.get_mut(&plugin_id) {
            let shape_set = container
                .as_any_mut()
                .downcast_mut::<ShapeSet<T>>()
                .unwrap();
            shape_set.remove(id)
        } else {
            None
        }
    }

    fn get_shape_set<T: ComplexShapeTrait + 'static>(&self) -> Option<&ShapeSet<T>> {
        let plugin_id = T::PLUGIN_ID;
        if let Some(container) = self.type_containers.get(&plugin_id) {
            let shape_set = container.as_any().downcast_ref::<ShapeSet<T>>().unwrap();
            Some(shape_set)
        } else {
            None
        }
    }
}

#[cfg(test)]
mod tests {
    use glam_det::Point3;
    use phys_geom::volume::ComputeVolume;
    use phys_geom::ComputeAabb3;
    use wasm_bindgen_test::*;

    use super::ComplexShapeId;
    use crate::ray::{Raycast, RaycastHitResult};
    use crate::shapes::convex_hull::ConvexHull;
    use crate::shapes::ShapeContainer;
    use crate::{ComplexShapeTrait, ContainsPoint, Expansion, ShapePlugin, SignedDistanceToPoint};

    #[derive(Clone, Debug)]
    struct TestShape;

    impl ShapePlugin for TestShape {
        const PLUGIN_ID: u64 = 0xc54a471a;
    }
    impl ComputeAabb3 for TestShape {
        fn compute_aabb(&self) -> phys_geom::Aabb3 {
            phys_geom::Aabb3::default()
        }
    }
    impl Expansion for TestShape {
        fn max_radius_and_max_angular_expansion(&self) -> (f32, f32) {
            (0.0, 0.0)
        }
    }
    impl Raycast for TestShape {
        fn raycast(
            &self,
            _local_ray: crate::Ray,
            _max_distance: f32,
            _discard_inside_hit: bool,
        ) -> std::option::Option<RaycastHitResult> {
            None
        }
    }
    impl ContainsPoint for TestShape {
        fn contains_point_with_threshold(
            &self,
            _local_point: Point3,
            _threshold: f32,
        ) -> crate::ContainsResult {
            crate::ContainsResult::Outside
        }
    }
    impl SignedDistanceToPoint for TestShape {
        fn signed_distance_to_point(&self, _local_point: Point3) -> f32 {
            0.0
        }
    }
    impl ComputeVolume for TestShape {
        fn compute_volume(&self) -> phys_geom::math::Real {
            0.0
        }
    }
    impl ComplexShapeTrait for TestShape {}

    #[test]
    #[wasm_bindgen_test]
    fn test_add_and_get() {
        let _ = env_logger::builder().is_test(true).try_init();

        let test_shape = TestShape;
        let mut shape_batches = ShapeContainer::default();
        let test_shape_id = shape_batches.add(test_shape);
        let invalid_id = ComplexShapeId(0x12456);
        assert!(shape_batches.get::<TestShape>(test_shape_id).is_some());
        assert!(shape_batches.get::<ConvexHull>(invalid_id).is_none());

        let test_shape_set = shape_batches.get_shape_set::<TestShape>();
        let convex_hull_set = shape_batches.get_shape_set::<ConvexHull>();
        assert!(test_shape_set.is_some());
        assert!(convex_hull_set.is_none());

        let length = test_shape_set.unwrap().0.len();
        assert_eq!(length, 1);

        shape_batches.remove::<TestShape>(test_shape_id);
        shape_batches.remove::<TestShape>(invalid_id);

        let test_shape_set = shape_batches.get_shape_set::<TestShape>();
        let length = test_shape_set.unwrap().0.len();
        assert_eq!(length, 0);
    }
}