cs2_nav/
collisions.rs

1/// Module for ray collision detection using a Bounding Volume Hierarchy tree.
2///
3/// Taken from: <https://github.com/pnxenopoulos/awpy/blob/main/awpy/visibility.py>
4use crate::position::Position;
5use crate::utils::create_file_with_parents;
6
7use bincode::{deserialize_from, serialize_into};
8use pyo3::exceptions::PyValueError;
9use pyo3::{PyResult, pyclass, pymethods};
10use serde::{Deserialize, Serialize};
11use std::fs;
12use std::fs::File;
13use std::io::Read;
14use std::path::{Path, PathBuf};
15
16/// A triangle in 3D space used for ray intersection checks.
17#[pyclass(module = "cs2_nav")]
18#[derive(Debug, Clone, Deserialize, Serialize)]
19pub struct Triangle {
20    pub p1: Position,
21    pub p2: Position,
22    pub p3: Position,
23}
24
25#[pymethods]
26impl Triangle {
27    #[new]
28    #[must_use]
29    pub const fn new(p1: Position, p2: Position, p3: Position) -> Self {
30        Self { p1, p2, p3 }
31    }
32
33    #[must_use]
34    pub fn get_centroid(&self) -> Position {
35        Position::new(
36            (self.p1.x + self.p2.x + self.p3.x) / 3.0,
37            (self.p1.y + self.p2.y + self.p3.y) / 3.0,
38            (self.p1.z + self.p2.z + self.p3.z) / 3.0,
39        )
40    }
41}
42
43// ---------- Edge ----------
44#[derive(Debug, Clone, Deserialize, Serialize)]
45pub struct Edge {
46    pub next: i32,
47    pub twin: i32,
48    pub origin: i32,
49    pub face: i32,
50}
51
52/// Axis-Aligned Bounding Box for efficient collision detection.
53#[derive(Debug, Clone, Deserialize, Serialize)]
54pub struct Aabb {
55    pub min_point: Position,
56    pub max_point: Position,
57}
58
59fn check_axis(origin: f64, direction: f64, min_val: f64, max_val: f64, epsilon: f64) -> (f64, f64) {
60    if direction.abs() < epsilon {
61        if origin < min_val || origin > max_val {
62            return (f64::INFINITY, f64::NEG_INFINITY);
63        }
64        return (f64::NEG_INFINITY, f64::INFINITY);
65    }
66    let t1 = (min_val - origin) / direction;
67    let t2 = (max_val - origin) / direction;
68    (t1.min(t2), t1.max(t2))
69}
70
71impl Aabb {
72    #[must_use]
73    pub const fn from_triangle(triangle: &Triangle) -> Self {
74        let min_point = Position::new(
75            triangle.p1.x.min(triangle.p2.x).min(triangle.p3.x),
76            triangle.p1.y.min(triangle.p2.y).min(triangle.p3.y),
77            triangle.p1.z.min(triangle.p2.z).min(triangle.p3.z),
78        );
79        let max_point = Position::new(
80            triangle.p1.x.max(triangle.p2.x).max(triangle.p3.x),
81            triangle.p1.y.max(triangle.p2.y).max(triangle.p3.y),
82            triangle.p1.z.max(triangle.p2.z).max(triangle.p3.z),
83        );
84        Self {
85            min_point,
86            max_point,
87        }
88    }
89
90    #[must_use]
91    pub fn intersects_ray(&self, ray_origin: &Position, ray_direction: &Position) -> bool {
92        let epsilon = 1e-6;
93
94        let (tx_min, tx_max) = check_axis(
95            ray_origin.x,
96            ray_direction.x,
97            self.min_point.x,
98            self.max_point.x,
99            epsilon,
100        );
101        let (ty_min, ty_max) = check_axis(
102            ray_origin.y,
103            ray_direction.y,
104            self.min_point.y,
105            self.max_point.y,
106            epsilon,
107        );
108        let (tz_min, tz_max) = check_axis(
109            ray_origin.z,
110            ray_direction.z,
111            self.min_point.z,
112            self.max_point.z,
113            epsilon,
114        );
115
116        let t_enter = tx_min.max(ty_min).max(tz_min);
117        let t_exit = tx_max.min(ty_max).min(tz_max);
118
119        t_enter <= t_exit && t_exit >= 0.0
120    }
121}
122
123impl std::fmt::Display for Aabb {
124    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
125        write!(
126            f,
127            "AABB(min_point={:?}, max_point={:?})",
128            self.min_point, self.max_point
129        )
130    }
131}
132
133/// Node in the Bounding Volume Hierarchy tree.
134#[derive(Debug, Clone, Deserialize, Serialize)]
135pub struct BVHNode {
136    pub aabb: Aabb,
137    pub triangle: Option<Triangle>,
138    pub left: Option<Box<BVHNode>>,
139    pub right: Option<Box<BVHNode>>,
140}
141
142/// Collision checker using a Bounding Volume Hierarchy tree.
143#[pyclass(name = "VisibilityChecker", module = "cs2_nav")]
144#[derive(Debug, Clone, Deserialize, Serialize)]
145pub struct CollisionChecker {
146    #[pyo3(get)]
147    pub n_triangles: usize,
148    pub root: BVHNode,
149}
150
151impl CollisionChecker {
152    /// Construct a new `CollisionChecker` from a file of triangles or an existing list.
153    #[must_use]
154    pub fn new(tri_file: &Path) -> Self {
155        let triangles = Self::read_tri_file(tri_file, 1000);
156
157        let n_triangles = triangles.len();
158        let root = Self::build_bvh(triangles);
159        Self { n_triangles, root }
160    }
161
162    /// Read a .tri file containing triangles.
163    ///
164    /// From <https://github.com/pnxenopoulos/awpy/blob/main/awpy/visibility.py#L757>
165    /// # Panics
166    ///
167    /// Will panic if no file exists at the given path or if the file cannot be read.
168    pub fn read_tri_file<P: AsRef<Path>>(tri_file: P, buffer_size: usize) -> Vec<Triangle> {
169        // 9 f32 values per triangle, each f32 is 4 bytes.
170        let chunk_size: usize = buffer_size * 9 * 4;
171        let mut triangles = Vec::new();
172        let mut file = fs::File::open(tri_file).expect("Unable to open tri file");
173        let mut buffer = vec![0u8; chunk_size].into_boxed_slice();
174
175        loop {
176            let n = file.read(&mut buffer).expect("Failed to read file");
177            if n == 0 {
178                break;
179            }
180            // number of complete triangles in the buffer.
181            let num_complete_triangles = n / 36;
182            for i in 0..num_complete_triangles {
183                let offset = i * 36;
184                let slice = &buffer[offset..offset + 36];
185                let mut values = [0f32; 9];
186                for (i, chunk) in slice.chunks_exact(4).enumerate() {
187                    values[i] = f32::from_ne_bytes(chunk.try_into().unwrap());
188                }
189                triangles.push(Triangle {
190                    p1: Position::new(
191                        f64::from(values[0]),
192                        f64::from(values[1]),
193                        f64::from(values[2]),
194                    ),
195                    p2: Position::new(
196                        f64::from(values[3]),
197                        f64::from(values[4]),
198                        f64::from(values[5]),
199                    ),
200                    p3: Position::new(
201                        f64::from(values[6]),
202                        f64::from(values[7]),
203                        f64::from(values[8]),
204                    ),
205                });
206            }
207        }
208        triangles
209    }
210
211    /// Build a Bounding Volume Hierarchy tree from a list of triangles.
212    ///
213    /// # Panics
214    ///
215    /// Will panic if not triangles were provided or a triangle centroid coordinate comparison fails.
216    pub fn build_bvh(triangles: Vec<Triangle>) -> BVHNode {
217        assert!(!triangles.is_empty(), "No triangles provided");
218        if triangles.len() == 1 {
219            return BVHNode {
220                aabb: Aabb::from_triangle(&triangles[0]),
221                triangle: Some(triangles[0].clone()),
222                left: None,
223                right: None,
224            };
225        }
226        // Compute centroids.
227        let centroids: Vec<Position> = triangles.iter().map(Triangle::get_centroid).collect();
228
229        // Find spread along each axis.
230        let (min_x, max_x) = centroids
231            .iter()
232            .fold((f64::INFINITY, f64::NEG_INFINITY), |(min, max), c| {
233                (min.min(c.x), max.max(c.x))
234            });
235        let (min_y, max_y) = centroids
236            .iter()
237            .fold((f64::INFINITY, f64::NEG_INFINITY), |(min, max), c| {
238                (min.min(c.y), max.max(c.y))
239            });
240        let (min_z, max_z) = centroids
241            .iter()
242            .fold((f64::INFINITY, f64::NEG_INFINITY), |(min, max), c| {
243                (min.min(c.z), max.max(c.z))
244            });
245        let x_spread = max_x - min_x;
246        let y_spread = max_y - min_y;
247        let z_spread = max_z - min_z;
248
249        // Choose split axis: 0 = x, 1 = y, 2 = z.
250        let axis = if x_spread >= y_spread && x_spread >= z_spread {
251            0
252        } else if y_spread >= z_spread {
253            1
254        } else {
255            2
256        };
257
258        // Sort triangles based on centroid coordinate.
259        let mut triangles_sorted = triangles;
260        triangles_sorted.sort_by(|a, b| {
261            let ca = a.get_centroid();
262            let cb = b.get_centroid();
263            let coord_a = if axis == 0 {
264                ca.x
265            } else if axis == 1 {
266                ca.y
267            } else {
268                ca.z
269            };
270            let coord_b = if axis == 0 {
271                cb.x
272            } else if axis == 1 {
273                cb.y
274            } else {
275                cb.z
276            };
277            coord_a.partial_cmp(&coord_b).unwrap()
278        });
279
280        let mid = triangles_sorted.len() / 2;
281        let left = Self::build_bvh(triangles_sorted[..mid].to_vec());
282        let right = Self::build_bvh(triangles_sorted[mid..].to_vec());
283
284        // Create encompassing AABB from children.
285        let min_point = Position::new(
286            left.aabb.min_point.x.min(right.aabb.min_point.x),
287            left.aabb.min_point.y.min(right.aabb.min_point.y),
288            left.aabb.min_point.z.min(right.aabb.min_point.z),
289        );
290        let max_point = Position::new(
291            left.aabb.max_point.x.max(right.aabb.max_point.x),
292            left.aabb.max_point.y.max(right.aabb.max_point.y),
293            left.aabb.max_point.z.max(right.aabb.max_point.z),
294        );
295
296        BVHNode {
297            aabb: Aabb {
298                min_point,
299                max_point,
300            },
301            triangle: None,
302            left: Some(Box::new(left)),
303            right: Some(Box::new(right)),
304        }
305    }
306
307    /// Traverse the BVH tree to check for ray intersections.
308    fn traverse_bvh(
309        node: &BVHNode,
310        ray_origin: &Position,
311        ray_direction: &Position,
312        max_distance: f64,
313    ) -> bool {
314        if !node.aabb.intersects_ray(ray_origin, ray_direction) {
315            return false;
316        }
317
318        if let Some(ref tri) = node.triangle {
319            if let Some(t) = Self::ray_triangle_intersection(ray_origin, ray_direction, tri) {
320                return t <= max_distance;
321            }
322            return false;
323        }
324
325        let left_hit = Self::traverse_bvh(
326            node.left.as_ref().unwrap(),
327            ray_origin,
328            ray_direction,
329            max_distance,
330        );
331        let right_hit = Self::traverse_bvh(
332            node.right.as_ref().unwrap(),
333            ray_origin,
334            ray_direction,
335            max_distance,
336        );
337        left_hit || right_hit
338    }
339
340    /// Save the loaded collision checker with the BVH to a file.
341    ///
342    /// # Panics
343    ///
344    /// Will panic if the file cannot be created or written to.
345    pub fn save_to_binary(&self, filename: &Path) {
346        let mut file = create_file_with_parents(filename);
347        serialize_into(&mut file, &self).unwrap();
348    }
349
350    /// Load a struct instance from a JSON file
351    /// # Panics
352    ///
353    /// Will panic if the file cannot be read or deserialized.
354    #[must_use]
355    pub fn from_binary(filename: &Path) -> Self {
356        let mut file = File::open(filename).unwrap();
357        deserialize_from(&mut file).unwrap()
358    }
359}
360
361#[pymethods]
362impl CollisionChecker {
363    /// Construct a new `CollisionChecker` from a file of triangles or an existing list.
364    ///
365    /// # Errors
366    ///
367    /// Will return an error if both or neither of `tri_file` and `triangles` are provided.
368    #[new]
369    #[pyo3(signature = (path=None, triangles=None))]
370    pub fn py_new(path: Option<PathBuf>, triangles: Option<Vec<Triangle>>) -> PyResult<Self> {
371        let triangles = match (path, triangles) {
372            (Some(tri_file), None) => Self::read_tri_file(tri_file, 1000),
373            (None, Some(triangles)) => triangles,
374            _ => {
375                return Err(PyValueError::new_err(
376                    "Exactly one of tri_file or triangles must be provided",
377                ));
378            }
379        };
380
381        let n_triangles = triangles.len();
382        if n_triangles == 0 {
383            return Err(PyValueError::new_err("No triangles provided"));
384        }
385        let root = Self::build_bvh(triangles);
386        Ok(Self { n_triangles, root })
387    }
388
389    /// Check if the line segment between start and end is visible.
390    /// Returns true if no triangle obstructs the view.
391    #[must_use]
392    #[pyo3(name = "is_visible")]
393    pub fn connection_unobstructed(&self, start: Position, end: Position) -> bool {
394        let mut direction = end - start;
395        let distance = direction.length();
396        if distance < 1e-6 {
397            return true;
398        }
399        direction = direction.normalize();
400        // If any intersection is found along the ray, then the segment is not visible.
401        !Self::traverse_bvh(&self.root, &start, &direction, distance)
402    }
403
404    #[must_use]
405    pub fn __repr__(&self) -> String {
406        format!("VisibilityChecker(n_triangles={})", self.n_triangles)
407    }
408
409    #[must_use]
410    #[allow(clippy::needless_pass_by_value)]
411    #[staticmethod]
412    #[pyo3(name = "read_tri_file")]
413    #[pyo3(signature = (tri_file, buffer_size=1000))]
414    fn py_read_tri_file(tri_file: PathBuf, buffer_size: usize) -> Vec<Triangle> {
415        Self::read_tri_file(tri_file, buffer_size)
416    }
417
418    /// Check for ray-triangle intersection.
419    /// Returns Some(distance) if intersecting; otherwise None.
420    #[must_use]
421    #[staticmethod]
422    #[pyo3(name = "_ray_triangle_intersection")]
423    pub fn ray_triangle_intersection(
424        ray_origin: &Position,
425        ray_direction: &Position,
426        triangle: &Triangle,
427    ) -> Option<f64> {
428        let epsilon = 1e-6;
429        let edge1 = triangle.p2 - triangle.p1;
430        let edge2 = triangle.p3 - triangle.p1;
431        let h = ray_direction.cross(&edge2);
432        let a = edge1.dot(&h);
433
434        if a.abs() < epsilon {
435            return None;
436        }
437
438        let f = 1.0 / a;
439        let s = *ray_origin - triangle.p1;
440        let u = f * s.dot(&h);
441        if !(0.0..=1.0).contains(&u) {
442            return None;
443        }
444
445        let q = s.cross(&edge1);
446        let v = f * ray_direction.dot(&q);
447        if v < 0.0 || (u + v) > 1.0 {
448            return None;
449        }
450
451        let t = f * edge2.dot(&q);
452        if t > epsilon { Some(t) } else { None }
453    }
454}
455
456impl std::fmt::Display for CollisionChecker {
457    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
458        write!(f, "CollisionChecker(n_triangles={})", self.n_triangles)
459    }
460}
461
462#[derive(Debug, Clone, Copy, Deserialize, Serialize)]
463pub enum CollisionCheckerStyle {
464    Visibility,
465    Walkability,
466}
467
468/// Load a visibility checker from a pickle file if available; otherwise build from a .tri file.
469/// # Panics
470///
471/// Will panic if the bath path for the tri or vis file cannot be constructed.
472/// Is "`CURRENT_FILE_PATH`/../../"
473#[must_use]
474pub fn load_collision_checker(map_name: &str, style: CollisionCheckerStyle) -> CollisionChecker {
475    let postfix = match style {
476        CollisionCheckerStyle::Visibility => "",
477        CollisionCheckerStyle::Walkability => "-clippings",
478    };
479    let current_file = PathBuf::from(file!());
480    let base = current_file
481        .parent()
482        .expect("No parent found")
483        .parent()
484        .unwrap();
485    let tri_path = base.join("tri").join(format!("{map_name}{postfix}.tri"));
486    let mut binary_path = tri_path.clone();
487    binary_path.set_extension("vis");
488
489    if binary_path.exists() {
490        println!(
491            "Loading collision checker with style {style:?} from binary: {}",
492            binary_path.file_stem().unwrap().to_string_lossy()
493        );
494        return CollisionChecker::from_binary(&binary_path);
495    }
496    println!("{tri_path:?}");
497    println!(
498        "Building collision checker with style {style:?} from tri: {}",
499        tri_path.file_stem().unwrap().to_string_lossy()
500    );
501    let vis_checker = CollisionChecker::new(&tri_path);
502    vis_checker.save_to_binary(&binary_path);
503    vis_checker
504}