use std::fs::File;
use std::path::Path;
use memmap2::Mmap;
use crate::error::{Result, SrtmError};
const SRTM1_SIZE: usize = 3601 * 3601 * 2;
const SRTM3_SIZE: usize = 1201 * 1201 * 2;
const SRTM1_SAMPLES: usize = 3601;
const SRTM3_SAMPLES: usize = 1201;
pub const VOID_VALUE: i16 = -32768;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SrtmResolution {
Srtm1,
Srtm3,
}
impl SrtmResolution {
pub fn samples(&self) -> usize {
match self {
SrtmResolution::Srtm1 => SRTM1_SAMPLES,
SrtmResolution::Srtm3 => SRTM3_SAMPLES,
}
}
pub fn meters(&self) -> f64 {
match self {
SrtmResolution::Srtm1 => 30.0,
SrtmResolution::Srtm3 => 90.0,
}
}
}
pub struct SrtmTile {
data: Mmap,
samples: usize,
resolution: SrtmResolution,
base_lat: i32,
base_lon: i32,
}
impl SrtmTile {
pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self> {
Self::from_file_with_coords(path, 0, 0)
}
pub fn from_file_with_coords<P: AsRef<Path>>(
path: P,
base_lat: i32,
base_lon: i32,
) -> Result<Self> {
let file = File::open(&path)?;
let mmap = unsafe { Mmap::map(&file)? };
#[cfg(unix)]
{
use memmap2::Advice;
let _ = mmap.advise(Advice::Random);
}
let (samples, resolution) = match mmap.len() {
SRTM1_SIZE => (SRTM1_SAMPLES, SrtmResolution::Srtm1),
SRTM3_SIZE => (SRTM3_SAMPLES, SrtmResolution::Srtm3),
size => return Err(SrtmError::InvalidFileSize { size }),
};
Ok(Self {
data: mmap,
samples,
resolution,
base_lat,
base_lon,
})
}
pub fn get_elevation(&self, lat: f64, lon: f64) -> Result<i16> {
self.get_elevation_inner(lat, lon, f64::round)
}
pub fn get_elevation_floor(&self, lat: f64, lon: f64) -> Result<i16> {
self.get_elevation_inner(lat, lon, f64::floor)
}
fn get_elevation_inner(&self, lat: f64, lon: f64, rounding_fn: fn(f64) -> f64) -> Result<i16> {
let lat_frac = lat - lat.floor();
let lon_frac = lon - lon.floor();
if !(0.0..=1.0).contains(&lat_frac) || !(0.0..=1.0).contains(&lon_frac) {
return Err(SrtmError::OutOfBounds { lat, lon });
}
let row = rounding_fn((1.0 - lat_frac) * (self.samples - 1) as f64) as usize;
let col = rounding_fn(lon_frac * (self.samples - 1) as f64) as usize;
Ok(self.get_elevation_at(row, col))
}
pub fn get_elevation_interpolated(&self, lat: f64, lon: f64) -> Result<Option<f64>> {
let lat_frac = lat - lat.floor();
let lon_frac = lon - lon.floor();
if !(0.0..=1.0).contains(&lat_frac) || !(0.0..=1.0).contains(&lon_frac) {
return Err(SrtmError::OutOfBounds { lat, lon });
}
let row_pos = (1.0 - lat_frac) * (self.samples - 1) as f64;
let col_pos = lon_frac * (self.samples - 1) as f64;
let row0 = row_pos.floor() as usize;
let col0 = col_pos.floor() as usize;
let row1 = (row0 + 1).min(self.samples - 1);
let col1 = (col0 + 1).min(self.samples - 1);
let row_weight = row_pos - row0 as f64;
let col_weight = col_pos - col0 as f64;
let v00 = self.get_elevation_at(row0, col0);
let v10 = self.get_elevation_at(row0, col1);
let v01 = self.get_elevation_at(row1, col0);
let v11 = self.get_elevation_at(row1, col1);
if v00 == VOID_VALUE || v10 == VOID_VALUE || v01 == VOID_VALUE || v11 == VOID_VALUE {
return Ok(None);
}
let v0 = v00 as f64 + (v10 as f64 - v00 as f64) * col_weight;
let v1 = v01 as f64 + (v11 as f64 - v01 as f64) * col_weight;
let elevation = v0 + (v1 - v0) * row_weight;
Ok(Some(elevation))
}
#[inline(always)]
fn get_elevation_at(&self, row: usize, col: usize) -> i16 {
let row = row.min(self.samples - 1);
let col = col.min(self.samples - 1);
let offset = (row * self.samples + col) * 2;
debug_assert!(offset + 1 < self.data.len());
unsafe {
i16::from_be_bytes([
*self.data.get_unchecked(offset),
*self.data.get_unchecked(offset + 1),
])
}
}
pub fn resolution(&self) -> SrtmResolution {
self.resolution
}
pub fn samples(&self) -> usize {
self.samples
}
pub fn base_lat(&self) -> i32 {
self.base_lat
}
pub fn base_lon(&self) -> i32 {
self.base_lon
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
use tempfile::NamedTempFile;
fn create_test_srtm3_file() -> NamedTempFile {
let mut file = NamedTempFile::new().unwrap();
let mut data = vec![0u8; SRTM3_SIZE];
data[0] = 0x03;
data[1] = 0xE8;
let center_offset = (600 * SRTM3_SAMPLES + 600) * 2;
data[center_offset] = 0x01;
data[center_offset + 1] = 0xF4;
let se_offset = (1200 * SRTM3_SAMPLES + 1200) * 2;
data[se_offset] = 0x00;
data[se_offset + 1] = 0x64;
file.write_all(&data).unwrap();
file
}
#[test]
fn test_load_srtm3_file() {
let file = create_test_srtm3_file();
let tile = SrtmTile::from_file(file.path()).unwrap();
assert_eq!(tile.resolution(), SrtmResolution::Srtm3);
assert_eq!(tile.samples(), SRTM3_SAMPLES);
}
#[test]
fn test_invalid_file_size() {
let mut file = NamedTempFile::new().unwrap();
file.write_all(&vec![0u8; 1000]).unwrap();
let result = SrtmTile::from_file(file.path());
assert!(result.is_err());
if let Err(SrtmError::InvalidFileSize { size }) = result {
assert_eq!(size, 1000);
} else {
panic!("Expected InvalidFileSize error");
}
}
#[test]
fn test_get_elevation_corners() {
let file = create_test_srtm3_file();
let tile = SrtmTile::from_file_with_coords(file.path(), 35, 138).unwrap();
let elev = tile.get_elevation(35.9999, 138.0001).unwrap();
assert!(elev >= 0, "Elevation should be non-negative");
let elev = tile.get_elevation(35.0001, 138.9999).unwrap();
assert!(elev >= 0, "Elevation should be non-negative");
}
#[test]
fn test_get_elevation_center() {
let file = create_test_srtm3_file();
let tile = SrtmTile::from_file_with_coords(file.path(), 35, 138).unwrap();
let elev = tile.get_elevation(35.5, 138.5).unwrap();
assert_eq!(elev, 500);
}
#[test]
fn test_resolution_info() {
assert_eq!(SrtmResolution::Srtm1.samples(), 3601);
assert_eq!(SrtmResolution::Srtm3.samples(), 1201);
assert_eq!(SrtmResolution::Srtm1.meters(), 30.0);
assert_eq!(SrtmResolution::Srtm3.meters(), 90.0);
}
fn create_interpolation_test_file() -> NamedTempFile {
let mut file = NamedTempFile::new().unwrap();
let mut data = vec![0u8; SRTM3_SIZE];
let set_elevation = |data: &mut Vec<u8>, row: usize, col: usize, elev: i16| {
let offset = (row * SRTM3_SAMPLES + col) * 2;
let bytes = elev.to_be_bytes();
data[offset] = bytes[0];
data[offset + 1] = bytes[1];
};
set_elevation(&mut data, 600, 600, 100);
set_elevation(&mut data, 600, 601, 200);
set_elevation(&mut data, 601, 600, 300);
set_elevation(&mut data, 601, 601, 400);
file.write_all(&data).unwrap();
file
}
#[test]
fn test_interpolation_at_grid_points() {
let file = create_interpolation_test_file();
let tile = SrtmTile::from_file_with_coords(file.path(), 35, 138).unwrap();
let elev = tile.get_elevation_interpolated(35.5, 138.5).unwrap();
assert!(elev.is_some());
let elev = elev.unwrap();
assert!((elev - 100.0).abs() < 1.0, "Expected ~100, got {}", elev);
}
#[test]
fn test_interpolation_midpoint() {
let file = create_interpolation_test_file();
let tile = SrtmTile::from_file_with_coords(file.path(), 35, 138).unwrap();
let lat = 35.0 + (1.0 - 600.5 / 1200.0);
let lon = 138.0 + 600.5 / 1200.0;
let elev = tile.get_elevation_interpolated(lat, lon).unwrap();
assert!(elev.is_some());
let elev = elev.unwrap();
assert!((elev - 250.0).abs() < 5.0, "Expected ~250, got {}", elev);
}
#[test]
fn test_interpolation_horizontal() {
let file = create_interpolation_test_file();
let tile = SrtmTile::from_file_with_coords(file.path(), 35, 138).unwrap();
let lat = 35.0 + (1.0 - 600.0 / 1200.0); let lon = 138.0 + 600.5 / 1200.0;
let elev = tile.get_elevation_interpolated(lat, lon).unwrap();
assert!(elev.is_some());
let elev = elev.unwrap();
assert!((elev - 150.0).abs() < 10.0, "Expected ~150, got {}", elev);
}
#[test]
fn test_interpolation_void_value() {
let mut file = NamedTempFile::new().unwrap();
let mut data = vec![0u8; SRTM3_SIZE];
let void_bytes = VOID_VALUE.to_be_bytes();
let offset = (600 * SRTM3_SAMPLES + 600) * 2;
data[offset] = void_bytes[0];
data[offset + 1] = void_bytes[1];
let set_elevation = |data: &mut Vec<u8>, row: usize, col: usize, elev: i16| {
let offset = (row * SRTM3_SAMPLES + col) * 2;
let bytes = elev.to_be_bytes();
data[offset] = bytes[0];
data[offset + 1] = bytes[1];
};
set_elevation(&mut data, 600, 601, 200);
set_elevation(&mut data, 601, 600, 300);
set_elevation(&mut data, 601, 601, 400);
file.write_all(&data).unwrap();
let tile = SrtmTile::from_file_with_coords(file.path(), 35, 138).unwrap();
let lat = 35.0 + (1.0 - 600.5 / 1200.0);
let lon = 138.0 + 600.5 / 1200.0;
let elev = tile.get_elevation_interpolated(lat, lon).unwrap();
assert!(elev.is_none(), "Expected None for void area");
}
fn create_rounding_test_file() -> NamedTempFile {
let mut file = NamedTempFile::new().unwrap();
let mut data = vec![0u8; SRTM3_SIZE];
let set_elevation = |data: &mut Vec<u8>, row: usize, col: usize, elev: i16| {
let offset = (row * SRTM3_SAMPLES + col) * 2;
let bytes = elev.to_be_bytes();
data[offset] = bytes[0];
data[offset + 1] = bytes[1];
};
set_elevation(&mut data, 786, 1008, 191);
set_elevation(&mut data, 786, 1009, 190);
file.write_all(&data).unwrap();
file
}
#[test]
fn test_floor_vs_round_different_results() {
let file = create_rounding_test_file();
let tile = SrtmTile::from_file_with_coords(file.path(), 33, -97).unwrap();
let lat = 33.3448;
let lon = -96.1592;
let elev_round = tile.get_elevation(lat, lon).unwrap();
let elev_floor = tile.get_elevation_floor(lat, lon).unwrap();
assert_eq!(elev_round, 190, "round should select col 1009");
assert_eq!(elev_floor, 191, "floor should select col 1008");
}
#[test]
fn test_floor_matches_round_at_exact_grid() {
let file = create_test_srtm3_file();
let tile = SrtmTile::from_file_with_coords(file.path(), 35, 138).unwrap();
let elev_round = tile.get_elevation(35.5, 138.5).unwrap();
let elev_floor = tile.get_elevation_floor(35.5, 138.5).unwrap();
assert_eq!(elev_round, elev_floor);
assert_eq!(elev_round, 500);
}
}