use std::collections::{BTreeMap, HashSet};
use std::io::{Read, Seek};
use grib::LatLons as _;
use grib::codetables::Code;
use grib::codetables::grib2::Table4_4;
use grib::def::grib2::RefTime;
use rayon::prelude::*;
use crate::wind_map::{TimedWindMap, WindMap};
use crate::{WeatherRow, WindSample};
const DISCIPLINE_METEOROLOGICAL: u8 = 0;
const PARAM_CATEGORY_MOMENTUM: u8 = 2;
const PARAM_NUMBER_UGRD: u8 = 2;
const PARAM_NUMBER_VGRD: u8 = 3;
const SURFACE_HEIGHT_ABOVE_GROUND: u8 = 103;
const TARGET_HEIGHT_METRES: f64 = 10.0;
const HEIGHT_TOLERANCE_METRES: f64 = 0.5;
#[derive(Clone, Copy, Debug)]
pub struct Grib2Bbox {
pub lat_min: f32,
pub lat_max: f32,
pub lon_min: f32,
pub lon_max: f32,
}
impl Grib2Bbox {
fn contains(self, lat: f32, lon: f32) -> bool {
lat >= self.lat_min && lat <= self.lat_max && lon >= self.lon_min && lon <= self.lon_max
}
}
#[derive(Debug)]
#[non_exhaustive]
pub enum LoadError {
Grib(grib::GribError),
NoFrames,
}
impl std::fmt::Display for LoadError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Grib(e) => write!(f, "GRIB2 error: {e}"),
Self::NoFrames => write!(
f,
"GRIB2 file contained no complete UGRD/VGRD pair at 10 m above ground",
),
}
}
}
impl std::error::Error for LoadError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Self::Grib(e) => Some(e),
Self::NoFrames => None,
}
}
}
impl From<grib::GribError> for LoadError {
fn from(e: grib::GribError) -> Self {
Self::Grib(e)
}
}
#[derive(Default)]
struct Pair {
u: Option<Vec<f32>>,
v: Option<Vec<f32>>,
}
struct GridCache {
latlons: Vec<(f32, f32)>,
kept_axes: Option<(HashSet<u32>, HashSet<u32>)>,
}
struct PendingSubmessage {
absolute_t: i64,
is_u: bool,
decoder: grib::Grib2SubmessageDecoder,
}
impl TimedWindMap {
pub fn from_grib2_reader<R: Read + Seek>(
reader: R,
stride: usize,
bbox: Option<Grib2Bbox>,
) -> Result<Self, LoadError> {
let stride = stride.max(1);
let total_start = std::time::Instant::now();
let parse_start = std::time::Instant::now();
let grib2 = grib::from_reader(reader)?;
let parse_elapsed = parse_start.elapsed();
let phase1_start = std::time::Instant::now();
let mut pending: Vec<PendingSubmessage> = Vec::new();
let mut cache: Option<GridCache> = None;
for (_index, sub) in &grib2 {
if sub.indicator().discipline != DISCIPLINE_METEOROLOGICAL {
continue;
}
let pd = sub.prod_def();
let Some(category) = pd.parameter_category() else {
continue;
};
let Some(number) = pd.parameter_number() else {
continue;
};
if category != PARAM_CATEGORY_MOMENTUM {
continue;
}
let is_u = number == PARAM_NUMBER_UGRD;
let is_v = number == PARAM_NUMBER_VGRD;
if !is_u && !is_v {
continue;
}
let Some((first_surface, _second_surface)) = pd.fixed_surfaces() else {
continue;
};
if first_surface.surface_type != SURFACE_HEIGHT_ABOVE_GROUND {
continue;
}
if (first_surface.value() - TARGET_HEIGHT_METRES).abs() > HEIGHT_TOLERANCE_METRES {
continue;
}
let Some(forecast_time) = pd.forecast_time() else {
continue;
};
let Some(forecast_seconds) =
forecast_unit_to_seconds(&forecast_time.unit, forecast_time.value)
else {
log::warn!(
"GRIB2: skipping submessage with unsupported forecast-time unit {:?}",
forecast_time.unit,
);
continue;
};
let ref_time = sub.identification().ref_time_unchecked();
let Some(ref_seconds) = ref_time_to_unix_seconds(&ref_time) else {
log::warn!("GRIB2: skipping submessage with invalid reference time {ref_time:?}");
continue;
};
let absolute_t = ref_seconds.saturating_add(i64::from(forecast_seconds));
if cache.is_none() {
let latlons: Vec<(f32, f32)> = match sub.latlons() {
Ok(iter) => iter.collect(),
Err(e) => {
log::warn!("GRIB2: skipping submessage at t={absolute_t}s ({e})");
continue;
}
};
let kept_axes = (stride > 1).then(|| compute_kept_axes(&latlons, stride, bbox));
cache = Some(GridCache { latlons, kept_axes });
}
let decoder = match grib::Grib2SubmessageDecoder::from(sub) {
Ok(d) => d,
Err(e) => {
log::warn!("GRIB2: skipping submessage at t={absolute_t}s ({e})");
continue;
}
};
pending.push(PendingSubmessage {
absolute_t,
is_u,
decoder,
});
}
let phase1_elapsed = phase1_start.elapsed();
let pending_count = pending.len();
let Some(cache) = cache else {
return Err(LoadError::NoFrames);
};
let phase2_start = std::time::Instant::now();
let decoded = decode_pending_parallel(pending, cache.latlons.len());
let phase2_elapsed = phase2_start.elapsed();
let decoded_count = decoded.len();
let phase3_start = std::time::Instant::now();
let (times, frames, build_rows_total, wind_map_new_total) =
assemble_frames_parallel(decoded, &cache, bbox);
let phase3_elapsed = phase3_start.elapsed();
if frames.is_empty() {
return Err(LoadError::NoFrames);
}
let step_seconds = compute_median_step_seconds(×);
let mut map = Self::new(frames, step_seconds);
if let (Some(&min_unix), Some(&max_unix)) = (times.first(), times.last())
&& let (Some(start), Some(end)) = (
chrono::DateTime::<chrono::Utc>::from_timestamp(min_unix, 0),
chrono::DateTime::<chrono::Utc>::from_timestamp(max_unix, 0),
)
{
map = map.with_time_range(start, end);
}
log_grib2_load_summary(&LoadSummary {
total_elapsed: total_start.elapsed(),
parse_elapsed,
phase1_elapsed,
phase2_elapsed,
phase3_elapsed,
grid_size: cache.latlons.len(),
frame_count: map.frame_count(),
pending_count,
decoded_count,
build_rows_total,
wind_map_new_total,
});
Ok(map)
}
}
struct LoadSummary {
total_elapsed: std::time::Duration,
parse_elapsed: std::time::Duration,
phase1_elapsed: std::time::Duration,
phase2_elapsed: std::time::Duration,
phase3_elapsed: std::time::Duration,
grid_size: usize,
frame_count: usize,
pending_count: usize,
decoded_count: usize,
build_rows_total: std::time::Duration,
wind_map_new_total: std::time::Duration,
}
fn decode_pending_parallel(
pending: Vec<PendingSubmessage>,
expected_len: usize,
) -> Vec<(i64, bool, Vec<f32>)> {
pending
.into_par_iter()
.filter_map(|p| {
let values: Vec<f32> = match p.decoder.dispatch() {
Ok(iter) => iter.collect(),
Err(e) => {
log::warn!(
"GRIB2: skipping submessage at t={}s ({:?})",
p.absolute_t,
e,
);
return None;
}
};
if values.len() != expected_len {
log::warn!(
"GRIB2: values ({}) / grid ({}) length mismatch at t={}s, skipping",
values.len(),
expected_len,
p.absolute_t,
);
return None;
}
Some((p.absolute_t, p.is_u, values))
})
.collect()
}
fn assemble_frames_parallel(
decoded: Vec<(i64, bool, Vec<f32>)>,
cache: &GridCache,
bbox: Option<Grib2Bbox>,
) -> (
Vec<i64>,
Vec<WindMap>,
std::time::Duration,
std::time::Duration,
) {
use std::sync::atomic::{AtomicU64, Ordering};
let mut by_time: BTreeMap<i64, Pair> = BTreeMap::new();
for (absolute_t, is_u, values) in decoded {
let entry = by_time.entry(absolute_t).or_default();
if is_u {
entry.u = Some(values);
} else {
entry.v = Some(values);
}
}
let by_time_vec: Vec<(i64, Pair)> = by_time.into_iter().collect();
let build_rows_total = AtomicU64::new(0);
let wind_map_new_total = AtomicU64::new(0);
let frames_and_times: Vec<(i64, WindMap)> = by_time_vec
.into_par_iter()
.filter_map(|(absolute_t, pair)| {
let (Some(u), Some(v)) = (pair.u, pair.v) else {
log::warn!("GRIB2: incomplete UGRD/VGRD pair at t={absolute_t}s, skipping",);
return None;
};
let br_start = std::time::Instant::now();
let rows = build_rows(&u, &v, &cache.latlons, bbox, cache.kept_axes.as_ref());
build_rows_total.fetch_add(br_start.elapsed().as_nanos() as u64, Ordering::Relaxed);
if rows.is_empty() {
log::warn!(
"GRIB2: no points survived bbox/stride filter at t={absolute_t}s, skipping",
);
return None;
}
let wm_start = std::time::Instant::now();
let map = WindMap::new(rows);
wind_map_new_total.fetch_add(wm_start.elapsed().as_nanos() as u64, Ordering::Relaxed);
Some((absolute_t, map))
})
.collect();
let (times, frames): (Vec<i64>, Vec<WindMap>) = frames_and_times.into_iter().unzip();
let build_rows_total =
std::time::Duration::from_nanos(build_rows_total.load(Ordering::Relaxed));
let wind_map_new_total =
std::time::Duration::from_nanos(wind_map_new_total.load(Ordering::Relaxed));
(times, frames, build_rows_total, wind_map_new_total)
}
fn compute_median_step_seconds(times: &[i64]) -> f32 {
if times.len() < 2 {
return 1.0;
}
let mut gaps: Vec<i64> = times
.windows(2)
.map(|w| w[1] - w[0])
.filter(|&g| g > 0)
.collect();
gaps.sort_unstable();
let median = gaps.get(gaps.len() / 2).copied().unwrap_or(1);
median as f32
}
fn log_grib2_load_summary(s: &LoadSummary) {
let summary = format!(
"GRIB2 load: total={:.2}s (parse={:.2}s, phase1={:.2}s, phase2={:.2}s, phase3={:.2}s) \
grid={} frames={} pending={} decoded={} \
build_rows_cpu_sum={:.2}s wind_map_new_cpu_sum={:.2}s\n",
s.total_elapsed.as_secs_f64(),
s.parse_elapsed.as_secs_f64(),
s.phase1_elapsed.as_secs_f64(),
s.phase2_elapsed.as_secs_f64(),
s.phase3_elapsed.as_secs_f64(),
s.grid_size,
s.frame_count,
s.pending_count,
s.decoded_count,
s.build_rows_total.as_secs_f64(),
s.wind_map_new_total.as_secs_f64(),
);
log::info!("{}", summary.trim_end());
let log_path = std::env::temp_dir().join("bywind-grib2-load.log");
if let Ok(mut f) = std::fs::OpenOptions::new()
.create(true)
.append(true)
.open(&log_path)
{
use std::io::Write as _;
drop(f.write_all(summary.as_bytes()));
}
}
fn ref_time_to_unix_seconds(rt: &RefTime) -> Option<i64> {
let days = days_from_civil(i32::from(rt.year), rt.month, rt.day)?;
if rt.hour > 23 || rt.minute > 59 || rt.second > 60 {
return None;
}
Some(
days * 86_400
+ i64::from(rt.hour) * 3_600
+ i64::from(rt.minute) * 60
+ i64::from(rt.second),
)
}
fn days_from_civil(y: i32, m: u8, d: u8) -> Option<i64> {
if !(1..=12).contains(&m) || !(1..=31).contains(&d) {
return None;
}
let y = if m <= 2 { y - 1 } else { y };
let era = if y >= 0 { y } else { y - 399 } / 400;
let yoe = (y - era * 400) as i64; let m = i64::from(m);
let d = i64::from(d);
let doy = (153 * (if m > 2 { m - 3 } else { m + 9 }) + 2) / 5 + d - 1; let doe = yoe * 365 + yoe / 4 - yoe / 100 + doy;
Some(i64::from(era) * 146_097 + doe - 719_468)
}
fn forecast_unit_to_seconds(unit: &Code<Table4_4, u8>, value: u32) -> Option<u32> {
let Code::Name(unit) = unit else { return None };
let multiplier: u32 = match unit {
Table4_4::Second => 1,
Table4_4::Minute => 60,
Table4_4::Hour => 3600,
Table4_4::ThreeHours => 3 * 3600,
Table4_4::SixHours => 6 * 3600,
Table4_4::TwelveHours => 12 * 3600,
Table4_4::Day => 86_400,
Table4_4::Month
| Table4_4::Year
| Table4_4::Decade
| Table4_4::Normal
| Table4_4::Century
| Table4_4::Missing => return None,
};
value.checked_mul(multiplier)
}
fn build_rows(
u_values: &[f32],
v_values: &[f32],
latlons: &[(f32, f32)],
bbox: Option<Grib2Bbox>,
kept_axes: Option<&(HashSet<u32>, HashSet<u32>)>,
) -> Vec<WeatherRow> {
let mut rows = Vec::with_capacity(latlons.len());
for i in 0..latlons.len() {
let (lat, lon) = latlons[i];
if let Some(b) = bbox
&& !b.contains(lat, lon)
{
continue;
}
if let Some((kept_lats, kept_lons)) = kept_axes
&& (!kept_lats.contains(&lat.to_bits()) || !kept_lons.contains(&lon.to_bits()))
{
continue;
}
let uc = u_values[i];
let vc = v_values[i];
if !uc.is_finite() || !vc.is_finite() {
continue;
}
let speed = uc.hypot(vc);
let direction = (270.0 - vc.atan2(uc).to_degrees()).rem_euclid(360.0);
rows.push(WeatherRow {
lon,
lat,
sample: WindSample { speed, direction },
});
}
rows
}
fn compute_kept_axes(
latlons: &[(f32, f32)],
stride: usize,
bbox: Option<Grib2Bbox>,
) -> (HashSet<u32>, HashSet<u32>) {
let cmp = |a: &f32, b: &f32| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal);
let in_bbox = |lat: f32, lon: f32| -> bool { bbox.is_none_or(|b| b.contains(lat, lon)) };
let mut unique_lats: Vec<f32> = latlons
.iter()
.filter(|(lat, lon)| in_bbox(*lat, *lon))
.map(|p| p.0)
.collect();
unique_lats.sort_by(cmp);
unique_lats.dedup();
let mut unique_lons: Vec<f32> = latlons
.iter()
.filter(|(lat, lon)| in_bbox(*lat, *lon))
.map(|p| p.1)
.collect();
unique_lons.sort_by(cmp);
unique_lons.dedup();
let kept_lats = unique_lats
.into_iter()
.step_by(stride)
.map(f32::to_bits)
.collect();
let kept_lons = unique_lons
.into_iter()
.step_by(stride)
.map(f32::to_bits)
.collect();
(kept_lats, kept_lons)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn forecast_unit_to_seconds_handles_common_units() {
let hour = Code::Name(Table4_4::Hour);
assert_eq!(forecast_unit_to_seconds(&hour, 0), Some(0));
assert_eq!(forecast_unit_to_seconds(&hour, 6), Some(21_600));
let three_hours = Code::Name(Table4_4::ThreeHours);
assert_eq!(forecast_unit_to_seconds(&three_hours, 4), Some(43_200));
let day = Code::Name(Table4_4::Day);
assert_eq!(forecast_unit_to_seconds(&day, 2), Some(172_800));
}
#[test]
fn forecast_unit_to_seconds_rejects_calendar_units() {
let month = Code::Name(Table4_4::Month);
assert_eq!(forecast_unit_to_seconds(&month, 1), None);
let missing = Code::Name(Table4_4::Missing);
assert_eq!(forecast_unit_to_seconds(&missing, 1), None);
let unknown: Code<Table4_4, u8> = Code::Num(99);
assert_eq!(forecast_unit_to_seconds(&unknown, 1), None);
}
fn kept(
latlons: &[(f32, f32)],
stride: usize,
bbox: Option<Grib2Bbox>,
) -> Option<(HashSet<u32>, HashSet<u32>)> {
(stride > 1).then(|| compute_kept_axes(latlons, stride, bbox))
}
#[test]
fn build_rows_bbox_drops_points_outside_rectangle() {
let mut latlons = Vec::new();
for j in 0..3 {
for i in 0..3 {
latlons.push((j as f32, i as f32));
}
}
let values = vec![1.0_f32; 9];
let bbox = Grib2Bbox {
lat_min: 1.0,
lat_max: 2.0,
lon_min: 1.0,
lon_max: 1.0,
};
let rows = build_rows(&values, &values, &latlons, Some(bbox), None);
assert_eq!(rows.len(), 2);
}
#[test]
fn build_rows_bbox_and_stride_compose() {
let mut latlons = Vec::new();
for j in 0..5 {
for i in 0..5 {
latlons.push((j as f32, i as f32));
}
}
let values = vec![1.0_f32; 25];
let bbox = Grib2Bbox {
lat_min: 1.0,
lat_max: 3.0,
lon_min: 1.0,
lon_max: 3.0,
};
let kept_axes = kept(&latlons, 2, Some(bbox));
let rows = build_rows(&values, &values, &latlons, Some(bbox), kept_axes.as_ref());
assert_eq!(rows.len(), 4);
}
#[test]
fn build_rows_round_trips_north_wind() {
let latlons = vec![(45.0, 0.0)];
let u = vec![0.0_f32];
let v = vec![-10.0_f32];
let rows = build_rows(&u, &v, &latlons, None, None);
assert_eq!(rows.len(), 1);
assert!((rows[0].sample.speed - 10.0).abs() < 1e-4);
assert!((rows[0].sample.direction - 0.0).abs() < 1e-3);
}
#[test]
fn build_rows_round_trips_east_wind() {
let latlons = vec![(45.0, 0.0)];
let u = vec![-10.0_f32];
let v = vec![0.0_f32];
let rows = build_rows(&u, &v, &latlons, None, None);
assert!((rows[0].sample.direction - 90.0).abs() < 1e-3);
}
#[test]
fn build_rows_skips_nan_samples() {
let latlons = vec![(45.0, 0.0), (45.5, 0.0)];
let u = vec![f32::NAN, 5.0];
let v = vec![0.0, 0.0];
let rows = build_rows(&u, &v, &latlons, None, None);
assert_eq!(rows.len(), 1);
}
#[test]
fn build_rows_emits_lon_lat_directly() {
let latlons = vec![(45.0_f32, -10.0_f32), (60.0, 5.0)];
let u = vec![0.0_f32; 2];
let v = vec![0.0_f32; 2];
let rows = build_rows(&u, &v, &latlons, None, None);
assert_eq!(rows.len(), 2);
assert!((rows[0].lon - (-10.0)).abs() < 1e-6);
assert!((rows[0].lat - 45.0).abs() < 1e-6);
assert!((rows[1].lon - 5.0).abs() < 1e-6);
assert!((rows[1].lat - 60.0).abs() < 1e-6);
}
#[test]
fn build_rows_stride_keeps_every_nth_unique_axis_value() {
let mut latlons = Vec::new();
for j in 0..4 {
for i in 0..4 {
latlons.push((j as f32, 10.0 + i as f32));
}
}
let values = vec![1.0_f32; 16];
let kept_axes = kept(&latlons, 2, None);
let rows = build_rows(&values, &values, &latlons, None, kept_axes.as_ref());
assert_eq!(rows.len(), 4);
let mut lats: Vec<f32> = rows.iter().map(|r| r.lat).collect();
let mut lons: Vec<f32> = rows.iter().map(|r| r.lon).collect();
lats.sort_by(|a, b| a.partial_cmp(b).unwrap());
lats.dedup();
lons.sort_by(|a, b| a.partial_cmp(b).unwrap());
lons.dedup();
assert_eq!(lats, vec![0.0, 2.0]);
assert_eq!(lons, vec![10.0, 12.0]);
}
#[test]
fn build_rows_stride_one_is_no_op() {
let latlons: Vec<(f32, f32)> = (0..9).map(|k| (k as f32, k as f32)).collect();
let values = vec![1.0_f32; 9];
let no_stride = build_rows(&values, &values, &latlons, None, None);
let stride_zero = build_rows(&values, &values, &latlons, None, None);
assert_eq!(no_stride.len(), 9);
assert_eq!(stride_zero.len(), 9);
}
#[test]
fn days_from_civil_matches_known_anchors() {
assert_eq!(days_from_civil(1970, 1, 1), Some(0));
assert_eq!(days_from_civil(2000, 1, 1), Some(10_957));
assert_eq!(days_from_civil(2000, 12, 31), Some(10_957 + 365));
assert_eq!(days_from_civil(1969, 12, 31), Some(-1));
assert_eq!(days_from_civil(1969, 1, 1), Some(-365));
}
#[test]
fn ref_time_to_unix_seconds_combines_date_and_time() {
let rt = RefTime::new(1970, 1, 1, 0, 0, 0);
assert_eq!(ref_time_to_unix_seconds(&rt), Some(0));
let rt = RefTime::new(1970, 1, 1, 1, 30, 15);
assert_eq!(ref_time_to_unix_seconds(&rt), Some(3600 + 30 * 60 + 15));
let rt = RefTime::new(1970, 1, 2, 0, 0, 0);
assert_eq!(ref_time_to_unix_seconds(&rt), Some(86_400));
let a = ref_time_to_unix_seconds(&RefTime::new(2026, 3, 1, 0, 0, 0)).unwrap();
let b = ref_time_to_unix_seconds(&RefTime::new(2026, 3, 1, 6, 0, 0)).unwrap();
assert_eq!(b - a, 21_600);
}
#[test]
fn ref_time_to_unix_seconds_rejects_invalid_clock() {
let bad_hour = RefTime::new(2026, 3, 1, 25, 0, 0);
assert_eq!(ref_time_to_unix_seconds(&bad_hour), None);
let bad_minute = RefTime::new(2026, 3, 1, 0, 60, 0);
assert_eq!(ref_time_to_unix_seconds(&bad_minute), None);
let bad_month = RefTime::new(2026, 13, 1, 0, 0, 0);
assert_eq!(ref_time_to_unix_seconds(&bad_month), None);
}
}