Skip to main content

proj_core/
grid.rs

1use crate::operation::{AreaOfUse, GridId, GridInterpolation, GridShiftDirection};
2use smallvec::SmallVec;
3use std::collections::HashMap;
4use std::f64::consts::PI;
5use std::path::PathBuf;
6use std::sync::{Arc, Mutex, OnceLock};
7use thiserror::Error;
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
10pub enum GridFormat {
11    Ntv2,
12    Unsupported,
13}
14
15#[derive(Debug, Clone, PartialEq)]
16pub struct GridDefinition {
17    pub id: GridId,
18    pub name: String,
19    pub format: GridFormat,
20    pub interpolation: GridInterpolation,
21    pub area_of_use: Option<AreaOfUse>,
22    pub resource_names: SmallVec<[String; 2]>,
23}
24
25#[derive(Debug, Clone, Copy, PartialEq)]
26pub struct GridSample {
27    pub lon_shift_radians: f64,
28    pub lat_shift_radians: f64,
29}
30
31#[derive(Debug, Error, Clone)]
32pub enum GridError {
33    #[error("grid not found: {0}")]
34    NotFound(String),
35    #[error("grid resource unavailable: {0}")]
36    Unavailable(String),
37    #[error("grid parse error: {0}")]
38    Parse(String),
39    #[error("grid point outside coverage: {0}")]
40    OutsideCoverage(String),
41    #[error("unsupported grid format: {0}")]
42    UnsupportedFormat(String),
43}
44
45pub trait GridProvider: Send + Sync {
46    fn definition(
47        &self,
48        grid: &GridDefinition,
49    ) -> std::result::Result<Option<GridDefinition>, GridError>;
50    fn load(&self, grid: &GridDefinition) -> std::result::Result<Option<GridHandle>, GridError>;
51}
52
53#[derive(Clone)]
54pub struct GridHandle {
55    definition: GridDefinition,
56    data: Arc<GridData>,
57}
58
59impl GridHandle {
60    pub fn definition(&self) -> &GridDefinition {
61        &self.definition
62    }
63
64    pub fn sample(
65        &self,
66        lon_radians: f64,
67        lat_radians: f64,
68    ) -> std::result::Result<GridSample, GridError> {
69        match self.data.as_ref() {
70            GridData::Ntv2(set) => set.sample(lon_radians, lat_radians),
71        }
72    }
73
74    pub fn apply(
75        &self,
76        lon_radians: f64,
77        lat_radians: f64,
78        direction: GridShiftDirection,
79    ) -> std::result::Result<(f64, f64), GridError> {
80        match self.data.as_ref() {
81            GridData::Ntv2(set) => set.apply(lon_radians, lat_radians, direction),
82        }
83    }
84}
85
86pub(crate) struct GridRuntime {
87    providers: Vec<Arc<dyn GridProvider>>,
88    definition_cache: Mutex<HashMap<String, GridDefinition>>,
89    handle_cache: Mutex<HashMap<String, GridHandle>>,
90}
91
92impl GridRuntime {
93    pub(crate) fn new(app_provider: Option<Arc<dyn GridProvider>>) -> Self {
94        let mut providers: Vec<Arc<dyn GridProvider>> = Vec::with_capacity(2);
95        if let Some(provider) = app_provider {
96            providers.push(provider);
97        }
98        providers.push(Arc::new(EmbeddedGridProvider));
99        Self {
100            providers,
101            definition_cache: Mutex::new(HashMap::new()),
102            handle_cache: Mutex::new(HashMap::new()),
103        }
104    }
105
106    pub(crate) fn resolve_definition(
107        &self,
108        grid: &GridDefinition,
109    ) -> std::result::Result<GridDefinition, GridError> {
110        let cache_key = grid_runtime_cache_key(grid);
111        if let Some(cached) = self
112            .definition_cache
113            .lock()
114            .expect("grid definition cache poisoned")
115            .get(&cache_key)
116            .cloned()
117        {
118            return Ok(cached);
119        }
120
121        for provider in &self.providers {
122            if let Some(definition) = provider.definition(grid)? {
123                self.definition_cache
124                    .lock()
125                    .expect("grid definition cache poisoned")
126                    .insert(cache_key, definition.clone());
127                return Ok(definition);
128            }
129        }
130
131        Err(GridError::Unavailable(grid.name.clone()))
132    }
133
134    pub(crate) fn resolve_handle(
135        &self,
136        grid: &GridDefinition,
137    ) -> std::result::Result<GridHandle, GridError> {
138        let cache_key = grid_runtime_cache_key(grid);
139        if let Some(cached) = self
140            .handle_cache
141            .lock()
142            .expect("grid handle cache poisoned")
143            .get(&cache_key)
144            .cloned()
145        {
146            return Ok(cached);
147        }
148
149        let definition = self.resolve_definition(grid)?;
150        for provider in &self.providers {
151            if let Some(handle) = provider.load(&definition)? {
152                self.handle_cache
153                    .lock()
154                    .expect("grid handle cache poisoned")
155                    .insert(cache_key, handle.clone());
156                return Ok(handle);
157            }
158        }
159
160        Err(GridError::Unavailable(definition.name))
161    }
162}
163
164fn grid_runtime_cache_key(grid: &GridDefinition) -> String {
165    let mut key = format!("{}|{:?}", grid.id.0, grid.format);
166    for resource in &grid.resource_names {
167        key.push('|');
168        key.push_str(resource);
169    }
170    key
171}
172
173#[derive(Default)]
174pub struct EmbeddedGridProvider;
175
176impl GridProvider for EmbeddedGridProvider {
177    fn definition(
178        &self,
179        grid: &GridDefinition,
180    ) -> std::result::Result<Option<GridDefinition>, GridError> {
181        if embedded_grid_resource(&grid.resource_names).is_some() {
182            return Ok(Some(grid.clone()));
183        }
184        Ok(None)
185    }
186
187    fn load(&self, grid: &GridDefinition) -> std::result::Result<Option<GridHandle>, GridError> {
188        let Some((resource_name, bytes)) = embedded_grid_resource(&grid.resource_names) else {
189            return Ok(None);
190        };
191
192        let key = GridDataCacheKey::new(grid.format, resource_name);
193        let data = cached_grid_data(embedded_grid_data_cache(), key, || {
194            parse_grid_data(grid.format, &grid.name, bytes)
195        })?;
196
197        Ok(Some(GridHandle {
198            definition: grid.clone(),
199            data,
200        }))
201    }
202}
203
204pub struct FilesystemGridProvider {
205    roots: Vec<PathBuf>,
206    data_cache: Mutex<HashMap<GridDataCacheKey, Arc<GridData>>>,
207}
208
209impl FilesystemGridProvider {
210    pub fn new<I>(roots: I) -> Self
211    where
212        I: IntoIterator<Item = PathBuf>,
213    {
214        Self {
215            roots: roots.into_iter().collect(),
216            data_cache: Mutex::new(HashMap::new()),
217        }
218    }
219
220    fn locate(&self, grid: &GridDefinition) -> Option<PathBuf> {
221        for root in &self.roots {
222            for name in &grid.resource_names {
223                let candidate = root.join(name);
224                if candidate.exists() {
225                    return Some(candidate);
226                }
227            }
228        }
229        None
230    }
231}
232
233impl GridProvider for FilesystemGridProvider {
234    fn definition(
235        &self,
236        grid: &GridDefinition,
237    ) -> std::result::Result<Option<GridDefinition>, GridError> {
238        if self.locate(grid).is_some() {
239            return Ok(Some(grid.clone()));
240        }
241        Ok(None)
242    }
243
244    fn load(&self, grid: &GridDefinition) -> std::result::Result<Option<GridHandle>, GridError> {
245        let Some(path) = self.locate(grid) else {
246            return Ok(None);
247        };
248
249        let cache_path = path.canonicalize().unwrap_or_else(|_| path.clone());
250        let key = GridDataCacheKey::new(grid.format, cache_path.to_string_lossy());
251        let data = cached_grid_data(&self.data_cache, key, || {
252            let bytes = std::fs::read(&path)
253                .map_err(|err| GridError::Unavailable(format!("{}: {err}", path.display())))?;
254            parse_grid_data(grid.format, &grid.name, &bytes)
255        })?;
256
257        Ok(Some(GridHandle {
258            definition: grid.clone(),
259            data,
260        }))
261    }
262}
263
264enum GridData {
265    Ntv2(Ntv2GridSet),
266}
267
268#[derive(Debug, Clone, PartialEq, Eq, Hash)]
269struct GridDataCacheKey {
270    format: GridFormat,
271    resource: String,
272}
273
274impl GridDataCacheKey {
275    fn new(format: GridFormat, resource: impl AsRef<str>) -> Self {
276        Self {
277            format,
278            resource: resource.as_ref().to_string(),
279        }
280    }
281}
282
283fn embedded_grid_data_cache() -> &'static Mutex<HashMap<GridDataCacheKey, Arc<GridData>>> {
284    static CACHE: OnceLock<Mutex<HashMap<GridDataCacheKey, Arc<GridData>>>> = OnceLock::new();
285    CACHE.get_or_init(|| Mutex::new(HashMap::new()))
286}
287
288fn cached_grid_data(
289    cache: &Mutex<HashMap<GridDataCacheKey, Arc<GridData>>>,
290    key: GridDataCacheKey,
291    parse: impl FnOnce() -> std::result::Result<GridData, GridError>,
292) -> std::result::Result<Arc<GridData>, GridError> {
293    if let Some(cached) = cache
294        .lock()
295        .expect("grid data cache poisoned")
296        .get(&key)
297        .cloned()
298    {
299        return Ok(cached);
300    }
301
302    let parsed = Arc::new(parse()?);
303    let mut cache = cache.lock().expect("grid data cache poisoned");
304    let cached = cache.entry(key).or_insert_with(|| Arc::clone(&parsed));
305    Ok(Arc::clone(cached))
306}
307
308fn parse_grid_data(
309    format: GridFormat,
310    name: &str,
311    bytes: &[u8],
312) -> std::result::Result<GridData, GridError> {
313    match format {
314        GridFormat::Ntv2 => Ok(GridData::Ntv2(Ntv2GridSet::parse(bytes)?)),
315        GridFormat::Unsupported => Err(GridError::UnsupportedFormat(name.into())),
316    }
317}
318
319fn embedded_grid_resource(names: &[String]) -> Option<(&'static str, &'static [u8])> {
320    for name in names {
321        if name.eq_ignore_ascii_case("ntv2_0.gsb") {
322            return Some(("ntv2_0.gsb", include_bytes!("../data/grids/ntv2_0.gsb")));
323        }
324    }
325    None
326}
327
328#[derive(Clone)]
329struct Ntv2GridSet {
330    grids: Vec<Ntv2Grid>,
331    roots: Vec<usize>,
332}
333
334impl Ntv2GridSet {
335    fn parse(bytes: &[u8]) -> std::result::Result<Self, GridError> {
336        const HEADER_LEN: usize = 11 * 16;
337
338        if bytes.len() < HEADER_LEN {
339            return Err(GridError::Parse("NTv2 file too small".into()));
340        }
341
342        let endian = if u32::from_le_bytes(bytes[8..12].try_into().expect("slice length checked"))
343            == 11
344        {
345            Endian::Little
346        } else if u32::from_be_bytes(bytes[8..12].try_into().expect("slice length checked")) == 11 {
347            Endian::Big
348        } else {
349            return Err(GridError::Parse(
350                "invalid NTv2 header endianness marker".into(),
351            ));
352        };
353
354        if &bytes[56..63] != b"SECONDS" {
355            return Err(GridError::Parse(
356                "only NTv2 GS_TYPE=SECONDS is supported".into(),
357            ));
358        }
359
360        let num_subfiles = read_u32(bytes, 40, endian)? as usize;
361        let mut offset = HEADER_LEN;
362        let mut grids = Vec::with_capacity(num_subfiles);
363        let mut name_to_index = HashMap::new();
364        let mut parent_links: Vec<Option<String>> = Vec::with_capacity(num_subfiles);
365
366        for _ in 0..num_subfiles {
367            let header = bytes
368                .get(offset..offset + HEADER_LEN)
369                .ok_or_else(|| GridError::Parse("truncated NTv2 subfile header".into()))?;
370            if &header[0..8] != b"SUB_NAME" {
371                return Err(GridError::Parse("invalid NTv2 subfile header tag".into()));
372            }
373
374            let name = parse_label(&header[8..16]);
375            let parent = parse_label(&header[24..32]);
376            let south = read_f64(header, 72, endian)? * PI / 180.0 / 3600.0;
377            let north = read_f64(header, 88, endian)? * PI / 180.0 / 3600.0;
378            let east = -read_f64(header, 104, endian)? * PI / 180.0 / 3600.0;
379            let west = -read_f64(header, 120, endian)? * PI / 180.0 / 3600.0;
380            let res_y = read_f64(header, 136, endian)? * PI / 180.0 / 3600.0;
381            let res_x = read_f64(header, 152, endian)? * PI / 180.0 / 3600.0;
382            let gs_count = read_u32(header, 168, endian)? as usize;
383
384            if !(west < east && south < north && res_x > 0.0 && res_y > 0.0) {
385                return Err(GridError::Parse(format!(
386                    "invalid NTv2 georeferencing for subgrid {name}"
387                )));
388            }
389
390            let width = (((east - west) / res_x).abs() + 0.5).floor() as usize + 1;
391            let height = (((north - south) / res_y).abs() + 0.5).floor() as usize + 1;
392            if width * height != gs_count {
393                return Err(GridError::Parse(format!(
394                    "NTv2 subgrid {name} cell count mismatch: expected {} got {gs_count}",
395                    width * height
396                )));
397            }
398
399            let data_len = gs_count
400                .checked_mul(4)
401                .and_then(|count| count.checked_mul(4))
402                .ok_or_else(|| GridError::Parse("NTv2 data size overflow".into()))?;
403            let data = bytes
404                .get(offset + HEADER_LEN..offset + HEADER_LEN + data_len)
405                .ok_or_else(|| {
406                    GridError::Parse(format!("truncated NTv2 data for subgrid {name}"))
407                })?;
408
409            let mut lat_shift = vec![0.0f64; gs_count];
410            let mut lon_shift = vec![0.0f64; gs_count];
411            for y in 0..height {
412                for x in 0..width {
413                    let source_x = width - 1 - x;
414                    let record_offset = (y * width + source_x) * 16;
415                    let lat = read_f32(data, record_offset, endian)? as f64 * PI / 180.0 / 3600.0;
416                    let lon =
417                        -(read_f32(data, record_offset + 4, endian)? as f64) * PI / 180.0 / 3600.0;
418                    let dest = y * width + x;
419                    lat_shift[dest] = lat;
420                    lon_shift[dest] = lon;
421                }
422            }
423
424            let index = grids.len();
425            name_to_index.insert(name.clone(), index);
426            parent_links.push(
427                if parent.eq_ignore_ascii_case("none") || parent.is_empty() {
428                    None
429                } else {
430                    Some(parent)
431                },
432            );
433            grids.push(Ntv2Grid {
434                name,
435                extent: GridExtent {
436                    west,
437                    south,
438                    east,
439                    north,
440                    res_x,
441                    res_y,
442                },
443                width,
444                height,
445                lat_shift,
446                lon_shift,
447                children: Vec::new(),
448            });
449            offset += HEADER_LEN + data_len;
450        }
451
452        let mut roots = Vec::new();
453        for (idx, parent) in parent_links.into_iter().enumerate() {
454            if let Some(parent_name) = parent {
455                let Some(parent_idx) = name_to_index.get(&parent_name).copied() else {
456                    return Err(GridError::Parse(format!(
457                        "missing NTv2 parent subgrid {parent_name} for {}",
458                        grids[idx].name
459                    )));
460                };
461                grids[parent_idx].children.push(idx);
462            } else {
463                roots.push(idx);
464            }
465        }
466
467        Ok(Self { grids, roots })
468    }
469
470    fn sample(
471        &self,
472        lon_radians: f64,
473        lat_radians: f64,
474    ) -> std::result::Result<GridSample, GridError> {
475        let (grid_idx, local_lon, local_lat) = self.grid_at(lon_radians, lat_radians)?;
476        let (lon_shift, lat_shift) = interpolate(&self.grids[grid_idx], local_lon, local_lat)?;
477        Ok(GridSample {
478            lon_shift_radians: lon_shift,
479            lat_shift_radians: lat_shift,
480        })
481    }
482
483    fn apply(
484        &self,
485        lon_radians: f64,
486        lat_radians: f64,
487        direction: GridShiftDirection,
488    ) -> std::result::Result<(f64, f64), GridError> {
489        match direction {
490            GridShiftDirection::Forward => {
491                let shift = self.sample(lon_radians, lat_radians)?;
492                Ok((
493                    lon_radians + shift.lon_shift_radians,
494                    lat_radians + shift.lat_shift_radians,
495                ))
496            }
497            GridShiftDirection::Reverse => self.apply_inverse(lon_radians, lat_radians),
498        }
499    }
500
501    fn apply_inverse(
502        &self,
503        lon_radians: f64,
504        lat_radians: f64,
505    ) -> std::result::Result<(f64, f64), GridError> {
506        const MAX_ITERATIONS: usize = 10;
507        const TOLERANCE: f64 = 1e-12;
508
509        let mut estimate_lon = lon_radians;
510        let mut estimate_lat = lat_radians;
511
512        for _ in 0..MAX_ITERATIONS {
513            let shift = self.sample(estimate_lon, estimate_lat)?;
514            let next_lon = lon_radians - shift.lon_shift_radians;
515            let next_lat = lat_radians - shift.lat_shift_radians;
516            let diff_lon = next_lon - estimate_lon;
517            let diff_lat = next_lat - estimate_lat;
518            estimate_lon = next_lon;
519            estimate_lat = next_lat;
520            if diff_lon * diff_lon + diff_lat * diff_lat <= TOLERANCE * TOLERANCE {
521                return Ok((estimate_lon, estimate_lat));
522            }
523        }
524
525        Ok((estimate_lon, estimate_lat))
526    }
527
528    fn grid_at(
529        &self,
530        lon_radians: f64,
531        lat_radians: f64,
532    ) -> std::result::Result<(usize, f64, f64), GridError> {
533        for &root in &self.roots {
534            if self.grids[root].extent.contains(lon_radians, lat_radians) {
535                let idx = self.deepest_child(root, lon_radians, lat_radians);
536                let extent = &self.grids[idx].extent;
537                return Ok((idx, lon_radians - extent.west, lat_radians - extent.south));
538            }
539        }
540        Err(GridError::OutsideCoverage(format!(
541            "longitude {:.8} latitude {:.8}",
542            lon_radians.to_degrees(),
543            lat_radians.to_degrees()
544        )))
545    }
546
547    fn deepest_child(&self, index: usize, lon_radians: f64, lat_radians: f64) -> usize {
548        for &child in &self.grids[index].children {
549            if self.grids[child].extent.contains(lon_radians, lat_radians) {
550                return self.deepest_child(child, lon_radians, lat_radians);
551            }
552        }
553        index
554    }
555}
556
557#[derive(Clone)]
558struct Ntv2Grid {
559    name: String,
560    extent: GridExtent,
561    width: usize,
562    height: usize,
563    lat_shift: Vec<f64>,
564    lon_shift: Vec<f64>,
565    children: Vec<usize>,
566}
567
568#[derive(Clone, Copy)]
569struct GridExtent {
570    west: f64,
571    south: f64,
572    east: f64,
573    north: f64,
574    res_x: f64,
575    res_y: f64,
576}
577
578impl GridExtent {
579    fn contains(&self, lon_radians: f64, lat_radians: f64) -> bool {
580        let epsilon = (self.res_x + self.res_y) * 1e-10;
581        lon_radians >= self.west - epsilon
582            && lon_radians <= self.east + epsilon
583            && lat_radians >= self.south - epsilon
584            && lat_radians <= self.north + epsilon
585    }
586}
587
588fn interpolate(
589    grid: &Ntv2Grid,
590    local_lon: f64,
591    local_lat: f64,
592) -> std::result::Result<(f64, f64), GridError> {
593    let lam = local_lon / grid.extent.res_x;
594    let phi = local_lat / grid.extent.res_y;
595    let mut x = lam.floor() as isize;
596    let mut y = phi.floor() as isize;
597    let mut fx = lam - x as f64;
598    let mut fy = phi - y as f64;
599
600    if x < 0 {
601        if x == -1 && fx > 1.0 - 1e-9 {
602            x = 0;
603            fx = 0.0;
604        } else {
605            return Err(GridError::OutsideCoverage(grid.name.clone()));
606        }
607    }
608    if y < 0 {
609        if y == -1 && fy > 1.0 - 1e-9 {
610            y = 0;
611            fy = 0.0;
612        } else {
613            return Err(GridError::OutsideCoverage(grid.name.clone()));
614        }
615    }
616    if x as usize + 1 >= grid.width {
617        if x as usize + 1 == grid.width && fx < 1e-9 {
618            x -= 1;
619            fx = 1.0;
620        } else {
621            return Err(GridError::OutsideCoverage(grid.name.clone()));
622        }
623    }
624    if y as usize + 1 >= grid.height {
625        if y as usize + 1 == grid.height && fy < 1e-9 {
626            y -= 1;
627            fy = 1.0;
628        } else {
629            return Err(GridError::OutsideCoverage(grid.name.clone()));
630        }
631    }
632
633    let idx = |xx: usize, yy: usize| yy * grid.width + xx;
634    let x0 = x as usize;
635    let y0 = y as usize;
636    let x1 = x0 + 1;
637    let y1 = y0 + 1;
638
639    let m00 = (1.0 - fx) * (1.0 - fy);
640    let m10 = fx * (1.0 - fy);
641    let m01 = (1.0 - fx) * fy;
642    let m11 = fx * fy;
643
644    let lon = m00 * grid.lon_shift[idx(x0, y0)]
645        + m10 * grid.lon_shift[idx(x1, y0)]
646        + m01 * grid.lon_shift[idx(x0, y1)]
647        + m11 * grid.lon_shift[idx(x1, y1)];
648    let lat = m00 * grid.lat_shift[idx(x0, y0)]
649        + m10 * grid.lat_shift[idx(x1, y0)]
650        + m01 * grid.lat_shift[idx(x0, y1)]
651        + m11 * grid.lat_shift[idx(x1, y1)];
652
653    Ok((lon, lat))
654}
655
656#[derive(Clone, Copy)]
657enum Endian {
658    Little,
659    Big,
660}
661
662fn parse_label(bytes: &[u8]) -> String {
663    let end = bytes
664        .iter()
665        .position(|byte| *byte == 0)
666        .unwrap_or(bytes.len());
667    String::from_utf8_lossy(&bytes[..end]).trim().to_string()
668}
669
670fn read_u32(bytes: &[u8], offset: usize, endian: Endian) -> std::result::Result<u32, GridError> {
671    let slice = bytes
672        .get(offset..offset + 4)
673        .ok_or_else(|| GridError::Parse("truncated integer".into()))?;
674    Ok(match endian {
675        Endian::Little => u32::from_le_bytes(slice.try_into().expect("slice length checked")),
676        Endian::Big => u32::from_be_bytes(slice.try_into().expect("slice length checked")),
677    })
678}
679
680fn read_f64(bytes: &[u8], offset: usize, endian: Endian) -> std::result::Result<f64, GridError> {
681    let slice = bytes
682        .get(offset..offset + 8)
683        .ok_or_else(|| GridError::Parse("truncated float64".into()))?;
684    Ok(match endian {
685        Endian::Little => f64::from_le_bytes(slice.try_into().expect("slice length checked")),
686        Endian::Big => f64::from_be_bytes(slice.try_into().expect("slice length checked")),
687    })
688}
689
690fn read_f32(bytes: &[u8], offset: usize, endian: Endian) -> std::result::Result<f32, GridError> {
691    let slice = bytes
692        .get(offset..offset + 4)
693        .ok_or_else(|| GridError::Parse("truncated float32".into()))?;
694    Ok(match endian {
695        Endian::Little => f32::from_le_bytes(slice.try_into().expect("slice length checked")),
696        Endian::Big => f32::from_be_bytes(slice.try_into().expect("slice length checked")),
697    })
698}
699
700#[cfg(test)]
701mod tests {
702    use super::*;
703    use std::sync::atomic::{AtomicUsize, Ordering};
704
705    #[test]
706    fn embedded_ntv2_grid_samples_known_point() {
707        let provider = EmbeddedGridProvider;
708        let definition = GridDefinition {
709            id: GridId(1),
710            name: "ntv2_0.gsb".into(),
711            format: GridFormat::Ntv2,
712            interpolation: GridInterpolation::Bilinear,
713            area_of_use: None,
714            resource_names: SmallVec::from_vec(vec!["ntv2_0.gsb".into()]),
715        };
716        let handle = provider.load(&definition).unwrap().expect("embedded grid");
717        let (lon, lat) = handle
718            .apply(
719                (-80.5041667f64).to_radians(),
720                44.5458333f64.to_radians(),
721                GridShiftDirection::Forward,
722            )
723            .unwrap();
724        assert!(
725            (lon.to_degrees() - (-80.50401615833)).abs() < 1e-6,
726            "lon={lon}"
727        );
728        assert!((lat.to_degrees() - 44.5458827236).abs() < 3e-6, "lat={lat}");
729    }
730
731    #[test]
732    fn embedded_provider_reuses_parsed_grid_data() {
733        let provider = EmbeddedGridProvider;
734        let definition = test_grid_definition();
735
736        let first = provider.load(&definition).unwrap().expect("embedded grid");
737        let mut renamed = definition.clone();
738        renamed.name = "renamed ntv2 grid".into();
739        let second = provider.load(&renamed).unwrap().expect("embedded grid");
740
741        assert!(Arc::ptr_eq(&first.data, &second.data));
742        assert_eq!(second.definition().name, "renamed ntv2 grid");
743    }
744
745    struct TrackingGridProvider {
746        override_definition: bool,
747        definition_calls: Arc<AtomicUsize>,
748        load_calls: Arc<AtomicUsize>,
749    }
750
751    impl GridProvider for TrackingGridProvider {
752        fn definition(
753            &self,
754            grid: &GridDefinition,
755        ) -> std::result::Result<Option<GridDefinition>, GridError> {
756            self.definition_calls.fetch_add(1, Ordering::SeqCst);
757            if self.override_definition {
758                let mut overridden = grid.clone();
759                overridden.name = "custom override".into();
760                Ok(Some(overridden))
761            } else {
762                Ok(None)
763            }
764        }
765
766        fn load(
767            &self,
768            grid: &GridDefinition,
769        ) -> std::result::Result<Option<GridHandle>, GridError> {
770            self.load_calls.fetch_add(1, Ordering::SeqCst);
771            EmbeddedGridProvider.load(grid)
772        }
773    }
774
775    fn test_grid_definition() -> GridDefinition {
776        GridDefinition {
777            id: GridId(1),
778            name: "ntv2_0.gsb".into(),
779            format: GridFormat::Ntv2,
780            interpolation: GridInterpolation::Bilinear,
781            area_of_use: None,
782            resource_names: SmallVec::from_vec(vec!["ntv2_0.gsb".into()]),
783        }
784    }
785
786    #[test]
787    fn app_grid_provider_can_override_embedded_grid() {
788        let definition_calls = Arc::new(AtomicUsize::new(0));
789        let load_calls = Arc::new(AtomicUsize::new(0));
790        let provider = TrackingGridProvider {
791            override_definition: true,
792            definition_calls: Arc::clone(&definition_calls),
793            load_calls: Arc::clone(&load_calls),
794        };
795        let runtime = GridRuntime::new(Some(Arc::new(provider)));
796
797        let handle = runtime
798            .resolve_handle(&test_grid_definition())
799            .expect("grid should resolve");
800
801        assert_eq!(handle.definition().name, "custom override");
802        assert_eq!(definition_calls.load(Ordering::SeqCst), 1);
803        assert_eq!(load_calls.load(Ordering::SeqCst), 1);
804    }
805
806    #[test]
807    fn app_grid_provider_falls_back_to_embedded_grid() {
808        let definition_calls = Arc::new(AtomicUsize::new(0));
809        let load_calls = Arc::new(AtomicUsize::new(0));
810        let provider = TrackingGridProvider {
811            override_definition: false,
812            definition_calls: Arc::clone(&definition_calls),
813            load_calls: Arc::clone(&load_calls),
814        };
815        let runtime = GridRuntime::new(Some(Arc::new(provider)));
816
817        let handle = runtime
818            .resolve_handle(&test_grid_definition())
819            .expect("embedded grid should remain available");
820
821        assert_eq!(handle.definition().name, "ntv2_0.gsb");
822        assert_eq!(definition_calls.load(Ordering::SeqCst), 1);
823        assert_eq!(load_calls.load(Ordering::SeqCst), 1);
824    }
825}