use crate::distance::*;
use crate::gate::*;
use crate::gradient::*;
use std::iter::Iterator;
pub trait Mesh {
type IntoIter: Iterator<Item = Gate>;
fn successors(&self, from: usize, backward: bool) -> Self::IntoIter;
fn len(&self) -> usize;
fn is_empty(&self) -> bool {
self.len() == 0
}
fn unique(self) -> FilteredMesh<Self>
where
Self: Sized + PartialEq,
{
FilteredMesh::new(self)
}
}
pub fn meshes_equal(a: &impl Mesh, b: &impl Mesh) -> bool {
if a.len() != b.len() {
return false;
}
for i in 0..a.len() {
let successors_a: Vec<Gate> = a.successors(i, false).collect();
let successors_b: Vec<Gate> = b.successors(i, false).collect();
if successors_a != successors_b {
return false;
}
}
true
}
pub fn reachable_from_origin(mesh: &impl Mesh) -> Vec<usize> {
if mesh.is_empty() {
return vec![];
}
let mut gradient = Gradient::from_mesh(mesh);
gradient.set_distance(0, 0.0);
gradient.spread(mesh);
(0..mesh.len())
.filter(|&i| gradient.get_distance(i) < DISTANCE_MAX)
.collect()
}
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(
feature = "serde",
derive(serde::Serialize, serde::Deserialize),
serde(bound(
serialize = "M: serde::Serialize",
deserialize = "M: serde::Deserialize<'de>"
))
)]
pub struct FilteredMesh<M: Mesh + PartialEq> {
inner: M,
new_to_old: Vec<usize>,
old_to_new: Vec<usize>,
}
impl<M: Mesh + PartialEq> FilteredMesh<M> {
pub fn new(mesh: M) -> Self {
let reachable = reachable_from_origin(&mesh);
let old_len = mesh.len();
let new_to_old = reachable;
let mut old_to_new = vec![usize::MAX; old_len];
for (new_idx, &old_idx) in new_to_old.iter().enumerate() {
old_to_new[old_idx] = new_idx;
}
Self {
inner: mesh,
new_to_old,
old_to_new,
}
}
pub fn inner(&self) -> &M {
&self.inner
}
pub fn into_inner(self) -> M {
self.inner
}
pub fn to_original_index(&self, filtered_index: usize) -> usize {
self.new_to_old[filtered_index]
}
pub fn to_filtered_index(&self, original_index: usize) -> Option<usize> {
if original_index >= self.old_to_new.len() {
return None;
}
let new_idx = self.old_to_new[original_index];
if new_idx == usize::MAX {
None
} else {
Some(new_idx)
}
}
}
impl<M: Mesh + PartialEq> Mesh for FilteredMesh<M> {
type IntoIter = std::vec::IntoIter<Gate>;
fn successors(&self, from: usize, backward: bool) -> Self::IntoIter {
let old_idx = self.new_to_old[from];
let successors: Vec<Gate> = self
.inner
.successors(old_idx, backward)
.filter_map(|gate| {
let new_target = self.old_to_new[gate.target()];
if new_target != usize::MAX {
Some(Gate::new(new_target, gate.distance))
} else {
None
}
})
.collect();
successors.into_iter()
}
fn len(&self) -> usize {
self.new_to_old.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::mesh_2d::Compact2D;
use crate::mesh_2d::Full2D;
use crate::mesh_topo::{MeshWithTopology, Topology};
struct VerticalWallTopology {
width: usize,
wall_x: usize, }
impl VerticalWallTopology {
fn new(width: usize, wall_x: usize) -> Self {
Self { width, wall_x }
}
fn index_to_x(&self, index: usize) -> usize {
index % self.width
}
}
impl Topology for VerticalWallTopology {
fn allowed(&self, from: usize, to: usize) -> crate::errors::Result<bool> {
let from_x = self.index_to_x(from);
let to_x = self.index_to_x(to);
let from_left = from_x < self.wall_x;
let to_left = to_x < self.wall_x;
Ok(from_left == to_left)
}
}
#[test]
fn test_filtered_mesh_with_topology() {
let base_mesh = Full2D::new(6, 3);
assert_eq!(base_mesh.len(), 18);
let topology = VerticalWallTopology::new(6, 3);
let mesh_with_topo = MeshWithTopology::new(&base_mesh, &topology);
assert_eq!(mesh_with_topo.len(), 18);
let successors_0: Vec<_> = mesh_with_topo.successors(0, false).collect();
assert!(successors_0.iter().all(|g| topology.allowed(0, g.target()).unwrap()));
let filtered = FilteredMesh::new(mesh_with_topo);
assert_eq!(filtered.len(), 9);
let mut gradient = Gradient::from_mesh(&filtered);
gradient.set_distance(0, 0.0);
gradient.spread(&filtered);
for i in 0..filtered.len() {
assert!(
gradient.get_distance(i) < DISTANCE_MAX,
"Node {} should be reachable",
i
);
}
}
#[test]
fn test_filtered_mesh_with_topology_via_unique() {
let base_mesh = Full2D::new(6, 3);
let topology = VerticalWallTopology::new(6, 3);
let mesh_with_topo = MeshWithTopology::new(&base_mesh, &topology);
let filtered = mesh_with_topo.unique();
assert_eq!(filtered.len(), 9);
assert_eq!(filtered.to_original_index(0), 0);
}
#[test]
fn test_filtered_mesh_basic() {
let mesh =
Compact2D::from_text("##########\n# # #\n# # #\n##########\n").unwrap();
assert_eq!(mesh.len(), 14);
let filtered = FilteredMesh::new(mesh);
assert_eq!(filtered.len(), 8); }
#[test]
fn test_filtered_mesh_with_gradient() {
let mesh =
Compact2D::from_text("##########\n# # #\n# # #\n##########\n").unwrap();
let filtered = FilteredMesh::new(mesh);
let mut gradient = Gradient::from_mesh(&filtered);
gradient.set_distance(0, 0.0);
gradient.spread(&filtered);
for i in 0..filtered.len() {
assert!(gradient.get_distance(i) < DISTANCE_MAX);
}
}
#[test]
fn test_filtered_mesh_index_translation() {
let mesh =
Compact2D::from_text("##########\n# # #\n# # #\n##########\n").unwrap();
let filtered = FilteredMesh::new(mesh);
assert_eq!(filtered.to_original_index(0), 0);
assert_eq!(filtered.to_filtered_index(0), Some(0));
assert_eq!(filtered.to_filtered_index(13), None);
}
#[test]
fn test_filtered_mesh_fully_connected() {
let mesh = Compact2D::from_text("...\n...\n...").unwrap();
assert_eq!(mesh.len(), 9);
let filtered = FilteredMesh::new(mesh);
assert_eq!(filtered.len(), 9); }
#[test]
fn test_filtered_mesh_empty() {
let mesh = Compact2D::from_text("###\n###\n###").unwrap();
assert_eq!(mesh.len(), 0);
let filtered = FilteredMesh::new(mesh);
assert_eq!(filtered.len(), 0);
}
#[test]
#[cfg(feature = "serde")]
fn test_filtered_mesh_serde() {
let mesh =
Compact2D::from_text("##########\n# # #\n# # #\n##########\n").unwrap();
let filtered = FilteredMesh::new(mesh);
assert_eq!(filtered.len(), 8);
let json = serde_json::to_string(&filtered).unwrap();
let deserialized: FilteredMesh<Compact2D> = serde_json::from_str(&json).unwrap();
assert_eq!(filtered.len(), deserialized.len());
assert_eq!(
filtered.to_original_index(0),
deserialized.to_original_index(0)
);
let mut gradient = Gradient::from_mesh(&deserialized);
gradient.set_distance(0, 0.0);
gradient.spread(&deserialized);
for i in 0..deserialized.len() {
assert!(gradient.get_distance(i) < DISTANCE_MAX);
}
}
}