use crate::error::SpaceError;
use crate::region::{BoundingShape, RegionPlan, RegionSpec};
use crate::space::Space;
use indexmap::IndexSet;
use murk_core::{Coord, SpaceInstanceId};
use smallvec::SmallVec;
use std::collections::VecDeque;
use std::fmt;
#[derive(Clone, Debug)]
pub enum ProductMetric {
L1,
LInfinity,
Weighted(Vec<f64>),
}
pub struct ProductSpace {
components: Vec<Box<dyn Space>>,
dim_offsets: Vec<usize>,
total_ndim: usize,
total_cells: usize,
instance_id: SpaceInstanceId,
}
impl fmt::Debug for ProductSpace {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ProductSpace")
.field("n_components", &self.components.len())
.field("total_ndim", &self.total_ndim)
.field("total_cells", &self.total_cells)
.field("dim_offsets", &self.dim_offsets)
.finish()
}
}
impl ProductSpace {
pub fn new(components: Vec<Box<dyn Space>>) -> Result<Self, SpaceError> {
if components.is_empty() {
return Err(SpaceError::InvalidComposition {
reason: "ProductSpace requires at least one component".to_string(),
});
}
let mut dim_offsets = Vec::with_capacity(components.len() + 1);
dim_offsets.push(0);
let mut total_ndim = 0usize;
for comp in &components {
total_ndim += comp.ndim();
dim_offsets.push(total_ndim);
}
let mut total_cells: usize = 1;
for comp in &components {
total_cells = total_cells.checked_mul(comp.cell_count()).ok_or_else(|| {
SpaceError::InvalidComposition {
reason: "total cell count overflows usize".to_string(),
}
})?;
}
Ok(Self {
components,
dim_offsets,
total_ndim,
total_cells,
instance_id: SpaceInstanceId::next(),
})
}
pub fn n_components(&self) -> usize {
self.components.len()
}
pub fn component(&self, i: usize) -> &dyn Space {
&*self.components[i]
}
fn split_coord(&self, coord: &Coord, i: usize) -> Coord {
let start = self.dim_offsets[i];
let end = self.dim_offsets[i + 1];
SmallVec::from_slice(&coord[start..end])
}
fn join_coords(&self, parts: &[Coord]) -> Coord {
let mut out = SmallVec::with_capacity(self.total_ndim);
for part in parts {
out.extend_from_slice(part);
}
out
}
fn sort_canonical(&self, coords: &mut [Coord]) {
coords.sort_by_key(|c| self.canonical_rank(c).unwrap_or(usize::MAX));
}
pub fn metric_distance(
&self,
a: &Coord,
b: &Coord,
metric: &ProductMetric,
) -> Result<f64, SpaceError> {
let per_comp: Vec<f64> = (0..self.components.len())
.map(|i| {
let ca = self.split_coord(a, i);
let cb = self.split_coord(b, i);
self.components[i].distance(&ca, &cb)
})
.collect();
match metric {
ProductMetric::L1 => Ok(per_comp.iter().sum()),
ProductMetric::LInfinity => Ok(per_comp.iter().copied().fold(0.0f64, f64::max)),
ProductMetric::Weighted(weights) => {
if weights.len() != self.components.len() {
return Err(SpaceError::InvalidComposition {
reason: format!(
"weighted metric requires exactly one weight per component \
(got {} weights for {} components)",
weights.len(),
self.components.len(),
),
});
}
Ok(per_comp.iter().zip(weights).map(|(d, w)| d * w).sum())
}
}
}
fn compile_cartesian_product(&self, per_comp: &[RegionPlan]) -> RegionPlan {
let mut bounding_dims = Vec::new();
for plan in per_comp {
match &plan.bounding_shape {
BoundingShape::Rect(dims) => bounding_dims.extend(dims),
}
}
let bounding_total: usize = bounding_dims.iter().product();
let comp_bounding_sizes: Vec<usize> = per_comp
.iter()
.map(|p| p.bounding_shape.total_elements())
.collect();
let n = per_comp.len();
let mut strides = vec![1usize; n];
for i in (0..n - 1).rev() {
strides[i] = strides[i + 1] * comp_bounding_sizes[i + 1];
}
let mut valid_mask = vec![0u8; bounding_total];
let mut coords = Vec::new();
let mut tensor_indices = Vec::new();
let per_comp_entries: Vec<Vec<(Coord, usize)>> = per_comp
.iter()
.map(|plan| {
plan.coords
.iter()
.zip(&plan.tensor_indices)
.map(|(c, &ti)| (c.clone(), ti))
.collect()
})
.collect();
let mut indices = vec![0usize; n];
loop {
let mut product_tensor_idx = 0;
let mut product_coord = SmallVec::with_capacity(self.total_ndim);
for (i, &idx) in indices.iter().enumerate() {
let (ref c, ti) = per_comp_entries[i][idx];
product_tensor_idx += ti * strides[i];
product_coord.extend_from_slice(c);
}
valid_mask[product_tensor_idx] = 1;
coords.push(product_coord);
tensor_indices.push(product_tensor_idx);
let mut carry = true;
for i in (0..n).rev() {
if carry {
indices[i] += 1;
if indices[i] < per_comp_entries[i].len() {
carry = false;
} else {
indices[i] = 0;
}
}
}
if carry {
break;
}
}
RegionPlan {
coords,
tensor_indices,
valid_mask,
bounding_shape: BoundingShape::Rect(bounding_dims),
}
}
}
impl Space for ProductSpace {
fn ndim(&self) -> usize {
self.total_ndim
}
fn cell_count(&self) -> usize {
self.total_cells
}
fn neighbours(&self, coord: &Coord) -> SmallVec<[Coord; 8]> {
let parts: Vec<Coord> = (0..self.components.len())
.map(|i| self.split_coord(coord, i))
.collect();
let mut result = SmallVec::new();
for i in 0..self.components.len() {
let comp_neighbours = self.components[i].neighbours(&parts[i]);
for nb in comp_neighbours {
let mut new_parts = parts.clone();
new_parts[i] = nb;
result.push(self.join_coords(&new_parts));
}
}
result
}
fn distance(&self, a: &Coord, b: &Coord) -> f64 {
(0..self.components.len())
.map(|i| {
let ca = self.split_coord(a, i);
let cb = self.split_coord(b, i);
self.components[i].distance(&ca, &cb)
})
.sum()
}
fn compile_region(&self, spec: &RegionSpec) -> Result<RegionPlan, SpaceError> {
match spec {
RegionSpec::All => {
let per_comp: Vec<RegionPlan> = self
.components
.iter()
.map(|c| c.compile_region(&RegionSpec::All))
.collect::<Result<_, _>>()?;
Ok(self.compile_cartesian_product(&per_comp))
}
RegionSpec::Rect { min, max } => {
if min.len() != self.total_ndim || max.len() != self.total_ndim {
return Err(SpaceError::InvalidRegion {
reason: format!(
"Rect coordinates must have {} dimensions, got {}/{}",
self.total_ndim,
min.len(),
max.len()
),
});
}
let per_comp: Vec<RegionPlan> = (0..self.components.len())
.map(|i| {
let start = self.dim_offsets[i];
let end = self.dim_offsets[i + 1];
let comp_min: Coord = SmallVec::from_slice(&min[start..end]);
let comp_max: Coord = SmallVec::from_slice(&max[start..end]);
self.components[i].compile_region(&RegionSpec::Rect {
min: comp_min,
max: comp_max,
})
})
.collect::<Result<_, _>>()?;
Ok(self.compile_cartesian_product(&per_comp))
}
RegionSpec::Disk { center, radius } => {
self.compile_disk_bfs(center, *radius)
}
RegionSpec::Neighbours { center, depth } => self.compile_disk_bfs(center, *depth),
RegionSpec::Coords(coords) => {
for coord in coords {
if coord.len() != self.total_ndim {
return Err(SpaceError::CoordOutOfBounds {
coord: coord.clone(),
bounds: format!("expected {}D coordinate", self.total_ndim),
});
}
for i in 0..self.components.len() {
let sub = self.split_coord(coord, i);
let ordering = self.components[i].canonical_ordering();
if !ordering.contains(&sub) {
return Err(SpaceError::CoordOutOfBounds {
coord: coord.clone(),
bounds: format!("component {i} coordinate out of bounds"),
});
}
}
}
let mut sorted: Vec<Coord> = coords.clone();
self.sort_canonical(&mut sorted);
sorted.dedup();
let n = sorted.len();
let tensor_indices: Vec<usize> = (0..n).collect();
let valid_mask = vec![1u8; n];
Ok(RegionPlan {
coords: sorted,
tensor_indices,
valid_mask,
bounding_shape: BoundingShape::Rect(vec![n]),
})
}
}
}
fn canonical_ordering(&self) -> Vec<Coord> {
let orderings: Vec<Vec<Coord>> = self
.components
.iter()
.map(|c| c.canonical_ordering())
.collect();
let n = self.components.len();
let mut result = Vec::with_capacity(self.total_cells);
let mut indices = vec![0usize; n];
loop {
let mut coord = SmallVec::with_capacity(self.total_ndim);
for (i, &idx) in indices.iter().enumerate() {
coord.extend_from_slice(&orderings[i][idx]);
}
result.push(coord);
let mut carry = true;
for i in (0..n).rev() {
if carry {
indices[i] += 1;
if indices[i] < orderings[i].len() {
carry = false;
} else {
indices[i] = 0;
}
}
}
if carry {
break;
}
}
result
}
fn canonical_rank(&self, coord: &Coord) -> Option<usize> {
if coord.len() != self.total_ndim {
return None;
}
let n = self.components.len();
let mut rank = 0usize;
let mut stride = 1usize;
for i in (0..n).rev() {
let sub = self.split_coord(coord, i);
let comp_rank = self.components[i].canonical_rank(&sub)?;
rank += comp_rank * stride;
stride *= self.components[i].cell_count();
}
Some(rank)
}
fn instance_id(&self) -> SpaceInstanceId {
self.instance_id
}
fn topology_eq(&self, other: &dyn Space) -> bool {
let Some(o) = (other as &dyn std::any::Any).downcast_ref::<Self>() else {
return false;
};
self.components.len() == o.components.len()
&& self
.components
.iter()
.zip(o.components.iter())
.all(|(a, b)| a.topology_eq(b.as_ref()))
}
}
impl ProductSpace {
fn compile_disk_bfs(&self, center: &Coord, radius: u32) -> Result<RegionPlan, SpaceError> {
if center.len() != self.total_ndim {
return Err(SpaceError::CoordOutOfBounds {
coord: center.clone(),
bounds: format!("expected {}D coordinate", self.total_ndim),
});
}
for i in 0..self.components.len() {
let sub = self.split_coord(center, i);
let ordering = self.components[i].canonical_ordering();
if !ordering.contains(&sub) {
return Err(SpaceError::CoordOutOfBounds {
coord: center.clone(),
bounds: format!("component {i} coordinate {:?} out of bounds", sub),
});
}
}
let mut visited: IndexSet<Coord> = IndexSet::new();
let mut queue: VecDeque<(Coord, u32)> = VecDeque::new();
let mut result: Vec<Coord> = Vec::new();
visited.insert(center.clone());
queue.push_back((center.clone(), 0));
result.push(center.clone());
while let Some((coord, dist)) = queue.pop_front() {
if dist >= radius {
continue;
}
for nb in self.neighbours(&coord) {
if visited.insert(nb.clone()) {
queue.push_back((nb.clone(), dist + 1));
result.push(nb);
}
}
}
self.sort_canonical(&mut result);
let n = result.len();
let tensor_indices: Vec<usize> = (0..n).collect();
let valid_mask = vec![1u8; n];
Ok(RegionPlan {
coords: result,
tensor_indices,
valid_mask,
bounding_shape: BoundingShape::Rect(vec![n]),
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::compliance;
use crate::{Hex2D, Line1D, Ring1D};
use murk_core::Coord;
use proptest::prelude::*;
use smallvec::smallvec;
fn hex_line() -> ProductSpace {
let hex = Hex2D::new(5, 5).unwrap();
let line = Line1D::new(10, crate::EdgeBehavior::Absorb).unwrap();
ProductSpace::new(vec![Box::new(hex), Box::new(line)]).unwrap()
}
#[test]
fn neighbours_hex_line() {
let s = hex_line();
let coord: Coord = smallvec![2, 1, 5];
let n = s.neighbours(&coord);
assert_eq!(n.len(), 8);
assert!(n.contains(&smallvec![3, 1, 5])); assert!(n.contains(&smallvec![3, 0, 5])); assert!(n.contains(&smallvec![2, 0, 5])); assert!(n.contains(&smallvec![1, 1, 5])); assert!(n.contains(&smallvec![1, 2, 5])); assert!(n.contains(&smallvec![2, 2, 5]));
assert!(n.contains(&smallvec![2, 1, 4]));
assert!(n.contains(&smallvec![2, 1, 6]));
}
#[test]
fn distance_hex_line() {
let s = hex_line();
let a: Coord = smallvec![2, 1, 5];
let b: Coord = smallvec![4, 0, 8];
assert_eq!(s.distance(&a, &b), 5.0);
}
#[test]
fn metric_distance_linf() {
let s = hex_line();
let a: Coord = smallvec![2, 1, 5];
let b: Coord = smallvec![4, 0, 8];
assert_eq!(
s.metric_distance(&a, &b, &ProductMetric::LInfinity)
.unwrap(),
3.0
);
}
#[test]
fn metric_distance_weighted() {
let s = hex_line();
let a: Coord = smallvec![2, 1, 5];
let b: Coord = smallvec![4, 0, 8];
assert_eq!(
s.metric_distance(&a, &b, &ProductMetric::Weighted(vec![1.0, 2.0]))
.unwrap(),
8.0
);
}
#[test]
fn iteration_order_hex_line() {
let s = hex_line();
let order = s.canonical_ordering();
assert_eq!(order.len(), 250);
for (i, coord) in order.iter().enumerate().take(10) {
let expected: Coord = smallvec![0, 0, i as i32];
assert_eq!(*coord, expected);
}
for (j, coord) in order[10..20].iter().enumerate() {
let expected: Coord = smallvec![1, 0, j as i32];
assert_eq!(*coord, expected);
}
}
#[test]
fn region_rect_hex_line() {
let s = hex_line();
let plan = s
.compile_region(&RegionSpec::Rect {
min: smallvec![1, 1, 3],
max: smallvec![2, 2, 5],
})
.unwrap();
assert_eq!(plan.cell_count(), 12);
}
#[test]
fn ndim_sum() {
let s = hex_line();
assert_eq!(s.ndim(), 3); }
#[test]
fn cell_count_product() {
let s = hex_line();
assert_eq!(s.cell_count(), 250); }
#[test]
fn three_component() {
let hex = Hex2D::new(3, 3).unwrap();
let line = Line1D::new(5, crate::EdgeBehavior::Absorb).unwrap();
let ring = Ring1D::new(4).unwrap();
let s = ProductSpace::new(vec![Box::new(hex), Box::new(line), Box::new(ring)]).unwrap();
assert_eq!(s.ndim(), 4); assert_eq!(s.cell_count(), 180); }
#[test]
fn single_component() {
let line = Line1D::new(5, crate::EdgeBehavior::Absorb).unwrap();
let s = ProductSpace::new(vec![Box::new(line)]).unwrap();
assert_eq!(s.ndim(), 1);
assert_eq!(s.cell_count(), 5);
let n = s.neighbours(&smallvec![2]);
assert_eq!(n.len(), 2);
assert!(n.contains(&smallvec![1]));
assert!(n.contains(&smallvec![3]));
}
#[test]
fn empty_components_error() {
let result = ProductSpace::new(vec![]);
assert!(matches!(result, Err(SpaceError::InvalidComposition { .. })));
}
#[test]
fn valid_ratio_hex_disk_x_line_all() {
let hex = Hex2D::new(10, 10).unwrap();
let line = Line1D::new(5, crate::EdgeBehavior::Absorb).unwrap();
let s = ProductSpace::new(vec![Box::new(hex), Box::new(line)]).unwrap();
let plan = s.compile_region(&RegionSpec::All).unwrap();
assert_eq!(plan.valid_ratio(), 1.0);
}
#[test]
fn compliance_hex_line_small() {
let hex = Hex2D::new(3, 3).unwrap();
let line = Line1D::new(3, crate::EdgeBehavior::Absorb).unwrap();
let s = ProductSpace::new(vec![Box::new(hex), Box::new(line)]).unwrap();
compliance::run_full_compliance(&s);
}
#[test]
fn compliance_line_ring() {
let line = Line1D::new(4, crate::EdgeBehavior::Absorb).unwrap();
let ring = Ring1D::new(3).unwrap();
let s = ProductSpace::new(vec![Box::new(line), Box::new(ring)]).unwrap();
compliance::run_full_compliance(&s);
}
#[test]
fn downcast_ref_product_space() {
let line = Line1D::new(5, crate::EdgeBehavior::Absorb).unwrap();
let s: Box<dyn Space> = Box::new(ProductSpace::new(vec![Box::new(line)]).unwrap());
assert!(s.downcast_ref::<ProductSpace>().is_some());
assert!(s.downcast_ref::<Hex2D>().is_none());
}
#[test]
fn disk_coords_match_canonical_order() {
let hex = Hex2D::new(5, 5).unwrap();
let line = Line1D::new(5, crate::EdgeBehavior::Absorb).unwrap();
let s = ProductSpace::new(vec![Box::new(hex), Box::new(line)]).unwrap();
let plan = s
.compile_region(&RegionSpec::Disk {
center: smallvec![2, 2, 2],
radius: 1,
})
.unwrap();
let canonical = s.canonical_ordering();
let mut last_pos = None;
for coord in &plan.coords {
let pos = canonical
.iter()
.position(|c| c == coord)
.expect("disk coord not in canonical ordering");
if let Some(lp) = last_pos {
assert!(pos > lp, "coords not in canonical order: {:?}", plan.coords);
}
last_pos = Some(pos);
}
}
#[test]
fn coords_region_matches_canonical_order() {
let hex = Hex2D::new(5, 5).unwrap();
let line = Line1D::new(5, crate::EdgeBehavior::Absorb).unwrap();
let s = ProductSpace::new(vec![Box::new(hex), Box::new(line)]).unwrap();
let plan = s
.compile_region(&RegionSpec::Coords(vec![
smallvec![1, 0, 3],
smallvec![0, 1, 2], smallvec![2, 0, 0],
]))
.unwrap();
let canonical = s.canonical_ordering();
let mut last_pos = None;
for coord in &plan.coords {
let pos = canonical.iter().position(|c| c == coord).unwrap();
if let Some(lp) = last_pos {
assert!(pos > lp, "coords not in canonical order: {:?}", plan.coords);
}
last_pos = Some(pos);
}
}
#[test]
fn disk_oob_center_rejected() {
let hex = Hex2D::new(5, 5).unwrap();
let line = Line1D::new(5, crate::EdgeBehavior::Absorb).unwrap();
let s = ProductSpace::new(vec![Box::new(hex), Box::new(line)]).unwrap();
let result = s.compile_region(&RegionSpec::Disk {
center: smallvec![999, 0, 2],
radius: 1,
});
assert!(result.is_err());
}
#[test]
fn weighted_metric_too_few_weights_returns_err() {
let s = hex_line(); let a: Coord = smallvec![2, 1, 5];
let b: Coord = smallvec![4, 0, 8];
let result = s.metric_distance(&a, &b, &ProductMetric::Weighted(vec![1.0]));
assert!(matches!(result, Err(SpaceError::InvalidComposition { .. })));
}
#[test]
fn weighted_metric_too_many_weights_returns_err() {
let s = hex_line(); let a: Coord = smallvec![2, 1, 5];
let b: Coord = smallvec![4, 0, 8];
let result = s.metric_distance(&a, &b, &ProductMetric::Weighted(vec![1.0, 2.0, 3.0]));
assert!(matches!(result, Err(SpaceError::InvalidComposition { .. })));
}
#[test]
fn weighted_metric_exact_match_succeeds() {
let s = hex_line(); let a: Coord = smallvec![2, 1, 5];
let b: Coord = smallvec![4, 0, 8];
let d = s
.metric_distance(&a, &b, &ProductMetric::Weighted(vec![1.0, 1.0]))
.unwrap();
assert_eq!(d, 5.0); }
proptest! {
#[test]
fn distance_is_metric(
len_a in 2u32..5,
len_b in 2u32..5,
ai in 0i32..5, bi in 0i32..5,
aj in 0i32..5, bj in 0i32..5,
ci in 0i32..5, cj in 0i32..5,
) {
let ai = ai % len_a as i32;
let bi = bi % len_b as i32;
let aj = aj % len_a as i32;
let bj = bj % len_b as i32;
let ci = ci % len_a as i32;
let cj = cj % len_b as i32;
let line_a = Line1D::new(len_a, crate::EdgeBehavior::Absorb).unwrap();
let line_b = Line1D::new(len_b, crate::EdgeBehavior::Absorb).unwrap();
let s = ProductSpace::new(vec![Box::new(line_a), Box::new(line_b)]).unwrap();
let a: Coord = smallvec![ai, bi];
let b: Coord = smallvec![aj, bj];
let cv: Coord = smallvec![ci, cj];
prop_assert!((s.distance(&a, &a) - 0.0).abs() < f64::EPSILON);
prop_assert!((s.distance(&a, &b) - s.distance(&b, &a)).abs() < f64::EPSILON);
prop_assert!(s.distance(&a, &cv) <= s.distance(&a, &b) + s.distance(&b, &cv) + f64::EPSILON);
}
#[test]
fn neighbours_symmetric(
len_a in 2u32..5,
len_b in 2u32..5,
i in 0i32..5, j in 0i32..5,
) {
let i = i % len_a as i32;
let j = j % len_b as i32;
let line_a = Line1D::new(len_a, crate::EdgeBehavior::Absorb).unwrap();
let line_b = Line1D::new(len_b, crate::EdgeBehavior::Absorb).unwrap();
let s = ProductSpace::new(vec![Box::new(line_a), Box::new(line_b)]).unwrap();
let coord: Coord = smallvec![i, j];
for nb in s.neighbours(&coord) {
let nb_neighbours = s.neighbours(&nb);
prop_assert!(
nb_neighbours.contains(&coord),
"neighbour symmetry violated: {:?} in N({:?}) but {:?} not in N({:?})",
nb, coord, coord, nb,
);
}
}
}
}