eorst 1.0.0

Earth Observation and Remote Sensing Toolkit - library for raster processing pipelines
//! Rasterization methods for RasterDataset.
//!
//! This module contains methods for burning vector values into raster blocks.

use crate::core_types::{RasterData, RasterType};
use crate::blocks::RasterBlock;
use crate::types::BlockSize;
use crate::gdal_utils::{create_rayon_pool, create_temp_file, file_stem_str, mosaic_translate_cleanup_time_steps, open_for_update, raster_from_size};
use crate::rasterdataset::RasterDataset;

use gdal::Dataset;
use gdal::vector::{Geometry, Layer, LayerAccess};
use gdal::spatial_ref::SpatialRef;
use gdal::spatial_ref::CoordTransform;
use anyhow::Result as GdalResult;
use gdal::raster::RasterizeOptions;
use ndarray::{Array4, Axis};
use num_traits::NumCast;
use rayon::prelude::*;
use std::collections::HashMap;
use std::path::{Path, PathBuf};

fn get_layer<'a>(dataset: &'a Dataset, layer_name: Option<&str>) -> GdalResult<Layer<'a>> {
    match layer_name {
        Some(name) => dataset.layer_by_name(name).map_err(|e| anyhow::anyhow!(e)),
        None => dataset.layer(0).map_err(|e| anyhow::anyhow!(e)),
    }
}

fn rasterize_and_write<R>(
    raster_block: &RasterBlock<R>,
    geometries: &[Geometry],
    burn_values: Vec<R>,
    rasterize_options: Option<RasterizeOptions>,
    block_fn: &Path,
) -> GdalResult<()>
where
    R: RasterType,
{
    let block_size = BlockSize {
        rows: raster_block.read_window.size.rows as usize,
        cols: raster_block.read_window.size.cols as usize,
    };
    raster_from_size::<R>(
        block_fn,
        raster_block.geo_transform,
        raster_block.epsg_code as u32,
        block_size,
        1,
        R::zero(),
    );
    let mut dataset = open_for_update(block_fn);

    let opts = rasterize_options.unwrap_or_else(|| RasterizeOptions::default());
    let burn_values_f64: Vec<f64> = burn_values
        .iter()
        .map(|v| NumCast::from(*v).unwrap_or(0.0))
        .collect();
    gdal::raster::rasterize(
        &mut dataset,
        &[1],
        geometries,
        burn_values_f64.as_slice(),
        Some(opts),
    )
    .expect("Rasterize failed");

    Ok(())
}

