diamond_square/ds/
mod.rs

1use ndarray::{Array2, ArrayView2};
2use rand::{rngs::SmallRng, Rng, SeedableRng};
3
4/// Generate a grid using diamond square algorithm
5///
6/// # Arguments
7///
8/// * `rng`:
9/// * `vs`:
10///
11/// returns: f32
12///
13/// # Examples
14///
15/// ```
16/// use ndarray::Array2;
17/// ```
18#[derive(Copy, Clone, Debug, PartialEq)]
19#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
20pub struct DiamondSquare {
21    /// Iteration of the algorithm
22    iteration: u32,
23    /// Roughness of the grid
24    roughness: f32,
25    /// Seed of the random number generator
26    seed: u64,
27}
28
29impl Default for DiamondSquare {
30    fn default() -> Self {
31        unsafe { Self { iteration: 2, seed: 42, roughness: 1.1 } }
32    }
33}
34
35impl DiamondSquare {
36    /// Generate a grid using diamond square algorithm
37    ///
38    /// # Arguments
39    ///
40    /// * `rng`:
41    /// * `vs`:
42    ///
43    /// returns: f32
44    ///
45    /// # Examples
46    ///
47    /// ```
48    /// # use diamond_square::DiamondSquare;
49    /// let mut cfg = DiamondSquare::default();
50    /// assert_eq!(cfg.get_iteration(), 2);
51    /// ```
52    pub fn get_iteration(&self) -> u32 {
53        self.iteration
54    }
55    /// Generate a grid using diamond square algorithm
56    ///
57    /// # Arguments
58    ///
59    /// * `rng`:
60    /// * `vs`:
61    ///
62    /// returns: f32
63    ///
64    /// # Examples
65    ///
66    /// ```
67    /// # use diamond_square::DiamondSquare;
68    /// let mut cfg = DiamondSquare::default();
69    /// cfg.set_iteration(5);
70    /// assert_eq!(cfg.get_iteration(), 5);
71    /// ```
72    pub fn set_iteration(&mut self, iteration: u32) {
73        assert!(iteration < 30, "iteration too high, out of memory");
74        self.iteration = iteration;
75    }
76    /// Generate a grid using diamond square algorithm
77    ///
78    /// # Arguments
79    ///
80    /// * `rng`:
81    /// * `vs`:
82    ///
83    /// returns: f32
84    ///
85    /// # Examples
86    ///
87    /// ```
88    /// # use diamond_square::DiamondSquare;
89    /// let mut cfg = DiamondSquare::default().with_iteration(5);
90    /// assert_eq!(cfg.get_iteration(), 5);
91    /// ```
92    pub fn with_iteration(mut self, iteration: u32) -> Self {
93        self.set_iteration(iteration);
94        self
95    }
96    /// Generate a grid using diamond square algorithm
97    ///
98    /// # Arguments
99    ///
100    /// * `rng`:
101    /// * `vs`:
102    ///
103    /// returns: f32
104    ///
105    /// # Examples
106    ///
107    /// ```
108    /// # use diamond_square::DiamondSquare;
109    /// let mut cfg = DiamondSquare::default();
110    /// assert_eq!(cfg.get_roughness(), 1.1);
111    /// ```
112    pub fn get_roughness(&self) -> f32 {
113        self.roughness
114    }
115    /// Generate a grid using diamond square algorithm
116    ///
117    /// # Arguments
118    ///
119    /// * `rng`:
120    /// * `vs`:
121    ///
122    /// returns: f32
123    ///
124    /// # Examples
125    ///
126    /// ```
127    /// # use diamond_square::DiamondSquare;
128    /// let mut cfg = DiamondSquare::default();
129    /// cfg.set_roughness(1.5);
130    /// assert_eq!(cfg.get_roughness(), 1.5);
131    /// ```
132    pub fn set_roughness(&mut self, roughness: f32) {
133        assert!(roughness >= 1.0, "roughness must be greater than 1.0");
134        self.roughness = roughness;
135    }
136    /// Generate a grid using diamond square algorithm
137    ///
138    /// # Arguments
139    ///
140    /// * `rng`:
141    /// * `vs`:
142    ///
143    /// returns: f32
144    ///
145    /// # Examples
146    ///
147    /// ```
148    /// # use diamond_square::DiamondSquare;
149    /// let mut cfg = DiamondSquare::default().with_roughness(1.5);
150    /// assert_eq!(cfg.get_roughness(), 1.5);
151    /// ```
152    pub fn with_roughness(mut self, roughness: f32) -> Self {
153        self.set_roughness(roughness);
154        self
155    }
156    /// Generate a grid using diamond square algorithm
157    ///
158    /// # Arguments
159    ///
160    /// * `rng`:
161    /// * `vs`:
162    ///
163    /// returns: f32
164    ///
165    /// # Examples
166    ///
167    /// ```
168    /// # use diamond_square::DiamondSquare;
169    /// let mut cfg = DiamondSquare::default();
170    /// assert_eq!(cfg.get_seed(), 42);
171    /// ```
172    pub fn get_seed(&self) -> u64 {
173        self.seed
174    }
175    /// Generate a grid using diamond square algorithm
176    ///
177    /// # Arguments
178    ///
179    /// * `rng`:
180    /// * `vs`:
181    ///
182    /// returns: f32
183    ///
184    /// # Examples
185    ///
186    /// ```
187    /// # use diamond_square::DiamondSquare;
188    /// let mut cfg = DiamondSquare::default();
189    /// cfg.set_seed(0);
190    /// assert_eq!(cfg.get_seed(), 0);
191    /// ```
192    pub fn set_seed(&mut self, seed: u64) {
193        self.seed = seed;
194    }
195    /// Generate a grid using diamond square algorithm
196    ///
197    /// # Arguments
198    ///
199    /// * `rng`:
200    /// * `vs`:
201    ///
202    /// returns: f32
203    ///
204    /// # Examples
205    ///
206    /// ```
207    /// # use diamond_square::DiamondSquare;
208    /// let mut cfg = DiamondSquare::default().with_seed(0);
209    /// assert_eq!(cfg.get_seed(), 0);
210    /// ```
211    pub fn with_seed(mut self, seed: u64) -> Self {
212        self.set_seed(seed);
213        self
214    }
215}
216
217impl DiamondSquare {
218    /// Generate a grid using diamond square algorithm
219    ///
220    /// # Arguments
221    ///
222    /// * `rng`:
223    /// * `vs`:
224    ///
225    /// returns: f32
226    ///
227    /// # Examples
228    ///
229    /// ```
230    /// use ndarray::Array2;
231    /// ```
232    pub fn enlarge(&self, matrix: ArrayView2<f32>) -> Array2<f32> {
233        let mut rng = SmallRng::seed_from_u64(self.seed);
234        let mut step = 2usize.pow(self.iteration);
235        unsafe {
236            let (mut output, w, h) = self.enlarge_map(matrix);
237            for iteration in 0..self.iteration {
238                println!("Iteration: {}, step: {} in ({}, {})", iteration + 1, step, h, w);
239                // diamond step
240                let half = step / 2;
241                for j in (half..h).step_by(step) {
242                    for i in (half..w).step_by(step) {
243                        let lu = *output.uget([i - half, j - half]);
244                        let ru = *output.uget([(i + half) % w, j - half]);
245                        let ld = *output.uget([i - half, (j + half) % h]);
246                        let rd = *output.uget([(i + half) % w, (j + half) % h]);
247                        *output.uget_mut([i, j]) = self.random_average(&mut rng, [lu, ru, ld, rd]);
248                    }
249                }
250                // square step even rows
251                for j in (half..h).step_by(step) {
252                    for i in (0..w).step_by(step) {
253                        let l = *output.uget([(w + i - half) % w, j]);
254                        let r = *output.uget([(0 + i + half), j]);
255                        let u = *output.uget([i, (h + j - half) % h]);
256                        let d = *output.uget([i, (0 + j + half) % h]);
257                        *output.uget_mut([i, j]) = self.random_average(&mut rng, [l, r, u, d]);
258                    }
259                }
260                // square step old rows
261                for j in (0..h).step_by(step) {
262                    for i in (half..w).step_by(step) {
263                        let l = *output.get([i - half, j]).unwrap();
264                        let r = *output.get([(i + half) % w, j]).unwrap();
265                        let u = *output.get([i, (h + j - half) % h]).unwrap();
266                        let d = *output.get([i, (0 + j + half) % h]).unwrap();
267                        *output.uget_mut([i, j]) = self.random_average(&mut rng, [l, r, u, d]);
268                    }
269                }
270                step = half;
271            }
272            // drop last colomn and last row
273            output
274        }
275    }
276
277    unsafe fn enlarge_map(&self, matrix: ArrayView2<f32>) -> (Array2<f32>, usize, usize) {
278        let step = 2usize.pow(self.iteration);
279        let width = matrix.shape().get_unchecked(0) * step;
280        let height = matrix.shape().get_unchecked(1) * step;
281        let mut output = Array2::zeros((width, height));
282        // fill the corners
283        for ((x, y), v) in matrix.indexed_iter() {
284            *output.uget_mut((x * step, y * step)) = *v;
285        }
286        (output, width, height)
287    }
288
289    /// Calculate the average of the given values with random multiplier
290    fn random_average(&self, rng: &mut SmallRng, vs: [f32; 4]) -> f32 {
291        let avg = vs.iter().sum::<f32>() / 4.0;
292        let r_roughness = self.roughness.recip();
293        avg * rng.gen_range(r_roughness..self.roughness)
294    }
295}