texture_synthesis/
ms.rs

1use rand::{Rng, SeedableRng};
2use rand_pcg::Pcg32;
3use rstar::RTree;
4use std::cmp::max;
5use std::sync::atomic::{AtomicUsize, Ordering};
6use std::sync::{Mutex, RwLock};
7
8use crate::{
9    img_pyramid::*,
10    session::{GeneratorProgress, ProgressStat},
11    unsync::*,
12    CoordinateTransform, Dims, SamplingMethod,
13};
14
15const TILING_BOUNDARY_PERCENTAGE: f32 = 0.05;
16
17#[derive(Debug)]
18pub struct GeneratorParams {
19    /// How many neighboring pixels each pixel is aware of during the generation
20    /// (bigger number -> more global structures are captured).
21    pub(crate) nearest_neighbors: u32,
22    /// How many random locations will be considered during a pixel resolution
23    /// apart from its immediate neighbors (if unsure, keep same as k-neighbors)
24    pub(crate) random_sample_locations: u64,
25    /// The distribution dispersion used for picking best candidate (controls
26    /// the distribution 'tail flatness'). Values close to 0.0 will produce
27    /// 'harsh' borders between generated 'chunks'. Values  closer to 1.0 will
28    /// produce a smoother gradient on those borders.
29    pub(crate) cauchy_dispersion: f32,
30    /// The percentage of pixels to be backtracked during each p_stage.
31    /// Range (0,1).
32    pub(crate) p: f32,
33    /// Controls the number of backtracking stages. Backtracking prevents
34    /// 'garbage' generation
35    pub(crate) p_stages: i32,
36    /// random seed
37    pub(crate) seed: u64,
38    /// controls the trade-off between guide and example map
39    pub(crate) alpha: f32,
40    pub(crate) max_thread_count: usize,
41    pub(crate) tiling_mode: bool,
42}
43
44#[derive(Debug, Default, Clone)]
45struct CandidateStruct {
46    coord: (SignedCoord2D, MapId), //X, Y, and map_id
47    k_neighs: Vec<(SignedCoord2D, MapId)>,
48    id: (PatchId, MapId),
49}
50
51impl CandidateStruct {
52    fn clear(&mut self) {
53        self.k_neighs.clear();
54    }
55}
56
57struct GuidesStruct<'a> {
58    pub example_guides: Vec<ImageBuffer<'a>>, // as many as there are examples
59    pub target_guide: ImageBuffer<'a>,        //single for final color_map
60}
61
62pub(crate) struct GuidesPyramidStruct {
63    pub example_guides: Vec<ImagePyramid>, // as many as there are examples
64    pub target_guide: ImagePyramid,        //single for final color_map
65}
66
67impl GuidesPyramidStruct {
68    fn to_guides_struct(&self, level: usize) -> GuidesStruct<'_> {
69        let tar_guide = ImageBuffer::from(&self.target_guide.pyramid[level]);
70        let ex_guide = self
71            .example_guides
72            .iter()
73            .map(|a| ImageBuffer::from(&a.pyramid[level]))
74            .collect();
75
76        GuidesStruct {
77            example_guides: ex_guide,
78            target_guide: tar_guide,
79        }
80    }
81}
82
83#[inline]
84fn modulo(a: i32, b: i32) -> i32 {
85    let result = a % b;
86    if result < 0 {
87        result + b
88    } else {
89        result
90    }
91}
92
93// for k-neighbors
94#[derive(Clone, Copy, Debug, Default)]
95struct SignedCoord2D {
96    x: i32,
97    y: i32,
98}
99
100impl SignedCoord2D {
101    fn from(x: i32, y: i32) -> Self {
102        Self { x, y }
103    }
104
105    fn to_unsigned(self) -> Coord2D {
106        Coord2D::from(self.x as u32, self.y as u32)
107    }
108
109    #[inline]
110    fn wrap(self, (dimx, dimy): (i32, i32)) -> Self {
111        let mut c = self;
112        c.x = modulo(c.x, dimx);
113        c.y = modulo(c.y, dimy);
114        c
115    }
116}
117
118#[derive(Clone, Copy, Debug)]
119struct Coord2D {
120    x: u32,
121    y: u32,
122}
123
124impl Coord2D {
125    fn from(x: u32, y: u32) -> Self {
126        Self { x, y }
127    }
128
129    fn to_flat(self, dims: Dims) -> CoordFlat {
130        CoordFlat(dims.width * self.y + self.x)
131    }
132
133    fn to_signed(self) -> SignedCoord2D {
134        SignedCoord2D {
135            x: self.x as i32,
136            y: self.y as i32,
137        }
138    }
139}
140#[derive(Clone, Copy, Debug)]
141struct CoordFlat(u32);
142
143impl CoordFlat {
144    fn to_2d(self, dims: Dims) -> Coord2D {
145        let y = self.0 / dims.width;
146        let x = self.0 - y * dims.width;
147        Coord2D::from(x, y)
148    }
149}
150
151#[derive(Clone, Copy, Debug, Default)]
152struct PatchId(u32);
153#[derive(Clone, Copy, Debug, Default)]
154struct MapId(u32);
155#[derive(Clone, Copy, Debug, Default)]
156struct Score(f32);
157
158#[derive(Clone, Debug, Default)]
159struct ColorPattern(Vec<u8>);
160
161impl ColorPattern {
162    pub fn new() -> Self {
163        Self(Vec::new())
164    }
165}
166
167#[derive(Clone)]
168pub(crate) struct ImageBuffer<'a> {
169    buffer: &'a [u8],
170    width: usize,
171    height: usize,
172}
173
174impl<'a> ImageBuffer<'a> {
175    #[inline]
176    fn is_in_bounds(&self, coord: SignedCoord2D) -> bool {
177        coord.x >= 0 && coord.y >= 0 && coord.x < self.width as i32 && coord.y < self.height as i32
178    }
179
180    #[inline]
181    fn get_pixel(&self, x: u32, y: u32) -> &'a image::Rgba<u8> {
182        let ind = (y as usize * self.width + x as usize) * 4;
183        unsafe {
184            &*((&self.buffer[ind..ind + 4])
185                .as_ptr()
186                .cast::<image::Rgba<u8>>())
187        }
188    }
189
190    #[inline]
191    fn dimensions(&self) -> (u32, u32) {
192        (self.width as u32, self.height as u32)
193    }
194}
195
196impl<'a> From<&'a image::RgbaImage> for ImageBuffer<'a> {
197    fn from(img: &'a image::RgbaImage) -> Self {
198        let (width, height) = img.dimensions();
199        Self {
200            buffer: img,
201            width: width as usize,
202            height: height as usize,
203        }
204    }
205}
206
207pub struct Generator {
208    pub(crate) color_map: UnsyncRgbaImage,
209    coord_map: UnsyncVec<(Coord2D, MapId)>, //list of samples coordinates from example map
210    id_map: UnsyncVec<(PatchId, MapId)>,    // list of all id maps of our generated image
211    pub(crate) output_size: Dims,           // size of the generated image
212    unresolved: Mutex<Vec<CoordFlat>>,      //for us to pick from
213    resolved: RwLock<Vec<(CoordFlat, Score)>>, //a list of resolved coordinates in our canvas and their scores
214    tree_grid: TreeGrid,                       // grid of R*Trees
215    locked_resolved: usize,                    //used for inpainting, to not backtrack these pixels
216    input_dimensions: Vec<Dims>,
217}
218
219impl Generator {
220    pub(crate) fn new(size: Dims) -> Self {
221        let s = (size.width as usize) * (size.height as usize);
222        let unresolved: Vec<CoordFlat> = (0..(s as u32)).map(CoordFlat).collect();
223        Self {
224            color_map: UnsyncRgbaImage::new(image::RgbaImage::new(size.width, size.height)),
225            coord_map: UnsyncVec::new(vec![(Coord2D::from(0, 0), MapId(0)); s]),
226            id_map: UnsyncVec::new(vec![(PatchId(0), MapId(0)); s]),
227            output_size: size,
228            unresolved: Mutex::new(unresolved),
229            resolved: RwLock::new(Vec::new()),
230            tree_grid: TreeGrid::new(size.width, size.height, max(size.width, size.height), 0, 0),
231            locked_resolved: 0,
232            input_dimensions: Vec::new(),
233        }
234    }
235
236    pub(crate) fn new_from_inpaint(
237        size: Dims,
238        inpaint_map: image::RgbaImage,
239        color_map: image::RgbaImage,
240        color_map_index: usize,
241    ) -> Self {
242        let inpaint_map =
243            if inpaint_map.width() != size.width || inpaint_map.height() != size.height {
244                image::imageops::resize(
245                    &inpaint_map,
246                    size.width,
247                    size.height,
248                    image::imageops::Triangle,
249                )
250            } else {
251                inpaint_map
252            };
253
254        let color_map = if color_map.width() != size.width || color_map.height() != size.height {
255            image::imageops::resize(
256                &color_map,
257                size.width,
258                size.height,
259                image::imageops::Triangle,
260            )
261        } else {
262            color_map
263        };
264
265        let s = (size.width as usize) * (size.height as usize);
266        let mut unresolved: Vec<CoordFlat> = Vec::new();
267        let mut resolved: Vec<(CoordFlat, Score)> = Vec::new();
268        let mut coord_map = vec![(Coord2D::from(0, 0), MapId(0)); s];
269        let tree_grid = TreeGrid::new(size.width, size.height, max(size.width, size.height), 0, 0);
270        //populate resolved, unresolved and coord map
271        for (i, pixel) in inpaint_map.pixels().enumerate() {
272            if pixel[0] < 255 {
273                unresolved.push(CoordFlat(i as u32));
274            } else {
275                resolved.push((CoordFlat(i as u32), Score(0.0)));
276                let coord = CoordFlat(i as u32).to_2d(size);
277                coord_map[i] = (coord, MapId(color_map_index as u32)); //this absolutely requires the input image and output image to be the same size!!!!
278            }
279        }
280
281        let locked_resolved = resolved.len();
282        Self {
283            color_map: UnsyncRgbaImage::new(color_map),
284            coord_map: UnsyncVec::new(coord_map),
285            id_map: UnsyncVec::new(vec![(PatchId(0), MapId(0)); s]),
286            output_size: size,
287            unresolved: Mutex::new(unresolved),
288            resolved: RwLock::new(resolved),
289            tree_grid,
290            locked_resolved,
291            input_dimensions: Vec::new(),
292        }
293    }
294
295    // Write resolved pixels from the update queue to an already write-locked `rtree` and `resolved` array.
296    fn flush_resolved(
297        &self,
298        my_resolved_list: &mut Vec<(CoordFlat, Score)>,
299        tree_grid: &TreeGrid,
300        update_queue: &[([i32; 2], CoordFlat, Score)],
301        is_tiling_mode: bool,
302    ) {
303        for (a, b, score) in update_queue.iter() {
304            tree_grid.insert(a[0], a[1]);
305
306            if is_tiling_mode {
307                //if close to border add additional mirrors
308                let x_l = ((self.output_size.width as f32) * TILING_BOUNDARY_PERCENTAGE) as i32;
309                let x_r = self.output_size.width as i32 - x_l;
310                let y_b = ((self.output_size.height as f32) * TILING_BOUNDARY_PERCENTAGE) as i32;
311                let y_t = self.output_size.height as i32 - y_b;
312
313                if a[0] < x_l {
314                    tree_grid.insert(a[0] + (self.output_size.width as i32), a[1]);
315                // +x
316                } else if a[0] > x_r {
317                    tree_grid.insert(a[0] - (self.output_size.width as i32), a[1]);
318                    // -x
319                }
320
321                if a[1] < y_b {
322                    tree_grid.insert(a[0], a[1] + (self.output_size.height as i32));
323                // +Y
324                } else if a[1] > y_t {
325                    tree_grid.insert(a[0], a[1] - (self.output_size.height as i32));
326                    // -Y
327                }
328            }
329            my_resolved_list.push((*b, *score));
330        }
331    }
332
333    #[allow(clippy::too_many_arguments)]
334    fn update(
335        &self,
336        my_resolved_list: &mut Vec<(CoordFlat, Score)>,
337        update_coord: Coord2D,
338        (example_coord, example_map_id): (Coord2D, MapId),
339        example_maps: &[ImageBuffer<'_>],
340        update_resolved_list: bool,
341        score: Score,
342        island_id: (PatchId, MapId),
343        is_tiling_mode: bool,
344    ) {
345        let flat_coord = update_coord.to_flat(self.output_size);
346
347        // A little cheat to avoid taking excessive locks.
348        //
349        // Access to `coord_map` and `color_map` is governed by values in `self.resolved`,
350        // in such a way that any values in the former will not be accessed until the latter is updated.
351        // Since `coord_map` and `color_map` also contain 'plain old data', we can set them directly
352        // by getting the raw pointers. The subsequent access to `self.resolved` goes through a lock,
353        // and ensures correct memory ordering.
354        unsafe {
355            self.coord_map
356                .assign_at(flat_coord.0 as usize, (example_coord, example_map_id));
357            self.id_map.assign_at(flat_coord.0 as usize, island_id);
358        }
359        self.color_map.put_pixel(
360            update_coord.x,
361            update_coord.y,
362            *example_maps[example_map_id.0 as usize].get_pixel(example_coord.x, example_coord.y),
363        );
364
365        if update_resolved_list {
366            self.flush_resolved(
367                my_resolved_list,
368                &self.tree_grid,
369                &[(
370                    [update_coord.x as i32, update_coord.y as i32],
371                    flat_coord,
372                    score,
373                )],
374                is_tiling_mode,
375            );
376        }
377    }
378
379    //returns flat coord
380    fn pick_random_unresolved(&self, seed: u64) -> Option<CoordFlat> {
381        let mut unresolved = self.unresolved.lock().unwrap();
382
383        if unresolved.len() == 0 {
384            None //return fail
385        } else {
386            let rand_index = Pcg32::seed_from_u64(seed).gen_range(0..unresolved.len());
387            Some(unresolved.swap_remove(rand_index)) //return success
388        }
389    }
390
391    fn find_k_nearest_resolved_neighs(
392        &self,
393        coord: Coord2D,
394        k: u32,
395        k_neighs_2d: &mut Vec<SignedCoord2D>,
396    ) -> bool {
397        self.tree_grid
398            .get_k_nearest_neighbors(coord.x, coord.y, k as usize, k_neighs_2d);
399        if k_neighs_2d.is_empty() {
400            return false;
401        }
402        true
403    }
404
405    fn get_distances_to_k_neighs(&self, coord: Coord2D, k_neighs_2d: &[SignedCoord2D]) -> Vec<f64> {
406        let (dimx, dimy) = (
407            f64::from(self.output_size.width),
408            f64::from(self.output_size.height),
409        );
410        let (x2, y2) = (f64::from(coord.x) / dimx, f64::from(coord.y) / dimy);
411        let mut k_neighs_dist: Vec<f64> = Vec::with_capacity(k_neighs_2d.len() * 4);
412
413        for coord in k_neighs_2d.iter() {
414            let (x1, y1) = ((f64::from(coord.x)) / dimx, (f64::from(coord.y)) / dimy);
415            let dist = (x1 - x2).mul_add(x1 - x2, (y1 - y2) * (y1 - y2));
416            // Duplicate the distance for each of our 4 channels
417            k_neighs_dist.extend_from_slice(&[dist, dist, dist, dist]);
418        }
419
420        //divide by avg
421        let avg: f64 = k_neighs_dist.iter().sum::<f64>() / (k_neighs_dist.len() as f64);
422
423        k_neighs_dist.iter_mut().for_each(|d| *d /= avg);
424        k_neighs_dist
425    }
426
427    pub(crate) fn resolve_random_batch(
428        &mut self,
429        steps: usize,
430        example_maps: &[ImageBuffer<'_>],
431        seed: u64,
432    ) {
433        for i in 0..steps {
434            if let Some(ref unresolved_flat) = self.pick_random_unresolved(seed + i as u64) {
435                //no resolved neighs? resolve at random!
436                self.resolve_at_random(
437                    &mut self.resolved.write().unwrap(),
438                    unresolved_flat.to_2d(self.output_size),
439                    example_maps,
440                    seed + i as u64 + u64::from(unresolved_flat.0),
441                );
442            }
443        }
444        self.locked_resolved += steps; //lock these pixels from being re-resolved
445    }
446
447    fn resolve_at_random(
448        &self,
449        my_resolved_list: &mut Vec<(CoordFlat, Score)>,
450        my_coord: Coord2D,
451        example_maps: &[ImageBuffer<'_>],
452        seed: u64,
453    ) {
454        let rand_map: u32 = Pcg32::seed_from_u64(seed).gen_range(0..example_maps.len()) as u32;
455        let rand_x: u32 =
456            Pcg32::seed_from_u64(seed).gen_range(0..example_maps[rand_map as usize].width as u32);
457        let rand_y: u32 =
458            Pcg32::seed_from_u64(seed).gen_range(0..example_maps[rand_map as usize].height as u32);
459
460        self.update(
461            my_resolved_list,
462            my_coord,
463            (Coord2D::from(rand_x, rand_y), MapId(rand_map)),
464            example_maps,
465            true,
466            // NOTE: giving score 0.0 which is absolutely imaginery since we're randomly
467            // initializing
468            Score(0.0),
469            (
470                PatchId(my_coord.to_flat(self.output_size).0),
471                MapId(rand_map),
472            ),
473            false,
474        );
475    }
476
477    #[allow(clippy::too_many_arguments)]
478    fn find_candidates<'a>(
479        &self,
480        candidates_vec: &'a mut Vec<CandidateStruct>,
481        unresolved_coord: Coord2D,
482        k_neighs: &[SignedCoord2D],
483        example_maps: &[ImageBuffer<'_>],
484        valid_non_ignored_samples_mask: &[&SamplingMethod],
485        m_rand: u32,
486        m_seed: u64,
487    ) -> &'a [CandidateStruct] {
488        let mut candidate_count = 0;
489        let unresolved_coord = unresolved_coord.to_signed();
490
491        let wrap_dim = (
492            self.output_size.width as i32,
493            self.output_size.height as i32,
494        );
495
496        //neighborhood based candidates
497        for neigh_coord in k_neighs {
498            //calculate the shift between the center coord and its found neighbor
499            let shift = (
500                unresolved_coord.x - (*neigh_coord).x,
501                unresolved_coord.y - (*neigh_coord).y,
502            );
503
504            //find center coord original location in the example map
505            let n_flat_coord = neigh_coord
506                .wrap(wrap_dim)
507                .to_unsigned()
508                .to_flat(self.output_size)
509                .0 as usize;
510            let (n_original_coord, _) = self.coord_map.as_ref()[n_flat_coord];
511            let (n_patch_id, n_map_id) = self.id_map.as_ref()[n_flat_coord];
512            //candidate coord is the original location of the neighbor + neighbor's shift to the center
513            let candidate_coord = SignedCoord2D::from(
514                n_original_coord.x as i32 + shift.0,
515                n_original_coord.y as i32 + shift.1,
516            );
517            //check if the shifted coord is valid (discard if not)
518            if check_coord_validity(
519                candidate_coord,
520                n_map_id,
521                example_maps,
522                valid_non_ignored_samples_mask[n_map_id.0 as usize],
523            ) {
524                //lets construct the full candidate pattern of neighbors identical to the center coord
525                candidates_vec[candidate_count]
526                    .k_neighs
527                    .resize(k_neighs.len(), (SignedCoord2D::from(0, 0), MapId(0)));
528
529                for (output, n2) in candidates_vec[candidate_count]
530                    .k_neighs
531                    .iter_mut()
532                    .zip(k_neighs)
533                {
534                    let shift = (n2.x - unresolved_coord.x, n2.y - unresolved_coord.y);
535                    let n2_coord = SignedCoord2D::from(
536                        candidate_coord.x + shift.0,
537                        candidate_coord.y + shift.1,
538                    );
539
540                    *output = (n2_coord, n_map_id);
541                }
542                //record the candidate info
543                candidates_vec[candidate_count].coord = (candidate_coord, n_map_id);
544                candidates_vec[candidate_count].id = (n_patch_id, n_map_id);
545                candidate_count += 1;
546            }
547        }
548
549        let mut rng = Pcg32::seed_from_u64(m_seed);
550        //random candidates
551        for _ in 0..m_rand {
552            let rand_map = (rng.gen_range(0..example_maps.len())) as u32;
553            let dims = example_maps[rand_map as usize].dimensions();
554            let dims = Dims {
555                width: dims.0,
556                height: dims.1,
557            };
558            let mut rand_x: i32;
559            let mut rand_y: i32;
560            let mut candidate_coord;
561            //generate a random valid candidate
562            loop {
563                rand_x = rng.gen_range(0..dims.width) as i32;
564                rand_y = rng.gen_range(0..dims.height) as i32;
565                candidate_coord = SignedCoord2D::from(rand_x, rand_y);
566                if check_coord_validity(
567                    candidate_coord,
568                    MapId(rand_map),
569                    example_maps,
570                    valid_non_ignored_samples_mask[rand_map as usize],
571                ) {
572                    break;
573                }
574            }
575            //for patch id (since we are not copying from a generated patch anymore), we take the pixel location in the example map
576            let map_id = MapId(rand_map);
577            let patch_id = PatchId(candidate_coord.to_unsigned().to_flat(dims).0);
578            //lets construct the full neighborhood pattern
579            candidates_vec[candidate_count]
580                .k_neighs
581                .resize(k_neighs.len(), (SignedCoord2D::from(0, 0), MapId(0)));
582
583            for (output, n2) in candidates_vec[candidate_count]
584                .k_neighs
585                .iter_mut()
586                .zip(k_neighs)
587            {
588                let shift = (unresolved_coord.x - n2.x, unresolved_coord.y - n2.y);
589                let n2_coord =
590                    SignedCoord2D::from(candidate_coord.x + shift.0, candidate_coord.y + shift.1);
591
592                *output = (n2_coord, map_id);
593            }
594
595            //record the candidate info
596            candidates_vec[candidate_count].coord = (candidate_coord, map_id);
597            candidates_vec[candidate_count].id = (patch_id, map_id);
598            candidate_count += 1;
599        }
600
601        &candidates_vec[0..candidate_count]
602    }
603
604    /// Returns an image of Ids for visualizing the 'copy islands' and map ids of those islands
605    pub fn get_id_maps(&self) -> [image::RgbaImage; 2] {
606        //init empty image
607        let mut map_id_map = image::RgbaImage::new(self.output_size.width, self.output_size.height);
608        let mut patch_id_map =
609            image::RgbaImage::new(self.output_size.width, self.output_size.height);
610        //populate the image with colors
611        for (i, (patch_id, map_id)) in self.id_map.as_ref().iter().enumerate() {
612            //get 2d coord
613            let coord = CoordFlat(i as u32).to_2d(self.output_size);
614            //get random color based on id
615            let color: image::Rgba<u8> = image::Rgba([
616                Pcg32::seed_from_u64(u64::from(patch_id.0)).gen_range(0..255),
617                Pcg32::seed_from_u64(u64::from((patch_id.0) * 5 + 21)).gen_range(0..255),
618                Pcg32::seed_from_u64(u64::from((patch_id.0) / 4 + 12)).gen_range(0..255),
619                255,
620            ]);
621            //write image
622            patch_id_map.put_pixel(coord.x, coord.y, color);
623            //get random color based on id
624            let color: image::Rgba<u8> = image::Rgba([
625                Pcg32::seed_from_u64(u64::from(map_id.0) * 200).gen_range(0..255),
626                Pcg32::seed_from_u64(u64::from((map_id.0) * 5 + 341)).gen_range(0..255),
627                Pcg32::seed_from_u64(u64::from((map_id.0) * 1200 - 35412)).gen_range(0..255),
628                255,
629            ]);
630            map_id_map.put_pixel(coord.x, coord.y, color);
631        }
632        [patch_id_map, map_id_map]
633    }
634
635    pub fn get_uncertainty_map(&self) -> image::RgbaImage {
636        let mut uncertainty_map =
637            image::RgbaImage::new(self.output_size.width, self.output_size.height);
638
639        for (flat_coord, score) in self.resolved.read().unwrap().iter() {
640            //get coord
641            let coord = flat_coord.to_2d(self.output_size);
642            //get value normalized
643            let normalized_score = (score.0.min(1.0) * 255.0) as u8;
644
645            let color: image::Rgba<u8> =
646                image::Rgba([normalized_score, 255 - normalized_score, 0, 255]);
647
648            //write image
649            uncertainty_map.put_pixel(coord.x, coord.y, color);
650        }
651
652        uncertainty_map
653    }
654
655    pub fn get_coord_transform(&self) -> CoordinateTransform {
656        // init empty 32bit image
657        let coord_map = self.coord_map.as_ref();
658
659        let mut buffer: Vec<u32> = Vec::new();
660
661        // presize the vector for our final size
662        buffer.resize(coord_map.len() * 3, 0);
663
664        //populate the image with colors
665        for (i, (coord, map_id)) in self.coord_map.as_ref().iter().enumerate() {
666            let b = map_id.0;
667
668            //record the color
669            let ind = i * 3;
670            let color = &mut buffer[ind..ind + 3];
671
672            color[0] = coord.x;
673            color[1] = coord.y;
674            color[2] = b;
675        }
676
677        let original_maps = self.input_dimensions.clone();
678
679        CoordinateTransform {
680            buffer,
681            output_size: Dims::new(self.output_size.width, self.output_size.height),
682            original_maps,
683        }
684    }
685
686    //replace every resolved pixel with a pixel from a new level
687    fn next_pyramid_level(&mut self, example_maps: &[ImageBuffer<'_>]) {
688        for (coord_flat, _) in self.resolved.read().unwrap().iter() {
689            let resolved_2d = coord_flat.to_2d(self.output_size);
690            let (example_map_coord, example_map_id) =
691                self.coord_map.as_ref()[coord_flat.0 as usize]; //so where the current pixel came from
692
693            self.color_map.put_pixel(
694                resolved_2d.x,
695                resolved_2d.y,
696                *example_maps[example_map_id.0 as usize]
697                    .get_pixel(example_map_coord.x, example_map_coord.y),
698            );
699        }
700    }
701
702    pub(crate) fn resolve(
703        &mut self,
704        params: &GeneratorParams,
705        example_maps_pyramid: &[ImagePyramid],
706        mut progress: Option<Box<dyn GeneratorProgress>>,
707        guides_pyramid: &Option<GuidesPyramidStruct>,
708        valid_samples: &[SamplingMethod],
709    ) {
710        let total_pixels_to_resolve = self.unresolved.lock().unwrap().len();
711
712        let mut pyramid_level = 0;
713
714        let valid_non_ignored_samples: Vec<&SamplingMethod> = valid_samples[..]
715            .iter()
716            .filter(|s| !s.is_ignore())
717            .collect();
718
719        // Get the dimensions for each input example, this is only used when
720        // saving a coordinate transform, so that the transform can be repeated
721        // with different inputs that can be resized to avoid various problems
722        self.input_dimensions = example_maps_pyramid
723            .iter()
724            .map(|ip| {
725                let original = ip.bottom();
726                Dims {
727                    width: original.width(),
728                    height: original.height(),
729                }
730            })
731            .collect();
732
733        let stage_pixels_to_resolve = |p_stage: i32| {
734            (params.p.powf(p_stage as f32) * (total_pixels_to_resolve as f32)) as usize
735        };
736
737        let is_tiling_mode = params.tiling_mode;
738
739        let cauchy_precomputed = PrerenderedU8Function::new(|a, b| {
740            metric_cauchy(a, b, params.cauchy_dispersion * params.cauchy_dispersion)
741        });
742        let l2_precomputed = PrerenderedU8Function::new(metric_l2);
743        let max_workers = params.max_thread_count;
744        // Use a single R*-tree initially, and fan out to a grid of them later?
745        let mut has_fanned_out = false;
746
747        {
748            // now that we have all of the parameters we can setup our initial tree grid
749            let tile_adjusted_width = (self.output_size.width as f32
750                * TILING_BOUNDARY_PERCENTAGE.mul_add(2.0, 1.0))
751                as u32
752                + 1;
753            let tile_adjusted_height = (self.output_size.height as f32
754                * TILING_BOUNDARY_PERCENTAGE.mul_add(2.0, 1.0))
755                as u32
756                + 1;
757            self.tree_grid = TreeGrid::new(
758                tile_adjusted_width,
759                tile_adjusted_height,
760                max(tile_adjusted_width, tile_adjusted_height),
761                (self.output_size.width as f32 * TILING_BOUNDARY_PERCENTAGE) as u32 + 1,
762                (self.output_size.height as f32 * TILING_BOUNDARY_PERCENTAGE) as u32 + 1,
763            );
764            // if we already have resolved pixels from an inpaint or multiexample add them to this tree grid
765            let resolved_queue = &mut self.resolved.write().unwrap();
766            let pixels_to_update: Vec<([i32; 2], CoordFlat, Score)> = resolved_queue
767                .drain(..)
768                .map(|a| {
769                    let coord_2d = a.0.to_2d(self.output_size);
770                    ([coord_2d.x as i32, coord_2d.y as i32], a.0, a.1)
771                })
772                .collect();
773            self.flush_resolved(
774                resolved_queue,
775                &self.tree_grid,
776                &pixels_to_update[..],
777                is_tiling_mode,
778            );
779        }
780
781        let mut progress_notifier = progress.as_mut().map(|progress| {
782            let overall_total: usize = (0..=params.p_stages).map(stage_pixels_to_resolve).sum();
783            ProgressNotifier::new(progress, overall_total)
784        });
785
786        for p_stage in (0..=params.p_stages).rev() {
787            //get maps from current pyramid level (for now it will be p-stage dependant)
788            let example_maps = get_single_example_level(
789                example_maps_pyramid,
790                valid_samples,
791                pyramid_level as usize,
792            );
793            let guides = get_single_guide_level(guides_pyramid, pyramid_level as usize);
794
795            //update pyramid level
796            if pyramid_level > 0 {
797                self.next_pyramid_level(&example_maps);
798            }
799            pyramid_level += 1;
800            pyramid_level = pyramid_level.min(params.p_stages - 1); //dont go beyond
801
802            //get seed
803            let p_stage_seed: u64 =
804                u64::from(Pcg32::seed_from_u64(params.seed + p_stage as u64).gen::<u32>());
805
806            //how many pixels do we need to resolve in this stage
807            let pixels_to_resolve = stage_pixels_to_resolve(p_stage);
808            if let Some(ref mut notifier) = progress_notifier {
809                notifier.start_stage(pixels_to_resolve);
810            };
811
812            let redo_count = self.resolved.get_mut().unwrap().len() - self.locked_resolved;
813
814            // Start with serial execution for the first few pixels, then go wide
815            let n_workers = if redo_count < 1000 { 1 } else { max_workers };
816            if !has_fanned_out && n_workers > 1 {
817                has_fanned_out = true;
818                let tile_adjusted_width = (self.output_size.width as f32
819                    * TILING_BOUNDARY_PERCENTAGE.mul_add(2.0, 1.0))
820                    as u32
821                    + 1;
822                let tile_adjusted_height = (self.output_size.height as f32
823                    * TILING_BOUNDARY_PERCENTAGE.mul_add(2.0, 1.0))
824                    as u32
825                    + 1;
826                // heuristic: pick a cell size so that the expected number of resolved points in any cell is 4 * k
827                // this seems to be a safe overestimate
828                let grid_cell_size =
829                    ((params.nearest_neighbors * self.output_size.width * self.output_size.height
830                        / redo_count as u32) as f64)
831                        .sqrt() as u32
832                        * 2
833                        + 1;
834                let new_tree_grid = TreeGrid::new(
835                    tile_adjusted_width,
836                    tile_adjusted_height,
837                    grid_cell_size,
838                    (self.output_size.width as f32 * TILING_BOUNDARY_PERCENTAGE) as u32 + 1,
839                    (self.output_size.height as f32 * TILING_BOUNDARY_PERCENTAGE) as u32 + 1,
840                );
841                self.tree_grid.clone_into_new_tree_grid(&new_tree_grid);
842                self.tree_grid = new_tree_grid;
843            }
844
845            //calculate the guidance alpha
846            let adaptive_alpha = if guides.is_some() && p_stage > 0 {
847                let total_resolved = self.resolved.read().unwrap().len() as f32;
848                (params.alpha * (1.0 - (total_resolved / (total_pixels_to_resolve as f32)))).powi(3)
849            } else {
850                0.0 //only care for content, not guidance
851            };
852
853            let guide_cost_precomputed =
854                PrerenderedU8Function::new(|a, b| adaptive_alpha * l2_precomputed.get(a, b));
855
856            let my_inverse_alpha_cost_precomputed = PrerenderedU8Function::new(|a, b| {
857                (1.0 - adaptive_alpha) * cauchy_precomputed.get(a, b)
858            });
859
860            // Keep track of how many items have been processed. Goes up to `pixels_to_resolve`
861            let processed_pixel_count = AtomicUsize::new(0);
862            let remaining_threads = AtomicUsize::new(n_workers);
863
864            let mut pixels_resolved_this_stage: Vec<Mutex<Vec<(CoordFlat, Score)>>> = Vec::new();
865            pixels_resolved_this_stage.resize_with(n_workers, || Mutex::new(Vec::new()));
866            let thread_counter = AtomicUsize::new(0);
867
868            let worker_fn = |mut progress_notifier: Option<&mut ProgressNotifier<'_>>| {
869                let mut candidates: Vec<CandidateStruct> = Vec::new();
870                let mut my_pattern: ColorPattern = ColorPattern::new();
871                let mut k_neighs: Vec<SignedCoord2D> =
872                    Vec::with_capacity(params.nearest_neighbors as usize);
873
874                let max_candidate_count =
875                    params.nearest_neighbors as usize + params.random_sample_locations as usize;
876
877                let my_thread_id = thread_counter.fetch_add(1, Ordering::Relaxed);
878                let mut my_resolved_list = pixels_resolved_this_stage[my_thread_id].lock().unwrap();
879
880                candidates.resize(max_candidate_count, CandidateStruct::default());
881
882                //alloc storage for our guides (regardless of whether we have them or not)
883                let mut my_guide_pattern: ColorPattern = ColorPattern::new();
884
885                let out_color_map = &[ImageBuffer::from(self.color_map.as_ref())];
886
887                loop {
888                    // Get the next work item
889                    let i = processed_pixel_count.fetch_add(1, Ordering::Relaxed);
890
891                    let update_resolved_list: bool;
892
893                    if let Some(notifier) = progress_notifier.as_mut() {
894                        notifier.update(i, &self.color_map);
895                    }
896
897                    if i >= pixels_to_resolve {
898                        // We've processed everything, so finish the worker
899                        break;
900                    }
901
902                    let loop_seed = p_stage_seed + i as u64;
903
904                    // 1. Get a pixel to resolve. Check if we have already resolved pixel i; if yes, resolve again; if no, pick a new one
905                    let next_unresolved = if i < redo_count {
906                        update_resolved_list = false;
907                        self.resolved.read().unwrap()[i + self.locked_resolved].0
908                    } else {
909                        update_resolved_list = true;
910                        if let Some(pixel) = self.pick_random_unresolved(loop_seed) {
911                            pixel
912                        } else {
913                            break;
914                        }
915                    };
916
917                    let unresolved_2d = next_unresolved.to_2d(self.output_size);
918
919                    // Clear previously found candidate neighbors
920                    for cand in candidates.iter_mut() {
921                        cand.clear();
922                    }
923                    k_neighs.clear();
924
925                    // 2. find K nearest resolved neighs
926                    if self.find_k_nearest_resolved_neighs(
927                        unresolved_2d,
928                        params.nearest_neighbors,
929                        &mut k_neighs,
930                    ) {
931                        //2.1 get distances to the pattern of neighbors
932                        let k_neighs_dist =
933                            self.get_distances_to_k_neighs(unresolved_2d, &k_neighs);
934                        let k_neighs_w_map_id =
935                            k_neighs.iter().map(|a| (*a, MapId(0))).collect::<Vec<_>>();
936
937                        // 3. find candidate for each resolved neighs + m random locations
938                        let candidates: &[CandidateStruct] = self.find_candidates(
939                            &mut candidates,
940                            unresolved_2d,
941                            &k_neighs,
942                            &example_maps,
943                            &valid_non_ignored_samples,
944                            params.random_sample_locations as u32,
945                            loop_seed + 1,
946                        );
947
948                        k_neighs_to_precomputed_reference_pattern(
949                            &k_neighs_w_map_id, //feed into the function with always 0 index of the sample map
950                            image::Rgba([0, 0, 0, 255]),
951                            out_color_map,
952                            &mut my_pattern,
953                            is_tiling_mode,
954                        );
955
956                        // 3.2 get pattern for guide map if we have them
957                        let (my_cost, guide_cost) = if let Some(ref in_guides) = guides {
958                            //get example pattern to compare to
959                            k_neighs_to_precomputed_reference_pattern(
960                                &k_neighs_w_map_id,
961                                image::Rgba([0, 0, 0, 255]),
962                                &[in_guides.target_guide.clone()],
963                                &mut my_guide_pattern,
964                                is_tiling_mode,
965                            );
966
967                            (
968                                &my_inverse_alpha_cost_precomputed,
969                                Some(&guide_cost_precomputed),
970                            )
971                        } else {
972                            (&cauchy_precomputed, None)
973                        };
974
975                        // 4. find best match based on the candidate patterns
976                        let (best_match, score) = find_best_match(
977                            image::Rgba([0, 0, 0, 255]),
978                            &example_maps,
979                            &guides,
980                            candidates,
981                            &my_pattern,
982                            &my_guide_pattern,
983                            &k_neighs_dist,
984                            my_cost,
985                            guide_cost,
986                        );
987
988                        let best_match_coord = best_match.coord.0.to_unsigned();
989                        let best_match_map_id = best_match.coord.1;
990
991                        // 5. resolve our pixel
992                        self.update(
993                            &mut my_resolved_list,
994                            unresolved_2d,
995                            (best_match_coord, best_match_map_id),
996                            &example_maps,
997                            update_resolved_list,
998                            score,
999                            best_match.id,
1000                            is_tiling_mode,
1001                        );
1002                    } else {
1003                        //no resolved neighs? resolve at random!
1004                        self.resolve_at_random(
1005                            &mut my_resolved_list,
1006                            unresolved_2d,
1007                            &example_maps,
1008                            p_stage_seed,
1009                        );
1010                    }
1011                }
1012                remaining_threads.fetch_sub(1, Ordering::Relaxed);
1013            };
1014
1015            // For WASM we do not have threads and crossbeam panics,
1016            // so let's just run the worker function directly.
1017            #[cfg(target_arch = "wasm32")]
1018            (worker_fn)(progress_notifier.as_mut());
1019
1020            #[cfg(not(target_arch = "wasm32"))]
1021            {
1022                crossbeam_utils::thread::scope(|scope| {
1023                    for _ in 0..n_workers {
1024                        scope.spawn(|_| (worker_fn)(None));
1025                    }
1026                    if let Some(ref mut notifier) = progress_notifier {
1027                        loop {
1028                            if remaining_threads.load(Ordering::Relaxed) == 0 {
1029                                break;
1030                            }
1031                            let stage_current = processed_pixel_count.load(Ordering::Relaxed);
1032                            notifier.update(stage_current, &self.color_map);
1033                        }
1034                    }
1035                })
1036                .unwrap();
1037            }
1038
1039            if let Some(ref mut notifier) = progress_notifier {
1040                notifier.finish_stage(pixels_to_resolve);
1041            }
1042
1043            {
1044                // append all per-thread resolved lists to the global list
1045                let mut resolved = self.resolved.write().unwrap();
1046                for thread_resolved in pixels_resolved_this_stage {
1047                    resolved.append(&mut thread_resolved.into_inner().unwrap());
1048                }
1049            }
1050        }
1051    }
1052}
1053
1054struct ProgressNotifier<'a> {
1055    progress: &'a mut Box<dyn GeneratorProgress>,
1056    /// The total of pixels to resolve for the current stage.
1057    stage_total: usize,
1058    /// The number of pixels currently resolved, across all stages.
1059    overall_current: usize,
1060    /// The overall total of pixels to resolve, across all stages.
1061    overall_total: usize,
1062    /// The current percentage value.
1063    pcnt: u32,
1064}
1065
1066impl<'a> ProgressNotifier<'a> {
1067    fn new(progress: &'a mut Box<dyn GeneratorProgress>, overall_total: usize) -> Self {
1068        Self {
1069            progress,
1070            stage_total: 0,
1071            overall_current: 0,
1072            overall_total,
1073            pcnt: 0,
1074        }
1075    }
1076
1077    fn start_stage(&mut self, stage_total: usize) {
1078        debug_assert_eq!(self.stage_total, 0);
1079        self.stage_total = stage_total;
1080    }
1081
1082    fn update(&mut self, stage_current: usize, color_map: &UnsyncRgbaImage) {
1083        let pcnt = ((self.overall_current + stage_current) as f32 / self.overall_total as f32
1084            * 100f32)
1085            .round() as u32;
1086
1087        if pcnt != self.pcnt {
1088            self.progress.update(crate::session::ProgressUpdate {
1089                image: color_map.as_ref(),
1090                total: ProgressStat {
1091                    total: self.overall_total,
1092                    current: self.overall_current + stage_current,
1093                },
1094                stage: ProgressStat {
1095                    total: self.stage_total,
1096                    current: stage_current,
1097                },
1098            });
1099            self.pcnt = pcnt;
1100        }
1101    }
1102
1103    fn finish_stage(&mut self, total_pixels_stage: usize) {
1104        self.overall_current += total_pixels_stage;
1105        self.stage_total = 0;
1106    }
1107}
1108
1109#[inline]
1110fn metric_cauchy(a: u8, b: u8, sig2: f32) -> f32 {
1111    let mut x2 = (f32::from(a) - f32::from(b)) / 255.0; //normalize the colors to be between 0-1
1112    x2 = x2 * x2;
1113    (x2 / sig2).ln_1p()
1114}
1115
1116#[inline]
1117fn metric_l2(a: u8, b: u8) -> f32 {
1118    let x = (f32::from(a) - f32::from(b)) / 255.0;
1119    x * x
1120}
1121
1122#[inline]
1123fn get_color_of_neighbor(
1124    outside_color: image::Rgba<u8>,
1125    source_maps: &[ImageBuffer<'_>],
1126    n_coord: SignedCoord2D,
1127    n_map: MapId,
1128    neighbor_color: &mut [u8],
1129    is_wrap_mode: bool,
1130    wrap_dim: (i32, i32),
1131) {
1132    let coord = if is_wrap_mode {
1133        n_coord.wrap(wrap_dim)
1134    } else {
1135        n_coord
1136    };
1137
1138    //check if he haven't gone outside the possible bounds
1139    if source_maps[n_map.0 as usize].is_in_bounds(coord) {
1140        neighbor_color.copy_from_slice(
1141            &(source_maps[n_map.0 as usize])
1142                .get_pixel(coord.x as u32, coord.y as u32)
1143                .0[..4],
1144        );
1145    } else {
1146        // if we have gone out of bounds, then just fill as outside color
1147        neighbor_color.copy_from_slice(&outside_color.0[..]);
1148    }
1149}
1150
1151fn k_neighs_to_precomputed_reference_pattern(
1152    k_neighs: &[(SignedCoord2D, MapId)],
1153    outside_color: image::Rgba<u8>,
1154    source_maps: &[ImageBuffer<'_>],
1155    pattern: &mut ColorPattern,
1156    is_wrap_mode: bool,
1157) {
1158    pattern.0.resize(k_neighs.len() * 4, 0);
1159    let mut i = 0;
1160
1161    let wrap_dim = (
1162        source_maps[0].dimensions().0 as i32,
1163        source_maps[0].dimensions().1 as i32,
1164    );
1165
1166    for (n_coord, n_map) in k_neighs {
1167        let end = i + 4;
1168
1169        get_color_of_neighbor(
1170            outside_color,
1171            source_maps,
1172            *n_coord,
1173            *n_map,
1174            &mut (pattern.0[i..end]),
1175            is_wrap_mode,
1176            wrap_dim,
1177        );
1178
1179        i = end;
1180    }
1181}
1182
1183#[allow(clippy::too_many_arguments)]
1184fn find_best_match<'a>(
1185    outside_color: image::Rgba<u8>,
1186    source_maps: &[ImageBuffer<'_>],
1187    guides: &Option<GuidesStruct<'_>>,
1188    candidates: &'a [CandidateStruct],
1189    my_precomputed_pattern: &ColorPattern,
1190    my_precomputed_guide_pattern: &ColorPattern,
1191    k_distances: &[f64], //weight by distance
1192    my_cost: &PrerenderedU8Function,
1193    guide_cost: Option<&PrerenderedU8Function>,
1194) -> (&'a CandidateStruct, Score) {
1195    let mut best_match = 0;
1196    let mut lowest_cost = std::f32::MAX;
1197
1198    let distance_gaussians: Vec<f32> = k_distances
1199        .iter()
1200        .copied()
1201        .map(|d| f64::exp(-1.0f64 * d))
1202        .map(|d| d as f32)
1203        .collect();
1204
1205    for (i, cand) in candidates.iter().enumerate() {
1206        if let Some(cost) = better_match(
1207            &cand.k_neighs,
1208            outside_color,
1209            source_maps,
1210            guides,
1211            my_precomputed_pattern,
1212            my_precomputed_guide_pattern,
1213            distance_gaussians.as_slice(),
1214            my_cost,
1215            guide_cost,
1216            lowest_cost,
1217        ) {
1218            lowest_cost = cost;
1219            best_match = i;
1220        }
1221    }
1222
1223    (&candidates[best_match], Score(lowest_cost))
1224}
1225
1226#[allow(clippy::too_many_arguments)]
1227fn better_match(
1228    k_neighs: &[(SignedCoord2D, MapId)],
1229    outside_color: image::Rgba<u8>,
1230    source_maps: &[ImageBuffer<'_>],
1231    guides: &Option<GuidesStruct<'_>>,
1232    my_precomputed_pattern: &ColorPattern,
1233    my_precomputed_guide_pattern: &ColorPattern,
1234    distance_gaussians: &[f32], //weight by distance
1235    my_cost: &PrerenderedU8Function,
1236    guide_cost: Option<&PrerenderedU8Function>,
1237    current_best: f32,
1238) -> Option<f32> {
1239    let mut score: f32 = 0.0; //minimize score
1240
1241    let mut i = 0;
1242    let mut next_pixel = [0; 4];
1243    let mut next_pixel_score: f32;
1244    for (n_coord, n_map) in k_neighs {
1245        next_pixel_score = 0.0;
1246        let end = i + 4;
1247
1248        //check if he haven't gone outside the possible bounds
1249        get_color_of_neighbor(
1250            outside_color,
1251            source_maps,
1252            *n_coord,
1253            *n_map,
1254            &mut next_pixel,
1255            false,
1256            (0, 0),
1257        );
1258
1259        for (channel_n, &channel) in next_pixel.iter().enumerate() {
1260            next_pixel_score += my_cost.get(my_precomputed_pattern.0[i + channel_n], channel);
1261        }
1262
1263        if let Some(guide_cost) = guide_cost {
1264            let example_guides = &(guides.as_ref().unwrap().example_guides);
1265            get_color_of_neighbor(
1266                outside_color,
1267                example_guides,
1268                *n_coord,
1269                *n_map,
1270                &mut next_pixel,
1271                false,
1272                (0, 0),
1273            );
1274
1275            for (channel_n, &channel) in next_pixel.iter().enumerate() {
1276                next_pixel_score +=
1277                    guide_cost.get(my_precomputed_guide_pattern.0[i + channel_n], channel);
1278            }
1279        }
1280        score += next_pixel_score * distance_gaussians[i];
1281        if score >= current_best {
1282            return None;
1283        }
1284        i = end;
1285    }
1286
1287    Some(score)
1288}
1289
1290struct PrerenderedU8Function {
1291    data: Vec<f32>,
1292}
1293
1294impl PrerenderedU8Function {
1295    pub fn new<F: Fn(u8, u8) -> f32>(function: F) -> Self {
1296        let mut data = vec![0f32; 65536];
1297
1298        for a in 0..=255u8 {
1299            for b in 0..=255u8 {
1300                data[a as usize * 256usize + b as usize] = function(a, b);
1301            }
1302        }
1303
1304        Self { data }
1305    }
1306
1307    #[inline]
1308    pub fn get(&self, a: u8, b: u8) -> f32 {
1309        self.data[a as usize * 256usize + b as usize]
1310    }
1311}
1312
1313struct TreeGrid {
1314    grid_width: u32,
1315    grid_height: u32,
1316    offset_x: i32,
1317    offset_y: i32,
1318    chunk_size: u32,
1319    rtrees: Vec<RwLock<RTree<[i32; 2]>>>,
1320}
1321
1322// This is a grid of rtrees
1323// The idea is that most pixels after the first couple steps will have their neighbors close by
1324impl TreeGrid {
1325    pub fn new(width: u32, height: u32, chunk_size: u32, offset_x: u32, offset_y: u32) -> Self {
1326        let mut rtrees: Vec<RwLock<RTree<[i32; 2]>>> = Vec::new();
1327        let grid_width = max((width + chunk_size - 1) / chunk_size, 1);
1328        let grid_height = max((height + chunk_size - 1) / chunk_size, 1);
1329        rtrees.resize_with((grid_width * grid_height) as usize, || {
1330            RwLock::new(RTree::new())
1331        });
1332        Self {
1333            rtrees,
1334            grid_width,
1335            grid_height,
1336            offset_x: offset_x as i32,
1337            offset_y: offset_y as i32,
1338            chunk_size,
1339        }
1340    }
1341
1342    #[inline]
1343    fn get_tree_index(&self, x: u32, y: u32) -> usize {
1344        (x * self.grid_height + y) as usize
1345    }
1346
1347    pub fn insert(&self, x: i32, y: i32) {
1348        let my_tree_index = self.get_tree_index(
1349            ((x + self.offset_x) as u32) / self.chunk_size,
1350            ((y + self.offset_y) as u32) / self.chunk_size,
1351        );
1352        self.rtrees[my_tree_index].write().unwrap().insert([x, y]);
1353    }
1354
1355    pub fn clone_into_new_tree_grid(&self, other: &Self) {
1356        for tree in &self.rtrees {
1357            for coord in tree.read().unwrap().iter() {
1358                other.insert((*coord)[0], (*coord)[1]);
1359            }
1360        }
1361    }
1362
1363    pub fn get_k_nearest_neighbors(
1364        &self,
1365        x: u32,
1366        y: u32,
1367        k: usize,
1368        result: &mut Vec<SignedCoord2D>,
1369    ) {
1370        let offset_x = x as i32 + self.offset_x;
1371        let offset_y = y as i32 + self.offset_y;
1372
1373        let chunk_x = offset_x / self.chunk_size as i32;
1374        let chunk_y = offset_y / self.chunk_size as i32;
1375
1376        struct ChunkSearchInfo {
1377            x: i32,
1378            y: i32,
1379            center: bool,
1380            closest_point_on_boundary_x: i64,
1381            closest_point_on_boundary_y: i64,
1382        }
1383
1384        // Assume that all k nearest neighbors are in these cells
1385        // it looks like we are rarely wrong once enough pixels are filled in
1386        let places_to_look = [
1387            ChunkSearchInfo {
1388                x: chunk_x,
1389                y: chunk_y,
1390                center: true,
1391                closest_point_on_boundary_x: 0,
1392                closest_point_on_boundary_y: 0,
1393            },
1394            ChunkSearchInfo {
1395                x: chunk_x + 1,
1396                y: chunk_y,
1397                center: false,
1398                closest_point_on_boundary_x: ((chunk_x + 1) * self.chunk_size as i32
1399                    - self.offset_x) as i64,
1400                closest_point_on_boundary_y: y as i64,
1401            },
1402            ChunkSearchInfo {
1403                x: chunk_x - 1,
1404                y: chunk_y,
1405                center: false,
1406                closest_point_on_boundary_x: (chunk_x * self.chunk_size as i32 - self.offset_x)
1407                    as i64,
1408                closest_point_on_boundary_y: y as i64,
1409            },
1410            ChunkSearchInfo {
1411                x: chunk_x,
1412                y: chunk_y - 1,
1413                center: false,
1414                closest_point_on_boundary_x: x as i64,
1415                closest_point_on_boundary_y: (chunk_y * self.chunk_size as i32 - self.offset_y)
1416                    as i64,
1417            },
1418            ChunkSearchInfo {
1419                x: chunk_x,
1420                y: chunk_y + 1,
1421                center: false,
1422                closest_point_on_boundary_x: x as i64,
1423                closest_point_on_boundary_y: ((chunk_y + 1) * self.chunk_size as i32
1424                    - self.offset_y) as i64,
1425            },
1426            ChunkSearchInfo {
1427                x: chunk_x + 1,
1428                y: chunk_y + 1,
1429                center: false,
1430                closest_point_on_boundary_x: ((chunk_x + 1) * self.chunk_size as i32
1431                    - self.offset_x) as i64,
1432                closest_point_on_boundary_y: ((chunk_y + 1) * self.chunk_size as i32
1433                    - self.offset_y) as i64,
1434            },
1435            ChunkSearchInfo {
1436                x: chunk_x - 1,
1437                y: chunk_y + 1,
1438                center: false,
1439                closest_point_on_boundary_x: (chunk_x * self.chunk_size as i32 - self.offset_x)
1440                    as i64,
1441                closest_point_on_boundary_y: ((chunk_y + 1) * self.chunk_size as i32
1442                    - self.offset_y) as i64,
1443            },
1444            ChunkSearchInfo {
1445                x: chunk_x + 1,
1446                y: chunk_y - 1,
1447                center: false,
1448                closest_point_on_boundary_x: ((chunk_x + 1) * self.chunk_size as i32
1449                    - self.offset_x) as i64,
1450                closest_point_on_boundary_y: (chunk_y * self.chunk_size as i32 - self.offset_y)
1451                    as i64,
1452            },
1453            ChunkSearchInfo {
1454                x: chunk_x - 1,
1455                y: chunk_y - 1,
1456                center: false,
1457                closest_point_on_boundary_x: (chunk_x * self.chunk_size as i32 - self.offset_x)
1458                    as i64,
1459                closest_point_on_boundary_y: (chunk_y * self.chunk_size as i32 - self.offset_y)
1460                    as i64,
1461            },
1462        ];
1463        // Note locking all of them at different times seems to be the best way
1464        // Naively trying to lock all at once could easily result in deadlocks
1465        let mut tmp_result: Vec<(i32, i32, i64)> = Vec::with_capacity(k * 9);
1466        result.clear();
1467        result.reserve(k);
1468
1469        // an upper bound is good enough here
1470        let mut upper_bound_kth_best_squared_distance = i64::max_value();
1471        for place_to_look in places_to_look.iter() {
1472            if place_to_look.x >= 0
1473                && place_to_look.x < self.grid_width as i32
1474                && place_to_look.y >= 0
1475                && place_to_look.y < self.grid_height as i32
1476            {
1477                let is_center = place_to_look.center;
1478
1479                // a tiny optimization to help us throw far away neighbors
1480                // saves us a decent amount of reads
1481                if !is_center {
1482                    let squared_distance_to_closest_possible_point_on_chunk = (x as i64
1483                        - place_to_look.closest_point_on_boundary_x)
1484                        * (x as i64 - place_to_look.closest_point_on_boundary_x)
1485                        + (y as i64 - place_to_look.closest_point_on_boundary_y)
1486                            * (y as i64 - place_to_look.closest_point_on_boundary_y);
1487
1488                    if squared_distance_to_closest_possible_point_on_chunk
1489                        > upper_bound_kth_best_squared_distance
1490                    {
1491                        continue;
1492                    }
1493                }
1494
1495                let my_tree_index =
1496                    self.get_tree_index(place_to_look.x as u32, place_to_look.y as u32);
1497                let my_rtree = &self.rtrees[my_tree_index];
1498                tmp_result.extend(
1499                    my_rtree
1500                        .read()
1501                        .unwrap()
1502                        .nearest_neighbor_iter(&[x as i32, y as i32])
1503                        .take(k)
1504                        .map(|a| {
1505                            (
1506                                (*a)[0],
1507                                (*a)[1],
1508                                ((*a)[0] as i64 - x as i64) * ((*a)[0] as i64 - x as i64)
1509                                    + ((*a)[1] as i64 - y as i64) * ((*a)[1] as i64 - y as i64),
1510                            )
1511                        }),
1512                );
1513
1514                // this isn't really the kth best distance but it's an okay approximation
1515                if tmp_result.len() >= k {
1516                    let furthest_dist_for_chunk = tmp_result[tmp_result.len() - 1].2;
1517                    if furthest_dist_for_chunk < upper_bound_kth_best_squared_distance {
1518                        upper_bound_kth_best_squared_distance = furthest_dist_for_chunk;
1519                    }
1520                }
1521            }
1522        }
1523        tmp_result.sort_by_key(|k| k.2);
1524        result.extend(
1525            tmp_result
1526                .iter()
1527                .take(k)
1528                .map(|a| SignedCoord2D::from(a.0, a.1)),
1529        );
1530    }
1531}
1532
1533#[inline]
1534fn check_coord_validity(
1535    coord: SignedCoord2D,
1536    map_id: MapId,
1537    example_maps: &[ImageBuffer<'_>],
1538    mask: &SamplingMethod,
1539) -> bool {
1540    if !example_maps[map_id.0 as usize].is_in_bounds(coord) {
1541        return false;
1542    }
1543
1544    match mask {
1545        SamplingMethod::All => true,
1546        SamplingMethod::Image(ref img) => img[(coord.x as u32, coord.y as u32)][0] != 0,
1547        SamplingMethod::Ignore => unreachable!(),
1548    }
1549}
1550
1551//get all the example images from a single pyramid level
1552fn get_single_example_level<'a>(
1553    example_maps_pyramid: &'a [ImagePyramid],
1554    valid_samples_mask: &[SamplingMethod],
1555    pyramid_level: usize,
1556) -> Vec<ImageBuffer<'a>> {
1557    example_maps_pyramid
1558        .iter()
1559        .enumerate()
1560        .filter(|&(i, _)| !valid_samples_mask[i].is_ignore())
1561        .map(|(_, a)| ImageBuffer::from(&a.pyramid[pyramid_level]))
1562        .collect()
1563}
1564
1565//get all the guide images from a single pyramid level
1566fn get_single_guide_level(
1567    guides_pyramid: &Option<GuidesPyramidStruct>,
1568    pyramid_level: usize,
1569) -> Option<GuidesStruct<'_>> {
1570    guides_pyramid
1571        .as_ref()
1572        .map(|guides_pyr| guides_pyr.to_guides_struct(pyramid_level))
1573}