1use crate::error::{AnalyticsError, Result};
6use scirs2_core::ndarray::{ArrayView2, ArrayView3};
7use std::collections::HashMap;
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
11pub enum ZonalStatistic {
12 Mean,
14 Median,
16 Min,
18 Max,
20 Sum,
22 Count,
24 StdDev,
26 Variance,
28 CoeffVar,
30 Percentile(u8),
32}
33
34#[derive(Debug, Clone)]
36pub struct ZonalResult {
37 pub zones: HashMap<i32, HashMap<ZonalStatistic, f64>>,
39 pub zone_ids: Vec<i32>,
41}
42
43pub struct ZonalCalculator {
45 statistics: Vec<ZonalStatistic>,
46 no_data_value: Option<f64>,
47}
48
49impl ZonalCalculator {
50 pub fn new() -> Self {
52 Self {
53 statistics: vec![
54 ZonalStatistic::Mean,
55 ZonalStatistic::Min,
56 ZonalStatistic::Max,
57 ZonalStatistic::Count,
58 ],
59 no_data_value: None,
60 }
61 }
62
63 pub fn with_statistics(mut self, stats: Vec<ZonalStatistic>) -> Self {
65 self.statistics = stats;
66 self
67 }
68
69 pub fn with_no_data(mut self, value: f64) -> Self {
71 self.no_data_value = Some(value);
72 self
73 }
74
75 pub fn calculate(
84 &self,
85 values: &ArrayView2<f64>,
86 zones: &ArrayView2<i32>,
87 ) -> Result<ZonalResult> {
88 if values.dim() != zones.dim() {
89 return Err(AnalyticsError::dimension_mismatch(
90 format!("{:?}", values.dim()),
91 format!("{:?}", zones.dim()),
92 ));
93 }
94
95 let mut zone_values: HashMap<i32, Vec<f64>> = HashMap::new();
97
98 for ((i, j), &zone_id) in zones.indexed_iter() {
99 let value = values[[i, j]];
100
101 if let Some(no_data) = self.no_data_value {
103 if (value - no_data).abs() < f64::EPSILON {
104 continue;
105 }
106 }
107
108 zone_values.entry(zone_id).or_default().push(value);
109 }
110
111 let mut result_zones = HashMap::new();
113 let mut zone_ids: Vec<i32> = zone_values.keys().copied().collect();
114 zone_ids.sort_unstable();
115
116 for (&zone_id, values_in_zone) in &zone_values {
117 let mut stats = HashMap::new();
118
119 for &statistic in &self.statistics {
120 let value = self.calculate_statistic(statistic, values_in_zone)?;
121 stats.insert(statistic, value);
122 }
123
124 result_zones.insert(zone_id, stats);
125 }
126
127 Ok(ZonalResult {
128 zones: result_zones,
129 zone_ids,
130 })
131 }
132
133 pub fn calculate_multiband(
142 &self,
143 values: &ArrayView3<f64>,
144 zones: &ArrayView2<i32>,
145 ) -> Result<Vec<ZonalResult>> {
146 let (height, width, n_bands) = values.dim();
147
148 if (height, width) != zones.dim() {
149 return Err(AnalyticsError::dimension_mismatch(
150 format!("{}x{}", height, width),
151 format!("{:?}", zones.dim()),
152 ));
153 }
154
155 let mut results = Vec::with_capacity(n_bands);
156
157 for band in 0..n_bands {
158 let band_values = values.slice(s![.., .., band]);
159 let result = self.calculate(&band_values, zones)?;
160 results.push(result);
161 }
162
163 Ok(results)
164 }
165
166 fn calculate_statistic(&self, stat: ZonalStatistic, values: &[f64]) -> Result<f64> {
168 if values.is_empty() {
169 return Ok(f64::NAN);
170 }
171
172 match stat {
173 ZonalStatistic::Mean => Ok(values.iter().sum::<f64>() / values.len() as f64),
174 ZonalStatistic::Median => self.calculate_median(values),
175 ZonalStatistic::Min => values
176 .iter()
177 .copied()
178 .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
179 .ok_or_else(|| AnalyticsError::zonal_stats_error("Failed to compute min")),
180 ZonalStatistic::Max => values
181 .iter()
182 .copied()
183 .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
184 .ok_or_else(|| AnalyticsError::zonal_stats_error("Failed to compute max")),
185 ZonalStatistic::Sum => Ok(values.iter().sum()),
186 ZonalStatistic::Count => Ok(values.len() as f64),
187 ZonalStatistic::StdDev => self.calculate_std_dev(values),
188 ZonalStatistic::Variance => self.calculate_variance(values),
189 ZonalStatistic::CoeffVar => {
190 let mean = values.iter().sum::<f64>() / values.len() as f64;
191 let std_dev = self.calculate_std_dev(values)?;
192 Ok(if mean.abs() > f64::EPSILON {
193 (std_dev / mean) * 100.0
194 } else {
195 f64::NAN
196 })
197 }
198 ZonalStatistic::Percentile(p) => self.calculate_percentile(values, p),
199 }
200 }
201
202 fn calculate_median(&self, values: &[f64]) -> Result<f64> {
203 let mut sorted = values.to_vec();
204 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
205
206 let n = sorted.len();
207 if n % 2 == 0 {
208 Ok((sorted[n / 2 - 1] + sorted[n / 2]) / 2.0)
209 } else {
210 Ok(sorted[n / 2])
211 }
212 }
213
214 fn calculate_variance(&self, values: &[f64]) -> Result<f64> {
215 let n = values.len() as f64;
216 let mean = values.iter().sum::<f64>() / n;
217 let variance = values.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / n;
218 Ok(variance)
219 }
220
221 fn calculate_std_dev(&self, values: &[f64]) -> Result<f64> {
222 Ok(self.calculate_variance(values)?.sqrt())
223 }
224
225 fn calculate_percentile(&self, values: &[f64], percentile: u8) -> Result<f64> {
226 if percentile > 100 {
227 return Err(AnalyticsError::invalid_parameter(
228 "percentile",
229 "must be between 0 and 100",
230 ));
231 }
232
233 let mut sorted = values.to_vec();
234 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
235
236 let n = sorted.len();
237 let rank = (percentile as f64 / 100.0) * ((n - 1) as f64);
238 let lower_idx = rank.floor() as usize;
239 let upper_idx = rank.ceil() as usize;
240 let fraction = rank - (lower_idx as f64);
241
242 Ok(sorted[lower_idx] + fraction * (sorted[upper_idx] - sorted[lower_idx]))
243 }
244}
245
246impl Default for ZonalCalculator {
247 fn default() -> Self {
248 Self::new()
249 }
250}
251
252pub struct WeightedZonalCalculator {
254 calculator: ZonalCalculator,
255}
256
257impl WeightedZonalCalculator {
258 pub fn new() -> Self {
260 Self {
261 calculator: ZonalCalculator::new(),
262 }
263 }
264
265 pub fn with_statistics(mut self, stats: Vec<ZonalStatistic>) -> Self {
267 self.calculator = self.calculator.with_statistics(stats);
268 self
269 }
270
271 pub fn calculate(
281 &self,
282 values: &ArrayView2<f64>,
283 weights: &ArrayView2<f64>,
284 zones: &ArrayView2<i32>,
285 ) -> Result<ZonalResult> {
286 if values.dim() != weights.dim() || values.dim() != zones.dim() {
287 return Err(AnalyticsError::dimension_mismatch(
288 format!("{:?}", values.dim()),
289 "all inputs must have same dimensions".to_string(),
290 ));
291 }
292
293 let mut zone_data: HashMap<i32, (Vec<f64>, Vec<f64>)> = HashMap::new();
295
296 for ((i, j), &zone_id) in zones.indexed_iter() {
297 let value = values[[i, j]];
298 let weight = weights[[i, j]];
299
300 if weight > 0.0 {
301 let entry = zone_data
302 .entry(zone_id)
303 .or_insert_with(|| (Vec::new(), Vec::new()));
304 entry.0.push(value);
305 entry.1.push(weight);
306 }
307 }
308
309 let mut result_zones = HashMap::new();
311 let mut zone_ids: Vec<i32> = zone_data.keys().copied().collect();
312 zone_ids.sort_unstable();
313
314 for (&zone_id, (values_in_zone, weights_in_zone)) in &zone_data {
315 let mut stats = HashMap::new();
316
317 let weighted_sum: f64 = values_in_zone
319 .iter()
320 .zip(weights_in_zone.iter())
321 .map(|(v, w)| v * w)
322 .sum();
323 let weight_sum: f64 = weights_in_zone.iter().sum();
324
325 if weight_sum > f64::EPSILON {
326 stats.insert(ZonalStatistic::Mean, weighted_sum / weight_sum);
327 }
328
329 stats.insert(ZonalStatistic::Count, values_in_zone.len() as f64);
331
332 if let Some(&min) = values_in_zone
334 .iter()
335 .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
336 {
337 stats.insert(ZonalStatistic::Min, min);
338 }
339
340 if let Some(&max) = values_in_zone
341 .iter()
342 .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
343 {
344 stats.insert(ZonalStatistic::Max, max);
345 }
346
347 result_zones.insert(zone_id, stats);
348 }
349
350 Ok(ZonalResult {
351 zones: result_zones,
352 zone_ids,
353 })
354 }
355}
356
357impl Default for WeightedZonalCalculator {
358 fn default() -> Self {
359 Self::new()
360 }
361}
362
363use scirs2_core::ndarray::s;
365
366#[cfg(test)]
367mod tests {
368 use super::*;
369 use approx::assert_abs_diff_eq;
370 use scirs2_core::ndarray::{Array, array};
371
372 #[test]
373 fn test_zonal_basic() {
374 let values = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]];
375 let zones = array![[1, 1, 2], [1, 2, 2], [2, 2, 2]];
376
377 let calculator = ZonalCalculator::new();
378 let result = calculator
379 .calculate(&values.view(), &zones.view())
380 .expect("Zonal statistics calculation should succeed");
381
382 assert_eq!(result.zone_ids.len(), 2);
383 assert!(result.zones.contains_key(&1));
384 assert!(result.zones.contains_key(&2));
385
386 let zone1_stats = &result.zones[&1];
388 assert_abs_diff_eq!(
389 zone1_stats[&ZonalStatistic::Mean],
390 (1.0 + 2.0 + 4.0) / 3.0,
391 epsilon = 1e-10
392 );
393 }
394
395 #[test]
396 fn test_zonal_statistics() {
397 let values = array![[1.0, 2.0], [3.0, 4.0]];
398 let zones = array![[1, 1], [1, 1]];
399
400 let calculator = ZonalCalculator::new().with_statistics(vec![
401 ZonalStatistic::Mean,
402 ZonalStatistic::Min,
403 ZonalStatistic::Max,
404 ZonalStatistic::StdDev,
405 ]);
406
407 let result = calculator
408 .calculate(&values.view(), &zones.view())
409 .expect("Zonal statistics with multiple stats should succeed");
410 let zone1_stats = &result.zones[&1];
411
412 assert_abs_diff_eq!(zone1_stats[&ZonalStatistic::Mean], 2.5, epsilon = 1e-10);
413 assert_abs_diff_eq!(zone1_stats[&ZonalStatistic::Min], 1.0, epsilon = 1e-10);
414 assert_abs_diff_eq!(zone1_stats[&ZonalStatistic::Max], 4.0, epsilon = 1e-10);
415 }
416
417 #[test]
418 fn test_weighted_zonal() {
419 let values = array![[1.0, 2.0], [3.0, 4.0]];
420 let weights = array![[1.0, 1.0], [1.0, 1.0]];
421 let zones = array![[1, 1], [1, 1]];
422
423 let calculator = WeightedZonalCalculator::new();
424 let result = calculator
425 .calculate(&values.view(), &weights.view(), &zones.view())
426 .expect("Weighted zonal statistics should succeed");
427
428 let zone1_stats = &result.zones[&1];
429 assert_abs_diff_eq!(zone1_stats[&ZonalStatistic::Mean], 2.5, epsilon = 1e-10);
430 }
431
432 #[test]
433 fn test_percentile() {
434 let values = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
435 let zones = array![[1, 1, 1], [1, 1, 1]];
436
437 let calculator = ZonalCalculator::new().with_statistics(vec![
438 ZonalStatistic::Percentile(50), ZonalStatistic::Percentile(25),
440 ZonalStatistic::Percentile(75),
441 ]);
442
443 let result = calculator
444 .calculate(&values.view(), &zones.view())
445 .expect("Percentile calculation should succeed");
446 let zone1_stats = &result.zones[&1];
447
448 assert_abs_diff_eq!(
449 zone1_stats[&ZonalStatistic::Percentile(50)],
450 3.5,
451 epsilon = 1e-10
452 );
453 }
454}