eorst 1.0.1

Earth Observation and Remote Sensing Toolkit - library for raster processing pipelines
//! Module to handle async I/O for raster datasets using async-tiff.
//!
//! This module provides async variants of the read functions,
//! allowing direct reading from S3 via async-tiff instead of GDAL VSI.

use crate::core_types::RasterData;
use crate::core_types::RasterType;
use crate::metadata::RasterDataBlock;
use crate::rasterdataset::RasterDataset;
use async_tiff::decoder::DecoderRegistry;
use async_tiff::metadata::cache::ReadaheadMetadataCache;
use async_tiff::metadata::TiffMetadataReader;
use async_tiff::reader::ObjectReader;
use async_tiff::TIFF;
use ndarray::{s, Array2, Array4};
use object_store::aws::AmazonS3Builder;


/// S3 configuration for DEA public data
const S3_REGION: &str = "ap-southeast-2";

/// Build object_store S3 store for a given bucket
fn build_s3_store(bucket: &str) -> std::sync::Arc<dyn object_store::ObjectStore> {
    let store = AmazonS3Builder::new()
        .with_bucket_name(bucket)
        .with_region(S3_REGION)
        .with_skip_signature(true)
        .build()
        .expect("failed to build S3 store");
    std::sync::Arc::new(store)
}

/// Cached async TIFF reader for efficient multi-block reads.
///
/// Opens the TIFF once and reuses the handle for subsequent reads.
struct CachedTiffReader {
    tiff: TIFF,
    reader: ObjectReader,
    decoder: DecoderRegistry,
}

impl CachedTiffReader {
    /// Open a COG from S3 and parse metadata once.
    async fn open(s3_url: &str) -> anyhow::Result<Self> {
        let url = url::Url::parse(s3_url)?;
        let bucket = url.host_str().ok_or_else(|| anyhow::anyhow!("No bucket in URL"))?;
        let path = url.path().trim_start_matches('/');

        let store = build_s3_store(bucket);
        let reader = ObjectReader::new(store, path.into());
        let cache = ReadaheadMetadataCache::new(reader.clone());

        let mut meta = TiffMetadataReader::try_open(&cache).await?;
        let ifds = meta.read_all_ifds(&cache).await?;
        let tiff = TIFF::new(ifds, meta.endianness());

        Ok(Self {
            tiff,
            reader,
            decoder: DecoderRegistry::default(),
        })
    }

    /// Read a window from a specific band (1-based index like GDAL).
    async fn read_window<T: RasterType>(
        &self,
        band_index: usize,
        offset: (isize, isize),
        window_size: (usize, usize),
    ) -> anyhow::Result<Array2<T>> {
        let ifd = &self.tiff.ifds()[band_index - 1];
        let tile_h = ifd.tile_height().expect("not tiled") as usize;
        let tile_w = ifd.tile_width().expect("not tiled") as usize;

        let (x_off, y_off) = (offset.0 as usize, offset.1 as usize);
        let (width, height) = window_size;

        // Compute which tiles we need
        let start_ty = y_off / tile_h;
        let start_tx = x_off / tile_w;
        let end_ty = (y_off + height - 1) / tile_h;
        let end_tx = (x_off + width - 1) / tile_w;

        let tile_coords: Vec<_> = (start_ty..=end_ty)
            .flat_map(|ty| (start_tx..=end_tx).map(move |tx| (tx as usize, ty as usize)))
            .collect();

        // Batch-fetch all tiles in parallel
        let tiles = ifd.fetch_tiles(&tile_coords, &self.reader).await?;

        // Allocate output buffer
        let mut output = Array2::<T>::zeros((height, width));

        // Stitch tiles into output
        for (tile_idx, tile) in tiles.into_iter().enumerate() {
            let (tx, ty) = tile_coords[tile_idx];
            let array = tile.decode(&self.decoder)?;
            let (typed, shape, _dtype) = array.into_inner();

            let t_h = shape[0];
            let t_w = shape[1];

            let tile_pixel_y = ty * tile_h;
            let tile_pixel_x = tx * tile_w;

            let tile_row_start = y_off.saturating_sub(tile_pixel_y);
            let tile_col_start = x_off.saturating_sub(tile_pixel_x);

            let out_row_start = tile_pixel_y.saturating_sub(y_off);
            let out_col_start = tile_pixel_x.saturating_sub(x_off);

            let copy_rows = (out_row_start + (t_h - tile_row_start)).min(height) - out_row_start;
            let copy_cols = (out_col_start + (t_w - tile_col_start)).min(width) - out_col_start;

            if let async_tiff::TypedArray::UInt16(data) = typed {
                let tile_arr = Array2::from_shape_vec((t_h, t_w), data)?;
                let tile_slice = tile_arr.slice(s![
                    tile_row_start..tile_row_start + copy_rows,
                    tile_col_start..tile_col_start + copy_cols
                ]);
                for i in 0..copy_rows {
                    for j in 0..copy_cols {
                        if let Some(val) = num_traits::NumCast::from(tile_slice[[i, j]]) {
                            output[[out_row_start + i, out_col_start + j]] = val;
                        }
                    }
                }
            }
        }

        Ok(output)
    }
}

