1use crate::position::{Position, PositionFromInputOptions};
5use crate::utils::create_file_with_parents;
6
7use bincode::{deserialize_from, serialize_into};
8use pyo3::exceptions::PyValueError;
9use pyo3::{PyResult, pyclass, pymethods};
10use serde::{Deserialize, Serialize};
11use std::fs;
12use std::fs::File;
13use std::io::Read;
14use std::path::{Path, PathBuf};
15
16#[pyclass(module = "cs2_nav")]
18#[derive(Debug, Clone, Deserialize, Serialize)]
19pub struct Triangle {
20 pub p1: Position,
21 pub p2: Position,
22 pub p3: Position,
23}
24
25#[pymethods]
26impl Triangle {
27 #[new]
28 #[must_use]
29 pub const fn new(p1: Position, p2: Position, p3: Position) -> Self {
30 Self { p1, p2, p3 }
31 }
32
33 #[must_use]
34 pub fn get_centroid(&self) -> Position {
35 Position::new(
36 (self.p1.x + self.p2.x + self.p3.x) / 3.0,
37 (self.p1.y + self.p2.y + self.p3.y) / 3.0,
38 (self.p1.z + self.p2.z + self.p3.z) / 3.0,
39 )
40 }
41
42 #[must_use]
45 pub fn ray_intersection(&self, ray_origin: &Position, ray_direction: &Position) -> Option<f64> {
46 let epsilon = 1e-6;
47 let edge1 = self.p2 - self.p1;
48 let edge2 = self.p3 - self.p1;
49 let h = ray_direction.cross(&edge2);
50 let a = edge1.dot(&h);
51
52 if a.abs() < epsilon {
53 return None;
54 }
55
56 let f = 1.0 / a;
57 let s = *ray_origin - self.p1;
58 let u = f * s.dot(&h);
59 if !(0.0..=1.0).contains(&u) {
60 return None;
61 }
62
63 let q = s.cross(&edge1);
64 let v = f * ray_direction.dot(&q);
65 if v < 0.0 || (u + v) > 1.0 {
66 return None;
67 }
68
69 let t = f * edge2.dot(&q);
70 if t > epsilon { Some(t) } else { None }
71 }
72}
73
74#[derive(Debug, Clone, Deserialize, Serialize)]
76pub struct Edge {
77 pub next: i32,
78 pub twin: i32,
79 pub origin: i32,
80 pub face: i32,
81}
82
83#[derive(Debug, Clone, Deserialize, Serialize)]
85pub struct Aabb {
86 pub min_point: Position,
87 pub max_point: Position,
88}
89
90fn check_axis(origin: f64, direction: f64, min_val: f64, max_val: f64, epsilon: f64) -> (f64, f64) {
91 if direction.abs() < epsilon {
92 if origin < min_val || origin > max_val {
93 return (f64::INFINITY, f64::NEG_INFINITY);
94 }
95 return (f64::NEG_INFINITY, f64::INFINITY);
96 }
97 let t1 = (min_val - origin) / direction;
98 let t2 = (max_val - origin) / direction;
99 (t1.min(t2), t1.max(t2))
100}
101
102impl Aabb {
103 #[must_use]
104 pub const fn from_triangle(triangle: &Triangle) -> Self {
105 let min_point = Position::new(
106 triangle.p1.x.min(triangle.p2.x).min(triangle.p3.x),
107 triangle.p1.y.min(triangle.p2.y).min(triangle.p3.y),
108 triangle.p1.z.min(triangle.p2.z).min(triangle.p3.z),
109 );
110 let max_point = Position::new(
111 triangle.p1.x.max(triangle.p2.x).max(triangle.p3.x),
112 triangle.p1.y.max(triangle.p2.y).max(triangle.p3.y),
113 triangle.p1.z.max(triangle.p2.z).max(triangle.p3.z),
114 );
115 Self {
116 min_point,
117 max_point,
118 }
119 }
120
121 #[must_use]
122 pub fn intersects_ray(&self, ray_origin: &Position, ray_direction: &Position) -> bool {
123 let epsilon = 1e-6;
124
125 let (tx_min, tx_max) = check_axis(
126 ray_origin.x,
127 ray_direction.x,
128 self.min_point.x,
129 self.max_point.x,
130 epsilon,
131 );
132 let (ty_min, ty_max) = check_axis(
133 ray_origin.y,
134 ray_direction.y,
135 self.min_point.y,
136 self.max_point.y,
137 epsilon,
138 );
139 let (tz_min, tz_max) = check_axis(
140 ray_origin.z,
141 ray_direction.z,
142 self.min_point.z,
143 self.max_point.z,
144 epsilon,
145 );
146
147 let t_enter = tx_min.max(ty_min).max(tz_min);
148 let t_exit = tx_max.min(ty_max).min(tz_max);
149
150 t_enter <= t_exit && t_exit >= 0.0
151 }
152}
153
154impl std::fmt::Display for Aabb {
155 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
156 write!(
157 f,
158 "AABB(min_point={:?}, max_point={:?})",
159 self.min_point, self.max_point
160 )
161 }
162}
163
164#[derive(Debug, Clone, Deserialize, Serialize)]
166pub struct BVHNode {
167 pub aabb: Aabb,
168 pub kind: BVHNodeKind,
169}
170
171#[derive(Debug, Clone, Deserialize, Serialize)]
173pub enum BVHNodeKind {
174 Leaf {
175 triangle: Triangle,
176 },
177 Internal {
178 left: Box<BVHNode>,
179 right: Box<BVHNode>,
180 },
181}
182
183#[pyclass(name = "VisibilityChecker", module = "cs2_nav")]
185#[derive(Debug, Clone, Deserialize, Serialize)]
186pub struct CollisionChecker {
187 #[pyo3(get)]
188 pub n_triangles: usize,
189 pub root: BVHNode,
190}
191
192impl CollisionChecker {
193 #[must_use]
195 pub fn new(tri_file: &Path) -> Self {
196 let triangles = Self::read_tri_file(tri_file, 1000);
197
198 let n_triangles = triangles.len();
199 let root = Self::build_bvh(triangles);
200 Self { n_triangles, root }
201 }
202
203 pub fn read_tri_file<P: AsRef<Path>>(tri_file: P, buffer_size: usize) -> Vec<Triangle> {
210 let chunk_size: usize = buffer_size * 9 * 4;
212 let mut triangles = Vec::new();
213 let mut file = fs::File::open(tri_file).expect("Unable to open tri file");
214 let mut buffer = vec![0u8; chunk_size].into_boxed_slice();
215
216 loop {
217 let n = file.read(&mut buffer).expect("Failed to read file");
218 if n == 0 {
219 break;
220 }
221 let num_complete_triangles = n / 36;
223 for i in 0..num_complete_triangles {
224 let offset = i * 36;
225 let slice = &buffer[offset..offset + 36];
226 let mut values = [0f32; 9];
227 for (i, chunk) in slice.chunks_exact(4).enumerate() {
228 values[i] = f32::from_ne_bytes(chunk.try_into().unwrap());
229 }
230 triangles.push(Triangle {
231 p1: Position::new(
232 f64::from(values[0]),
233 f64::from(values[1]),
234 f64::from(values[2]),
235 ),
236 p2: Position::new(
237 f64::from(values[3]),
238 f64::from(values[4]),
239 f64::from(values[5]),
240 ),
241 p3: Position::new(
242 f64::from(values[6]),
243 f64::from(values[7]),
244 f64::from(values[8]),
245 ),
246 });
247 }
248 }
249 triangles
250 }
251
252 pub fn build_bvh(triangles: Vec<Triangle>) -> BVHNode {
258 assert!(!triangles.is_empty(), "No triangles provided");
259 if triangles.len() == 1 {
260 return BVHNode {
261 aabb: Aabb::from_triangle(&triangles[0]),
262 kind: BVHNodeKind::Leaf {
263 triangle: triangles[0].clone(),
264 },
265 };
266 }
267 let centroids: Vec<Position> = triangles.iter().map(Triangle::get_centroid).collect();
269
270 let (min_x, max_x) = centroids
272 .iter()
273 .fold((f64::INFINITY, f64::NEG_INFINITY), |(min, max), c| {
274 (min.min(c.x), max.max(c.x))
275 });
276 let (min_y, max_y) = centroids
277 .iter()
278 .fold((f64::INFINITY, f64::NEG_INFINITY), |(min, max), c| {
279 (min.min(c.y), max.max(c.y))
280 });
281 let (min_z, max_z) = centroids
282 .iter()
283 .fold((f64::INFINITY, f64::NEG_INFINITY), |(min, max), c| {
284 (min.min(c.z), max.max(c.z))
285 });
286 let x_spread = max_x - min_x;
287 let y_spread = max_y - min_y;
288 let z_spread = max_z - min_z;
289
290 let axis = if x_spread >= y_spread && x_spread >= z_spread {
292 0
293 } else if y_spread >= z_spread {
294 1
295 } else {
296 2
297 };
298
299 let mut triangles_sorted = triangles;
301 triangles_sorted.sort_by(|a, b| {
302 let ca = a.get_centroid();
303 let cb = b.get_centroid();
304 let coord_a = if axis == 0 {
305 ca.x
306 } else if axis == 1 {
307 ca.y
308 } else {
309 ca.z
310 };
311 let coord_b = if axis == 0 {
312 cb.x
313 } else if axis == 1 {
314 cb.y
315 } else {
316 cb.z
317 };
318 coord_a.partial_cmp(&coord_b).unwrap()
319 });
320
321 let mid = triangles_sorted.len() / 2;
322 let left = Self::build_bvh(triangles_sorted[..mid].to_vec());
323 let right = Self::build_bvh(triangles_sorted[mid..].to_vec());
324
325 let min_point = Position::new(
327 left.aabb.min_point.x.min(right.aabb.min_point.x),
328 left.aabb.min_point.y.min(right.aabb.min_point.y),
329 left.aabb.min_point.z.min(right.aabb.min_point.z),
330 );
331 let max_point = Position::new(
332 left.aabb.max_point.x.max(right.aabb.max_point.x),
333 left.aabb.max_point.y.max(right.aabb.max_point.y),
334 left.aabb.max_point.z.max(right.aabb.max_point.z),
335 );
336
337 BVHNode {
338 aabb: Aabb {
339 min_point,
340 max_point,
341 },
342 kind: BVHNodeKind::Internal {
343 left: Box::new(left),
344 right: Box::new(right),
345 },
346 }
347 }
348
349 fn traverse_bvh(
351 node: &BVHNode,
352 ray_origin: &Position,
353 ray_direction: &Position,
354 max_distance: f64,
355 ) -> bool {
356 if !node.aabb.intersects_ray(ray_origin, ray_direction) {
357 return false;
358 }
359
360 match node.kind {
361 BVHNodeKind::Internal {
362 ref left,
363 ref right,
364 } => {
365 Self::traverse_bvh(left, ray_origin, ray_direction, max_distance)
366 || Self::traverse_bvh(right, ray_origin, ray_direction, max_distance)
367 }
368 BVHNodeKind::Leaf { ref triangle } => triangle
369 .ray_intersection(ray_origin, ray_direction)
370 .is_some_and(|t| t <= max_distance),
371 }
372 }
373
374 pub fn save_to_binary(&self, filename: &Path) {
380 let mut file = create_file_with_parents(filename);
381 serialize_into(&mut file, &self).unwrap();
382 }
383
384 #[must_use]
389 pub fn from_binary(filename: &Path) -> Self {
390 let mut file = File::open(filename).unwrap();
391 deserialize_from(&mut file).unwrap()
392 }
393
394 #[must_use]
397 pub fn connection_unobstructed(&self, start: Position, end: Position) -> bool {
398 let mut direction = end - start;
399 let distance = direction.length();
400 if distance < 1e-6 {
401 return true;
402 }
403 direction = direction.normalize();
404 !Self::traverse_bvh(&self.root, &start, &direction, distance)
406 }
407}
408
409#[pymethods]
410impl CollisionChecker {
411 #[new]
417 #[pyo3(signature = (path=None, triangles=None))]
418 fn py_new(path: Option<PathBuf>, triangles: Option<Vec<Triangle>>) -> PyResult<Self> {
419 let triangles = match (path, triangles) {
420 (Some(tri_file), None) => Self::read_tri_file(tri_file, 1000),
421 (None, Some(triangles)) => triangles,
422 _ => {
423 return Err(PyValueError::new_err(
424 "Exactly one of tri_file or triangles must be provided",
425 ));
426 }
427 };
428
429 let n_triangles = triangles.len();
430 if n_triangles == 0 {
431 return Err(PyValueError::new_err("No triangles provided"));
432 }
433 let root = Self::build_bvh(triangles);
434 Ok(Self { n_triangles, root })
435 }
436
437 fn is_visible(
438 &self,
439 start: PositionFromInputOptions,
440 end: PositionFromInputOptions,
441 ) -> PyResult<bool> {
442 Ok(self.connection_unobstructed(Position::from_input(start)?, Position::from_input(end)?))
443 }
444
445 #[must_use]
446 fn __repr__(&self) -> String {
447 format!("VisibilityChecker(n_triangles={})", self.n_triangles)
448 }
449
450 #[must_use]
451 #[allow(clippy::needless_pass_by_value)]
452 #[staticmethod]
453 #[pyo3(name = "read_tri_file")]
454 #[pyo3(signature = (tri_file, buffer_size=1000))]
455 fn py_read_tri_file(tri_file: PathBuf, buffer_size: usize) -> Vec<Triangle> {
456 Self::read_tri_file(tri_file, buffer_size)
457 }
458}
459
460impl std::fmt::Display for CollisionChecker {
461 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
462 write!(f, "CollisionChecker(n_triangles={})", self.n_triangles)
463 }
464}
465
466#[derive(Debug, Clone, Copy, Deserialize, Serialize)]
467pub enum CollisionCheckerStyle {
468 Visibility,
469 Walkability,
470}
471
472#[must_use]
478pub fn load_collision_checker(map_name: &str, style: CollisionCheckerStyle) -> CollisionChecker {
479 let postfix = match style {
480 CollisionCheckerStyle::Visibility => "",
481 CollisionCheckerStyle::Walkability => "-clippings",
482 };
483 let current_file = PathBuf::from(file!());
484 let base = current_file
485 .parent()
486 .expect("No parent found")
487 .parent()
488 .unwrap();
489 let tri_path = base.join("tri").join(format!("{map_name}{postfix}.tri"));
490 let mut binary_path = tri_path.clone();
491 binary_path.set_extension("vis");
492
493 if binary_path.exists() {
494 println!(
495 "Loading collision checker with style {style:?} from binary: {}",
496 binary_path.file_stem().unwrap().to_string_lossy()
497 );
498 return CollisionChecker::from_binary(&binary_path);
499 }
500 println!("{tri_path:?}");
501 println!(
502 "Building collision checker with style {style:?} from tri: {}",
503 tri_path.file_stem().unwrap().to_string_lossy()
504 );
505 let vis_checker = CollisionChecker::new(&tri_path);
506 vis_checker.save_to_binary(&binary_path);
507 vis_checker
508}