use crate::metadata::Extent;
use crate::types::{BlockSize, GeoTransform, ImageResolution};
use crate::core_types::RasterType;
use anyhow::Result;
use anyhow::Context;
use gdal::{Dataset, DatasetOptions, DriverManager, GdalOpenFlags};
use gdal::{raster::GdalDataType, raster::RasterCreationOptions, spatial_ref::{CoordTransform, SpatialRef}};
use gdal::vector::{FieldValue, LayerAccess};
use kdam::par_tqdm;
use kdam::rayon::prelude::*;
use std::env;
use std::collections::BTreeMap;
use std::path::{Path, PathBuf};
use std::process::Command;
use uuid::Uuid;
pub fn create_rayon_pool(n_cpus: usize) -> rayon::ThreadPool {
unsafe { std::env::set_var("RAYON_NUM_THREADS", n_cpus.to_string()) };
rayon::ThreadPoolBuilder::new()
.num_threads(n_cpus)
.build()
.expect("Failed to create rayon thread pool")
}
pub fn file_stem_str(path: &Path) -> &str {
path.file_stem()
.expect("Path has no file stem")
.to_str()
.expect("File stem is not valid UTF-8")
}
pub fn open_for_update(path: &Path) -> Dataset {
let opts = DatasetOptions {
open_flags: GdalOpenFlags::GDAL_OF_UPDATE,
..DatasetOptions::default()
};
Dataset::open_ex(path, opts).expect("Failed to open dataset for update")
}
pub fn write_bands_to_file<T: RasterType>(
out_ds: &Dataset,
data: ndarray::ArrayView3<T>,
write_offset: (isize, isize),
write_size: (usize, usize),
) {
use gdal::raster::Buffer;
use ndarray::s;
for band in 0..data.shape()[0] {
let b = (band + 1) as isize;
let mut out_band = out_ds.rasterband(b as usize).expect("Failed to get raster band");
let data_vec: Vec<T> = data.slice(s![band, .., ..]).into_iter().copied().collect();
let mut data_buffer = Buffer::new(write_size, data_vec);
out_band
.write(write_offset, write_size, &mut data_buffer)
.expect("Failed to write band");
}
}
pub fn run_gdal_command(argv: &[&str]) {
Command::new(argv[0])
.args(&argv[1..])
.spawn()
.expect("failed to start gdal command")
.wait()
.expect("failed to wait for gdal command");
}
pub fn read_raster_band<T: RasterType>(
raster_path: &Path,
band_index: usize,
offset: (isize, isize),
window_size: (usize, usize),
) -> ndarray::Array2<T> {
let ds = Dataset::open(raster_path).expect(&format!("Unable to open {:?}", raster_path));
let raster_band = ds.rasterband(band_index).expect("Failed to get raster band");
let array_size = window_size;
let e_resample_alg = None;
raster_band
.read_as::<T>(offset, window_size, array_size, e_resample_alg)
.expect("Failed to read raster band")
.to_array()
.expect("Failed to convert to array")
}
pub fn mosaic_translate_cleanup(
collected: &[PathBuf],
tmp_file: &Path,
out_file: &Path,
epsg_code: u32,
) {
mosaic(collected, tmp_file, epsg_code, None, None).expect("Could not mosaic to vrt");
translate(tmp_file, out_file).expect("Could not translate to geotiff");
std::fs::remove_file(tmp_file).expect("Unable to remove the temporary file");
collected
.iter()
.for_each(|f| std::fs::remove_file(f).expect("Unable to remove file"));
}
pub fn mosaic_translate_cleanup_time_steps(
collected: &[Vec<PathBuf>],
out_file: &Path,
epsg_code: u32,
n_times: usize,
) {
par_tqdm!((0..n_times).into_par_iter()).for_each(|time_index| {
let mut block_fns = Vec::new();
for block in collected.iter() {
let block_fn = block[time_index].to_owned();
block_fns.push(block_fn);
}
let tmp_vrt = PathBuf::from(create_temp_file("vrt"));
let file_stem = file_stem_str(out_file);
let time_fn = out_file.with_file_name(format!("{}_{}.tif", file_stem, time_index));
mosaic(&block_fns, &tmp_vrt, epsg_code, None, None).expect("Could not mosaic to vrt");
translate(&tmp_vrt, &time_fn).expect("Could not translate to geotiff");
std::fs::remove_file(tmp_vrt).expect("Unable to remove the temporary file");
block_fns
.iter()
.for_each(|f| std::fs::remove_file(f).expect("Unable to remove file"));
});
}
pub fn create_temp_file(ext: &str) -> String {
let dir = env::var("TMP_DIR").unwrap_or("/tmp".to_string());
let dir = Path::new(&dir);
let file_name = format!("eorst_{}.{}", Uuid::new_v4().simple(), ext);
let file_name = dir.join(file_name);
file_name.into_os_string().into_string().unwrap()
}
pub(crate) fn warp(
source: PathBuf,
target_resolution: Option<ImageResolution>,
target_epsg: u32,
) -> PathBuf {
let new_source = create_temp_file("vrt");
let mut args: Vec<String> = vec!["gdalwarp".to_string(), "-q".to_string()];
if let Some(tr) = target_resolution {
args.extend([
"-tr".to_string(),
format!("{}", tr.x),
format!("{}", tr.y),
"-tap".to_string(),
]);
}
args.extend([
"-t_srs".to_string(),
format!("EPSG:{}", target_epsg),
source.to_string_lossy().to_string(),
new_source.clone(),
]);
let argv: Vec<&str> = args.iter().map(|s| s.as_str()).collect();
run_gdal_command(&argv);
PathBuf::from(new_source)
}
pub(crate) fn warp_with_te_tr(
source: PathBuf,
target_te: &Extent,
target_resolution: ImageResolution,
) -> PathBuf {
let new_source = create_temp_file("vrt");
let xmin_s = format!("{}", (target_te.xmin * 100.).round() / 100.);
let ymin_s = format!("{}", (target_te.ymin * 100.).round() / 100.);
let xmax_s = format!("{}", (target_te.xmax * 100.).round() / 100.);
let ymax_s = format!("{}", (target_te.ymax * 100.).round() / 100.);
let trx_s = format!("{}", (target_resolution.x * 100.).round() / 100.);
let try_s = format!("{}", (target_resolution.y * 100.).round() / 100.);
let source_s = source.to_string_lossy();
let argv = vec![
"gdalwarp", "-q", "-te", &xmin_s, &ymin_s, &xmax_s, &ymax_s,
"-tr", &trx_s, &try_s, "-r", "nearest",
&source_s, &new_source,
];
run_gdal_command(&argv);
PathBuf::from(new_source)
}
pub(crate) fn change_res(source: PathBuf, target_resolution: ImageResolution) -> PathBuf {
let new_source = create_temp_file("vrt");
let trx = format!("{}", target_resolution.x);
let try_ = format!("{}", target_resolution.y);
let source_s = source.to_string_lossy();
let argv = vec![
"gdalwarp", "-q", "-tr", &trx, &try_, &source_s, &new_source,
];
run_gdal_command(&argv);
PathBuf::from(new_source)
}
pub(crate) fn extract_band(source: &Path, band: usize) -> PathBuf {
let src_ds = match Dataset::open(source) {
Ok(ds) => ds,
Err(_) => {
let new_source = create_temp_file("vrt");
let argv = &[
"gdal_translate",
"-b",
&format!("{}", band),
"-q",
source.to_str().unwrap(),
&new_source,
];
run_gdal_command(argv);
return PathBuf::from(new_source);
}
};
let (xsize, ysize) = src_ds.raster_size();
let gt = src_ds.geo_transform().ok();
let srs_wkt = src_ds.spatial_ref().ok().map(|s| s.to_wkt().ok()).flatten();
let band_idx = band;
let band_meta = match src_ds.rasterband(band_idx) {
Ok(b) => {
let dtype = b.band_type();
let no_data = b.no_data_value();
(Some(dtype), no_data)
}
Err(_) => (None, None),
};
let wkt_part = match &srs_wkt {
Some(wkt) => format!(" <SRS>{}</SRS>\n", wkt),
None => String::new(),
};
let gt_part = match > {
Some(arr) => format!(
" <GeoTransform> {:.15}, {:.15}, {:.15}, {:.15}, {:.15}, {:.15} </GeoTransform>\n",
arr[0], arr[1], arr[2], arr[3], arr[4], arr[5]
),
None => String::new(),
};
let nd_part = match band_meta.1 {
Some(nd) => format!("<NoDataValue>{}</NoDataValue>\n", nd),
None => String::new(),
};
let source_abs = std::fs::canonicalize(source)
.unwrap_or_else(|_| source.to_path_buf())
.to_string_lossy()
.to_string();
let dtype_str = match band_meta.0 {
Some(dt) => format!("{}", dt),
None => "Unknown".to_string(),
};
let vrt_xml = format!(
r#"<VRTDataset rasterXSize="{}" rasterYSize="{}">
{wkt_part}{gt_part} <VRTRasterBand dataType="{dtype}" band="1">
<SimpleSource>
<SourceFilename relativeToVRT="0">{source}</SourceFilename>
<SourceBand>{band}</SourceBand>
{nd_part} </SimpleSource>
</VRTRasterBand>
</VRTDataset>"#,
xsize, ysize,
wkt_part = wkt_part,
gt_part = gt_part,
dtype = dtype_str,
source = source_abs,
band = band,
nd_part = nd_part,
);
let new_source = PathBuf::from(create_temp_file("vrt"));
std::fs::write(&new_source, vrt_xml).expect("Failed to write VRT XML");
PathBuf::from(new_source)
}
pub fn raster_from_size<T>(
file_name: &Path,
geo_transform: GeoTransform,
epsg_code: u32,
block_size: BlockSize,
n_bands: isize,
na_value: T,
) where
T: RasterType,
{
let parent_path = file_name.parent().unwrap();
std::fs::create_dir(parent_path).unwrap_or(());
let size_x = block_size.cols;
let size_y = block_size.rows;
let driver = DriverManager::get_driver_by_name("GTIFF").unwrap();
let options =
RasterCreationOptions::from_iter(["COMPRESS=LZW", "BLOCKXSIZE=512", "BLOCKYSIZE=512"]);
let mut dataset = driver
.create_with_band_type_with_options::<T, _>(
file_name,
size_x,
size_y,
n_bands as usize,
&options,
)
.unwrap();
dataset
.set_geo_transform(&geo_transform.to_array())
.unwrap();
let srs = SpatialRef::from_epsg(epsg_code).unwrap();
dataset.set_spatial_ref(&srs).unwrap();
for band_index in 1..n_bands + 1 {
let mut raster_band = dataset.rasterband(band_index as usize).unwrap();
let no_data_f64 = na_value
.to_f64()
.expect("Failed to convert no_data value to f64");
raster_band.set_no_data_value(Some(no_data_f64)).unwrap();
}
}
pub fn mosaic(
collected: &[PathBuf],
tmp_file: &Path,
epsg_code: u32,
extent: Option<Extent>,
resolution: Option<f64>,
) -> Result<()> {
let collected_reproj: Vec<PathBuf> = par_tqdm!(collected
.par_iter())
.map(|image| {
let new_source = create_temp_file("vrt");
let epsg_s = format!("EPSG:{}", epsg_code);
let image_s = image.to_string_lossy().to_string();
let mut argv: Vec<String> = vec![
"gdalwarp".into(),
"-q".into(),
"-t_srs".into(),
epsg_s,
];
if let Some(ref ext) = extent {
argv.push("-te".into());
argv.push(ext.xmin.to_string());
argv.push(ext.ymin.to_string());
argv.push(ext.xmax.to_string());
argv.push(ext.ymax.to_string());
}
if let Some(res) = resolution {
argv.push("-tr".into());
argv.push(res.to_string());
argv.push(res.to_string());
}
argv.push(image_s);
argv.push(new_source.clone());
let argv_refs: Vec<&str> = argv.iter().map(|s| s.as_str()).collect();
run_gdal_command(&argv_refs);
PathBuf::from(new_source)
})
.collect();
let mut argv: Vec<String> = vec![
"gdalbuildvrt".to_string(),
"-q".to_string(),
tmp_file.to_string_lossy().to_string(),
];
argv.extend(collected_reproj.iter().map(|p| p.to_string_lossy().to_string()));
let argv_refs: Vec<&str> = argv.iter().map(|s| s.as_str()).collect();
run_gdal_command(&argv_refs);
Ok(())
}
pub fn mosaic_keep_inputs(
collected: &[PathBuf],
out_file: &Path,
epsg_code: u32,
extent: Option<Extent>,
resolution: Option<f64>,
) -> Result<()> {
let tmp_file = PathBuf::from(create_temp_file("vrt"));
mosaic(collected, &tmp_file, epsg_code, extent, resolution)?;
translate(&tmp_file, out_file)?;
std::fs::remove_file(&tmp_file).ok();
Ok(())
}
pub fn translate_with_driver(tmp_fn: &Path, image_fn: &Path, driver_name: &str) -> Result<()> {
let argv = vec![
"gdal_translate",
"-q",
"-of",
driver_name,
"-co",
"BIGTIFF=YES",
"-co",
"COMPRESS=DEFLATE",
"-co",
"NUM_THREADS=16",
tmp_fn.to_str().unwrap(),
image_fn.to_str().unwrap(),
];
run_gdal_command(&argv);
Ok(())
}
pub fn translate(tmp_fn: &Path, image_fn: &Path) -> Result<()> {
translate_with_driver(tmp_fn, image_fn, "GTiff").unwrap();
Ok(())
}
pub fn translate_to_cog(
src: &Path,
dst: &Path,
compression: &str,
overview_resampling: &str,
) -> Result<()> {
let cog_driver = DriverManager::get_driver_by_name("COG")
.context("COG driver not available (requires GDAL 3.1+)")?;
let src_ds = Dataset::open(src)
.with_context(|| format!("Failed to open source GeoTIFF {:?}", src))?;
let options = RasterCreationOptions::from_iter([
format!("COMPRESS={}", compression),
"BIGTIFF=YES".to_string(),
"OVERVIEWS=AUTO".to_string(),
format!("RESAMPLING={}", overview_resampling),
"NUM_THREADS=ALL_CPUS".to_string(),
]);
let dst_str = dst.to_str()
.context("Destination path is not valid UTF-8")?;
src_ds.create_copy(&cog_driver, dst_str, &options)
.with_context(|| format!("Failed to create COG at {:?}", dst))?;
Ok(())
}
#[allow(dead_code)]
pub(crate) fn get_widest_type(source: &PathBuf) -> GdalDataType {
use log::warn;
let dataset = Dataset::open(source).unwrap();
let mut widest: Option<GdalDataType> = None;
for i in 1..=dataset.raster_count() {
let band = dataset.rasterband(i).expect("Failed to read band");
let dtype = band.band_type();
if let Some(existing) = widest {
if dtype != existing {
warn!(
"Band {} has different type ({:?}) than first band ({:?})",
i, dtype, existing
);
}
widest = Some(existing.union(dtype));
} else {
widest = Some(dtype);
}
}
widest.expect("Dataset has no bands")
}
#[allow(dead_code)]
pub fn get_class(
dataset_path: &PathBuf,
id_column: &str,
class_column: &str,
) -> BTreeMap<i16, i16> {
let dataset = Dataset::open(dataset_path).unwrap();
let mut layer = dataset.layer(0).unwrap();
let _fields_defn = layer
.defn()
.fields()
.map(|field| (field.name(), field.field_type(), field.width()))
.collect::<Vec<_>>();
let mut class: BTreeMap<i16, i16> = BTreeMap::new();
for feature in layer.features() {
let id_column_idx = feature.field_index(id_column).expect("Bad column name");
let class_column_idx = feature.field_index(class_column).expect("Bad column name");
let id = feature
.field(id_column_idx)
.unwrap()
.unwrap()
.into_int()
.unwrap();
let condition = feature
.field(class_column_idx)
.unwrap()
.unwrap_or(FieldValue::IntegerValue(-1))
.into_int()
.unwrap();
class.insert(id as i16, condition as i16);
}
class
}
#[derive(Debug, Clone)]
pub struct BasicRasterInfo {
pub geo_transform: [f64; 6],
pub size: (usize, usize),
pub epsg_code: u32,
pub no_data: Option<f64>,
pub n_bands: usize,
}
impl BasicRasterInfo {
pub fn resolution(&self) -> crate::types::ImageResolution {
crate::types::ImageResolution {
x: self.geo_transform[1],
y: self.geo_transform[5],
}
}
pub fn geo_transform_struct(&self) -> crate::types::GeoTransform {
crate::types::GeoTransform {
x_ul: self.geo_transform[0],
x_res: self.geo_transform[1],
x_rot: self.geo_transform[2],
y_ul: self.geo_transform[3],
y_rot: self.geo_transform[4],
y_res: self.geo_transform[5],
}
}
pub fn image_size(&self) -> crate::types::ImageSize {
crate::types::ImageSize {
rows: self.size.1,
cols: self.size.0,
}
}
pub fn na_value<T: crate::core_types::RasterType>(&self) -> T {
self.no_data
.and_then(|v| num_traits::NumCast::from(v))
.unwrap_or(T::zero())
}
}
pub fn read_basic_raster_info(source: &Path) -> BasicRasterInfo {
let ds = Dataset::open(source).expect(&format!("Unable to open {:?}", source));
let geo_transform = ds.geo_transform().expect("Failed to get geo transform");
let size = ds.raster_size();
let spatial_ref = ds.spatial_ref().expect("Failed to get spatial ref");
let epsg_code = spatial_ref.auth_code().unwrap_or(0);
let mut no_data = None;
if size.0 > 0 && size.1 > 0 {
if let Ok(band) = ds.rasterband(1) {
no_data = band.no_data_value();
}
}
BasicRasterInfo {
geo_transform,
size,
epsg_code: epsg_code as u32,
no_data,
n_bands: ds.raster_count(),
}
}
fn compute_single_raster_extent(source: &Path, target_epsg: u32) -> Result<Extent> {
let ds = Dataset::open(source)?;
let gt = ds.geo_transform()?;
let (cols, rows) = ds.raster_size();
let src_srs = ds.spatial_ref()?;
let src_epsg = src_srs.auth_code().unwrap_or(0) as u32;
let corners = [
(gt[0], gt[3]), (gt[0] + gt[1] * cols as f64, gt[3]), (gt[0] + gt[1] * cols as f64, gt[3] + gt[5] * rows as f64), (gt[0], gt[3] + gt[5] * rows as f64), ];
if src_epsg == target_epsg {
let xs: Vec<f64> = corners.iter().map(|(x, _)| *x).collect();
let ys: Vec<f64> = corners.iter().map(|(_, y)| *y).collect();
Ok(Extent {
xmin: xs.iter().cloned().fold(f64::INFINITY, f64::min),
ymin: ys.iter().cloned().fold(f64::INFINITY, f64::min),
xmax: xs.iter().cloned().fold(f64::NEG_INFINITY, f64::max),
ymax: ys.iter().cloned().fold(f64::NEG_INFINITY, f64::max),
})
} else {
let target_srs = SpatialRef::from_epsg(target_epsg)?;
let ct = CoordTransform::new(&src_srs, &target_srs)?;
let mut xs = corners.iter().map(|(x, _)| *x).collect::<Vec<_>>();
let mut ys = corners.iter().map(|(_, y)| *y).collect::<Vec<_>>();
let mut zs = vec![0.0; xs.len()];
ct.transform_coords(&mut xs, &mut ys, &mut zs)?;
Ok(Extent {
xmin: xs.iter().cloned().fold(f64::INFINITY, f64::min),
ymin: ys.iter().cloned().fold(f64::INFINITY, f64::min),
xmax: xs.iter().cloned().fold(f64::NEG_INFINITY, f64::max),
ymax: ys.iter().cloned().fold(f64::NEG_INFINITY, f64::max),
})
}
}
pub fn compute_raster_union_extent(files: &[PathBuf], target_epsg: u32) -> Result<Extent> {
let mut union_extent = compute_single_raster_extent(&files[0], target_epsg)?;
for file in &files[1..] {
let extent = compute_single_raster_extent(file, target_epsg)?;
union_extent = union_extent.union(&extent);
}
Ok(union_extent)
}
pub fn compute_vector_extent(vector_path: &Path, target_epsg: u32) -> Result<Extent> {
let ds = Dataset::open(vector_path)?;
let layer = ds.layer(0)?;
let ext = layer.get_extent()?;
let src_srs = layer.spatial_ref().ok_or_else(|| anyhow::anyhow!("Vector layer has no spatial reference"))?;
let src_epsg = src_srs.auth_code().unwrap_or(0) as u32;
let corners = [
(ext.MinX, ext.MinY),
(ext.MaxX, ext.MinY),
(ext.MaxX, ext.MaxY),
(ext.MinX, ext.MaxY),
];
if src_epsg == target_epsg {
Ok(Extent {
xmin: ext.MinX,
ymin: ext.MinY,
xmax: ext.MaxX,
ymax: ext.MaxY,
})
} else {
let target_srs = SpatialRef::from_epsg(target_epsg)?;
let ct = CoordTransform::new(&src_srs, &target_srs)?;
let mut xs = corners.iter().map(|(x, _)| *x).collect::<Vec<_>>();
let mut ys = corners.iter().map(|(_, y)| *y).collect::<Vec<_>>();
let mut zs = vec![0.0; xs.len()];
ct.transform_coords(&mut xs, &mut ys, &mut zs)?;
Ok(Extent {
xmin: xs.iter().cloned().fold(f64::INFINITY, f64::min),
ymin: ys.iter().cloned().fold(f64::INFINITY, f64::min),
xmax: xs.iter().cloned().fold(f64::NEG_INFINITY, f64::max),
ymax: ys.iter().cloned().fold(f64::NEG_INFINITY, f64::max),
})
}
}