use std::any::Any;
use glam::Point3;
use phys_geom::Aabb3;
use rustc_hash::FxHashMap;
use vslab::{new_type_id, Slab};
use super::traits::{ComplexShapeTrait, ShapeSetTrait};
use crate::ray::RaycastHitResult;
use crate::{ContainsResult, Ray};
new_type_id!(
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
ComplexShapeId
);
#[derive(Clone)]
struct ShapeSet<T: 'static + ComplexShapeTrait>(Slab<ComplexShapeId, T>);
impl<T: 'static + ComplexShapeTrait> ShapeSet<T> {
pub fn new() -> Self {
ShapeSet::<T>(Slab::<ComplexShapeId, T>::default())
}
pub fn add(&mut self, value: T) -> ComplexShapeId {
self.0.insert(value)
}
pub fn get(&self, id: ComplexShapeId) -> Option<&T> {
self.0.get(id)
}
pub fn remove(&mut self, id: ComplexShapeId) -> Option<T> {
self.0.remove(id)
}
}
impl<T: 'static + Send + Sync + ComplexShapeTrait> ShapeSetTrait for ShapeSet<T> {
fn as_any(&self) -> &dyn Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn Any {
self
}
fn compute_aabb(&self, id: ComplexShapeId) -> Aabb3 {
self.get(id).unwrap().compute_aabb()
}
fn max_radius_and_max_angular_expansion(&self, id: ComplexShapeId) -> (f32, f32) {
self.get(id).unwrap().max_radius_and_max_angular_expansion()
}
fn raycast(
&self,
id: ComplexShapeId,
local_ray: Ray,
max_distance: f32,
discard_inside_hit: bool,
) -> Option<RaycastHitResult> {
self.get(id)
.unwrap()
.raycast(local_ray, max_distance, discard_inside_hit)
}
fn contains_point_with_threshold(
&self,
id: ComplexShapeId,
local_point: Point3,
threshold: f32,
) -> ContainsResult {
self.get(id)
.unwrap()
.contains_point_with_threshold(local_point, threshold)
}
fn signed_distance_to_point(&self, id: ComplexShapeId, local_point: Point3) -> f32 {
self.get(id).unwrap().signed_distance_to_point(local_point)
}
fn compute_volume(&self, id: ComplexShapeId) -> f32 {
self.get(id).unwrap().compute_volume()
}
}
#[derive(Default, Clone)]
pub struct ShapeContainer {
type_containers: FxHashMap<u64, Box<dyn ShapeSetTrait>>,
}
impl ShapeContainer {
#[inline]
#[must_use]
pub fn get_shape_set_by_id(&self, plugin_id: u64) -> &dyn ShapeSetTrait {
self.type_containers
.get(&plugin_id)
.expect("plugin not found")
.as_ref()
}
#[inline]
#[must_use]
pub fn get<T: ComplexShapeTrait + 'static>(&self, id: ComplexShapeId) -> Option<&T> {
if let Some(shape_set) = self.get_shape_set::<T>() {
shape_set.get(id)
} else {
None
}
}
#[inline]
#[must_use]
pub fn add<T: ComplexShapeTrait + 'static + Send + Sync>(
&mut self,
value: T,
) -> ComplexShapeId {
let plugin_id = T::PLUGIN_ID;
let id = if let Some(container) = self.type_containers.get_mut(&plugin_id) {
let shape_set = container
.as_any_mut()
.downcast_mut::<ShapeSet<T>>()
.unwrap();
shape_set.add(value)
} else {
let mut shape_set = ShapeSet::<T>::new();
let id = shape_set.add(value);
self.type_containers.insert(plugin_id, Box::new(shape_set));
id
};
id
}
#[inline]
pub fn remove<T: ComplexShapeTrait + 'static>(&mut self, id: ComplexShapeId) -> Option<T> {
let plugin_id = T::PLUGIN_ID;
if let Some(container) = self.type_containers.get_mut(&plugin_id) {
let shape_set = container
.as_any_mut()
.downcast_mut::<ShapeSet<T>>()
.unwrap();
shape_set.remove(id)
} else {
None
}
}
fn get_shape_set<T: ComplexShapeTrait + 'static>(&self) -> Option<&ShapeSet<T>> {
let plugin_id = T::PLUGIN_ID;
if let Some(container) = self.type_containers.get(&plugin_id) {
let shape_set = container.as_any().downcast_ref::<ShapeSet<T>>().unwrap();
Some(shape_set)
} else {
None
}
}
}
#[cfg(test)]
mod tests {
use glam_det::Point3;
use phys_geom::volume::ComputeVolume;
use phys_geom::ComputeAabb3;
use wasm_bindgen_test::*;
use super::ComplexShapeId;
use crate::ray::{Raycast, RaycastHitResult};
use crate::shapes::convex_hull::ConvexHull;
use crate::shapes::ShapeContainer;
use crate::{ComplexShapeTrait, ContainsPoint, Expansion, ShapePlugin, SignedDistanceToPoint};
#[derive(Clone, Debug)]
struct TestShape;
impl ShapePlugin for TestShape {
const PLUGIN_ID: u64 = 0xc54a471a;
}
impl ComputeAabb3 for TestShape {
fn compute_aabb(&self) -> phys_geom::Aabb3 {
phys_geom::Aabb3::default()
}
}
impl Expansion for TestShape {
fn max_radius_and_max_angular_expansion(&self) -> (f32, f32) {
(0.0, 0.0)
}
}
impl Raycast for TestShape {
fn raycast(
&self,
_local_ray: crate::Ray,
_max_distance: f32,
_discard_inside_hit: bool,
) -> std::option::Option<RaycastHitResult> {
None
}
}
impl ContainsPoint for TestShape {
fn contains_point_with_threshold(
&self,
_local_point: Point3,
_threshold: f32,
) -> crate::ContainsResult {
crate::ContainsResult::Outside
}
}
impl SignedDistanceToPoint for TestShape {
fn signed_distance_to_point(&self, _local_point: Point3) -> f32 {
0.0
}
}
impl ComputeVolume for TestShape {
fn compute_volume(&self) -> phys_geom::math::Real {
0.0
}
}
impl ComplexShapeTrait for TestShape {}
#[test]
#[wasm_bindgen_test]
fn test_add_and_get() {
let _ = env_logger::builder().is_test(true).try_init();
let test_shape = TestShape;
let mut shape_batches = ShapeContainer::default();
let test_shape_id = shape_batches.add(test_shape);
let invalid_id = ComplexShapeId(0x12456);
assert!(shape_batches.get::<TestShape>(test_shape_id).is_some());
assert!(shape_batches.get::<ConvexHull>(invalid_id).is_none());
let test_shape_set = shape_batches.get_shape_set::<TestShape>();
let convex_hull_set = shape_batches.get_shape_set::<ConvexHull>();
assert!(test_shape_set.is_some());
assert!(convex_hull_set.is_none());
let length = test_shape_set.unwrap().0.len();
assert_eq!(length, 1);
shape_batches.remove::<TestShape>(test_shape_id);
shape_batches.remove::<TestShape>(invalid_id);
let test_shape_set = shape_batches.get_shape_set::<TestShape>();
let length = test_shape_set.unwrap().0.len();
assert_eq!(length, 0);
}
}