1use crate::error::{Result, TemporalError};
7#[cfg(feature = "timeseries")]
8use crate::timeseries::TimeSeriesRaster;
9use scirs2_core::ndarray::{Array3, Array4, Axis};
10use serde::{Deserialize, Serialize};
11use tracing::{debug, info};
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct StackConfig {
16 pub axis: usize,
18 pub interpolation: InterpolationMethod,
20 pub fill_value: Option<f64>,
22}
23
24impl Default for StackConfig {
25 fn default() -> Self {
26 Self {
27 axis: 0,
28 interpolation: InterpolationMethod::Nearest,
29 fill_value: Some(f64::NAN),
30 }
31 }
32}
33
34#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
36pub enum InterpolationMethod {
37 Nearest,
39 Bilinear,
41 Cubic,
43}
44
45#[derive(Debug, Clone)]
50pub struct RasterStack {
51 data: Array4<f64>,
53 metadata: StackMetadata,
55}
56
57#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct StackMetadata {
60 pub n_time: usize,
62 pub height: usize,
64 pub width: usize,
66 pub n_bands: usize,
68 pub band_names: Vec<String>,
70 pub nodata: Option<f64>,
72}
73
74impl RasterStack {
75 pub fn new(data: Array4<f64>) -> Result<Self> {
80 let shape = data.shape();
81 if shape.len() != 4 {
82 return Err(TemporalError::dimension_mismatch(
83 "4D array",
84 format!("{}D array", shape.len()),
85 ));
86 }
87
88 let metadata = StackMetadata {
89 n_time: shape[0],
90 height: shape[1],
91 width: shape[2],
92 n_bands: shape[3],
93 band_names: (0..shape[3]).map(|i| format!("Band_{}", i + 1)).collect(),
94 nodata: None,
95 };
96
97 Ok(Self { data, metadata })
98 }
99
100 #[cfg(feature = "timeseries")]
105 pub fn from_timeseries(ts: &TimeSeriesRaster) -> Result<Self> {
106 if ts.is_empty() {
107 return Err(TemporalError::insufficient_data("Time series is empty"));
108 }
109
110 let (height, width, n_bands) = ts
112 .expected_shape()
113 .ok_or_else(|| TemporalError::insufficient_data("No shape information"))?;
114
115 let n_time = ts.len();
116
117 let mut data = Array4::zeros((n_time, height, width, n_bands));
119
120 for (t, (_, entry)) in ts.iter().enumerate() {
122 let entry_data = entry.data.as_ref().ok_or_else(|| {
123 TemporalError::invalid_input("Data not loaded in time series entry")
124 })?;
125
126 for i in 0..height {
128 for j in 0..width {
129 for k in 0..n_bands {
130 data[[t, i, j, k]] = entry_data[[i, j, k]];
131 }
132 }
133 }
134 }
135
136 let metadata = StackMetadata {
137 n_time,
138 height,
139 width,
140 n_bands,
141 band_names: (0..n_bands).map(|i| format!("Band_{}", i + 1)).collect(),
142 nodata: None,
143 };
144
145 info!(
146 "Created raster stack with shape ({}, {}, {}, {})",
147 n_time, height, width, n_bands
148 );
149
150 Ok(Self { data, metadata })
151 }
152
153 #[must_use]
155 pub fn shape(&self) -> (usize, usize, usize, usize) {
156 (
157 self.metadata.n_time,
158 self.metadata.height,
159 self.metadata.width,
160 self.metadata.n_bands,
161 )
162 }
163
164 #[must_use]
166 pub fn data(&self) -> &Array4<f64> {
167 &self.data
168 }
169
170 pub fn data_mut(&mut self) -> &mut Array4<f64> {
172 &mut self.data
173 }
174
175 #[must_use]
177 pub fn metadata(&self) -> &StackMetadata {
178 &self.metadata
179 }
180
181 pub fn set_band_names(&mut self, names: Vec<String>) -> Result<()> {
183 if names.len() != self.metadata.n_bands {
184 return Err(TemporalError::dimension_mismatch(
185 format!("{} bands", self.metadata.n_bands),
186 format!("{} names", names.len()),
187 ));
188 }
189 self.metadata.band_names = names;
190 Ok(())
191 }
192
193 pub fn set_nodata(&mut self, nodata: f64) {
195 self.metadata.nodata = Some(nodata);
196 }
197
198 pub fn get_time_slice(&self, time_index: usize) -> Result<Array3<f64>> {
203 if time_index >= self.metadata.n_time {
204 return Err(TemporalError::time_index_out_of_bounds(
205 time_index,
206 0,
207 self.metadata.n_time,
208 ));
209 }
210
211 Ok(self.data.index_axis(Axis(0), time_index).to_owned())
212 }
213
214 pub fn get_band_timeseries(&self, band_index: usize) -> Result<Array3<f64>> {
219 if band_index >= self.metadata.n_bands {
220 return Err(TemporalError::invalid_parameter(
221 "band_index",
222 format!(
223 "index {} out of bounds (max: {})",
224 band_index,
225 self.metadata.n_bands - 1
226 ),
227 ));
228 }
229
230 let mut result = Array3::zeros((
232 self.metadata.n_time,
233 self.metadata.height,
234 self.metadata.width,
235 ));
236
237 for t in 0..self.metadata.n_time {
238 for i in 0..self.metadata.height {
239 for j in 0..self.metadata.width {
240 result[[t, i, j]] = self.data[[t, i, j, band_index]];
241 }
242 }
243 }
244
245 Ok(result)
246 }
247
248 pub fn get_pixel_timeseries(&self, row: usize, col: usize, band: usize) -> Result<Vec<f64>> {
253 if row >= self.metadata.height {
254 return Err(TemporalError::invalid_parameter(
255 "row",
256 format!(
257 "index {} out of bounds (max: {})",
258 row,
259 self.metadata.height - 1
260 ),
261 ));
262 }
263 if col >= self.metadata.width {
264 return Err(TemporalError::invalid_parameter(
265 "col",
266 format!(
267 "index {} out of bounds (max: {})",
268 col,
269 self.metadata.width - 1
270 ),
271 ));
272 }
273 if band >= self.metadata.n_bands {
274 return Err(TemporalError::invalid_parameter(
275 "band",
276 format!(
277 "index {} out of bounds (max: {})",
278 band,
279 self.metadata.n_bands - 1
280 ),
281 ));
282 }
283
284 let mut values = Vec::with_capacity(self.metadata.n_time);
285 for t in 0..self.metadata.n_time {
286 values.push(self.data[[t, row, col, band]]);
287 }
288
289 Ok(values)
290 }
291
292 pub fn stack_bands(bands: Vec<Array3<f64>>) -> Result<Self> {
297 if bands.is_empty() {
298 return Err(TemporalError::insufficient_data("No bands to stack"));
299 }
300
301 let first_shape = bands[0].shape();
303 for (i, band) in bands.iter().enumerate().skip(1) {
304 if band.shape() != first_shape {
305 return Err(TemporalError::dimension_mismatch(
306 format!("{:?}", first_shape),
307 format!("{:?} (band {})", band.shape(), i),
308 ));
309 }
310 }
311
312 let n_time = first_shape[0];
313 let height = first_shape[1];
314 let width = first_shape[2];
315 let n_bands = bands.len();
316
317 let mut data = Array4::zeros((n_time, height, width, n_bands));
319
320 for (band_idx, band_data) in bands.iter().enumerate() {
321 for t in 0..n_time {
322 for i in 0..height {
323 for j in 0..width {
324 data[[t, i, j, band_idx]] = band_data[[t, i, j]];
325 }
326 }
327 }
328 }
329
330 let metadata = StackMetadata {
331 n_time,
332 height,
333 width,
334 n_bands,
335 band_names: (0..n_bands).map(|i| format!("Band_{}", i + 1)).collect(),
336 nodata: None,
337 };
338
339 debug!(
340 "Stacked {} bands into shape ({}, {}, {}, {})",
341 n_bands, n_time, height, width, n_bands
342 );
343
344 Ok(Self { data, metadata })
345 }
346
347 pub fn concatenate_time(stacks: Vec<Self>) -> Result<Self> {
352 if stacks.is_empty() {
353 return Err(TemporalError::insufficient_data("No stacks to concatenate"));
354 }
355
356 let first = &stacks[0];
358 let (_, height, width, n_bands) = first.shape();
359
360 for (i, stack) in stacks.iter().enumerate().skip(1) {
361 let (_, h, w, b) = stack.shape();
362 if h != height || w != width || b != n_bands {
363 return Err(TemporalError::dimension_mismatch(
364 format!("(?, {}, {}, {})", height, width, n_bands),
365 format!("(?, {}, {}, {}) at stack {}", h, w, b, i),
366 ));
367 }
368 }
369
370 let total_time: usize = stacks.iter().map(|s| s.metadata.n_time).sum();
372
373 let mut data = Array4::zeros((total_time, height, width, n_bands));
375 let mut current_time = 0;
376
377 for stack in &stacks {
378 let stack_time = stack.metadata.n_time;
379 for t in 0..stack_time {
380 for i in 0..height {
381 for j in 0..width {
382 for k in 0..n_bands {
383 data[[current_time + t, i, j, k]] = stack.data[[t, i, j, k]];
384 }
385 }
386 }
387 }
388 current_time += stack_time;
389 }
390
391 let metadata = StackMetadata {
392 n_time: total_time,
393 height,
394 width,
395 n_bands,
396 band_names: first.metadata.band_names.clone(),
397 nodata: first.metadata.nodata,
398 };
399
400 info!(
401 "Concatenated {} stacks into shape ({}, {}, {}, {})",
402 stacks.len(),
403 total_time,
404 height,
405 width,
406 n_bands
407 );
408
409 Ok(Self { data, metadata })
410 }
411
412 pub fn subset_time(&self, start: usize, end: usize) -> Result<Self> {
417 if start >= end {
418 return Err(TemporalError::invalid_time_range(
419 start.to_string(),
420 end.to_string(),
421 ));
422 }
423 if end > self.metadata.n_time {
424 return Err(TemporalError::time_index_out_of_bounds(
425 end,
426 0,
427 self.metadata.n_time,
428 ));
429 }
430
431 let n_time = end - start;
432 let mut data = Array4::zeros((
433 n_time,
434 self.metadata.height,
435 self.metadata.width,
436 self.metadata.n_bands,
437 ));
438
439 for (t_out, t_in) in (start..end).enumerate() {
440 for i in 0..self.metadata.height {
441 for j in 0..self.metadata.width {
442 for k in 0..self.metadata.n_bands {
443 data[[t_out, i, j, k]] = self.data[[t_in, i, j, k]];
444 }
445 }
446 }
447 }
448
449 let metadata = StackMetadata {
450 n_time,
451 height: self.metadata.height,
452 width: self.metadata.width,
453 n_bands: self.metadata.n_bands,
454 band_names: self.metadata.band_names.clone(),
455 nodata: self.metadata.nodata,
456 };
457
458 Ok(Self { data, metadata })
459 }
460
461 pub fn subset_bands(&self, band_indices: &[usize]) -> Result<Self> {
466 if band_indices.is_empty() {
467 return Err(TemporalError::insufficient_data("No bands selected"));
468 }
469
470 for &idx in band_indices {
472 if idx >= self.metadata.n_bands {
473 return Err(TemporalError::invalid_parameter(
474 "band_index",
475 format!(
476 "index {} out of bounds (max: {})",
477 idx,
478 self.metadata.n_bands - 1
479 ),
480 ));
481 }
482 }
483
484 let n_bands = band_indices.len();
485 let mut data = Array4::zeros((
486 self.metadata.n_time,
487 self.metadata.height,
488 self.metadata.width,
489 n_bands,
490 ));
491
492 for t in 0..self.metadata.n_time {
493 for i in 0..self.metadata.height {
494 for j in 0..self.metadata.width {
495 for (k_out, &k_in) in band_indices.iter().enumerate() {
496 data[[t, i, j, k_out]] = self.data[[t, i, j, k_in]];
497 }
498 }
499 }
500 }
501
502 let band_names = band_indices
503 .iter()
504 .map(|&i| self.metadata.band_names[i].clone())
505 .collect();
506
507 let metadata = StackMetadata {
508 n_time: self.metadata.n_time,
509 height: self.metadata.height,
510 width: self.metadata.width,
511 n_bands,
512 band_names,
513 nodata: self.metadata.nodata,
514 };
515
516 Ok(Self { data, metadata })
517 }
518
519 pub fn apply_temporal<F>(&self, func: F) -> Result<Array3<f64>>
524 where
525 F: Fn(&[f64]) -> f64,
526 {
527 let mut result = Array3::zeros((
528 self.metadata.height,
529 self.metadata.width,
530 self.metadata.n_bands,
531 ));
532
533 for i in 0..self.metadata.height {
534 for j in 0..self.metadata.width {
535 for k in 0..self.metadata.n_bands {
536 let timeseries: Vec<f64> = (0..self.metadata.n_time)
537 .map(|t| self.data[[t, i, j, k]])
538 .collect();
539 result[[i, j, k]] = func(×eries);
540 }
541 }
542 }
543
544 Ok(result)
545 }
546
547 pub fn mean_temporal(&self) -> Result<Array3<f64>> {
552 self.apply_temporal(|values| values.iter().sum::<f64>() / values.len() as f64)
553 }
554
555 pub fn median_temporal(&self) -> Result<Array3<f64>> {
560 self.apply_temporal(|values| {
561 let mut sorted = values.to_vec();
562 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
563 let mid = sorted.len() / 2;
564 if sorted.len() % 2 == 0 {
565 (sorted[mid - 1] + sorted[mid]) / 2.0
566 } else {
567 sorted[mid]
568 }
569 })
570 }
571
572 pub fn std_temporal(&self) -> Result<Array3<f64>> {
577 self.apply_temporal(|values| {
578 let mean = values.iter().sum::<f64>() / values.len() as f64;
579 let variance =
580 values.iter().map(|&v| (v - mean).powi(2)).sum::<f64>() / values.len() as f64;
581 variance.sqrt()
582 })
583 }
584
585 pub fn min_temporal(&self) -> Result<Array3<f64>> {
590 self.apply_temporal(|values| {
591 values
592 .iter()
593 .copied()
594 .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
595 .unwrap_or(f64::NAN)
596 })
597 }
598
599 pub fn max_temporal(&self) -> Result<Array3<f64>> {
604 self.apply_temporal(|values| {
605 values
606 .iter()
607 .copied()
608 .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
609 .unwrap_or(f64::NAN)
610 })
611 }
612}
613
614#[cfg(test)]
615mod tests {
616 use super::*;
617 use approx::assert_abs_diff_eq;
618
619 #[test]
620 fn test_raster_stack_creation() {
621 let data = Array4::zeros((10, 100, 100, 3));
622 let stack = RasterStack::new(data).expect("should create stack");
623 assert_eq!(stack.shape(), (10, 100, 100, 3));
624 }
625
626 #[test]
627 fn test_get_time_slice() {
628 let mut data = Array4::zeros((10, 5, 5, 2));
629 data[[3, 2, 2, 0]] = 42.0;
630
631 let stack = RasterStack::new(data).expect("should create stack");
632 let slice = stack.get_time_slice(3).expect("should get slice");
633
634 assert_eq!(slice.shape(), &[5, 5, 2]);
635 assert_abs_diff_eq!(slice[[2, 2, 0]], 42.0);
636 }
637
638 #[test]
639 fn test_get_pixel_timeseries() {
640 let mut data = Array4::zeros((10, 5, 5, 2));
641 for t in 0..10 {
642 data[[t, 2, 3, 1]] = t as f64;
643 }
644
645 let stack = RasterStack::new(data).expect("should create stack");
646 let ts = stack
647 .get_pixel_timeseries(2, 3, 1)
648 .expect("should get timeseries");
649
650 assert_eq!(ts.len(), 10);
651 for (i, &val) in ts.iter().enumerate() {
652 assert_abs_diff_eq!(val, i as f64);
653 }
654 }
655
656 #[test]
657 fn test_stack_bands() {
658 let band1 = Array3::from_elem((5, 10, 10), 1.0);
659 let band2 = Array3::from_elem((5, 10, 10), 2.0);
660 let band3 = Array3::from_elem((5, 10, 10), 3.0);
661
662 let stack = RasterStack::stack_bands(vec![band1, band2, band3]).expect("should stack");
663
664 assert_eq!(stack.shape(), (5, 10, 10, 3));
665 assert_abs_diff_eq!(stack.data()[[0, 0, 0, 0]], 1.0);
666 assert_abs_diff_eq!(stack.data()[[0, 0, 0, 1]], 2.0);
667 assert_abs_diff_eq!(stack.data()[[0, 0, 0, 2]], 3.0);
668 }
669
670 #[test]
671 fn test_concatenate_time() {
672 let data1 = Array4::from_elem((5, 10, 10, 2), 1.0);
673 let stack1 = RasterStack::new(data1).expect("should create");
674
675 let data2 = Array4::from_elem((3, 10, 10, 2), 2.0);
676 let stack2 = RasterStack::new(data2).expect("should create");
677
678 let concatenated =
679 RasterStack::concatenate_time(vec![stack1, stack2]).expect("should concatenate");
680
681 assert_eq!(concatenated.shape(), (8, 10, 10, 2));
682 }
683
684 #[test]
685 fn test_subset_time() {
686 let data = Array4::zeros((10, 5, 5, 2));
687 let stack = RasterStack::new(data).expect("should create");
688
689 let subset = stack.subset_time(2, 7).expect("should subset");
690 assert_eq!(subset.shape(), (5, 5, 5, 2));
691 }
692
693 #[test]
694 fn test_subset_bands() {
695 let data = Array4::zeros((10, 5, 5, 5));
696 let stack = RasterStack::new(data).expect("should create");
697
698 let subset = stack.subset_bands(&[0, 2, 4]).expect("should subset");
699 assert_eq!(subset.shape(), (10, 5, 5, 3));
700 }
701
702 #[test]
703 fn test_mean_temporal() {
704 let mut data = Array4::zeros((3, 2, 2, 1));
705 data[[0, 0, 0, 0]] = 1.0;
706 data[[1, 0, 0, 0]] = 2.0;
707 data[[2, 0, 0, 0]] = 3.0;
708
709 let stack = RasterStack::new(data).expect("should create");
710 let mean = stack.mean_temporal().expect("should calculate mean");
711
712 assert_abs_diff_eq!(mean[[0, 0, 0]], 2.0);
713 }
714
715 #[test]
716 fn test_median_temporal() {
717 let mut data = Array4::zeros((5, 2, 2, 1));
718 data[[0, 0, 0, 0]] = 1.0;
719 data[[1, 0, 0, 0]] = 2.0;
720 data[[2, 0, 0, 0]] = 3.0;
721 data[[3, 0, 0, 0]] = 4.0;
722 data[[4, 0, 0, 0]] = 5.0;
723
724 let stack = RasterStack::new(data).expect("should create");
725 let median = stack.median_temporal().expect("should calculate median");
726
727 assert_abs_diff_eq!(median[[0, 0, 0]], 3.0);
728 }
729
730 #[test]
731 fn test_min_max_temporal() {
732 let mut data = Array4::zeros((5, 2, 2, 1));
733 data[[0, 0, 0, 0]] = 1.0;
734 data[[1, 0, 0, 0]] = 5.0;
735 data[[2, 0, 0, 0]] = 3.0;
736 data[[3, 0, 0, 0]] = 2.0;
737 data[[4, 0, 0, 0]] = 4.0;
738
739 let stack = RasterStack::new(data).expect("should create");
740 let min = stack.min_temporal().expect("should calculate min");
741 let max = stack.max_temporal().expect("should calculate max");
742
743 assert_abs_diff_eq!(min[[0, 0, 0]], 1.0);
744 assert_abs_diff_eq!(max[[0, 0, 0]], 5.0);
745 }
746}