eorst 1.0.1

Earth Observation and Remote Sensing Toolkit - library for raster processing pipelines
//! Rasterization methods for RasterDataset.
//!
//! This module contains methods for burning vector values into raster blocks.

use crate::blocks::RasterBlock;
use crate::core_types::RasterType;
use crate::gdal_utils::create_rayon_pool;
use crate::parallel_writer;
use crate::rasterdataset::RasterDataset;

use anyhow::Result as GdalResult;
use gdal::raster::RasterizeOptions;
use gdal::spatial_ref::CoordTransform;
use gdal::spatial_ref::SpatialRef;
use gdal::vector::{Layer, LayerAccess};
use gdal::{Dataset, DriverManager};
use kdam::par_tqdm;
use ndarray::Array3;
use num_traits::NumCast;
use rayon::prelude::*;
use std::collections::HashMap;
use std::path::Path;

fn get_layer<'a>(dataset: &'a Dataset, layer_name: Option<&str>) -> GdalResult<Layer<'a>> {
    match layer_name {
        Some(name) => dataset.layer_by_name(name).map_err(|e| anyhow::anyhow!(e)),
        None => dataset.layer(0).map_err(|e| anyhow::anyhow!(e)),
    }
}

impl<R> RasterDataset<R>
where
    R: RasterType,
{
    /// Burns vector values into the raster dataset.
    ///
    /// Rasterizes geometries from a vector file by processing blocks in parallel
    /// and writing directly to the output GeoTIFF via [`ParallelGeoTiffWriter`],
    /// avoiding intermediate files and a mosaic step.
    pub fn rasterize(
        &self,
        vector_fn: &Path,
        column_name: Option<&str>,
        map: Option<HashMap<String, i32>>,
        layer_name: Option<&str>,
        n_cpus: usize,
        rasterize_options: Option<RasterizeOptions>,
        out_file: &Path,
    ) {
        log::debug!("Starting rasterize.");
        gdal::config::set_config_option("OSR_DEFAULT_AXIS_MAPPING_STRATEGY", "TRADITIONAL_GIS_ORDER").unwrap();

        if column_name.is_none() {
            log::warn!("No --column specified, using feature FID as burn value");
        }

        let n_bands = self.metadata.shape.layers * self.metadata.shape.times;

        // Pre-create output GeoTIFF
        parallel_writer::create_output_geotiff::<R>(
            out_file,
            &self.metadata.geo_transform,
            self.metadata.epsg_code as u32,
            self.metadata.shape.cols,
            self.metadata.shape.rows,
            n_bands,
            R::zero(),
        )
        .expect("Failed to create output GeoTIFF");

        let writer = parallel_writer::ParallelGeoTiffWriter::new(
            out_file.to_path_buf(),
            self.metadata.geo_transform,
            self.metadata.epsg_code as u32,
            self.metadata.shape.cols,
            self.metadata.shape.rows,
            n_bands,
        );

        let pool = create_rayon_pool(n_cpus);

        pool.install(|| {
            par_tqdm!(self.blocks.to_owned().into_par_iter()).for_each(
                |raster_block: RasterBlock<R>| {
                    let dataset = Dataset::open(vector_fn).expect("unable to open dataset");
                    let mut layer = get_layer(&dataset, layer_name).unwrap();

                    // Compute block bounding box in raster CRS
                    let x_ul = raster_block.geo_transform.x_ul;
                    let y_ul = raster_block.geo_transform.y_ul;
                    let x2 = x_ul
                        + (raster_block.geo_transform.x_res
                            * raster_block.read_window.size.cols as f64);
                    let y2 = y_ul
                        + (raster_block.geo_transform.y_res
                            * raster_block.read_window.size.rows as f64);
                    let (xmin, xmax) = if x_ul < x2 { (x_ul, x2) } else { (x2, x_ul) };
                    let (ymin, ymax) = if y_ul < y2 { (y_ul, y2) } else { (y2, y_ul) };

                    let src = SpatialRef::from_epsg(raster_block.epsg_code as u32)
                        .expect("Invalid raster EPSG code");
                    let dst = layer
                        .spatial_ref()
                        .expect("Vector layer has no CRS defined");

                    // Create and transform spatial filter polygon
                    let mut filter_geom = gdal::vector::Geometry::from_wkt(&format!(
                        "POLYGON (({} {}, {} {}, {} {}, {} {}, {} {}))",
                        xmin, ymin, xmin, ymax, xmax, ymax, xmax, ymin, xmin, ymin
                    ))
                    .unwrap();
                    filter_geom.set_spatial_ref(src.clone());
                    let filter_geom = if src != dst {
                        let tx = CoordTransform::new(&src, &dst)
                            .expect("Failed to create coordinate transform");
                        filter_geom
                            .transform(&tx)
                            .expect("Failed to transform bbox into layer CRS")
                    } else {
                        filter_geom
                    };
                    layer.set_spatial_filter(&filter_geom);

                    // Collect geometries and burn values from features in this block
                    let mut geometries = Vec::new();
                    let mut burn_values: Vec<R> = Vec::new();
                    let available_fields: Vec<String> = if column_name.is_some() {
                        layer.defn().fields().map(|f| f.name()).collect()
                    } else {
                        Vec::new()
                    };

                    for feature in layer.features() {
                        let g = feature.geometry().unwrap().to_owned();
                        let g = g.buffer(0., 8).unwrap();
                        let g = if src != dst {
                            let tx = CoordTransform::new(&dst, &src)
                                .expect("Failed to create coordinate transform");
                            g.transform(&tx)
                                .expect("Failed to transform geom into raster CRS")
                        } else {
                            g
                        };

                        if let Some(col) = column_name {
                            let field_index = match feature.field_index(col) {
                                Ok(idx) => idx,
                                Err(_) => panic!(
                                    "Column '{}' not found in vector layer. Available fields: {}",
                                    col,
                                    available_fields.join(", ")
                                ),
                            };
                            if let Some(field_value) =
                                feature.field(field_index).expect("Invalid column name!")
                            {
                                let v = match field_value {
                                    gdal::vector::FieldValue::IntegerValue(i) => {
                                        NumCast::from(i).unwrap_or(0)
                                    }
                                    gdal::vector::FieldValue::RealValue(r) => {
                                        NumCast::from(r).unwrap()
                                    }
                                    gdal::vector::FieldValue::StringValue(s) => match &map {
                                        Some(hmap) => match hmap.get(&s) {
                                            Some(v) => *v,
                                            None => panic!(
                                                "String '{}' not found in lookup table",
                                                s
                                            ),
                                        },
                                        None => panic!(
                                            "Field '{}' is a string ('{}'), but no string→number map was supplied",
                                            col, s
                                        ),
                                    },
                                    _ => panic!("The field to rasterize has to be numeric"),
                                };
                                burn_values.push(NumCast::from(v).expect("Cant cast!"));
                            } else {
                                panic!("{}", format!("Field {col} does not exist"));
                            }
                        } else {
                            let fid = feature.fid().expect("Feature has no FID") as i32;
                            burn_values.push(NumCast::from(fid).expect("Cant cast!"));
                        }
                        geometries.push(g);
                    }

                    // Create an in-memory raster for this block, rasterize into it
                    let block_rows = raster_block.read_window.size.rows as usize;
                    let block_cols = raster_block.read_window.size.cols as usize;

                    let mem_driver = DriverManager::get_driver_by_name("MEM")
                        .expect("GDAL MEM driver not available");
                    let mut mem_dataset = mem_driver
                        .create_with_band_type::<R, _>("", block_cols, block_rows, 1)
                        .expect("Failed to create in-memory dataset");
                    mem_dataset
                        .set_geo_transform(&raster_block.geo_transform.to_array())
                        .expect("Failed to set geo transform");
                    let srs = SpatialRef::from_epsg(raster_block.epsg_code as u32)
                        .expect("Invalid EPSG code");
                    mem_dataset
                        .set_spatial_ref(&srs)
                        .expect("Failed to set CRS");

                    let opts = rasterize_options.unwrap_or_else(|| RasterizeOptions::default());
                    let burn_values_f64: Vec<f64> = burn_values
                        .iter()
                        .map(|v| NumCast::from(*v).unwrap_or(0.0))
                        .collect();
                    gdal::raster::rasterize(
                        &mut mem_dataset,
                        &[1],
                        &geometries,
                        burn_values_f64.as_slice(),
                        Some(opts),
                    )
                    .expect("Rasterize failed");

                    // Read rasterized data back from in-memory dataset
                    let band = mem_dataset.rasterband(1).expect("No band 1");
                    let buffer: gdal::raster::Buffer<R> = band
                        .read_as::<R>(
                            (0, 0),
                            (block_cols, block_rows),
                            (block_cols, block_rows),
                            None,
                        )
                        .expect("Failed to read rasterized data");
                    let (_shape, data) = buffer.into_shape_and_vec();

                    let array = Array3::from_shape_vec((1, block_rows, block_cols), data)
                        .expect("Shape mismatch reading rasterized data");

                    // Write to output via parallel writer (no mosaic step needed)
                    parallel_writer::write_block(
                        &writer,
                        array.view(),
                        raster_block.read_window.clone(),
                    )
                    .expect("Failed to write block to output");
                },
            )
        });
    }

    /// Burns vector values into the raster dataset, outputting a Cloud Optimized GeoTIFF.
    ///
    /// COG variant of [`rasterize`](Self::rasterize). Follows the same rasterization
    /// logic but produces a proper COG with overviews and IFD reordering.
    pub fn rasterize_cog(
        &self,
        vector_fn: &Path,
        column_name: Option<&str>,
        map: Option<HashMap<String, i32>>,
        layer_name: Option<&str>,
        n_cpus: usize,
        rasterize_options: Option<RasterizeOptions>,
        out_file: &Path,
        config: &crate::types::OutputConfig,
    ) {
        use crate::types::OutputFormat;

        match config.format {
            OutputFormat::GeoTiff => {
                self.rasterize(vector_fn, column_name, map, layer_name, n_cpus, rasterize_options, out_file);
            }
            OutputFormat::GeoTiffOverviews => {
                self.rasterize(vector_fn, column_name, map, layer_name, n_cpus, rasterize_options, out_file);

                let n_bands = self.metadata.shape.layers * self.metadata.shape.times;
                let writer = parallel_writer::ParallelGeoTiffWriter::new(
                    out_file.to_path_buf(),
                    self.metadata.geo_transform,
                    self.metadata.epsg_code as u32,
                    self.metadata.shape.cols,
                    self.metadata.shape.rows,
                    n_bands,
                );
                writer.build_overviews(&config.overview_resampling, &config.overview_levels).ok();
            }
            OutputFormat::COG => {
                let intermediate = std::path::PathBuf::from(crate::gdal_utils::create_temp_file("tif"));

                self.rasterize(vector_fn, column_name, map, layer_name, n_cpus, rasterize_options, &intermediate);

                crate::gdal_utils::translate_to_cog(
                    &intermediate, out_file,
                    &config.compression,
                    &config.overview_resampling,
                ).ok();
                std::fs::remove_file(&intermediate).ok();
            }
        }
    }
}