oxygengine_procedural/
wave_function_collapse.rs

1use oxygengine_utils::{grid_2d::Grid2d, Scalar};
2#[cfg(feature = "parallel")]
3use rayon::prelude::*;
4use std::{
5    collections::{HashSet, VecDeque},
6    iter::FromIterator,
7};
8
9const NEIGHBOR_COORD_DIRS: [Direction; 4] = [
10    Direction::Left,
11    Direction::Right,
12    Direction::Top,
13    Direction::Bottom,
14];
15
16#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
17pub enum Direction {
18    Left,
19    Right,
20    Top,
21    Bottom,
22}
23
24#[derive(Debug, Clone, Copy)]
25pub enum WaveFunctionCollapseError {
26    /// pattern index
27    FoundPatternWithZeroFrequency(usize),
28    /// pattern index
29    FoundEmptyPattern(usize),
30    /// (col, row)
31    SuperpositionCellHasNoPattern(usize, usize),
32    FoundUncollapsedCell,
33    FoundImpossibleInitialState,
34    BuilderInProgress,
35}
36
37#[derive(Debug, Clone)]
38pub enum WaveFunctionCollapseResult<T> {
39    Incomplete,
40    Collapsed(Grid2d<T>),
41    Impossible,
42}
43
44#[derive(Debug, Default, Clone)]
45pub struct WaveFunctionCollapseModel<T>
46where
47    T: Clone + Send + Sync + PartialEq,
48{
49    /// [(pattern, weight)]
50    patterns: Vec<(Grid2d<T>, Scalar)>,
51    /// [[(pattern, direction)]]
52    neighbors: Vec<HashSet<(usize, Direction)>>,
53}
54
55impl<T> WaveFunctionCollapseModel<T>
56where
57    T: Clone + Send + Sync + PartialEq,
58{
59    pub fn from_patterns(
60        patterns: Vec<(Grid2d<T>, usize)>,
61    ) -> Result<Self, WaveFunctionCollapseError> {
62        for (i, (p, f)) in patterns.iter().enumerate() {
63            if *f == 0 {
64                return Err(WaveFunctionCollapseError::FoundPatternWithZeroFrequency(i));
65            } else if p.is_empty() {
66                return Err(WaveFunctionCollapseError::FoundEmptyPattern(i));
67            }
68        }
69        let total = patterns.iter().fold(0, |a, (_, f)| a + f) as Scalar;
70        let mut unique = Vec::with_capacity(patterns.len());
71        for (p, f) in patterns {
72            if let Some((_, f2)) = unique.iter_mut().find(|(p2, _)| &p == p2) {
73                *f2 += f;
74            } else {
75                unique.push((p, f));
76            }
77        }
78        let patterns = unique
79            .into_iter()
80            .map(|(p, f)| (p, f as Scalar / total))
81            .collect::<Vec<_>>();
82        let mut neighbors = vec![HashSet::default(); patterns.len()];
83        for (ai, (ap, _)) in patterns.iter().enumerate() {
84            let (ac, ar) = ap.size();
85            for (bi, (bp, _)) in patterns.iter().enumerate() {
86                let (bc, br) = bp.size();
87                if ar == br {
88                    if bp.has_union_with(ap, 1, 0) {
89                        neighbors[ai].insert((bi, Direction::Left));
90                        neighbors[bi].insert((ai, Direction::Right));
91                    }
92                    if ap.has_union_with(bp, 1, 0) {
93                        neighbors[ai].insert((bi, Direction::Right));
94                        neighbors[bi].insert((ai, Direction::Left));
95                    }
96                }
97                if ac == bc {
98                    if bp.has_union_with(ap, 0, 1) {
99                        neighbors[ai].insert((bi, Direction::Top));
100                        neighbors[bi].insert((ai, Direction::Bottom));
101                    }
102                    if ap.has_union_with(bp, 0, 1) {
103                        neighbors[ai].insert((bi, Direction::Bottom));
104                        neighbors[bi].insert((ai, Direction::Top));
105                    }
106                }
107            }
108        }
109        Ok(Self {
110            patterns,
111            neighbors,
112        })
113    }
114
115    pub fn from_views(
116        sample_size: (usize, usize),
117        seamless: bool,
118        views: Vec<Grid2d<Option<T>>>,
119    ) -> Result<Self, WaveFunctionCollapseError> {
120        let f = |w: Grid2d<&Option<T>>| {
121            let items = w
122                .iter()
123                .filter_map(|c| c.as_ref().cloned())
124                .collect::<Vec<_>>();
125            if items.len() == w.len() {
126                Some((Grid2d::with_cells(w.cols(), items), 1))
127            } else {
128                None
129            }
130        };
131        let patterns = views
132            .into_iter()
133            .flat_map(|view| {
134                if seamless {
135                    view.windows_seamless(sample_size)
136                        .filter_map(f)
137                        .collect::<Vec<_>>()
138                } else {
139                    view.windows(sample_size).filter_map(f).collect::<Vec<_>>()
140                }
141            })
142            .collect();
143        Self::from_patterns(patterns)
144    }
145
146    /// [(pattern, weight)]
147    pub fn patterns(&self) -> &[(Grid2d<T>, Scalar)] {
148        &self.patterns
149    }
150
151    /// [[pattern, direction]]
152    pub fn neighbors(&self) -> &[HashSet<(usize, Direction)>] {
153        &self.neighbors
154    }
155}
156
157#[derive(Debug, Clone)]
158struct Cell {
159    patterns: HashSet<usize>,
160    entropy: Scalar,
161}
162
163#[derive(Debug, Clone, Copy)]
164enum BuilderPhase {
165    /// current cell index
166    Process(usize),
167    Done,
168    Error(WaveFunctionCollapseError),
169}
170
171#[derive(Clone)]
172pub struct WaveFunctionCollapseSolverBuilder<T>
173where
174    T: Clone + Send + Sync + PartialEq,
175{
176    model: WaveFunctionCollapseModel<T>,
177    superposition: [Grid2d<Cell>; 2],
178    current: usize,
179    phase: BuilderPhase,
180    cells_per_step: usize,
181}
182
183impl<T> WaveFunctionCollapseSolverBuilder<T>
184where
185    T: Clone + Send + Sync + PartialEq,
186{
187    fn new(
188        model: WaveFunctionCollapseModel<T>,
189        superposition: Grid2d<Vec<T>>,
190        cells_per_step: Option<usize>,
191    ) -> Result<Self, WaveFunctionCollapseError> {
192        let (cols, rows) = superposition.size();
193        let cells = superposition
194            .iter_view((0, 0)..(cols, rows))
195            .map(|(col, row, cells)| {
196                let patterns = cells
197                    .iter()
198                    .flat_map(|cell| {
199                        model
200                            .patterns()
201                            .iter()
202                            .enumerate()
203                            .filter_map(|(index, (pattern, _))| {
204                                let pattern_cell = pattern.cell(0, 0).unwrap();
205                                if cell == pattern_cell {
206                                    Some(index)
207                                } else {
208                                    None
209                                }
210                            })
211                            .collect::<HashSet<_>>()
212                    })
213                    .collect::<HashSet<_>>();
214                if patterns.is_empty() {
215                    Err(WaveFunctionCollapseError::SuperpositionCellHasNoPattern(
216                        col, row,
217                    ))
218                } else {
219                    let entropy = calculate_entropy(&model, &patterns);
220                    Ok(Cell { patterns, entropy })
221                }
222            })
223            .collect::<Result<Vec<_>, _>>()?;
224        let max_patterns = cells
225            .iter()
226            .map(|cell| cell.patterns.len())
227            .max_by(|a, b| a.cmp(b))
228            .unwrap_or(1);
229        let cells_per_step = if let Some(cells_per_step) = cells_per_step {
230            cells_per_step
231        } else if max_patterns > 0 {
232            cells.len() / max_patterns
233        } else {
234            cells.len()
235        }
236        .max(1);
237        let superposition = Grid2d::with_cells(cols, cells);
238        Ok(Self {
239            model,
240            superposition: [superposition.clone(), superposition],
241            current: 0,
242            phase: BuilderPhase::Process(0),
243            cells_per_step,
244        })
245    }
246
247    /// true if has to continue (is not done and has no error)
248    pub fn process(&mut self) -> bool {
249        match self.phase {
250            BuilderPhase::Done | BuilderPhase::Error(_) => false,
251            BuilderPhase::Process(mut index) => {
252                let mut remaining = self.cells_per_step;
253                let mut reduced = false;
254                let cols = self.source().cols();
255                let rows = self.source().rows();
256                let count = self.source().len();
257                while index < count && remaining > 0 {
258                    let col = index % cols;
259                    let row = index / cols;
260                    let patterns = &self.source().cell(col, row).unwrap().patterns;
261                    let count = patterns.len();
262                    match count {
263                        0 | 1 => {
264                            let cell = Cell {
265                                patterns: patterns.clone(),
266                                entropy: 0.0,
267                            };
268                            self.target().set(col, row, cell)
269                        }
270                        _ => {
271                            let samples = [
272                                self.source().cell((cols + col - 1) % cols, row).unwrap(),
273                                self.source().cell((col + 1) % cols, row).unwrap(),
274                                self.source().cell(col, (rows + row - 1) % rows).unwrap(),
275                                self.source().cell(col, (row + 1) % rows).unwrap(),
276                            ];
277                            #[cfg(not(feature = "parallel"))]
278                            let patterns = patterns.iter();
279                            #[cfg(feature = "parallel")]
280                            let patterns = patterns.par_iter();
281                            let patterns = patterns
282                                .filter(|index| {
283                                    let neighbors = self.model.neighbors().get(**index).unwrap();
284                                    if neighbors.is_empty() {
285                                        return false;
286                                    }
287                                    NEIGHBOR_COORD_DIRS.iter().enumerate().all(|(i, d)| {
288                                        samples[i].patterns.iter().any(|n| {
289                                            neighbors.iter().any(|(neighbor, direction)| {
290                                                direction == d && neighbor == n
291                                            })
292                                        })
293                                    })
294                                })
295                                .cloned()
296                                .collect::<HashSet<_>>();
297                            if patterns.is_empty() {
298                                self.phase = BuilderPhase::Error(
299                                    WaveFunctionCollapseError::FoundImpossibleInitialState,
300                                );
301                                return false;
302                            } else if patterns.len() < count {
303                                reduced = true;
304                            }
305                            let entropy = calculate_entropy(&self.model, &patterns);
306                            self.target().set(col, row, Cell { patterns, entropy });
307                        }
308                    }
309                    index += 1;
310                    remaining -= 1;
311                }
312                if index == count {
313                    if reduced {
314                        self.phase = BuilderPhase::Process(0);
315                        self.current = (self.current + 1) % 2;
316                        true
317                    } else {
318                        self.phase = BuilderPhase::Done;
319                        false
320                    }
321                } else {
322                    self.phase = BuilderPhase::Process(index);
323                    true
324                }
325            }
326        }
327    }
328
329    /// (current, max)
330    pub fn progress(&self) -> (usize, usize) {
331        let count = self.source().len();
332        match self.phase {
333            BuilderPhase::Done | BuilderPhase::Error(_) => (count, count),
334            BuilderPhase::Process(index) => (index, count),
335        }
336    }
337
338    pub fn build(self) -> Result<WaveFunctionCollapseSolver<T>, WaveFunctionCollapseError> {
339        match self.phase {
340            BuilderPhase::Error(error) => Err(error),
341            BuilderPhase::Done => {
342                let count = self.source().len();
343                Ok(WaveFunctionCollapseSolver {
344                    superposition: self.source().clone(),
345                    model: self.model,
346                    cached_progress: 0,
347                    cached_open: VecDeque::with_capacity(count),
348                    lately_updated: HashSet::with_capacity(count),
349                })
350            }
351            BuilderPhase::Process(_) => Err(WaveFunctionCollapseError::BuilderInProgress),
352        }
353    }
354
355    fn source(&self) -> &Grid2d<Cell> {
356        &self.superposition[self.current]
357    }
358
359    fn target(&mut self) -> &mut Grid2d<Cell> {
360        &mut self.superposition[(self.current + 1) % 2]
361    }
362}
363
364#[derive(Clone)]
365pub struct WaveFunctionCollapseSolver<T>
366where
367    T: Clone + Send + Sync + PartialEq,
368{
369    model: WaveFunctionCollapseModel<T>,
370    superposition: Grid2d<Cell>,
371    cached_progress: usize,
372    cached_open: VecDeque<(usize, usize)>,
373    lately_updated: HashSet<(usize, usize)>,
374}
375
376impl<T> std::fmt::Debug for WaveFunctionCollapseSolver<T>
377where
378    T: Clone + Send + Sync + PartialEq + std::fmt::Debug,
379{
380    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
381        f.debug_struct("WaveFunctionCollapseSolver")
382            .field("model", &self.model)
383            .field("superposition", &self.superposition)
384            .field("cached_progress", &self.cached_progress)
385            .field("cached_open", &self.cached_open)
386            .field("lately_updated", &self.lately_updated)
387            .finish()
388    }
389}
390
391impl<T> WaveFunctionCollapseSolver<T>
392where
393    T: Clone + Send + Sync + PartialEq,
394{
395    pub fn lately_updated(&self) -> &HashSet<(usize, usize)> {
396        &self.lately_updated
397    }
398
399    pub fn lately_updated_uncollapsed_cells<V>(&self) -> Vec<(usize, usize, V)>
400    where
401        V: FromIterator<T>,
402    {
403        self.lately_updated
404            .iter()
405            .filter_map(|(col, row)| {
406                self.superposition.cell(*col, *row).map(|cell| {
407                    let items = cell
408                        .patterns
409                        .iter()
410                        .map(|index| self.model.patterns()[*index].0.cell(0, 0).unwrap().clone())
411                        .collect::<V>();
412                    (*col, *row, items)
413                })
414            })
415            .collect()
416    }
417
418    pub fn lately_updated_collapsed_cells(&self) -> Vec<(usize, usize, T)> {
419        self.lately_updated
420            .iter()
421            .filter_map(|(col, row)| {
422                if let Some(cell) = self.superposition.cell(*col, *row) {
423                    if cell.patterns.len() == 1 {
424                        let index = *cell.patterns.iter().next().unwrap();
425                        let item = self.model.patterns()[index].0.cell(0, 0).unwrap().clone();
426                        Some((*col, *row, item))
427                    } else {
428                        None
429                    }
430                } else {
431                    None
432                }
433            })
434            .collect()
435    }
436
437    pub fn build(
438        model: WaveFunctionCollapseModel<T>,
439        superposition: Grid2d<Vec<T>>,
440        cells_per_step: Option<usize>,
441    ) -> Result<WaveFunctionCollapseSolverBuilder<T>, WaveFunctionCollapseError> {
442        WaveFunctionCollapseSolverBuilder::new(model, superposition, cells_per_step)
443    }
444
445    pub fn new(
446        model: WaveFunctionCollapseModel<T>,
447        superposition: Grid2d<Vec<T>>,
448    ) -> Result<Self, WaveFunctionCollapseError> {
449        let count = superposition.len();
450        let mut builder = Self::build(model, superposition, Some(count))?;
451        while builder.process() {}
452        builder.build()
453    }
454
455    pub fn new_inspect<F>(
456        model: WaveFunctionCollapseModel<T>,
457        superposition: Grid2d<Vec<T>>,
458        cells_per_step: Option<usize>,
459        mut f: F,
460    ) -> Result<Self, WaveFunctionCollapseError>
461    where
462        F: FnMut(usize, usize),
463    {
464        let mut builder = Self::build(model, superposition, cells_per_step)?;
465        let (p, m) = builder.progress();
466        f(p, m);
467        while builder.process() {
468            let (p, m) = builder.progress();
469            f(p, m);
470        }
471        let (p, m) = builder.progress();
472        f(p, m);
473        builder.build()
474    }
475
476    pub fn collapse<R>(&mut self, gen_range: R) -> WaveFunctionCollapseResult<T>
477    where
478        R: FnMut(Scalar, Scalar) -> Scalar + Clone,
479    {
480        loop {
481            match self.collapse_step(gen_range.clone()) {
482                WaveFunctionCollapseResult::Incomplete => continue,
483                result => return result,
484            }
485        }
486    }
487
488    pub fn collapse_with_tries<R>(
489        &mut self,
490        mut tries: usize,
491        gen_range: R,
492    ) -> WaveFunctionCollapseResult<T>
493    where
494        R: FnMut(Scalar, Scalar) -> Scalar + Clone,
495    {
496        while tries > 0 {
497            match self.collapse(gen_range.clone()) {
498                WaveFunctionCollapseResult::Impossible => {
499                    tries -= 1;
500                    continue;
501                }
502                result => return result,
503            }
504        }
505        WaveFunctionCollapseResult::Impossible
506    }
507
508    pub fn collapse_inspect<R, F>(
509        &mut self,
510        gen_range: R,
511        mut f: F,
512    ) -> WaveFunctionCollapseResult<T>
513    where
514        F: FnMut(usize, usize, &Self),
515        R: FnMut(Scalar, Scalar) -> Scalar + Clone,
516    {
517        loop {
518            match self.collapse_step(gen_range.clone()) {
519                WaveFunctionCollapseResult::Incomplete => {
520                    let (p, m) = self.progress();
521                    f(p, m, self);
522                    continue;
523                }
524                result => return result,
525            }
526        }
527    }
528
529    pub fn collapse_inspect_with_tries<R, F>(
530        &mut self,
531        mut tries: usize,
532        gen_range: R,
533        mut f: F,
534    ) -> WaveFunctionCollapseResult<T>
535    where
536        F: FnMut() -> Box<dyn FnMut(usize, usize, &Self)>,
537        R: FnMut(Scalar, Scalar) -> Scalar + Clone,
538    {
539        while tries > 0 {
540            match self.collapse_inspect(gen_range.clone(), f()) {
541                WaveFunctionCollapseResult::Impossible => {
542                    tries -= 1;
543                    continue;
544                }
545                result => return result,
546            }
547        }
548        WaveFunctionCollapseResult::Impossible
549    }
550
551    pub fn collapse_step<R>(&mut self, gen_range: R) -> WaveFunctionCollapseResult<T>
552    where
553        R: FnMut(Scalar, Scalar) -> Scalar,
554    {
555        let coord = if let Ok(coord) = self.get_uncollapsed_coord() {
556            coord
557        } else {
558            return WaveFunctionCollapseResult::Impossible;
559        };
560        let (col, row) = if let Some(coord) = coord {
561            coord
562        } else if let Ok(collapsed) =
563            Self::superposition_to_collapsed_world(&self.model, &self.superposition)
564        {
565            return WaveFunctionCollapseResult::Collapsed(collapsed);
566        } else {
567            return WaveFunctionCollapseResult::Impossible;
568        };
569        self.lately_updated.clear();
570        if !self.collapse_cell(col, row, gen_range) {
571            return WaveFunctionCollapseResult::Impossible;
572        }
573        self.lately_updated.insert((col, row));
574        let (cols, rows) = self.superposition.size();
575        self.cached_open.push_back(((col + cols - 1) % cols, row));
576        self.cached_open.push_back(((col + 1) % cols, row));
577        self.cached_open.push_back((col, (row + rows - 1) % rows));
578        self.cached_open.push_back((col, (row + 1) % rows));
579        while !self.cached_open.is_empty() {
580            self.partially_reduce_superposition();
581        }
582        self.cached_progress =
583            self.superposition
584                .iter()
585                .fold(0, |a, c| if c.patterns.len() == 1 { a + 1 } else { a });
586        WaveFunctionCollapseResult::Incomplete
587    }
588
589    pub fn progress(&self) -> (usize, usize) {
590        (self.cached_progress, self.superposition.len())
591    }
592
593    fn collapse_cell<R>(&mut self, col: usize, row: usize, mut gen_range: R) -> bool
594    where
595        R: FnMut(Scalar, Scalar) -> Scalar,
596    {
597        let patterns = self.model.patterns();
598        let cell = self.superposition.cell(col, row).unwrap();
599        let total = cell
600            .patterns
601            .iter()
602            .fold(0.0, |accum, index| accum + patterns[*index].1);
603        let mut selected = gen_range(0.0, total);
604        for index in cell.patterns.iter() {
605            let weight = patterns[*index].1;
606            if selected <= weight {
607                let mut patterns = HashSet::with_capacity(1);
608                patterns.insert(*index);
609                self.superposition.set(
610                    col,
611                    row,
612                    Cell {
613                        patterns,
614                        entropy: 0.0,
615                    },
616                );
617                return true;
618            }
619            selected -= weight;
620        }
621        false
622    }
623
624    fn get_uncollapsed_coord(&self) -> Result<Option<(usize, usize)>, ()> {
625        if self
626            .superposition
627            .iter()
628            .any(|cell| cell.patterns.is_empty())
629        {
630            return Err(());
631        }
632        let cols = self.superposition.cols();
633        let result = {
634            let mut result = None;
635            for (index, cell) in self.superposition.iter().enumerate() {
636                let col = index % cols;
637                let row = index / cols;
638                if cell.patterns.len() > 1 {
639                    if let Some((_, _, entropy)) = result {
640                        if cell.entropy < entropy {
641                            result = Some((col, row, cell.entropy));
642                        }
643                    } else {
644                        result = Some((col, row, cell.entropy));
645                    }
646                }
647            }
648            result
649        };
650        Ok(result.map(|(col, row, _)| (col, row)))
651    }
652
653    fn partially_reduce_superposition(&mut self) {
654        if self.cached_open.is_empty() {
655            return;
656        }
657        let (col, row) = self.cached_open.pop_front().unwrap();
658        let (cols, rows) = self.superposition.size();
659        let patterns = &self.superposition.cell(col, row).unwrap().patterns;
660        let count = patterns.len();
661        if count > 1 {
662            let samples = [
663                self.superposition
664                    .cell((cols + col - 1) % cols, row)
665                    .unwrap(),
666                self.superposition.cell((col + 1) % cols, row).unwrap(),
667                self.superposition
668                    .cell(col, (rows + row - 1) % rows)
669                    .unwrap(),
670                self.superposition.cell(col, (row + 1) % rows).unwrap(),
671            ];
672            let neighbors = self.model.neighbors();
673            #[cfg(not(feature = "parallel"))]
674            let patterns = patterns.iter();
675            #[cfg(feature = "parallel")]
676            let patterns = patterns.par_iter();
677            let patterns = patterns
678                .filter(|index| {
679                    let neighbors = neighbors.get(**index).unwrap();
680                    if neighbors.is_empty() {
681                        return false;
682                    }
683                    NEIGHBOR_COORD_DIRS.iter().enumerate().all(|(i, d)| {
684                        samples[i].patterns.iter().any(|n| {
685                            neighbors
686                                .iter()
687                                .any(|(neighbor, direction)| direction == d && neighbor == n)
688                        })
689                    })
690                })
691                .cloned()
692                .collect::<HashSet<_>>();
693            if patterns.len() < count {
694                self.lately_updated.insert((col, row));
695                let coord = ((col + cols - 1) % cols, row);
696                if samples[0].patterns.len() > 1 && !self.cached_open.contains(&coord) {
697                    self.cached_open.push_back(coord);
698                }
699                let coord = ((col + 1) % cols, row);
700                if samples[1].patterns.len() > 1 && !self.cached_open.contains(&coord) {
701                    self.cached_open.push_back(coord);
702                }
703                let coord = (col, (row + rows - 1) % rows);
704                if samples[2].patterns.len() > 1 && !self.cached_open.contains(&coord) {
705                    self.cached_open.push_back(coord);
706                }
707                let coord = (col, (row + 1) % rows);
708                if samples[3].patterns.len() > 1 && !self.cached_open.contains(&coord) {
709                    self.cached_open.push_back(coord);
710                }
711                let entropy = calculate_entropy(&self.model, &patterns);
712                self.superposition.set(col, row, Cell { patterns, entropy });
713            }
714        }
715    }
716
717    pub fn get_uncollapsed_world(&self) -> Grid2d<Vec<T>> {
718        let cols = self.superposition.cols();
719        let cells = self
720            .superposition
721            .iter()
722            .map(|cell| {
723                cell.patterns
724                    .iter()
725                    .map(|index| self.model.patterns()[*index].0.cell(0, 0).unwrap().clone())
726                    .collect::<Vec<_>>()
727            })
728            .collect::<Vec<_>>();
729        Grid2d::with_cells(cols, cells)
730    }
731
732    fn superposition_to_collapsed_world(
733        model: &WaveFunctionCollapseModel<T>,
734        superposition: &Grid2d<Cell>,
735    ) -> Result<Grid2d<T>, WaveFunctionCollapseError> {
736        let cols = superposition.cols();
737        let cells = superposition
738            .iter()
739            .map(|cell| {
740                if cell.patterns.len() == 1 {
741                    let index = cell.patterns.iter().next().unwrap();
742                    Ok(model.patterns()[*index].0.cell(0, 0).unwrap().clone())
743                } else {
744                    Err(WaveFunctionCollapseError::FoundUncollapsedCell)
745                }
746            })
747            .collect::<Result<Vec<_>, _>>()?;
748        Ok(Grid2d::with_cells(cols, cells))
749    }
750}
751
752fn calculate_entropy<T>(model: &WaveFunctionCollapseModel<T>, patterns: &HashSet<usize>) -> Scalar
753where
754    T: Clone + Send + Sync + PartialEq,
755{
756    if patterns.is_empty() {
757        return 0.0;
758    }
759    let mut total_weight = 0.0;
760    let mut total_weight_log = 0.0;
761    for index in patterns {
762        let weight = model.patterns()[*index].1;
763        total_weight += weight;
764        total_weight_log += weight * weight.log2();
765    }
766    total_weight.log2() - (total_weight_log / total_weight)
767}
768
769#[cfg(test)]
770mod tests {
771    use super::*;
772
773    #[allow(dead_code)]
774    fn parse_view(data: &str) -> Grid2d<Option<char>> {
775        let lines = data
776            .split(|c| c == '\n' || c == '\r')
777            .filter(|l| !l.is_empty())
778            .collect::<Vec<_>>();
779        let cols = lines.iter().fold(0, |a, l| a.max(l.len()));
780        let rows = lines.len();
781        let mut result = Grid2d::new(cols, rows, None);
782        for (row, line) in lines.into_iter().enumerate() {
783            for (col, character) in line.chars().enumerate() {
784                if !character.is_whitespace() {
785                    result.set(col, row, Some(character));
786                }
787            }
788        }
789        print_view("= VIEW:", &result);
790        result
791    }
792
793    #[allow(dead_code)]
794    fn print_view(msg: &str, pattern: &Grid2d<Option<char>>) {
795        println!("{}", msg);
796        for row in 0..pattern.rows() {
797            for cell in pattern.get_row_cells(row).unwrap() {
798                if let Some(cell) = cell {
799                    print!("{}", cell);
800                } else {
801                    print!(" ");
802                }
803            }
804            println!();
805        }
806    }
807
808    #[allow(dead_code)]
809    fn print_collapsed_world(msg: &str, world: &Grid2d<char>) {
810        println!("{}", msg);
811        for row in 0..world.rows() {
812            for cell in world.get_row_cells(row).unwrap() {
813                print!("{}", cell);
814            }
815            println!();
816        }
817    }
818
819    #[allow(dead_code)]
820    fn print_uncollapsed_world(msg: &str, world: &Grid2d<Vec<char>>, uncertain: char) {
821        println!("{}", msg);
822        for row in 0..world.rows() {
823            for cell in world.get_row_cells(row).unwrap() {
824                if cell.len() == 1 {
825                    print!("{}", cell[0]);
826                } else {
827                    print!("{}", uncertain);
828                }
829            }
830            println!();
831        }
832    }
833
834    #[allow(dead_code)]
835    fn print_pattern(msg: &str, pattern: &Grid2d<char>) {
836        println!("{}", msg);
837        for row in 0..pattern.rows() {
838            for cell in pattern.get_row_cells(row).unwrap() {
839                print!("{}", cell);
840            }
841            println!();
842        }
843    }
844
845    #[test]
846    #[cfg(feature = "longrun")]
847    fn test_general() {
848        use rand::{thread_rng, Rng};
849        use std::time::Instant;
850
851        let view = parse_view(include_str!("../resources/view.txt"));
852        let values = {
853            let mut values = view.iter().filter_map(|c| c.clone()).collect::<Vec<_>>();
854            values.sort();
855            values.dedup();
856            values
857        };
858        println!("= VALUES: {:?}", values);
859        let model = WaveFunctionCollapseModel::from_views((3, 3), true, vec![view]).unwrap();
860        let world = Grid2d::new(75, 75, values);
861        let timer = Instant::now();
862        let mut timer2 = Instant::now();
863        let mut solver = WaveFunctionCollapseSolver::new_inspect(model, world, None, |p, m| {
864            if timer2.elapsed().as_millis() >= 400 {
865                timer2 = Instant::now();
866                println!(
867                    "= INITIALIZE: {} / {} ({}%)",
868                    p,
869                    m,
870                    100.0 * p as Scalar / m as Scalar
871                );
872            }
873        })
874        .unwrap();
875        println!("= INITIALIZED IN: {:?}", timer.elapsed());
876        let timer = Instant::now();
877        let mut timer2 = Instant::now();
878        let mut rng = thread_rng();
879        let mut max_changes = 0;
880        let result = solver.collapse_inspect(
881            move |f, t| rng.gen_range(f..t),
882            |p, m, s| {
883                max_changes = max_changes.max(s.lately_updated().len());
884                if timer2.elapsed().as_millis() >= 400 {
885                    timer2 = Instant::now();
886                    println!();
887                    println!();
888                    print_uncollapsed_world(
889                        "= UNCOLLAPSED WORLD:",
890                        &s.get_uncollapsed_world(),
891                        ' ',
892                    );
893                    println!(
894                        "= PROGRESS: {} / {} ({}%)",
895                        p,
896                        m,
897                        100.0 * p as Scalar / m as Scalar
898                    )
899                }
900            },
901        );
902        match result {
903            WaveFunctionCollapseResult::Collapsed(world) => {
904                println!();
905                println!();
906                println!(
907                    "= COLLAPSED IN: {:?} | MAX CHANGES: {}",
908                    timer.elapsed(),
909                    max_changes
910                );
911                print_collapsed_world("= COLLAPSED WORLD:", &world);
912            }
913            _ => panic!("= IMPOSSIBLE WORLD"),
914        }
915    }
916}