eorst 1.0.0

Earth Observation and Remote Sensing Toolkit - library for raster processing pipelines
//! Data source definitions for raster datasets.
//!
//! This module provides types for describing the source files and bands
//! that make up a raster dataset.

use chrono::{DateTime, FixedOffset};
use gdal::Dataset;
use std::path::PathBuf;

/// Type for representing dates in raster data.
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum DateType {
    /// Specific date.
    Date(DateTime<FixedOffset>),
    /// Index into a time dimension.
    Index(usize),
}

/// Data source for a raster dataset.
#[derive(Debug, PartialEq, Clone)]
pub struct DataSource {
    /// Path to the source file.
    pub source: PathBuf,
    /// Band indices to use.
    pub bands: Vec<usize>,
    /// Names of the layers.
    pub layer_names: Vec<String>,
}

/// Builder for creating a DataSource.
pub struct DataSourceBuilder {
    /// Path to the source file.
    pub source: PathBuf,
    /// Band indices to use.
    pub bands: Vec<usize>,
    /// Names of the layers.
    pub layer_names: Vec<String>,
    /// Date indices for temporal data.
    pub date_indices: Vec<DateType>,
}

impl DataSourceBuilder {
    /// Creates a DataSourceBuilder from a file path.
    pub fn from_file(file_name: &PathBuf) -> Self {
        let source = file_name;

        let n_bands = DataSourceBuilder::get_n_bands(file_name);
        let bands: Vec<usize> = (1..n_bands + 1).collect();

        let layer_names = (0..n_bands)
            .map(|layer_index| format!("layer_{:?}", layer_index))
            .collect::<Vec<String>>();
        let date_indices = vec![DateType::Index(0)];

        DataSourceBuilder {
            source: source.clone(),
            bands,
            layer_names,
            date_indices,
        }
    }

    /// Sets the band dimension (layer or time).
    pub fn band_dimension(self, dimension: crate::types::Dimension) -> Self {
        let n_bands = DataSourceBuilder::get_n_bands(&self.source);

        let layer_names = match dimension {
            crate::types::Dimension::Layer => self.layer_names,
            crate::types::Dimension::Time => vec!["Layer_0".to_string()],
        };

        let date_indices = match dimension {
            crate::types::Dimension::Layer => self.date_indices,
            crate::types::Dimension::Time => {
                (0..n_bands).map(DateType::Index).collect::<Vec<DateType>>()
            }
        };

        DataSourceBuilder {
            source: self.source.clone(),
            bands: self.bands,
            layer_names,
            date_indices,
        }
    }

    /// Sets the date indices for temporal data.
    pub fn set_dates(self, dates: Vec<DateType>) -> Self {
        let n_dates = self.date_indices.len();
        let n_dates_new = dates.len();
        assert_eq!(n_dates, n_dates_new);

        DataSourceBuilder {
            source: self.source.clone(),
            bands: self.bands,
            layer_names: self.layer_names,
            date_indices: dates,
        }
    }

    /// Sets the layer names.
    pub fn set_names(self, names: Vec<&str>) -> Self {
        let n_names = self.layer_names.len();
        let n_names_new = names.len();
        assert_eq!(n_names, n_names_new);

        DataSourceBuilder {
            source: self.source.clone(),
            bands: self.bands,
            layer_names: names.iter().map(|n| n.to_string()).collect(),
            date_indices: self.date_indices,
        }
    }

    /// Builds the DataSource.
    pub fn build(self) -> DataSource {
        DataSource {
            source: self.source,
            bands: self.bands,
            layer_names: self.layer_names,
        }
    }

    /// Sets the band indices.
    pub fn bands(mut self, bands: Vec<usize>) -> Self {
        self.bands = bands.clone();
        // Update layer_names to match the selected bands
        self.layer_names = bands
            .iter()
            .enumerate()
            .map(|(i, _)| format!("layer_{}", i))
            .collect();
        self
    }

    fn get_n_bands(source: &PathBuf) -> usize {
        let ds = Dataset::open(source).unwrap();
        ds.raster_count()
    }
}