cs2_nav/
collisions.rs

1/// Module for ray collision detection using a Bounding Volume Hierarchy tree.
2///
3/// Taken from: <https://github.com/pnxenopoulos/awpy/blob/main/awpy/visibility.py>
4use crate::position::{Position, PositionFromInputOptions};
5use crate::utils::create_file_with_parents;
6
7use bincode::{deserialize_from, serialize_into};
8use pyo3::exceptions::PyValueError;
9use pyo3::{PyResult, pyclass, pymethods};
10use serde::{Deserialize, Serialize};
11use std::fs;
12use std::fs::File;
13use std::io::Read;
14use std::path::{Path, PathBuf};
15
16/// A triangle in 3D space used for ray intersection checks.
17#[pyclass(module = "cs2_nav")]
18#[derive(Debug, Clone, Deserialize, Serialize)]
19pub struct Triangle {
20    pub p1: Position,
21    pub p2: Position,
22    pub p3: Position,
23}
24
25#[pymethods]
26impl Triangle {
27    #[new]
28    #[must_use]
29    pub const fn new(p1: Position, p2: Position, p3: Position) -> Self {
30        Self { p1, p2, p3 }
31    }
32
33    #[must_use]
34    pub fn get_centroid(&self) -> Position {
35        Position::new(
36            (self.p1.x + self.p2.x + self.p3.x) / 3.0,
37            (self.p1.y + self.p2.y + self.p3.y) / 3.0,
38            (self.p1.z + self.p2.z + self.p3.z) / 3.0,
39        )
40    }
41
42    /// Check for ray-triangle intersection.
43    /// Returns Some(distance) if intersecting; otherwise None.
44    #[must_use]
45    pub fn ray_intersection(&self, ray_origin: &Position, ray_direction: &Position) -> Option<f64> {
46        let epsilon = 1e-6;
47        let edge1 = self.p2 - self.p1;
48        let edge2 = self.p3 - self.p1;
49        let h = ray_direction.cross(&edge2);
50        let a = edge1.dot(&h);
51
52        if a.abs() < epsilon {
53            return None;
54        }
55
56        let f = 1.0 / a;
57        let s = *ray_origin - self.p1;
58        let u = f * s.dot(&h);
59        if !(0.0..=1.0).contains(&u) {
60            return None;
61        }
62
63        let q = s.cross(&edge1);
64        let v = f * ray_direction.dot(&q);
65        if v < 0.0 || (u + v) > 1.0 {
66            return None;
67        }
68
69        let t = f * edge2.dot(&q);
70        if t > epsilon { Some(t) } else { None }
71    }
72}
73
74// ---------- Edge ----------
75#[derive(Debug, Clone, Deserialize, Serialize)]
76pub struct Edge {
77    pub next: i32,
78    pub twin: i32,
79    pub origin: i32,
80    pub face: i32,
81}
82
83/// Axis-Aligned Bounding Box for efficient collision detection.
84#[derive(Debug, Clone, Deserialize, Serialize)]
85pub struct Aabb {
86    pub min_point: Position,
87    pub max_point: Position,
88}
89
90fn check_axis(origin: f64, direction: f64, min_val: f64, max_val: f64, epsilon: f64) -> (f64, f64) {
91    if direction.abs() < epsilon {
92        if origin < min_val || origin > max_val {
93            return (f64::INFINITY, f64::NEG_INFINITY);
94        }
95        return (f64::NEG_INFINITY, f64::INFINITY);
96    }
97    let t1 = (min_val - origin) / direction;
98    let t2 = (max_val - origin) / direction;
99    (t1.min(t2), t1.max(t2))
100}
101
102impl Aabb {
103    #[must_use]
104    pub const fn from_triangle(triangle: &Triangle) -> Self {
105        let min_point = Position::new(
106            triangle.p1.x.min(triangle.p2.x).min(triangle.p3.x),
107            triangle.p1.y.min(triangle.p2.y).min(triangle.p3.y),
108            triangle.p1.z.min(triangle.p2.z).min(triangle.p3.z),
109        );
110        let max_point = Position::new(
111            triangle.p1.x.max(triangle.p2.x).max(triangle.p3.x),
112            triangle.p1.y.max(triangle.p2.y).max(triangle.p3.y),
113            triangle.p1.z.max(triangle.p2.z).max(triangle.p3.z),
114        );
115        Self {
116            min_point,
117            max_point,
118        }
119    }
120
121    #[must_use]
122    pub fn intersects_ray(&self, ray_origin: &Position, ray_direction: &Position) -> bool {
123        let epsilon = 1e-6;
124
125        let (tx_min, tx_max) = check_axis(
126            ray_origin.x,
127            ray_direction.x,
128            self.min_point.x,
129            self.max_point.x,
130            epsilon,
131        );
132        let (ty_min, ty_max) = check_axis(
133            ray_origin.y,
134            ray_direction.y,
135            self.min_point.y,
136            self.max_point.y,
137            epsilon,
138        );
139        let (tz_min, tz_max) = check_axis(
140            ray_origin.z,
141            ray_direction.z,
142            self.min_point.z,
143            self.max_point.z,
144            epsilon,
145        );
146
147        let t_enter = tx_min.max(ty_min).max(tz_min);
148        let t_exit = tx_max.min(ty_max).min(tz_max);
149
150        t_enter <= t_exit && t_exit >= 0.0
151    }
152}
153
154impl std::fmt::Display for Aabb {
155    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
156        write!(
157            f,
158            "AABB(min_point={:?}, max_point={:?})",
159            self.min_point, self.max_point
160        )
161    }
162}
163
164/// Node in the Bounding Volume Hierarchy tree.
165#[derive(Debug, Clone, Deserialize, Serialize)]
166pub struct BVHNode {
167    pub aabb: Aabb,
168    pub kind: BVHNodeKind,
169}
170
171/// Kind of BVH node: either a leaf with a triangle, or an internal node with two children.
172#[derive(Debug, Clone, Deserialize, Serialize)]
173pub enum BVHNodeKind {
174    Leaf {
175        triangle: Triangle,
176    },
177    Internal {
178        left: Box<BVHNode>,
179        right: Box<BVHNode>,
180    },
181}
182
183/// Collision checker using a Bounding Volume Hierarchy tree.
184#[pyclass(name = "VisibilityChecker", module = "cs2_nav")]
185#[derive(Debug, Clone, Deserialize, Serialize)]
186pub struct CollisionChecker {
187    #[pyo3(get)]
188    pub n_triangles: usize,
189    pub root: BVHNode,
190}
191
192impl CollisionChecker {
193    /// Construct a new `CollisionChecker` from a file of triangles or an existing list.
194    #[must_use]
195    pub fn new(tri_file: &Path) -> Self {
196        let triangles = Self::read_tri_file(tri_file, 1000);
197
198        let n_triangles = triangles.len();
199        let root = Self::build_bvh(triangles);
200        Self { n_triangles, root }
201    }
202
203    /// Read a .tri file containing triangles.
204    ///
205    /// From <https://github.com/pnxenopoulos/awpy/blob/main/awpy/visibility.py#L757>
206    /// # Panics
207    ///
208    /// Will panic if no file exists at the given path or if the file cannot be read.
209    pub fn read_tri_file<P: AsRef<Path>>(tri_file: P, buffer_size: usize) -> Vec<Triangle> {
210        // 9 f32 values per triangle, each f32 is 4 bytes.
211        let chunk_size: usize = buffer_size * 9 * 4;
212        let mut triangles = Vec::new();
213        let mut file = fs::File::open(tri_file).expect("Unable to open tri file");
214        let mut buffer = vec![0u8; chunk_size].into_boxed_slice();
215
216        loop {
217            let n = file.read(&mut buffer).expect("Failed to read file");
218            if n == 0 {
219                break;
220            }
221            // number of complete triangles in the buffer.
222            let num_complete_triangles = n / 36;
223            for i in 0..num_complete_triangles {
224                let offset = i * 36;
225                let slice = &buffer[offset..offset + 36];
226                let mut values = [0f32; 9];
227                for (i, chunk) in slice.chunks_exact(4).enumerate() {
228                    values[i] = f32::from_ne_bytes(chunk.try_into().unwrap());
229                }
230                triangles.push(Triangle {
231                    p1: Position::new(
232                        f64::from(values[0]),
233                        f64::from(values[1]),
234                        f64::from(values[2]),
235                    ),
236                    p2: Position::new(
237                        f64::from(values[3]),
238                        f64::from(values[4]),
239                        f64::from(values[5]),
240                    ),
241                    p3: Position::new(
242                        f64::from(values[6]),
243                        f64::from(values[7]),
244                        f64::from(values[8]),
245                    ),
246                });
247            }
248        }
249        triangles
250    }
251
252    /// Build a Bounding Volume Hierarchy tree from a list of triangles.
253    ///
254    /// # Panics
255    ///
256    /// Will panic if not triangles were provided or a triangle centroid coordinate comparison fails.
257    pub fn build_bvh(triangles: Vec<Triangle>) -> BVHNode {
258        assert!(!triangles.is_empty(), "No triangles provided");
259        if triangles.len() == 1 {
260            return BVHNode {
261                aabb: Aabb::from_triangle(&triangles[0]),
262                kind: BVHNodeKind::Leaf {
263                    triangle: triangles[0].clone(),
264                },
265            };
266        }
267        // Compute centroids.
268        let centroids: Vec<Position> = triangles.iter().map(Triangle::get_centroid).collect();
269
270        // Find spread along each axis.
271        let (min_x, max_x) = centroids
272            .iter()
273            .fold((f64::INFINITY, f64::NEG_INFINITY), |(min, max), c| {
274                (min.min(c.x), max.max(c.x))
275            });
276        let (min_y, max_y) = centroids
277            .iter()
278            .fold((f64::INFINITY, f64::NEG_INFINITY), |(min, max), c| {
279                (min.min(c.y), max.max(c.y))
280            });
281        let (min_z, max_z) = centroids
282            .iter()
283            .fold((f64::INFINITY, f64::NEG_INFINITY), |(min, max), c| {
284                (min.min(c.z), max.max(c.z))
285            });
286        let x_spread = max_x - min_x;
287        let y_spread = max_y - min_y;
288        let z_spread = max_z - min_z;
289
290        // Choose split axis: 0 = x, 1 = y, 2 = z.
291        let axis = if x_spread >= y_spread && x_spread >= z_spread {
292            0
293        } else if y_spread >= z_spread {
294            1
295        } else {
296            2
297        };
298
299        // Sort triangles based on centroid coordinate.
300        let mut triangles_sorted = triangles;
301        triangles_sorted.sort_by(|a, b| {
302            let ca = a.get_centroid();
303            let cb = b.get_centroid();
304            let coord_a = if axis == 0 {
305                ca.x
306            } else if axis == 1 {
307                ca.y
308            } else {
309                ca.z
310            };
311            let coord_b = if axis == 0 {
312                cb.x
313            } else if axis == 1 {
314                cb.y
315            } else {
316                cb.z
317            };
318            coord_a.partial_cmp(&coord_b).unwrap()
319        });
320
321        let mid = triangles_sorted.len() / 2;
322        let left = Self::build_bvh(triangles_sorted[..mid].to_vec());
323        let right = Self::build_bvh(triangles_sorted[mid..].to_vec());
324
325        // Create encompassing AABB from children.
326        let min_point = Position::new(
327            left.aabb.min_point.x.min(right.aabb.min_point.x),
328            left.aabb.min_point.y.min(right.aabb.min_point.y),
329            left.aabb.min_point.z.min(right.aabb.min_point.z),
330        );
331        let max_point = Position::new(
332            left.aabb.max_point.x.max(right.aabb.max_point.x),
333            left.aabb.max_point.y.max(right.aabb.max_point.y),
334            left.aabb.max_point.z.max(right.aabb.max_point.z),
335        );
336
337        BVHNode {
338            aabb: Aabb {
339                min_point,
340                max_point,
341            },
342            kind: BVHNodeKind::Internal {
343                left: Box::new(left),
344                right: Box::new(right),
345            },
346        }
347    }
348
349    /// Traverse the BVH tree to check for ray intersections.
350    fn traverse_bvh(
351        node: &BVHNode,
352        ray_origin: &Position,
353        ray_direction: &Position,
354        max_distance: f64,
355    ) -> bool {
356        if !node.aabb.intersects_ray(ray_origin, ray_direction) {
357            return false;
358        }
359
360        match node.kind {
361            BVHNodeKind::Internal {
362                ref left,
363                ref right,
364            } => {
365                Self::traverse_bvh(left, ray_origin, ray_direction, max_distance)
366                    || Self::traverse_bvh(right, ray_origin, ray_direction, max_distance)
367            }
368            BVHNodeKind::Leaf { ref triangle } => triangle
369                .ray_intersection(ray_origin, ray_direction)
370                .is_some_and(|t| t <= max_distance),
371        }
372    }
373
374    /// Save the loaded collision checker with the BVH to a file.
375    ///
376    /// # Panics
377    ///
378    /// Will panic if the file cannot be created or written to.
379    pub fn save_to_binary(&self, filename: &Path) {
380        let mut file = create_file_with_parents(filename);
381        serialize_into(&mut file, &self).unwrap();
382    }
383
384    /// Load a struct instance from a JSON file
385    /// # Panics
386    ///
387    /// Will panic if the file cannot be read or deserialized.
388    #[must_use]
389    pub fn from_binary(filename: &Path) -> Self {
390        let mut file = File::open(filename).unwrap();
391        deserialize_from(&mut file).unwrap()
392    }
393
394    /// Check if the line segment between start and end is visible.
395    /// Returns true if no triangle obstructs the view.
396    #[must_use]
397    pub fn connection_unobstructed(&self, start: Position, end: Position) -> bool {
398        let mut direction = end - start;
399        let distance = direction.length();
400        if distance < 1e-6 {
401            return true;
402        }
403        direction = direction.normalize();
404        // If any intersection is found along the ray, then the segment is not visible.
405        !Self::traverse_bvh(&self.root, &start, &direction, distance)
406    }
407}
408
409#[pymethods]
410impl CollisionChecker {
411    /// Construct a new `CollisionChecker` from a file of triangles or an existing list.
412    ///
413    /// # Errors
414    ///
415    /// Will return an error if both or neither of `tri_file` and `triangles` are provided.
416    #[new]
417    #[pyo3(signature = (path=None, triangles=None))]
418    fn py_new(path: Option<PathBuf>, triangles: Option<Vec<Triangle>>) -> PyResult<Self> {
419        let triangles = match (path, triangles) {
420            (Some(tri_file), None) => Self::read_tri_file(tri_file, 1000),
421            (None, Some(triangles)) => triangles,
422            _ => {
423                return Err(PyValueError::new_err(
424                    "Exactly one of tri_file or triangles must be provided",
425                ));
426            }
427        };
428
429        let n_triangles = triangles.len();
430        if n_triangles == 0 {
431            return Err(PyValueError::new_err("No triangles provided"));
432        }
433        let root = Self::build_bvh(triangles);
434        Ok(Self { n_triangles, root })
435    }
436
437    fn is_visible(
438        &self,
439        start: PositionFromInputOptions,
440        end: PositionFromInputOptions,
441    ) -> PyResult<bool> {
442        Ok(self.connection_unobstructed(Position::from_input(start)?, Position::from_input(end)?))
443    }
444
445    #[must_use]
446    fn __repr__(&self) -> String {
447        format!("VisibilityChecker(n_triangles={})", self.n_triangles)
448    }
449
450    #[must_use]
451    #[allow(clippy::needless_pass_by_value)]
452    #[staticmethod]
453    #[pyo3(name = "read_tri_file")]
454    #[pyo3(signature = (tri_file, buffer_size=1000))]
455    fn py_read_tri_file(tri_file: PathBuf, buffer_size: usize) -> Vec<Triangle> {
456        Self::read_tri_file(tri_file, buffer_size)
457    }
458}
459
460impl std::fmt::Display for CollisionChecker {
461    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
462        write!(f, "CollisionChecker(n_triangles={})", self.n_triangles)
463    }
464}
465
466#[derive(Debug, Clone, Copy, Deserialize, Serialize)]
467pub enum CollisionCheckerStyle {
468    Visibility,
469    Walkability,
470}
471
472/// Load a visibility checker from a pickle file if available; otherwise build from a .tri file.
473/// # Panics
474///
475/// Will panic if the bath path for the tri or vis file cannot be constructed.
476/// Is "`CURRENT_FILE_PATH`/../../"
477#[must_use]
478pub fn load_collision_checker(map_name: &str, style: CollisionCheckerStyle) -> CollisionChecker {
479    let postfix = match style {
480        CollisionCheckerStyle::Visibility => "",
481        CollisionCheckerStyle::Walkability => "-clippings",
482    };
483    let current_file = PathBuf::from(file!());
484    let base = current_file
485        .parent()
486        .expect("No parent found")
487        .parent()
488        .unwrap();
489    let tri_path = base.join("tri").join(format!("{map_name}{postfix}.tri"));
490    let mut binary_path = tri_path.clone();
491    binary_path.set_extension("vis");
492
493    if binary_path.exists() {
494        println!(
495            "Loading collision checker with style {style:?} from binary: {}",
496            binary_path.file_stem().unwrap().to_string_lossy()
497        );
498        return CollisionChecker::from_binary(&binary_path);
499    }
500    println!("{tri_path:?}");
501    println!(
502        "Building collision checker with style {style:?} from tri: {}",
503        tri_path.file_stem().unwrap().to_string_lossy()
504    );
505    let vis_checker = CollisionChecker::new(&tri_path);
506    vis_checker.save_to_binary(&binary_path);
507    vis_checker
508}