Skip to main content

proj_core/
grid.rs

1use crate::operation::{AreaOfUse, GridId, GridInterpolation, GridShiftDirection};
2use smallvec::SmallVec;
3use std::collections::HashMap;
4use std::f64::consts::PI;
5use std::path::{Component, Path, PathBuf};
6use std::sync::{Arc, Mutex, OnceLock};
7use thiserror::Error;
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
10pub enum GridFormat {
11    /// NTv2 horizontal datum-shift grid (`.gsb`).
12    Ntv2,
13    /// NOAA/VDatum binary GTX vertical offset grid (`.gtx`).
14    Gtx,
15    Unsupported,
16}
17
18#[derive(Debug, Clone, PartialEq)]
19pub struct GridDefinition {
20    pub id: GridId,
21    pub name: String,
22    pub format: GridFormat,
23    pub interpolation: GridInterpolation,
24    pub area_of_use: Option<AreaOfUse>,
25    pub resource_names: SmallVec<[String; 2]>,
26}
27
28#[derive(Debug, Clone, Copy, PartialEq)]
29pub struct GridSample {
30    pub lon_shift_radians: f64,
31    pub lat_shift_radians: f64,
32}
33
34#[derive(Debug, Clone, Copy, PartialEq)]
35pub struct VerticalGridSample {
36    /// Vertical offset in meters at the sampled horizontal position.
37    pub offset_meters: f64,
38}
39
40#[derive(Debug, Error, Clone)]
41pub enum GridError {
42    #[error("grid not found: {0}")]
43    NotFound(String),
44    #[error("grid resource unavailable: {0}")]
45    Unavailable(String),
46    #[error("grid parse error: {0}")]
47    Parse(String),
48    #[error("grid point outside coverage: {0}")]
49    OutsideCoverage(String),
50    #[error("unsupported grid format: {0}")]
51    UnsupportedFormat(String),
52}
53
54pub trait GridProvider: Send + Sync {
55    fn definition(
56        &self,
57        grid: &GridDefinition,
58    ) -> std::result::Result<Option<GridDefinition>, GridError>;
59    fn load(&self, grid: &GridDefinition) -> std::result::Result<Option<GridHandle>, GridError>;
60}
61
62#[derive(Clone)]
63pub struct GridHandle {
64    definition: GridDefinition,
65    data: Arc<CachedGridData>,
66}
67
68impl GridHandle {
69    /// Parse a grid resource into a handle.
70    ///
71    /// Custom [`GridProvider`] implementations can use this constructor after
72    /// loading bytes from their own package, object store, or manifest.
73    pub fn from_bytes(
74        definition: GridDefinition,
75        bytes: &[u8],
76    ) -> std::result::Result<Self, GridError> {
77        Ok(Self {
78            data: Arc::new(parse_cached_grid_data(
79                definition.format,
80                &definition.name,
81                bytes,
82            )?),
83            definition,
84        })
85    }
86
87    pub fn definition(&self) -> &GridDefinition {
88        &self.definition
89    }
90
91    pub fn checksum(&self) -> &str {
92        &self.data.checksum
93    }
94
95    pub fn sample(
96        &self,
97        lon_radians: f64,
98        lat_radians: f64,
99    ) -> std::result::Result<GridSample, GridError> {
100        match &self.data.data {
101            GridData::Ntv2(set) => set.sample(lon_radians, lat_radians),
102            GridData::Gtx(_) => Err(GridError::UnsupportedFormat(format!(
103                "{} is a vertical grid",
104                self.definition.name
105            ))),
106        }
107    }
108
109    pub fn sample_vertical_offset_meters(
110        &self,
111        lon_radians: f64,
112        lat_radians: f64,
113    ) -> std::result::Result<VerticalGridSample, GridError> {
114        match &self.data.data {
115            GridData::Gtx(grid) => grid.sample(lon_radians, lat_radians),
116            GridData::Ntv2(_) => Err(GridError::UnsupportedFormat(format!(
117                "{} is a horizontal grid",
118                self.definition.name
119            ))),
120        }
121    }
122
123    pub fn apply(
124        &self,
125        lon_radians: f64,
126        lat_radians: f64,
127        direction: GridShiftDirection,
128    ) -> std::result::Result<(f64, f64), GridError> {
129        match &self.data.data {
130            GridData::Ntv2(set) => set.apply(lon_radians, lat_radians, direction),
131            GridData::Gtx(_) => Err(GridError::UnsupportedFormat(format!(
132                "{} is a vertical grid",
133                self.definition.name
134            ))),
135        }
136    }
137}
138
139pub(crate) struct GridRuntime {
140    providers: Vec<Arc<dyn GridProvider>>,
141    definition_cache: Mutex<HashMap<String, GridDefinition>>,
142    handle_cache: Mutex<HashMap<String, GridHandle>>,
143}
144
145impl GridRuntime {
146    pub(crate) fn new(app_provider: Option<Arc<dyn GridProvider>>) -> Self {
147        let mut providers: Vec<Arc<dyn GridProvider>> = Vec::with_capacity(2);
148        if let Some(provider) = app_provider {
149            providers.push(provider);
150        }
151        providers.push(Arc::new(EmbeddedGridProvider));
152        Self {
153            providers,
154            definition_cache: Mutex::new(HashMap::new()),
155            handle_cache: Mutex::new(HashMap::new()),
156        }
157    }
158
159    pub(crate) fn resolve_definition(
160        &self,
161        grid: &GridDefinition,
162    ) -> std::result::Result<GridDefinition, GridError> {
163        let cache_key = grid_runtime_cache_key(grid);
164        if let Some(cached) = self
165            .definition_cache
166            .lock()
167            .expect("grid definition cache poisoned")
168            .get(&cache_key)
169            .cloned()
170        {
171            return Ok(cached);
172        }
173
174        for provider in &self.providers {
175            if let Some(definition) = provider.definition(grid)? {
176                self.definition_cache
177                    .lock()
178                    .expect("grid definition cache poisoned")
179                    .insert(cache_key, definition.clone());
180                return Ok(definition);
181            }
182        }
183
184        Err(GridError::Unavailable(grid.name.clone()))
185    }
186
187    pub(crate) fn resolve_handle(
188        &self,
189        grid: &GridDefinition,
190    ) -> std::result::Result<GridHandle, GridError> {
191        let cache_key = grid_runtime_cache_key(grid);
192        if let Some(cached) = self
193            .handle_cache
194            .lock()
195            .expect("grid handle cache poisoned")
196            .get(&cache_key)
197            .cloned()
198        {
199            return Ok(cached);
200        }
201
202        let definition = self.resolve_definition(grid)?;
203        for provider in &self.providers {
204            if let Some(handle) = provider.load(&definition)? {
205                self.handle_cache
206                    .lock()
207                    .expect("grid handle cache poisoned")
208                    .insert(cache_key, handle.clone());
209                return Ok(handle);
210            }
211        }
212
213        Err(GridError::Unavailable(definition.name))
214    }
215}
216
217fn grid_runtime_cache_key(grid: &GridDefinition) -> String {
218    let mut key = format!("{}|{:?}", grid.id.0, grid.format);
219    for resource in &grid.resource_names {
220        key.push('|');
221        key.push_str(resource);
222    }
223    key
224}
225
226#[derive(Default)]
227pub struct EmbeddedGridProvider;
228
229impl GridProvider for EmbeddedGridProvider {
230    fn definition(
231        &self,
232        grid: &GridDefinition,
233    ) -> std::result::Result<Option<GridDefinition>, GridError> {
234        if embedded_grid_resource(&grid.resource_names).is_some() {
235            return Ok(Some(grid.clone()));
236        }
237        Ok(None)
238    }
239
240    fn load(&self, grid: &GridDefinition) -> std::result::Result<Option<GridHandle>, GridError> {
241        let Some((resource_name, bytes)) = embedded_grid_resource(&grid.resource_names) else {
242            return Ok(None);
243        };
244
245        let key = GridDataCacheKey::new(grid.format, resource_name);
246        let data = cached_grid_data(embedded_grid_data_cache(), key, || {
247            parse_cached_grid_data(grid.format, &grid.name, bytes)
248        })?;
249
250        Ok(Some(GridHandle {
251            definition: grid.clone(),
252            data,
253        }))
254    }
255}
256
257pub struct FilesystemGridProvider {
258    roots: Vec<PathBuf>,
259    data_cache: Mutex<HashMap<GridDataCacheKey, Arc<CachedGridData>>>,
260}
261
262impl FilesystemGridProvider {
263    pub fn new<I>(roots: I) -> Self
264    where
265        I: IntoIterator<Item = PathBuf>,
266    {
267        Self {
268            roots: roots.into_iter().collect(),
269            data_cache: Mutex::new(HashMap::new()),
270        }
271    }
272
273    fn locate(&self, grid: &GridDefinition) -> Option<PathBuf> {
274        for root in &self.roots {
275            let Ok(root) = root.canonicalize() else {
276                continue;
277            };
278            for name in &grid.resource_names {
279                if !is_safe_grid_resource_name(name) {
280                    continue;
281                }
282                let candidate = root.join(name);
283                let Ok(canonical_candidate) = candidate.canonicalize() else {
284                    continue;
285                };
286                if canonical_candidate.starts_with(&root) && canonical_candidate.is_file() {
287                    return Some(canonical_candidate);
288                }
289            }
290        }
291        None
292    }
293}
294
295impl GridProvider for FilesystemGridProvider {
296    fn definition(
297        &self,
298        grid: &GridDefinition,
299    ) -> std::result::Result<Option<GridDefinition>, GridError> {
300        if self.locate(grid).is_some() {
301            return Ok(Some(grid.clone()));
302        }
303        Ok(None)
304    }
305
306    fn load(&self, grid: &GridDefinition) -> std::result::Result<Option<GridHandle>, GridError> {
307        let Some(path) = self.locate(grid) else {
308            return Ok(None);
309        };
310
311        let cache_path = path.canonicalize().unwrap_or_else(|_| path.clone());
312        let key = GridDataCacheKey::new(grid.format, cache_path.to_string_lossy());
313        let data = cached_grid_data(&self.data_cache, key, || {
314            let bytes = std::fs::read(&path)
315                .map_err(|err| GridError::Unavailable(format!("{}: {err}", path.display())))?;
316            parse_cached_grid_data(grid.format, &grid.name, &bytes)
317        })?;
318
319        Ok(Some(GridHandle {
320            definition: grid.clone(),
321            data,
322        }))
323    }
324}
325
326fn is_safe_grid_resource_name(name: &str) -> bool {
327    let path = Path::new(name);
328    if path.as_os_str().is_empty() {
329        return false;
330    }
331    path.components()
332        .all(|component| matches!(component, Component::Normal(_)))
333}
334
335enum GridData {
336    Ntv2(Ntv2GridSet),
337    Gtx(GtxGrid),
338}
339
340struct CachedGridData {
341    data: GridData,
342    checksum: String,
343}
344
345#[derive(Debug, Clone, PartialEq, Eq, Hash)]
346struct GridDataCacheKey {
347    format: GridFormat,
348    resource: String,
349}
350
351impl GridDataCacheKey {
352    fn new(format: GridFormat, resource: impl AsRef<str>) -> Self {
353        Self {
354            format,
355            resource: resource.as_ref().to_string(),
356        }
357    }
358}
359
360fn embedded_grid_data_cache() -> &'static Mutex<HashMap<GridDataCacheKey, Arc<CachedGridData>>> {
361    static CACHE: OnceLock<Mutex<HashMap<GridDataCacheKey, Arc<CachedGridData>>>> = OnceLock::new();
362    CACHE.get_or_init(|| Mutex::new(HashMap::new()))
363}
364
365fn cached_grid_data(
366    cache: &Mutex<HashMap<GridDataCacheKey, Arc<CachedGridData>>>,
367    key: GridDataCacheKey,
368    parse: impl FnOnce() -> std::result::Result<CachedGridData, GridError>,
369) -> std::result::Result<Arc<CachedGridData>, GridError> {
370    if let Some(cached) = cache
371        .lock()
372        .expect("grid data cache poisoned")
373        .get(&key)
374        .cloned()
375    {
376        return Ok(cached);
377    }
378
379    let parsed = Arc::new(parse()?);
380    let mut cache = cache.lock().expect("grid data cache poisoned");
381    let cached = cache.entry(key).or_insert_with(|| Arc::clone(&parsed));
382    Ok(Arc::clone(cached))
383}
384
385fn parse_grid_data(
386    format: GridFormat,
387    name: &str,
388    bytes: &[u8],
389) -> std::result::Result<GridData, GridError> {
390    match format {
391        GridFormat::Ntv2 => Ok(GridData::Ntv2(Ntv2GridSet::parse(bytes)?)),
392        GridFormat::Gtx => Ok(GridData::Gtx(GtxGrid::parse(bytes)?)),
393        GridFormat::Unsupported => Err(GridError::UnsupportedFormat(name.into())),
394    }
395}
396
397fn parse_cached_grid_data(
398    format: GridFormat,
399    name: &str,
400    bytes: &[u8],
401) -> std::result::Result<CachedGridData, GridError> {
402    Ok(CachedGridData {
403        data: parse_grid_data(format, name, bytes)?,
404        checksum: sha256_hex(bytes),
405    })
406}
407
408fn sha256_hex(bytes: &[u8]) -> String {
409    const H0: [u32; 8] = [
410        0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, 0x510e527f, 0x9b05688c, 0x1f83d9ab,
411        0x5be0cd19,
412    ];
413    const K: [u32; 64] = [
414        0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4,
415        0xab1c5ed5, 0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe,
416        0x9bdc06a7, 0xc19bf174, 0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f,
417        0x4a7484aa, 0x5cb0a9dc, 0x76f988da, 0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7,
418        0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967, 0x27b70a85, 0x2e1b2138, 0x4d2c6dfc,
419        0x53380d13, 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85, 0xa2bfe8a1, 0xa81a664b,
420        0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070, 0x19a4c116,
421        0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3,
422        0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7,
423        0xc67178f2,
424    ];
425
426    let bit_len = (bytes.len() as u64).wrapping_mul(8);
427    let mut padded = Vec::with_capacity((bytes.len() + 72).div_ceil(64) * 64);
428    padded.extend_from_slice(bytes);
429    padded.push(0x80);
430    while (padded.len() % 64) != 56 {
431        padded.push(0);
432    }
433    padded.extend_from_slice(&bit_len.to_be_bytes());
434
435    let mut h = H0;
436    let mut w = [0u32; 64];
437    for chunk in padded.chunks_exact(64) {
438        for (i, word) in w.iter_mut().take(16).enumerate() {
439            *word = u32::from_be_bytes(
440                chunk[i * 4..i * 4 + 4]
441                    .try_into()
442                    .expect("slice length checked"),
443            );
444        }
445        for i in 16..64 {
446            let s0 = w[i - 15].rotate_right(7) ^ w[i - 15].rotate_right(18) ^ (w[i - 15] >> 3);
447            let s1 = w[i - 2].rotate_right(17) ^ w[i - 2].rotate_right(19) ^ (w[i - 2] >> 10);
448            w[i] = w[i - 16]
449                .wrapping_add(s0)
450                .wrapping_add(w[i - 7])
451                .wrapping_add(s1);
452        }
453
454        let mut a = h[0];
455        let mut b = h[1];
456        let mut c = h[2];
457        let mut d = h[3];
458        let mut e = h[4];
459        let mut f = h[5];
460        let mut g = h[6];
461        let mut hh = h[7];
462
463        for i in 0..64 {
464            let s1 = e.rotate_right(6) ^ e.rotate_right(11) ^ e.rotate_right(25);
465            let ch = (e & f) ^ ((!e) & g);
466            let temp1 = hh
467                .wrapping_add(s1)
468                .wrapping_add(ch)
469                .wrapping_add(K[i])
470                .wrapping_add(w[i]);
471            let s0 = a.rotate_right(2) ^ a.rotate_right(13) ^ a.rotate_right(22);
472            let maj = (a & b) ^ (a & c) ^ (b & c);
473            let temp2 = s0.wrapping_add(maj);
474
475            hh = g;
476            g = f;
477            f = e;
478            e = d.wrapping_add(temp1);
479            d = c;
480            c = b;
481            b = a;
482            a = temp1.wrapping_add(temp2);
483        }
484
485        h[0] = h[0].wrapping_add(a);
486        h[1] = h[1].wrapping_add(b);
487        h[2] = h[2].wrapping_add(c);
488        h[3] = h[3].wrapping_add(d);
489        h[4] = h[4].wrapping_add(e);
490        h[5] = h[5].wrapping_add(f);
491        h[6] = h[6].wrapping_add(g);
492        h[7] = h[7].wrapping_add(hh);
493    }
494
495    let mut out = String::with_capacity(71);
496    out.push_str("sha256:");
497    for word in h {
498        use std::fmt::Write as _;
499        write!(&mut out, "{word:08x}").expect("writing to string cannot fail");
500    }
501    out
502}
503
504fn embedded_grid_resource(names: &[String]) -> Option<(&'static str, &'static [u8])> {
505    for name in names {
506        if name.eq_ignore_ascii_case("ntv2_0.gsb") {
507            return Some(("ntv2_0.gsb", include_bytes!("../data/grids/ntv2_0.gsb")));
508        }
509    }
510    None
511}
512
513#[derive(Clone)]
514struct Ntv2GridSet {
515    grids: Vec<Ntv2Grid>,
516    roots: Vec<usize>,
517}
518
519impl Ntv2GridSet {
520    fn parse(bytes: &[u8]) -> std::result::Result<Self, GridError> {
521        const HEADER_LEN: usize = 11 * 16;
522
523        if bytes.len() < HEADER_LEN {
524            return Err(GridError::Parse("NTv2 file too small".into()));
525        }
526
527        let endian = if u32::from_le_bytes(bytes[8..12].try_into().expect("slice length checked"))
528            == 11
529        {
530            Endian::Little
531        } else if u32::from_be_bytes(bytes[8..12].try_into().expect("slice length checked")) == 11 {
532            Endian::Big
533        } else {
534            return Err(GridError::Parse(
535                "invalid NTv2 header endianness marker".into(),
536            ));
537        };
538
539        if &bytes[56..63] != b"SECONDS" {
540            return Err(GridError::Parse(
541                "only NTv2 GS_TYPE=SECONDS is supported".into(),
542            ));
543        }
544
545        let num_subfiles = read_u32(bytes, 40, endian)? as usize;
546        let mut offset = HEADER_LEN;
547        let mut grids = Vec::with_capacity(num_subfiles);
548        let mut name_to_index = HashMap::new();
549        let mut parent_links: Vec<Option<String>> = Vec::with_capacity(num_subfiles);
550
551        for _ in 0..num_subfiles {
552            let header = bytes
553                .get(offset..offset + HEADER_LEN)
554                .ok_or_else(|| GridError::Parse("truncated NTv2 subfile header".into()))?;
555            if &header[0..8] != b"SUB_NAME" {
556                return Err(GridError::Parse("invalid NTv2 subfile header tag".into()));
557            }
558
559            let name = parse_label(&header[8..16]);
560            let parent = parse_label(&header[24..32]);
561            let south = read_f64(header, 72, endian)? * PI / 180.0 / 3600.0;
562            let north = read_f64(header, 88, endian)? * PI / 180.0 / 3600.0;
563            let east = -read_f64(header, 104, endian)? * PI / 180.0 / 3600.0;
564            let west = -read_f64(header, 120, endian)? * PI / 180.0 / 3600.0;
565            let res_y = read_f64(header, 136, endian)? * PI / 180.0 / 3600.0;
566            let res_x = read_f64(header, 152, endian)? * PI / 180.0 / 3600.0;
567            let gs_count = read_u32(header, 168, endian)? as usize;
568
569            if !(west < east && south < north && res_x > 0.0 && res_y > 0.0) {
570                return Err(GridError::Parse(format!(
571                    "invalid NTv2 georeferencing for subgrid {name}"
572                )));
573            }
574
575            let width = (((east - west) / res_x).abs() + 0.5).floor() as usize + 1;
576            let height = (((north - south) / res_y).abs() + 0.5).floor() as usize + 1;
577            if width * height != gs_count {
578                return Err(GridError::Parse(format!(
579                    "NTv2 subgrid {name} cell count mismatch: expected {} got {gs_count}",
580                    width * height
581                )));
582            }
583
584            let data_len = gs_count
585                .checked_mul(4)
586                .and_then(|count| count.checked_mul(4))
587                .ok_or_else(|| GridError::Parse("NTv2 data size overflow".into()))?;
588            let data = bytes
589                .get(offset + HEADER_LEN..offset + HEADER_LEN + data_len)
590                .ok_or_else(|| {
591                    GridError::Parse(format!("truncated NTv2 data for subgrid {name}"))
592                })?;
593
594            let mut lat_shift = vec![0.0f64; gs_count];
595            let mut lon_shift = vec![0.0f64; gs_count];
596            for y in 0..height {
597                for x in 0..width {
598                    let source_x = width - 1 - x;
599                    let record_offset = (y * width + source_x) * 16;
600                    let lat = read_f32(data, record_offset, endian)? as f64 * PI / 180.0 / 3600.0;
601                    let lon =
602                        -(read_f32(data, record_offset + 4, endian)? as f64) * PI / 180.0 / 3600.0;
603                    let dest = y * width + x;
604                    lat_shift[dest] = lat;
605                    lon_shift[dest] = lon;
606                }
607            }
608
609            let index = grids.len();
610            name_to_index.insert(name.clone(), index);
611            parent_links.push(
612                if parent.eq_ignore_ascii_case("none") || parent.is_empty() {
613                    None
614                } else {
615                    Some(parent)
616                },
617            );
618            grids.push(Ntv2Grid {
619                name,
620                extent: GridExtent {
621                    west,
622                    south,
623                    east,
624                    north,
625                    res_x,
626                    res_y,
627                },
628                width,
629                height,
630                lat_shift,
631                lon_shift,
632                children: Vec::new(),
633            });
634            offset += HEADER_LEN + data_len;
635        }
636
637        let mut roots = Vec::new();
638        for (idx, parent) in parent_links.into_iter().enumerate() {
639            if let Some(parent_name) = parent {
640                let Some(parent_idx) = name_to_index.get(&parent_name).copied() else {
641                    return Err(GridError::Parse(format!(
642                        "missing NTv2 parent subgrid {parent_name} for {}",
643                        grids[idx].name
644                    )));
645                };
646                grids[parent_idx].children.push(idx);
647            } else {
648                roots.push(idx);
649            }
650        }
651
652        Ok(Self { grids, roots })
653    }
654
655    fn sample(
656        &self,
657        lon_radians: f64,
658        lat_radians: f64,
659    ) -> std::result::Result<GridSample, GridError> {
660        let (grid_idx, local_lon, local_lat) = self.grid_at(lon_radians, lat_radians)?;
661        let (lon_shift, lat_shift) = interpolate(&self.grids[grid_idx], local_lon, local_lat)?;
662        Ok(GridSample {
663            lon_shift_radians: lon_shift,
664            lat_shift_radians: lat_shift,
665        })
666    }
667
668    fn apply(
669        &self,
670        lon_radians: f64,
671        lat_radians: f64,
672        direction: GridShiftDirection,
673    ) -> std::result::Result<(f64, f64), GridError> {
674        match direction {
675            GridShiftDirection::Forward => {
676                let shift = self.sample(lon_radians, lat_radians)?;
677                Ok((
678                    lon_radians + shift.lon_shift_radians,
679                    lat_radians + shift.lat_shift_radians,
680                ))
681            }
682            GridShiftDirection::Reverse => self.apply_inverse(lon_radians, lat_radians),
683        }
684    }
685
686    fn apply_inverse(
687        &self,
688        lon_radians: f64,
689        lat_radians: f64,
690    ) -> std::result::Result<(f64, f64), GridError> {
691        const MAX_ITERATIONS: usize = 10;
692        const TOLERANCE: f64 = 1e-12;
693
694        let mut estimate_lon = lon_radians;
695        let mut estimate_lat = lat_radians;
696
697        for _ in 0..MAX_ITERATIONS {
698            let shift = self.sample(estimate_lon, estimate_lat)?;
699            let next_lon = lon_radians - shift.lon_shift_radians;
700            let next_lat = lat_radians - shift.lat_shift_radians;
701            let diff_lon = next_lon - estimate_lon;
702            let diff_lat = next_lat - estimate_lat;
703            estimate_lon = next_lon;
704            estimate_lat = next_lat;
705            if diff_lon * diff_lon + diff_lat * diff_lat <= TOLERANCE * TOLERANCE {
706                return Ok((estimate_lon, estimate_lat));
707            }
708        }
709
710        Ok((estimate_lon, estimate_lat))
711    }
712
713    fn grid_at(
714        &self,
715        lon_radians: f64,
716        lat_radians: f64,
717    ) -> std::result::Result<(usize, f64, f64), GridError> {
718        for &root in &self.roots {
719            if self.grids[root].extent.contains(lon_radians, lat_radians) {
720                let idx = self.deepest_child(root, lon_radians, lat_radians);
721                let extent = &self.grids[idx].extent;
722                return Ok((idx, lon_radians - extent.west, lat_radians - extent.south));
723            }
724        }
725        Err(GridError::OutsideCoverage(format!(
726            "longitude {:.8} latitude {:.8}",
727            lon_radians.to_degrees(),
728            lat_radians.to_degrees()
729        )))
730    }
731
732    fn deepest_child(&self, index: usize, lon_radians: f64, lat_radians: f64) -> usize {
733        for &child in &self.grids[index].children {
734            if self.grids[child].extent.contains(lon_radians, lat_radians) {
735                return self.deepest_child(child, lon_radians, lat_radians);
736            }
737        }
738        index
739    }
740}
741
742#[derive(Clone)]
743struct Ntv2Grid {
744    name: String,
745    extent: GridExtent,
746    width: usize,
747    height: usize,
748    lat_shift: Vec<f64>,
749    lon_shift: Vec<f64>,
750    children: Vec<usize>,
751}
752
753#[derive(Clone, Copy)]
754struct GridExtent {
755    west: f64,
756    south: f64,
757    east: f64,
758    north: f64,
759    res_x: f64,
760    res_y: f64,
761}
762
763impl GridExtent {
764    fn contains(&self, lon_radians: f64, lat_radians: f64) -> bool {
765        let epsilon = (self.res_x + self.res_y) * 1e-10;
766        lon_radians >= self.west - epsilon
767            && lon_radians <= self.east + epsilon
768            && lat_radians >= self.south - epsilon
769            && lat_radians <= self.north + epsilon
770    }
771}
772
773fn interpolate(
774    grid: &Ntv2Grid,
775    local_lon: f64,
776    local_lat: f64,
777) -> std::result::Result<(f64, f64), GridError> {
778    let lam = local_lon / grid.extent.res_x;
779    let phi = local_lat / grid.extent.res_y;
780    let mut x = lam.floor() as isize;
781    let mut y = phi.floor() as isize;
782    let mut fx = lam - x as f64;
783    let mut fy = phi - y as f64;
784
785    if x < 0 {
786        if x == -1 && fx > 1.0 - 1e-9 {
787            x = 0;
788            fx = 0.0;
789        } else {
790            return Err(GridError::OutsideCoverage(grid.name.clone()));
791        }
792    }
793    if y < 0 {
794        if y == -1 && fy > 1.0 - 1e-9 {
795            y = 0;
796            fy = 0.0;
797        } else {
798            return Err(GridError::OutsideCoverage(grid.name.clone()));
799        }
800    }
801    if x as usize + 1 >= grid.width {
802        if x as usize + 1 == grid.width && fx < 1e-9 {
803            x -= 1;
804            fx = 1.0;
805        } else {
806            return Err(GridError::OutsideCoverage(grid.name.clone()));
807        }
808    }
809    if y as usize + 1 >= grid.height {
810        if y as usize + 1 == grid.height && fy < 1e-9 {
811            y -= 1;
812            fy = 1.0;
813        } else {
814            return Err(GridError::OutsideCoverage(grid.name.clone()));
815        }
816    }
817
818    let idx = |xx: usize, yy: usize| yy * grid.width + xx;
819    let x0 = x as usize;
820    let y0 = y as usize;
821    let x1 = x0 + 1;
822    let y1 = y0 + 1;
823
824    let m00 = (1.0 - fx) * (1.0 - fy);
825    let m10 = fx * (1.0 - fy);
826    let m01 = (1.0 - fx) * fy;
827    let m11 = fx * fy;
828
829    let lon = m00 * grid.lon_shift[idx(x0, y0)]
830        + m10 * grid.lon_shift[idx(x1, y0)]
831        + m01 * grid.lon_shift[idx(x0, y1)]
832        + m11 * grid.lon_shift[idx(x1, y1)];
833    let lat = m00 * grid.lat_shift[idx(x0, y0)]
834        + m10 * grid.lat_shift[idx(x1, y0)]
835        + m01 * grid.lat_shift[idx(x0, y1)]
836        + m11 * grid.lat_shift[idx(x1, y1)];
837
838    Ok((lon, lat))
839}
840
841#[derive(Clone)]
842struct GtxGrid {
843    west_degrees: f64,
844    south_degrees: f64,
845    east_degrees: f64,
846    north_degrees: f64,
847    delta_lon_degrees: f64,
848    delta_lat_degrees: f64,
849    width: usize,
850    height: usize,
851    offsets_meters: Vec<f64>,
852}
853
854impl GtxGrid {
855    fn parse(bytes: &[u8]) -> std::result::Result<Self, GridError> {
856        const HEADER_LEN: usize = 40;
857
858        if bytes.len() < HEADER_LEN {
859            return Err(GridError::Parse("GTX file too small".into()));
860        }
861
862        let south_degrees = read_f64(bytes, 0, Endian::Big)?;
863        let west_degrees = read_f64(bytes, 8, Endian::Big)?;
864        let delta_lat_degrees = read_f64(bytes, 16, Endian::Big)?;
865        let delta_lon_degrees = read_f64(bytes, 24, Endian::Big)?;
866        let height_i32 = read_i32(bytes, 32, Endian::Big)?;
867        let width_i32 = read_i32(bytes, 36, Endian::Big)?;
868
869        if !(west_degrees.is_finite()
870            && south_degrees.is_finite()
871            && delta_lon_degrees.is_finite()
872            && delta_lat_degrees.is_finite()
873            && delta_lon_degrees > 0.0
874            && delta_lat_degrees > 0.0
875            && width_i32 >= 2
876            && height_i32 >= 2)
877        {
878            return Err(GridError::Parse("invalid GTX georeferencing".into()));
879        }
880        let height = height_i32 as usize;
881        let width = width_i32 as usize;
882
883        let count = width
884            .checked_mul(height)
885            .ok_or_else(|| GridError::Parse("GTX data size overflow".into()))?;
886        let expected_len = HEADER_LEN
887            + count
888                .checked_mul(4)
889                .ok_or_else(|| GridError::Parse("GTX data size overflow".into()))?;
890        if bytes.len() < expected_len {
891            return Err(GridError::Parse("truncated GTX data".into()));
892        }
893
894        let mut offsets_meters = Vec::with_capacity(count);
895        for index in 0..count {
896            let value = read_f32(bytes, HEADER_LEN + index * 4, Endian::Big)? as f64;
897            if (value + 88.8888).abs() <= 1e-4 {
898                offsets_meters.push(f64::NAN);
899            } else {
900                offsets_meters.push(value);
901            }
902        }
903
904        let east_degrees = west_degrees + delta_lon_degrees * (width - 1) as f64;
905        let north_degrees = south_degrees + delta_lat_degrees * (height - 1) as f64;
906
907        Ok(Self {
908            west_degrees,
909            south_degrees,
910            east_degrees,
911            north_degrees,
912            delta_lon_degrees,
913            delta_lat_degrees,
914            width,
915            height,
916            offsets_meters,
917        })
918    }
919
920    fn sample(
921        &self,
922        lon_radians: f64,
923        lat_radians: f64,
924    ) -> std::result::Result<VerticalGridSample, GridError> {
925        let lon_degrees = self.normalize_lon_degrees(lon_radians.to_degrees());
926        let lat_degrees = lat_radians.to_degrees();
927
928        if !self.contains(lon_degrees, lat_degrees) {
929            return Err(GridError::OutsideCoverage(format!(
930                "longitude {:.8} latitude {:.8}",
931                lon_radians.to_degrees(),
932                lat_degrees
933            )));
934        }
935
936        let lam = (lon_degrees - self.west_degrees) / self.delta_lon_degrees;
937        let phi = (lat_degrees - self.south_degrees) / self.delta_lat_degrees;
938        let mut x = lam.floor() as isize;
939        let mut y = phi.floor() as isize;
940        let mut fx = lam - x as f64;
941        let mut fy = phi - y as f64;
942
943        if x as usize + 1 >= self.width {
944            if x as usize + 1 == self.width && fx < 1e-9 {
945                x -= 1;
946                fx = 1.0;
947            } else {
948                return Err(GridError::OutsideCoverage("GTX longitude edge".into()));
949            }
950        }
951        if y as usize + 1 >= self.height {
952            if y as usize + 1 == self.height && fy < 1e-9 {
953                y -= 1;
954                fy = 1.0;
955            } else {
956                return Err(GridError::OutsideCoverage("GTX latitude edge".into()));
957            }
958        }
959        if x < 0 || y < 0 {
960            return Err(GridError::OutsideCoverage("GTX negative grid index".into()));
961        }
962
963        let x0 = x as usize;
964        let y0 = y as usize;
965        let x1 = x0 + 1;
966        let y1 = y0 + 1;
967        let idx = |xx: usize, yy: usize| yy * self.width + xx;
968        let z00 = self.offsets_meters[idx(x0, y0)];
969        let z10 = self.offsets_meters[idx(x1, y0)];
970        let z01 = self.offsets_meters[idx(x0, y1)];
971        let z11 = self.offsets_meters[idx(x1, y1)];
972
973        if !(z00.is_finite() && z10.is_finite() && z01.is_finite() && z11.is_finite()) {
974            return Err(GridError::OutsideCoverage(
975                "GTX interpolation touches a null cell".into(),
976            ));
977        }
978
979        let m00 = (1.0 - fx) * (1.0 - fy);
980        let m10 = fx * (1.0 - fy);
981        let m01 = (1.0 - fx) * fy;
982        let m11 = fx * fy;
983        Ok(VerticalGridSample {
984            offset_meters: m00 * z00 + m10 * z10 + m01 * z01 + m11 * z11,
985        })
986    }
987
988    fn contains(&self, lon_degrees: f64, lat_degrees: f64) -> bool {
989        let epsilon = (self.delta_lon_degrees + self.delta_lat_degrees) * 1e-10;
990        lon_degrees >= self.west_degrees - epsilon
991            && lon_degrees <= self.east_degrees + epsilon
992            && lat_degrees >= self.south_degrees - epsilon
993            && lat_degrees <= self.north_degrees + epsilon
994    }
995
996    fn normalize_lon_degrees(&self, lon_degrees: f64) -> f64 {
997        if self.contains(lon_degrees, self.south_degrees) {
998            return lon_degrees;
999        }
1000
1001        let mut lon = lon_degrees;
1002        while lon < self.west_degrees {
1003            lon += 360.0;
1004        }
1005        while lon > self.east_degrees {
1006            lon -= 360.0;
1007        }
1008        lon
1009    }
1010}
1011
1012#[derive(Clone, Copy)]
1013enum Endian {
1014    Little,
1015    Big,
1016}
1017
1018fn parse_label(bytes: &[u8]) -> String {
1019    let end = bytes
1020        .iter()
1021        .position(|byte| *byte == 0)
1022        .unwrap_or(bytes.len());
1023    String::from_utf8_lossy(&bytes[..end]).trim().to_string()
1024}
1025
1026fn read_u32(bytes: &[u8], offset: usize, endian: Endian) -> std::result::Result<u32, GridError> {
1027    let slice = bytes
1028        .get(offset..offset + 4)
1029        .ok_or_else(|| GridError::Parse("truncated integer".into()))?;
1030    Ok(match endian {
1031        Endian::Little => u32::from_le_bytes(slice.try_into().expect("slice length checked")),
1032        Endian::Big => u32::from_be_bytes(slice.try_into().expect("slice length checked")),
1033    })
1034}
1035
1036fn read_i32(bytes: &[u8], offset: usize, endian: Endian) -> std::result::Result<i32, GridError> {
1037    let slice = bytes
1038        .get(offset..offset + 4)
1039        .ok_or_else(|| GridError::Parse("truncated integer".into()))?;
1040    Ok(match endian {
1041        Endian::Little => i32::from_le_bytes(slice.try_into().expect("slice length checked")),
1042        Endian::Big => i32::from_be_bytes(slice.try_into().expect("slice length checked")),
1043    })
1044}
1045
1046fn read_f64(bytes: &[u8], offset: usize, endian: Endian) -> std::result::Result<f64, GridError> {
1047    let slice = bytes
1048        .get(offset..offset + 8)
1049        .ok_or_else(|| GridError::Parse("truncated float64".into()))?;
1050    Ok(match endian {
1051        Endian::Little => f64::from_le_bytes(slice.try_into().expect("slice length checked")),
1052        Endian::Big => f64::from_be_bytes(slice.try_into().expect("slice length checked")),
1053    })
1054}
1055
1056fn read_f32(bytes: &[u8], offset: usize, endian: Endian) -> std::result::Result<f32, GridError> {
1057    let slice = bytes
1058        .get(offset..offset + 4)
1059        .ok_or_else(|| GridError::Parse("truncated float32".into()))?;
1060    Ok(match endian {
1061        Endian::Little => f32::from_le_bytes(slice.try_into().expect("slice length checked")),
1062        Endian::Big => f32::from_be_bytes(slice.try_into().expect("slice length checked")),
1063    })
1064}
1065
1066#[cfg(test)]
1067mod tests {
1068    use super::*;
1069    use std::sync::atomic::{AtomicUsize, Ordering};
1070
1071    #[test]
1072    fn embedded_ntv2_grid_samples_known_point() {
1073        let provider = EmbeddedGridProvider;
1074        let definition = GridDefinition {
1075            id: GridId(1),
1076            name: "ntv2_0.gsb".into(),
1077            format: GridFormat::Ntv2,
1078            interpolation: GridInterpolation::Bilinear,
1079            area_of_use: None,
1080            resource_names: SmallVec::from_vec(vec!["ntv2_0.gsb".into()]),
1081        };
1082        let handle = provider.load(&definition).unwrap().expect("embedded grid");
1083        let (lon, lat) = handle
1084            .apply(
1085                (-80.5041667f64).to_radians(),
1086                44.5458333f64.to_radians(),
1087                GridShiftDirection::Forward,
1088            )
1089            .unwrap();
1090        assert!(
1091            (lon.to_degrees() - (-80.50401615833)).abs() < 1e-6,
1092            "lon={lon}"
1093        );
1094        assert!((lat.to_degrees() - 44.5458827236).abs() < 3e-6, "lat={lat}");
1095    }
1096
1097    #[test]
1098    fn embedded_provider_reuses_parsed_grid_data() {
1099        let provider = EmbeddedGridProvider;
1100        let definition = test_grid_definition();
1101
1102        let first = provider.load(&definition).unwrap().expect("embedded grid");
1103        let mut renamed = definition.clone();
1104        renamed.name = "renamed ntv2 grid".into();
1105        let second = provider.load(&renamed).unwrap().expect("embedded grid");
1106
1107        assert!(Arc::ptr_eq(&first.data, &second.data));
1108        assert_eq!(second.definition().name, "renamed ntv2 grid");
1109    }
1110
1111    #[test]
1112    fn grid_handle_reports_sha256_checksum() {
1113        let provider = EmbeddedGridProvider;
1114        let handle = provider
1115            .load(&test_grid_definition())
1116            .unwrap()
1117            .expect("embedded grid");
1118
1119        assert!(handle.checksum().starts_with("sha256:"));
1120        assert_eq!(handle.checksum().len(), 71);
1121        assert_eq!(
1122            sha256_hex(b"abc"),
1123            "sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"
1124        );
1125    }
1126
1127    struct TrackingGridProvider {
1128        override_definition: bool,
1129        definition_calls: Arc<AtomicUsize>,
1130        load_calls: Arc<AtomicUsize>,
1131    }
1132
1133    impl GridProvider for TrackingGridProvider {
1134        fn definition(
1135            &self,
1136            grid: &GridDefinition,
1137        ) -> std::result::Result<Option<GridDefinition>, GridError> {
1138            self.definition_calls.fetch_add(1, Ordering::SeqCst);
1139            if self.override_definition {
1140                let mut overridden = grid.clone();
1141                overridden.name = "custom override".into();
1142                Ok(Some(overridden))
1143            } else {
1144                Ok(None)
1145            }
1146        }
1147
1148        fn load(
1149            &self,
1150            grid: &GridDefinition,
1151        ) -> std::result::Result<Option<GridHandle>, GridError> {
1152            self.load_calls.fetch_add(1, Ordering::SeqCst);
1153            EmbeddedGridProvider.load(grid)
1154        }
1155    }
1156
1157    fn test_grid_definition() -> GridDefinition {
1158        GridDefinition {
1159            id: GridId(1),
1160            name: "ntv2_0.gsb".into(),
1161            format: GridFormat::Ntv2,
1162            interpolation: GridInterpolation::Bilinear,
1163            area_of_use: None,
1164            resource_names: SmallVec::from_vec(vec!["ntv2_0.gsb".into()]),
1165        }
1166    }
1167
1168    #[test]
1169    fn filesystem_provider_rejects_unsafe_resource_names() {
1170        let root = std::env::temp_dir().join(format!("proj-core-grid-root-{}", std::process::id()));
1171        std::fs::create_dir_all(&root).unwrap();
1172        std::fs::write(root.join("safe.gtx"), []).unwrap();
1173
1174        let provider = FilesystemGridProvider::new(vec![root.clone()]);
1175        let mut definition = test_grid_definition();
1176        definition.format = GridFormat::Gtx;
1177        definition.resource_names = SmallVec::from_vec(vec!["../safe.gtx".into()]);
1178        assert!(provider.definition(&definition).unwrap().is_none());
1179
1180        definition.resource_names =
1181            SmallVec::from_vec(vec![root.join("safe.gtx").to_string_lossy().into_owned()]);
1182        assert!(provider.definition(&definition).unwrap().is_none());
1183
1184        definition.resource_names = SmallVec::from_vec(vec!["safe.gtx".into()]);
1185        assert!(provider.definition(&definition).unwrap().is_some());
1186    }
1187
1188    fn test_gtx_bytes(values: &[f32]) -> Vec<u8> {
1189        let mut bytes = Vec::new();
1190        bytes.extend_from_slice(&10.0f64.to_be_bytes());
1191        bytes.extend_from_slice(&20.0f64.to_be_bytes());
1192        bytes.extend_from_slice(&1.0f64.to_be_bytes());
1193        bytes.extend_from_slice(&1.0f64.to_be_bytes());
1194        bytes.extend_from_slice(&3i32.to_be_bytes());
1195        bytes.extend_from_slice(&3i32.to_be_bytes());
1196        for value in values {
1197            bytes.extend_from_slice(&value.to_be_bytes());
1198        }
1199        bytes
1200    }
1201
1202    #[test]
1203    fn gtx_grid_samples_bilinear_offsets() {
1204        let bytes = test_gtx_bytes(&[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
1205        let data = parse_grid_data(GridFormat::Gtx, "test.gtx", &bytes).unwrap();
1206        let GridData::Gtx(grid) = data else {
1207            panic!("expected GTX grid");
1208        };
1209
1210        let sample = grid
1211            .sample(20.5f64.to_radians(), 10.5f64.to_radians())
1212            .unwrap();
1213        assert!((sample.offset_meters - 2.0).abs() < 1e-12);
1214    }
1215
1216    #[test]
1217    fn gtx_grid_rejects_outside_or_null_cells() {
1218        let bytes = test_gtx_bytes(&[0.0, 1.0, 2.0, 3.0, -88.8888, 5.0, 6.0, 7.0, 8.0]);
1219        let data = parse_grid_data(GridFormat::Gtx, "test.gtx", &bytes).unwrap();
1220        let GridData::Gtx(grid) = data else {
1221            panic!("expected GTX grid");
1222        };
1223
1224        let null_err = grid
1225            .sample(20.5f64.to_radians(), 10.5f64.to_radians())
1226            .unwrap_err();
1227        assert!(matches!(null_err, GridError::OutsideCoverage(_)));
1228
1229        let outside_err = grid
1230            .sample(30.0f64.to_radians(), 10.5f64.to_radians())
1231            .unwrap_err();
1232        assert!(matches!(outside_err, GridError::OutsideCoverage(_)));
1233    }
1234
1235    #[test]
1236    fn app_grid_provider_can_override_embedded_grid() {
1237        let definition_calls = Arc::new(AtomicUsize::new(0));
1238        let load_calls = Arc::new(AtomicUsize::new(0));
1239        let provider = TrackingGridProvider {
1240            override_definition: true,
1241            definition_calls: Arc::clone(&definition_calls),
1242            load_calls: Arc::clone(&load_calls),
1243        };
1244        let runtime = GridRuntime::new(Some(Arc::new(provider)));
1245
1246        let handle = runtime
1247            .resolve_handle(&test_grid_definition())
1248            .expect("grid should resolve");
1249
1250        assert_eq!(handle.definition().name, "custom override");
1251        assert_eq!(definition_calls.load(Ordering::SeqCst), 1);
1252        assert_eq!(load_calls.load(Ordering::SeqCst), 1);
1253    }
1254
1255    #[test]
1256    fn app_grid_provider_falls_back_to_embedded_grid() {
1257        let definition_calls = Arc::new(AtomicUsize::new(0));
1258        let load_calls = Arc::new(AtomicUsize::new(0));
1259        let provider = TrackingGridProvider {
1260            override_definition: false,
1261            definition_calls: Arc::clone(&definition_calls),
1262            load_calls: Arc::clone(&load_calls),
1263        };
1264        let runtime = GridRuntime::new(Some(Arc::new(provider)));
1265
1266        let handle = runtime
1267            .resolve_handle(&test_grid_definition())
1268            .expect("embedded grid should remain available");
1269
1270        assert_eq!(handle.definition().name, "ntv2_0.gsb");
1271        assert_eq!(definition_calls.load(Ordering::SeqCst), 1);
1272        assert_eq!(load_calls.load(Ordering::SeqCst), 1);
1273    }
1274}