use ndarray::{Array2, ArrayView2};
use rand::{rngs::SmallRng, Rng, SeedableRng};
#[derive(Copy, Clone, Debug, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct DiamondSquare {
iteration: u32,
roughness: f32,
seed: u64,
}
impl Default for DiamondSquare {
fn default() -> Self {
unsafe { Self { iteration: 2, seed: 42, roughness: 1.1 } }
}
}
impl DiamondSquare {
pub fn get_iteration(&self) -> u32 {
self.iteration
}
pub fn set_iteration(&mut self, iteration: u32) {
assert!(iteration < 30, "iteration too high, out of memory");
self.iteration = iteration;
}
pub fn with_iteration(mut self, iteration: u32) -> Self {
self.set_iteration(iteration);
self
}
pub fn get_roughness(&self) -> f32 {
self.roughness
}
pub fn set_roughness(&mut self, roughness: f32) {
assert!(roughness >= 1.0, "roughness must be greater than 1.0");
self.roughness = roughness;
}
pub fn with_roughness(mut self, roughness: f32) -> Self {
self.set_roughness(roughness);
self
}
pub fn get_seed(&self) -> u64 {
self.seed
}
pub fn set_seed(&mut self, seed: u64) {
self.seed = seed;
}
pub fn with_seed(mut self, seed: u64) -> Self {
self.set_seed(seed);
self
}
}
impl DiamondSquare {
pub fn enlarge(&self, matrix: ArrayView2<f32>) -> Array2<f32> {
let mut rng = SmallRng::seed_from_u64(self.seed);
let mut step = 2usize.pow(self.iteration);
unsafe {
let (mut output, w, h) = self.enlarge_map(matrix);
for iteration in 0..self.iteration {
println!("Iteration: {}, step: {} in ({}, {})", iteration + 1, step, h, w);
let half = step / 2;
for j in (half..h).step_by(step) {
for i in (half..w).step_by(step) {
let lu = *output.uget([i - half, j - half]);
let ru = *output.uget([(i + half) % w, j - half]);
let ld = *output.uget([i - half, (j + half) % h]);
let rd = *output.uget([(i + half) % w, (j + half) % h]);
*output.uget_mut([i, j]) = self.random_average(&mut rng, [lu, ru, ld, rd]);
}
}
for j in (half..h).step_by(step) {
for i in (0..w).step_by(step) {
let l = *output.uget([(w + i - half) % w, j]);
let r = *output.uget([(0 + i + half), j]);
let u = *output.uget([i, (h + j - half) % h]);
let d = *output.uget([i, (0 + j + half) % h]);
*output.uget_mut([i, j]) = self.random_average(&mut rng, [l, r, u, d]);
}
}
for j in (0..h).step_by(step) {
for i in (half..w).step_by(step) {
let l = *output.get([i - half, j]).unwrap();
let r = *output.get([(i + half) % w, j]).unwrap();
let u = *output.get([i, (h + j - half) % h]).unwrap();
let d = *output.get([i, (0 + j + half) % h]).unwrap();
*output.uget_mut([i, j]) = self.random_average(&mut rng, [l, r, u, d]);
}
}
step = half;
}
output
}
}
unsafe fn enlarge_map(&self, matrix: ArrayView2<f32>) -> (Array2<f32>, usize, usize) {
let step = 2usize.pow(self.iteration);
let width = matrix.shape().get_unchecked(0) * step;
let height = matrix.shape().get_unchecked(1) * step;
let mut output = Array2::zeros((width, height));
for ((x, y), v) in matrix.indexed_iter() {
*output.uget_mut((x * step, y * step)) = *v;
}
(output, width, height)
}
fn random_average(&self, rng: &mut SmallRng, vs: [f32; 4]) -> f32 {
let avg = vs.iter().sum::<f32>() / 4.0;
let r_roughness = self.roughness.recip();
avg * rng.gen_range(r_roughness..self.roughness)
}
}