1use crate::operation::{AreaOfUse, GridId, GridInterpolation, GridShiftDirection};
2use smallvec::SmallVec;
3use std::collections::HashMap;
4use std::f64::consts::PI;
5use std::path::PathBuf;
6use std::sync::{Arc, Mutex, OnceLock};
7use thiserror::Error;
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
10pub enum GridFormat {
11 Ntv2,
12 Unsupported,
13}
14
15#[derive(Debug, Clone, PartialEq)]
16pub struct GridDefinition {
17 pub id: GridId,
18 pub name: String,
19 pub format: GridFormat,
20 pub interpolation: GridInterpolation,
21 pub area_of_use: Option<AreaOfUse>,
22 pub resource_names: SmallVec<[String; 2]>,
23}
24
25#[derive(Debug, Clone, Copy, PartialEq)]
26pub struct GridSample {
27 pub lon_shift_radians: f64,
28 pub lat_shift_radians: f64,
29}
30
31#[derive(Debug, Error, Clone)]
32pub enum GridError {
33 #[error("grid not found: {0}")]
34 NotFound(String),
35 #[error("grid resource unavailable: {0}")]
36 Unavailable(String),
37 #[error("grid parse error: {0}")]
38 Parse(String),
39 #[error("grid point outside coverage: {0}")]
40 OutsideCoverage(String),
41 #[error("unsupported grid format: {0}")]
42 UnsupportedFormat(String),
43}
44
45pub trait GridProvider: Send + Sync {
46 fn definition(
47 &self,
48 grid: &GridDefinition,
49 ) -> std::result::Result<Option<GridDefinition>, GridError>;
50 fn load(&self, grid: &GridDefinition) -> std::result::Result<Option<GridHandle>, GridError>;
51}
52
53#[derive(Clone)]
54pub struct GridHandle {
55 definition: GridDefinition,
56 data: Arc<GridData>,
57}
58
59impl GridHandle {
60 pub fn definition(&self) -> &GridDefinition {
61 &self.definition
62 }
63
64 pub fn sample(
65 &self,
66 lon_radians: f64,
67 lat_radians: f64,
68 ) -> std::result::Result<GridSample, GridError> {
69 match self.data.as_ref() {
70 GridData::Ntv2(set) => set.sample(lon_radians, lat_radians),
71 }
72 }
73
74 pub fn apply(
75 &self,
76 lon_radians: f64,
77 lat_radians: f64,
78 direction: GridShiftDirection,
79 ) -> std::result::Result<(f64, f64), GridError> {
80 match self.data.as_ref() {
81 GridData::Ntv2(set) => set.apply(lon_radians, lat_radians, direction),
82 }
83 }
84}
85
86pub(crate) struct GridRuntime {
87 providers: Vec<Arc<dyn GridProvider>>,
88 definition_cache: Mutex<HashMap<String, GridDefinition>>,
89 handle_cache: Mutex<HashMap<String, GridHandle>>,
90}
91
92impl GridRuntime {
93 pub(crate) fn new(app_provider: Option<Arc<dyn GridProvider>>) -> Self {
94 let mut providers: Vec<Arc<dyn GridProvider>> = Vec::with_capacity(2);
95 if let Some(provider) = app_provider {
96 providers.push(provider);
97 }
98 providers.push(Arc::new(EmbeddedGridProvider));
99 Self {
100 providers,
101 definition_cache: Mutex::new(HashMap::new()),
102 handle_cache: Mutex::new(HashMap::new()),
103 }
104 }
105
106 pub(crate) fn resolve_definition(
107 &self,
108 grid: &GridDefinition,
109 ) -> std::result::Result<GridDefinition, GridError> {
110 let cache_key = grid_runtime_cache_key(grid);
111 if let Some(cached) = self
112 .definition_cache
113 .lock()
114 .expect("grid definition cache poisoned")
115 .get(&cache_key)
116 .cloned()
117 {
118 return Ok(cached);
119 }
120
121 for provider in &self.providers {
122 if let Some(definition) = provider.definition(grid)? {
123 self.definition_cache
124 .lock()
125 .expect("grid definition cache poisoned")
126 .insert(cache_key, definition.clone());
127 return Ok(definition);
128 }
129 }
130
131 Err(GridError::Unavailable(grid.name.clone()))
132 }
133
134 pub(crate) fn resolve_handle(
135 &self,
136 grid: &GridDefinition,
137 ) -> std::result::Result<GridHandle, GridError> {
138 let cache_key = grid_runtime_cache_key(grid);
139 if let Some(cached) = self
140 .handle_cache
141 .lock()
142 .expect("grid handle cache poisoned")
143 .get(&cache_key)
144 .cloned()
145 {
146 return Ok(cached);
147 }
148
149 let definition = self.resolve_definition(grid)?;
150 for provider in &self.providers {
151 if let Some(handle) = provider.load(&definition)? {
152 self.handle_cache
153 .lock()
154 .expect("grid handle cache poisoned")
155 .insert(cache_key, handle.clone());
156 return Ok(handle);
157 }
158 }
159
160 Err(GridError::Unavailable(definition.name))
161 }
162}
163
164fn grid_runtime_cache_key(grid: &GridDefinition) -> String {
165 let mut key = format!("{}|{:?}", grid.id.0, grid.format);
166 for resource in &grid.resource_names {
167 key.push('|');
168 key.push_str(resource);
169 }
170 key
171}
172
173#[derive(Default)]
174pub struct EmbeddedGridProvider;
175
176impl GridProvider for EmbeddedGridProvider {
177 fn definition(
178 &self,
179 grid: &GridDefinition,
180 ) -> std::result::Result<Option<GridDefinition>, GridError> {
181 if embedded_grid_resource(&grid.resource_names).is_some() {
182 return Ok(Some(grid.clone()));
183 }
184 Ok(None)
185 }
186
187 fn load(&self, grid: &GridDefinition) -> std::result::Result<Option<GridHandle>, GridError> {
188 let Some((resource_name, bytes)) = embedded_grid_resource(&grid.resource_names) else {
189 return Ok(None);
190 };
191
192 let key = GridDataCacheKey::new(grid.format, resource_name);
193 let data = cached_grid_data(embedded_grid_data_cache(), key, || {
194 parse_grid_data(grid.format, &grid.name, bytes)
195 })?;
196
197 Ok(Some(GridHandle {
198 definition: grid.clone(),
199 data,
200 }))
201 }
202}
203
204pub struct FilesystemGridProvider {
205 roots: Vec<PathBuf>,
206 data_cache: Mutex<HashMap<GridDataCacheKey, Arc<GridData>>>,
207}
208
209impl FilesystemGridProvider {
210 pub fn new<I>(roots: I) -> Self
211 where
212 I: IntoIterator<Item = PathBuf>,
213 {
214 Self {
215 roots: roots.into_iter().collect(),
216 data_cache: Mutex::new(HashMap::new()),
217 }
218 }
219
220 fn locate(&self, grid: &GridDefinition) -> Option<PathBuf> {
221 for root in &self.roots {
222 for name in &grid.resource_names {
223 let candidate = root.join(name);
224 if candidate.exists() {
225 return Some(candidate);
226 }
227 }
228 }
229 None
230 }
231}
232
233impl GridProvider for FilesystemGridProvider {
234 fn definition(
235 &self,
236 grid: &GridDefinition,
237 ) -> std::result::Result<Option<GridDefinition>, GridError> {
238 if self.locate(grid).is_some() {
239 return Ok(Some(grid.clone()));
240 }
241 Ok(None)
242 }
243
244 fn load(&self, grid: &GridDefinition) -> std::result::Result<Option<GridHandle>, GridError> {
245 let Some(path) = self.locate(grid) else {
246 return Ok(None);
247 };
248
249 let cache_path = path.canonicalize().unwrap_or_else(|_| path.clone());
250 let key = GridDataCacheKey::new(grid.format, cache_path.to_string_lossy());
251 let data = cached_grid_data(&self.data_cache, key, || {
252 let bytes = std::fs::read(&path)
253 .map_err(|err| GridError::Unavailable(format!("{}: {err}", path.display())))?;
254 parse_grid_data(grid.format, &grid.name, &bytes)
255 })?;
256
257 Ok(Some(GridHandle {
258 definition: grid.clone(),
259 data,
260 }))
261 }
262}
263
264enum GridData {
265 Ntv2(Ntv2GridSet),
266}
267
268#[derive(Debug, Clone, PartialEq, Eq, Hash)]
269struct GridDataCacheKey {
270 format: GridFormat,
271 resource: String,
272}
273
274impl GridDataCacheKey {
275 fn new(format: GridFormat, resource: impl AsRef<str>) -> Self {
276 Self {
277 format,
278 resource: resource.as_ref().to_string(),
279 }
280 }
281}
282
283fn embedded_grid_data_cache() -> &'static Mutex<HashMap<GridDataCacheKey, Arc<GridData>>> {
284 static CACHE: OnceLock<Mutex<HashMap<GridDataCacheKey, Arc<GridData>>>> = OnceLock::new();
285 CACHE.get_or_init(|| Mutex::new(HashMap::new()))
286}
287
288fn cached_grid_data(
289 cache: &Mutex<HashMap<GridDataCacheKey, Arc<GridData>>>,
290 key: GridDataCacheKey,
291 parse: impl FnOnce() -> std::result::Result<GridData, GridError>,
292) -> std::result::Result<Arc<GridData>, GridError> {
293 if let Some(cached) = cache
294 .lock()
295 .expect("grid data cache poisoned")
296 .get(&key)
297 .cloned()
298 {
299 return Ok(cached);
300 }
301
302 let parsed = Arc::new(parse()?);
303 let mut cache = cache.lock().expect("grid data cache poisoned");
304 let cached = cache.entry(key).or_insert_with(|| Arc::clone(&parsed));
305 Ok(Arc::clone(cached))
306}
307
308fn parse_grid_data(
309 format: GridFormat,
310 name: &str,
311 bytes: &[u8],
312) -> std::result::Result<GridData, GridError> {
313 match format {
314 GridFormat::Ntv2 => Ok(GridData::Ntv2(Ntv2GridSet::parse(bytes)?)),
315 GridFormat::Unsupported => Err(GridError::UnsupportedFormat(name.into())),
316 }
317}
318
319fn embedded_grid_resource(names: &[String]) -> Option<(&'static str, &'static [u8])> {
320 for name in names {
321 if name.eq_ignore_ascii_case("ntv2_0.gsb") {
322 return Some(("ntv2_0.gsb", include_bytes!("../data/grids/ntv2_0.gsb")));
323 }
324 }
325 None
326}
327
328#[derive(Clone)]
329struct Ntv2GridSet {
330 grids: Vec<Ntv2Grid>,
331 roots: Vec<usize>,
332}
333
334impl Ntv2GridSet {
335 fn parse(bytes: &[u8]) -> std::result::Result<Self, GridError> {
336 const HEADER_LEN: usize = 11 * 16;
337
338 if bytes.len() < HEADER_LEN {
339 return Err(GridError::Parse("NTv2 file too small".into()));
340 }
341
342 let endian = if u32::from_le_bytes(bytes[8..12].try_into().expect("slice length checked"))
343 == 11
344 {
345 Endian::Little
346 } else if u32::from_be_bytes(bytes[8..12].try_into().expect("slice length checked")) == 11 {
347 Endian::Big
348 } else {
349 return Err(GridError::Parse(
350 "invalid NTv2 header endianness marker".into(),
351 ));
352 };
353
354 if &bytes[56..63] != b"SECONDS" {
355 return Err(GridError::Parse(
356 "only NTv2 GS_TYPE=SECONDS is supported".into(),
357 ));
358 }
359
360 let num_subfiles = read_u32(bytes, 40, endian)? as usize;
361 let mut offset = HEADER_LEN;
362 let mut grids = Vec::with_capacity(num_subfiles);
363 let mut name_to_index = HashMap::new();
364 let mut parent_links: Vec<Option<String>> = Vec::with_capacity(num_subfiles);
365
366 for _ in 0..num_subfiles {
367 let header = bytes
368 .get(offset..offset + HEADER_LEN)
369 .ok_or_else(|| GridError::Parse("truncated NTv2 subfile header".into()))?;
370 if &header[0..8] != b"SUB_NAME" {
371 return Err(GridError::Parse("invalid NTv2 subfile header tag".into()));
372 }
373
374 let name = parse_label(&header[8..16]);
375 let parent = parse_label(&header[24..32]);
376 let south = read_f64(header, 72, endian)? * PI / 180.0 / 3600.0;
377 let north = read_f64(header, 88, endian)? * PI / 180.0 / 3600.0;
378 let east = -read_f64(header, 104, endian)? * PI / 180.0 / 3600.0;
379 let west = -read_f64(header, 120, endian)? * PI / 180.0 / 3600.0;
380 let res_y = read_f64(header, 136, endian)? * PI / 180.0 / 3600.0;
381 let res_x = read_f64(header, 152, endian)? * PI / 180.0 / 3600.0;
382 let gs_count = read_u32(header, 168, endian)? as usize;
383
384 if !(west < east && south < north && res_x > 0.0 && res_y > 0.0) {
385 return Err(GridError::Parse(format!(
386 "invalid NTv2 georeferencing for subgrid {name}"
387 )));
388 }
389
390 let width = (((east - west) / res_x).abs() + 0.5).floor() as usize + 1;
391 let height = (((north - south) / res_y).abs() + 0.5).floor() as usize + 1;
392 if width * height != gs_count {
393 return Err(GridError::Parse(format!(
394 "NTv2 subgrid {name} cell count mismatch: expected {} got {gs_count}",
395 width * height
396 )));
397 }
398
399 let data_len = gs_count
400 .checked_mul(4)
401 .and_then(|count| count.checked_mul(4))
402 .ok_or_else(|| GridError::Parse("NTv2 data size overflow".into()))?;
403 let data = bytes
404 .get(offset + HEADER_LEN..offset + HEADER_LEN + data_len)
405 .ok_or_else(|| {
406 GridError::Parse(format!("truncated NTv2 data for subgrid {name}"))
407 })?;
408
409 let mut lat_shift = vec![0.0f64; gs_count];
410 let mut lon_shift = vec![0.0f64; gs_count];
411 for y in 0..height {
412 for x in 0..width {
413 let source_x = width - 1 - x;
414 let record_offset = (y * width + source_x) * 16;
415 let lat = read_f32(data, record_offset, endian)? as f64 * PI / 180.0 / 3600.0;
416 let lon =
417 -(read_f32(data, record_offset + 4, endian)? as f64) * PI / 180.0 / 3600.0;
418 let dest = y * width + x;
419 lat_shift[dest] = lat;
420 lon_shift[dest] = lon;
421 }
422 }
423
424 let index = grids.len();
425 name_to_index.insert(name.clone(), index);
426 parent_links.push(
427 if parent.eq_ignore_ascii_case("none") || parent.is_empty() {
428 None
429 } else {
430 Some(parent)
431 },
432 );
433 grids.push(Ntv2Grid {
434 name,
435 extent: GridExtent {
436 west,
437 south,
438 east,
439 north,
440 res_x,
441 res_y,
442 },
443 width,
444 height,
445 lat_shift,
446 lon_shift,
447 children: Vec::new(),
448 });
449 offset += HEADER_LEN + data_len;
450 }
451
452 let mut roots = Vec::new();
453 for (idx, parent) in parent_links.into_iter().enumerate() {
454 if let Some(parent_name) = parent {
455 let Some(parent_idx) = name_to_index.get(&parent_name).copied() else {
456 return Err(GridError::Parse(format!(
457 "missing NTv2 parent subgrid {parent_name} for {}",
458 grids[idx].name
459 )));
460 };
461 grids[parent_idx].children.push(idx);
462 } else {
463 roots.push(idx);
464 }
465 }
466
467 Ok(Self { grids, roots })
468 }
469
470 fn sample(
471 &self,
472 lon_radians: f64,
473 lat_radians: f64,
474 ) -> std::result::Result<GridSample, GridError> {
475 let (grid_idx, local_lon, local_lat) = self.grid_at(lon_radians, lat_radians)?;
476 let (lon_shift, lat_shift) = interpolate(&self.grids[grid_idx], local_lon, local_lat)?;
477 Ok(GridSample {
478 lon_shift_radians: lon_shift,
479 lat_shift_radians: lat_shift,
480 })
481 }
482
483 fn apply(
484 &self,
485 lon_radians: f64,
486 lat_radians: f64,
487 direction: GridShiftDirection,
488 ) -> std::result::Result<(f64, f64), GridError> {
489 match direction {
490 GridShiftDirection::Forward => {
491 let shift = self.sample(lon_radians, lat_radians)?;
492 Ok((
493 lon_radians + shift.lon_shift_radians,
494 lat_radians + shift.lat_shift_radians,
495 ))
496 }
497 GridShiftDirection::Reverse => self.apply_inverse(lon_radians, lat_radians),
498 }
499 }
500
501 fn apply_inverse(
502 &self,
503 lon_radians: f64,
504 lat_radians: f64,
505 ) -> std::result::Result<(f64, f64), GridError> {
506 const MAX_ITERATIONS: usize = 10;
507 const TOLERANCE: f64 = 1e-12;
508
509 let mut estimate_lon = lon_radians;
510 let mut estimate_lat = lat_radians;
511
512 for _ in 0..MAX_ITERATIONS {
513 let shift = self.sample(estimate_lon, estimate_lat)?;
514 let next_lon = lon_radians - shift.lon_shift_radians;
515 let next_lat = lat_radians - shift.lat_shift_radians;
516 let diff_lon = next_lon - estimate_lon;
517 let diff_lat = next_lat - estimate_lat;
518 estimate_lon = next_lon;
519 estimate_lat = next_lat;
520 if diff_lon * diff_lon + diff_lat * diff_lat <= TOLERANCE * TOLERANCE {
521 return Ok((estimate_lon, estimate_lat));
522 }
523 }
524
525 Ok((estimate_lon, estimate_lat))
526 }
527
528 fn grid_at(
529 &self,
530 lon_radians: f64,
531 lat_radians: f64,
532 ) -> std::result::Result<(usize, f64, f64), GridError> {
533 for &root in &self.roots {
534 if self.grids[root].extent.contains(lon_radians, lat_radians) {
535 let idx = self.deepest_child(root, lon_radians, lat_radians);
536 let extent = &self.grids[idx].extent;
537 return Ok((idx, lon_radians - extent.west, lat_radians - extent.south));
538 }
539 }
540 Err(GridError::OutsideCoverage(format!(
541 "longitude {:.8} latitude {:.8}",
542 lon_radians.to_degrees(),
543 lat_radians.to_degrees()
544 )))
545 }
546
547 fn deepest_child(&self, index: usize, lon_radians: f64, lat_radians: f64) -> usize {
548 for &child in &self.grids[index].children {
549 if self.grids[child].extent.contains(lon_radians, lat_radians) {
550 return self.deepest_child(child, lon_radians, lat_radians);
551 }
552 }
553 index
554 }
555}
556
557#[derive(Clone)]
558struct Ntv2Grid {
559 name: String,
560 extent: GridExtent,
561 width: usize,
562 height: usize,
563 lat_shift: Vec<f64>,
564 lon_shift: Vec<f64>,
565 children: Vec<usize>,
566}
567
568#[derive(Clone, Copy)]
569struct GridExtent {
570 west: f64,
571 south: f64,
572 east: f64,
573 north: f64,
574 res_x: f64,
575 res_y: f64,
576}
577
578impl GridExtent {
579 fn contains(&self, lon_radians: f64, lat_radians: f64) -> bool {
580 let epsilon = (self.res_x + self.res_y) * 1e-10;
581 lon_radians >= self.west - epsilon
582 && lon_radians <= self.east + epsilon
583 && lat_radians >= self.south - epsilon
584 && lat_radians <= self.north + epsilon
585 }
586}
587
588fn interpolate(
589 grid: &Ntv2Grid,
590 local_lon: f64,
591 local_lat: f64,
592) -> std::result::Result<(f64, f64), GridError> {
593 let lam = local_lon / grid.extent.res_x;
594 let phi = local_lat / grid.extent.res_y;
595 let mut x = lam.floor() as isize;
596 let mut y = phi.floor() as isize;
597 let mut fx = lam - x as f64;
598 let mut fy = phi - y as f64;
599
600 if x < 0 {
601 if x == -1 && fx > 1.0 - 1e-9 {
602 x = 0;
603 fx = 0.0;
604 } else {
605 return Err(GridError::OutsideCoverage(grid.name.clone()));
606 }
607 }
608 if y < 0 {
609 if y == -1 && fy > 1.0 - 1e-9 {
610 y = 0;
611 fy = 0.0;
612 } else {
613 return Err(GridError::OutsideCoverage(grid.name.clone()));
614 }
615 }
616 if x as usize + 1 >= grid.width {
617 if x as usize + 1 == grid.width && fx < 1e-9 {
618 x -= 1;
619 fx = 1.0;
620 } else {
621 return Err(GridError::OutsideCoverage(grid.name.clone()));
622 }
623 }
624 if y as usize + 1 >= grid.height {
625 if y as usize + 1 == grid.height && fy < 1e-9 {
626 y -= 1;
627 fy = 1.0;
628 } else {
629 return Err(GridError::OutsideCoverage(grid.name.clone()));
630 }
631 }
632
633 let idx = |xx: usize, yy: usize| yy * grid.width + xx;
634 let x0 = x as usize;
635 let y0 = y as usize;
636 let x1 = x0 + 1;
637 let y1 = y0 + 1;
638
639 let m00 = (1.0 - fx) * (1.0 - fy);
640 let m10 = fx * (1.0 - fy);
641 let m01 = (1.0 - fx) * fy;
642 let m11 = fx * fy;
643
644 let lon = m00 * grid.lon_shift[idx(x0, y0)]
645 + m10 * grid.lon_shift[idx(x1, y0)]
646 + m01 * grid.lon_shift[idx(x0, y1)]
647 + m11 * grid.lon_shift[idx(x1, y1)];
648 let lat = m00 * grid.lat_shift[idx(x0, y0)]
649 + m10 * grid.lat_shift[idx(x1, y0)]
650 + m01 * grid.lat_shift[idx(x0, y1)]
651 + m11 * grid.lat_shift[idx(x1, y1)];
652
653 Ok((lon, lat))
654}
655
656#[derive(Clone, Copy)]
657enum Endian {
658 Little,
659 Big,
660}
661
662fn parse_label(bytes: &[u8]) -> String {
663 let end = bytes
664 .iter()
665 .position(|byte| *byte == 0)
666 .unwrap_or(bytes.len());
667 String::from_utf8_lossy(&bytes[..end]).trim().to_string()
668}
669
670fn read_u32(bytes: &[u8], offset: usize, endian: Endian) -> std::result::Result<u32, GridError> {
671 let slice = bytes
672 .get(offset..offset + 4)
673 .ok_or_else(|| GridError::Parse("truncated integer".into()))?;
674 Ok(match endian {
675 Endian::Little => u32::from_le_bytes(slice.try_into().expect("slice length checked")),
676 Endian::Big => u32::from_be_bytes(slice.try_into().expect("slice length checked")),
677 })
678}
679
680fn read_f64(bytes: &[u8], offset: usize, endian: Endian) -> std::result::Result<f64, GridError> {
681 let slice = bytes
682 .get(offset..offset + 8)
683 .ok_or_else(|| GridError::Parse("truncated float64".into()))?;
684 Ok(match endian {
685 Endian::Little => f64::from_le_bytes(slice.try_into().expect("slice length checked")),
686 Endian::Big => f64::from_be_bytes(slice.try_into().expect("slice length checked")),
687 })
688}
689
690fn read_f32(bytes: &[u8], offset: usize, endian: Endian) -> std::result::Result<f32, GridError> {
691 let slice = bytes
692 .get(offset..offset + 4)
693 .ok_or_else(|| GridError::Parse("truncated float32".into()))?;
694 Ok(match endian {
695 Endian::Little => f32::from_le_bytes(slice.try_into().expect("slice length checked")),
696 Endian::Big => f32::from_be_bytes(slice.try_into().expect("slice length checked")),
697 })
698}
699
700#[cfg(test)]
701mod tests {
702 use super::*;
703 use std::sync::atomic::{AtomicUsize, Ordering};
704
705 #[test]
706 fn embedded_ntv2_grid_samples_known_point() {
707 let provider = EmbeddedGridProvider;
708 let definition = GridDefinition {
709 id: GridId(1),
710 name: "ntv2_0.gsb".into(),
711 format: GridFormat::Ntv2,
712 interpolation: GridInterpolation::Bilinear,
713 area_of_use: None,
714 resource_names: SmallVec::from_vec(vec!["ntv2_0.gsb".into()]),
715 };
716 let handle = provider.load(&definition).unwrap().expect("embedded grid");
717 let (lon, lat) = handle
718 .apply(
719 (-80.5041667f64).to_radians(),
720 44.5458333f64.to_radians(),
721 GridShiftDirection::Forward,
722 )
723 .unwrap();
724 assert!(
725 (lon.to_degrees() - (-80.50401615833)).abs() < 1e-6,
726 "lon={lon}"
727 );
728 assert!((lat.to_degrees() - 44.5458827236).abs() < 3e-6, "lat={lat}");
729 }
730
731 #[test]
732 fn embedded_provider_reuses_parsed_grid_data() {
733 let provider = EmbeddedGridProvider;
734 let definition = test_grid_definition();
735
736 let first = provider.load(&definition).unwrap().expect("embedded grid");
737 let mut renamed = definition.clone();
738 renamed.name = "renamed ntv2 grid".into();
739 let second = provider.load(&renamed).unwrap().expect("embedded grid");
740
741 assert!(Arc::ptr_eq(&first.data, &second.data));
742 assert_eq!(second.definition().name, "renamed ntv2 grid");
743 }
744
745 struct TrackingGridProvider {
746 override_definition: bool,
747 definition_calls: Arc<AtomicUsize>,
748 load_calls: Arc<AtomicUsize>,
749 }
750
751 impl GridProvider for TrackingGridProvider {
752 fn definition(
753 &self,
754 grid: &GridDefinition,
755 ) -> std::result::Result<Option<GridDefinition>, GridError> {
756 self.definition_calls.fetch_add(1, Ordering::SeqCst);
757 if self.override_definition {
758 let mut overridden = grid.clone();
759 overridden.name = "custom override".into();
760 Ok(Some(overridden))
761 } else {
762 Ok(None)
763 }
764 }
765
766 fn load(
767 &self,
768 grid: &GridDefinition,
769 ) -> std::result::Result<Option<GridHandle>, GridError> {
770 self.load_calls.fetch_add(1, Ordering::SeqCst);
771 EmbeddedGridProvider.load(grid)
772 }
773 }
774
775 fn test_grid_definition() -> GridDefinition {
776 GridDefinition {
777 id: GridId(1),
778 name: "ntv2_0.gsb".into(),
779 format: GridFormat::Ntv2,
780 interpolation: GridInterpolation::Bilinear,
781 area_of_use: None,
782 resource_names: SmallVec::from_vec(vec!["ntv2_0.gsb".into()]),
783 }
784 }
785
786 #[test]
787 fn app_grid_provider_can_override_embedded_grid() {
788 let definition_calls = Arc::new(AtomicUsize::new(0));
789 let load_calls = Arc::new(AtomicUsize::new(0));
790 let provider = TrackingGridProvider {
791 override_definition: true,
792 definition_calls: Arc::clone(&definition_calls),
793 load_calls: Arc::clone(&load_calls),
794 };
795 let runtime = GridRuntime::new(Some(Arc::new(provider)));
796
797 let handle = runtime
798 .resolve_handle(&test_grid_definition())
799 .expect("grid should resolve");
800
801 assert_eq!(handle.definition().name, "custom override");
802 assert_eq!(definition_calls.load(Ordering::SeqCst), 1);
803 assert_eq!(load_calls.load(Ordering::SeqCst), 1);
804 }
805
806 #[test]
807 fn app_grid_provider_falls_back_to_embedded_grid() {
808 let definition_calls = Arc::new(AtomicUsize::new(0));
809 let load_calls = Arc::new(AtomicUsize::new(0));
810 let provider = TrackingGridProvider {
811 override_definition: false,
812 definition_calls: Arc::clone(&definition_calls),
813 load_calls: Arc::clone(&load_calls),
814 };
815 let runtime = GridRuntime::new(Some(Arc::new(provider)));
816
817 let handle = runtime
818 .resolve_handle(&test_grid_definition())
819 .expect("embedded grid should remain available");
820
821 assert_eq!(handle.definition().name, "ntv2_0.gsb");
822 assert_eq!(definition_calls.load(Ordering::SeqCst), 1);
823 assert_eq!(load_calls.load(Ordering::SeqCst), 1);
824 }
825}