use bitvec::{order::Lsb0, vec::BitVec};
use image::{GenericImageView, Rgba, RgbaImage};
use rand::{
distr::{Distribution, weighted::WeightedIndex},
prelude::IteratorRandom,
rng,
};
use rayon::prelude::*;
use std::{
cmp::Ordering,
collections::{HashMap, VecDeque},
};
use paraxis::{mathematics::ivector::IVector, structure::grid::Grid};
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)]
enum Direction {
Left = 0,
Right = 1,
Up = 2,
Down = 3,
}
pub struct WFC<T, const N: usize> {
possibilities: Grid<BitVec<u64, Lsb0>, N>,
rules: Vec<T>,
compatibility: Vec<Vec<BitVec<u64, Lsb0>>>,
buckets: Vec<VecDeque<IVector<N>>>,
weights: Vec<u32>,
}
impl<T, const N: usize> WFC<T, N> {
pub fn lowest_entropy(&mut self) -> (IVector<N>, Option<u32>) {
self.possibilities.sort_by(|_, b, _, d| {
let bc = b.count_ones();
let dc = d.count_ones();
let b_end = bc <= 1;
let d_end = dc <= 1;
match (b_end, d_end) {
(false, false) => bc.cmp(&dc),
(true, true) => Ordering::Equal,
(true, false) => Ordering::Greater,
(false, true) => Ordering::Less,
}
});
let lowest = self
.possibilities
.first()
.expect("Possibility grid is empty.");
if lowest.1.count_ones() > 1 {
(lowest.0.clone(), Some(lowest.1.count_ones() as u32))
} else {
(IVector::new([0; N]), None)
}
}
pub fn collapse(&mut self) -> bool {
let lowest_entropy = {
let mut rng = rng();
let mut result = None;
for e in 2..self.buckets.len() {
let valid_positions: Vec<IVector<N>> = self.buckets[e]
.iter()
.copied()
.filter(|&pos| self.possibilities.get(pos).unwrap().count_ones() == e)
.collect();
if !valid_positions.is_empty() {
let pos = *valid_positions.iter().choose(&mut rng).unwrap();
result = Some((pos, e));
self.buckets[e].retain(|&p| p != pos);
break;
}
}
result
};
if lowest_entropy.is_none() {
return true;
}
let lowest_entropy = lowest_entropy.unwrap();
let mut rng = rng();
let possibilities = self.possibilities.get(lowest_entropy.0).unwrap();
let valid_ids: Vec<usize> = possibilities.iter_ones().collect();
let valid_weights: Vec<u32> = valid_ids.iter().map(|&id| self.weights[id]).collect();
let dist = WeightedIndex::new(&valid_weights)
.expect("Failed to create weighted index (are all weights 0?)");
let chosen_id = valid_ids[dist.sample(&mut rng)];
let mut new_set = BitVec::<u64, Lsb0>::repeat(false, self.rules.len());
new_set.set(chosen_id, true);
self.possibilities.insert(new_set, lowest_entropy.0);
let mut stack = VecDeque::new();
let mut in_stack = std::collections::HashSet::new();
stack.push_back(lowest_entropy.0);
in_stack.insert(lowest_entropy.0);
let mut allowed = BitVec::<u64, Lsb0>::repeat(false, self.rules.len());
while let Some(current_position) = stack.pop_front() {
in_stack.remove(¤t_position);
let current_set = self.possibilities.get(current_position).unwrap().clone();
for (neighbour, _) in self.possibilities.neighbours(¤t_position) {
let delta = neighbour - current_position;
let mut dir_index = usize::MAX;
for i in 0..N {
if delta[i] > 0 {
dir_index = i * 2;
break;
} else if delta[i] < 0 {
dir_index = i * 2 + 1;
break;
}
}
if dir_index == usize::MAX {
unreachable!("Neighbor is not adjacent!");
}
allowed.fill(false);
for possible_id in current_set.iter_ones() {
allowed |= &self.compatibility[possible_id][dir_index];
}
let mut neighbour_set = self.possibilities.get(neighbour).unwrap().clone();
let old_count = neighbour_set.count_ones();
neighbour_set &= allowed.clone();
let new_count = neighbour_set.count_ones();
if new_count == 0 {
return false;
}
if new_count < old_count {
self.buckets[new_count as usize].push_back(neighbour);
self.possibilities.insert(neighbour_set, neighbour);
if !in_stack.contains(&neighbour) {
stack.push_back(neighbour);
in_stack.insert(neighbour);
}
}
}
}
false
}
fn tiles_match(
tile_a: &[Rgba<u8>],
tile_b: &[Rgba<u8>],
dir_index: usize,
size: usize,
) -> bool {
let axis = dir_index / 2;
let is_positive = dir_index % 2 == 0;
let axis_mult = size.pow((N - 1 - axis) as u32);
let total_elements = size.pow(N as u32);
for i in 0..total_elements {
let coord_on_axis = (i / axis_mult) % size;
if coord_on_axis < size - 1 {
if is_positive {
if tile_a[i + axis_mult] != tile_b[i] {
return false;
}
} else {
if tile_a[i] != tile_b[i + axis_mult] {
return false;
}
}
}
}
true
}
}
#[cfg(feature = "image")]
impl WFC<Vec<Rgba<u8>>, 2> {
pub fn from_image(path: &str, tilesize: usize, output_size: IVector<2>) -> Self {
let mut rules = Vec::new();
let mut weights = Vec::new();
let mut tile_to_id = HashMap::new();
let mut id_to_pos = Vec::new();
let img = image::open(path).expect("Could not find image.");
let (width, height) = img.dimensions();
let mut adjacency = Grid::new(IVector::new([
(width / tilesize as u32) as i32,
(height / tilesize as u32) as i32,
]));
for h in (0..height).step_by(1) {
for w in (0..width).step_by(1) {
let mut block = Vec::with_capacity(tilesize * tilesize);
for bh in 0..tilesize {
for bw in 0..tilesize {
block
.push(img.get_pixel((w + bw as u32) % width, (h + bh as u32) % height));
}
}
let grid_pos =
IVector::new([(w / tilesize as u32) as i32, (h / tilesize as u32) as i32]);
let id = if let Some(pos) = rules.iter().position(|r| r == &block) {
pos
} else {
let new_id = rules.len();
rules.push(block.clone());
weights.push(0);
id_to_pos.push(grid_pos);
tile_to_id.insert(block, new_id);
new_id
};
weights[id] += 1;
adjacency.insert(id, grid_pos);
}
}
let all_ids = BitVec::<u64, Lsb0>::repeat(true, rules.len());
let mut possibilities = Grid::new(output_size);
for x in 0..output_size[0] {
for y in 0..output_size[1] {
possibilities.insert(all_ids.clone(), IVector::new([x, y]));
}
}
let mut compatibility = vec![
vec![
BitVec::<u64, Lsb0>::repeat(false, rules.len()),
BitVec::<u64, Lsb0>::repeat(false, rules.len()),
BitVec::<u64, Lsb0>::repeat(false, rules.len()),
BitVec::<u64, Lsb0>::repeat(false, rules.len()),
];
rules.len()
];
let dirs = [
Direction::Left,
Direction::Right,
Direction::Up,
Direction::Down,
];
for i in 0..rules.len() {
for &dir in &dirs {
for j in 0..rules.len() {
if WFC::<Rgba<u8>, 2>::tiles_match(&rules[i], &rules[j], dir as usize, tilesize)
{
compatibility[i][dir as usize].set(j, true);
}
}
}
}
let mut buckets = Vec::new();
for _ in 0..=rules.len() {
buckets.push(VecDeque::new());
}
buckets[rules.len()].extend(possibilities.positions());
println!("Extracted {} unique tiles.", rules.len());
Self {
possibilities,
rules,
compatibility,
buckets,
weights,
}
}
pub fn save_image(&self, path: &str) {
let shape = self.possibilities.shape();
let width = shape[0] as usize;
let height = shape[1] as usize;
let tile_averages: Vec<[u8; 4]> = self
.rules
.par_iter()
.map(|tile| {
let mut sum = [0u32; 4];
for p in tile {
for c in 0..4 {
sum[c] += p[c] as u32;
}
}
let n = tile.len() as u32;
[
(sum[0] / n) as u8,
(sum[1] / n) as u8,
(sum[2] / n) as u8,
(sum[3] / n) as u8,
]
})
.collect();
let mut pixels: Vec<[u8; 4]> = vec![[0; 4]; width * height];
pixels.par_iter_mut().enumerate().for_each(|(i, pixel)| {
let x = i % width;
let y = i / width;
let pos = IVector::new([x as i32, y as i32]);
let set = self.possibilities.get(pos).unwrap();
let count = set.count_ones() as u32;
if count == 0 {
return;
}
if count == 1 {
let id = set.iter_ones().next().unwrap();
let top_left = self.rules[id][0];
*pixel = [top_left[0], top_left[1], top_left[2], top_left[3]];
} else {
let mut sum = [0u32; 4];
for id in set.iter_ones() {
let avg = &tile_averages[id];
for c in 0..4 {
sum[c] += avg[c] as u32;
}
}
*pixel = [
(sum[0] / count) as u8,
(sum[1] / count) as u8,
(sum[2] / count) as u8,
(sum[3] / count) as u8,
];
}
});
let mut output = RgbaImage::new(width as u32, height as u32);
for y in 0..height {
for x in 0..width {
output.put_pixel(x as u32, y as u32, image::Rgba(pixels[y * width + x]));
}
}
output.save(path).unwrap();
}
}
#[test]
fn basic_image() {
let mut wfc = WFC::from_image("Spirals.png", 3, IVector::new([300, 300]));
let mut iterations = 0;
while !wfc.collapse() {
iterations += 1;
if iterations % 1000 == 0 {
println!("{}", iterations);
}
}
wfc.save_image(&format!("output.png"));
}