diamond_square/md/
mod.rs

1use ndarray::Array1;
2use rand::{rngs::SmallRng, Rng, SeedableRng};
3use std::{num::NonZeroUsize, ops::Range};
4
5/// Generate a grid using diamond square algorithm
6///
7/// # Examples
8///
9/// ```
10/// use ndarray::Array2;
11/// ```
12#[derive(Clone, Debug, PartialEq)]
13#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
14pub struct MidpointDisplacement {
15    /// Length of the segment
16    length: NonZeroUsize,
17    /// Iteration of the algorithm
18    iteration: u32,
19    /// Range of the random number
20    range: Range<f32>,
21    /// Roughness of the grid
22    roughness: f32,
23    /// Seed of the random number generator
24    seed: u64,
25}
26
27impl Default for MidpointDisplacement {
28    fn default() -> Self {
29        unsafe {
30            Self {
31                length: NonZeroUsize::new_unchecked(4),
32                iteration: 2,
33                roughness: 1.1,
34                seed: 42,
35                range: Range { start: -1.0, end: 1.0 },
36            }
37        }
38    }
39}
40
41impl MidpointDisplacement {
42    /// Get the initial size of the terrain.
43    ///
44    /// # Examples
45    ///
46    /// ```
47    /// # use diamond_square::MidpointDisplacement;
48    /// let mut cfg = MidpointDisplacement::default();
49    /// assert_eq!(cfg.get_length(), 4);
50    /// ```
51    pub fn get_length(&self) -> usize {
52        self.length.get()
53    }
54    /// Get the initial size of the terrain.
55    ///
56    /// # Examples
57    ///
58    /// ```
59    /// # use diamond_square::MidpointDisplacement;
60    /// let mut cfg = MidpointDisplacement::default();
61    /// assert_eq!(cfg.get_map_length(), 17);
62    /// ```
63    pub fn get_map_length(&self) -> usize {
64        let step = 2usize.pow(self.iteration);
65        self.length.get() * step + 1
66    }
67    /// Get the initial size of the terrain.
68    ///
69    /// # Examples
70    ///
71    /// ```
72    /// # use diamond_square::MidpointDisplacement;
73    /// let mut cfg = MidpointDisplacement::default();
74    /// cfg.set_length(8);
75    /// assert_eq!(cfg.get_length(), 8);
76    /// ```
77    pub fn set_length(&mut self, length: usize) {
78        self.length = NonZeroUsize::new(length).unwrap();
79    }
80    /// Get the initial size of the terrain.
81    ///
82    /// # Examples
83    ///
84    /// ```
85    /// # use diamond_square::MidpointDisplacement;
86    /// let mut cfg = MidpointDisplacement::default().with_length(10);
87    /// assert_eq!(cfg.get_length(), 10);
88    /// ```
89    pub fn with_length(mut self, length: usize) -> Self {
90        self.set_length(length);
91        self
92    }
93    /// Get the initial size of the terrain.
94    ///
95    /// # Examples
96    ///
97    /// ```
98    /// # use diamond_square::MidpointDisplacement;
99    /// let mut cfg = MidpointDisplacement::default();
100    /// assert_eq!(cfg.get_iteration(), 2);
101    /// ```
102    pub fn get_iteration(&self) -> u32 {
103        self.iteration
104    }
105    /// Get the initial size of the terrain.
106    ///
107    /// # Examples
108    ///
109    /// ```
110    /// # use diamond_square::MidpointDisplacement;
111    /// let mut cfg = MidpointDisplacement::default();
112    /// cfg.set_iteration(10);
113    /// assert_eq!(cfg.get_iteration(), 10);
114    /// ```
115    pub fn set_iteration(&mut self, iteration: u32) {
116        self.iteration = iteration;
117    }
118    /// Get the initial size of the terrain.
119    ///
120    /// # Examples
121    ///
122    /// ```
123    /// # use diamond_square::MidpointDisplacement;
124    /// let mut cfg = MidpointDisplacement::default().with_iteration(10);
125    /// assert_eq!(cfg.get_iteration(), 10);
126    /// ```
127    pub fn with_iteration(mut self, iteration: u32) -> Self {
128        self.set_iteration(iteration);
129        self
130    }
131    /// Get the initial size of the terrain.
132    ///
133    /// # Examples
134    ///
135    /// ```
136    /// # use diamond_square::MidpointDisplacement;
137    /// let mut cfg = MidpointDisplacement::default();
138    /// assert_eq!(cfg.get_range().start, -1.0);
139    /// assert_eq!(cfg.get_range().end, 1.0);
140    /// ```
141    pub fn get_range(&self) -> Range<f32> {
142        self.range.clone()
143    }
144    /// Get the initial size of the terrain.
145    ///
146    /// # Examples
147    ///
148    /// ```
149    /// # use diamond_square::MidpointDisplacement;
150    /// let mut cfg = MidpointDisplacement::default();
151    /// cfg.set_range(0.0..1.0);
152    /// assert_eq!(cfg.get_range().start, 0.0);
153    /// assert_eq!(cfg.get_range().end, 1.0);
154    /// ```
155    pub fn set_range(&mut self, range: Range<f32>) {
156        self.range = range;
157    }
158    /// Get the initial size of the terrain.
159    ///
160    /// # Examples
161    ///
162    /// ```
163    /// # use diamond_square::MidpointDisplacement;
164    /// let mut cfg = MidpointDisplacement::default().with_range(0.0..1.0);
165    /// assert_eq!(cfg.get_range().start, 0.0);
166    /// assert_eq!(cfg.get_range().end, 1.0);
167    /// ```
168    pub fn with_range(mut self, range: Range<f32>) -> Self {
169        self.set_range(range);
170        self
171    }
172    /// Get the initial size of the terrain.
173    ///
174    /// # Examples
175    ///
176    /// ```
177    /// # use diamond_square::MidpointDisplacement;
178    /// let mut cfg = MidpointDisplacement::default();
179    /// assert_eq!(cfg.get_roughness(), 1.1);
180    /// ```
181    pub fn get_roughness(&self) -> f32 {
182        self.roughness
183    }
184    /// Get the initial size of the terrain.
185    ///
186    /// # Examples
187    ///
188    /// ```
189    /// # use diamond_square::MidpointDisplacement;
190    /// let mut cfg = MidpointDisplacement::default();
191    /// cfg.set_roughness(2.0);
192    /// assert_eq!(cfg.get_roughness(), 2.0);
193    /// ```
194    pub fn set_roughness(&mut self, roughness: f32) {
195        self.roughness = roughness;
196    }
197    /// Get the initial size of the terrain.
198    ///
199    /// # Examples
200    ///
201    /// ```
202    /// # use diamond_square::MidpointDisplacement;
203    /// let mut cfg = MidpointDisplacement::default().with_roughness(2.0);
204    /// assert_eq!(cfg.get_roughness(), 2.0);
205    /// ```
206    pub fn with_roughness(mut self, roughness: f32) -> Self {
207        self.set_roughness(roughness);
208        self
209    }
210    /// Get the initial size of the terrain.
211    ///
212    /// # Examples
213    ///
214    /// ```
215    /// # use diamond_square::MidpointDisplacement;
216    /// let mut cfg = MidpointDisplacement::default();
217    /// assert_eq!(cfg.get_seed(), 42);
218    /// ```
219    pub fn get_seed(&self) -> u64 {
220        self.seed
221    }
222    /// Get the initial size of the terrain.
223    ///
224    /// # Examples
225    ///
226    /// ```
227    /// # use diamond_square::MidpointDisplacement;
228    /// let mut cfg = MidpointDisplacement::default();
229    /// cfg.set_seed(0);
230    /// assert_eq!(cfg.get_seed(), 0);
231    /// ```
232    pub fn set_seed(&mut self, seed: u64) {
233        self.seed = seed;
234    }
235    /// Get the initial size of the terrain.
236    ///
237    /// # Examples
238    ///
239    /// ```
240    /// # use diamond_square::MidpointDisplacement;
241    /// let mut cfg = MidpointDisplacement::default().with_seed(0);
242    /// assert_eq!(cfg.get_seed(), 0);
243    /// ```
244    pub fn with_seed(mut self, seed: u64) -> Self {
245        self.set_seed(seed);
246        self
247    }
248}
249
250impl MidpointDisplacement {
251    /// Generate a grid using diamond square algorithm
252    ///
253    /// # Examples
254    ///
255    /// ```
256    /// use ndarray::Array2;
257    /// ```
258    pub fn generate(&self) -> Array1<f32> {
259        let mut rng = SmallRng::seed_from_u64(self.seed);
260        let mut step = 2usize.pow(self.iteration);
261        let length = self.get_map_length();
262        let mut grid = Array1::zeros((length,));
263        for x in (0..length).step_by(step) {
264            // SAFETY: x obviously in range
265            unsafe {
266                *grid.uget_mut(x) = rng.gen_range(self.range.clone());
267            }
268        }
269        for iteration in 0..self.iteration {
270            tracing::trace!("Iteration: {}, step: {} in {}", iteration + 1, step, length);
271            let half = step / 2;
272            for i in (half..length).step_by(step) {
273                // SAFETY: left and right must in range
274                unsafe {
275                    let l = *grid.uget(i - half);
276                    let r = *grid.uget(i + half);
277                    *grid.uget_mut(i) = self.random_average(&mut rng, [l, r]);
278                }
279            }
280            step /= 2;
281        }
282        grid
283    }
284    /// Calculate the average of the given values with random multiplier
285    fn random_average(&self, rng: &mut SmallRng, vs: [f32; 2]) -> f32 {
286        let mut avg = vs.iter().sum::<f32>() / 2.0;
287        if self.roughness != 1.0 {
288            let r_roughness = self.roughness.recip();
289            avg *= rng.gen_range(r_roughness..self.roughness);
290        }
291        avg
292    }
293}