use std::collections::HashMap;
use hisab::Vec2;
use serde::{Deserialize, Serialize};
use crate::mesh::{NavMesh, NavPolyId};
#[cfg(feature = "logging")]
use tracing::instrument;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct LayerId(pub u32);
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct LayeredPolyId {
pub layer: LayerId,
pub poly: NavPolyId,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub struct LayerConnection {
pub from: LayeredPolyId,
pub to: LayeredPolyId,
pub cost: f32,
pub bidirectional: bool,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct MultiLayerNavMesh {
layers: HashMap<LayerId, NavMesh>,
connections: Vec<LayerConnection>,
}
impl MultiLayerNavMesh {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn add_layer(&mut self, id: LayerId, mesh: NavMesh) {
self.layers.insert(id, mesh);
}
pub fn remove_layer(&mut self, id: LayerId) -> bool {
if self.layers.remove(&id).is_some() {
self.connections
.retain(|c| c.from.layer != id && c.to.layer != id);
true
} else {
false
}
}
#[must_use]
pub fn get_layer(&self, id: LayerId) -> Option<&NavMesh> {
self.layers.get(&id)
}
#[must_use]
pub fn layer_count(&self) -> usize {
self.layers.len()
}
pub fn add_connection(&mut self, conn: LayerConnection) {
self.connections.push(conn);
}
#[must_use]
pub fn connections(&self) -> &[LayerConnection] {
&self.connections
}
#[cfg_attr(feature = "logging", instrument(skip(self)))]
#[must_use]
pub fn find_path(
&self,
start_layer: LayerId,
start_pos: Vec2,
goal_layer: LayerId,
goal_pos: Vec2,
) -> Option<Vec<LayeredPolyId>> {
if start_layer == goal_layer {
let mesh = self.layers.get(&start_layer)?;
let poly_path = mesh.find_path(start_pos, goal_pos)?;
return Some(
poly_path
.into_iter()
.map(|p| LayeredPolyId {
layer: start_layer,
poly: p,
})
.collect(),
);
}
use std::cmp::Ordering;
use std::collections::BinaryHeap;
let mut nodes: Vec<LayeredPolyId> = Vec::new();
let mut positions: Vec<Vec2> = Vec::new();
let mut node_map: HashMap<LayeredPolyId, usize> = HashMap::new();
for (&lid, mesh) in &self.layers {
for poly in mesh.polys() {
let lpid = LayeredPolyId {
layer: lid,
poly: poly.id,
};
let idx = nodes.len();
node_map.insert(lpid, idx);
nodes.push(lpid);
positions.push(poly.centroid());
}
}
let n = nodes.len();
if n == 0 {
return None;
}
let start_mesh = self.layers.get(&start_layer)?;
let start_poly = start_mesh.find_poly_at(start_pos)?;
let start_id = LayeredPolyId {
layer: start_layer,
poly: start_poly,
};
let start_idx = *node_map.get(&start_id)?;
let goal_mesh = self.layers.get(&goal_layer)?;
let goal_poly = goal_mesh.find_poly_at(goal_pos)?;
let goal_id = LayeredPolyId {
layer: goal_layer,
poly: goal_poly,
};
let goal_idx = *node_map.get(&goal_id)?;
let mut adj: Vec<Vec<(usize, f32)>> = vec![Vec::new(); n];
for (&lid, mesh) in &self.layers {
for poly in mesh.polys() {
let from = LayeredPolyId {
layer: lid,
poly: poly.id,
};
let fi = node_map[&from];
for &nid in &poly.neighbors {
let to = LayeredPolyId {
layer: lid,
poly: nid,
};
if let Some(&ti) = node_map.get(&to) {
let cost = positions[fi].distance(positions[ti]) * poly.cost;
adj[fi].push((ti, cost));
}
}
}
}
for conn in &self.connections {
if let (Some(&fi), Some(&ti)) = (node_map.get(&conn.from), node_map.get(&conn.to)) {
adj[fi].push((ti, conn.cost));
if conn.bidirectional {
adj[ti].push((fi, conn.cost));
}
}
}
#[derive(Clone, Copy)]
struct N {
idx: usize,
f: f32,
}
impl PartialEq for N {
fn eq(&self, o: &Self) -> bool {
self.f == o.f
}
}
impl Eq for N {}
impl PartialOrd for N {
fn partial_cmp(&self, o: &Self) -> Option<Ordering> {
Some(self.cmp(o))
}
}
impl Ord for N {
fn cmp(&self, o: &Self) -> Ordering {
o.f.partial_cmp(&self.f).unwrap_or(Ordering::Equal)
}
}
let mut g = vec![f32::INFINITY; n];
let mut came: Vec<Option<usize>> = vec![None; n];
let mut closed = vec![false; n];
let mut open = BinaryHeap::new();
g[start_idx] = 0.0;
open.push(N {
idx: start_idx,
f: positions[start_idx].distance(positions[goal_idx]),
});
while let Some(cur) = open.pop() {
if cur.idx == goal_idx {
let mut path = Vec::new();
let mut c = goal_idx;
loop {
path.push(nodes[c]);
match came[c] {
Some(p) => c = p,
None => break,
}
}
path.reverse();
return Some(path);
}
if closed[cur.idx] {
continue;
}
closed[cur.idx] = true;
for &(ni, cost) in &adj[cur.idx] {
if closed[ni] {
continue;
}
let tg = g[cur.idx] + cost;
if tg < g[ni] {
g[ni] = tg;
came[ni] = Some(cur.idx);
open.push(N {
idx: ni,
f: tg + positions[ni].distance(positions[goal_idx]),
});
}
}
}
None
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::mesh::{NavMesh, NavPoly, NavPolyId};
fn make_simple_mesh() -> NavMesh {
let mut mesh = NavMesh::new();
mesh.add_poly(NavPoly {
id: NavPolyId(0),
vertices: vec![
Vec2::new(0.0, 0.0),
Vec2::new(5.0, 0.0),
Vec2::new(5.0, 5.0),
Vec2::new(0.0, 5.0),
],
neighbors: vec![NavPolyId(1)],
cost: 1.0,
layer: 0,
});
mesh.add_poly(NavPoly {
id: NavPolyId(1),
vertices: vec![
Vec2::new(5.0, 0.0),
Vec2::new(10.0, 0.0),
Vec2::new(10.0, 5.0),
Vec2::new(5.0, 5.0),
],
neighbors: vec![NavPolyId(0)],
cost: 1.0,
layer: 0,
});
mesh
}
#[test]
fn same_layer_path() {
let mut ml = MultiLayerNavMesh::new();
ml.add_layer(LayerId(0), make_simple_mesh());
let path = ml.find_path(
LayerId(0),
Vec2::new(1.0, 1.0),
LayerId(0),
Vec2::new(9.0, 1.0),
);
assert!(path.is_some());
let path = path.unwrap();
assert_eq!(path.len(), 2);
assert!(path.iter().all(|lp| lp.layer == LayerId(0)));
}
#[test]
fn cross_layer_path_via_connection() {
let mut ml = MultiLayerNavMesh::new();
ml.add_layer(LayerId(0), make_simple_mesh());
ml.add_layer(LayerId(1), make_simple_mesh());
ml.add_connection(LayerConnection {
from: LayeredPolyId {
layer: LayerId(0),
poly: NavPolyId(1),
},
to: LayeredPolyId {
layer: LayerId(1),
poly: NavPolyId(0),
},
cost: 2.0,
bidirectional: false,
});
let path = ml.find_path(
LayerId(0),
Vec2::new(1.0, 1.0),
LayerId(1),
Vec2::new(9.0, 1.0),
);
assert!(path.is_some());
let path = path.unwrap();
assert!(path.len() >= 3);
assert_eq!(path.first().unwrap().layer, LayerId(0));
assert_eq!(path.last().unwrap().layer, LayerId(1));
}
#[test]
fn no_path_without_connection() {
let mut ml = MultiLayerNavMesh::new();
ml.add_layer(LayerId(0), make_simple_mesh());
ml.add_layer(LayerId(1), make_simple_mesh());
let path = ml.find_path(
LayerId(0),
Vec2::new(1.0, 1.0),
LayerId(1),
Vec2::new(9.0, 1.0),
);
assert!(path.is_none());
}
#[test]
fn bidirectional_connection() {
let mut ml = MultiLayerNavMesh::new();
ml.add_layer(LayerId(0), make_simple_mesh());
ml.add_layer(LayerId(1), make_simple_mesh());
ml.add_connection(LayerConnection {
from: LayeredPolyId {
layer: LayerId(0),
poly: NavPolyId(1),
},
to: LayeredPolyId {
layer: LayerId(1),
poly: NavPolyId(0),
},
cost: 2.0,
bidirectional: true,
});
let fwd = ml.find_path(
LayerId(0),
Vec2::new(1.0, 1.0),
LayerId(1),
Vec2::new(9.0, 1.0),
);
assert!(fwd.is_some());
let rev = ml.find_path(
LayerId(1),
Vec2::new(1.0, 1.0),
LayerId(0),
Vec2::new(9.0, 1.0),
);
assert!(rev.is_some());
}
#[test]
fn remove_layer() {
let mut ml = MultiLayerNavMesh::new();
ml.add_layer(LayerId(0), make_simple_mesh());
ml.add_layer(LayerId(1), make_simple_mesh());
ml.add_connection(LayerConnection {
from: LayeredPolyId {
layer: LayerId(0),
poly: NavPolyId(0),
},
to: LayeredPolyId {
layer: LayerId(1),
poly: NavPolyId(0),
},
cost: 1.0,
bidirectional: false,
});
assert_eq!(ml.layer_count(), 2);
assert_eq!(ml.connections().len(), 1);
assert!(ml.remove_layer(LayerId(1)));
assert_eq!(ml.layer_count(), 1);
assert_eq!(ml.connections().len(), 0);
assert!(!ml.remove_layer(LayerId(99)));
}
#[test]
fn serde_roundtrip() {
let mut ml = MultiLayerNavMesh::new();
ml.add_layer(LayerId(0), make_simple_mesh());
ml.add_layer(LayerId(1), make_simple_mesh());
ml.add_connection(LayerConnection {
from: LayeredPolyId {
layer: LayerId(0),
poly: NavPolyId(1),
},
to: LayeredPolyId {
layer: LayerId(1),
poly: NavPolyId(0),
},
cost: 3.0,
bidirectional: true,
});
let json = serde_json::to_string(&ml).unwrap();
let deserialized: MultiLayerNavMesh = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.layer_count(), 2);
assert_eq!(deserialized.connections().len(), 1);
assert!((deserialized.connections()[0].cost - 3.0).abs() < f32::EPSILON);
}
fn make_single_poly_mesh() -> NavMesh {
let mut mesh = NavMesh::new();
mesh.add_poly(NavPoly {
id: NavPolyId(0),
vertices: vec![
Vec2::ZERO,
Vec2::new(10.0, 0.0),
Vec2::new(10.0, 10.0),
Vec2::new(0.0, 10.0),
],
neighbors: vec![],
cost: 1.0,
layer: 0,
});
mesh
}
#[test]
fn multilayer_cyclic_connections() {
let mut ml = MultiLayerNavMesh::new();
for lid in 0..3u32 {
ml.add_layer(LayerId(lid), make_single_poly_mesh());
}
ml.add_connection(LayerConnection {
from: LayeredPolyId {
layer: LayerId(0),
poly: NavPolyId(0),
},
to: LayeredPolyId {
layer: LayerId(1),
poly: NavPolyId(0),
},
cost: 1.0,
bidirectional: false,
});
ml.add_connection(LayerConnection {
from: LayeredPolyId {
layer: LayerId(1),
poly: NavPolyId(0),
},
to: LayeredPolyId {
layer: LayerId(2),
poly: NavPolyId(0),
},
cost: 1.0,
bidirectional: false,
});
ml.add_connection(LayerConnection {
from: LayeredPolyId {
layer: LayerId(2),
poly: NavPolyId(0),
},
to: LayeredPolyId {
layer: LayerId(0),
poly: NavPolyId(0),
},
cost: 1.0,
bidirectional: false,
});
let path = ml.find_path(
LayerId(0),
Vec2::new(5.0, 5.0),
LayerId(2),
Vec2::new(5.0, 5.0),
);
assert!(path.is_some());
}
#[test]
fn multilayer_many_layers() {
let mut ml = MultiLayerNavMesh::new();
for lid in 0..10u32 {
ml.add_layer(LayerId(lid), make_single_poly_mesh());
if lid > 0 {
ml.add_connection(LayerConnection {
from: LayeredPolyId {
layer: LayerId(lid - 1),
poly: NavPolyId(0),
},
to: LayeredPolyId {
layer: LayerId(lid),
poly: NavPolyId(0),
},
cost: 1.0,
bidirectional: true,
});
}
}
assert_eq!(ml.layer_count(), 10);
let path = ml.find_path(
LayerId(0),
Vec2::new(5.0, 5.0),
LayerId(9),
Vec2::new(5.0, 5.0),
);
assert!(path.is_some());
}
#[test]
fn multilayer_three_layer_path() {
let mut ml = MultiLayerNavMesh::new();
for lid in 0..3u32 {
ml.add_layer(LayerId(lid), make_single_poly_mesh());
}
ml.add_connection(LayerConnection {
from: LayeredPolyId {
layer: LayerId(0),
poly: NavPolyId(0),
},
to: LayeredPolyId {
layer: LayerId(1),
poly: NavPolyId(0),
},
cost: 1.0,
bidirectional: true,
});
ml.add_connection(LayerConnection {
from: LayeredPolyId {
layer: LayerId(1),
poly: NavPolyId(0),
},
to: LayeredPolyId {
layer: LayerId(2),
poly: NavPolyId(0),
},
cost: 1.0,
bidirectional: true,
});
let path = ml.find_path(
LayerId(0),
Vec2::new(5.0, 5.0),
LayerId(2),
Vec2::new(5.0, 5.0),
);
assert!(path.is_some());
assert_eq!(path.unwrap().len(), 3); }
#[test]
fn multilayer_remove_connected_layer() {
let mut ml = MultiLayerNavMesh::new();
for lid in 0..3u32 {
ml.add_layer(LayerId(lid), make_single_poly_mesh());
}
ml.add_connection(LayerConnection {
from: LayeredPolyId {
layer: LayerId(0),
poly: NavPolyId(0),
},
to: LayeredPolyId {
layer: LayerId(1),
poly: NavPolyId(0),
},
cost: 1.0,
bidirectional: true,
});
ml.add_connection(LayerConnection {
from: LayeredPolyId {
layer: LayerId(1),
poly: NavPolyId(0),
},
to: LayeredPolyId {
layer: LayerId(2),
poly: NavPolyId(0),
},
cost: 1.0,
bidirectional: true,
});
ml.remove_layer(LayerId(1));
assert_eq!(ml.layer_count(), 2);
assert!(
ml.connections()
.iter()
.all(|c| c.from.layer != LayerId(1) && c.to.layer != LayerId(1))
);
assert!(
ml.find_path(
LayerId(0),
Vec2::new(5.0, 5.0),
LayerId(2),
Vec2::new(5.0, 5.0),
)
.is_none()
);
}
}