Skip to main content

ad_plugins_rs/
file_nexus.rs

1//! NeXus file writer plugin.
2//!
3//! Writes NDArray data in NeXus/HDF5 format using the rust-hdf5 library.
4//! Follows the simplified NXdata convention:
5//!
6//! ```text
7//! /entry (NX_class=NXentry)
8//!   /instrument (NX_class=NXinstrument)
9//!     /detector (NX_class=NXdetector)
10//!       /data → dataset [frames × Y × X]
11//!   /data (NX_class=NXdata)
12//!     /data → same dataset
13//! ```
14
15use std::path::{Path, PathBuf};
16
17use ad_core_rs::error::{ADError, ADResult};
18use ad_core_rs::ndarray::{NDArray, NDDataBuffer, NDDataType, NDDimension};
19use ad_core_rs::ndarray_pool::NDArrayPool;
20use ad_core_rs::plugin::file_base::{NDFileMode, NDFileWriter};
21use ad_core_rs::plugin::file_controller::FilePluginController;
22use ad_core_rs::plugin::runtime::{
23    NDPluginProcess, ParamChangeResult, PluginParamSnapshot, ProcessResult,
24};
25
26use rust_hdf5::{H5Dataset, H5File};
27
28/// NeXus file writer using HDF5 with NeXus group structure.
29pub struct NexusWriter {
30    current_path: Option<PathBuf>,
31    file: Option<H5File>,
32    frame_count: usize,
33    /// Reusable dataset handle for multi-frame writes.
34    dataset: Option<H5Dataset>,
35}
36
37impl NexusWriter {
38    pub fn new() -> Self {
39        Self {
40            current_path: None,
41            file: None,
42            frame_count: 0,
43            dataset: None,
44        }
45    }
46
47    pub fn frame_count(&self) -> usize {
48        self.frame_count
49    }
50
51    /// Write an NX_class marker dataset into a group.
52    ///
53    /// rust-hdf5 does not support group-level attributes, so we create a
54    /// scalar u8 dataset named "NX_class" and attach the class name as a
55    /// string attribute on it.
56    fn write_nx_class(group: &rust_hdf5::H5Group, class_name: &str) -> ADResult<()> {
57        let ds = group
58            .new_dataset::<u8>()
59            .shape([1usize])
60            .create("NX_class")
61            .map_err(|e| {
62                ADError::UnsupportedConversion(format!("NX_class dataset error: {}", e))
63            })?;
64        ds.write_raw(&[0u8])
65            .map_err(|e| ADError::UnsupportedConversion(format!("NX_class write error: {}", e)))?;
66        let attr = ds
67            .new_attr::<rust_hdf5::types::VarLenUnicode>()
68            .shape(())
69            .create("value")
70            .map_err(|e| ADError::UnsupportedConversion(format!("NX_class attr error: {}", e)))?;
71        attr.write_string(class_name).map_err(|e| {
72            ADError::UnsupportedConversion(format!("NX_class attr write error: {}", e))
73        })?;
74        Ok(())
75    }
76}
77
78impl Default for NexusWriter {
79    fn default() -> Self {
80        Self::new()
81    }
82}
83
84impl NDFileWriter for NexusWriter {
85    fn open_file(&mut self, path: &Path, _mode: NDFileMode, _array: &NDArray) -> ADResult<()> {
86        self.current_path = Some(path.to_path_buf());
87        self.frame_count = 0;
88
89        let h5file = H5File::create(path)
90            .map_err(|e| ADError::UnsupportedConversion(format!("NeXus create error: {}", e)))?;
91
92        // Create NeXus group hierarchy with NX_class marker datasets.
93        // Note: rust-hdf5 does not support group-level attributes, so we store
94        // NX_class as a scalar u8 dataset within each group. NeXus-aware readers
95        // should use the path hierarchy for group identification.
96        let entry = h5file
97            .create_group("entry")
98            .map_err(|e| ADError::UnsupportedConversion(format!("NeXus group error: {}", e)))?;
99        Self::write_nx_class(&entry, "NXentry")?;
100        let instrument = entry
101            .create_group("instrument")
102            .map_err(|e| ADError::UnsupportedConversion(format!("NeXus group error: {}", e)))?;
103        Self::write_nx_class(&instrument, "NXinstrument")?;
104        let _detector = instrument
105            .create_group("detector")
106            .map_err(|e| ADError::UnsupportedConversion(format!("NeXus group error: {}", e)))?;
107        Self::write_nx_class(&_detector, "NXdetector")?;
108        let _data_group = entry
109            .create_group("data")
110            .map_err(|e| ADError::UnsupportedConversion(format!("NeXus group error: {}", e)))?;
111        Self::write_nx_class(&_data_group, "NXdata")?;
112
113        self.file = Some(h5file);
114        Ok(())
115    }
116
117    fn write_file(&mut self, array: &NDArray) -> ADResult<()> {
118        let h5file = self
119            .file
120            .as_ref()
121            .ok_or_else(|| ADError::UnsupportedConversion("no NeXus file open".into()))?;
122
123        let frame_shape = array.dims.iter().rev().map(|d| d.size).collect::<Vec<_>>();
124
125        if self.frame_count == 0 {
126            // First frame: create a chunked dataset with leading frame dimension.
127            let detector_group = h5file
128                .root_group()
129                .group("entry")
130                .map_err(|e| ADError::UnsupportedConversion(e.to_string()))?
131                .group("instrument")
132                .map_err(|e| ADError::UnsupportedConversion(e.to_string()))?
133                .group("detector")
134                .map_err(|e| ADError::UnsupportedConversion(e.to_string()))?;
135
136            // Shape: [1, dim0, dim1, ...], chunk: [1, dim0, dim1, ...]
137            let mut ds_shape = vec![1usize];
138            ds_shape.extend_from_slice(&frame_shape);
139            let chunk_dims = ds_shape.clone();
140
141            macro_rules! create_chunked {
142                ($t:ty, $v:expr) => {{
143                    let ds = detector_group
144                        .new_dataset::<$t>()
145                        .shape(&ds_shape[..])
146                        .chunk(&chunk_dims[..])
147                        .resizable()
148                        .create("data")
149                        .map_err(|e| {
150                            ADError::UnsupportedConversion(format!("NeXus dataset error: {}", e))
151                        })?;
152                    let raw = unsafe {
153                        std::slice::from_raw_parts(
154                            $v.as_ptr() as *const u8,
155                            $v.len() * std::mem::size_of::<$t>(),
156                        )
157                    };
158                    ds.write_chunk(0, raw).map_err(|e| {
159                        ADError::UnsupportedConversion(format!("NeXus write error: {}", e))
160                    })?;
161                    // Write NDArray attributes on the first frame
162                    for attr in array.attributes.iter() {
163                        let val_str = attr.value.as_string();
164                        let _ = ds
165                            .new_attr::<rust_hdf5::types::VarLenUnicode>()
166                            .shape(())
167                            .create(attr.name.as_str())
168                            .and_then(|a| {
169                                let s: rust_hdf5::types::VarLenUnicode =
170                                    val_str.parse().unwrap_or_default();
171                                a.write_scalar(&s)
172                            });
173                    }
174                    ds
175                }};
176            }
177
178            let ds = match &array.data {
179                NDDataBuffer::U8(v) => create_chunked!(u8, v),
180                NDDataBuffer::U16(v) => create_chunked!(u16, v),
181                NDDataBuffer::I16(v) => create_chunked!(i16, v),
182                NDDataBuffer::I32(v) => create_chunked!(i32, v),
183                NDDataBuffer::U32(v) => create_chunked!(u32, v),
184                NDDataBuffer::F32(v) => create_chunked!(f32, v),
185                NDDataBuffer::F64(v) => create_chunked!(f64, v),
186                _ => {
187                    let raw = array.data.as_u8_slice();
188                    let ds = detector_group
189                        .new_dataset::<u8>()
190                        .shape(&ds_shape[..])
191                        .chunk(&chunk_dims[..])
192                        .resizable()
193                        .create("data")
194                        .map_err(|e| {
195                            ADError::UnsupportedConversion(format!("NeXus dataset error: {}", e))
196                        })?;
197                    ds.write_chunk(0, raw).map_err(|e| {
198                        ADError::UnsupportedConversion(format!("NeXus write error: {}", e))
199                    })?;
200                    ds
201                }
202            };
203
204            self.dataset = Some(ds);
205        } else {
206            // Subsequent frames: extend dataset and write new chunk.
207            let ds = self.dataset.as_ref().ok_or_else(|| {
208                ADError::UnsupportedConversion("no dataset for multi-frame write".into())
209            })?;
210
211            let new_frame_count = self.frame_count + 1;
212            let mut new_shape = vec![new_frame_count];
213            new_shape.extend_from_slice(&frame_shape);
214            ds.extend(&new_shape).map_err(|e| {
215                ADError::UnsupportedConversion(format!("NeXus extend error: {}", e))
216            })?;
217
218            let raw = array.data.as_u8_slice();
219            ds.write_chunk(self.frame_count, raw)
220                .map_err(|e| ADError::UnsupportedConversion(format!("NeXus write error: {}", e)))?;
221        }
222
223        // Write per-frame uniqueId and timeStamp as attributes on the dataset
224        if let Some(ref ds) = self.dataset {
225            let uid_name = format!("uniqueId_{}", self.frame_count);
226            let _ = ds
227                .new_attr::<i32>()
228                .shape(())
229                .create(&uid_name)
230                .and_then(|a| a.write_numeric(&array.unique_id));
231            let ts_name = format!("timeStamp_{}", self.frame_count);
232            let _ = ds
233                .new_attr::<f64>()
234                .shape(())
235                .create(&ts_name)
236                .and_then(|a| a.write_numeric(&array.time_stamp));
237        }
238
239        self.frame_count += 1;
240        Ok(())
241    }
242
243    fn read_file(&mut self) -> ADResult<NDArray> {
244        let path = self
245            .current_path
246            .as_ref()
247            .ok_or_else(|| ADError::UnsupportedConversion("no file open".into()))?;
248
249        let h5file = H5File::open(path)
250            .map_err(|e| ADError::UnsupportedConversion(format!("NeXus open error: {}", e)))?;
251
252        // Try reading from /entry/instrument/detector/data
253        let ds = h5file
254            .dataset("entry/instrument/detector/data")
255            .map_err(|e| ADError::UnsupportedConversion(format!("NeXus dataset error: {}", e)))?;
256
257        let shape = ds.shape();
258        let dims: Vec<NDDimension> = shape.iter().rev().map(|&s| NDDimension::new(s)).collect();
259
260        if let Ok(data) = ds.read_raw::<u8>() {
261            let mut arr = NDArray::new(dims, NDDataType::UInt8);
262            arr.data = NDDataBuffer::U8(data);
263            return Ok(arr);
264        }
265        if let Ok(data) = ds.read_raw::<u16>() {
266            let mut arr = NDArray::new(dims, NDDataType::UInt16);
267            arr.data = NDDataBuffer::U16(data);
268            return Ok(arr);
269        }
270        if let Ok(data) = ds.read_raw::<f64>() {
271            let mut arr = NDArray::new(dims, NDDataType::Float64);
272            arr.data = NDDataBuffer::F64(data);
273            return Ok(arr);
274        }
275
276        Err(ADError::UnsupportedConversion(
277            "unsupported data type in NeXus file".into(),
278        ))
279    }
280
281    fn close_file(&mut self) -> ADResult<()> {
282        self.dataset = None;
283        self.file = None;
284        self.current_path = None;
285        Ok(())
286    }
287
288    fn supports_multiple_arrays(&self) -> bool {
289        true
290    }
291}
292
293// ============================================================
294// Processor
295// ============================================================
296
297pub struct NexusFileProcessor {
298    ctrl: FilePluginController<NexusWriter>,
299}
300
301impl NexusFileProcessor {
302    pub fn new() -> Self {
303        Self {
304            ctrl: FilePluginController::new(NexusWriter::new()),
305        }
306    }
307}
308
309impl Default for NexusFileProcessor {
310    fn default() -> Self {
311        Self::new()
312    }
313}
314
315impl NDPluginProcess for NexusFileProcessor {
316    fn process_array(&mut self, array: &NDArray, _pool: &NDArrayPool) -> ProcessResult {
317        self.ctrl.process_array(array)
318    }
319
320    fn plugin_type(&self) -> &str {
321        "NDFileNexus"
322    }
323
324    fn register_params(
325        &mut self,
326        base: &mut asyn_rs::port::PortDriverBase,
327    ) -> asyn_rs::error::AsynResult<()> {
328        self.ctrl.register_params(base)?;
329        use asyn_rs::param::ParamType;
330        base.create_param("NEXUS_TEMPLATE_PATH", ParamType::Octet)?;
331        base.create_param("NEXUS_TEMPLATE_FILE", ParamType::Octet)?;
332        base.create_param("NEXUS_TEMPLATE_VALID", ParamType::Int32)?;
333        base.create_param("TEMPLATE_FILE_PATH", ParamType::Octet)?;
334        base.create_param("TEMPLATE_FILE_NAME", ParamType::Octet)?;
335        base.create_param("TEMPLATE_FILE_VALID", ParamType::Int32)?;
336        Ok(())
337    }
338
339    fn on_param_change(
340        &mut self,
341        reason: usize,
342        params: &PluginParamSnapshot,
343    ) -> ParamChangeResult {
344        self.ctrl.on_param_change(reason, params)
345    }
346}
347
348#[cfg(test)]
349mod tests {
350    use super::*;
351
352    fn temp_path(prefix: &str) -> PathBuf {
353        use std::sync::atomic::{AtomicU32, Ordering};
354        static COUNTER: AtomicU32 = AtomicU32::new(0);
355        let n = COUNTER.fetch_add(1, Ordering::Relaxed);
356        std::env::temp_dir().join(format!("adcore_test_{}_{}.nxs", prefix, n))
357    }
358
359    #[test]
360    fn test_nexus_write_read() {
361        let path = temp_path("nexus_basic");
362        let mut writer = NexusWriter::new();
363
364        let mut arr = NDArray::new(
365            vec![NDDimension::new(4), NDDimension::new(4)],
366            NDDataType::UInt8,
367        );
368        if let NDDataBuffer::U8(ref mut v) = arr.data {
369            for i in 0..16 {
370                v[i] = i as u8;
371            }
372        }
373
374        writer.open_file(&path, NDFileMode::Single, &arr).unwrap();
375        writer.write_file(&arr).unwrap();
376        writer.close_file().unwrap();
377
378        // Verify NeXus structure
379        let h5file = H5File::open(&path).unwrap();
380        let ds = h5file.dataset("entry/instrument/detector/data").unwrap();
381        let data: Vec<u8> = ds.read_raw().unwrap();
382        assert_eq!(data.len(), 16);
383        assert_eq!(data[0], 0);
384        assert_eq!(data[15], 15);
385
386        std::fs::remove_file(&path).ok();
387    }
388
389    #[test]
390    fn test_nexus_multiple_frames() {
391        let path = temp_path("nexus_multi");
392        let mut writer = NexusWriter::new();
393
394        let mut arr1 = NDArray::new(
395            vec![NDDimension::new(4), NDDimension::new(4)],
396            NDDataType::UInt8,
397        );
398        if let NDDataBuffer::U8(ref mut v) = arr1.data {
399            for i in 0..16 {
400                v[i] = i as u8;
401            }
402        }
403
404        let mut arr2 = NDArray::new(
405            vec![NDDimension::new(4), NDDimension::new(4)],
406            NDDataType::UInt8,
407        );
408        if let NDDataBuffer::U8(ref mut v) = arr2.data {
409            for i in 0..16 {
410                v[i] = (i + 100) as u8;
411            }
412        }
413
414        writer.open_file(&path, NDFileMode::Stream, &arr1).unwrap();
415        writer.write_file(&arr1).unwrap();
416        writer.write_file(&arr2).unwrap();
417        writer.close_file().unwrap();
418
419        assert_eq!(writer.frame_count(), 2);
420
421        // Verify single dataset with leading frame dimension [2, 4, 4]
422        let h5file = H5File::open(&path).unwrap();
423        let ds = h5file.dataset("entry/instrument/detector/data").unwrap();
424        let shape = ds.shape();
425        assert_eq!(shape, vec![2, 4, 4]);
426
427        let data: Vec<u8> = ds.read_raw().unwrap();
428        assert_eq!(data.len(), 32);
429        // First frame
430        assert_eq!(data[0], 0);
431        assert_eq!(data[15], 15);
432        // Second frame
433        assert_eq!(data[16], 100);
434        assert_eq!(data[31], 115);
435
436        std::fs::remove_file(&path).ok();
437    }
438}