eorst 1.0.1

Earth Observation and Remote Sensing Toolkit - library for raster processing pipelines
//! Parallel GeoTIFF writer inspired by Dask/rioxarray's RasterioWriter.
//!
//! Pre-creates an output GeoTIFF and allows multiple threads to write
//! their blocks directly to windows in the file, serialized by a Mutex.
//!
//! Unlike the rioxarray approach which opens/closes per chunk, this writer
//! keeps the GDAL dataset open across writes (protected by a Mutex),
//! avoiding repeated header parsing overhead.
//!
//! This replaces the mosaic phase (gdalwarp → gdalbuildvrt → gdal_translate
//! subprocess chain) with direct windowed writes, eliminating subprocess
//! spawning and intermediate files.

use crate::core_types::RasterType;
use crate::types::{GeoTransform, ReadWindow};
use anyhow::{Context, Result};
use gdal::{
    config, Dataset, DatasetOptions, DriverManager, GdalOpenFlags,
    raster::{Buffer, RasterCreationOptions},
    spatial_ref::SpatialRef,
};
use std::path::{Path, PathBuf};
use std::sync::Mutex;

/// Writer that allows parallel block writes to a pre-created GeoTIFF.
///
/// Thread-safe via internal Mutex. The GDAL dataset is opened once and
/// kept alive for the writer's lifetime. Each `write_block()` call
/// acquires the Mutex, writes to the appropriate window, and returns.
///
/// The dataset is automatically closed when the writer is dropped.
pub struct ParallelGeoTiffWriter {
    /// Path to the output GeoTIFF file
    pub output_path: PathBuf,
    /// Geographic transformation parameters
    pub geo_transform: GeoTransform,
    /// EPSG coordinate reference system code
    pub epsg_code: u32,
    /// Total image size in pixels
    pub total_cols: usize,
    pub total_rows: usize,
    /// Number of bands in the output
    pub n_bands: usize,
    /// Mutex-guarded cached GDAL dataset (opened on first write)
    pub dataset: Mutex<Option<Dataset>>,
}

impl ParallelGeoTiffWriter {
    /// Creates a new writer for the given output path and parameters.
    ///
    /// The dataset is not opened until the first `write_block()` call.
    pub fn new(
        output_path: PathBuf,
        geo_transform: GeoTransform,
        epsg_code: u32,
        total_cols: usize,
        total_rows: usize,
        n_bands: usize,
    ) -> Self {
        Self {
            output_path,
            geo_transform,
            epsg_code,
            total_cols,
            total_rows,
            n_bands,
            dataset: Mutex::new(None),
        }
    }

    /// Builds overviews on the written GeoTIFF.
    ///
    /// This must be called after all `write_block()` calls are complete.
    /// It opens the dataset in update mode and calls GDAL's `build_overviews`
    /// with `GDAL_NUM_THREADS=ALL_CPUS` for multithreaded processing.
    ///
    /// # Arguments
    /// * `resampling` - Resampling method, e.g. "CUBIC", "NEAREST", "AVERAGE"
    /// * `levels` - Overview decimation factors, e.g. \[2, 4, 8, 16, 32\]
    pub fn build_overviews(&self, resampling: &str, levels: &[i32]) -> Result<()> {
        let mut guard = self.dataset.lock().expect("Mutex poisoned");

        // Open dataset if not already open
        if guard.is_none() {
            let opts = DatasetOptions {
                open_flags: GdalOpenFlags::GDAL_OF_UPDATE,
                ..DatasetOptions::default()
            };
            let dataset = Dataset::open_ex(&self.output_path, opts)
                .with_context(|| format!("Failed to open {:?} for overview building", self.output_path))?;
            *guard = Some(dataset);
        }

        // Enable multithreaded overview building — uses all available CPUs
        config::set_config_option("GDAL_NUM_THREADS", "ALL_CPUS")
            .context("Failed to set GDAL_NUM_THREADS")?;

        let dataset = guard.as_mut().unwrap();
        dataset
            .build_overviews(resampling, levels, &[])
            .with_context(|| format!("Failed to build overviews on {:?}", self.output_path))?;

        Ok(())
    }
}

