Skip to main content

ad_plugins_rs/
file_netcdf.rs

1use std::path::{Path, PathBuf};
2
3use ad_core_rs::error::{ADError, ADResult};
4use ad_core_rs::ndarray::{NDArray, NDDataBuffer, NDDataType, NDDimension};
5use ad_core_rs::ndarray_pool::NDArrayPool;
6use ad_core_rs::plugin::file_base::{NDFileMode, NDFileWriter};
7use ad_core_rs::plugin::file_controller::FilePluginController;
8use ad_core_rs::plugin::runtime::{
9    NDPluginProcess, ParamChangeResult, PluginParamSnapshot, ProcessResult,
10};
11
12use netcdf3::{DataSet, FileReader, FileWriter, Version};
13
14const VAR_NAME: &str = "array_data";
15const DIM_UNLIMITED: &str = "numArrays";
16
17/// A single buffered frame captured from an NDArray.
18struct FrameData {
19    dims: Vec<usize>,
20    data: NDDataBuffer,
21    data_type: NDDataType,
22    attrs: Vec<(String, String)>,
23}
24
25/// NetCDF-3 file writer.
26///
27/// Because `netcdf3::FileWriter` is `!Send` (uses `Rc` internally), we cannot
28/// store it as a field on a `Send + Sync` struct.  Instead we buffer frame data
29/// in memory and materialise the `FileWriter` only inside `close_file()`, where
30/// it is created, used, and dropped within a single method call.  The same
31/// approach is used for `read_file()` with `FileReader`.
32pub struct NetcdfWriter {
33    current_path: Option<PathBuf>,
34    frames: Vec<FrameData>,
35}
36
37impl NetcdfWriter {
38    pub fn new() -> Self {
39        Self {
40            current_path: None,
41            frames: Vec::new(),
42        }
43    }
44}
45
46/// Map NDDataType → netcdf3 DataType.  Returns error for 64-bit integers
47/// which NetCDF-3 classic format does not support.
48fn nc_data_type(dt: NDDataType) -> ADResult<netcdf3::DataType> {
49    match dt {
50        NDDataType::Int8 => Ok(netcdf3::DataType::I8),
51        NDDataType::UInt8 => Ok(netcdf3::DataType::U8),
52        NDDataType::Int16 | NDDataType::UInt16 => Ok(netcdf3::DataType::I16),
53        NDDataType::Int32 | NDDataType::UInt32 => Ok(netcdf3::DataType::I32),
54        NDDataType::Float32 => Ok(netcdf3::DataType::F32),
55        NDDataType::Float64 => Ok(netcdf3::DataType::F64),
56        NDDataType::Int64 | NDDataType::UInt64 => Err(ADError::UnsupportedConversion(
57            "NetCDF-3 does not support 64-bit integer types".into(),
58        )),
59    }
60}
61
62/// Write a single frame's data to a fixed-dimension variable.
63fn write_var_data(writer: &mut FileWriter, data: &NDDataBuffer) -> ADResult<()> {
64    let err = |e: netcdf3::error::WriteError| {
65        ADError::UnsupportedConversion(format!("NetCDF write error: {:?}", e))
66    };
67    match data {
68        NDDataBuffer::I8(v) => writer.write_var_i8(VAR_NAME, v).map_err(err),
69        NDDataBuffer::U8(v) => writer.write_var_u8(VAR_NAME, v).map_err(err),
70        NDDataBuffer::I16(v) => writer.write_var_i16(VAR_NAME, v).map_err(err),
71        NDDataBuffer::U16(v) => {
72            let reinterp: Vec<i16> = v.iter().map(|&x| x as i16).collect();
73            writer.write_var_i16(VAR_NAME, &reinterp).map_err(err)
74        }
75        NDDataBuffer::I32(v) => writer.write_var_i32(VAR_NAME, v).map_err(err),
76        NDDataBuffer::U32(v) => {
77            let reinterp: Vec<i32> = v.iter().map(|&x| x as i32).collect();
78            writer.write_var_i32(VAR_NAME, &reinterp).map_err(err)
79        }
80        NDDataBuffer::F32(v) => writer.write_var_f32(VAR_NAME, v).map_err(err),
81        NDDataBuffer::F64(v) => writer.write_var_f64(VAR_NAME, v).map_err(err),
82        _ => Err(ADError::UnsupportedConversion(
83            "NetCDF-3 does not support 64-bit integer types".into(),
84        )),
85    }
86}
87
88/// Write a single record (one frame) to a record variable.
89fn write_record_data(
90    writer: &mut FileWriter,
91    record_index: usize,
92    data: &NDDataBuffer,
93) -> ADResult<()> {
94    let err = |e: netcdf3::error::WriteError| {
95        ADError::UnsupportedConversion(format!("NetCDF write error: {:?}", e))
96    };
97    match data {
98        NDDataBuffer::I8(v) => writer
99            .write_record_i8(VAR_NAME, record_index, v)
100            .map_err(err),
101        NDDataBuffer::U8(v) => writer
102            .write_record_u8(VAR_NAME, record_index, v)
103            .map_err(err),
104        NDDataBuffer::I16(v) => writer
105            .write_record_i16(VAR_NAME, record_index, v)
106            .map_err(err),
107        NDDataBuffer::U16(v) => {
108            let reinterp: Vec<i16> = v.iter().map(|&x| x as i16).collect();
109            writer
110                .write_record_i16(VAR_NAME, record_index, &reinterp)
111                .map_err(err)
112        }
113        NDDataBuffer::I32(v) => writer
114            .write_record_i32(VAR_NAME, record_index, v)
115            .map_err(err),
116        NDDataBuffer::U32(v) => {
117            let reinterp: Vec<i32> = v.iter().map(|&x| x as i32).collect();
118            writer
119                .write_record_i32(VAR_NAME, record_index, &reinterp)
120                .map_err(err)
121        }
122        NDDataBuffer::F32(v) => writer
123            .write_record_f32(VAR_NAME, record_index, v)
124            .map_err(err),
125        NDDataBuffer::F64(v) => writer
126            .write_record_f64(VAR_NAME, record_index, v)
127            .map_err(err),
128        _ => Err(ADError::UnsupportedConversion(
129            "NetCDF-3 does not support 64-bit integer types".into(),
130        )),
131    }
132}
133
134impl NDFileWriter for NetcdfWriter {
135    fn open_file(&mut self, path: &Path, _mode: NDFileMode, _array: &NDArray) -> ADResult<()> {
136        self.current_path = Some(path.to_path_buf());
137        self.frames.clear();
138        Ok(())
139    }
140
141    fn write_file(&mut self, array: &NDArray) -> ADResult<()> {
142        // Validate data type early
143        nc_data_type(array.data.data_type())?;
144
145        let dims: Vec<usize> = array.dims.iter().map(|d| d.size).collect();
146        let attrs: Vec<(String, String)> = array
147            .attributes
148            .iter()
149            .map(|a| (a.name.clone(), a.value.as_string()))
150            .collect();
151
152        self.frames.push(FrameData {
153            dims,
154            data: array.data.clone(),
155            data_type: array.data.data_type(),
156            attrs,
157        });
158        Ok(())
159    }
160
161    fn close_file(&mut self) -> ADResult<()> {
162        let path = match self.current_path.take() {
163            Some(p) => p,
164            None => return Ok(()),
165        };
166
167        if self.frames.is_empty() {
168            return Ok(());
169        }
170
171        let map_def = |e: netcdf3::error::InvalidDataSet| {
172            ADError::UnsupportedConversion(format!("NetCDF definition error: {:?}", e))
173        };
174        let map_write = |e: netcdf3::error::WriteError| {
175            ADError::UnsupportedConversion(format!("NetCDF write error: {:?}", e))
176        };
177
178        let first = &self.frames[0];
179        let nc_dt = nc_data_type(first.data_type)?;
180        let multi = self.frames.len() > 1;
181
182        // Build DataSet definition
183        let mut ds = DataSet::new();
184
185        // Fixed dimensions: dim0, dim1, ...
186        let mut dim_names: Vec<String> = Vec::new();
187        for (i, &size) in first.dims.iter().enumerate() {
188            let name = format!("dim{}", i);
189            ds.add_fixed_dim(&name, size).map_err(map_def)?;
190            dim_names.push(name);
191        }
192
193        // Variable dimensions list
194        let var_dims: Vec<String> = if multi {
195            // Unlimited dimension first for record variables
196            ds.set_unlimited_dim(DIM_UNLIMITED, self.frames.len())
197                .map_err(map_def)?;
198            let mut v = vec![DIM_UNLIMITED.to_string()];
199            v.extend(dim_names.iter().cloned());
200            v
201        } else {
202            dim_names.clone()
203        };
204
205        let var_dim_refs: Vec<&str> = var_dims.iter().map(|s| s.as_str()).collect();
206        ds.add_var(VAR_NAME, &var_dim_refs, nc_dt)
207            .map_err(map_def)?;
208
209        // Store NDArray attributes as variable attributes on array_data
210        // Merge attributes from all frames (first frame wins on duplicates)
211        let mut seen_attrs = std::collections::HashSet::new();
212        for frame in &self.frames {
213            for (name, value) in &frame.attrs {
214                if seen_attrs.insert(name.clone()) {
215                    let _ = ds.add_var_attr_string(VAR_NAME, name, value);
216                }
217            }
218        }
219
220        // Global attributes
221        ds.add_global_attr_i32("uniqueId", vec![0])
222            .map_err(map_def)?;
223        ds.add_global_attr_i32("dataType", vec![first.data_type as i32])
224            .map_err(map_def)?;
225        ds.add_global_attr_i32("numArrays", vec![self.frames.len() as i32])
226            .map_err(map_def)?;
227
228        // Write
229        let mut writer = FileWriter::open(&path).map_err(map_write)?;
230        writer
231            .set_def(&ds, Version::Classic, 0)
232            .map_err(map_write)?;
233
234        if multi {
235            for (i, frame) in self.frames.iter().enumerate() {
236                write_record_data(&mut writer, i, &frame.data)?;
237            }
238        } else {
239            write_var_data(&mut writer, &self.frames[0].data)?;
240        }
241
242        writer.close().map_err(map_write)?;
243        self.frames.clear();
244        Ok(())
245    }
246
247    fn read_file(&mut self) -> ADResult<NDArray> {
248        let path = self
249            .current_path
250            .as_ref()
251            .ok_or_else(|| ADError::UnsupportedConversion("no file open".into()))?;
252
253        let map_read = |e: netcdf3::error::ReadError| {
254            ADError::UnsupportedConversion(format!("NetCDF read error: {:?}", e))
255        };
256
257        let mut reader = FileReader::open(path).map_err(map_read)?;
258
259        // Extract metadata from data_set() before any mutable read calls
260        let (is_record, dims, original_type_ordinal) = {
261            let ds = reader.data_set();
262            let var = ds.get_var(VAR_NAME).ok_or_else(|| {
263                ADError::UnsupportedConversion(format!(
264                    "variable '{}' not found in NetCDF file",
265                    VAR_NAME
266                ))
267            })?;
268
269            let is_record = ds.is_record_var(VAR_NAME).unwrap_or(false);
270
271            let var_dims_rc = var.get_dims();
272            let mut dims: Vec<NDDimension> = Vec::new();
273            for d in &var_dims_rc {
274                if d.is_unlimited() {
275                    continue;
276                }
277                dims.push(NDDimension::new(d.size()));
278            }
279
280            let original_type_ordinal = ds
281                .get_global_attr_i32("dataType")
282                .and_then(|slice| slice.first().copied());
283
284            (is_record, dims, original_type_ordinal)
285        };
286
287        // Read first frame (record 0 if record variable, else full var)
288        let data_vec = if is_record {
289            reader.read_record(VAR_NAME, 0).map_err(map_read)?
290        } else {
291            reader.read_var(VAR_NAME).map_err(map_read)?
292        };
293
294        let (nd_type, buf) = match data_vec {
295            netcdf3::DataVector::I8(v) => (NDDataType::Int8, NDDataBuffer::I8(v)),
296            netcdf3::DataVector::U8(v) => (NDDataType::UInt8, NDDataBuffer::U8(v)),
297            netcdf3::DataVector::I16(v) => (NDDataType::Int16, NDDataBuffer::I16(v)),
298            netcdf3::DataVector::I32(v) => (NDDataType::Int32, NDDataBuffer::I32(v)),
299            netcdf3::DataVector::F32(v) => (NDDataType::Float32, NDDataBuffer::F32(v)),
300            netcdf3::DataVector::F64(v) => (NDDataType::Float64, NDDataBuffer::F64(v)),
301        };
302
303        // Check global attr "dataType" to recover original NDDataType
304        let actual_type = original_type_ordinal
305            .and_then(|v| NDDataType::from_ordinal(v as u8))
306            .unwrap_or(nd_type);
307
308        // Re-interpret if the original type was unsigned and stored as signed
309        let buf = match (actual_type, buf) {
310            (NDDataType::UInt16, NDDataBuffer::I16(v)) => {
311                NDDataBuffer::U16(v.into_iter().map(|x| x as u16).collect())
312            }
313            (NDDataType::UInt32, NDDataBuffer::I32(v)) => {
314                NDDataBuffer::U32(v.into_iter().map(|x| x as u32).collect())
315            }
316            (_, buf) => buf,
317        };
318
319        let mut arr = NDArray::new(dims, actual_type);
320        arr.data = buf;
321        Ok(arr)
322    }
323
324    fn supports_multiple_arrays(&self) -> bool {
325        true
326    }
327}
328
329/// NetCDF file processor wrapping NDPluginFileBase + NetcdfWriter.
330pub struct NetcdfFileProcessor {
331    ctrl: FilePluginController<NetcdfWriter>,
332}
333
334impl NetcdfFileProcessor {
335    pub fn new() -> Self {
336        Self {
337            ctrl: FilePluginController::new(NetcdfWriter::new()),
338        }
339    }
340}
341
342impl Default for NetcdfFileProcessor {
343    fn default() -> Self {
344        Self::new()
345    }
346}
347
348impl NDPluginProcess for NetcdfFileProcessor {
349    fn process_array(&mut self, array: &NDArray, _pool: &NDArrayPool) -> ProcessResult {
350        self.ctrl.process_array(array)
351    }
352
353    fn plugin_type(&self) -> &str {
354        "NDFileNetCDF"
355    }
356
357    fn register_params(
358        &mut self,
359        base: &mut asyn_rs::port::PortDriverBase,
360    ) -> asyn_rs::error::AsynResult<()> {
361        self.ctrl.register_params(base)
362    }
363
364    fn on_param_change(
365        &mut self,
366        reason: usize,
367        params: &PluginParamSnapshot,
368    ) -> ParamChangeResult {
369        self.ctrl.on_param_change(reason, params)
370    }
371}
372
373#[cfg(test)]
374mod tests {
375    use super::*;
376    use ad_core_rs::attributes::{NDAttrSource, NDAttrValue, NDAttribute};
377    use std::sync::atomic::{AtomicU32, Ordering};
378
379    static TEST_COUNTER: AtomicU32 = AtomicU32::new(0);
380
381    fn temp_path(prefix: &str) -> PathBuf {
382        let n = TEST_COUNTER.fetch_add(1, Ordering::Relaxed);
383        std::env::temp_dir().join(format!("adcore_test_{}_{}.nc", prefix, n))
384    }
385
386    #[test]
387    fn test_write_u8_mono() {
388        let path = temp_path("nc_u8");
389        let mut writer = NetcdfWriter::new();
390
391        let mut arr = NDArray::new(
392            vec![NDDimension::new(4), NDDimension::new(4)],
393            NDDataType::UInt8,
394        );
395        if let NDDataBuffer::U8(v) = &mut arr.data {
396            for i in 0..16 {
397                v[i] = i as u8;
398            }
399        }
400
401        writer.open_file(&path, NDFileMode::Single, &arr).unwrap();
402        writer.write_file(&arr).unwrap();
403        writer.close_file().unwrap();
404
405        // Verify file exists and has NetCDF magic bytes: "CDF\x01" or "CDF\x02"
406        let data = std::fs::read(&path).unwrap();
407        assert!(data.len() > 16);
408        assert_eq!(&data[0..3], b"CDF", "Expected NetCDF magic bytes");
409
410        std::fs::remove_file(&path).ok();
411    }
412
413    #[test]
414    fn test_write_u16() {
415        let path = temp_path("nc_u16");
416        let mut writer = NetcdfWriter::new();
417
418        let mut arr = NDArray::new(
419            vec![NDDimension::new(4), NDDimension::new(4)],
420            NDDataType::UInt16,
421        );
422        if let NDDataBuffer::U16(v) = &mut arr.data {
423            for i in 0..16 {
424                v[i] = (i * 1000) as u16;
425            }
426        }
427
428        writer.open_file(&path, NDFileMode::Single, &arr).unwrap();
429        writer.write_file(&arr).unwrap();
430        writer.close_file().unwrap();
431
432        let data = std::fs::read(&path).unwrap();
433        assert!(data.len() > 32);
434        assert_eq!(&data[0..3], b"CDF");
435
436        std::fs::remove_file(&path).ok();
437    }
438
439    #[test]
440    fn test_roundtrip_u8() {
441        let path = temp_path("nc_rt_u8");
442        let mut writer = NetcdfWriter::new();
443
444        let mut arr = NDArray::new(
445            vec![NDDimension::new(4), NDDimension::new(4)],
446            NDDataType::UInt8,
447        );
448        if let NDDataBuffer::U8(v) = &mut arr.data {
449            for i in 0..16 {
450                v[i] = (i * 10) as u8;
451            }
452        }
453
454        writer.open_file(&path, NDFileMode::Single, &arr).unwrap();
455        writer.write_file(&arr).unwrap();
456        writer.close_file().unwrap();
457
458        writer.current_path = Some(path.clone());
459        let read_back = writer.read_file().unwrap();
460        if let (NDDataBuffer::U8(orig), NDDataBuffer::U8(read)) = (&arr.data, &read_back.data) {
461            assert_eq!(orig, read);
462        } else {
463            panic!("data type mismatch on roundtrip");
464        }
465
466        std::fs::remove_file(&path).ok();
467    }
468
469    #[test]
470    fn test_roundtrip_i16() {
471        let path = temp_path("nc_rt_i16");
472        let mut writer = NetcdfWriter::new();
473
474        let mut arr = NDArray::new(
475            vec![NDDimension::new(4), NDDimension::new(4)],
476            NDDataType::Int16,
477        );
478        if let NDDataBuffer::I16(v) = &mut arr.data {
479            for i in 0..16 {
480                v[i] = (i as i16) * 100 - 500;
481            }
482        }
483
484        writer.open_file(&path, NDFileMode::Single, &arr).unwrap();
485        writer.write_file(&arr).unwrap();
486        writer.close_file().unwrap();
487
488        writer.current_path = Some(path.clone());
489        let read_back = writer.read_file().unwrap();
490        if let (NDDataBuffer::I16(orig), NDDataBuffer::I16(read)) = (&arr.data, &read_back.data) {
491            assert_eq!(orig, read);
492        } else {
493            panic!("data type mismatch on roundtrip");
494        }
495
496        std::fs::remove_file(&path).ok();
497    }
498
499    #[test]
500    fn test_roundtrip_f32() {
501        let path = temp_path("nc_rt_f32");
502        let mut writer = NetcdfWriter::new();
503
504        let mut arr = NDArray::new(
505            vec![NDDimension::new(4), NDDimension::new(4)],
506            NDDataType::Float32,
507        );
508        if let NDDataBuffer::F32(v) = &mut arr.data {
509            for i in 0..16 {
510                v[i] = i as f32 * 0.5;
511            }
512        }
513
514        writer.open_file(&path, NDFileMode::Single, &arr).unwrap();
515        writer.write_file(&arr).unwrap();
516        writer.close_file().unwrap();
517
518        writer.current_path = Some(path.clone());
519        let read_back = writer.read_file().unwrap();
520        if let (NDDataBuffer::F32(orig), NDDataBuffer::F32(read)) = (&arr.data, &read_back.data) {
521            assert_eq!(orig, read);
522        } else {
523            panic!("data type mismatch on roundtrip");
524        }
525
526        std::fs::remove_file(&path).ok();
527    }
528
529    #[test]
530    fn test_multiple_frames() {
531        let path = temp_path("nc_multi");
532        let mut writer = NetcdfWriter::new();
533
534        let mut arr1 = NDArray::new(
535            vec![NDDimension::new(4), NDDimension::new(4)],
536            NDDataType::UInt8,
537        );
538        if let NDDataBuffer::U8(v) = &mut arr1.data {
539            for i in 0..16 {
540                v[i] = i as u8;
541            }
542        }
543
544        let mut arr2 = NDArray::new(
545            vec![NDDimension::new(4), NDDimension::new(4)],
546            NDDataType::UInt8,
547        );
548        if let NDDataBuffer::U8(v) = &mut arr2.data {
549            for i in 0..16 {
550                v[i] = (i as u8).wrapping_add(100);
551            }
552        }
553
554        let mut arr3 = NDArray::new(
555            vec![NDDimension::new(4), NDDimension::new(4)],
556            NDDataType::UInt8,
557        );
558        if let NDDataBuffer::U8(v) = &mut arr3.data {
559            for i in 0..16 {
560                v[i] = (i as u8).wrapping_add(200);
561            }
562        }
563
564        writer.open_file(&path, NDFileMode::Stream, &arr1).unwrap();
565        writer.write_file(&arr1).unwrap();
566        writer.write_file(&arr2).unwrap();
567        writer.write_file(&arr3).unwrap();
568        writer.close_file().unwrap();
569
570        // Read back first frame
571        writer.current_path = Some(path.clone());
572        let read_back = writer.read_file().unwrap();
573        if let NDDataBuffer::U8(v) = &read_back.data {
574            assert_eq!(v.len(), 16);
575            for i in 0..16 {
576                assert_eq!(v[i], i as u8, "mismatch at index {}", i);
577            }
578        } else {
579            panic!("expected U8 data");
580        }
581
582        std::fs::remove_file(&path).ok();
583    }
584
585    #[test]
586    fn test_attributes_stored() {
587        let path = temp_path("nc_attrs");
588        let mut writer = NetcdfWriter::new();
589
590        let mut arr = NDArray::new(vec![NDDimension::new(4)], NDDataType::UInt8);
591        arr.attributes.add(NDAttribute {
592            name: "exposure".into(),
593            description: "".into(),
594            source: NDAttrSource::Driver,
595            value: NDAttrValue::Float64(0.5),
596        });
597        arr.attributes.add(NDAttribute {
598            name: "gain".into(),
599            description: "".into(),
600            source: NDAttrSource::Driver,
601            value: NDAttrValue::Int32(42),
602        });
603
604        writer.open_file(&path, NDFileMode::Single, &arr).unwrap();
605        writer.write_file(&arr).unwrap();
606        writer.close_file().unwrap();
607
608        // Verify attributes via FileReader
609        let reader = FileReader::open(&path).unwrap();
610        let ds = reader.data_set();
611
612        let exposure = ds.get_var_attr_as_string(VAR_NAME, "exposure");
613        assert_eq!(exposure, Some("0.5".to_string()));
614
615        let gain = ds.get_var_attr_as_string(VAR_NAME, "gain");
616        assert_eq!(gain, Some("42".to_string()));
617
618        drop(reader);
619        std::fs::remove_file(&path).ok();
620    }
621}