use std::path::Path;
use ferray_core::Array;
use ferray_core::dimension::Dimension;
use ferray_core::dtype::Element;
use ferray_core::error::FerrayResult;
use ferray_io::npy::{NpyElement, load as npy_load, save as npy_save};
use crate::MaskedArray;
pub fn save_masked<T, D, P1, P2>(
data_path: P1,
mask_path: P2,
ma: &MaskedArray<T, D>,
) -> FerrayResult<()>
where
T: Element + NpyElement,
D: Dimension,
P1: AsRef<Path>,
P2: AsRef<Path>,
{
npy_save(data_path.as_ref(), ma.data())?;
npy_save(mask_path.as_ref(), ma.mask())?;
Ok(())
}
pub fn load_masked<T, D, P1, P2>(data_path: P1, mask_path: P2) -> FerrayResult<MaskedArray<T, D>>
where
T: Element + NpyElement,
D: Dimension,
P1: AsRef<Path>,
P2: AsRef<Path>,
{
let data: Array<T, D> = npy_load(data_path.as_ref())?;
let mask: Array<bool, D> = npy_load(mask_path.as_ref())?;
MaskedArray::new(data, mask)
}
#[cfg(test)]
mod tests {
use super::*;
use ferray_core::Array;
use ferray_core::dimension::{Ix1, Ix2};
fn test_dir() -> std::path::PathBuf {
let dir = std::env::temp_dir().join(format!("ferray_ma_io_{}", std::process::id()));
let _ = std::fs::create_dir_all(&dir);
dir
}
fn test_file(name: &str) -> std::path::PathBuf {
test_dir().join(name)
}
#[test]
fn save_and_load_roundtrips_1d_f64_masked_array() {
let d = Array::<f64, Ix1>::from_vec(Ix1::new([5]), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
let m = Array::<bool, Ix1>::from_vec(Ix1::new([5]), vec![false, true, false, false, true])
.unwrap();
let ma = MaskedArray::new(d, m).unwrap();
let data_path = test_file("test_1d.data.npy");
let mask_path = test_file("test_1d.mask.npy");
save_masked(&data_path, &mask_path, &ma).unwrap();
let loaded: MaskedArray<f64, Ix1> = load_masked(&data_path, &mask_path).unwrap();
assert_eq!(loaded.shape(), &[5]);
assert_eq!(
loaded.data().iter().copied().collect::<Vec<_>>(),
vec![1.0, 2.0, 3.0, 4.0, 5.0]
);
assert_eq!(
loaded.mask().iter().copied().collect::<Vec<_>>(),
vec![false, true, false, false, true]
);
let _ = std::fs::remove_file(&data_path);
let _ = std::fs::remove_file(&mask_path);
}
#[test]
fn save_and_load_roundtrips_2d_i32_masked_array() {
let d =
Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![10, 20, 30, 40, 50, 60]).unwrap();
let m = Array::<bool, Ix2>::from_vec(
Ix2::new([2, 3]),
vec![false, false, true, true, false, false],
)
.unwrap();
let ma = MaskedArray::new(d, m).unwrap();
let data_path = test_file("test_2d.data.npy");
let mask_path = test_file("test_2d.mask.npy");
save_masked(&data_path, &mask_path, &ma).unwrap();
let loaded: MaskedArray<i32, Ix2> = load_masked(&data_path, &mask_path).unwrap();
assert_eq!(loaded.shape(), &[2, 3]);
assert_eq!(
loaded.data().iter().copied().collect::<Vec<_>>(),
vec![10, 20, 30, 40, 50, 60]
);
assert_eq!(
loaded.mask().iter().copied().collect::<Vec<_>>(),
vec![false, false, true, true, false, false]
);
let _ = std::fs::remove_file(&data_path);
let _ = std::fs::remove_file(&mask_path);
}
#[test]
fn load_masked_defaults_soft_mask_and_zero_fill() {
let d = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
let m = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![false, true, false]).unwrap();
let mut ma = MaskedArray::new(d, m).unwrap();
ma.set_fill_value(-999.0);
ma.harden_mask().unwrap();
let data_path = test_file("test_defaults.data.npy");
let mask_path = test_file("test_defaults.mask.npy");
save_masked(&data_path, &mask_path, &ma).unwrap();
let loaded: MaskedArray<f64, Ix1> = load_masked(&data_path, &mask_path).unwrap();
assert_eq!(loaded.fill_value(), 0.0);
assert!(!loaded.is_hard_mask());
let _ = std::fs::remove_file(&data_path);
let _ = std::fs::remove_file(&mask_path);
}
#[test]
fn load_masked_rejects_mismatched_shapes() {
let d = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
let wrong_mask = Array::<bool, Ix1>::from_vec(Ix1::new([4]), vec![false; 4]).unwrap();
let data_path = test_file("test_bad_shape.data.npy");
let mask_path = test_file("test_bad_shape.mask.npy");
ferray_io::npy::save(&data_path, &d).unwrap();
ferray_io::npy::save(&mask_path, &wrong_mask).unwrap();
let result: FerrayResult<MaskedArray<f64, Ix1>> = load_masked(&data_path, &mask_path);
assert!(result.is_err());
let _ = std::fs::remove_file(&data_path);
let _ = std::fs::remove_file(&mask_path);
}
#[test]
fn load_masked_rejects_wrong_dtype() {
let d = Array::<f32, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
let m = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![false, true, false]).unwrap();
let data_path = test_file("test_wrong_dtype.data.npy");
let mask_path = test_file("test_wrong_dtype.mask.npy");
ferray_io::npy::save(&data_path, &d).unwrap();
ferray_io::npy::save(&mask_path, &m).unwrap();
let result: FerrayResult<MaskedArray<f64, Ix1>> = load_masked(&data_path, &mask_path);
assert!(result.is_err());
let _ = std::fs::remove_file(&data_path);
let _ = std::fs::remove_file(&mask_path);
}
}