/// Read a raster band window using async-tiff (S3 direct, no GDAL).
///
/// This is the async equivalent of `gdal_utils::read_raster_band()`.
/// Uses a cached TIFF reader for efficiency.
///
/// # Arguments
/// * `s3_url` - S3 URL like "s3://bucket/path/to/file.tif"
/// * `band_index` - Band index (1-based, like GDAL)
/// * `offset` - (x_offset, y_offset) in pixels
/// * `window_size` - (width, height) in pixels
pub async fn read_raster_band_async<T: RasterType>(
    s3_url: &str,
    band_index: usize,
    offset: (isize, isize),
    window_size: (usize, usize),
) -> anyhow::Result<Array2<T>> {
    // Use cached reader for efficiency
    let reader = CachedTiffReader::open(s3_url).await?;
    reader.read_window(band_index, offset, window_size).await
}

/// Reads a block of raster data using async-tiff.
///
/// This is the async equivalent of `io::read_block()`.
/// Reads from S3 directly using async-tiff instead of opening via GDAL.
impl<R> RasterDataset<R>
where
    R: RasterType,
{
    /// Async variant of `read_block()` using async-tiff for S3 access.
    ///
    /// Returns `RasterData<T>` (4D array: times × layers × rows × cols).
    pub async fn read_block_async<T: RasterType>(
        &self,
        block_id: usize,
    ) -> RasterData<T> {
        // Convert VSI paths to S3 URLs
        let s3_urls: Vec<String> = self.metadata.layers
            .iter()
            .map(|layer| vsi_to_s3_url(layer.source.to_str().unwrap_or_default()))
            .collect();
        
        self.read_block_async_with_urls::<T>(&s3_urls, block_id).await
    }

    /// Async read block using pre-computed S3 URLs.
    ///
    /// Caches TIFF handles for efficiency across multiple block reads.
    pub async fn read_block_async_with_urls<T: RasterType>(
        &self,
        s3_urls: &[String],
        block_id: usize,
    ) -> RasterData<T> {
        let block = &self.blocks[block_id];
        let read_window = block.read_window;

        let rows = read_window.size.rows as usize;
        let cols = read_window.size.cols as usize;
        let data_shape = (
            self.metadata.shape.times,
            self.metadata.shape.layers,
            rows,
            cols,
        );

        let mut data: RasterData<T> = RasterData::zeros(data_shape);

        // Cache readers in a HashMap to avoid reopening
        let mut readers: std::collections::HashMap<usize, CachedTiffReader> = std::collections::HashMap::new();
        
        for (idx, layer) in self.metadata.layers.iter().enumerate() {
            // Open reader if not already cached
            if !readers.contains_key(&idx) {
                let reader = CachedTiffReader::open(&s3_urls[idx]).await;
                if let Ok(r) = reader {
                    readers.insert(idx, r);
                } else {
                    continue;
                }
            }
            
            let reader = readers.get(&idx).unwrap();
            let window = (read_window.offset.cols, read_window.offset.rows);
            let window_size = (cols, rows);

            let layer_data = reader.read_window::<T>(1, window, window_size).await;
            // For now, unwrap — production code should handle errors better
            let layer_data = layer_data.expect("async read failed");

            let slice = s![
                layer.time_pos,
                layer.layer_pos,
                ..,
                ..
            ];
            data.slice_mut(slice).assign(&layer_data);
        }
        data
    }

    /// Async variant of `apply()` using async-tiff for S3 access.
    ///
    /// Applies a worker function to each block asynchronously.
    /// This is the async equivalent of `processing::apply()`.
    /// Reads from S3 directly using async-tiff (COG-aware, parallel tile fetching).
    ///
    /// # Arguments
    /// * `worker` - Function that processes a `RasterDataBlock<T>` and returns `Array4<U>`
    /// * `n_cpus` - Number of parallel workers (used for writing phase)
    /// * `out_file` - Output file path
    pub async fn apply_async<U>(
        &self,
        worker: fn(&RasterDataBlock<R>) -> anyhow::Result<Array4<U>>,
        n_cpus: usize,
        out_file: &std::path::Path,
    ) -> anyhow::Result<()>
    where
        U: RasterType,
    {
        // Convert VSI paths to S3 URLs
        let s3_urls: Vec<String> = self.metadata.layers
            .iter()
            .map(|layer| vsi_to_s3_url(layer.source.to_str().unwrap_or_default()))
            .collect();
        
        self.apply_async_with_urls(&s3_urls, worker, n_cpus, out_file).await
    }

    /// Async variant that accepts pre-computed S3 URLs.
    ///
    /// Use this when you have the S3 URLs directly (e.g., from STAC assets).
    pub async fn apply_async_with_urls<U>(
        &self,
        s3_urls: &[String],
        worker: fn(&RasterDataBlock<R>) -> anyhow::Result<Array4<U>>,
        n_cpus: usize,
        out_file: &std::path::Path,
    ) -> anyhow::Result<()>
    where
        U: RasterType,
    {
        use crate::gdal_utils::{create_temp_file, file_stem_str, mosaic_translate_cleanup_time_steps};
        use num_traits::NumCast;
 
        use std::path::PathBuf;

        let tmp_file = PathBuf::from(create_temp_file("vrt"));
        let n_times = self.metadata.shape.times;
        let epsg_code = self.metadata.epsg_code;

        // Phase 1: Read blocks asynchronously using async-tiff and apply worker
        let block_futures: Vec<_> = self
            .blocks
            .iter()
            .map(|raster_block| {
                let tmp_file_clone = tmp_file.clone();
                let s3_urls_clone = s3_urls.to_vec();
                async move {
                    let bid = raster_block.block_index;
                    let file_stem = file_stem_str(&tmp_file_clone);

                    // Read block using async-tiff (S3 direct)
                    let block_data: RasterData<R> = self.read_block_async_with_urls::<R>(&s3_urls_clone, bid).await;

                    // Create RasterDataBlock with metadata
                    let raster_data_block = RasterDataBlock {
                        data: block_data,
                        metadata: self.metadata.clone(),
                        no_data: NumCast::from(0).unwrap(),
                    };

                    // Apply worker function
                    let result = worker(&raster_data_block)?;

                    // Write block to temp files
                    let block_fns = raster_block.write_time_step_blocks(
                        &result,
                        &tmp_file_clone,
                        file_stem,
                        bid,
                    );

                    anyhow::Result::Ok(block_fns)
                }
            })
            .collect();

        // Execute all block processing concurrently
        let collected: Vec<anyhow::Result<Vec<PathBuf>>> =
            futures::future::join_all(block_futures).await;

        let collected: Vec<Vec<PathBuf>> = collected
            .into_iter()
            .collect::<anyhow::Result<_>>()?;

        // Phase 2: Assemble time steps in parallel (use rayon for writing)
        let pool = crate::gdal_utils::create_rayon_pool(n_cpus);
        pool.install(|| {
            mosaic_translate_cleanup_time_steps(&collected, out_file, epsg_code, n_times);
        });

        Ok(())
    }
}

