diamond_square/ds/mod.rs
1use ndarray::{Array2, ArrayView2};
2use rand::{rngs::SmallRng, Rng, SeedableRng};
3
4/// Generate a grid using diamond square algorithm
5///
6/// # Arguments
7///
8/// * `rng`:
9/// * `vs`:
10///
11/// returns: f32
12///
13/// # Examples
14///
15/// ```
16/// use ndarray::Array2;
17/// ```
18#[derive(Copy, Clone, Debug, PartialEq)]
19#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
20pub struct DiamondSquare {
21 /// Iteration of the algorithm
22 iteration: u32,
23 /// Roughness of the grid
24 roughness: f32,
25 /// Seed of the random number generator
26 seed: u64,
27}
28
29impl Default for DiamondSquare {
30 fn default() -> Self {
31 unsafe { Self { iteration: 2, seed: 42, roughness: 1.1 } }
32 }
33}
34
35impl DiamondSquare {
36 /// Generate a grid using diamond square algorithm
37 ///
38 /// # Arguments
39 ///
40 /// * `rng`:
41 /// * `vs`:
42 ///
43 /// returns: f32
44 ///
45 /// # Examples
46 ///
47 /// ```
48 /// # use diamond_square::DiamondSquare;
49 /// let mut cfg = DiamondSquare::default();
50 /// assert_eq!(cfg.get_iteration(), 2);
51 /// ```
52 pub fn get_iteration(&self) -> u32 {
53 self.iteration
54 }
55 /// Generate a grid using diamond square algorithm
56 ///
57 /// # Arguments
58 ///
59 /// * `rng`:
60 /// * `vs`:
61 ///
62 /// returns: f32
63 ///
64 /// # Examples
65 ///
66 /// ```
67 /// # use diamond_square::DiamondSquare;
68 /// let mut cfg = DiamondSquare::default();
69 /// cfg.set_iteration(5);
70 /// assert_eq!(cfg.get_iteration(), 5);
71 /// ```
72 pub fn set_iteration(&mut self, iteration: u32) {
73 assert!(iteration < 30, "iteration too high, out of memory");
74 self.iteration = iteration;
75 }
76 /// Generate a grid using diamond square algorithm
77 ///
78 /// # Arguments
79 ///
80 /// * `rng`:
81 /// * `vs`:
82 ///
83 /// returns: f32
84 ///
85 /// # Examples
86 ///
87 /// ```
88 /// # use diamond_square::DiamondSquare;
89 /// let mut cfg = DiamondSquare::default().with_iteration(5);
90 /// assert_eq!(cfg.get_iteration(), 5);
91 /// ```
92 pub fn with_iteration(mut self, iteration: u32) -> Self {
93 self.set_iteration(iteration);
94 self
95 }
96 /// Generate a grid using diamond square algorithm
97 ///
98 /// # Arguments
99 ///
100 /// * `rng`:
101 /// * `vs`:
102 ///
103 /// returns: f32
104 ///
105 /// # Examples
106 ///
107 /// ```
108 /// # use diamond_square::DiamondSquare;
109 /// let mut cfg = DiamondSquare::default();
110 /// assert_eq!(cfg.get_roughness(), 1.1);
111 /// ```
112 pub fn get_roughness(&self) -> f32 {
113 self.roughness
114 }
115 /// Generate a grid using diamond square algorithm
116 ///
117 /// # Arguments
118 ///
119 /// * `rng`:
120 /// * `vs`:
121 ///
122 /// returns: f32
123 ///
124 /// # Examples
125 ///
126 /// ```
127 /// # use diamond_square::DiamondSquare;
128 /// let mut cfg = DiamondSquare::default();
129 /// cfg.set_roughness(1.5);
130 /// assert_eq!(cfg.get_roughness(), 1.5);
131 /// ```
132 pub fn set_roughness(&mut self, roughness: f32) {
133 assert!(roughness >= 1.0, "roughness must be greater than 1.0");
134 self.roughness = roughness;
135 }
136 /// Generate a grid using diamond square algorithm
137 ///
138 /// # Arguments
139 ///
140 /// * `rng`:
141 /// * `vs`:
142 ///
143 /// returns: f32
144 ///
145 /// # Examples
146 ///
147 /// ```
148 /// # use diamond_square::DiamondSquare;
149 /// let mut cfg = DiamondSquare::default().with_roughness(1.5);
150 /// assert_eq!(cfg.get_roughness(), 1.5);
151 /// ```
152 pub fn with_roughness(mut self, roughness: f32) -> Self {
153 self.set_roughness(roughness);
154 self
155 }
156 /// Generate a grid using diamond square algorithm
157 ///
158 /// # Arguments
159 ///
160 /// * `rng`:
161 /// * `vs`:
162 ///
163 /// returns: f32
164 ///
165 /// # Examples
166 ///
167 /// ```
168 /// # use diamond_square::DiamondSquare;
169 /// let mut cfg = DiamondSquare::default();
170 /// assert_eq!(cfg.get_seed(), 42);
171 /// ```
172 pub fn get_seed(&self) -> u64 {
173 self.seed
174 }
175 /// Generate a grid using diamond square algorithm
176 ///
177 /// # Arguments
178 ///
179 /// * `rng`:
180 /// * `vs`:
181 ///
182 /// returns: f32
183 ///
184 /// # Examples
185 ///
186 /// ```
187 /// # use diamond_square::DiamondSquare;
188 /// let mut cfg = DiamondSquare::default();
189 /// cfg.set_seed(0);
190 /// assert_eq!(cfg.get_seed(), 0);
191 /// ```
192 pub fn set_seed(&mut self, seed: u64) {
193 self.seed = seed;
194 }
195 /// Generate a grid using diamond square algorithm
196 ///
197 /// # Arguments
198 ///
199 /// * `rng`:
200 /// * `vs`:
201 ///
202 /// returns: f32
203 ///
204 /// # Examples
205 ///
206 /// ```
207 /// # use diamond_square::DiamondSquare;
208 /// let mut cfg = DiamondSquare::default().with_seed(0);
209 /// assert_eq!(cfg.get_seed(), 0);
210 /// ```
211 pub fn with_seed(mut self, seed: u64) -> Self {
212 self.set_seed(seed);
213 self
214 }
215}
216
217impl DiamondSquare {
218 /// Generate a grid using diamond square algorithm
219 ///
220 /// # Arguments
221 ///
222 /// * `rng`:
223 /// * `vs`:
224 ///
225 /// returns: f32
226 ///
227 /// # Examples
228 ///
229 /// ```
230 /// use ndarray::Array2;
231 /// ```
232 pub fn enlarge(&self, matrix: ArrayView2<f32>) -> Array2<f32> {
233 let mut rng = SmallRng::seed_from_u64(self.seed);
234 let mut step = 2usize.pow(self.iteration);
235 unsafe {
236 let (mut output, w, h) = self.enlarge_map(matrix);
237 for iteration in 0..self.iteration {
238 println!("Iteration: {}, step: {} in ({}, {})", iteration + 1, step, h, w);
239 // diamond step
240 let half = step / 2;
241 for j in (half..h).step_by(step) {
242 for i in (half..w).step_by(step) {
243 let lu = *output.uget([i - half, j - half]);
244 let ru = *output.uget([(i + half) % w, j - half]);
245 let ld = *output.uget([i - half, (j + half) % h]);
246 let rd = *output.uget([(i + half) % w, (j + half) % h]);
247 *output.uget_mut([i, j]) = self.random_average(&mut rng, [lu, ru, ld, rd]);
248 }
249 }
250 // square step even rows
251 for j in (half..h).step_by(step) {
252 for i in (0..w).step_by(step) {
253 let l = *output.uget([(w + i - half) % w, j]);
254 let r = *output.uget([(0 + i + half), j]);
255 let u = *output.uget([i, (h + j - half) % h]);
256 let d = *output.uget([i, (0 + j + half) % h]);
257 *output.uget_mut([i, j]) = self.random_average(&mut rng, [l, r, u, d]);
258 }
259 }
260 // square step old rows
261 for j in (0..h).step_by(step) {
262 for i in (half..w).step_by(step) {
263 let l = *output.get([i - half, j]).unwrap();
264 let r = *output.get([(i + half) % w, j]).unwrap();
265 let u = *output.get([i, (h + j - half) % h]).unwrap();
266 let d = *output.get([i, (0 + j + half) % h]).unwrap();
267 *output.uget_mut([i, j]) = self.random_average(&mut rng, [l, r, u, d]);
268 }
269 }
270 step = half;
271 }
272 // drop last colomn and last row
273 output
274 }
275 }
276
277 unsafe fn enlarge_map(&self, matrix: ArrayView2<f32>) -> (Array2<f32>, usize, usize) {
278 let step = 2usize.pow(self.iteration);
279 let width = matrix.shape().get_unchecked(0) * step;
280 let height = matrix.shape().get_unchecked(1) * step;
281 let mut output = Array2::zeros((width, height));
282 // fill the corners
283 for ((x, y), v) in matrix.indexed_iter() {
284 *output.uget_mut((x * step, y * step)) = *v;
285 }
286 (output, width, height)
287 }
288
289 /// Calculate the average of the given values with random multiplier
290 fn random_average(&self, rng: &mut SmallRng, vs: [f32; 4]) -> f32 {
291 let avg = vs.iter().sum::<f32>() / 4.0;
292 let r_roughness = self.roughness.recip();
293 avg * rng.gen_range(r_roughness..self.roughness)
294 }
295}