comfy_core/
spatial_hash.rs

1use crate::*;
2use fxhash::FxHashMap;
3
4/// Experimental spatial hash.
5
6#[derive(Copy, Clone, Debug)]
7pub struct Intersection {
8    pub point: Vec2,
9    pub normal: Vec2,
10}
11
12#[derive(Clone, Copy, Debug)]
13pub struct AabbShape {
14    pub min: Vec2,
15    pub max: Vec2,
16}
17
18impl AabbShape {
19    pub fn shape(center: Vec2, size: Vec2) -> Shape {
20        let min = center - size / 2.0;
21        let max = center + size / 2.0;
22        Shape::Aabb(AabbShape { min, max })
23    }
24
25    pub fn intersects_circle(&self, circle: CircleShape) -> bool {
26        let closest = self.min.max(self.max.min(circle.center));
27        let distance = circle.center.distance(closest);
28        distance <= circle.radius
29    }
30
31    pub fn intersects_aabb(&self, aabb: AabbShape) -> bool {
32        self.min.x <= aabb.max.x &&
33            self.max.x >= aabb.min.x &&
34            self.min.y <= aabb.max.y &&
35            self.max.y >= aabb.min.y
36    }
37
38    pub fn center(&self) -> Vec2 {
39        (self.min + self.max) / 2.0
40    }
41
42    pub fn size(&self) -> Vec2 {
43        self.max - self.min
44    }
45
46    pub fn line_intersection(
47        &self,
48        start: Vec2,
49        end: Vec2,
50    ) -> Option<Intersection> {
51        let dir = end - start;
52
53        let mut tmin = (self.min.x - start.x) / dir.x;
54        let mut tmax = (self.max.x - start.x) / dir.x;
55
56        if tmin > tmax {
57            std::mem::swap(&mut tmin, &mut tmax);
58        }
59
60        let mut tymin = (self.min.y - start.y) / dir.y;
61        let mut tymax = (self.max.y - start.y) / dir.y;
62
63        if tymin > tymax {
64            std::mem::swap(&mut tymin, &mut tymax);
65        }
66
67        tmin = tmin.max(tymin);
68        tmax = tmax.min(tymax);
69
70        if tmin > tmax {
71            return None;
72        }
73
74        let t = if (0.0..=1.0).contains(&tmin) {
75            tmin
76        } else if (0.0..=1.0).contains(&tmax) {
77            tmax
78        } else {
79            return None;
80        };
81
82        let intersection_point = start + dir * t;
83
84        // Compute the normal
85        let mut normal = Vec2::ZERO;
86
87        // Determine which face was hit based on the intersection point
88        let tolerance = 1e-5; // A small tolerance value to account for floating point errors
89
90        if (intersection_point.x - self.min.x).abs() < tolerance {
91            normal = Vec2::new(-1.0, 0.0);
92        } else if (intersection_point.x - self.max.x).abs() < tolerance {
93            normal = Vec2::new(1.0, 0.0);
94        } else if (intersection_point.y - self.min.y).abs() < tolerance {
95            normal = Vec2::new(0.0, -1.0);
96        } else if (intersection_point.y - self.max.y).abs() < tolerance {
97            normal = Vec2::new(0.0, 1.0);
98        }
99
100        Some(Intersection { point: intersection_point, normal })
101    }
102}
103
104#[derive(Clone, Copy, Debug)]
105pub struct CircleShape {
106    pub center: Vec2,
107    pub radius: f32,
108}
109
110impl CircleShape {
111    pub fn bounding_rect(&self) -> AabbShape {
112        let min = self.center - Vec2::splat(self.radius);
113        let max = self.center + Vec2::splat(self.radius);
114        AabbShape { min, max }
115    }
116
117    pub fn intersects_circle(&self, circle: CircleShape) -> bool {
118        let distance = self.center.distance(circle.center);
119        distance <= self.radius + circle.radius
120    }
121
122    pub fn intersects_aabb(&self, aabb: AabbShape) -> bool {
123        aabb.intersects_circle(*self)
124    }
125
126    // pub fn intersects_line(&self, start: Vec2, end: Vec2) -> Option<Vec2> {
127    //     let to_target = self.center - start;
128    //
129    //     let line_vec = end - start;
130    //     let ray_len = line_vec.length();
131    //     let ray_dir = line_vec.normalize();
132    //
133    //     let dot = to_target.dot(ray_dir);
134    //
135    //     if dot < 0.0 || dot > ray_len {
136    //         return None;
137    //     }
138    //
139    //     let closest_point = start + ray_dir * dot;
140    //
141    //     let dist_squared = (self.center - closest_point).length_squared();
142    //
143    //     if dist_squared > self.radius.powi(2) {
144    //         return None;
145    //     }
146    //
147    //     let t = (self.radius.powi(2) - dist_squared).sqrt();
148    //
149    //     let intersection1 = closest_point + ray_dir * (0.0 - t);
150    //     let intersection2 = closest_point + ray_dir * t;
151    //
152    //     if (intersection1 - start).length() < (intersection2 - start).length() {
153    //         Some(intersection1)
154    //     } else {
155    //         Some(intersection2)
156    //     }
157    // }
158
159    pub fn intersects_line(
160        &self,
161        start: Vec2,
162        end: Vec2,
163    ) -> Option<Intersection> {
164        let to_target = self.center - start;
165
166        let line_vec = end - start;
167        let ray_len = line_vec.length();
168        let ray_dir = line_vec.normalize();
169
170        let dot = to_target.dot(ray_dir);
171
172        if dot < 0.0 || dot > ray_len {
173            return None;
174        }
175
176        let closest_point = start + ray_dir * dot;
177
178        let dist_squared = (self.center - closest_point).length_squared();
179
180        if dist_squared > self.radius.powi(2) {
181            return None;
182        }
183
184        let t = (self.radius.powi(2) - dist_squared).sqrt();
185
186        let intersection1 = closest_point + ray_dir * (0.0 - t);
187        let intersection2 = closest_point + ray_dir * t;
188
189        let intersection_point = if (intersection1 - start).length() <
190            (intersection2 - start).length()
191        {
192            intersection1
193        } else {
194            intersection2
195        };
196
197        // Calculate the normal at the intersection point
198        let normal = (intersection_point - self.center).normalize();
199
200        Some(Intersection { point: intersection_point, normal })
201    }
202}
203
204#[derive(Clone, Copy, Debug)]
205pub enum Shape {
206    Circle(CircleShape),
207    Aabb(AabbShape),
208}
209
210impl Shape {
211    pub fn bounding_rect(&self) -> AabbShape {
212        match self {
213            Shape::Circle(circle) => circle.bounding_rect(),
214            Shape::Aabb(aabb) => *aabb,
215        }
216    }
217
218    pub fn intersects_shape(&self, shape: Shape) -> bool {
219        match (*self, shape) {
220            (Shape::Circle(circle1), Shape::Circle(circle2)) => {
221                circle1.intersects_circle(circle2)
222            }
223            (Shape::Circle(circle), Shape::Aabb(aabb)) |
224            (Shape::Aabb(aabb), Shape::Circle(circle)) => {
225                circle.intersects_aabb(aabb)
226            }
227            (Shape::Aabb(aabb1), Shape::Aabb(aabb2)) => {
228                aabb1.intersects_aabb(aabb2)
229            }
230        }
231    }
232
233    pub fn intersects_line(
234        &self,
235        start: Vec2,
236        end: Vec2,
237    ) -> Option<Intersection> {
238        match self {
239            Shape::Circle(circle) => circle.intersects_line(start, end),
240            Shape::Aabb(aabb) => aabb.line_intersection(start, end),
241        }
242    }
243}
244
245#[derive(Clone, Copy)]
246pub enum SpatialQuery {
247    ShapeQuery(Shape),
248}
249
250#[derive(Clone, Copy)]
251pub struct SpatialHashData {
252    pub shape: Shape,
253    pub userdata: UserData,
254}
255
256pub struct SpatialHash {
257    pub grid_size: f32,
258    pub inner: FxHashMap<(i32, i32), Vec<SpatialHashData>>,
259}
260
261impl SpatialHash {
262    pub fn new() -> Self {
263        const DEFAULT_GRID_SIZE: f32 = 100.0;
264        Self { grid_size: DEFAULT_GRID_SIZE, inner: FxHashMap::default() }
265    }
266
267    pub fn clear(&mut self) {
268        self.inner.clear();
269    }
270
271    pub fn add_shape(&mut self, shape: Shape, data: UserData) {
272        match shape {
273            Shape::Circle(circle) => {
274                self.add_shape(Shape::Aabb(circle.bounding_rect()), data);
275            }
276            Shape::Aabb(aabb) => {
277                let min = aabb.min / self.grid_size;
278                let max = aabb.max / self.grid_size;
279                let min = min.floor();
280                let max = max.ceil();
281
282                for x in min.x as i32..max.x as i32 {
283                    for y in min.y as i32..max.y as i32 {
284                        let key = (x, y);
285                        let entry = self.inner.entry(key).or_default();
286
287                        entry.push(SpatialHashData { shape, userdata: data });
288                    }
289                }
290            }
291        }
292    }
293
294    pub fn query(
295        &self,
296        query: SpatialQuery,
297    ) -> impl Iterator<Item = &UserData> {
298        match query {
299            SpatialQuery::ShapeQuery(shape) => {
300                let bounding_rect = shape.bounding_rect();
301                let min = bounding_rect.min / self.grid_size;
302                let max = bounding_rect.max / self.grid_size;
303                let min = min.floor();
304                let max = max.ceil();
305                (min.x as i32..max.x as i32)
306                    .flat_map(move |x| {
307                        (min.y as i32..max.y as i32).map(move |y| (x, y))
308                    })
309                    .flat_map(move |key| {
310                        self.inner.get(&key).into_iter().flatten()
311                    })
312                    .filter(move |data| data.shape.intersects_shape(shape))
313                    .map(|data| &data.userdata)
314            }
315        }
316    }
317
318    pub fn raycast(
319        &self,
320        start: Vec2,
321        end: Vec2,
322    ) -> Option<(Intersection, &UserData)> {
323        let mut t = 0.0;
324        // let dir = (end - start).normalize();
325        let mut closest_intersection: Option<(Intersection, &UserData)> = None;
326
327        while t <= 1.0 {
328            let current_point = start + t * (end - start);
329            let key = (
330                (current_point.x / self.grid_size).floor() as i32,
331                (current_point.y / self.grid_size).floor() as i32,
332            );
333
334            if let Some(cell) = self.inner.get(&key) {
335                for spatial_data in cell {
336                    if let Some(intersection) =
337                        spatial_data.shape.intersects_line(start, end)
338                    {
339                        if closest_intersection.map_or(
340                            true,
341                            |(closest_point, _)| {
342                                intersection.point.distance_squared(start) <
343                                    closest_point
344                                        .point
345                                        .distance_squared(start)
346                            },
347                        ) {
348                            closest_intersection =
349                                Some((intersection, &spatial_data.userdata));
350                        }
351                    }
352                }
353            }
354
355            t += self.grid_size / (end - start).length();
356        }
357
358        // draw_text(
359        //     &format!("{:.1?} {:.1?} {:#.1?}", start, end, closest_intersection),
360        //     start,
361        //     WHITE,
362        //     TextAlign::Center,
363        // );
364
365        closest_intersection
366    }
367}
368
369#[derive(Clone, Copy, Default, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
370#[repr(C)]
371pub struct UserData {
372    pub entity_type: u64,
373    pub entity: Option<Entity>,
374}
375
376pub fn draw_spatial(spatial: &SpatialHash) {
377    for (_, bucket) in spatial.inner.iter() {
378        for item in bucket.iter() {
379            match &item.shape {
380                Shape::Circle(circle) => {
381                    draw_circle_outline(
382                        circle.center,
383                        circle.radius,
384                        0.1,
385                        RED,
386                        499,
387                    )
388                }
389                Shape::Aabb(aabb) => {
390                    draw_rect_outline(
391                        aabb.center(),
392                        aabb.size(),
393                        0.1,
394                        RED,
395                        499,
396                    );
397                }
398            }
399        }
400    }
401}