/// Converts a GDAL VSI path to an S3 URL.
///
/// Example: "/vsis3/dea-public-data/path/to/file.tif" → "s3://dea-public-data/path/to/file.tif"
fn vsi_to_s3_url(vsi_path: &str) -> String {
    // Handle /vsis3/ paths (direct S3)
    if let Some(stripped) = vsi_path.strip_prefix("/vsis3/") {
        return format!("s3://{}", stripped);
    }
    // Handle /vsicurl/https://... paths (HTTP URLs)
    if let Some(stripped) = vsi_path.strip_prefix("/vsicurl/") {
        // The path is something like "https://dea-public-data.s3.ap-southeast-2.amazonaws.com/..."
        // We need to convert to s3://bucket/path format
        if let Ok(url) = url::Url::parse(stripped) {
            if let Some(host) = url.host_str() {
                // Extract bucket name from host (e.g., "dea-public-data.s3.ap-southeast-2.amazonaws.com")
                if let Some(bucket) = host.split('.').next() {
                    let path = url.path().trim_start_matches('/');
                    return format!("s3://{}/{}", bucket, path);
                }
            }
        }
        return stripped.to_string();
    }
    // If it's already an s3:// URL, return as-is
    if vsi_path.starts_with("s3://") {
        return vsi_path.to_string();
    }
    // Fallback: return as-is
    vsi_path.to_string()
}

#[cfg(test)]
mod tests {
    use super::*;
    use tokio;

    #[tokio::test]
    async fn test_read_raster_band_async() {
        // Test reading a small block from DEA Sentinel-2 COG
        let s3_url = "s3://dea-public-data/baseline/ga_s2bm_ard_3/56/JNS/2021/01/15/20210116T010541/ga_s2bm_nbart_3-2-1_56JNS_2021-01-15_final_band04.tif";
        let result = read_raster_band_async::<u16>(s3_url, 1, (0, 0), (512, 512)).await;
        assert!(result.is_ok(), "Failed to read raster band async: {:?}", result.err());
        let data = result.unwrap();
        assert_eq!(data.shape(), &[512, 512]);
    }
}