use crate::wfc_util::*;
use petgraph::Graph;
use std::collections::HashMap;
use std::collections::HashSet;
pub trait GridBuilder {
fn build_grid_system(&mut self, grid: &mut GridSystem) -> Result<(), GridError>;
fn get_dimensions(&self) -> Vec<usize> {
vec![]
}
fn get_grid_type_name(&self) -> &'static str {
"CustomGrid"
}
}
pub struct GridSystem {
graph: WFCGraph,
cell_lookup: HashMap<String, CellId>,
virtual_nodes: HashSet<CellId>,
}
impl GridSystem {
pub fn new() -> Self {
Self {
graph: Graph::new(),
cell_lookup: HashMap::new(),
virtual_nodes: HashSet::new(),
}
}
pub fn with_capacity(nodes: usize, edges: usize) -> Self {
Self {
graph: Graph::with_capacity(nodes, edges),
cell_lookup: HashMap::new(),
virtual_nodes: HashSet::new(),
}
}
pub fn build_with<T: GridBuilder>(&mut self, mut builder: T) -> Result<(), GridError> {
builder.build_grid_system(self)
}
pub fn from_builder<T: GridBuilder>(mut builder: T) -> Result<Self, GridError> {
let mut grid = Self::new();
builder.build_grid_system(&mut grid)?;
Ok(grid)
}
pub fn add_cell(&mut self, cell_data: Cell) -> CellId {
self.graph.add_node(cell_data)
}
pub fn add_cell_with_name(&mut self, cell_data: Cell, name: String) -> CellId {
let cell_id = self.add_cell(cell_data);
self.cell_lookup.insert(name, cell_id);
cell_id
}
pub fn get_cell_by_name(&self, name: &str) -> Option<CellId> {
self.cell_lookup.get(name).copied()
}
pub fn is_virtual_node(&self, node_id: CellId) -> bool {
self.virtual_nodes.contains(&node_id)
}
pub fn create_edge(&mut self, from: CellId, to: Option<CellId>) -> Result<EdgeId, GridError> {
if !self.graph.node_indices().any(|n| n == from) {
return Err(GridError::NodeNotFound);
}
let target_node = match to {
Some(real_to) => {
if !self.graph.node_indices().any(|n| n == real_to) {
return Err(GridError::NodeNotFound);
}
real_to
}
None => {
let virtual_node = self
.graph
.add_node(Cell::with_name("__VIRTUAL__".to_string()));
self.virtual_nodes.insert(virtual_node);
virtual_node
}
};
if from == target_node {
return Err(GridError::SelfLoop);
}
if self.graph.find_edge(from, target_node).is_some() {
return Err(GridError::EdgeAlreadyExists);
}
let edge_id = self.graph.add_edge(from, target_node, GraphEdge::new());
Ok(edge_id)
}
pub fn get_neighbors(&self, cell_id: CellId) -> Vec<CellId> {
self.graph.neighbors(cell_id).collect()
}
pub fn find_edge(&self, from: CellId, to: CellId) -> Option<EdgeId> {
self.graph.find_edge(from, to)
}
pub fn get_all_cells(&self) -> impl Iterator<Item = CellId> + '_ {
self.graph.node_indices()
}
pub fn get_cells_count(&self) -> usize {
self.graph.node_count()
}
pub fn get_edges_count(&self) -> usize {
self.graph.edge_count()
}
pub fn get_neighbor_by_direction<D>(&self, cell_id: CellId, direction: D) -> Option<CellId>
where
D: DirectionTrait,
{
let neighbors = self.get_neighbors(cell_id);
if let Some(index) = direction.to_neighbor_index() {
neighbors.get(index).copied()
} else {
self.find_incoming_neighbor_by_direction(cell_id, direction)
}
}
fn find_incoming_neighbor_by_direction<D>(
&self,
cell_id: CellId,
direction: D,
) -> Option<CellId>
where
D: DirectionTrait,
{
for node_id in self.graph.node_indices() {
let neighbors = self.get_neighbors(node_id);
if let Some(opposite_direction) = direction.opposite() {
if let Some(index) = opposite_direction.to_neighbor_index() {
if let Some(&neighbor) = neighbors.get(index) {
if neighbor == cell_id {
return Some(node_id);
}
}
}
}
}
None
}
pub fn get_neighbors_by_direction<D>(&self, cell_id: CellId, direction: D) -> Vec<CellId>
where
D: DirectionTrait,
{
if let Some(neighbor) = self.get_neighbor_by_direction(cell_id, direction) {
vec![neighbor]
} else {
vec![]
}
}
pub fn contains_cell(&self, cell_id: CellId) -> bool {
self.graph.node_indices().any(|n| n == cell_id)
}
pub fn contains_edge(&self, from: CellId, to: CellId) -> bool {
self.graph.find_edge(from, to).is_some()
}
pub fn capacity(&self) -> (usize, usize) {
self.graph.capacity()
}
pub fn clear(&mut self) {
self.graph.clear();
self.cell_lookup.clear();
}
pub fn get_cell_degree(&self, cell_id: CellId) -> usize {
self.get_neighbors(cell_id).len()
}
pub fn validate_structure(&self) -> Result<(), GridError> {
for edge_id in self.graph.edge_indices() {
if let Some((source, target)) = self.graph.edge_endpoints(edge_id) {
if !self.contains_cell(source) {
return Err(GridError::NodeNotFound);
}
if !self.contains_cell(target) {
return Err(GridError::NodeNotFound);
}
}
}
Ok(())
}
pub fn get_statistics(&self) -> String {
format!(
"GridSystem Statistics:\n Nodes: {}\n Edges: {}\n Capacity: {:?}\n Named cells: {}",
self.get_cells_count(),
self.get_edges_count(),
self.capacity(),
self.cell_lookup.len()
)
}
pub fn debug_print_neighbors(&self, cell_id: CellId) {
println!("Cell {:?} neighbors:", cell_id);
let neighbors = self.get_neighbors(cell_id);
for (i, neighbor) in neighbors.iter().enumerate() {
println!(" [{}]: {:?}", i, neighbor);
}
println!(" Direction queries:");
for direction in Direction4::all_directions() {
if let Some(neighbor) = self.get_neighbor_by_direction(cell_id, direction) {
println!(" {}: {:?}", direction.name(), neighbor);
} else {
println!(" {}: None", direction.name());
}
}
}
pub fn debug_print_grid(&self) {
println!("=== Grid System Debug Info ===");
println!("{}", self.get_statistics());
println!("\nAll cells:");
for cell_id in self.get_all_cells() {
let neighbors = self.get_neighbors(cell_id);
println!(" {:?}: neighbors = {:?}", cell_id, neighbors);
}
println!("\nNamed cells:");
for (name, cell_id) in &self.cell_lookup {
println!(" '{}': {:?}", name, cell_id);
}
}
}
impl Default for GridSystem {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
struct SimpleGridBuilder {
width: usize,
height: usize,
}
impl SimpleGridBuilder {
fn new(width: usize, height: usize) -> Self {
Self { width, height }
}
}
impl GridBuilder for SimpleGridBuilder {
fn build_grid_system(&mut self, grid: &mut GridSystem) -> Result<(), GridError> {
let mut cells = vec![vec![]; self.height];
for y in 0..self.height {
cells[y] = Vec::with_capacity(self.width);
for x in 0..self.width {
let cell_id = grid.add_cell_with_name(
Cell::with_id((y * self.width + x) as u32),
format!("cell_{}_{}", x, y),
);
cells[y].push(cell_id);
}
}
for y in 0..self.height {
for x in 0..self.width {
let current = cells[y][x];
if x + 1 < self.width {
grid.create_edge(current, Some(cells[y][x + 1]))?;
} else {
grid.create_edge(current, None)?;
}
if y + 1 < self.height {
grid.create_edge(current, Some(cells[y + 1][x]))?;
} else {
grid.create_edge(current, None)?;
}
}
}
Ok(())
}
fn get_dimensions(&self) -> Vec<usize> {
vec![self.width, self.height]
}
fn get_grid_type_name(&self) -> &'static str {
"SimpleGrid"
}
}
#[test]
fn test_grid_system_creation() {
let grid = GridSystem::new();
assert_eq!(grid.get_cells_count(), 0);
assert_eq!(grid.get_edges_count(), 0);
}
#[test]
fn test_add_cells_and_edges() {
let mut grid = GridSystem::new();
let cell1 = grid.add_cell(Cell::with_id(1));
let cell2 = grid.add_cell(Cell::with_id(2));
assert_eq!(grid.get_cells_count(), 2);
let _edge = grid.create_edge(cell1, Some(cell2)).unwrap();
assert_eq!(grid.get_edges_count(), 1);
let neighbors = grid.get_neighbors(cell1);
assert_eq!(neighbors.len(), 1);
assert_eq!(neighbors[0], cell2);
}
#[test]
fn test_direction_queries() {
let mut grid = GridSystem::new();
let cells = vec![
vec![
grid.add_cell(Cell::with_id(0)),
grid.add_cell(Cell::with_id(1)),
],
vec![
grid.add_cell(Cell::with_id(2)),
grid.add_cell(Cell::with_id(3)),
],
];
let center = cells[0][0];
let east = cells[0][1];
let south = cells[1][0];
grid.create_edge(center, Some(east)).unwrap(); grid.create_edge(center, Some(south)).unwrap();
assert_eq!(
grid.get_neighbor_by_direction(center, Direction4::East),
Some(east)
);
assert_eq!(
grid.get_neighbor_by_direction(center, Direction4::South),
Some(south)
);
assert_eq!(
grid.get_neighbor_by_direction(center, Direction4::West),
None
);
assert_eq!(
grid.get_neighbor_by_direction(center, Direction4::North),
None
);
}
#[test]
fn test_error_handling() {
let mut grid = GridSystem::new();
let cell1 = grid.add_cell(Cell::new());
assert_eq!(
grid.create_edge(cell1, Some(cell1)),
Err(GridError::SelfLoop)
);
let cell2 = grid.add_cell(Cell::new());
grid.create_edge(cell1, Some(cell2)).unwrap();
assert_eq!(
grid.create_edge(cell1, Some(cell2)),
Err(GridError::EdgeAlreadyExists)
);
}
#[test]
fn test_named_cells() {
let mut grid = GridSystem::new();
let cell_id = grid.add_cell_with_name(Cell::new(), "test_cell".to_string());
assert_eq!(grid.get_cell_by_name("test_cell"), Some(cell_id));
assert_eq!(grid.get_cell_by_name("nonexistent"), None);
}
#[test]
fn test_structure_validation() {
let mut grid = GridSystem::new();
let cell1 = grid.add_cell(Cell::new());
let cell2 = grid.add_cell(Cell::new());
grid.create_edge(cell1, Some(cell2)).unwrap();
assert!(grid.validate_structure().is_ok());
}
#[test]
fn test_grid_builder_trait() {
let builder = SimpleGridBuilder::new(3, 2);
let mut grid = GridSystem::new();
grid.build_with(builder).unwrap();
assert_eq!(grid.get_cells_count(), 11);
assert!(grid.get_cell_by_name("cell_0_0").is_some());
assert!(grid.get_cell_by_name("cell_2_1").is_some());
assert!(grid.get_cell_by_name("cell_3_0").is_none());
let expected_edges = 12;
assert_eq!(grid.get_edges_count(), expected_edges);
}
#[test]
fn test_from_builder() {
let builder = SimpleGridBuilder::new(2, 2);
let grid = GridSystem::from_builder(builder).unwrap();
assert_eq!(grid.get_cells_count(), 8);
let cell_0_0 = grid.get_cell_by_name("cell_0_0").unwrap();
let neighbors = grid.get_neighbors(cell_0_0);
assert_eq!(neighbors.len(), 2);
let expected_edges = 8;
assert_eq!(grid.get_edges_count(), expected_edges);
}
}