/// Pre-creates a GeoTIFF file with the given parameters.
///
/// Creates the file with GTiff driver, LZW compression, tiled 512x512 blocks,
/// correct geotransform, CRS, and no-data values.
pub fn create_output_geotiff<T: RasterType>(
    path: &Path,
    geo_transform: &GeoTransform,
    epsg_code: u32,
    total_cols: usize,
    total_rows: usize,
    n_bands: usize,
    na_value: T,
) -> Result<()> {
    if let Some(parent) = path.parent() {
        std::fs::create_dir_all(parent)
            .with_context(|| format!("Failed to create parent directory for {:?}", path))?;
    }

    let driver = DriverManager::get_driver_by_name("GTIFF")
        .context("GTiff driver not available")?;

    let options = RasterCreationOptions::from_iter([
        "COMPRESS=LZW",
        "TILED=YES",
        "BLOCKXSIZE=512",
        "BLOCKYSIZE=512",
        "BIGTIFF=YES",
    ]);

    let mut dataset = driver
        .create_with_band_type_with_options::<T, _>(
            path,
            total_cols,
            total_rows,
            n_bands,
            &options,
        )
        .with_context(|| format!("Failed to create GeoTIFF at {:?}", path))?;

    dataset
        .set_geo_transform(&geo_transform.to_array())
        .context("Failed to set geo transform")?;

    let srs = SpatialRef::from_epsg(epsg_code)
        .context(format!("Invalid EPSG code: {}", epsg_code))?;
    dataset
        .set_spatial_ref(&srs)
        .context("Failed to set spatial reference")?;

    for band_idx in 1..=n_bands {
        let mut band = dataset
            .rasterband(band_idx)
            .context(format!("Failed to access band {}", band_idx))?;
        if let Some(na_f64) = na_value.to_f64() {
            band.set_no_data_value(Some(na_f64))
                .context("Failed to set no-data value")?;
        }
    }

    Ok(())
}

/// Writes a single block's data (all bands) to the pre-created GeoTIFF.
///
/// This method is thread-safe: it acquires the writer's mutex, opens the
/// dataset on first call (then reuses it), writes all bands to the
/// specified window, and returns. The dataset stays open for subsequent writes.
///
/// # Arguments
/// * `data` - 3D array with shape (bands, rows, cols)
/// * `window` - The read window defining where this block belongs in the output
pub fn write_block<T: RasterType>(
    writer: &ParallelGeoTiffWriter,
    data: ndarray::ArrayView3<T>,
    window: ReadWindow,
) -> Result<()> {
    let mut guard = writer.dataset.lock().expect("Mutex poisoned");

    // Open dataset on first write
    if guard.is_none() {
        let opts = DatasetOptions {
            open_flags: GdalOpenFlags::GDAL_OF_UPDATE,
            ..DatasetOptions::default()
        };
        let dataset = Dataset::open_ex(&writer.output_path, opts)
            .with_context(|| format!("Failed to open {:?} for update", writer.output_path))?;
        *guard = Some(dataset);
    }

    let dataset = guard.as_mut().unwrap();
    let block_rows = data.shape()[1];
    let block_cols = data.shape()[2];

    for band_idx in 0..data.shape()[0] {
        let mut band = dataset
            .rasterband(band_idx + 1)
            .with_context(|| format!("Failed to access band {}", band_idx + 1))?;

        let band_data = data.index_axis(ndarray::Axis(0), band_idx);
        let data_vec: Vec<T> = band_data.into_iter().copied().collect();
        let mut buffer = Buffer::new((block_cols, block_rows), data_vec);

        band.write(
            (window.offset.cols, window.offset.rows),
            (block_cols, block_rows),
            &mut buffer,
        )
        .with_context(|| {
            format!(
                "Failed to write band {} to window {:?}",
                band_idx + 1,
                window
            )
        })?;
    }

    Ok(())
}