use std::collections::HashSet;
use crate::{math::Point, FloatNumber};
#[derive(Debug, Copy, Clone, PartialEq, Default)]
pub enum SeedGenerator {
#[default]
RegularGrid,
}
impl SeedGenerator {
#[must_use]
pub fn generate<T, const N: usize>(
&self,
width: usize,
height: usize,
pixels: &[Point<T, N>],
mask: &[bool],
k: usize,
) -> HashSet<usize>
where
T: FloatNumber,
{
assert_eq!(
pixels.len(),
mask.len(),
"pixels and mask must have the same length"
);
if k == 0 {
return HashSet::new();
}
if k > pixels.len() {
return HashSet::from_iter(mask.iter().enumerate().filter(|(_, &m)| m).map(|(i, _)| i));
}
match self {
Self::RegularGrid => regular_grid(width, height, pixels, mask, k),
}
}
}
#[inline]
#[must_use]
fn regular_grid<T, const N: usize>(
width: usize,
height: usize,
pixels: &[Point<T, N>],
mask: &[bool],
k: usize,
) -> HashSet<usize>
where
T: FloatNumber,
{
let step = (T::from_usize(pixels.len()) / T::from_usize(k))
.sqrt()
.round()
.trunc_to_usize()
.max(1); let half = step / 2;
let mut seeds = HashSet::with_capacity(k);
'outer: for y in (half..height).step_by(step) {
for x in (half..width).step_by(step) {
let index = x + y * width;
if mask[index] && index < pixels.len() {
seeds.insert(index);
}
if seeds.len() == k {
break 'outer;
}
}
}
seeds
}
#[cfg(test)]
mod tests {
use rstest::rstest;
use super::*;
#[must_use]
fn sample_points<T>(cols: usize, rows: usize) -> Vec<Point<T, 2>>
where
T: FloatNumber,
{
vec![[T::zero(); 2]; cols * rows]
}
#[test]
fn test_default() {
let generator = SeedGenerator::default();
assert_eq!(generator, SeedGenerator::RegularGrid);
}
#[rstest]
#[case(0, vec![])]
#[case(1, vec![65])] #[case(2, vec![39, 46])] #[case(4, vec![26, 31, 86, 91])] #[case(6, vec![26, 30, 34, 74, 78, 82])] fn test_regular_grid_generate(#[case] k: usize, #[case] expected: Vec<usize>) {
let width = 12;
let height = 9;
let points = sample_points::<f64>(width, height);
let mask = vec![true; width * height];
let generator = SeedGenerator::RegularGrid;
let actual = generator.generate(width, height, &points, &mask, k);
assert_eq!(actual.len(), expected.len());
assert_eq!(actual, HashSet::from_iter(expected));
}
#[test]
fn test_generate_zero_seeds() {
let width = 4;
let height = 3;
let points = sample_points::<f64>(width, height);
let mask = vec![true; width * height];
let generator = SeedGenerator::default();
let actual = generator.generate(width, height, &points, &mask, 0);
assert_eq!(actual.len(), 0);
}
#[test]
fn test_generate_too_many_seeds() {
let width = 4;
let height = 3;
let points = sample_points::<f64>(width, height);
let mask = vec![true; width * height];
let generator = SeedGenerator::default();
let actual = generator.generate(width, height, &points, &mask, 13);
assert_eq!(actual.len(), 12);
}
#[test]
fn test_generate_with_mask() {
let width = 4;
let height = 3;
let points = sample_points::<f64>(width, height);
let mask = vec![
true, true, true, true, true, false, true, true, true, true, true, true,
];
let generator = SeedGenerator::default();
let actual = generator.generate(width, height, &points, &mask, 2);
assert_eq!(actual.len(), 1);
assert_eq!(actual, HashSet::from_iter([7]));
}
}