1use std::path::{Path, PathBuf};
2
3use ad_core_rs::error::{ADError, ADResult};
4use ad_core_rs::ndarray::{NDArray, NDDataBuffer, NDDataType, NDDimension};
5use ad_core_rs::ndarray_pool::NDArrayPool;
6use ad_core_rs::plugin::file_base::{NDFileMode, NDFileWriter};
7use ad_core_rs::plugin::file_controller::FilePluginController;
8use ad_core_rs::plugin::runtime::{
9 NDPluginProcess, ParamChangeResult, PluginParamSnapshot, ProcessResult,
10};
11
12use netcdf3::{DataSet, FileReader, FileWriter, Version};
13
14const VAR_NAME: &str = "array_data";
15const DIM_UNLIMITED: &str = "numArrays";
16
17struct DimMeta {
19 size: usize,
20 offset: usize,
21 binning: usize,
22 reverse: bool,
23}
24
25struct FrameData {
27 dims: Vec<usize>,
28 dim_meta: Vec<DimMeta>,
29 data: NDDataBuffer,
30 data_type: NDDataType,
31 attrs: Vec<(String, String)>,
32 unique_id: i32,
33 time_stamp: f64,
34}
35
36pub struct NetcdfWriter {
44 current_path: Option<PathBuf>,
45 frames: Vec<FrameData>,
46}
47
48impl NetcdfWriter {
49 pub fn new() -> Self {
50 Self {
51 current_path: None,
52 frames: Vec::new(),
53 }
54 }
55}
56
57fn nc_data_type(dt: NDDataType) -> ADResult<netcdf3::DataType> {
60 match dt {
61 NDDataType::Int8 => Ok(netcdf3::DataType::I8),
62 NDDataType::UInt8 => Ok(netcdf3::DataType::U8),
63 NDDataType::Int16 | NDDataType::UInt16 => Ok(netcdf3::DataType::I16),
64 NDDataType::Int32 | NDDataType::UInt32 => Ok(netcdf3::DataType::I32),
65 NDDataType::Float32 => Ok(netcdf3::DataType::F32),
66 NDDataType::Float64 => Ok(netcdf3::DataType::F64),
67 NDDataType::Int64 | NDDataType::UInt64 => Ok(netcdf3::DataType::F64),
68 }
69}
70
71fn write_var_data(writer: &mut FileWriter, data: &NDDataBuffer) -> ADResult<()> {
73 let err = |e: netcdf3::error::WriteError| {
74 ADError::UnsupportedConversion(format!("NetCDF write error: {:?}", e))
75 };
76 match data {
77 NDDataBuffer::I8(v) => writer.write_var_i8(VAR_NAME, v).map_err(err),
78 NDDataBuffer::U8(v) => writer.write_var_u8(VAR_NAME, v).map_err(err),
79 NDDataBuffer::I16(v) => writer.write_var_i16(VAR_NAME, v).map_err(err),
80 NDDataBuffer::U16(v) => {
81 let reinterp: Vec<i16> = v.iter().map(|&x| x as i16).collect();
82 writer.write_var_i16(VAR_NAME, &reinterp).map_err(err)
83 }
84 NDDataBuffer::I32(v) => writer.write_var_i32(VAR_NAME, v).map_err(err),
85 NDDataBuffer::U32(v) => {
86 let reinterp: Vec<i32> = v.iter().map(|&x| x as i32).collect();
87 writer.write_var_i32(VAR_NAME, &reinterp).map_err(err)
88 }
89 NDDataBuffer::F32(v) => writer.write_var_f32(VAR_NAME, v).map_err(err),
90 NDDataBuffer::F64(v) => writer.write_var_f64(VAR_NAME, v).map_err(err),
91 NDDataBuffer::I64(v) => {
92 let reinterp: Vec<f64> = v.iter().map(|&x| x as f64).collect();
93 writer.write_var_f64(VAR_NAME, &reinterp).map_err(err)
94 }
95 NDDataBuffer::U64(v) => {
96 let reinterp: Vec<f64> = v.iter().map(|&x| x as f64).collect();
97 writer.write_var_f64(VAR_NAME, &reinterp).map_err(err)
98 }
99 }
100}
101
102fn write_record_data(
104 writer: &mut FileWriter,
105 record_index: usize,
106 data: &NDDataBuffer,
107) -> ADResult<()> {
108 let err = |e: netcdf3::error::WriteError| {
109 ADError::UnsupportedConversion(format!("NetCDF write error: {:?}", e))
110 };
111 match data {
112 NDDataBuffer::I8(v) => writer
113 .write_record_i8(VAR_NAME, record_index, v)
114 .map_err(err),
115 NDDataBuffer::U8(v) => writer
116 .write_record_u8(VAR_NAME, record_index, v)
117 .map_err(err),
118 NDDataBuffer::I16(v) => writer
119 .write_record_i16(VAR_NAME, record_index, v)
120 .map_err(err),
121 NDDataBuffer::U16(v) => {
122 let reinterp: Vec<i16> = v.iter().map(|&x| x as i16).collect();
123 writer
124 .write_record_i16(VAR_NAME, record_index, &reinterp)
125 .map_err(err)
126 }
127 NDDataBuffer::I32(v) => writer
128 .write_record_i32(VAR_NAME, record_index, v)
129 .map_err(err),
130 NDDataBuffer::U32(v) => {
131 let reinterp: Vec<i32> = v.iter().map(|&x| x as i32).collect();
132 writer
133 .write_record_i32(VAR_NAME, record_index, &reinterp)
134 .map_err(err)
135 }
136 NDDataBuffer::F32(v) => writer
137 .write_record_f32(VAR_NAME, record_index, v)
138 .map_err(err),
139 NDDataBuffer::F64(v) => writer
140 .write_record_f64(VAR_NAME, record_index, v)
141 .map_err(err),
142 NDDataBuffer::I64(v) => {
143 let reinterp: Vec<f64> = v.iter().map(|&x| x as f64).collect();
144 writer
145 .write_record_f64(VAR_NAME, record_index, &reinterp)
146 .map_err(err)
147 }
148 NDDataBuffer::U64(v) => {
149 let reinterp: Vec<f64> = v.iter().map(|&x| x as f64).collect();
150 writer
151 .write_record_f64(VAR_NAME, record_index, &reinterp)
152 .map_err(err)
153 }
154 }
155}
156
157impl NDFileWriter for NetcdfWriter {
158 fn open_file(&mut self, path: &Path, _mode: NDFileMode, _array: &NDArray) -> ADResult<()> {
159 self.current_path = Some(path.to_path_buf());
160 self.frames.clear();
161 Ok(())
162 }
163
164 fn write_file(&mut self, array: &NDArray) -> ADResult<()> {
165 nc_data_type(array.data.data_type())?;
167
168 let dims: Vec<usize> = array.dims.iter().map(|d| d.size).collect();
169 let dim_meta: Vec<DimMeta> = array
170 .dims
171 .iter()
172 .map(|d| DimMeta {
173 size: d.size,
174 offset: d.offset,
175 binning: d.binning,
176 reverse: d.reverse,
177 })
178 .collect();
179 let attrs: Vec<(String, String)> = array
180 .attributes
181 .iter()
182 .map(|a| (a.name.clone(), a.value.as_string()))
183 .collect();
184
185 self.frames.push(FrameData {
186 dims,
187 dim_meta,
188 data: array.data.clone(),
189 data_type: array.data.data_type(),
190 attrs,
191 unique_id: array.unique_id,
192 time_stamp: array.time_stamp,
193 });
194 Ok(())
195 }
196
197 fn close_file(&mut self) -> ADResult<()> {
198 let path = match self.current_path.take() {
199 Some(p) => p,
200 None => return Ok(()),
201 };
202
203 if self.frames.is_empty() {
204 return Ok(());
205 }
206
207 let map_def = |e: netcdf3::error::InvalidDataSet| {
208 ADError::UnsupportedConversion(format!("NetCDF definition error: {:?}", e))
209 };
210 let map_write = |e: netcdf3::error::WriteError| {
211 ADError::UnsupportedConversion(format!("NetCDF write error: {:?}", e))
212 };
213
214 let first = &self.frames[0];
215 let nc_dt = nc_data_type(first.data_type)?;
216 let multi = self.frames.len() > 1;
217
218 let mut ds = DataSet::new();
220
221 let ndims = first.dims.len();
223 let mut dim_names: Vec<String> = Vec::new();
224 for i in 0..ndims {
225 let dim_idx = ndims - 1 - i;
226 let name = format!("dim{}", i);
227 ds.add_fixed_dim(&name, first.dims[dim_idx])
228 .map_err(map_def)?;
229 dim_names.push(name);
230 }
231
232 let var_dims: Vec<String> = if multi {
234 ds.set_unlimited_dim(DIM_UNLIMITED, self.frames.len())
236 .map_err(map_def)?;
237 let mut v = vec![DIM_UNLIMITED.to_string()];
238 v.extend(dim_names.iter().cloned());
239 v
240 } else {
241 dim_names.clone()
242 };
243
244 let var_dim_refs: Vec<&str> = var_dims.iter().map(|s| s.as_str()).collect();
245 ds.add_var(VAR_NAME, &var_dim_refs, nc_dt)
246 .map_err(map_def)?;
247
248 let mut seen_attrs = std::collections::HashSet::new();
251 for frame in &self.frames {
252 for (name, value) in &frame.attrs {
253 if seen_attrs.insert(name.clone()) {
254 let _ = ds.add_var_attr_string(VAR_NAME, name, value);
255 }
256 }
257 }
258
259 if multi {
261 ds.add_var("uniqueId", &[DIM_UNLIMITED], netcdf3::DataType::I32)
262 .map_err(map_def)?;
263 ds.add_var("timeStamp", &[DIM_UNLIMITED], netcdf3::DataType::F64)
264 .map_err(map_def)?;
265 }
266
267 ds.add_global_attr_i32("uniqueId", vec![first.unique_id])
269 .map_err(map_def)?;
270 ds.add_global_attr_i32("dataType", vec![first.data_type as i32])
271 .map_err(map_def)?;
272 ds.add_global_attr_i32("numArrays", vec![self.frames.len() as i32])
273 .map_err(map_def)?;
274
275 ds.add_global_attr_i32("numArrayDims", vec![ndims as i32])
277 .map_err(map_def)?;
278 let dim_size: Vec<i32> = first.dim_meta.iter().map(|d| d.size as i32).collect();
279 ds.add_global_attr_i32("dimSize", dim_size)
280 .map_err(map_def)?;
281 let dim_offset: Vec<i32> = first.dim_meta.iter().map(|d| d.offset as i32).collect();
282 ds.add_global_attr_i32("dimOffset", dim_offset)
283 .map_err(map_def)?;
284 let dim_binning: Vec<i32> = first.dim_meta.iter().map(|d| d.binning as i32).collect();
285 ds.add_global_attr_i32("dimBinning", dim_binning)
286 .map_err(map_def)?;
287 let dim_reverse: Vec<i32> = first
288 .dim_meta
289 .iter()
290 .map(|d| if d.reverse { 1 } else { 0 })
291 .collect();
292 ds.add_global_attr_i32("dimReverse", dim_reverse)
293 .map_err(map_def)?;
294
295 let mut writer = FileWriter::open(&path).map_err(map_write)?;
297 writer
298 .set_def(&ds, Version::Classic, 0)
299 .map_err(map_write)?;
300
301 if multi {
302 for (i, frame) in self.frames.iter().enumerate() {
303 write_record_data(&mut writer, i, &frame.data)?;
304 writer
305 .write_record_i32("uniqueId", i, &[frame.unique_id])
306 .map_err(map_write)?;
307 writer
308 .write_record_f64("timeStamp", i, &[frame.time_stamp])
309 .map_err(map_write)?;
310 }
311 } else {
312 write_var_data(&mut writer, &self.frames[0].data)?;
313 }
314
315 writer.close().map_err(map_write)?;
316 self.frames.clear();
317 Ok(())
318 }
319
320 fn read_file(&mut self) -> ADResult<NDArray> {
321 let path = self
322 .current_path
323 .as_ref()
324 .ok_or_else(|| ADError::UnsupportedConversion("no file open".into()))?;
325
326 let map_read = |e: netcdf3::error::ReadError| {
327 ADError::UnsupportedConversion(format!("NetCDF read error: {:?}", e))
328 };
329
330 let mut reader = FileReader::open(path).map_err(map_read)?;
331
332 let (is_record, dims, original_type_ordinal) = {
334 let ds = reader.data_set();
335 let var = ds.get_var(VAR_NAME).ok_or_else(|| {
336 ADError::UnsupportedConversion(format!(
337 "variable '{}' not found in NetCDF file",
338 VAR_NAME
339 ))
340 })?;
341
342 let is_record = ds.is_record_var(VAR_NAME).unwrap_or(false);
343
344 let var_dims_rc = var.get_dims();
345 let mut dims: Vec<NDDimension> = Vec::new();
346 for d in &var_dims_rc {
347 if d.is_unlimited() {
348 continue;
349 }
350 dims.push(NDDimension::new(d.size()));
351 }
352
353 let original_type_ordinal = ds
354 .get_global_attr_i32("dataType")
355 .and_then(|slice| slice.first().copied());
356
357 (is_record, dims, original_type_ordinal)
358 };
359
360 let data_vec = if is_record {
362 reader.read_record(VAR_NAME, 0).map_err(map_read)?
363 } else {
364 reader.read_var(VAR_NAME).map_err(map_read)?
365 };
366
367 let (nd_type, buf) = match data_vec {
368 netcdf3::DataVector::I8(v) => (NDDataType::Int8, NDDataBuffer::I8(v)),
369 netcdf3::DataVector::U8(v) => (NDDataType::UInt8, NDDataBuffer::U8(v)),
370 netcdf3::DataVector::I16(v) => (NDDataType::Int16, NDDataBuffer::I16(v)),
371 netcdf3::DataVector::I32(v) => (NDDataType::Int32, NDDataBuffer::I32(v)),
372 netcdf3::DataVector::F32(v) => (NDDataType::Float32, NDDataBuffer::F32(v)),
373 netcdf3::DataVector::F64(v) => (NDDataType::Float64, NDDataBuffer::F64(v)),
374 };
375
376 let actual_type = original_type_ordinal
378 .and_then(|v| NDDataType::from_ordinal(v as u8))
379 .unwrap_or(nd_type);
380
381 let buf = match (actual_type, buf) {
383 (NDDataType::UInt16, NDDataBuffer::I16(v)) => {
384 NDDataBuffer::U16(v.into_iter().map(|x| x as u16).collect())
385 }
386 (NDDataType::UInt32, NDDataBuffer::I32(v)) => {
387 NDDataBuffer::U32(v.into_iter().map(|x| x as u32).collect())
388 }
389 (_, buf) => buf,
390 };
391
392 let mut arr = NDArray::new(dims, actual_type);
393 arr.data = buf;
394 Ok(arr)
395 }
396
397 fn supports_multiple_arrays(&self) -> bool {
398 true
399 }
400}
401
402pub struct NetcdfFileProcessor {
404 ctrl: FilePluginController<NetcdfWriter>,
405}
406
407impl NetcdfFileProcessor {
408 pub fn new() -> Self {
409 Self {
410 ctrl: FilePluginController::new(NetcdfWriter::new()),
411 }
412 }
413}
414
415impl Default for NetcdfFileProcessor {
416 fn default() -> Self {
417 Self::new()
418 }
419}
420
421impl NDPluginProcess for NetcdfFileProcessor {
422 fn process_array(&mut self, array: &NDArray, _pool: &NDArrayPool) -> ProcessResult {
423 self.ctrl.process_array(array)
424 }
425
426 fn plugin_type(&self) -> &str {
427 "NDFileNetCDF"
428 }
429
430 fn register_params(
431 &mut self,
432 base: &mut asyn_rs::port::PortDriverBase,
433 ) -> asyn_rs::error::AsynResult<()> {
434 self.ctrl.register_params(base)
435 }
436
437 fn on_param_change(
438 &mut self,
439 reason: usize,
440 params: &PluginParamSnapshot,
441 ) -> ParamChangeResult {
442 self.ctrl.on_param_change(reason, params)
443 }
444}
445
446#[cfg(test)]
447mod tests {
448 use super::*;
449 use ad_core_rs::attributes::{NDAttrSource, NDAttrValue, NDAttribute};
450 use std::sync::atomic::{AtomicU32, Ordering};
451
452 static TEST_COUNTER: AtomicU32 = AtomicU32::new(0);
453
454 fn temp_path(prefix: &str) -> PathBuf {
455 let n = TEST_COUNTER.fetch_add(1, Ordering::Relaxed);
456 std::env::temp_dir().join(format!("adcore_test_{}_{}.nc", prefix, n))
457 }
458
459 #[test]
460 fn test_write_u8_mono() {
461 let path = temp_path("nc_u8");
462 let mut writer = NetcdfWriter::new();
463
464 let mut arr = NDArray::new(
465 vec![NDDimension::new(4), NDDimension::new(4)],
466 NDDataType::UInt8,
467 );
468 if let NDDataBuffer::U8(v) = &mut arr.data {
469 for i in 0..16 {
470 v[i] = i as u8;
471 }
472 }
473
474 writer.open_file(&path, NDFileMode::Single, &arr).unwrap();
475 writer.write_file(&arr).unwrap();
476 writer.close_file().unwrap();
477
478 let data = std::fs::read(&path).unwrap();
480 assert!(data.len() > 16);
481 assert_eq!(&data[0..3], b"CDF", "Expected NetCDF magic bytes");
482
483 std::fs::remove_file(&path).ok();
484 }
485
486 #[test]
487 fn test_write_u16() {
488 let path = temp_path("nc_u16");
489 let mut writer = NetcdfWriter::new();
490
491 let mut arr = NDArray::new(
492 vec![NDDimension::new(4), NDDimension::new(4)],
493 NDDataType::UInt16,
494 );
495 if let NDDataBuffer::U16(v) = &mut arr.data {
496 for i in 0..16 {
497 v[i] = (i * 1000) as u16;
498 }
499 }
500
501 writer.open_file(&path, NDFileMode::Single, &arr).unwrap();
502 writer.write_file(&arr).unwrap();
503 writer.close_file().unwrap();
504
505 let data = std::fs::read(&path).unwrap();
506 assert!(data.len() > 32);
507 assert_eq!(&data[0..3], b"CDF");
508
509 std::fs::remove_file(&path).ok();
510 }
511
512 #[test]
513 fn test_roundtrip_u8() {
514 let path = temp_path("nc_rt_u8");
515 let mut writer = NetcdfWriter::new();
516
517 let mut arr = NDArray::new(
518 vec![NDDimension::new(4), NDDimension::new(4)],
519 NDDataType::UInt8,
520 );
521 if let NDDataBuffer::U8(v) = &mut arr.data {
522 for i in 0..16 {
523 v[i] = (i * 10) as u8;
524 }
525 }
526
527 writer.open_file(&path, NDFileMode::Single, &arr).unwrap();
528 writer.write_file(&arr).unwrap();
529 writer.close_file().unwrap();
530
531 writer.current_path = Some(path.clone());
532 let read_back = writer.read_file().unwrap();
533 if let (NDDataBuffer::U8(orig), NDDataBuffer::U8(read)) = (&arr.data, &read_back.data) {
534 assert_eq!(orig, read);
535 } else {
536 panic!("data type mismatch on roundtrip");
537 }
538
539 std::fs::remove_file(&path).ok();
540 }
541
542 #[test]
543 fn test_roundtrip_i16() {
544 let path = temp_path("nc_rt_i16");
545 let mut writer = NetcdfWriter::new();
546
547 let mut arr = NDArray::new(
548 vec![NDDimension::new(4), NDDimension::new(4)],
549 NDDataType::Int16,
550 );
551 if let NDDataBuffer::I16(v) = &mut arr.data {
552 for i in 0..16 {
553 v[i] = (i as i16) * 100 - 500;
554 }
555 }
556
557 writer.open_file(&path, NDFileMode::Single, &arr).unwrap();
558 writer.write_file(&arr).unwrap();
559 writer.close_file().unwrap();
560
561 writer.current_path = Some(path.clone());
562 let read_back = writer.read_file().unwrap();
563 if let (NDDataBuffer::I16(orig), NDDataBuffer::I16(read)) = (&arr.data, &read_back.data) {
564 assert_eq!(orig, read);
565 } else {
566 panic!("data type mismatch on roundtrip");
567 }
568
569 std::fs::remove_file(&path).ok();
570 }
571
572 #[test]
573 fn test_roundtrip_f32() {
574 let path = temp_path("nc_rt_f32");
575 let mut writer = NetcdfWriter::new();
576
577 let mut arr = NDArray::new(
578 vec![NDDimension::new(4), NDDimension::new(4)],
579 NDDataType::Float32,
580 );
581 if let NDDataBuffer::F32(v) = &mut arr.data {
582 for i in 0..16 {
583 v[i] = i as f32 * 0.5;
584 }
585 }
586
587 writer.open_file(&path, NDFileMode::Single, &arr).unwrap();
588 writer.write_file(&arr).unwrap();
589 writer.close_file().unwrap();
590
591 writer.current_path = Some(path.clone());
592 let read_back = writer.read_file().unwrap();
593 if let (NDDataBuffer::F32(orig), NDDataBuffer::F32(read)) = (&arr.data, &read_back.data) {
594 assert_eq!(orig, read);
595 } else {
596 panic!("data type mismatch on roundtrip");
597 }
598
599 std::fs::remove_file(&path).ok();
600 }
601
602 #[test]
603 fn test_multiple_frames() {
604 let path = temp_path("nc_multi");
605 let mut writer = NetcdfWriter::new();
606
607 let mut arr1 = NDArray::new(
608 vec![NDDimension::new(4), NDDimension::new(4)],
609 NDDataType::UInt8,
610 );
611 if let NDDataBuffer::U8(v) = &mut arr1.data {
612 for i in 0..16 {
613 v[i] = i as u8;
614 }
615 }
616
617 let mut arr2 = NDArray::new(
618 vec![NDDimension::new(4), NDDimension::new(4)],
619 NDDataType::UInt8,
620 );
621 if let NDDataBuffer::U8(v) = &mut arr2.data {
622 for i in 0..16 {
623 v[i] = (i as u8).wrapping_add(100);
624 }
625 }
626
627 let mut arr3 = NDArray::new(
628 vec![NDDimension::new(4), NDDimension::new(4)],
629 NDDataType::UInt8,
630 );
631 if let NDDataBuffer::U8(v) = &mut arr3.data {
632 for i in 0..16 {
633 v[i] = (i as u8).wrapping_add(200);
634 }
635 }
636
637 writer.open_file(&path, NDFileMode::Stream, &arr1).unwrap();
638 writer.write_file(&arr1).unwrap();
639 writer.write_file(&arr2).unwrap();
640 writer.write_file(&arr3).unwrap();
641 writer.close_file().unwrap();
642
643 writer.current_path = Some(path.clone());
645 let read_back = writer.read_file().unwrap();
646 if let NDDataBuffer::U8(v) = &read_back.data {
647 assert_eq!(v.len(), 16);
648 for i in 0..16 {
649 assert_eq!(v[i], i as u8, "mismatch at index {}", i);
650 }
651 } else {
652 panic!("expected U8 data");
653 }
654
655 std::fs::remove_file(&path).ok();
656 }
657
658 #[test]
659 fn test_attributes_stored() {
660 let path = temp_path("nc_attrs");
661 let mut writer = NetcdfWriter::new();
662
663 let mut arr = NDArray::new(vec![NDDimension::new(4)], NDDataType::UInt8);
664 arr.attributes.add(NDAttribute {
665 name: "exposure".into(),
666 description: "".into(),
667 source: NDAttrSource::Driver,
668 value: NDAttrValue::Float64(0.5),
669 });
670 arr.attributes.add(NDAttribute {
671 name: "gain".into(),
672 description: "".into(),
673 source: NDAttrSource::Driver,
674 value: NDAttrValue::Int32(42),
675 });
676
677 writer.open_file(&path, NDFileMode::Single, &arr).unwrap();
678 writer.write_file(&arr).unwrap();
679 writer.close_file().unwrap();
680
681 let reader = FileReader::open(&path).unwrap();
683 let ds = reader.data_set();
684
685 let exposure = ds.get_var_attr_as_string(VAR_NAME, "exposure");
686 assert_eq!(exposure, Some("0.5".to_string()));
687
688 let gain = ds.get_var_attr_as_string(VAR_NAME, "gain");
689 assert_eq!(gain, Some("42".to_string()));
690
691 drop(reader);
692 std::fs::remove_file(&path).ok();
693 }
694}