eorst 0.1.0

An Earth Observation and Remote Sensing Toolkit
use crate::common::*;
use crate::singlerasterdataset::*;
use colored::*;
use core::fmt::Debug;
use gdal::{raster::Buffer, Dataset, DatasetOptions, GdalOpenFlags};
use ndarray::{s, Array2, Array3, Array4};
use rayon::iter::IntoParallelIterator;
use rayon::iter::ParallelIterator;
use std::fs;
use std::path::{Path, PathBuf};
use std::process::Command;

#[derive(PartialEq, Debug, Clone, Default)]
pub struct MultiRasterDataset {
    datasets: Vec<SingleRasterDataset>,
    pub epsg_code: u32,
    pub n_blocks: usize,
    pub blocks_attributes: Vec<BlockAttributes>,
}

impl MultiRasterDataset {
    // this function will save a vector in the samplesxfeatures format to
    // a raster with n_features bands
    pub fn write_samples_feature<T>(&self, block_id: usize, data: &Array2<T>, file_name: PathBuf)
    where
        T: gdal::raster::GdalType + Copy + num_traits::identities::Zero + Debug,
    {
        let n_bands = data.shape()[1] as isize;
        let block_geotransform = self.blocks_attributes[block_id].geo_transform;
        let epsg_code = self.epsg_code;
        let blocks_attributes = self.blocks_attributes[block_id];
        let size_x = blocks_attributes.read_window.size.cols as usize;
        let size_y = blocks_attributes.read_window.size.rows as usize;
        let block_size: BlockSize = BlockSize {
            rows: size_y,
            cols: size_x,
        };

        raster_from_size::<T>(
            &PathBuf::from(&file_name),
            block_geotransform,
            epsg_code,
            block_size,
            n_bands,
        );
        let ds_options = DatasetOptions {
            open_flags: GdalOpenFlags::GDAL_OF_UPDATE,
            ..DatasetOptions::default()
        };
        let out_ds = Dataset::open_ex(Path::new(&file_name), ds_options).unwrap();

        // each feature goes to a band
        for feature in 1..n_bands + 1 {
            println!("Saving feature{:?}", feature);
            println!("Shape {:?} - {:?}", size_x, size_y);
            let mut out_band = out_ds.rasterband(feature).unwrap();

            let out_data = data.slice(s![.., feature - 1]);
            let out_data_u8: Vec<T> = out_data.into_iter().map(|v| *v as T).collect();
            let data_buffer = Buffer {
                size: (size_x, size_y),
                data: out_data_u8,
            };
            out_band
                .write((0, 0), (size_x, size_y), &data_buffer)
                .unwrap();
        }
    }

    pub fn builder() -> MultiRasterDataset {
        MultiRasterDataset::default()
    }
    pub fn write_window3<T>(&self, block_index: usize, data: Array3<T>, out_fn: &PathBuf)
    where
        T: gdal::raster::GdalType + Copy + num_traits::identities::Zero + Debug,
    {
        // create and empty raster with the right size
        let block_geotransform = self.blocks_attributes[block_index].geo_transform;
        let block_size: BlockSize = BlockSize {
            rows: self.blocks_attributes[block_index].read_window.size.rows as usize,
            cols: self.blocks_attributes[block_index].read_window.size.cols as usize,
        };
        let epsg_code = self.epsg_code;
        let dataset_options: DatasetOptions = DatasetOptions {
            open_flags: GdalOpenFlags::GDAL_OF_UPDATE,
            allowed_drivers: None,
            open_options: None,
            sibling_files: None,
        };
        let n_bands = data.shape()[0] as isize;

        raster_from_size::<T>(out_fn, block_geotransform, epsg_code, block_size, n_bands);
        let out_ds = Dataset::open_ex(out_fn, dataset_options).unwrap();
        for band in 0..data.shape()[0] {
            let b = (band + 1) as isize;
            let mut out_band = out_ds.rasterband(b).unwrap();
            let data_vec: Vec<T> = data
                .slice(s![band, .., ..])
                .into_iter()
                .map(|v| *v)
                .collect();
            let data_buffer = Buffer {
                size: (block_size.cols, block_size.rows),
                data: data_vec,
            };

            out_band
                .write((0, 0), (block_size.cols, block_size.rows), &data_buffer)
                .unwrap();
        }
    }

    pub fn read_window_as_vec_obs<T>(&self, read_window: ReadWindow) -> Array4<T>
    where
        T: gdal::raster::GdalType + Copy + num_traits::identities::Zero,
    {
        let bands = self.datasets[0].bands.to_vec();
        let mut multi_data: Array4<T> = Array4::zeros((
            self.datasets.len(),
            bands.len(),
            read_window.size.rows as usize,
            read_window.size.cols as usize,
        ));
        let mut ds_index = 0;
        for dataset in &self.datasets {
            let single_data = dataset.read_window::<T>(read_window); //Array3
            let mut sub = multi_data.slice_mut(s![ds_index, .., .., ..]);
            sub.assign(&single_data);
            ds_index += 1;
        }
        multi_data
    }

