use anyhow::{Result, anyhow};
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum GfmProfile {
PrithviV2,
ClayV15,
}
#[derive(Clone, Debug)]
pub struct GfmProfileSpec {
pub name: &'static str,
pub model_target: &'static str,
pub bands_order: Vec<&'static str>,
pub tile_size: usize,
pub band_norm: Vec<(f32, f32)>,
pub expected_unit: &'static str,
pub source_url: &'static str,
}
impl GfmProfile {
pub fn from_name(s: &str) -> Result<Self> {
match s.to_lowercase().as_str() {
"prithvi-v2" | "prithvi" | "prithvi-eo-2" | "prithvi-eo-2.0" => Ok(Self::PrithviV2),
"clay-v1.5" | "clay" | "clay-v15" => Ok(Self::ClayV15),
_ => Err(anyhow!(
"Unknown GFM profile '{}'. Supported: prithvi-v2, clay-v1.5",
s
)),
}
}
pub fn spec(&self) -> GfmProfileSpec {
match self {
Self::PrithviV2 => GfmProfileSpec {
name: "prithvi-v2",
model_target: "ibm-nasa-geospatial/Prithvi-EO-2.0-300M",
bands_order: vec!["B02", "B03", "B04", "B05", "B06", "B07"],
tile_size: 224,
band_norm: vec![
(1087.0, 2248.0), (1342.0, 2179.0), (1433.0, 2178.0), (2734.0, 1850.0), (1958.0, 1242.0), (1363.0, 1049.0), ],
expected_unit: "DN_sr_x10000",
source_url: "https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M",
},
Self::ClayV15 => GfmProfileSpec {
name: "clay-v1.5",
model_target: "made-with-clay/Clay/v1.5",
bands_order: vec![
"B02", "B03", "B04", "B05", "B06", "B07", "B08", "B8A", "B11", "B12",
],
tile_size: 256,
band_norm: vec![
(0.1182, 0.0461),
(0.1262, 0.0468),
(0.1389, 0.0596),
(0.1497, 0.0635),
(0.2010, 0.0780),
(0.2353, 0.0950),
(0.2455, 0.0987),
(0.2547, 0.1023),
(0.2099, 0.1153),
(0.1421, 0.0890),
],
expected_unit: "reflectance_0_1",
source_url: "https://github.com/Clay-foundation/model",
},
}
}
}
pub fn apply_band_norm(buf: &mut [f32], band_norm: &[(f32, f32)], tile_size: usize) {
let band_pixels = tile_size * tile_size;
apply_band_norm_block(buf, band_norm, band_pixels);
}
pub fn apply_band_norm_block(buf: &mut [f32], band_norm: &[(f32, f32)], block_size: usize) {
for (bi, (mean, std)) in band_norm.iter().enumerate() {
let band_offset = bi * block_size;
let std_safe = if *std > 1e-10 { *std } else { 1e-10 };
for px in 0..block_size {
let v = buf[band_offset + px];
if v.is_finite() {
buf[band_offset + px] = (v - mean) / std_safe;
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn from_name_aliases() {
assert_eq!(
GfmProfile::from_name("prithvi-v2").unwrap(),
GfmProfile::PrithviV2
);
assert_eq!(
GfmProfile::from_name("prithvi").unwrap(),
GfmProfile::PrithviV2
);
assert_eq!(
GfmProfile::from_name("PRITHVI-EO-2.0").unwrap(),
GfmProfile::PrithviV2
);
assert_eq!(
GfmProfile::from_name("clay-v1.5").unwrap(),
GfmProfile::ClayV15
);
assert_eq!(GfmProfile::from_name("Clay").unwrap(), GfmProfile::ClayV15);
assert!(GfmProfile::from_name("bogus").is_err());
}
#[test]
fn prithvi_v2_spec_consistent() {
let s = GfmProfile::PrithviV2.spec();
assert_eq!(s.bands_order.len(), s.band_norm.len());
assert_eq!(s.bands_order.len(), 6);
assert_eq!(s.tile_size, 224);
}
#[test]
fn clay_v15_spec_consistent() {
let s = GfmProfile::ClayV15.spec();
assert_eq!(s.bands_order.len(), s.band_norm.len());
assert_eq!(s.bands_order.len(), 10);
assert_eq!(s.tile_size, 256);
}
#[test]
fn apply_band_norm_zscore() {
let mut buf: Vec<f32> = vec![
2.0, 3.0, 4.0, 5.0, 10.0, 12.0, 14.0, 16.0, ];
let norm = vec![(2.0_f32, 1.0_f32), (10.0_f32, 2.0_f32)];
apply_band_norm(&mut buf, &norm, 2);
let expected: Vec<f32> = vec![0.0, 1.0, 2.0, 3.0, 0.0, 1.0, 2.0, 3.0];
for (got, want) in buf.iter().zip(expected.iter()) {
assert!((got - want).abs() < 1e-6, "got {} want {}", got, want);
}
}
#[test]
fn apply_band_norm_block_temporal() {
let mut buf: Vec<f32> = vec![
2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0,
];
let norm = vec![(2.0_f32, 1.0_f32)];
apply_band_norm_block(&mut buf, &norm, 8);
let expected: Vec<f32> = vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0];
for (got, want) in buf.iter().zip(expected.iter()) {
assert!((got - want).abs() < 1e-6, "got {} want {}", got, want);
}
}
#[test]
fn apply_band_norm_preserves_nan() {
let mut buf: Vec<f32> = vec![2.0, f32::NAN, 4.0, 5.0];
let norm = vec![(2.0_f32, 1.0_f32)];
apply_band_norm(&mut buf, &norm, 2);
assert_eq!(buf[0], 0.0);
assert!(buf[1].is_nan());
assert_eq!(buf[2], 2.0);
assert_eq!(buf[3], 3.0);
}
}