1use oxygengine_utils::{grid_2d::Grid2d, Scalar};
2#[cfg(feature = "parallel")]
3use rayon::prelude::*;
4use std::{
5 collections::{HashSet, VecDeque},
6 iter::FromIterator,
7};
8
9const NEIGHBOR_COORD_DIRS: [Direction; 4] = [
10 Direction::Left,
11 Direction::Right,
12 Direction::Top,
13 Direction::Bottom,
14];
15
16#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
17pub enum Direction {
18 Left,
19 Right,
20 Top,
21 Bottom,
22}
23
24#[derive(Debug, Clone, Copy)]
25pub enum WaveFunctionCollapseError {
26 FoundPatternWithZeroFrequency(usize),
28 FoundEmptyPattern(usize),
30 SuperpositionCellHasNoPattern(usize, usize),
32 FoundUncollapsedCell,
33 FoundImpossibleInitialState,
34 BuilderInProgress,
35}
36
37#[derive(Debug, Clone)]
38pub enum WaveFunctionCollapseResult<T> {
39 Incomplete,
40 Collapsed(Grid2d<T>),
41 Impossible,
42}
43
44#[derive(Debug, Default, Clone)]
45pub struct WaveFunctionCollapseModel<T>
46where
47 T: Clone + Send + Sync + PartialEq,
48{
49 patterns: Vec<(Grid2d<T>, Scalar)>,
51 neighbors: Vec<HashSet<(usize, Direction)>>,
53}
54
55impl<T> WaveFunctionCollapseModel<T>
56where
57 T: Clone + Send + Sync + PartialEq,
58{
59 pub fn from_patterns(
60 patterns: Vec<(Grid2d<T>, usize)>,
61 ) -> Result<Self, WaveFunctionCollapseError> {
62 for (i, (p, f)) in patterns.iter().enumerate() {
63 if *f == 0 {
64 return Err(WaveFunctionCollapseError::FoundPatternWithZeroFrequency(i));
65 } else if p.is_empty() {
66 return Err(WaveFunctionCollapseError::FoundEmptyPattern(i));
67 }
68 }
69 let total = patterns.iter().fold(0, |a, (_, f)| a + f) as Scalar;
70 let mut unique = Vec::with_capacity(patterns.len());
71 for (p, f) in patterns {
72 if let Some((_, f2)) = unique.iter_mut().find(|(p2, _)| &p == p2) {
73 *f2 += f;
74 } else {
75 unique.push((p, f));
76 }
77 }
78 let patterns = unique
79 .into_iter()
80 .map(|(p, f)| (p, f as Scalar / total))
81 .collect::<Vec<_>>();
82 let mut neighbors = vec![HashSet::default(); patterns.len()];
83 for (ai, (ap, _)) in patterns.iter().enumerate() {
84 let (ac, ar) = ap.size();
85 for (bi, (bp, _)) in patterns.iter().enumerate() {
86 let (bc, br) = bp.size();
87 if ar == br {
88 if bp.has_union_with(ap, 1, 0) {
89 neighbors[ai].insert((bi, Direction::Left));
90 neighbors[bi].insert((ai, Direction::Right));
91 }
92 if ap.has_union_with(bp, 1, 0) {
93 neighbors[ai].insert((bi, Direction::Right));
94 neighbors[bi].insert((ai, Direction::Left));
95 }
96 }
97 if ac == bc {
98 if bp.has_union_with(ap, 0, 1) {
99 neighbors[ai].insert((bi, Direction::Top));
100 neighbors[bi].insert((ai, Direction::Bottom));
101 }
102 if ap.has_union_with(bp, 0, 1) {
103 neighbors[ai].insert((bi, Direction::Bottom));
104 neighbors[bi].insert((ai, Direction::Top));
105 }
106 }
107 }
108 }
109 Ok(Self {
110 patterns,
111 neighbors,
112 })
113 }
114
115 pub fn from_views(
116 sample_size: (usize, usize),
117 seamless: bool,
118 views: Vec<Grid2d<Option<T>>>,
119 ) -> Result<Self, WaveFunctionCollapseError> {
120 let f = |w: Grid2d<&Option<T>>| {
121 let items = w
122 .iter()
123 .filter_map(|c| c.as_ref().cloned())
124 .collect::<Vec<_>>();
125 if items.len() == w.len() {
126 Some((Grid2d::with_cells(w.cols(), items), 1))
127 } else {
128 None
129 }
130 };
131 let patterns = views
132 .into_iter()
133 .flat_map(|view| {
134 if seamless {
135 view.windows_seamless(sample_size)
136 .filter_map(f)
137 .collect::<Vec<_>>()
138 } else {
139 view.windows(sample_size).filter_map(f).collect::<Vec<_>>()
140 }
141 })
142 .collect();
143 Self::from_patterns(patterns)
144 }
145
146 pub fn patterns(&self) -> &[(Grid2d<T>, Scalar)] {
148 &self.patterns
149 }
150
151 pub fn neighbors(&self) -> &[HashSet<(usize, Direction)>] {
153 &self.neighbors
154 }
155}
156
157#[derive(Debug, Clone)]
158struct Cell {
159 patterns: HashSet<usize>,
160 entropy: Scalar,
161}
162
163#[derive(Debug, Clone, Copy)]
164enum BuilderPhase {
165 Process(usize),
167 Done,
168 Error(WaveFunctionCollapseError),
169}
170
171#[derive(Clone)]
172pub struct WaveFunctionCollapseSolverBuilder<T>
173where
174 T: Clone + Send + Sync + PartialEq,
175{
176 model: WaveFunctionCollapseModel<T>,
177 superposition: [Grid2d<Cell>; 2],
178 current: usize,
179 phase: BuilderPhase,
180 cells_per_step: usize,
181}
182
183impl<T> WaveFunctionCollapseSolverBuilder<T>
184where
185 T: Clone + Send + Sync + PartialEq,
186{
187 fn new(
188 model: WaveFunctionCollapseModel<T>,
189 superposition: Grid2d<Vec<T>>,
190 cells_per_step: Option<usize>,
191 ) -> Result<Self, WaveFunctionCollapseError> {
192 let (cols, rows) = superposition.size();
193 let cells = superposition
194 .iter_view((0, 0)..(cols, rows))
195 .map(|(col, row, cells)| {
196 let patterns = cells
197 .iter()
198 .flat_map(|cell| {
199 model
200 .patterns()
201 .iter()
202 .enumerate()
203 .filter_map(|(index, (pattern, _))| {
204 let pattern_cell = pattern.cell(0, 0).unwrap();
205 if cell == pattern_cell {
206 Some(index)
207 } else {
208 None
209 }
210 })
211 .collect::<HashSet<_>>()
212 })
213 .collect::<HashSet<_>>();
214 if patterns.is_empty() {
215 Err(WaveFunctionCollapseError::SuperpositionCellHasNoPattern(
216 col, row,
217 ))
218 } else {
219 let entropy = calculate_entropy(&model, &patterns);
220 Ok(Cell { patterns, entropy })
221 }
222 })
223 .collect::<Result<Vec<_>, _>>()?;
224 let max_patterns = cells
225 .iter()
226 .map(|cell| cell.patterns.len())
227 .max_by(|a, b| a.cmp(b))
228 .unwrap_or(1);
229 let cells_per_step = if let Some(cells_per_step) = cells_per_step {
230 cells_per_step
231 } else if max_patterns > 0 {
232 cells.len() / max_patterns
233 } else {
234 cells.len()
235 }
236 .max(1);
237 let superposition = Grid2d::with_cells(cols, cells);
238 Ok(Self {
239 model,
240 superposition: [superposition.clone(), superposition],
241 current: 0,
242 phase: BuilderPhase::Process(0),
243 cells_per_step,
244 })
245 }
246
247 pub fn process(&mut self) -> bool {
249 match self.phase {
250 BuilderPhase::Done | BuilderPhase::Error(_) => false,
251 BuilderPhase::Process(mut index) => {
252 let mut remaining = self.cells_per_step;
253 let mut reduced = false;
254 let cols = self.source().cols();
255 let rows = self.source().rows();
256 let count = self.source().len();
257 while index < count && remaining > 0 {
258 let col = index % cols;
259 let row = index / cols;
260 let patterns = &self.source().cell(col, row).unwrap().patterns;
261 let count = patterns.len();
262 match count {
263 0 | 1 => {
264 let cell = Cell {
265 patterns: patterns.clone(),
266 entropy: 0.0,
267 };
268 self.target().set(col, row, cell)
269 }
270 _ => {
271 let samples = [
272 self.source().cell((cols + col - 1) % cols, row).unwrap(),
273 self.source().cell((col + 1) % cols, row).unwrap(),
274 self.source().cell(col, (rows + row - 1) % rows).unwrap(),
275 self.source().cell(col, (row + 1) % rows).unwrap(),
276 ];
277 #[cfg(not(feature = "parallel"))]
278 let patterns = patterns.iter();
279 #[cfg(feature = "parallel")]
280 let patterns = patterns.par_iter();
281 let patterns = patterns
282 .filter(|index| {
283 let neighbors = self.model.neighbors().get(**index).unwrap();
284 if neighbors.is_empty() {
285 return false;
286 }
287 NEIGHBOR_COORD_DIRS.iter().enumerate().all(|(i, d)| {
288 samples[i].patterns.iter().any(|n| {
289 neighbors.iter().any(|(neighbor, direction)| {
290 direction == d && neighbor == n
291 })
292 })
293 })
294 })
295 .cloned()
296 .collect::<HashSet<_>>();
297 if patterns.is_empty() {
298 self.phase = BuilderPhase::Error(
299 WaveFunctionCollapseError::FoundImpossibleInitialState,
300 );
301 return false;
302 } else if patterns.len() < count {
303 reduced = true;
304 }
305 let entropy = calculate_entropy(&self.model, &patterns);
306 self.target().set(col, row, Cell { patterns, entropy });
307 }
308 }
309 index += 1;
310 remaining -= 1;
311 }
312 if index == count {
313 if reduced {
314 self.phase = BuilderPhase::Process(0);
315 self.current = (self.current + 1) % 2;
316 true
317 } else {
318 self.phase = BuilderPhase::Done;
319 false
320 }
321 } else {
322 self.phase = BuilderPhase::Process(index);
323 true
324 }
325 }
326 }
327 }
328
329 pub fn progress(&self) -> (usize, usize) {
331 let count = self.source().len();
332 match self.phase {
333 BuilderPhase::Done | BuilderPhase::Error(_) => (count, count),
334 BuilderPhase::Process(index) => (index, count),
335 }
336 }
337
338 pub fn build(self) -> Result<WaveFunctionCollapseSolver<T>, WaveFunctionCollapseError> {
339 match self.phase {
340 BuilderPhase::Error(error) => Err(error),
341 BuilderPhase::Done => {
342 let count = self.source().len();
343 Ok(WaveFunctionCollapseSolver {
344 superposition: self.source().clone(),
345 model: self.model,
346 cached_progress: 0,
347 cached_open: VecDeque::with_capacity(count),
348 lately_updated: HashSet::with_capacity(count),
349 })
350 }
351 BuilderPhase::Process(_) => Err(WaveFunctionCollapseError::BuilderInProgress),
352 }
353 }
354
355 fn source(&self) -> &Grid2d<Cell> {
356 &self.superposition[self.current]
357 }
358
359 fn target(&mut self) -> &mut Grid2d<Cell> {
360 &mut self.superposition[(self.current + 1) % 2]
361 }
362}
363
364#[derive(Clone)]
365pub struct WaveFunctionCollapseSolver<T>
366where
367 T: Clone + Send + Sync + PartialEq,
368{
369 model: WaveFunctionCollapseModel<T>,
370 superposition: Grid2d<Cell>,
371 cached_progress: usize,
372 cached_open: VecDeque<(usize, usize)>,
373 lately_updated: HashSet<(usize, usize)>,
374}
375
376impl<T> std::fmt::Debug for WaveFunctionCollapseSolver<T>
377where
378 T: Clone + Send + Sync + PartialEq + std::fmt::Debug,
379{
380 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
381 f.debug_struct("WaveFunctionCollapseSolver")
382 .field("model", &self.model)
383 .field("superposition", &self.superposition)
384 .field("cached_progress", &self.cached_progress)
385 .field("cached_open", &self.cached_open)
386 .field("lately_updated", &self.lately_updated)
387 .finish()
388 }
389}
390
391impl<T> WaveFunctionCollapseSolver<T>
392where
393 T: Clone + Send + Sync + PartialEq,
394{
395 pub fn lately_updated(&self) -> &HashSet<(usize, usize)> {
396 &self.lately_updated
397 }
398
399 pub fn lately_updated_uncollapsed_cells<V>(&self) -> Vec<(usize, usize, V)>
400 where
401 V: FromIterator<T>,
402 {
403 self.lately_updated
404 .iter()
405 .filter_map(|(col, row)| {
406 self.superposition.cell(*col, *row).map(|cell| {
407 let items = cell
408 .patterns
409 .iter()
410 .map(|index| self.model.patterns()[*index].0.cell(0, 0).unwrap().clone())
411 .collect::<V>();
412 (*col, *row, items)
413 })
414 })
415 .collect()
416 }
417
418 pub fn lately_updated_collapsed_cells(&self) -> Vec<(usize, usize, T)> {
419 self.lately_updated
420 .iter()
421 .filter_map(|(col, row)| {
422 if let Some(cell) = self.superposition.cell(*col, *row) {
423 if cell.patterns.len() == 1 {
424 let index = *cell.patterns.iter().next().unwrap();
425 let item = self.model.patterns()[index].0.cell(0, 0).unwrap().clone();
426 Some((*col, *row, item))
427 } else {
428 None
429 }
430 } else {
431 None
432 }
433 })
434 .collect()
435 }
436
437 pub fn build(
438 model: WaveFunctionCollapseModel<T>,
439 superposition: Grid2d<Vec<T>>,
440 cells_per_step: Option<usize>,
441 ) -> Result<WaveFunctionCollapseSolverBuilder<T>, WaveFunctionCollapseError> {
442 WaveFunctionCollapseSolverBuilder::new(model, superposition, cells_per_step)
443 }
444
445 pub fn new(
446 model: WaveFunctionCollapseModel<T>,
447 superposition: Grid2d<Vec<T>>,
448 ) -> Result<Self, WaveFunctionCollapseError> {
449 let count = superposition.len();
450 let mut builder = Self::build(model, superposition, Some(count))?;
451 while builder.process() {}
452 builder.build()
453 }
454
455 pub fn new_inspect<F>(
456 model: WaveFunctionCollapseModel<T>,
457 superposition: Grid2d<Vec<T>>,
458 cells_per_step: Option<usize>,
459 mut f: F,
460 ) -> Result<Self, WaveFunctionCollapseError>
461 where
462 F: FnMut(usize, usize),
463 {
464 let mut builder = Self::build(model, superposition, cells_per_step)?;
465 let (p, m) = builder.progress();
466 f(p, m);
467 while builder.process() {
468 let (p, m) = builder.progress();
469 f(p, m);
470 }
471 let (p, m) = builder.progress();
472 f(p, m);
473 builder.build()
474 }
475
476 pub fn collapse<R>(&mut self, gen_range: R) -> WaveFunctionCollapseResult<T>
477 where
478 R: FnMut(Scalar, Scalar) -> Scalar + Clone,
479 {
480 loop {
481 match self.collapse_step(gen_range.clone()) {
482 WaveFunctionCollapseResult::Incomplete => continue,
483 result => return result,
484 }
485 }
486 }
487
488 pub fn collapse_with_tries<R>(
489 &mut self,
490 mut tries: usize,
491 gen_range: R,
492 ) -> WaveFunctionCollapseResult<T>
493 where
494 R: FnMut(Scalar, Scalar) -> Scalar + Clone,
495 {
496 while tries > 0 {
497 match self.collapse(gen_range.clone()) {
498 WaveFunctionCollapseResult::Impossible => {
499 tries -= 1;
500 continue;
501 }
502 result => return result,
503 }
504 }
505 WaveFunctionCollapseResult::Impossible
506 }
507
508 pub fn collapse_inspect<R, F>(
509 &mut self,
510 gen_range: R,
511 mut f: F,
512 ) -> WaveFunctionCollapseResult<T>
513 where
514 F: FnMut(usize, usize, &Self),
515 R: FnMut(Scalar, Scalar) -> Scalar + Clone,
516 {
517 loop {
518 match self.collapse_step(gen_range.clone()) {
519 WaveFunctionCollapseResult::Incomplete => {
520 let (p, m) = self.progress();
521 f(p, m, self);
522 continue;
523 }
524 result => return result,
525 }
526 }
527 }
528
529 pub fn collapse_inspect_with_tries<R, F>(
530 &mut self,
531 mut tries: usize,
532 gen_range: R,
533 mut f: F,
534 ) -> WaveFunctionCollapseResult<T>
535 where
536 F: FnMut() -> Box<dyn FnMut(usize, usize, &Self)>,
537 R: FnMut(Scalar, Scalar) -> Scalar + Clone,
538 {
539 while tries > 0 {
540 match self.collapse_inspect(gen_range.clone(), f()) {
541 WaveFunctionCollapseResult::Impossible => {
542 tries -= 1;
543 continue;
544 }
545 result => return result,
546 }
547 }
548 WaveFunctionCollapseResult::Impossible
549 }
550
551 pub fn collapse_step<R>(&mut self, gen_range: R) -> WaveFunctionCollapseResult<T>
552 where
553 R: FnMut(Scalar, Scalar) -> Scalar,
554 {
555 let coord = if let Ok(coord) = self.get_uncollapsed_coord() {
556 coord
557 } else {
558 return WaveFunctionCollapseResult::Impossible;
559 };
560 let (col, row) = if let Some(coord) = coord {
561 coord
562 } else if let Ok(collapsed) =
563 Self::superposition_to_collapsed_world(&self.model, &self.superposition)
564 {
565 return WaveFunctionCollapseResult::Collapsed(collapsed);
566 } else {
567 return WaveFunctionCollapseResult::Impossible;
568 };
569 self.lately_updated.clear();
570 if !self.collapse_cell(col, row, gen_range) {
571 return WaveFunctionCollapseResult::Impossible;
572 }
573 self.lately_updated.insert((col, row));
574 let (cols, rows) = self.superposition.size();
575 self.cached_open.push_back(((col + cols - 1) % cols, row));
576 self.cached_open.push_back(((col + 1) % cols, row));
577 self.cached_open.push_back((col, (row + rows - 1) % rows));
578 self.cached_open.push_back((col, (row + 1) % rows));
579 while !self.cached_open.is_empty() {
580 self.partially_reduce_superposition();
581 }
582 self.cached_progress =
583 self.superposition
584 .iter()
585 .fold(0, |a, c| if c.patterns.len() == 1 { a + 1 } else { a });
586 WaveFunctionCollapseResult::Incomplete
587 }
588
589 pub fn progress(&self) -> (usize, usize) {
590 (self.cached_progress, self.superposition.len())
591 }
592
593 fn collapse_cell<R>(&mut self, col: usize, row: usize, mut gen_range: R) -> bool
594 where
595 R: FnMut(Scalar, Scalar) -> Scalar,
596 {
597 let patterns = self.model.patterns();
598 let cell = self.superposition.cell(col, row).unwrap();
599 let total = cell
600 .patterns
601 .iter()
602 .fold(0.0, |accum, index| accum + patterns[*index].1);
603 let mut selected = gen_range(0.0, total);
604 for index in cell.patterns.iter() {
605 let weight = patterns[*index].1;
606 if selected <= weight {
607 let mut patterns = HashSet::with_capacity(1);
608 patterns.insert(*index);
609 self.superposition.set(
610 col,
611 row,
612 Cell {
613 patterns,
614 entropy: 0.0,
615 },
616 );
617 return true;
618 }
619 selected -= weight;
620 }
621 false
622 }
623
624 fn get_uncollapsed_coord(&self) -> Result<Option<(usize, usize)>, ()> {
625 if self
626 .superposition
627 .iter()
628 .any(|cell| cell.patterns.is_empty())
629 {
630 return Err(());
631 }
632 let cols = self.superposition.cols();
633 let result = {
634 let mut result = None;
635 for (index, cell) in self.superposition.iter().enumerate() {
636 let col = index % cols;
637 let row = index / cols;
638 if cell.patterns.len() > 1 {
639 if let Some((_, _, entropy)) = result {
640 if cell.entropy < entropy {
641 result = Some((col, row, cell.entropy));
642 }
643 } else {
644 result = Some((col, row, cell.entropy));
645 }
646 }
647 }
648 result
649 };
650 Ok(result.map(|(col, row, _)| (col, row)))
651 }
652
653 fn partially_reduce_superposition(&mut self) {
654 if self.cached_open.is_empty() {
655 return;
656 }
657 let (col, row) = self.cached_open.pop_front().unwrap();
658 let (cols, rows) = self.superposition.size();
659 let patterns = &self.superposition.cell(col, row).unwrap().patterns;
660 let count = patterns.len();
661 if count > 1 {
662 let samples = [
663 self.superposition
664 .cell((cols + col - 1) % cols, row)
665 .unwrap(),
666 self.superposition.cell((col + 1) % cols, row).unwrap(),
667 self.superposition
668 .cell(col, (rows + row - 1) % rows)
669 .unwrap(),
670 self.superposition.cell(col, (row + 1) % rows).unwrap(),
671 ];
672 let neighbors = self.model.neighbors();
673 #[cfg(not(feature = "parallel"))]
674 let patterns = patterns.iter();
675 #[cfg(feature = "parallel")]
676 let patterns = patterns.par_iter();
677 let patterns = patterns
678 .filter(|index| {
679 let neighbors = neighbors.get(**index).unwrap();
680 if neighbors.is_empty() {
681 return false;
682 }
683 NEIGHBOR_COORD_DIRS.iter().enumerate().all(|(i, d)| {
684 samples[i].patterns.iter().any(|n| {
685 neighbors
686 .iter()
687 .any(|(neighbor, direction)| direction == d && neighbor == n)
688 })
689 })
690 })
691 .cloned()
692 .collect::<HashSet<_>>();
693 if patterns.len() < count {
694 self.lately_updated.insert((col, row));
695 let coord = ((col + cols - 1) % cols, row);
696 if samples[0].patterns.len() > 1 && !self.cached_open.contains(&coord) {
697 self.cached_open.push_back(coord);
698 }
699 let coord = ((col + 1) % cols, row);
700 if samples[1].patterns.len() > 1 && !self.cached_open.contains(&coord) {
701 self.cached_open.push_back(coord);
702 }
703 let coord = (col, (row + rows - 1) % rows);
704 if samples[2].patterns.len() > 1 && !self.cached_open.contains(&coord) {
705 self.cached_open.push_back(coord);
706 }
707 let coord = (col, (row + 1) % rows);
708 if samples[3].patterns.len() > 1 && !self.cached_open.contains(&coord) {
709 self.cached_open.push_back(coord);
710 }
711 let entropy = calculate_entropy(&self.model, &patterns);
712 self.superposition.set(col, row, Cell { patterns, entropy });
713 }
714 }
715 }
716
717 pub fn get_uncollapsed_world(&self) -> Grid2d<Vec<T>> {
718 let cols = self.superposition.cols();
719 let cells = self
720 .superposition
721 .iter()
722 .map(|cell| {
723 cell.patterns
724 .iter()
725 .map(|index| self.model.patterns()[*index].0.cell(0, 0).unwrap().clone())
726 .collect::<Vec<_>>()
727 })
728 .collect::<Vec<_>>();
729 Grid2d::with_cells(cols, cells)
730 }
731
732 fn superposition_to_collapsed_world(
733 model: &WaveFunctionCollapseModel<T>,
734 superposition: &Grid2d<Cell>,
735 ) -> Result<Grid2d<T>, WaveFunctionCollapseError> {
736 let cols = superposition.cols();
737 let cells = superposition
738 .iter()
739 .map(|cell| {
740 if cell.patterns.len() == 1 {
741 let index = cell.patterns.iter().next().unwrap();
742 Ok(model.patterns()[*index].0.cell(0, 0).unwrap().clone())
743 } else {
744 Err(WaveFunctionCollapseError::FoundUncollapsedCell)
745 }
746 })
747 .collect::<Result<Vec<_>, _>>()?;
748 Ok(Grid2d::with_cells(cols, cells))
749 }
750}
751
752fn calculate_entropy<T>(model: &WaveFunctionCollapseModel<T>, patterns: &HashSet<usize>) -> Scalar
753where
754 T: Clone + Send + Sync + PartialEq,
755{
756 if patterns.is_empty() {
757 return 0.0;
758 }
759 let mut total_weight = 0.0;
760 let mut total_weight_log = 0.0;
761 for index in patterns {
762 let weight = model.patterns()[*index].1;
763 total_weight += weight;
764 total_weight_log += weight * weight.log2();
765 }
766 total_weight.log2() - (total_weight_log / total_weight)
767}
768
769#[cfg(test)]
770mod tests {
771 use super::*;
772
773 #[allow(dead_code)]
774 fn parse_view(data: &str) -> Grid2d<Option<char>> {
775 let lines = data
776 .split(|c| c == '\n' || c == '\r')
777 .filter(|l| !l.is_empty())
778 .collect::<Vec<_>>();
779 let cols = lines.iter().fold(0, |a, l| a.max(l.len()));
780 let rows = lines.len();
781 let mut result = Grid2d::new(cols, rows, None);
782 for (row, line) in lines.into_iter().enumerate() {
783 for (col, character) in line.chars().enumerate() {
784 if !character.is_whitespace() {
785 result.set(col, row, Some(character));
786 }
787 }
788 }
789 print_view("= VIEW:", &result);
790 result
791 }
792
793 #[allow(dead_code)]
794 fn print_view(msg: &str, pattern: &Grid2d<Option<char>>) {
795 println!("{}", msg);
796 for row in 0..pattern.rows() {
797 for cell in pattern.get_row_cells(row).unwrap() {
798 if let Some(cell) = cell {
799 print!("{}", cell);
800 } else {
801 print!(" ");
802 }
803 }
804 println!();
805 }
806 }
807
808 #[allow(dead_code)]
809 fn print_collapsed_world(msg: &str, world: &Grid2d<char>) {
810 println!("{}", msg);
811 for row in 0..world.rows() {
812 for cell in world.get_row_cells(row).unwrap() {
813 print!("{}", cell);
814 }
815 println!();
816 }
817 }
818
819 #[allow(dead_code)]
820 fn print_uncollapsed_world(msg: &str, world: &Grid2d<Vec<char>>, uncertain: char) {
821 println!("{}", msg);
822 for row in 0..world.rows() {
823 for cell in world.get_row_cells(row).unwrap() {
824 if cell.len() == 1 {
825 print!("{}", cell[0]);
826 } else {
827 print!("{}", uncertain);
828 }
829 }
830 println!();
831 }
832 }
833
834 #[allow(dead_code)]
835 fn print_pattern(msg: &str, pattern: &Grid2d<char>) {
836 println!("{}", msg);
837 for row in 0..pattern.rows() {
838 for cell in pattern.get_row_cells(row).unwrap() {
839 print!("{}", cell);
840 }
841 println!();
842 }
843 }
844
845 #[test]
846 #[cfg(feature = "longrun")]
847 fn test_general() {
848 use rand::{thread_rng, Rng};
849 use std::time::Instant;
850
851 let view = parse_view(include_str!("../resources/view.txt"));
852 let values = {
853 let mut values = view.iter().filter_map(|c| c.clone()).collect::<Vec<_>>();
854 values.sort();
855 values.dedup();
856 values
857 };
858 println!("= VALUES: {:?}", values);
859 let model = WaveFunctionCollapseModel::from_views((3, 3), true, vec![view]).unwrap();
860 let world = Grid2d::new(75, 75, values);
861 let timer = Instant::now();
862 let mut timer2 = Instant::now();
863 let mut solver = WaveFunctionCollapseSolver::new_inspect(model, world, None, |p, m| {
864 if timer2.elapsed().as_millis() >= 400 {
865 timer2 = Instant::now();
866 println!(
867 "= INITIALIZE: {} / {} ({}%)",
868 p,
869 m,
870 100.0 * p as Scalar / m as Scalar
871 );
872 }
873 })
874 .unwrap();
875 println!("= INITIALIZED IN: {:?}", timer.elapsed());
876 let timer = Instant::now();
877 let mut timer2 = Instant::now();
878 let mut rng = thread_rng();
879 let mut max_changes = 0;
880 let result = solver.collapse_inspect(
881 move |f, t| rng.gen_range(f..t),
882 |p, m, s| {
883 max_changes = max_changes.max(s.lately_updated().len());
884 if timer2.elapsed().as_millis() >= 400 {
885 timer2 = Instant::now();
886 println!();
887 println!();
888 print_uncollapsed_world(
889 "= UNCOLLAPSED WORLD:",
890 &s.get_uncollapsed_world(),
891 ' ',
892 );
893 println!(
894 "= PROGRESS: {} / {} ({}%)",
895 p,
896 m,
897 100.0 * p as Scalar / m as Scalar
898 )
899 }
900 },
901 );
902 match result {
903 WaveFunctionCollapseResult::Collapsed(world) => {
904 println!();
905 println!();
906 println!(
907 "= COLLAPSED IN: {:?} | MAX CHANGES: {}",
908 timer.elapsed(),
909 max_changes
910 );
911 print_collapsed_world("= COLLAPSED WORLD:", &world);
912 }
913 _ => panic!("= IMPOSSIBLE WORLD"),
914 }
915 }
916}