use crate::core_types::RasterData;
use crate::core_types::RasterType;
use crate::metadata::RasterDataBlock;
use crate::rasterdataset::RasterDataset;
use async_tiff::decoder::DecoderRegistry;
use async_tiff::metadata::cache::ReadaheadMetadataCache;
use async_tiff::metadata::TiffMetadataReader;
use async_tiff::reader::ObjectReader;
use async_tiff::TIFF;
use ndarray::{s, Array2, Array4};
use object_store::aws::AmazonS3Builder;
const S3_REGION: &str = "ap-southeast-2";
fn build_s3_store(bucket: &str) -> std::sync::Arc<dyn object_store::ObjectStore> {
let store = AmazonS3Builder::new()
.with_bucket_name(bucket)
.with_region(S3_REGION)
.with_skip_signature(true)
.build()
.expect("failed to build S3 store");
std::sync::Arc::new(store)
}
struct CachedTiffReader {
tiff: TIFF,
reader: ObjectReader,
decoder: DecoderRegistry,
}
impl CachedTiffReader {
async fn open(s3_url: &str) -> anyhow::Result<Self> {
let url = url::Url::parse(s3_url)?;
let bucket = url.host_str().ok_or_else(|| anyhow::anyhow!("No bucket in URL"))?;
let path = url.path().trim_start_matches('/');
let store = build_s3_store(bucket);
let reader = ObjectReader::new(store, path.into());
let cache = ReadaheadMetadataCache::new(reader.clone());
let mut meta = TiffMetadataReader::try_open(&cache).await?;
let ifds = meta.read_all_ifds(&cache).await?;
let tiff = TIFF::new(ifds, meta.endianness());
Ok(Self {
tiff,
reader,
decoder: DecoderRegistry::default(),
})
}
async fn read_window<T: RasterType>(
&self,
band_index: usize,
offset: (isize, isize),
window_size: (usize, usize),
) -> anyhow::Result<Array2<T>> {
let ifd = &self.tiff.ifds()[band_index - 1];
let tile_h = ifd.tile_height().expect("not tiled") as usize;
let tile_w = ifd.tile_width().expect("not tiled") as usize;
let (x_off, y_off) = (offset.0 as usize, offset.1 as usize);
let (width, height) = window_size;
let start_ty = y_off / tile_h;
let start_tx = x_off / tile_w;
let end_ty = (y_off + height - 1) / tile_h;
let end_tx = (x_off + width - 1) / tile_w;
let tile_coords: Vec<_> = (start_ty..=end_ty)
.flat_map(|ty| (start_tx..=end_tx).map(move |tx| (tx as usize, ty as usize)))
.collect();
let tiles = ifd.fetch_tiles(&tile_coords, &self.reader).await?;
let mut output = Array2::<T>::zeros((height, width));
for (tile_idx, tile) in tiles.into_iter().enumerate() {
let (tx, ty) = tile_coords[tile_idx];
let array = tile.decode(&self.decoder)?;
let (typed, shape, _dtype) = array.into_inner();
let t_h = shape[0];
let t_w = shape[1];
let tile_pixel_y = ty * tile_h;
let tile_pixel_x = tx * tile_w;
let tile_row_start = y_off.saturating_sub(tile_pixel_y);
let tile_col_start = x_off.saturating_sub(tile_pixel_x);
let out_row_start = tile_pixel_y.saturating_sub(y_off);
let out_col_start = tile_pixel_x.saturating_sub(x_off);
let copy_rows = (out_row_start + (t_h - tile_row_start)).min(height) - out_row_start;
let copy_cols = (out_col_start + (t_w - tile_col_start)).min(width) - out_col_start;
if let async_tiff::TypedArray::UInt16(data) = typed {
let tile_arr = Array2::from_shape_vec((t_h, t_w), data)?;
let tile_slice = tile_arr.slice(s![
tile_row_start..tile_row_start + copy_rows,
tile_col_start..tile_col_start + copy_cols
]);
for i in 0..copy_rows {
for j in 0..copy_cols {
if let Some(val) = num_traits::NumCast::from(tile_slice[[i, j]]) {
output[[out_row_start + i, out_col_start + j]] = val;
}
}
}
}
}
Ok(output)
}
}
pub async fn read_raster_band_async<T: RasterType>(
s3_url: &str,
band_index: usize,
offset: (isize, isize),
window_size: (usize, usize),
) -> anyhow::Result<Array2<T>> {
let reader = CachedTiffReader::open(s3_url).await?;
reader.read_window(band_index, offset, window_size).await
}
impl<R> RasterDataset<R>
where
R: RasterType,
{
pub async fn read_block_async<T: RasterType>(
&self,
block_id: usize,
) -> RasterData<T> {
let s3_urls: Vec<String> = self.metadata.layers
.iter()
.map(|layer| vsi_to_s3_url(layer.source.to_str().unwrap_or_default()))
.collect();
self.read_block_async_with_urls::<T>(&s3_urls, block_id).await
}
pub async fn read_block_async_with_urls<T: RasterType>(
&self,
s3_urls: &[String],
block_id: usize,
) -> RasterData<T> {
let block = &self.blocks[block_id];
let read_window = block.read_window;
let rows = read_window.size.rows as usize;
let cols = read_window.size.cols as usize;
let data_shape = (
self.metadata.shape.times,
self.metadata.shape.layers,
rows,
cols,
);
let mut data: RasterData<T> = RasterData::zeros(data_shape);
let mut readers: std::collections::HashMap<usize, CachedTiffReader> = std::collections::HashMap::new();
for (idx, layer) in self.metadata.layers.iter().enumerate() {
if !readers.contains_key(&idx) {
let reader = CachedTiffReader::open(&s3_urls[idx]).await;
if let Ok(r) = reader {
readers.insert(idx, r);
} else {
continue;
}
}
let reader = readers.get(&idx).unwrap();
let window = (read_window.offset.cols, read_window.offset.rows);
let window_size = (cols, rows);
let layer_data = reader.read_window::<T>(1, window, window_size).await;
let layer_data = layer_data.expect("async read failed");
let slice = s![
layer.time_pos,
layer.layer_pos,
..,
..
];
data.slice_mut(slice).assign(&layer_data);
}
data
}
pub async fn apply_async<U>(
&self,
worker: fn(&RasterDataBlock<R>) -> anyhow::Result<Array4<U>>,
n_cpus: usize,
out_file: &std::path::Path,
) -> anyhow::Result<()>
where
U: RasterType,
{
let s3_urls: Vec<String> = self.metadata.layers
.iter()
.map(|layer| vsi_to_s3_url(layer.source.to_str().unwrap_or_default()))
.collect();
self.apply_async_with_urls(&s3_urls, worker, n_cpus, out_file).await
}
pub async fn apply_async_with_urls<U>(
&self,
s3_urls: &[String],
worker: fn(&RasterDataBlock<R>) -> anyhow::Result<Array4<U>>,
n_cpus: usize,
out_file: &std::path::Path,
) -> anyhow::Result<()>
where
U: RasterType,
{
use crate::gdal_utils::{create_temp_file, file_stem_str, mosaic_translate_cleanup_time_steps};
use num_traits::NumCast;
use std::path::PathBuf;
let tmp_file = PathBuf::from(create_temp_file("vrt"));
let n_times = self.metadata.shape.times;
let epsg_code = self.metadata.epsg_code;
let block_futures: Vec<_> = self
.blocks
.iter()
.map(|raster_block| {
let tmp_file_clone = tmp_file.clone();
let s3_urls_clone = s3_urls.to_vec();
async move {
let bid = raster_block.block_index;
let file_stem = file_stem_str(&tmp_file_clone);
let block_data: RasterData<R> = self.read_block_async_with_urls::<R>(&s3_urls_clone, bid).await;
let raster_data_block = RasterDataBlock {
data: block_data,
metadata: self.metadata.clone(),
no_data: NumCast::from(0).unwrap(),
};
let result = worker(&raster_data_block)?;
let block_fns = raster_block.write_time_step_blocks(
&result,
&tmp_file_clone,
file_stem,
bid,
);
anyhow::Result::Ok(block_fns)
}
})
.collect();
let collected: Vec<anyhow::Result<Vec<PathBuf>>> =
futures::future::join_all(block_futures).await;
let collected: Vec<Vec<PathBuf>> = collected
.into_iter()
.collect::<anyhow::Result<_>>()?;
let pool = crate::gdal_utils::create_rayon_pool(n_cpus);
pool.install(|| {
mosaic_translate_cleanup_time_steps(&collected, out_file, epsg_code, n_times);
});
Ok(())
}
}
fn vsi_to_s3_url(vsi_path: &str) -> String {
if let Some(stripped) = vsi_path.strip_prefix("/vsis3/") {
return format!("s3://{}", stripped);
}
if let Some(stripped) = vsi_path.strip_prefix("/vsicurl/") {
if let Ok(url) = url::Url::parse(stripped) {
if let Some(host) = url.host_str() {
if let Some(bucket) = host.split('.').next() {
let path = url.path().trim_start_matches('/');
return format!("s3://{}/{}", bucket, path);
}
}
}
return stripped.to_string();
}
if vsi_path.starts_with("s3://") {
return vsi_path.to_string();
}
vsi_path.to_string()
}
#[cfg(test)]
mod tests {
use super::*;
use tokio;
#[tokio::test]
async fn test_read_raster_band_async() {
let s3_url = "s3://dea-public-data/baseline/ga_s2bm_ard_3/56/JNS/2021/01/15/20210116T010541/ga_s2bm_nbart_3-2-1_56JNS_2021-01-15_final_band04.tif";
let result = read_raster_band_async::<u16>(s3_url, 1, (0, 0), (512, 512)).await;
assert!(result.is_ok(), "Failed to read raster band async: {:?}", result.err());
let data = result.unwrap();
assert_eq!(data.shape(), &[512, 512]);
}
}