use std::collections::BTreeMap;
use std::fmt::Debug;
use std::path::Path;
use ciborium::Value as CborValue;
use eccodes::{CodesFile, FallibleIterator, KeyRead, ProductKind};
use tensogram::pipeline::apply_pipeline;
use tensogram::types::{ByteOrder, DataObjectDescriptor, GlobalMetadata};
use tensogram::{DataPipeline, Dtype, EncodeOptions, encode};
use crate::error::GribError;
use crate::metadata::{GribKeySet, extract_all_namespace_keys, extract_mars_keys};
#[derive(Debug, Clone)]
pub struct ConvertOptions {
pub grouping: Grouping,
pub encode_options: EncodeOptions,
pub preserve_all_keys: bool,
pub pipeline: DataPipeline,
}
#[derive(Debug, Clone)]
pub enum Grouping {
OneToOne,
MergeAll,
}
impl Default for ConvertOptions {
fn default() -> Self {
Self {
grouping: Grouping::MergeAll,
encode_options: EncodeOptions::default(),
preserve_all_keys: false,
pipeline: DataPipeline::default(),
}
}
}
pub(crate) struct GribExtracted {
pub(crate) keys: GribKeySet,
pub(crate) values: Vec<f64>,
pub(crate) shape: Vec<u64>,
pub(crate) grib_keys: Option<BTreeMap<String, BTreeMap<String, CborValue>>>,
}
pub fn convert_grib_file(path: &Path, options: &ConvertOptions) -> Result<Vec<Vec<u8>>, GribError> {
let mut handle = CodesFile::new_from_file(path, ProductKind::GRIB)?;
let grib_messages = extract_messages(&mut handle, options.preserve_all_keys)?;
finish_conversion(grib_messages, options)
}
pub fn convert_grib_buffer(
buffer: Vec<u8>,
options: &ConvertOptions,
) -> Result<Vec<Vec<u8>>, GribError> {
let mut handle = CodesFile::new_from_memory(buffer, ProductKind::GRIB)?;
let grib_messages = extract_messages(&mut handle, options.preserve_all_keys)?;
finish_conversion(grib_messages, options)
}
fn extract_messages<D: Debug>(
handle: &mut CodesFile<D>,
preserve_all_keys: bool,
) -> Result<Vec<GribExtracted>, GribError> {
let mut grib_messages = Vec::new();
let mut iter = handle.ref_message_iter();
while let Some(mut msg) = iter.next()? {
let mut keys = extract_mars_keys(&mut msg)?;
let values: Vec<f64> = msg.read_key("values")?;
if let Ok(grid_type) = KeyRead::<String>::read_key(&msg, "gridType") {
keys.keys
.insert("grid".to_string(), CborValue::Text(grid_type));
}
let grib_keys = if preserve_all_keys {
Some(extract_all_namespace_keys(&mut msg)?)
} else {
None
};
let ni: i64 = msg.read_key("Ni").unwrap_or(0);
let nj: i64 = msg.read_key("Nj").unwrap_or(0);
let shape =
if ni > 0 && nj > 0 {
let nj = u64::try_from(nj)
.map_err(|_| GribError::InvalidData("Nj out of u64 range".into()))?;
let ni = u64::try_from(ni)
.map_err(|_| GribError::InvalidData("Ni out of u64 range".into()))?;
vec![nj, ni] } else {
vec![u64::try_from(values.len()).map_err(|_| {
GribError::InvalidData("numberOfPoints out of u64 range".into())
})?]
};
grib_messages.push(GribExtracted {
keys,
values,
shape,
grib_keys,
});
}
Ok(grib_messages)
}
fn finish_conversion(
grib_messages: Vec<GribExtracted>,
options: &ConvertOptions,
) -> Result<Vec<Vec<u8>>, GribError> {
if grib_messages.is_empty() {
return Err(GribError::NoMessages);
}
match options.grouping {
Grouping::OneToOne => {
convert_one_to_one(&grib_messages, &options.encode_options, &options.pipeline)
}
Grouping::MergeAll => {
convert_merge_all(&grib_messages, &options.encode_options, &options.pipeline)
}
}
}
fn convert_one_to_one(
messages: &[GribExtracted],
encode_options: &EncodeOptions,
pipeline: &DataPipeline,
) -> Result<Vec<Vec<u8>>, GribError> {
let mut results = Vec::with_capacity(messages.len());
for msg in messages {
let mut entry = BTreeMap::new();
if !msg.keys.keys.is_empty() {
entry.insert("mars".to_string(), btree_to_cbor_map(&msg.keys.keys));
}
if let Some(grib) = &msg.grib_keys
&& !grib.is_empty()
{
entry.insert("grib".to_string(), nested_btree_to_cbor_map(grib));
}
let global_meta = GlobalMetadata {
base: vec![entry],
..Default::default()
};
let (desc, data_bytes) = build_data_object(&msg.values, &msg.shape, pipeline)?;
let encoded = encode(&global_meta, &[(&desc, &data_bytes)], encode_options)
.map_err(|e| GribError::Encode(e.to_string()))?;
results.push(encoded);
}
Ok(results)
}
fn convert_merge_all(
messages: &[GribExtracted],
encode_options: &EncodeOptions,
pipeline: &DataPipeline,
) -> Result<Vec<Vec<u8>>, GribError> {
let base: Vec<BTreeMap<String, CborValue>> = messages
.iter()
.map(|msg| {
let mut entry = BTreeMap::new();
if !msg.keys.keys.is_empty() {
entry.insert("mars".to_string(), btree_to_cbor_map(&msg.keys.keys));
}
if let Some(grib) = &msg.grib_keys
&& !grib.is_empty()
{
entry.insert("grib".to_string(), nested_btree_to_cbor_map(grib));
}
entry
})
.collect();
let global_meta = GlobalMetadata {
base,
..Default::default()
};
let mut descriptors_and_data = Vec::with_capacity(messages.len());
for msg in messages {
let (desc, data_bytes) = build_data_object(&msg.values, &msg.shape, pipeline)?;
descriptors_and_data.push((desc, data_bytes));
}
let refs: Vec<_> = descriptors_and_data
.iter()
.map(|(desc, data)| (desc, data.as_slice()))
.collect();
let encoded = encode(&global_meta, &refs, encode_options)
.map_err(|e| GribError::Encode(e.to_string()))?;
Ok(vec![encoded])
}
fn btree_to_cbor_map(map: &BTreeMap<String, CborValue>) -> CborValue {
CborValue::Map(
map.iter()
.map(|(k, v)| (CborValue::Text(k.clone()), v.clone()))
.collect(),
)
}
fn nested_btree_to_cbor_map(map: &BTreeMap<String, BTreeMap<String, CborValue>>) -> CborValue {
CborValue::Map(
map.iter()
.map(|(ns, inner)| (CborValue::Text(ns.clone()), btree_to_cbor_map(inner)))
.collect(),
)
}
fn build_data_object(
values: &[f64],
shape: &[u64],
pipeline: &DataPipeline,
) -> Result<(DataObjectDescriptor, Vec<u8>), GribError> {
let ndim = shape.len() as u64;
let mut strides = vec![0u64; shape.len()];
if !shape.is_empty() {
strides[shape.len() - 1] = 1;
for i in (0..shape.len() - 1).rev() {
strides[i] = strides[i + 1] * shape[i + 1];
}
}
let mut desc = DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim,
shape: shape.to_vec(),
strides,
dtype: Dtype::Float64,
byte_order: ByteOrder::Little,
encoding: "none".to_string(),
filter: "none".to_string(),
compression: "none".to_string(),
params: BTreeMap::new(),
masks: None,
};
apply_pipeline(&mut desc, Some(values), pipeline, "GRIB message")
.map_err(GribError::InvalidData)?;
let data_bytes: Vec<u8> = values.iter().flat_map(|v| v.to_le_bytes()).collect();
Ok((desc, data_bytes))
}
#[cfg(test)]
mod tests {
use std::path::PathBuf;
use super::*;
fn testdata(name: &str) -> PathBuf {
let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
path.push("testdata");
path.push(name);
path
}
#[test]
fn buffer_matches_file() {
let path = testdata("2t.grib2");
let bytes = std::fs::read(&path).expect("read 2t.grib2");
let opts = ConvertOptions::default();
let from_file = convert_grib_file(&path, &opts).expect("convert file");
let from_buf = convert_grib_buffer(bytes, &opts).expect("convert buffer");
assert_eq!(from_file.len(), from_buf.len());
let opts = tensogram::DecodeOptions::default();
for (a, b) in from_file.iter().zip(from_buf.iter()) {
let (_, objs_a) = tensogram::decode(a, &opts).expect("decode file");
let (_, objs_b) = tensogram::decode(b, &opts).expect("decode buffer");
assert_eq!(objs_a.len(), objs_b.len());
for ((da, ba), (db, bb)) in objs_a.iter().zip(objs_b.iter()) {
assert_eq!(da.shape, db.shape);
assert_eq!(da.dtype, db.dtype);
assert_eq!(da.encoding, db.encoding);
assert_eq!(da.filter, db.filter);
assert_eq!(da.compression, db.compression);
assert_eq!(ba, bb, "payload bytes must match");
}
}
}
#[test]
fn buffer_rejects_garbage() {
let garbage = b"this is not a grib message".to_vec();
let result = convert_grib_buffer(garbage, &ConvertOptions::default());
assert!(result.is_err(), "expected error on garbage input");
}
#[test]
fn buffer_merge_all_grouping() {
let bytes = std::fs::read(testdata("2t.grib2")).expect("read 2t.grib2");
let opts = ConvertOptions {
grouping: Grouping::MergeAll,
..Default::default()
};
let result = convert_grib_buffer(bytes, &opts).expect("convert buffer");
assert_eq!(result.len(), 1);
}
#[test]
fn buffer_one_to_one_grouping() {
let bytes = std::fs::read(testdata("2t.grib2")).expect("read 2t.grib2");
let opts = ConvertOptions {
grouping: Grouping::OneToOne,
..Default::default()
};
let result = convert_grib_buffer(bytes, &opts).expect("convert buffer");
assert_eq!(result.len(), 1);
}
#[test]
fn buffer_preserve_all_keys() {
let bytes = std::fs::read(testdata("2t.grib2")).expect("read 2t.grib2");
let opts = ConvertOptions {
preserve_all_keys: true,
..Default::default()
};
let messages = convert_grib_buffer(bytes, &opts).expect("convert buffer");
let meta = tensogram::decode_metadata(&messages[0]).expect("decode metadata");
assert!(
meta.base.iter().any(|entry| entry.contains_key("grib")),
"preserve_all_keys should populate the grib sub-object"
);
}
#[test]
fn buffer_empty_returns_error() {
let result = convert_grib_buffer(Vec::new(), &ConvertOptions::default());
assert!(result.is_err(), "empty buffer must produce an error");
}
}