    pub fn read_window<T>(&self, read_window: ReadWindow) -> Array4<T>
    where
        T: gdal::raster::GdalType + Copy + num_traits::identities::Zero,
    {
        let bands = self.datasets[0].bands.to_vec();
        let mut multi_data: Array4<T> = Array4::zeros((
            self.datasets.len(),
            bands.len(),
            read_window.size.rows as usize,
            read_window.size.cols as usize,
        ));
        let mut ds_index = 0;
        for dataset in &self.datasets {
            let single_data = dataset.read_window::<T>(read_window); //Array3
            let mut sub = multi_data.slice_mut(s![ds_index, .., .., ..]);
            sub.assign(&single_data);
            ds_index += 1;
        }
        multi_data
    }

    pub fn oom_apply<T>(&self, f: fn(&Array4<T>) -> Array4<T>)
    where
        T: gdal::raster::GdalType + Copy + num_traits::Zero + Debug,
    {
        let pool = rayon::ThreadPoolBuilder::new()
            .num_threads(self.datasets[0].num_threads)
            .build()
            .unwrap();

        pool.install(|| {
            self.blocks_attributes
                .to_owned()
                .into_par_iter()
                .for_each(|block_attributes| {
                    let w = block_attributes.read_window;
                    let block_data = self.read_window::<T>(w);
                    let _ = f(&block_data);
                    // todo!
                })
        })
    }
    pub fn oom_apply_time<T>(&self, f: fn(&Array4<T>) -> Array3<T>, out_file: &PathBuf)
    where
        T: gdal::raster::GdalType + Copy + num_traits::Zero + Debug,
    {
        // todo!: Using first ds as a templete. This should be actually saved as mrd attributes!
        let pool = rayon::ThreadPoolBuilder::new()
            .num_threads(self.datasets[0].num_threads)
            .build()
            .unwrap();

        // this will be used to save the intermediate vrt and block files
        let tmp_file = PathBuf::from(create_temp_file("vrt"));

        let handle = pool.install(|| {
            self.blocks_attributes.to_owned().into_par_iter().map(
                |block_attributes: BlockAttributes| -> PathBuf {
                    let id = block_attributes.block_index;
                    let file_stem = tmp_file.file_stem().unwrap().to_str().unwrap();
                    let block_fn = tmp_file.with_file_name(format!("{}_{}.tif", file_stem, id));
                    let w = block_attributes.read_window;
                    let block_data = self.read_window::<T>(w);
                    let result = f(&block_data);
                    // save block to disk
                    self.write_window3(block_attributes.block_index, result, &block_fn);
                    block_fn
                },
            )
        });
        let collected: Vec<PathBuf> = handle.collect();

        // build a vrt with all the collected files
        Command::new("gdalbuildvrt")
            .arg(&tmp_file)
            .args(&collected)
            .output()
            .expect("failed to create vrts");

        let tif_fn = out_file;
        // translate to gtiff
        Command::new("gdal_translate")
            .arg(&tmp_file)
            .arg(tif_fn)
            .output()
            .expect("failed to create vrts");
        // remove all files
        fs::remove_file(tmp_file).expect("Unable to remove the temporary file");
        collected
            .iter()
            .for_each(|f| fs::remove_file(f).expect("Unable to remove file"))
    }
}

#[derive(Default)]
pub struct MultiRasterDatasetBuilder {
    // Probably lots of optional fields.
    datasets: Vec<SingleRasterDataset>,
    pub blocks_attributes: Vec<BlockAttributes>,
    pub epsg_code: u32,
    pub n_blocks: usize,
}

impl MultiRasterDatasetBuilder {
    // very simple at this stage. The idea is to add here some vrt wrapping to allow
    // datasets of different CRS and/or dimensions.
    pub fn from_datasets(datasets: Vec<SingleRasterDataset>) -> MultiRasterDatasetBuilder {
        // Set the minimally required fields of Foo.

        let first = &datasets[0];
        let equal_datasets = datasets.iter().all(|item| {
            (item.image_size == first.image_size)
                & (item.geo_transform == first.geo_transform)
                & (item.epsg_code == first.epsg_code)
        });

        assert!(
            equal_datasets,
            "Datasets have to have same image size, geo transform and epsg codes"
        );
        let blocks_attributes = datasets[0].blocks_attributes.to_owned();
        let epsg_code = datasets[0].epsg_code;
        let n_blocks = datasets[0].n_blocks;
        MultiRasterDatasetBuilder {
            datasets,
            blocks_attributes,
            epsg_code,
            n_blocks,
        }
    }

    // If we can get away with not consuming the Builder here, that is an
    // advantage. It means we can use the MultiRasterDatasetBuilder as a template for constructing
    // many Foos.
    pub fn build(self) -> MultiRasterDataset {
        // Create a Foo from the MultiRasterDatasetBuilder, applying all settings in MultiRasterDatasetBuilder
        // to Foo.
        MultiRasterDataset {
            datasets: self.datasets,
            blocks_attributes: self.blocks_attributes,
            epsg_code: self.epsg_code,
            n_blocks: self.n_blocks,
        }
    }
}

impl std::fmt::Display for MultiRasterDataset {
    fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::result::Result<(), std::fmt::Error> {
        writeln!(fmt, "{}", "To Be implemented: ".bold())
    }
}