Skip to main content

ad_plugins/
file_netcdf.rs

1use std::path::{Path, PathBuf};
2use std::sync::Arc;
3
4use ad_core::error::{ADError, ADResult};
5use ad_core::ndarray::{NDArray, NDDataBuffer, NDDataType, NDDimension};
6use ad_core::ndarray_pool::NDArrayPool;
7use ad_core::plugin::file_base::{NDFileMode, NDFileWriter, NDPluginFileBase};
8use ad_core::plugin::runtime::{NDPluginProcess, ProcessResult};
9
10use netcdf3::{DataSet, FileReader, FileWriter, Version};
11
12const VAR_NAME: &str = "array_data";
13const DIM_UNLIMITED: &str = "numArrays";
14
15/// A single buffered frame captured from an NDArray.
16struct FrameData {
17    dims: Vec<usize>,
18    data: NDDataBuffer,
19    data_type: NDDataType,
20    attrs: Vec<(String, String)>,
21}
22
23/// NetCDF-3 file writer.
24///
25/// Because `netcdf3::FileWriter` is `!Send` (uses `Rc` internally), we cannot
26/// store it as a field on a `Send + Sync` struct.  Instead we buffer frame data
27/// in memory and materialise the `FileWriter` only inside `close_file()`, where
28/// it is created, used, and dropped within a single method call.  The same
29/// approach is used for `read_file()` with `FileReader`.
30pub struct NetcdfWriter {
31    current_path: Option<PathBuf>,
32    frames: Vec<FrameData>,
33}
34
35impl NetcdfWriter {
36    pub fn new() -> Self {
37        Self {
38            current_path: None,
39            frames: Vec::new(),
40        }
41    }
42}
43
44/// Map NDDataType → netcdf3 DataType.  Returns error for 64-bit integers
45/// which NetCDF-3 classic format does not support.
46fn nc_data_type(dt: NDDataType) -> ADResult<netcdf3::DataType> {
47    match dt {
48        NDDataType::Int8 => Ok(netcdf3::DataType::I8),
49        NDDataType::UInt8 => Ok(netcdf3::DataType::U8),
50        NDDataType::Int16 | NDDataType::UInt16 => Ok(netcdf3::DataType::I16),
51        NDDataType::Int32 | NDDataType::UInt32 => Ok(netcdf3::DataType::I32),
52        NDDataType::Float32 => Ok(netcdf3::DataType::F32),
53        NDDataType::Float64 => Ok(netcdf3::DataType::F64),
54        NDDataType::Int64 | NDDataType::UInt64 => Err(ADError::UnsupportedConversion(
55            "NetCDF-3 does not support 64-bit integer types".into(),
56        )),
57    }
58}
59
60/// Write a single frame's data to a fixed-dimension variable.
61fn write_var_data(
62    writer: &mut FileWriter,
63    data: &NDDataBuffer,
64) -> ADResult<()> {
65    let err = |e: netcdf3::error::WriteError| {
66        ADError::UnsupportedConversion(format!("NetCDF write error: {:?}", e))
67    };
68    match data {
69        NDDataBuffer::I8(v) => writer.write_var_i8(VAR_NAME, v).map_err(err),
70        NDDataBuffer::U8(v) => writer.write_var_u8(VAR_NAME, v).map_err(err),
71        NDDataBuffer::I16(v) => writer.write_var_i16(VAR_NAME, v).map_err(err),
72        NDDataBuffer::U16(v) => {
73            let reinterp: Vec<i16> = v.iter().map(|&x| x as i16).collect();
74            writer.write_var_i16(VAR_NAME, &reinterp).map_err(err)
75        }
76        NDDataBuffer::I32(v) => writer.write_var_i32(VAR_NAME, v).map_err(err),
77        NDDataBuffer::U32(v) => {
78            let reinterp: Vec<i32> = v.iter().map(|&x| x as i32).collect();
79            writer.write_var_i32(VAR_NAME, &reinterp).map_err(err)
80        }
81        NDDataBuffer::F32(v) => writer.write_var_f32(VAR_NAME, v).map_err(err),
82        NDDataBuffer::F64(v) => writer.write_var_f64(VAR_NAME, v).map_err(err),
83        _ => Err(ADError::UnsupportedConversion(
84            "NetCDF-3 does not support 64-bit integer types".into(),
85        )),
86    }
87}
88
89/// Write a single record (one frame) to a record variable.
90fn write_record_data(
91    writer: &mut FileWriter,
92    record_index: usize,
93    data: &NDDataBuffer,
94) -> ADResult<()> {
95    let err = |e: netcdf3::error::WriteError| {
96        ADError::UnsupportedConversion(format!("NetCDF write error: {:?}", e))
97    };
98    match data {
99        NDDataBuffer::I8(v) => writer.write_record_i8(VAR_NAME, record_index, v).map_err(err),
100        NDDataBuffer::U8(v) => writer.write_record_u8(VAR_NAME, record_index, v).map_err(err),
101        NDDataBuffer::I16(v) => writer.write_record_i16(VAR_NAME, record_index, v).map_err(err),
102        NDDataBuffer::U16(v) => {
103            let reinterp: Vec<i16> = v.iter().map(|&x| x as i16).collect();
104            writer.write_record_i16(VAR_NAME, record_index, &reinterp).map_err(err)
105        }
106        NDDataBuffer::I32(v) => writer.write_record_i32(VAR_NAME, record_index, v).map_err(err),
107        NDDataBuffer::U32(v) => {
108            let reinterp: Vec<i32> = v.iter().map(|&x| x as i32).collect();
109            writer.write_record_i32(VAR_NAME, record_index, &reinterp).map_err(err)
110        }
111        NDDataBuffer::F32(v) => writer.write_record_f32(VAR_NAME, record_index, v).map_err(err),
112        NDDataBuffer::F64(v) => writer.write_record_f64(VAR_NAME, record_index, v).map_err(err),
113        _ => Err(ADError::UnsupportedConversion(
114            "NetCDF-3 does not support 64-bit integer types".into(),
115        )),
116    }
117}
118
119impl NDFileWriter for NetcdfWriter {
120    fn open_file(&mut self, path: &Path, _mode: NDFileMode, _array: &NDArray) -> ADResult<()> {
121        self.current_path = Some(path.to_path_buf());
122        self.frames.clear();
123        Ok(())
124    }
125
126    fn write_file(&mut self, array: &NDArray) -> ADResult<()> {
127        // Validate data type early
128        nc_data_type(array.data.data_type())?;
129
130        let dims: Vec<usize> = array.dims.iter().map(|d| d.size).collect();
131        let attrs: Vec<(String, String)> = array
132            .attributes
133            .iter()
134            .map(|a| (a.name.clone(), a.value.as_string()))
135            .collect();
136
137        self.frames.push(FrameData {
138            dims,
139            data: array.data.clone(),
140            data_type: array.data.data_type(),
141            attrs,
142        });
143        Ok(())
144    }
145
146    fn close_file(&mut self) -> ADResult<()> {
147        let path = match self.current_path.take() {
148            Some(p) => p,
149            None => return Ok(()),
150        };
151
152        if self.frames.is_empty() {
153            return Ok(());
154        }
155
156        let map_def = |e: netcdf3::error::InvalidDataSet| {
157            ADError::UnsupportedConversion(format!("NetCDF definition error: {:?}", e))
158        };
159        let map_write = |e: netcdf3::error::WriteError| {
160            ADError::UnsupportedConversion(format!("NetCDF write error: {:?}", e))
161        };
162
163        let first = &self.frames[0];
164        let nc_dt = nc_data_type(first.data_type)?;
165        let multi = self.frames.len() > 1;
166
167        // Build DataSet definition
168        let mut ds = DataSet::new();
169
170        // Fixed dimensions: dim0, dim1, ...
171        let mut dim_names: Vec<String> = Vec::new();
172        for (i, &size) in first.dims.iter().enumerate() {
173            let name = format!("dim{}", i);
174            ds.add_fixed_dim(&name, size).map_err(map_def)?;
175            dim_names.push(name);
176        }
177
178        // Variable dimensions list
179        let var_dims: Vec<String> = if multi {
180            // Unlimited dimension first for record variables
181            ds.set_unlimited_dim(DIM_UNLIMITED, self.frames.len()).map_err(map_def)?;
182            let mut v = vec![DIM_UNLIMITED.to_string()];
183            v.extend(dim_names.iter().cloned());
184            v
185        } else {
186            dim_names.clone()
187        };
188
189        let var_dim_refs: Vec<&str> = var_dims.iter().map(|s| s.as_str()).collect();
190        ds.add_var(VAR_NAME, &var_dim_refs, nc_dt).map_err(map_def)?;
191
192        // Store NDArray attributes as variable attributes on array_data
193        // Merge attributes from all frames (first frame wins on duplicates)
194        let mut seen_attrs = std::collections::HashSet::new();
195        for frame in &self.frames {
196            for (name, value) in &frame.attrs {
197                if seen_attrs.insert(name.clone()) {
198                    let _ = ds.add_var_attr_string(VAR_NAME, name, value);
199                }
200            }
201        }
202
203        // Global attributes
204        ds.add_global_attr_i32("uniqueId", vec![0]).map_err(map_def)?;
205        ds.add_global_attr_i32("dataType", vec![first.data_type as i32]).map_err(map_def)?;
206        ds.add_global_attr_i32("numArrays", vec![self.frames.len() as i32]).map_err(map_def)?;
207
208        // Write
209        let mut writer = FileWriter::open(&path).map_err(map_write)?;
210        writer.set_def(&ds, Version::Classic, 0).map_err(map_write)?;
211
212        if multi {
213            for (i, frame) in self.frames.iter().enumerate() {
214                write_record_data(&mut writer, i, &frame.data)?;
215            }
216        } else {
217            write_var_data(&mut writer, &self.frames[0].data)?;
218        }
219
220        writer.close().map_err(map_write)?;
221        self.frames.clear();
222        Ok(())
223    }
224
225    fn read_file(&mut self) -> ADResult<NDArray> {
226        let path = self.current_path.as_ref()
227            .ok_or_else(|| ADError::UnsupportedConversion("no file open".into()))?;
228
229        let map_read = |e: netcdf3::error::ReadError| {
230            ADError::UnsupportedConversion(format!("NetCDF read error: {:?}", e))
231        };
232
233        let mut reader = FileReader::open(path).map_err(map_read)?;
234
235        // Extract metadata from data_set() before any mutable read calls
236        let (is_record, dims, original_type_ordinal) = {
237            let ds = reader.data_set();
238            let var = ds.get_var(VAR_NAME)
239                .ok_or_else(|| ADError::UnsupportedConversion(
240                    format!("variable '{}' not found in NetCDF file", VAR_NAME),
241                ))?;
242
243            let is_record = ds.is_record_var(VAR_NAME).unwrap_or(false);
244
245            let var_dims_rc = var.get_dims();
246            let mut dims: Vec<NDDimension> = Vec::new();
247            for d in &var_dims_rc {
248                if d.is_unlimited() {
249                    continue;
250                }
251                dims.push(NDDimension::new(d.size()));
252            }
253
254            let original_type_ordinal = ds
255                .get_global_attr_i32("dataType")
256                .and_then(|slice| slice.first().copied());
257
258            (is_record, dims, original_type_ordinal)
259        };
260
261        // Read first frame (record 0 if record variable, else full var)
262        let data_vec = if is_record {
263            reader.read_record(VAR_NAME, 0).map_err(map_read)?
264        } else {
265            reader.read_var(VAR_NAME).map_err(map_read)?
266        };
267
268        let (nd_type, buf) = match data_vec {
269            netcdf3::DataVector::I8(v) => (NDDataType::Int8, NDDataBuffer::I8(v)),
270            netcdf3::DataVector::U8(v) => (NDDataType::UInt8, NDDataBuffer::U8(v)),
271            netcdf3::DataVector::I16(v) => (NDDataType::Int16, NDDataBuffer::I16(v)),
272            netcdf3::DataVector::I32(v) => (NDDataType::Int32, NDDataBuffer::I32(v)),
273            netcdf3::DataVector::F32(v) => (NDDataType::Float32, NDDataBuffer::F32(v)),
274            netcdf3::DataVector::F64(v) => (NDDataType::Float64, NDDataBuffer::F64(v)),
275        };
276
277        // Check global attr "dataType" to recover original NDDataType
278        let actual_type = original_type_ordinal
279            .and_then(|v| NDDataType::from_ordinal(v as u8))
280            .unwrap_or(nd_type);
281
282        // Re-interpret if the original type was unsigned and stored as signed
283        let buf = match (actual_type, buf) {
284            (NDDataType::UInt16, NDDataBuffer::I16(v)) => {
285                NDDataBuffer::U16(v.into_iter().map(|x| x as u16).collect())
286            }
287            (NDDataType::UInt32, NDDataBuffer::I32(v)) => {
288                NDDataBuffer::U32(v.into_iter().map(|x| x as u32).collect())
289            }
290            (_, buf) => buf,
291        };
292
293        let mut arr = NDArray::new(dims, actual_type);
294        arr.data = buf;
295        Ok(arr)
296    }
297
298    fn supports_multiple_arrays(&self) -> bool {
299        true
300    }
301}
302
303/// NetCDF file processor wrapping NDPluginFileBase + NetcdfWriter.
304pub struct NetcdfFileProcessor {
305    file_base: NDPluginFileBase,
306    writer: NetcdfWriter,
307}
308
309impl NetcdfFileProcessor {
310    pub fn new() -> Self {
311        Self {
312            file_base: NDPluginFileBase::new(),
313            writer: NetcdfWriter::new(),
314        }
315    }
316
317    pub fn file_base_mut(&mut self) -> &mut NDPluginFileBase {
318        &mut self.file_base
319    }
320}
321
322impl Default for NetcdfFileProcessor {
323    fn default() -> Self {
324        Self::new()
325    }
326}
327
328impl NDPluginProcess for NetcdfFileProcessor {
329    fn process_array(&mut self, array: &NDArray, _pool: &NDArrayPool) -> ProcessResult {
330        let _ = self
331            .file_base
332            .process_array(Arc::new(array.clone()), &mut self.writer);
333        ProcessResult::empty()
334    }
335
336    fn plugin_type(&self) -> &str {
337        "NDFileNetCDF"
338    }
339}
340
341#[cfg(test)]
342mod tests {
343    use super::*;
344    use ad_core::attributes::{NDAttrSource, NDAttrValue, NDAttribute};
345    use std::sync::atomic::{AtomicU32, Ordering};
346
347    static TEST_COUNTER: AtomicU32 = AtomicU32::new(0);
348
349    fn temp_path(prefix: &str) -> PathBuf {
350        let n = TEST_COUNTER.fetch_add(1, Ordering::Relaxed);
351        std::env::temp_dir().join(format!("adcore_test_{}_{}.nc", prefix, n))
352    }
353
354    #[test]
355    fn test_write_u8_mono() {
356        let path = temp_path("nc_u8");
357        let mut writer = NetcdfWriter::new();
358
359        let mut arr = NDArray::new(
360            vec![NDDimension::new(4), NDDimension::new(4)],
361            NDDataType::UInt8,
362        );
363        if let NDDataBuffer::U8(v) = &mut arr.data {
364            for i in 0..16 {
365                v[i] = i as u8;
366            }
367        }
368
369        writer.open_file(&path, NDFileMode::Single, &arr).unwrap();
370        writer.write_file(&arr).unwrap();
371        writer.close_file().unwrap();
372
373        // Verify file exists and has NetCDF magic bytes: "CDF\x01" or "CDF\x02"
374        let data = std::fs::read(&path).unwrap();
375        assert!(data.len() > 16);
376        assert_eq!(&data[0..3], b"CDF", "Expected NetCDF magic bytes");
377
378        std::fs::remove_file(&path).ok();
379    }
380
381    #[test]
382    fn test_write_u16() {
383        let path = temp_path("nc_u16");
384        let mut writer = NetcdfWriter::new();
385
386        let mut arr = NDArray::new(
387            vec![NDDimension::new(4), NDDimension::new(4)],
388            NDDataType::UInt16,
389        );
390        if let NDDataBuffer::U16(v) = &mut arr.data {
391            for i in 0..16 {
392                v[i] = (i * 1000) as u16;
393            }
394        }
395
396        writer.open_file(&path, NDFileMode::Single, &arr).unwrap();
397        writer.write_file(&arr).unwrap();
398        writer.close_file().unwrap();
399
400        let data = std::fs::read(&path).unwrap();
401        assert!(data.len() > 32);
402        assert_eq!(&data[0..3], b"CDF");
403
404        std::fs::remove_file(&path).ok();
405    }
406
407    #[test]
408    fn test_roundtrip_u8() {
409        let path = temp_path("nc_rt_u8");
410        let mut writer = NetcdfWriter::new();
411
412        let mut arr = NDArray::new(
413            vec![NDDimension::new(4), NDDimension::new(4)],
414            NDDataType::UInt8,
415        );
416        if let NDDataBuffer::U8(v) = &mut arr.data {
417            for i in 0..16 {
418                v[i] = (i * 10) as u8;
419            }
420        }
421
422        writer.open_file(&path, NDFileMode::Single, &arr).unwrap();
423        writer.write_file(&arr).unwrap();
424        writer.close_file().unwrap();
425
426        writer.current_path = Some(path.clone());
427        let read_back = writer.read_file().unwrap();
428        if let (NDDataBuffer::U8(orig), NDDataBuffer::U8(read)) =
429            (&arr.data, &read_back.data)
430        {
431            assert_eq!(orig, read);
432        } else {
433            panic!("data type mismatch on roundtrip");
434        }
435
436        std::fs::remove_file(&path).ok();
437    }
438
439    #[test]
440    fn test_roundtrip_i16() {
441        let path = temp_path("nc_rt_i16");
442        let mut writer = NetcdfWriter::new();
443
444        let mut arr = NDArray::new(
445            vec![NDDimension::new(4), NDDimension::new(4)],
446            NDDataType::Int16,
447        );
448        if let NDDataBuffer::I16(v) = &mut arr.data {
449            for i in 0..16 {
450                v[i] = (i as i16) * 100 - 500;
451            }
452        }
453
454        writer.open_file(&path, NDFileMode::Single, &arr).unwrap();
455        writer.write_file(&arr).unwrap();
456        writer.close_file().unwrap();
457
458        writer.current_path = Some(path.clone());
459        let read_back = writer.read_file().unwrap();
460        if let (NDDataBuffer::I16(orig), NDDataBuffer::I16(read)) =
461            (&arr.data, &read_back.data)
462        {
463            assert_eq!(orig, read);
464        } else {
465            panic!("data type mismatch on roundtrip");
466        }
467
468        std::fs::remove_file(&path).ok();
469    }
470
471    #[test]
472    fn test_roundtrip_f32() {
473        let path = temp_path("nc_rt_f32");
474        let mut writer = NetcdfWriter::new();
475
476        let mut arr = NDArray::new(
477            vec![NDDimension::new(4), NDDimension::new(4)],
478            NDDataType::Float32,
479        );
480        if let NDDataBuffer::F32(v) = &mut arr.data {
481            for i in 0..16 {
482                v[i] = i as f32 * 0.5;
483            }
484        }
485
486        writer.open_file(&path, NDFileMode::Single, &arr).unwrap();
487        writer.write_file(&arr).unwrap();
488        writer.close_file().unwrap();
489
490        writer.current_path = Some(path.clone());
491        let read_back = writer.read_file().unwrap();
492        if let (NDDataBuffer::F32(orig), NDDataBuffer::F32(read)) =
493            (&arr.data, &read_back.data)
494        {
495            assert_eq!(orig, read);
496        } else {
497            panic!("data type mismatch on roundtrip");
498        }
499
500        std::fs::remove_file(&path).ok();
501    }
502
503    #[test]
504    fn test_multiple_frames() {
505        let path = temp_path("nc_multi");
506        let mut writer = NetcdfWriter::new();
507
508        let mut arr1 = NDArray::new(
509            vec![NDDimension::new(4), NDDimension::new(4)],
510            NDDataType::UInt8,
511        );
512        if let NDDataBuffer::U8(v) = &mut arr1.data {
513            for i in 0..16 {
514                v[i] = i as u8;
515            }
516        }
517
518        let mut arr2 = NDArray::new(
519            vec![NDDimension::new(4), NDDimension::new(4)],
520            NDDataType::UInt8,
521        );
522        if let NDDataBuffer::U8(v) = &mut arr2.data {
523            for i in 0..16 {
524                v[i] = (i as u8).wrapping_add(100);
525            }
526        }
527
528        let mut arr3 = NDArray::new(
529            vec![NDDimension::new(4), NDDimension::new(4)],
530            NDDataType::UInt8,
531        );
532        if let NDDataBuffer::U8(v) = &mut arr3.data {
533            for i in 0..16 {
534                v[i] = (i as u8).wrapping_add(200);
535            }
536        }
537
538        writer.open_file(&path, NDFileMode::Stream, &arr1).unwrap();
539        writer.write_file(&arr1).unwrap();
540        writer.write_file(&arr2).unwrap();
541        writer.write_file(&arr3).unwrap();
542        writer.close_file().unwrap();
543
544        // Read back first frame
545        writer.current_path = Some(path.clone());
546        let read_back = writer.read_file().unwrap();
547        if let NDDataBuffer::U8(v) = &read_back.data {
548            assert_eq!(v.len(), 16);
549            for i in 0..16 {
550                assert_eq!(v[i], i as u8, "mismatch at index {}", i);
551            }
552        } else {
553            panic!("expected U8 data");
554        }
555
556        std::fs::remove_file(&path).ok();
557    }
558
559    #[test]
560    fn test_attributes_stored() {
561        let path = temp_path("nc_attrs");
562        let mut writer = NetcdfWriter::new();
563
564        let mut arr = NDArray::new(
565            vec![NDDimension::new(4)],
566            NDDataType::UInt8,
567        );
568        arr.attributes.add(NDAttribute {
569            name: "exposure".into(),
570            description: "".into(),
571            source: NDAttrSource::Driver,
572            value: NDAttrValue::Float64(0.5),
573        });
574        arr.attributes.add(NDAttribute {
575            name: "gain".into(),
576            description: "".into(),
577            source: NDAttrSource::Driver,
578            value: NDAttrValue::Int32(42),
579        });
580
581        writer.open_file(&path, NDFileMode::Single, &arr).unwrap();
582        writer.write_file(&arr).unwrap();
583        writer.close_file().unwrap();
584
585        // Verify attributes via FileReader
586        let reader = FileReader::open(&path).unwrap();
587        let ds = reader.data_set();
588
589        let exposure = ds.get_var_attr_as_string(VAR_NAME, "exposure");
590        assert_eq!(exposure, Some("0.5".to_string()));
591
592        let gain = ds.get_var_attr_as_string(VAR_NAME, "gain");
593        assert_eq!(gain, Some("42".to_string()));
594
595        drop(reader);
596        std::fs::remove_file(&path).ok();
597    }
598}