1use std::path::{Path, PathBuf};
2
3use ad_core_rs::error::{ADError, ADResult};
4use ad_core_rs::ndarray::{NDArray, NDDataBuffer, NDDataType, NDDimension};
5use ad_core_rs::ndarray_pool::NDArrayPool;
6use ad_core_rs::plugin::file_base::{NDFileMode, NDFileWriter};
7use ad_core_rs::plugin::file_controller::FilePluginController;
8use ad_core_rs::plugin::runtime::{
9 NDPluginProcess, ParamChangeResult, PluginParamSnapshot, ProcessResult,
10};
11
12use netcdf3::{DataSet, FileReader, FileWriter, Version};
13
14const VAR_NAME: &str = "array_data";
15const DIM_UNLIMITED: &str = "numArrays";
16
17struct FrameData {
19 dims: Vec<usize>,
20 data: NDDataBuffer,
21 data_type: NDDataType,
22 attrs: Vec<(String, String)>,
23}
24
25pub struct NetcdfWriter {
33 current_path: Option<PathBuf>,
34 frames: Vec<FrameData>,
35}
36
37impl NetcdfWriter {
38 pub fn new() -> Self {
39 Self {
40 current_path: None,
41 frames: Vec::new(),
42 }
43 }
44}
45
46fn nc_data_type(dt: NDDataType) -> ADResult<netcdf3::DataType> {
49 match dt {
50 NDDataType::Int8 => Ok(netcdf3::DataType::I8),
51 NDDataType::UInt8 => Ok(netcdf3::DataType::U8),
52 NDDataType::Int16 | NDDataType::UInt16 => Ok(netcdf3::DataType::I16),
53 NDDataType::Int32 | NDDataType::UInt32 => Ok(netcdf3::DataType::I32),
54 NDDataType::Float32 => Ok(netcdf3::DataType::F32),
55 NDDataType::Float64 => Ok(netcdf3::DataType::F64),
56 NDDataType::Int64 | NDDataType::UInt64 => Err(ADError::UnsupportedConversion(
57 "NetCDF-3 does not support 64-bit integer types".into(),
58 )),
59 }
60}
61
62fn write_var_data(writer: &mut FileWriter, data: &NDDataBuffer) -> ADResult<()> {
64 let err = |e: netcdf3::error::WriteError| {
65 ADError::UnsupportedConversion(format!("NetCDF write error: {:?}", e))
66 };
67 match data {
68 NDDataBuffer::I8(v) => writer.write_var_i8(VAR_NAME, v).map_err(err),
69 NDDataBuffer::U8(v) => writer.write_var_u8(VAR_NAME, v).map_err(err),
70 NDDataBuffer::I16(v) => writer.write_var_i16(VAR_NAME, v).map_err(err),
71 NDDataBuffer::U16(v) => {
72 let reinterp: Vec<i16> = v.iter().map(|&x| x as i16).collect();
73 writer.write_var_i16(VAR_NAME, &reinterp).map_err(err)
74 }
75 NDDataBuffer::I32(v) => writer.write_var_i32(VAR_NAME, v).map_err(err),
76 NDDataBuffer::U32(v) => {
77 let reinterp: Vec<i32> = v.iter().map(|&x| x as i32).collect();
78 writer.write_var_i32(VAR_NAME, &reinterp).map_err(err)
79 }
80 NDDataBuffer::F32(v) => writer.write_var_f32(VAR_NAME, v).map_err(err),
81 NDDataBuffer::F64(v) => writer.write_var_f64(VAR_NAME, v).map_err(err),
82 _ => Err(ADError::UnsupportedConversion(
83 "NetCDF-3 does not support 64-bit integer types".into(),
84 )),
85 }
86}
87
88fn write_record_data(
90 writer: &mut FileWriter,
91 record_index: usize,
92 data: &NDDataBuffer,
93) -> ADResult<()> {
94 let err = |e: netcdf3::error::WriteError| {
95 ADError::UnsupportedConversion(format!("NetCDF write error: {:?}", e))
96 };
97 match data {
98 NDDataBuffer::I8(v) => writer
99 .write_record_i8(VAR_NAME, record_index, v)
100 .map_err(err),
101 NDDataBuffer::U8(v) => writer
102 .write_record_u8(VAR_NAME, record_index, v)
103 .map_err(err),
104 NDDataBuffer::I16(v) => writer
105 .write_record_i16(VAR_NAME, record_index, v)
106 .map_err(err),
107 NDDataBuffer::U16(v) => {
108 let reinterp: Vec<i16> = v.iter().map(|&x| x as i16).collect();
109 writer
110 .write_record_i16(VAR_NAME, record_index, &reinterp)
111 .map_err(err)
112 }
113 NDDataBuffer::I32(v) => writer
114 .write_record_i32(VAR_NAME, record_index, v)
115 .map_err(err),
116 NDDataBuffer::U32(v) => {
117 let reinterp: Vec<i32> = v.iter().map(|&x| x as i32).collect();
118 writer
119 .write_record_i32(VAR_NAME, record_index, &reinterp)
120 .map_err(err)
121 }
122 NDDataBuffer::F32(v) => writer
123 .write_record_f32(VAR_NAME, record_index, v)
124 .map_err(err),
125 NDDataBuffer::F64(v) => writer
126 .write_record_f64(VAR_NAME, record_index, v)
127 .map_err(err),
128 _ => Err(ADError::UnsupportedConversion(
129 "NetCDF-3 does not support 64-bit integer types".into(),
130 )),
131 }
132}
133
134impl NDFileWriter for NetcdfWriter {
135 fn open_file(&mut self, path: &Path, _mode: NDFileMode, _array: &NDArray) -> ADResult<()> {
136 self.current_path = Some(path.to_path_buf());
137 self.frames.clear();
138 Ok(())
139 }
140
141 fn write_file(&mut self, array: &NDArray) -> ADResult<()> {
142 nc_data_type(array.data.data_type())?;
144
145 let dims: Vec<usize> = array.dims.iter().map(|d| d.size).collect();
146 let attrs: Vec<(String, String)> = array
147 .attributes
148 .iter()
149 .map(|a| (a.name.clone(), a.value.as_string()))
150 .collect();
151
152 self.frames.push(FrameData {
153 dims,
154 data: array.data.clone(),
155 data_type: array.data.data_type(),
156 attrs,
157 });
158 Ok(())
159 }
160
161 fn close_file(&mut self) -> ADResult<()> {
162 let path = match self.current_path.take() {
163 Some(p) => p,
164 None => return Ok(()),
165 };
166
167 if self.frames.is_empty() {
168 return Ok(());
169 }
170
171 let map_def = |e: netcdf3::error::InvalidDataSet| {
172 ADError::UnsupportedConversion(format!("NetCDF definition error: {:?}", e))
173 };
174 let map_write = |e: netcdf3::error::WriteError| {
175 ADError::UnsupportedConversion(format!("NetCDF write error: {:?}", e))
176 };
177
178 let first = &self.frames[0];
179 let nc_dt = nc_data_type(first.data_type)?;
180 let multi = self.frames.len() > 1;
181
182 let mut ds = DataSet::new();
184
185 let mut dim_names: Vec<String> = Vec::new();
187 for (i, &size) in first.dims.iter().enumerate() {
188 let name = format!("dim{}", i);
189 ds.add_fixed_dim(&name, size).map_err(map_def)?;
190 dim_names.push(name);
191 }
192
193 let var_dims: Vec<String> = if multi {
195 ds.set_unlimited_dim(DIM_UNLIMITED, self.frames.len())
197 .map_err(map_def)?;
198 let mut v = vec![DIM_UNLIMITED.to_string()];
199 v.extend(dim_names.iter().cloned());
200 v
201 } else {
202 dim_names.clone()
203 };
204
205 let var_dim_refs: Vec<&str> = var_dims.iter().map(|s| s.as_str()).collect();
206 ds.add_var(VAR_NAME, &var_dim_refs, nc_dt)
207 .map_err(map_def)?;
208
209 let mut seen_attrs = std::collections::HashSet::new();
212 for frame in &self.frames {
213 for (name, value) in &frame.attrs {
214 if seen_attrs.insert(name.clone()) {
215 let _ = ds.add_var_attr_string(VAR_NAME, name, value);
216 }
217 }
218 }
219
220 ds.add_global_attr_i32("uniqueId", vec![0])
222 .map_err(map_def)?;
223 ds.add_global_attr_i32("dataType", vec![first.data_type as i32])
224 .map_err(map_def)?;
225 ds.add_global_attr_i32("numArrays", vec![self.frames.len() as i32])
226 .map_err(map_def)?;
227
228 let mut writer = FileWriter::open(&path).map_err(map_write)?;
230 writer
231 .set_def(&ds, Version::Classic, 0)
232 .map_err(map_write)?;
233
234 if multi {
235 for (i, frame) in self.frames.iter().enumerate() {
236 write_record_data(&mut writer, i, &frame.data)?;
237 }
238 } else {
239 write_var_data(&mut writer, &self.frames[0].data)?;
240 }
241
242 writer.close().map_err(map_write)?;
243 self.frames.clear();
244 Ok(())
245 }
246
247 fn read_file(&mut self) -> ADResult<NDArray> {
248 let path = self
249 .current_path
250 .as_ref()
251 .ok_or_else(|| ADError::UnsupportedConversion("no file open".into()))?;
252
253 let map_read = |e: netcdf3::error::ReadError| {
254 ADError::UnsupportedConversion(format!("NetCDF read error: {:?}", e))
255 };
256
257 let mut reader = FileReader::open(path).map_err(map_read)?;
258
259 let (is_record, dims, original_type_ordinal) = {
261 let ds = reader.data_set();
262 let var = ds.get_var(VAR_NAME).ok_or_else(|| {
263 ADError::UnsupportedConversion(format!(
264 "variable '{}' not found in NetCDF file",
265 VAR_NAME
266 ))
267 })?;
268
269 let is_record = ds.is_record_var(VAR_NAME).unwrap_or(false);
270
271 let var_dims_rc = var.get_dims();
272 let mut dims: Vec<NDDimension> = Vec::new();
273 for d in &var_dims_rc {
274 if d.is_unlimited() {
275 continue;
276 }
277 dims.push(NDDimension::new(d.size()));
278 }
279
280 let original_type_ordinal = ds
281 .get_global_attr_i32("dataType")
282 .and_then(|slice| slice.first().copied());
283
284 (is_record, dims, original_type_ordinal)
285 };
286
287 let data_vec = if is_record {
289 reader.read_record(VAR_NAME, 0).map_err(map_read)?
290 } else {
291 reader.read_var(VAR_NAME).map_err(map_read)?
292 };
293
294 let (nd_type, buf) = match data_vec {
295 netcdf3::DataVector::I8(v) => (NDDataType::Int8, NDDataBuffer::I8(v)),
296 netcdf3::DataVector::U8(v) => (NDDataType::UInt8, NDDataBuffer::U8(v)),
297 netcdf3::DataVector::I16(v) => (NDDataType::Int16, NDDataBuffer::I16(v)),
298 netcdf3::DataVector::I32(v) => (NDDataType::Int32, NDDataBuffer::I32(v)),
299 netcdf3::DataVector::F32(v) => (NDDataType::Float32, NDDataBuffer::F32(v)),
300 netcdf3::DataVector::F64(v) => (NDDataType::Float64, NDDataBuffer::F64(v)),
301 };
302
303 let actual_type = original_type_ordinal
305 .and_then(|v| NDDataType::from_ordinal(v as u8))
306 .unwrap_or(nd_type);
307
308 let buf = match (actual_type, buf) {
310 (NDDataType::UInt16, NDDataBuffer::I16(v)) => {
311 NDDataBuffer::U16(v.into_iter().map(|x| x as u16).collect())
312 }
313 (NDDataType::UInt32, NDDataBuffer::I32(v)) => {
314 NDDataBuffer::U32(v.into_iter().map(|x| x as u32).collect())
315 }
316 (_, buf) => buf,
317 };
318
319 let mut arr = NDArray::new(dims, actual_type);
320 arr.data = buf;
321 Ok(arr)
322 }
323
324 fn supports_multiple_arrays(&self) -> bool {
325 true
326 }
327}
328
329pub struct NetcdfFileProcessor {
331 ctrl: FilePluginController<NetcdfWriter>,
332}
333
334impl NetcdfFileProcessor {
335 pub fn new() -> Self {
336 Self {
337 ctrl: FilePluginController::new(NetcdfWriter::new()),
338 }
339 }
340}
341
342impl Default for NetcdfFileProcessor {
343 fn default() -> Self {
344 Self::new()
345 }
346}
347
348impl NDPluginProcess for NetcdfFileProcessor {
349 fn process_array(&mut self, array: &NDArray, _pool: &NDArrayPool) -> ProcessResult {
350 self.ctrl.process_array(array)
351 }
352
353 fn plugin_type(&self) -> &str {
354 "NDFileNetCDF"
355 }
356
357 fn register_params(
358 &mut self,
359 base: &mut asyn_rs::port::PortDriverBase,
360 ) -> asyn_rs::error::AsynResult<()> {
361 self.ctrl.register_params(base)
362 }
363
364 fn on_param_change(
365 &mut self,
366 reason: usize,
367 params: &PluginParamSnapshot,
368 ) -> ParamChangeResult {
369 self.ctrl.on_param_change(reason, params)
370 }
371}
372
373#[cfg(test)]
374mod tests {
375 use super::*;
376 use ad_core_rs::attributes::{NDAttrSource, NDAttrValue, NDAttribute};
377 use std::sync::atomic::{AtomicU32, Ordering};
378
379 static TEST_COUNTER: AtomicU32 = AtomicU32::new(0);
380
381 fn temp_path(prefix: &str) -> PathBuf {
382 let n = TEST_COUNTER.fetch_add(1, Ordering::Relaxed);
383 std::env::temp_dir().join(format!("adcore_test_{}_{}.nc", prefix, n))
384 }
385
386 #[test]
387 fn test_write_u8_mono() {
388 let path = temp_path("nc_u8");
389 let mut writer = NetcdfWriter::new();
390
391 let mut arr = NDArray::new(
392 vec![NDDimension::new(4), NDDimension::new(4)],
393 NDDataType::UInt8,
394 );
395 if let NDDataBuffer::U8(v) = &mut arr.data {
396 for i in 0..16 {
397 v[i] = i as u8;
398 }
399 }
400
401 writer.open_file(&path, NDFileMode::Single, &arr).unwrap();
402 writer.write_file(&arr).unwrap();
403 writer.close_file().unwrap();
404
405 let data = std::fs::read(&path).unwrap();
407 assert!(data.len() > 16);
408 assert_eq!(&data[0..3], b"CDF", "Expected NetCDF magic bytes");
409
410 std::fs::remove_file(&path).ok();
411 }
412
413 #[test]
414 fn test_write_u16() {
415 let path = temp_path("nc_u16");
416 let mut writer = NetcdfWriter::new();
417
418 let mut arr = NDArray::new(
419 vec![NDDimension::new(4), NDDimension::new(4)],
420 NDDataType::UInt16,
421 );
422 if let NDDataBuffer::U16(v) = &mut arr.data {
423 for i in 0..16 {
424 v[i] = (i * 1000) as u16;
425 }
426 }
427
428 writer.open_file(&path, NDFileMode::Single, &arr).unwrap();
429 writer.write_file(&arr).unwrap();
430 writer.close_file().unwrap();
431
432 let data = std::fs::read(&path).unwrap();
433 assert!(data.len() > 32);
434 assert_eq!(&data[0..3], b"CDF");
435
436 std::fs::remove_file(&path).ok();
437 }
438
439 #[test]
440 fn test_roundtrip_u8() {
441 let path = temp_path("nc_rt_u8");
442 let mut writer = NetcdfWriter::new();
443
444 let mut arr = NDArray::new(
445 vec![NDDimension::new(4), NDDimension::new(4)],
446 NDDataType::UInt8,
447 );
448 if let NDDataBuffer::U8(v) = &mut arr.data {
449 for i in 0..16 {
450 v[i] = (i * 10) as u8;
451 }
452 }
453
454 writer.open_file(&path, NDFileMode::Single, &arr).unwrap();
455 writer.write_file(&arr).unwrap();
456 writer.close_file().unwrap();
457
458 writer.current_path = Some(path.clone());
459 let read_back = writer.read_file().unwrap();
460 if let (NDDataBuffer::U8(orig), NDDataBuffer::U8(read)) = (&arr.data, &read_back.data) {
461 assert_eq!(orig, read);
462 } else {
463 panic!("data type mismatch on roundtrip");
464 }
465
466 std::fs::remove_file(&path).ok();
467 }
468
469 #[test]
470 fn test_roundtrip_i16() {
471 let path = temp_path("nc_rt_i16");
472 let mut writer = NetcdfWriter::new();
473
474 let mut arr = NDArray::new(
475 vec![NDDimension::new(4), NDDimension::new(4)],
476 NDDataType::Int16,
477 );
478 if let NDDataBuffer::I16(v) = &mut arr.data {
479 for i in 0..16 {
480 v[i] = (i as i16) * 100 - 500;
481 }
482 }
483
484 writer.open_file(&path, NDFileMode::Single, &arr).unwrap();
485 writer.write_file(&arr).unwrap();
486 writer.close_file().unwrap();
487
488 writer.current_path = Some(path.clone());
489 let read_back = writer.read_file().unwrap();
490 if let (NDDataBuffer::I16(orig), NDDataBuffer::I16(read)) = (&arr.data, &read_back.data) {
491 assert_eq!(orig, read);
492 } else {
493 panic!("data type mismatch on roundtrip");
494 }
495
496 std::fs::remove_file(&path).ok();
497 }
498
499 #[test]
500 fn test_roundtrip_f32() {
501 let path = temp_path("nc_rt_f32");
502 let mut writer = NetcdfWriter::new();
503
504 let mut arr = NDArray::new(
505 vec![NDDimension::new(4), NDDimension::new(4)],
506 NDDataType::Float32,
507 );
508 if let NDDataBuffer::F32(v) = &mut arr.data {
509 for i in 0..16 {
510 v[i] = i as f32 * 0.5;
511 }
512 }
513
514 writer.open_file(&path, NDFileMode::Single, &arr).unwrap();
515 writer.write_file(&arr).unwrap();
516 writer.close_file().unwrap();
517
518 writer.current_path = Some(path.clone());
519 let read_back = writer.read_file().unwrap();
520 if let (NDDataBuffer::F32(orig), NDDataBuffer::F32(read)) = (&arr.data, &read_back.data) {
521 assert_eq!(orig, read);
522 } else {
523 panic!("data type mismatch on roundtrip");
524 }
525
526 std::fs::remove_file(&path).ok();
527 }
528
529 #[test]
530 fn test_multiple_frames() {
531 let path = temp_path("nc_multi");
532 let mut writer = NetcdfWriter::new();
533
534 let mut arr1 = NDArray::new(
535 vec![NDDimension::new(4), NDDimension::new(4)],
536 NDDataType::UInt8,
537 );
538 if let NDDataBuffer::U8(v) = &mut arr1.data {
539 for i in 0..16 {
540 v[i] = i as u8;
541 }
542 }
543
544 let mut arr2 = NDArray::new(
545 vec![NDDimension::new(4), NDDimension::new(4)],
546 NDDataType::UInt8,
547 );
548 if let NDDataBuffer::U8(v) = &mut arr2.data {
549 for i in 0..16 {
550 v[i] = (i as u8).wrapping_add(100);
551 }
552 }
553
554 let mut arr3 = NDArray::new(
555 vec![NDDimension::new(4), NDDimension::new(4)],
556 NDDataType::UInt8,
557 );
558 if let NDDataBuffer::U8(v) = &mut arr3.data {
559 for i in 0..16 {
560 v[i] = (i as u8).wrapping_add(200);
561 }
562 }
563
564 writer.open_file(&path, NDFileMode::Stream, &arr1).unwrap();
565 writer.write_file(&arr1).unwrap();
566 writer.write_file(&arr2).unwrap();
567 writer.write_file(&arr3).unwrap();
568 writer.close_file().unwrap();
569
570 writer.current_path = Some(path.clone());
572 let read_back = writer.read_file().unwrap();
573 if let NDDataBuffer::U8(v) = &read_back.data {
574 assert_eq!(v.len(), 16);
575 for i in 0..16 {
576 assert_eq!(v[i], i as u8, "mismatch at index {}", i);
577 }
578 } else {
579 panic!("expected U8 data");
580 }
581
582 std::fs::remove_file(&path).ok();
583 }
584
585 #[test]
586 fn test_attributes_stored() {
587 let path = temp_path("nc_attrs");
588 let mut writer = NetcdfWriter::new();
589
590 let mut arr = NDArray::new(vec![NDDimension::new(4)], NDDataType::UInt8);
591 arr.attributes.add(NDAttribute {
592 name: "exposure".into(),
593 description: "".into(),
594 source: NDAttrSource::Driver,
595 value: NDAttrValue::Float64(0.5),
596 });
597 arr.attributes.add(NDAttribute {
598 name: "gain".into(),
599 description: "".into(),
600 source: NDAttrSource::Driver,
601 value: NDAttrValue::Int32(42),
602 });
603
604 writer.open_file(&path, NDFileMode::Single, &arr).unwrap();
605 writer.write_file(&arr).unwrap();
606 writer.close_file().unwrap();
607
608 let reader = FileReader::open(&path).unwrap();
610 let ds = reader.data_set();
611
612 let exposure = ds.get_var_attr_as_string(VAR_NAME, "exposure");
613 assert_eq!(exposure, Some("0.5".to_string()));
614
615 let gain = ds.get_var_attr_as_string(VAR_NAME, "gain");
616 assert_eq!(gain, Some("42".to_string()));
617
618 drop(reader);
619 std::fs::remove_file(&path).ok();
620 }
621}