1use std::path::{Path, PathBuf};
2use std::sync::Arc;
3
4use ad_core::error::{ADError, ADResult};
5use ad_core::ndarray::{NDArray, NDDataBuffer, NDDataType, NDDimension};
6use ad_core::ndarray_pool::NDArrayPool;
7use ad_core::plugin::file_base::{NDFileMode, NDFileWriter, NDPluginFileBase};
8use ad_core::plugin::runtime::{NDPluginProcess, ProcessResult};
9
10use netcdf3::{DataSet, FileReader, FileWriter, Version};
11
12const VAR_NAME: &str = "array_data";
13const DIM_UNLIMITED: &str = "numArrays";
14
15struct FrameData {
17 dims: Vec<usize>,
18 data: NDDataBuffer,
19 data_type: NDDataType,
20 attrs: Vec<(String, String)>,
21}
22
23pub struct NetcdfWriter {
31 current_path: Option<PathBuf>,
32 frames: Vec<FrameData>,
33}
34
35impl NetcdfWriter {
36 pub fn new() -> Self {
37 Self {
38 current_path: None,
39 frames: Vec::new(),
40 }
41 }
42}
43
44fn nc_data_type(dt: NDDataType) -> ADResult<netcdf3::DataType> {
47 match dt {
48 NDDataType::Int8 => Ok(netcdf3::DataType::I8),
49 NDDataType::UInt8 => Ok(netcdf3::DataType::U8),
50 NDDataType::Int16 | NDDataType::UInt16 => Ok(netcdf3::DataType::I16),
51 NDDataType::Int32 | NDDataType::UInt32 => Ok(netcdf3::DataType::I32),
52 NDDataType::Float32 => Ok(netcdf3::DataType::F32),
53 NDDataType::Float64 => Ok(netcdf3::DataType::F64),
54 NDDataType::Int64 | NDDataType::UInt64 => Err(ADError::UnsupportedConversion(
55 "NetCDF-3 does not support 64-bit integer types".into(),
56 )),
57 }
58}
59
60fn write_var_data(
62 writer: &mut FileWriter,
63 data: &NDDataBuffer,
64) -> ADResult<()> {
65 let err = |e: netcdf3::error::WriteError| {
66 ADError::UnsupportedConversion(format!("NetCDF write error: {:?}", e))
67 };
68 match data {
69 NDDataBuffer::I8(v) => writer.write_var_i8(VAR_NAME, v).map_err(err),
70 NDDataBuffer::U8(v) => writer.write_var_u8(VAR_NAME, v).map_err(err),
71 NDDataBuffer::I16(v) => writer.write_var_i16(VAR_NAME, v).map_err(err),
72 NDDataBuffer::U16(v) => {
73 let reinterp: Vec<i16> = v.iter().map(|&x| x as i16).collect();
74 writer.write_var_i16(VAR_NAME, &reinterp).map_err(err)
75 }
76 NDDataBuffer::I32(v) => writer.write_var_i32(VAR_NAME, v).map_err(err),
77 NDDataBuffer::U32(v) => {
78 let reinterp: Vec<i32> = v.iter().map(|&x| x as i32).collect();
79 writer.write_var_i32(VAR_NAME, &reinterp).map_err(err)
80 }
81 NDDataBuffer::F32(v) => writer.write_var_f32(VAR_NAME, v).map_err(err),
82 NDDataBuffer::F64(v) => writer.write_var_f64(VAR_NAME, v).map_err(err),
83 _ => Err(ADError::UnsupportedConversion(
84 "NetCDF-3 does not support 64-bit integer types".into(),
85 )),
86 }
87}
88
89fn write_record_data(
91 writer: &mut FileWriter,
92 record_index: usize,
93 data: &NDDataBuffer,
94) -> ADResult<()> {
95 let err = |e: netcdf3::error::WriteError| {
96 ADError::UnsupportedConversion(format!("NetCDF write error: {:?}", e))
97 };
98 match data {
99 NDDataBuffer::I8(v) => writer.write_record_i8(VAR_NAME, record_index, v).map_err(err),
100 NDDataBuffer::U8(v) => writer.write_record_u8(VAR_NAME, record_index, v).map_err(err),
101 NDDataBuffer::I16(v) => writer.write_record_i16(VAR_NAME, record_index, v).map_err(err),
102 NDDataBuffer::U16(v) => {
103 let reinterp: Vec<i16> = v.iter().map(|&x| x as i16).collect();
104 writer.write_record_i16(VAR_NAME, record_index, &reinterp).map_err(err)
105 }
106 NDDataBuffer::I32(v) => writer.write_record_i32(VAR_NAME, record_index, v).map_err(err),
107 NDDataBuffer::U32(v) => {
108 let reinterp: Vec<i32> = v.iter().map(|&x| x as i32).collect();
109 writer.write_record_i32(VAR_NAME, record_index, &reinterp).map_err(err)
110 }
111 NDDataBuffer::F32(v) => writer.write_record_f32(VAR_NAME, record_index, v).map_err(err),
112 NDDataBuffer::F64(v) => writer.write_record_f64(VAR_NAME, record_index, v).map_err(err),
113 _ => Err(ADError::UnsupportedConversion(
114 "NetCDF-3 does not support 64-bit integer types".into(),
115 )),
116 }
117}
118
119impl NDFileWriter for NetcdfWriter {
120 fn open_file(&mut self, path: &Path, _mode: NDFileMode, _array: &NDArray) -> ADResult<()> {
121 self.current_path = Some(path.to_path_buf());
122 self.frames.clear();
123 Ok(())
124 }
125
126 fn write_file(&mut self, array: &NDArray) -> ADResult<()> {
127 nc_data_type(array.data.data_type())?;
129
130 let dims: Vec<usize> = array.dims.iter().map(|d| d.size).collect();
131 let attrs: Vec<(String, String)> = array
132 .attributes
133 .iter()
134 .map(|a| (a.name.clone(), a.value.as_string()))
135 .collect();
136
137 self.frames.push(FrameData {
138 dims,
139 data: array.data.clone(),
140 data_type: array.data.data_type(),
141 attrs,
142 });
143 Ok(())
144 }
145
146 fn close_file(&mut self) -> ADResult<()> {
147 let path = match self.current_path.take() {
148 Some(p) => p,
149 None => return Ok(()),
150 };
151
152 if self.frames.is_empty() {
153 return Ok(());
154 }
155
156 let map_def = |e: netcdf3::error::InvalidDataSet| {
157 ADError::UnsupportedConversion(format!("NetCDF definition error: {:?}", e))
158 };
159 let map_write = |e: netcdf3::error::WriteError| {
160 ADError::UnsupportedConversion(format!("NetCDF write error: {:?}", e))
161 };
162
163 let first = &self.frames[0];
164 let nc_dt = nc_data_type(first.data_type)?;
165 let multi = self.frames.len() > 1;
166
167 let mut ds = DataSet::new();
169
170 let mut dim_names: Vec<String> = Vec::new();
172 for (i, &size) in first.dims.iter().enumerate() {
173 let name = format!("dim{}", i);
174 ds.add_fixed_dim(&name, size).map_err(map_def)?;
175 dim_names.push(name);
176 }
177
178 let var_dims: Vec<String> = if multi {
180 ds.set_unlimited_dim(DIM_UNLIMITED, self.frames.len()).map_err(map_def)?;
182 let mut v = vec![DIM_UNLIMITED.to_string()];
183 v.extend(dim_names.iter().cloned());
184 v
185 } else {
186 dim_names.clone()
187 };
188
189 let var_dim_refs: Vec<&str> = var_dims.iter().map(|s| s.as_str()).collect();
190 ds.add_var(VAR_NAME, &var_dim_refs, nc_dt).map_err(map_def)?;
191
192 let mut seen_attrs = std::collections::HashSet::new();
195 for frame in &self.frames {
196 for (name, value) in &frame.attrs {
197 if seen_attrs.insert(name.clone()) {
198 let _ = ds.add_var_attr_string(VAR_NAME, name, value);
199 }
200 }
201 }
202
203 ds.add_global_attr_i32("uniqueId", vec![0]).map_err(map_def)?;
205 ds.add_global_attr_i32("dataType", vec![first.data_type as i32]).map_err(map_def)?;
206 ds.add_global_attr_i32("numArrays", vec![self.frames.len() as i32]).map_err(map_def)?;
207
208 let mut writer = FileWriter::open(&path).map_err(map_write)?;
210 writer.set_def(&ds, Version::Classic, 0).map_err(map_write)?;
211
212 if multi {
213 for (i, frame) in self.frames.iter().enumerate() {
214 write_record_data(&mut writer, i, &frame.data)?;
215 }
216 } else {
217 write_var_data(&mut writer, &self.frames[0].data)?;
218 }
219
220 writer.close().map_err(map_write)?;
221 self.frames.clear();
222 Ok(())
223 }
224
225 fn read_file(&mut self) -> ADResult<NDArray> {
226 let path = self.current_path.as_ref()
227 .ok_or_else(|| ADError::UnsupportedConversion("no file open".into()))?;
228
229 let map_read = |e: netcdf3::error::ReadError| {
230 ADError::UnsupportedConversion(format!("NetCDF read error: {:?}", e))
231 };
232
233 let mut reader = FileReader::open(path).map_err(map_read)?;
234
235 let (is_record, dims, original_type_ordinal) = {
237 let ds = reader.data_set();
238 let var = ds.get_var(VAR_NAME)
239 .ok_or_else(|| ADError::UnsupportedConversion(
240 format!("variable '{}' not found in NetCDF file", VAR_NAME),
241 ))?;
242
243 let is_record = ds.is_record_var(VAR_NAME).unwrap_or(false);
244
245 let var_dims_rc = var.get_dims();
246 let mut dims: Vec<NDDimension> = Vec::new();
247 for d in &var_dims_rc {
248 if d.is_unlimited() {
249 continue;
250 }
251 dims.push(NDDimension::new(d.size()));
252 }
253
254 let original_type_ordinal = ds
255 .get_global_attr_i32("dataType")
256 .and_then(|slice| slice.first().copied());
257
258 (is_record, dims, original_type_ordinal)
259 };
260
261 let data_vec = if is_record {
263 reader.read_record(VAR_NAME, 0).map_err(map_read)?
264 } else {
265 reader.read_var(VAR_NAME).map_err(map_read)?
266 };
267
268 let (nd_type, buf) = match data_vec {
269 netcdf3::DataVector::I8(v) => (NDDataType::Int8, NDDataBuffer::I8(v)),
270 netcdf3::DataVector::U8(v) => (NDDataType::UInt8, NDDataBuffer::U8(v)),
271 netcdf3::DataVector::I16(v) => (NDDataType::Int16, NDDataBuffer::I16(v)),
272 netcdf3::DataVector::I32(v) => (NDDataType::Int32, NDDataBuffer::I32(v)),
273 netcdf3::DataVector::F32(v) => (NDDataType::Float32, NDDataBuffer::F32(v)),
274 netcdf3::DataVector::F64(v) => (NDDataType::Float64, NDDataBuffer::F64(v)),
275 };
276
277 let actual_type = original_type_ordinal
279 .and_then(|v| NDDataType::from_ordinal(v as u8))
280 .unwrap_or(nd_type);
281
282 let buf = match (actual_type, buf) {
284 (NDDataType::UInt16, NDDataBuffer::I16(v)) => {
285 NDDataBuffer::U16(v.into_iter().map(|x| x as u16).collect())
286 }
287 (NDDataType::UInt32, NDDataBuffer::I32(v)) => {
288 NDDataBuffer::U32(v.into_iter().map(|x| x as u32).collect())
289 }
290 (_, buf) => buf,
291 };
292
293 let mut arr = NDArray::new(dims, actual_type);
294 arr.data = buf;
295 Ok(arr)
296 }
297
298 fn supports_multiple_arrays(&self) -> bool {
299 true
300 }
301}
302
303pub struct NetcdfFileProcessor {
305 file_base: NDPluginFileBase,
306 writer: NetcdfWriter,
307}
308
309impl NetcdfFileProcessor {
310 pub fn new() -> Self {
311 Self {
312 file_base: NDPluginFileBase::new(),
313 writer: NetcdfWriter::new(),
314 }
315 }
316
317 pub fn file_base_mut(&mut self) -> &mut NDPluginFileBase {
318 &mut self.file_base
319 }
320}
321
322impl Default for NetcdfFileProcessor {
323 fn default() -> Self {
324 Self::new()
325 }
326}
327
328impl NDPluginProcess for NetcdfFileProcessor {
329 fn process_array(&mut self, array: &NDArray, _pool: &NDArrayPool) -> ProcessResult {
330 let _ = self
331 .file_base
332 .process_array(Arc::new(array.clone()), &mut self.writer);
333 ProcessResult::empty()
334 }
335
336 fn plugin_type(&self) -> &str {
337 "NDFileNetCDF"
338 }
339}
340
341#[cfg(test)]
342mod tests {
343 use super::*;
344 use ad_core::attributes::{NDAttrSource, NDAttrValue, NDAttribute};
345 use std::sync::atomic::{AtomicU32, Ordering};
346
347 static TEST_COUNTER: AtomicU32 = AtomicU32::new(0);
348
349 fn temp_path(prefix: &str) -> PathBuf {
350 let n = TEST_COUNTER.fetch_add(1, Ordering::Relaxed);
351 std::env::temp_dir().join(format!("adcore_test_{}_{}.nc", prefix, n))
352 }
353
354 #[test]
355 fn test_write_u8_mono() {
356 let path = temp_path("nc_u8");
357 let mut writer = NetcdfWriter::new();
358
359 let mut arr = NDArray::new(
360 vec![NDDimension::new(4), NDDimension::new(4)],
361 NDDataType::UInt8,
362 );
363 if let NDDataBuffer::U8(ref mut v) = arr.data {
364 for i in 0..16 {
365 v[i] = i as u8;
366 }
367 }
368
369 writer.open_file(&path, NDFileMode::Single, &arr).unwrap();
370 writer.write_file(&arr).unwrap();
371 writer.close_file().unwrap();
372
373 let data = std::fs::read(&path).unwrap();
375 assert!(data.len() > 16);
376 assert_eq!(&data[0..3], b"CDF", "Expected NetCDF magic bytes");
377
378 std::fs::remove_file(&path).ok();
379 }
380
381 #[test]
382 fn test_write_u16() {
383 let path = temp_path("nc_u16");
384 let mut writer = NetcdfWriter::new();
385
386 let mut arr = NDArray::new(
387 vec![NDDimension::new(4), NDDimension::new(4)],
388 NDDataType::UInt16,
389 );
390 if let NDDataBuffer::U16(ref mut v) = arr.data {
391 for i in 0..16 {
392 v[i] = (i * 1000) as u16;
393 }
394 }
395
396 writer.open_file(&path, NDFileMode::Single, &arr).unwrap();
397 writer.write_file(&arr).unwrap();
398 writer.close_file().unwrap();
399
400 let data = std::fs::read(&path).unwrap();
401 assert!(data.len() > 32);
402 assert_eq!(&data[0..3], b"CDF");
403
404 std::fs::remove_file(&path).ok();
405 }
406
407 #[test]
408 fn test_roundtrip_u8() {
409 let path = temp_path("nc_rt_u8");
410 let mut writer = NetcdfWriter::new();
411
412 let mut arr = NDArray::new(
413 vec![NDDimension::new(4), NDDimension::new(4)],
414 NDDataType::UInt8,
415 );
416 if let NDDataBuffer::U8(ref mut v) = arr.data {
417 for i in 0..16 {
418 v[i] = (i * 10) as u8;
419 }
420 }
421
422 writer.open_file(&path, NDFileMode::Single, &arr).unwrap();
423 writer.write_file(&arr).unwrap();
424 writer.close_file().unwrap();
425
426 writer.current_path = Some(path.clone());
427 let read_back = writer.read_file().unwrap();
428 if let (NDDataBuffer::U8(ref orig), NDDataBuffer::U8(ref read)) =
429 (&arr.data, &read_back.data)
430 {
431 assert_eq!(orig, read);
432 } else {
433 panic!("data type mismatch on roundtrip");
434 }
435
436 std::fs::remove_file(&path).ok();
437 }
438
439 #[test]
440 fn test_roundtrip_i16() {
441 let path = temp_path("nc_rt_i16");
442 let mut writer = NetcdfWriter::new();
443
444 let mut arr = NDArray::new(
445 vec![NDDimension::new(4), NDDimension::new(4)],
446 NDDataType::Int16,
447 );
448 if let NDDataBuffer::I16(ref mut v) = arr.data {
449 for i in 0..16 {
450 v[i] = (i as i16) * 100 - 500;
451 }
452 }
453
454 writer.open_file(&path, NDFileMode::Single, &arr).unwrap();
455 writer.write_file(&arr).unwrap();
456 writer.close_file().unwrap();
457
458 writer.current_path = Some(path.clone());
459 let read_back = writer.read_file().unwrap();
460 if let (NDDataBuffer::I16(ref orig), NDDataBuffer::I16(ref read)) =
461 (&arr.data, &read_back.data)
462 {
463 assert_eq!(orig, read);
464 } else {
465 panic!("data type mismatch on roundtrip");
466 }
467
468 std::fs::remove_file(&path).ok();
469 }
470
471 #[test]
472 fn test_roundtrip_f32() {
473 let path = temp_path("nc_rt_f32");
474 let mut writer = NetcdfWriter::new();
475
476 let mut arr = NDArray::new(
477 vec![NDDimension::new(4), NDDimension::new(4)],
478 NDDataType::Float32,
479 );
480 if let NDDataBuffer::F32(ref mut v) = arr.data {
481 for i in 0..16 {
482 v[i] = i as f32 * 0.5;
483 }
484 }
485
486 writer.open_file(&path, NDFileMode::Single, &arr).unwrap();
487 writer.write_file(&arr).unwrap();
488 writer.close_file().unwrap();
489
490 writer.current_path = Some(path.clone());
491 let read_back = writer.read_file().unwrap();
492 if let (NDDataBuffer::F32(ref orig), NDDataBuffer::F32(ref read)) =
493 (&arr.data, &read_back.data)
494 {
495 assert_eq!(orig, read);
496 } else {
497 panic!("data type mismatch on roundtrip");
498 }
499
500 std::fs::remove_file(&path).ok();
501 }
502
503 #[test]
504 fn test_multiple_frames() {
505 let path = temp_path("nc_multi");
506 let mut writer = NetcdfWriter::new();
507
508 let mut arr1 = NDArray::new(
509 vec![NDDimension::new(4), NDDimension::new(4)],
510 NDDataType::UInt8,
511 );
512 if let NDDataBuffer::U8(ref mut v) = arr1.data {
513 for i in 0..16 {
514 v[i] = i as u8;
515 }
516 }
517
518 let mut arr2 = NDArray::new(
519 vec![NDDimension::new(4), NDDimension::new(4)],
520 NDDataType::UInt8,
521 );
522 if let NDDataBuffer::U8(ref mut v) = arr2.data {
523 for i in 0..16 {
524 v[i] = (i as u8).wrapping_add(100);
525 }
526 }
527
528 let mut arr3 = NDArray::new(
529 vec![NDDimension::new(4), NDDimension::new(4)],
530 NDDataType::UInt8,
531 );
532 if let NDDataBuffer::U8(ref mut v) = arr3.data {
533 for i in 0..16 {
534 v[i] = (i as u8).wrapping_add(200);
535 }
536 }
537
538 writer.open_file(&path, NDFileMode::Stream, &arr1).unwrap();
539 writer.write_file(&arr1).unwrap();
540 writer.write_file(&arr2).unwrap();
541 writer.write_file(&arr3).unwrap();
542 writer.close_file().unwrap();
543
544 writer.current_path = Some(path.clone());
546 let read_back = writer.read_file().unwrap();
547 if let NDDataBuffer::U8(ref v) = read_back.data {
548 assert_eq!(v.len(), 16);
549 for i in 0..16 {
550 assert_eq!(v[i], i as u8, "mismatch at index {}", i);
551 }
552 } else {
553 panic!("expected U8 data");
554 }
555
556 std::fs::remove_file(&path).ok();
557 }
558
559 #[test]
560 fn test_attributes_stored() {
561 let path = temp_path("nc_attrs");
562 let mut writer = NetcdfWriter::new();
563
564 let mut arr = NDArray::new(
565 vec![NDDimension::new(4)],
566 NDDataType::UInt8,
567 );
568 arr.attributes.add(NDAttribute {
569 name: "exposure".into(),
570 description: "".into(),
571 source: NDAttrSource::Driver,
572 value: NDAttrValue::Float64(0.5),
573 });
574 arr.attributes.add(NDAttribute {
575 name: "gain".into(),
576 description: "".into(),
577 source: NDAttrSource::Driver,
578 value: NDAttrValue::Int32(42),
579 });
580
581 writer.open_file(&path, NDFileMode::Single, &arr).unwrap();
582 writer.write_file(&arr).unwrap();
583 writer.close_file().unwrap();
584
585 let reader = FileReader::open(&path).unwrap();
587 let ds = reader.data_set();
588
589 let exposure = ds.get_var_attr_as_string(VAR_NAME, "exposure");
590 assert_eq!(exposure, Some("0.5".to_string()));
591
592 let gain = ds.get_var_attr_as_string(VAR_NAME, "gain");
593 assert_eq!(gain, Some("42".to_string()));
594
595 drop(reader);
596 std::fs::remove_file(&path).ok();
597 }
598}