1use std::path::{Path, PathBuf};
16
17use ad_core_rs::error::{ADError, ADResult};
18use ad_core_rs::ndarray::{NDArray, NDDataBuffer, NDDataType, NDDimension};
19use ad_core_rs::ndarray_pool::NDArrayPool;
20use ad_core_rs::plugin::file_base::{NDFileMode, NDFileWriter};
21use ad_core_rs::plugin::file_controller::FilePluginController;
22use ad_core_rs::plugin::runtime::{
23 NDPluginProcess, ParamChangeResult, PluginParamSnapshot, ProcessResult,
24};
25
26use rust_hdf5::{H5Dataset, H5File};
27
28pub struct NexusWriter {
30 current_path: Option<PathBuf>,
31 file: Option<H5File>,
32 frame_count: usize,
33 dataset: Option<H5Dataset>,
35}
36
37impl NexusWriter {
38 pub fn new() -> Self {
39 Self {
40 current_path: None,
41 file: None,
42 frame_count: 0,
43 dataset: None,
44 }
45 }
46
47 pub fn frame_count(&self) -> usize {
48 self.frame_count
49 }
50
51 fn write_nx_class(group: &rust_hdf5::H5Group, class_name: &str) -> ADResult<()> {
57 let ds = group
58 .new_dataset::<u8>()
59 .shape([1usize])
60 .create("NX_class")
61 .map_err(|e| {
62 ADError::UnsupportedConversion(format!("NX_class dataset error: {}", e))
63 })?;
64 ds.write_raw(&[0u8])
65 .map_err(|e| ADError::UnsupportedConversion(format!("NX_class write error: {}", e)))?;
66 let attr = ds
67 .new_attr::<rust_hdf5::types::VarLenUnicode>()
68 .shape(())
69 .create("value")
70 .map_err(|e| ADError::UnsupportedConversion(format!("NX_class attr error: {}", e)))?;
71 attr.write_string(class_name).map_err(|e| {
72 ADError::UnsupportedConversion(format!("NX_class attr write error: {}", e))
73 })?;
74 Ok(())
75 }
76}
77
78impl Default for NexusWriter {
79 fn default() -> Self {
80 Self::new()
81 }
82}
83
84impl NDFileWriter for NexusWriter {
85 fn open_file(&mut self, path: &Path, _mode: NDFileMode, _array: &NDArray) -> ADResult<()> {
86 self.current_path = Some(path.to_path_buf());
87 self.frame_count = 0;
88
89 let h5file = H5File::create(path)
90 .map_err(|e| ADError::UnsupportedConversion(format!("NeXus create error: {}", e)))?;
91
92 let entry = h5file
97 .create_group("entry")
98 .map_err(|e| ADError::UnsupportedConversion(format!("NeXus group error: {}", e)))?;
99 Self::write_nx_class(&entry, "NXentry")?;
100 let instrument = entry
101 .create_group("instrument")
102 .map_err(|e| ADError::UnsupportedConversion(format!("NeXus group error: {}", e)))?;
103 Self::write_nx_class(&instrument, "NXinstrument")?;
104 let _detector = instrument
105 .create_group("detector")
106 .map_err(|e| ADError::UnsupportedConversion(format!("NeXus group error: {}", e)))?;
107 Self::write_nx_class(&_detector, "NXdetector")?;
108 let _data_group = entry
109 .create_group("data")
110 .map_err(|e| ADError::UnsupportedConversion(format!("NeXus group error: {}", e)))?;
111 Self::write_nx_class(&_data_group, "NXdata")?;
112
113 self.file = Some(h5file);
114 Ok(())
115 }
116
117 fn write_file(&mut self, array: &NDArray) -> ADResult<()> {
118 let h5file = self
119 .file
120 .as_ref()
121 .ok_or_else(|| ADError::UnsupportedConversion("no NeXus file open".into()))?;
122
123 let frame_shape = array.dims.iter().rev().map(|d| d.size).collect::<Vec<_>>();
124
125 if self.frame_count == 0 {
126 let detector_group = h5file
128 .root_group()
129 .group("entry")
130 .map_err(|e| ADError::UnsupportedConversion(e.to_string()))?
131 .group("instrument")
132 .map_err(|e| ADError::UnsupportedConversion(e.to_string()))?
133 .group("detector")
134 .map_err(|e| ADError::UnsupportedConversion(e.to_string()))?;
135
136 let mut ds_shape = vec![1usize];
138 ds_shape.extend_from_slice(&frame_shape);
139 let chunk_dims = ds_shape.clone();
140
141 macro_rules! create_chunked {
142 ($t:ty, $v:expr) => {{
143 let ds = detector_group
144 .new_dataset::<$t>()
145 .shape(&ds_shape[..])
146 .chunk(&chunk_dims[..])
147 .resizable()
148 .create("data")
149 .map_err(|e| {
150 ADError::UnsupportedConversion(format!("NeXus dataset error: {}", e))
151 })?;
152 let raw = unsafe {
153 std::slice::from_raw_parts(
154 $v.as_ptr() as *const u8,
155 $v.len() * std::mem::size_of::<$t>(),
156 )
157 };
158 ds.write_chunk(0, raw).map_err(|e| {
159 ADError::UnsupportedConversion(format!("NeXus write error: {}", e))
160 })?;
161 for attr in array.attributes.iter() {
163 let val_str = attr.value.as_string();
164 let _ = ds
165 .new_attr::<rust_hdf5::types::VarLenUnicode>()
166 .shape(())
167 .create(attr.name.as_str())
168 .and_then(|a| {
169 let s: rust_hdf5::types::VarLenUnicode =
170 val_str.parse().unwrap_or_default();
171 a.write_scalar(&s)
172 });
173 }
174 ds
175 }};
176 }
177
178 let ds = match &array.data {
179 NDDataBuffer::U8(v) => create_chunked!(u8, v),
180 NDDataBuffer::U16(v) => create_chunked!(u16, v),
181 NDDataBuffer::I16(v) => create_chunked!(i16, v),
182 NDDataBuffer::I32(v) => create_chunked!(i32, v),
183 NDDataBuffer::U32(v) => create_chunked!(u32, v),
184 NDDataBuffer::F32(v) => create_chunked!(f32, v),
185 NDDataBuffer::F64(v) => create_chunked!(f64, v),
186 _ => {
187 let raw = array.data.as_u8_slice();
188 let ds = detector_group
189 .new_dataset::<u8>()
190 .shape(&ds_shape[..])
191 .chunk(&chunk_dims[..])
192 .resizable()
193 .create("data")
194 .map_err(|e| {
195 ADError::UnsupportedConversion(format!("NeXus dataset error: {}", e))
196 })?;
197 ds.write_chunk(0, raw).map_err(|e| {
198 ADError::UnsupportedConversion(format!("NeXus write error: {}", e))
199 })?;
200 ds
201 }
202 };
203
204 self.dataset = Some(ds);
205 } else {
206 let ds = self.dataset.as_ref().ok_or_else(|| {
208 ADError::UnsupportedConversion("no dataset for multi-frame write".into())
209 })?;
210
211 let new_frame_count = self.frame_count + 1;
212 let mut new_shape = vec![new_frame_count];
213 new_shape.extend_from_slice(&frame_shape);
214 ds.extend(&new_shape).map_err(|e| {
215 ADError::UnsupportedConversion(format!("NeXus extend error: {}", e))
216 })?;
217
218 let raw = array.data.as_u8_slice();
219 ds.write_chunk(self.frame_count, raw)
220 .map_err(|e| ADError::UnsupportedConversion(format!("NeXus write error: {}", e)))?;
221 }
222
223 if let Some(ref ds) = self.dataset {
225 let uid_name = format!("uniqueId_{}", self.frame_count);
226 let _ = ds
227 .new_attr::<i32>()
228 .shape(())
229 .create(&uid_name)
230 .and_then(|a| a.write_numeric(&array.unique_id));
231 let ts_name = format!("timeStamp_{}", self.frame_count);
232 let _ = ds
233 .new_attr::<f64>()
234 .shape(())
235 .create(&ts_name)
236 .and_then(|a| a.write_numeric(&array.time_stamp));
237 }
238
239 self.frame_count += 1;
240 Ok(())
241 }
242
243 fn read_file(&mut self) -> ADResult<NDArray> {
244 let path = self
245 .current_path
246 .as_ref()
247 .ok_or_else(|| ADError::UnsupportedConversion("no file open".into()))?;
248
249 let h5file = H5File::open(path)
250 .map_err(|e| ADError::UnsupportedConversion(format!("NeXus open error: {}", e)))?;
251
252 let ds = h5file
254 .dataset("entry/instrument/detector/data")
255 .map_err(|e| ADError::UnsupportedConversion(format!("NeXus dataset error: {}", e)))?;
256
257 let shape = ds.shape();
258 let dims: Vec<NDDimension> = shape.iter().rev().map(|&s| NDDimension::new(s)).collect();
259
260 if let Ok(data) = ds.read_raw::<u8>() {
261 let mut arr = NDArray::new(dims, NDDataType::UInt8);
262 arr.data = NDDataBuffer::U8(data);
263 return Ok(arr);
264 }
265 if let Ok(data) = ds.read_raw::<u16>() {
266 let mut arr = NDArray::new(dims, NDDataType::UInt16);
267 arr.data = NDDataBuffer::U16(data);
268 return Ok(arr);
269 }
270 if let Ok(data) = ds.read_raw::<f64>() {
271 let mut arr = NDArray::new(dims, NDDataType::Float64);
272 arr.data = NDDataBuffer::F64(data);
273 return Ok(arr);
274 }
275
276 Err(ADError::UnsupportedConversion(
277 "unsupported data type in NeXus file".into(),
278 ))
279 }
280
281 fn close_file(&mut self) -> ADResult<()> {
282 self.dataset = None;
283 self.file = None;
284 self.current_path = None;
285 Ok(())
286 }
287
288 fn supports_multiple_arrays(&self) -> bool {
289 true
290 }
291}
292
293pub struct NexusFileProcessor {
298 ctrl: FilePluginController<NexusWriter>,
299}
300
301impl NexusFileProcessor {
302 pub fn new() -> Self {
303 Self {
304 ctrl: FilePluginController::new(NexusWriter::new()),
305 }
306 }
307}
308
309impl Default for NexusFileProcessor {
310 fn default() -> Self {
311 Self::new()
312 }
313}
314
315impl NDPluginProcess for NexusFileProcessor {
316 fn process_array(&mut self, array: &NDArray, _pool: &NDArrayPool) -> ProcessResult {
317 self.ctrl.process_array(array)
318 }
319
320 fn plugin_type(&self) -> &str {
321 "NDFileNexus"
322 }
323
324 fn register_params(
325 &mut self,
326 base: &mut asyn_rs::port::PortDriverBase,
327 ) -> asyn_rs::error::AsynResult<()> {
328 self.ctrl.register_params(base)?;
329 use asyn_rs::param::ParamType;
330 base.create_param("NEXUS_TEMPLATE_PATH", ParamType::Octet)?;
331 base.create_param("NEXUS_TEMPLATE_FILE", ParamType::Octet)?;
332 base.create_param("NEXUS_TEMPLATE_VALID", ParamType::Int32)?;
333 base.create_param("TEMPLATE_FILE_PATH", ParamType::Octet)?;
334 base.create_param("TEMPLATE_FILE_NAME", ParamType::Octet)?;
335 base.create_param("TEMPLATE_FILE_VALID", ParamType::Int32)?;
336 Ok(())
337 }
338
339 fn on_param_change(
340 &mut self,
341 reason: usize,
342 params: &PluginParamSnapshot,
343 ) -> ParamChangeResult {
344 self.ctrl.on_param_change(reason, params)
345 }
346}
347
348#[cfg(test)]
349mod tests {
350 use super::*;
351
352 fn temp_path(prefix: &str) -> PathBuf {
353 use std::sync::atomic::{AtomicU32, Ordering};
354 static COUNTER: AtomicU32 = AtomicU32::new(0);
355 let n = COUNTER.fetch_add(1, Ordering::Relaxed);
356 std::env::temp_dir().join(format!("adcore_test_{}_{}.nxs", prefix, n))
357 }
358
359 #[test]
360 fn test_nexus_write_read() {
361 let path = temp_path("nexus_basic");
362 let mut writer = NexusWriter::new();
363
364 let mut arr = NDArray::new(
365 vec![NDDimension::new(4), NDDimension::new(4)],
366 NDDataType::UInt8,
367 );
368 if let NDDataBuffer::U8(ref mut v) = arr.data {
369 for i in 0..16 {
370 v[i] = i as u8;
371 }
372 }
373
374 writer.open_file(&path, NDFileMode::Single, &arr).unwrap();
375 writer.write_file(&arr).unwrap();
376 writer.close_file().unwrap();
377
378 let h5file = H5File::open(&path).unwrap();
380 let ds = h5file.dataset("entry/instrument/detector/data").unwrap();
381 let data: Vec<u8> = ds.read_raw().unwrap();
382 assert_eq!(data.len(), 16);
383 assert_eq!(data[0], 0);
384 assert_eq!(data[15], 15);
385
386 std::fs::remove_file(&path).ok();
387 }
388
389 #[test]
390 fn test_nexus_multiple_frames() {
391 let path = temp_path("nexus_multi");
392 let mut writer = NexusWriter::new();
393
394 let mut arr1 = NDArray::new(
395 vec![NDDimension::new(4), NDDimension::new(4)],
396 NDDataType::UInt8,
397 );
398 if let NDDataBuffer::U8(ref mut v) = arr1.data {
399 for i in 0..16 {
400 v[i] = i as u8;
401 }
402 }
403
404 let mut arr2 = NDArray::new(
405 vec![NDDimension::new(4), NDDimension::new(4)],
406 NDDataType::UInt8,
407 );
408 if let NDDataBuffer::U8(ref mut v) = arr2.data {
409 for i in 0..16 {
410 v[i] = (i + 100) as u8;
411 }
412 }
413
414 writer.open_file(&path, NDFileMode::Stream, &arr1).unwrap();
415 writer.write_file(&arr1).unwrap();
416 writer.write_file(&arr2).unwrap();
417 writer.close_file().unwrap();
418
419 assert_eq!(writer.frame_count(), 2);
420
421 let h5file = H5File::open(&path).unwrap();
423 let ds = h5file.dataset("entry/instrument/detector/data").unwrap();
424 let shape = ds.shape();
425 assert_eq!(shape, vec![2, 4, 4]);
426
427 let data: Vec<u8> = ds.read_raw().unwrap();
428 assert_eq!(data.len(), 32);
429 assert_eq!(data[0], 0);
431 assert_eq!(data[15], 15);
432 assert_eq!(data[16], 100);
434 assert_eq!(data[31], 115);
435
436 std::fs::remove_file(&path).ok();
437 }
438}