impl<R> RasterDataset<R>
where
    R: RasterType,
{
    /// Burns vector values into the raster dataset.
    pub fn rasterize(
        &self,
        vector_fn: &PathBuf,
        column_name: &str,
        map: Option<HashMap<String, i32>>,
        layer_name: Option<&str>,
        n_cpus: usize,
        rasterize_options: Option<RasterizeOptions>,
        out_file: &Path,
    ) {
        log::debug!("Starting rasterize.");
        gdal::config::set_config_option("OSR_DEFAULT_AXIS_MAPPING_STRATEGY", "TRADITIONAL_GIS_ORDER").unwrap();
        let pool = create_rayon_pool(n_cpus);

        let tmp_file = PathBuf::from(create_temp_file("vrt"));

        let handle = pool.install(|| {
            self.blocks.to_owned().into_par_iter().map(
                |raster_block: RasterBlock<R>| -> Vec<PathBuf> {
                    let dataset = Dataset::open(vector_fn).expect("unable to open dataset");
                    let mut layer = get_layer(&dataset, layer_name).unwrap();
                    let bid = raster_block.block_index;

                    let file_stem = file_stem_str(&tmp_file);
                    let mut result: Array4<i32> = RasterData::zeros((
                        1,
                        1,
                        raster_block.read_window.size.cols as usize,
                        raster_block.read_window.size.rows as usize,
                    ));
                    for c in result.iter_mut() {
                        *c = 1
                    }

                    let x_ul = raster_block.geo_transform.x_ul;
                    let y_ul = raster_block.geo_transform.y_ul;

                    let x2 = x_ul
                        + (raster_block.geo_transform.x_res
                            * raster_block.read_window.size.cols as f64);

                    let y2 = y_ul
                        + (raster_block.geo_transform.y_res
                            * raster_block.read_window.size.rows as f64);

                    let (xmin, xmax) = if x_ul < x2 { (x_ul, x2) } else { (x2, x_ul) };
                    let (ymin, ymax) = if y_ul < y2 { (y_ul, y2) } else { (y2, y_ul) };

                    let src = SpatialRef::from_epsg(raster_block.epsg_code as u32)
                        .expect("Invalid raster EPSG code");

                    let dst = layer
                        .spatial_ref()
                        .expect("Vector layer has no CRS defined");

                    let mut geom = gdal::vector::Geometry::from_wkt(&format!(
                        "POLYGON (({} {}, {} {}, {} {}, {} {}, {} {}))",
                        xmin, ymin, xmin, ymax, xmax, ymax, xmax, ymin, xmin, ymin
                    ))
                    .unwrap();

                    geom.set_spatial_ref(src.clone());

                    let geom = if src != dst {
                        let tx = CoordTransform::new(&src, &dst)
                            .expect("Failed to create coordinate transform");
                        geom.transform(&tx)
                            .expect("Failed to transform bbox into layer CRS")
                    } else {
                        geom
                    };

                    layer.set_spatial_filter(&geom);
                    let mut geometries = Vec::new();
                    let mut burn_values: Vec<R> = Vec::new();

                    for feature in layer.features() {
                        let g = feature.geometry().unwrap().to_owned();
                        let g = g.buffer(0., 8).unwrap();

                        let g = if src != dst {
                            let tx = CoordTransform::new(&dst, &src)
                                .expect("Failed to create coordinate transform");
                            g.transform(&tx)
                                .expect("Failed to transform bbox into layer CRS")
                        } else {
                            g
                        };

                        let field_index = feature
                            .field_index(column_name)
                            .expect("Invalid column name,");
                        if let Some(field_value) =
                            feature.field(field_index).expect("Invalid column name!")
                        {
                            let v = match field_value {
                                gdal::vector::FieldValue::IntegerValue(i) => NumCast::from(i).unwrap_or(0),
                                gdal::vector::FieldValue::RealValue(r) => NumCast::from(r).unwrap(),
                                gdal::vector::FieldValue::StringValue(s) => match &map {
                                    Some(hmap) => match hmap.get(&s) {
                                        Some(v) => *v,
                                        None => panic!("String '{}' not found in lookup table", s),
                                    },
                                    None => panic!(
                                        "Field '{}' is a string ('{}'), but no string→number map was supplied",
                                        column_name, s
                                    ),
                                },
                                _ => panic!("The field to rasterize has to be numeric"),
                            };
                            burn_values.push(NumCast::from(v).expect("Cant cast!"));
                            let zero = R::from(0).unwrap();
                            let _no_null: Vec<_> =
                                burn_values.iter().filter(|m| **m != zero).collect();
                        } else {
                            panic!("{}", format!("Field {column_name} does not exist"));
                        }
                        geometries.push(g);
                    }

                    let mut block_fns: Vec<PathBuf> = Vec::new();
                    for (tid, _) in result.axis_iter(Axis(0)).enumerate() {
                        let block_fn =
                            tmp_file.with_file_name(format!("{}_{}_{}.tif", file_stem, tid, bid));

                        rasterize_and_write::<R>(
                            &raster_block,
                            &geometries,
                            burn_values.clone(),
                            rasterize_options,
                            &block_fn,
                        )
                        .unwrap();
                        block_fns.push(block_fn);
                    }
                    block_fns
                },
            )
        });

        let collected: Vec<Vec<PathBuf>> = handle.collect();
        mosaic_translate_cleanup_time_steps(
            &collected, out_file, self.metadata.epsg_code, self.metadata.shape.times,
        );
    }
}