1use crate::position::Position;
5use crate::utils::create_file_with_parents;
6
7use bincode::{deserialize_from, serialize_into};
8use pyo3::exceptions::PyValueError;
9use pyo3::{PyResult, pyclass, pymethods};
10use serde::{Deserialize, Serialize};
11use std::fs;
12use std::fs::File;
13use std::io::Read;
14use std::path::{Path, PathBuf};
15
16#[pyclass(module = "cs2_nav")]
18#[derive(Debug, Clone, Deserialize, Serialize)]
19pub struct Triangle {
20 pub p1: Position,
21 pub p2: Position,
22 pub p3: Position,
23}
24
25#[pymethods]
26impl Triangle {
27 #[new]
28 #[must_use]
29 pub const fn new(p1: Position, p2: Position, p3: Position) -> Self {
30 Self { p1, p2, p3 }
31 }
32
33 #[must_use]
34 pub fn get_centroid(&self) -> Position {
35 Position::new(
36 (self.p1.x + self.p2.x + self.p3.x) / 3.0,
37 (self.p1.y + self.p2.y + self.p3.y) / 3.0,
38 (self.p1.z + self.p2.z + self.p3.z) / 3.0,
39 )
40 }
41}
42
43#[derive(Debug, Clone, Deserialize, Serialize)]
45pub struct Edge {
46 pub next: i32,
47 pub twin: i32,
48 pub origin: i32,
49 pub face: i32,
50}
51
52#[derive(Debug, Clone, Deserialize, Serialize)]
54pub struct Aabb {
55 pub min_point: Position,
56 pub max_point: Position,
57}
58
59fn check_axis(origin: f64, direction: f64, min_val: f64, max_val: f64, epsilon: f64) -> (f64, f64) {
60 if direction.abs() < epsilon {
61 if origin < min_val || origin > max_val {
62 return (f64::INFINITY, f64::NEG_INFINITY);
63 }
64 return (f64::NEG_INFINITY, f64::INFINITY);
65 }
66 let t1 = (min_val - origin) / direction;
67 let t2 = (max_val - origin) / direction;
68 (t1.min(t2), t1.max(t2))
69}
70
71impl Aabb {
72 #[must_use]
73 pub const fn from_triangle(triangle: &Triangle) -> Self {
74 let min_point = Position::new(
75 triangle.p1.x.min(triangle.p2.x).min(triangle.p3.x),
76 triangle.p1.y.min(triangle.p2.y).min(triangle.p3.y),
77 triangle.p1.z.min(triangle.p2.z).min(triangle.p3.z),
78 );
79 let max_point = Position::new(
80 triangle.p1.x.max(triangle.p2.x).max(triangle.p3.x),
81 triangle.p1.y.max(triangle.p2.y).max(triangle.p3.y),
82 triangle.p1.z.max(triangle.p2.z).max(triangle.p3.z),
83 );
84 Self {
85 min_point,
86 max_point,
87 }
88 }
89
90 #[must_use]
91 pub fn intersects_ray(&self, ray_origin: &Position, ray_direction: &Position) -> bool {
92 let epsilon = 1e-6;
93
94 let (tx_min, tx_max) = check_axis(
95 ray_origin.x,
96 ray_direction.x,
97 self.min_point.x,
98 self.max_point.x,
99 epsilon,
100 );
101 let (ty_min, ty_max) = check_axis(
102 ray_origin.y,
103 ray_direction.y,
104 self.min_point.y,
105 self.max_point.y,
106 epsilon,
107 );
108 let (tz_min, tz_max) = check_axis(
109 ray_origin.z,
110 ray_direction.z,
111 self.min_point.z,
112 self.max_point.z,
113 epsilon,
114 );
115
116 let t_enter = tx_min.max(ty_min).max(tz_min);
117 let t_exit = tx_max.min(ty_max).min(tz_max);
118
119 t_enter <= t_exit && t_exit >= 0.0
120 }
121}
122
123impl std::fmt::Display for Aabb {
124 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
125 write!(
126 f,
127 "AABB(min_point={:?}, max_point={:?})",
128 self.min_point, self.max_point
129 )
130 }
131}
132
133#[derive(Debug, Clone, Deserialize, Serialize)]
135pub struct BVHNode {
136 pub aabb: Aabb,
137 pub triangle: Option<Triangle>,
138 pub left: Option<Box<BVHNode>>,
139 pub right: Option<Box<BVHNode>>,
140}
141
142#[pyclass(name = "VisibilityChecker", module = "cs2_nav")]
144#[derive(Debug, Clone, Deserialize, Serialize)]
145pub struct CollisionChecker {
146 #[pyo3(get)]
147 pub n_triangles: usize,
148 pub root: BVHNode,
149}
150
151impl CollisionChecker {
152 #[must_use]
154 pub fn new(tri_file: &Path) -> Self {
155 let triangles = Self::read_tri_file(tri_file, 1000);
156
157 let n_triangles = triangles.len();
158 let root = Self::build_bvh(triangles);
159 Self { n_triangles, root }
160 }
161
162 pub fn read_tri_file<P: AsRef<Path>>(tri_file: P, buffer_size: usize) -> Vec<Triangle> {
169 let chunk_size: usize = buffer_size * 9 * 4;
171 let mut triangles = Vec::new();
172 let mut file = fs::File::open(tri_file).expect("Unable to open tri file");
173 let mut buffer = vec![0u8; chunk_size].into_boxed_slice();
174
175 loop {
176 let n = file.read(&mut buffer).expect("Failed to read file");
177 if n == 0 {
178 break;
179 }
180 let num_complete_triangles = n / 36;
182 for i in 0..num_complete_triangles {
183 let offset = i * 36;
184 let slice = &buffer[offset..offset + 36];
185 let mut values = [0f32; 9];
186 for (i, chunk) in slice.chunks_exact(4).enumerate() {
187 values[i] = f32::from_ne_bytes(chunk.try_into().unwrap());
188 }
189 triangles.push(Triangle {
190 p1: Position::new(
191 f64::from(values[0]),
192 f64::from(values[1]),
193 f64::from(values[2]),
194 ),
195 p2: Position::new(
196 f64::from(values[3]),
197 f64::from(values[4]),
198 f64::from(values[5]),
199 ),
200 p3: Position::new(
201 f64::from(values[6]),
202 f64::from(values[7]),
203 f64::from(values[8]),
204 ),
205 });
206 }
207 }
208 triangles
209 }
210
211 pub fn build_bvh(triangles: Vec<Triangle>) -> BVHNode {
217 assert!(!triangles.is_empty(), "No triangles provided");
218 if triangles.len() == 1 {
219 return BVHNode {
220 aabb: Aabb::from_triangle(&triangles[0]),
221 triangle: Some(triangles[0].clone()),
222 left: None,
223 right: None,
224 };
225 }
226 let centroids: Vec<Position> = triangles.iter().map(Triangle::get_centroid).collect();
228
229 let (min_x, max_x) = centroids
231 .iter()
232 .fold((f64::INFINITY, f64::NEG_INFINITY), |(min, max), c| {
233 (min.min(c.x), max.max(c.x))
234 });
235 let (min_y, max_y) = centroids
236 .iter()
237 .fold((f64::INFINITY, f64::NEG_INFINITY), |(min, max), c| {
238 (min.min(c.y), max.max(c.y))
239 });
240 let (min_z, max_z) = centroids
241 .iter()
242 .fold((f64::INFINITY, f64::NEG_INFINITY), |(min, max), c| {
243 (min.min(c.z), max.max(c.z))
244 });
245 let x_spread = max_x - min_x;
246 let y_spread = max_y - min_y;
247 let z_spread = max_z - min_z;
248
249 let axis = if x_spread >= y_spread && x_spread >= z_spread {
251 0
252 } else if y_spread >= z_spread {
253 1
254 } else {
255 2
256 };
257
258 let mut triangles_sorted = triangles;
260 triangles_sorted.sort_by(|a, b| {
261 let ca = a.get_centroid();
262 let cb = b.get_centroid();
263 let coord_a = if axis == 0 {
264 ca.x
265 } else if axis == 1 {
266 ca.y
267 } else {
268 ca.z
269 };
270 let coord_b = if axis == 0 {
271 cb.x
272 } else if axis == 1 {
273 cb.y
274 } else {
275 cb.z
276 };
277 coord_a.partial_cmp(&coord_b).unwrap()
278 });
279
280 let mid = triangles_sorted.len() / 2;
281 let left = Self::build_bvh(triangles_sorted[..mid].to_vec());
282 let right = Self::build_bvh(triangles_sorted[mid..].to_vec());
283
284 let min_point = Position::new(
286 left.aabb.min_point.x.min(right.aabb.min_point.x),
287 left.aabb.min_point.y.min(right.aabb.min_point.y),
288 left.aabb.min_point.z.min(right.aabb.min_point.z),
289 );
290 let max_point = Position::new(
291 left.aabb.max_point.x.max(right.aabb.max_point.x),
292 left.aabb.max_point.y.max(right.aabb.max_point.y),
293 left.aabb.max_point.z.max(right.aabb.max_point.z),
294 );
295
296 BVHNode {
297 aabb: Aabb {
298 min_point,
299 max_point,
300 },
301 triangle: None,
302 left: Some(Box::new(left)),
303 right: Some(Box::new(right)),
304 }
305 }
306
307 fn traverse_bvh(
309 node: &BVHNode,
310 ray_origin: &Position,
311 ray_direction: &Position,
312 max_distance: f64,
313 ) -> bool {
314 if !node.aabb.intersects_ray(ray_origin, ray_direction) {
315 return false;
316 }
317
318 if let Some(ref tri) = node.triangle {
319 if let Some(t) = Self::ray_triangle_intersection(ray_origin, ray_direction, tri) {
320 return t <= max_distance;
321 }
322 return false;
323 }
324
325 let left_hit = Self::traverse_bvh(
326 node.left.as_ref().unwrap(),
327 ray_origin,
328 ray_direction,
329 max_distance,
330 );
331 let right_hit = Self::traverse_bvh(
332 node.right.as_ref().unwrap(),
333 ray_origin,
334 ray_direction,
335 max_distance,
336 );
337 left_hit || right_hit
338 }
339
340 pub fn save_to_binary(&self, filename: &Path) {
346 let mut file = create_file_with_parents(filename);
347 serialize_into(&mut file, &self).unwrap();
348 }
349
350 #[must_use]
355 pub fn from_binary(filename: &Path) -> Self {
356 let mut file = File::open(filename).unwrap();
357 deserialize_from(&mut file).unwrap()
358 }
359}
360
361#[pymethods]
362impl CollisionChecker {
363 #[new]
369 #[pyo3(signature = (path=None, triangles=None))]
370 pub fn py_new(path: Option<PathBuf>, triangles: Option<Vec<Triangle>>) -> PyResult<Self> {
371 let triangles = match (path, triangles) {
372 (Some(tri_file), None) => Self::read_tri_file(tri_file, 1000),
373 (None, Some(triangles)) => triangles,
374 _ => {
375 return Err(PyValueError::new_err(
376 "Exactly one of tri_file or triangles must be provided",
377 ));
378 }
379 };
380
381 let n_triangles = triangles.len();
382 if n_triangles == 0 {
383 return Err(PyValueError::new_err("No triangles provided"));
384 }
385 let root = Self::build_bvh(triangles);
386 Ok(Self { n_triangles, root })
387 }
388
389 #[must_use]
392 #[pyo3(name = "is_visible")]
393 pub fn connection_unobstructed(&self, start: Position, end: Position) -> bool {
394 let mut direction = end - start;
395 let distance = direction.length();
396 if distance < 1e-6 {
397 return true;
398 }
399 direction = direction.normalize();
400 !Self::traverse_bvh(&self.root, &start, &direction, distance)
402 }
403
404 #[must_use]
405 pub fn __repr__(&self) -> String {
406 format!("VisibilityChecker(n_triangles={})", self.n_triangles)
407 }
408
409 #[must_use]
410 #[allow(clippy::needless_pass_by_value)]
411 #[staticmethod]
412 #[pyo3(name = "read_tri_file")]
413 #[pyo3(signature = (tri_file, buffer_size=1000))]
414 fn py_read_tri_file(tri_file: PathBuf, buffer_size: usize) -> Vec<Triangle> {
415 Self::read_tri_file(tri_file, buffer_size)
416 }
417
418 #[must_use]
421 #[staticmethod]
422 #[pyo3(name = "_ray_triangle_intersection")]
423 pub fn ray_triangle_intersection(
424 ray_origin: &Position,
425 ray_direction: &Position,
426 triangle: &Triangle,
427 ) -> Option<f64> {
428 let epsilon = 1e-6;
429 let edge1 = triangle.p2 - triangle.p1;
430 let edge2 = triangle.p3 - triangle.p1;
431 let h = ray_direction.cross(&edge2);
432 let a = edge1.dot(&h);
433
434 if a.abs() < epsilon {
435 return None;
436 }
437
438 let f = 1.0 / a;
439 let s = *ray_origin - triangle.p1;
440 let u = f * s.dot(&h);
441 if !(0.0..=1.0).contains(&u) {
442 return None;
443 }
444
445 let q = s.cross(&edge1);
446 let v = f * ray_direction.dot(&q);
447 if v < 0.0 || (u + v) > 1.0 {
448 return None;
449 }
450
451 let t = f * edge2.dot(&q);
452 if t > epsilon { Some(t) } else { None }
453 }
454}
455
456impl std::fmt::Display for CollisionChecker {
457 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
458 write!(f, "CollisionChecker(n_triangles={})", self.n_triangles)
459 }
460}
461
462#[derive(Debug, Clone, Copy, Deserialize, Serialize)]
463pub enum CollisionCheckerStyle {
464 Visibility,
465 Walkability,
466}
467
468#[must_use]
474pub fn load_collision_checker(map_name: &str, style: CollisionCheckerStyle) -> CollisionChecker {
475 let postfix = match style {
476 CollisionCheckerStyle::Visibility => "",
477 CollisionCheckerStyle::Walkability => "-clippings",
478 };
479 let current_file = PathBuf::from(file!());
480 let base = current_file
481 .parent()
482 .expect("No parent found")
483 .parent()
484 .unwrap();
485 let tri_path = base.join("tri").join(format!("{map_name}{postfix}.tri"));
486 let mut binary_path = tri_path.clone();
487 binary_path.set_extension("vis");
488
489 if binary_path.exists() {
490 println!(
491 "Loading collision checker with style {style:?} from binary: {}",
492 binary_path.file_stem().unwrap().to_string_lossy()
493 );
494 return CollisionChecker::from_binary(&binary_path);
495 }
496 println!("{tri_path:?}");
497 println!(
498 "Building collision checker with style {style:?} from tri: {}",
499 tri_path.file_stem().unwrap().to_string_lossy()
500 );
501 let vis_checker = CollisionChecker::new(&tri_path);
502 vis_checker.save_to_binary(&binary_path);
503 vis_checker
504}