1use std::sync::Arc;
2
3#[cfg(feature = "parallel")]
4use rayon::prelude::*;
5#[cfg(feature = "parallel")]
6use crate::par_util;
7
8use ad_core::ndarray::{NDArray, NDDataBuffer, NDDataType};
9use ad_core::ndarray_pool::NDArrayPool;
10use ad_core::plugin::runtime::{NDPluginProcess, ProcessResult};
11
12#[derive(Debug, Clone)]
25pub struct FilterConfig {
26 pub num_filter: usize,
28 pub auto_reset: bool,
30 pub filter_callbacks: usize,
32 pub oc: [f64; 4],
34 pub fc: [f64; 4],
36 pub rc: [f64; 2],
38 pub o_offset: f64,
40 pub o_scale: f64,
42 pub f_offset: f64,
44 pub f_scale: f64,
46}
47
48impl Default for FilterConfig {
49 fn default() -> Self {
50 Self {
51 num_filter: 1,
52 auto_reset: false,
53 filter_callbacks: 0,
54 oc: [1.0, 0.0, 0.0, 0.0], fc: [1.0, 0.0, 0.0, 0.0],
56 rc: [1.0, 0.0],
57 o_offset: 0.0,
58 o_scale: 1.0,
59 f_offset: 0.0,
60 f_scale: 1.0,
61 }
62 }
63}
64
65#[derive(Debug, Clone)]
67pub struct ProcessConfig {
68 pub enable_background: bool,
69 pub enable_flat_field: bool,
70 pub enable_offset_scale: bool,
71 pub offset: f64,
72 pub scale: f64,
73 pub enable_low_clip: bool,
74 pub low_clip: f64,
75 pub enable_high_clip: bool,
76 pub high_clip: f64,
77 pub enable_filter: bool,
78 pub filter: FilterConfig,
79 pub output_type: Option<NDDataType>,
80 pub save_background: bool,
82 pub save_flat_field: bool,
84 pub valid_background: bool,
86 pub valid_flat_field: bool,
88}
89
90impl Default for ProcessConfig {
91 fn default() -> Self {
92 Self {
93 enable_background: false,
94 enable_flat_field: false,
95 enable_offset_scale: false,
96 offset: 0.0,
97 scale: 1.0,
98 enable_low_clip: false,
99 low_clip: 0.0,
100 enable_high_clip: false,
101 high_clip: 0.0,
102 enable_filter: false,
103 filter: FilterConfig::default(),
104 output_type: None,
105 save_background: false,
106 save_flat_field: false,
107 valid_background: false,
108 valid_flat_field: false,
109 }
110 }
111}
112
113pub struct ProcessState {
115 pub config: ProcessConfig,
116 pub background: Option<Vec<f64>>,
117 pub flat_field: Option<Vec<f64>>,
118 pub filter_state: Option<Vec<f64>>,
120 pub filter_state_prev: Option<Vec<Vec<f64>>>,
122 pub output_state: Option<Vec<f64>>,
124 pub output_state_prev: Option<Vec<f64>>,
126 pub num_filtered: usize,
128}
129
130impl ProcessState {
131 pub fn new(config: ProcessConfig) -> Self {
132 Self {
133 config,
134 background: None,
135 flat_field: None,
136 filter_state: None,
137 filter_state_prev: None,
138 output_state: None,
139 output_state_prev: None,
140 num_filtered: 0,
141 }
142 }
143
144 pub fn save_background(&mut self, array: &NDArray) {
146 let n = array.data.len();
147 let mut bg = vec![0.0f64; n];
148 for i in 0..n {
149 bg[i] = array.data.get_as_f64(i).unwrap_or(0.0);
150 }
151 self.background = Some(bg);
152 self.config.valid_background = true;
153 }
154
155 pub fn save_flat_field(&mut self, array: &NDArray) {
157 let n = array.data.len();
158 let mut ff = vec![0.0f64; n];
159 for i in 0..n {
160 ff[i] = array.data.get_as_f64(i).unwrap_or(0.0);
161 }
162 self.flat_field = Some(ff);
163 self.config.valid_flat_field = true;
164 }
165
166 pub fn reset_filter(&mut self) {
168 self.filter_state = None;
169 self.filter_state_prev = None;
170 self.output_state = None;
171 self.output_state_prev = None;
172 self.num_filtered = 0;
173 }
174
175 pub fn process(&mut self, src: &NDArray) -> NDArray {
177 let n = src.data.len();
178 let mut values = vec![0.0f64; n];
179 for i in 0..n {
180 values[i] = src.data.get_as_f64(i).unwrap_or(0.0);
181 }
182
183 if self.config.save_background {
185 self.save_background(src);
186 self.config.save_background = false;
187 }
188 if self.config.save_flat_field {
189 self.save_flat_field(src);
190 self.config.save_flat_field = false;
191 }
192
193 let needs_element_ops = self.config.enable_background
196 || self.config.enable_flat_field
197 || self.config.enable_offset_scale
198 || self.config.enable_low_clip
199 || self.config.enable_high_clip;
200
201 if needs_element_ops {
202 let bg = if self.config.enable_background { self.background.as_ref() } else { None };
203 let (ff, ff_mean) = if self.config.enable_flat_field {
204 if let Some(ref ff) = self.flat_field {
205 let mean = ff.iter().sum::<f64>() / ff.len().max(1) as f64;
206 (Some(ff.as_slice()), mean)
207 } else {
208 (None, 0.0)
209 }
210 } else {
211 (None, 0.0)
212 };
213 let do_offset_scale = self.config.enable_offset_scale;
214 let scale = self.config.scale;
215 let offset = self.config.offset;
216 let do_low_clip = self.config.enable_low_clip;
217 let low_clip = self.config.low_clip;
218 let do_high_clip = self.config.enable_high_clip;
219 let high_clip = self.config.high_clip;
220
221 let apply_stages = |i: usize, v: &mut f64| {
222 if let Some(bg) = bg {
224 if i < bg.len() {
225 *v -= bg[i];
226 }
227 }
228 if let Some(ff) = ff {
230 if i < ff.len() && ff[i] != 0.0 {
231 *v = *v * ff_mean / ff[i];
232 }
233 }
234 if do_offset_scale {
236 *v = *v * scale + offset;
237 }
238 if do_low_clip && *v < low_clip {
240 *v = low_clip;
241 }
242 if do_high_clip && *v > high_clip {
243 *v = high_clip;
244 }
245 };
246
247 #[cfg(feature = "parallel")]
248 let use_parallel = par_util::should_parallelize(n);
249 #[cfg(not(feature = "parallel"))]
250 let use_parallel = false;
251
252 if use_parallel {
253 #[cfg(feature = "parallel")]
254 par_util::thread_pool().install(|| {
255 values.par_iter_mut().enumerate().for_each(|(i, v)| {
256 apply_stages(i, v);
257 });
258 });
259 } else {
260 for (i, v) in values.iter_mut().enumerate() {
261 apply_stages(i, v);
262 }
263 }
264 }
265
266 if self.config.enable_filter {
268 let fc = &self.config.filter;
269 let is_first_frame = self.filter_state.is_none();
270
271 if is_first_frame {
272 let rc1 = fc.rc[0];
274 let rc2 = fc.rc[1];
275
276 let mut f_new = vec![0.0f64; n];
277 let f_prev = self.filter_state.as_ref();
279 for i in 0..n {
280 let fp = f_prev.map_or(0.0, |p| p[i]);
281 f_new[i] = rc1 * values[i] + rc2 * fp;
282 }
283
284 let mut o_new = vec![0.0f64; n];
286 for i in 0..n {
287 o_new[i] = fc.oc[0] * f_new[i];
288 o_new[i] = fc.o_offset + fc.o_scale * o_new[i];
289 }
290
291 self.filter_state_prev = Some(vec![vec![0.0; n], vec![0.0; n]]);
293 self.output_state_prev = Some(vec![0.0; n]);
294 self.output_state = Some(o_new.clone());
295 self.filter_state = Some(f_new);
296 self.num_filtered = 1;
297
298 values = o_new;
299 } else {
300 let f_prev = self.filter_state.as_ref().unwrap(); let f_prev_history = self.filter_state_prev.as_ref().unwrap();
303 let f_prev2 = &f_prev_history[0]; let f_prev3 = &f_prev_history[1]; let o_prev = self.output_state.as_ref().unwrap(); let o_prev2 = self.output_state_prev.as_ref().unwrap(); let f_offset = fc.f_offset;
310 let f_scale = fc.f_scale;
311 let o_offset = fc.o_offset;
312 let o_scale = fc.o_scale;
313 let fc_coeffs = fc.fc;
314 let oc_coeffs = fc.oc;
315
316 let mut f_new = vec![0.0f64; n];
317 let mut o_new = vec![0.0f64; n];
318
319 for i in 0..n {
320 f_new[i] = fc_coeffs[0] * values[i]
322 + fc_coeffs[1] * f_prev[i]
323 + fc_coeffs[2] * (f_prev2[i] - f_offset)
324 + fc_coeffs[3] * (f_prev3[i] - f_offset);
325 f_new[i] = f_offset + f_scale * f_new[i];
327
328 o_new[i] = oc_coeffs[0] * f_new[i]
330 + oc_coeffs[1] * f_prev[i]
331 + oc_coeffs[2] * (o_prev[i] - o_offset)
332 + oc_coeffs[3] * (o_prev2[i] - o_offset);
333 o_new[i] = o_offset + o_scale * o_new[i];
335 }
336
337 let old_f_prev = f_prev.clone();
339 self.filter_state_prev = Some(vec![old_f_prev, f_prev2.clone()]);
340 self.output_state_prev = Some(o_prev.clone());
342 self.output_state = Some(o_new.clone());
343 self.filter_state = Some(f_new);
344
345 self.num_filtered += 1;
346
347 if fc.auto_reset && fc.num_filter > 0 && self.num_filtered >= fc.num_filter {
349 self.reset_filter();
350 }
351
352 values = o_new;
353 }
354 }
355
356 let out_type = self.config.output_type.unwrap_or(src.data.data_type());
358 let mut out_data = NDDataBuffer::zeros(out_type, n);
359 for i in 0..n {
360 out_data.set_from_f64(i, values[i]);
361 }
362
363 let mut arr = NDArray::new(src.dims.clone(), out_type);
364 arr.data = out_data;
365 arr.unique_id = src.unique_id;
366 arr.timestamp = src.timestamp;
367 arr.attributes = src.attributes.clone();
368 arr
369 }
370}
371
372pub struct ProcessProcessor {
376 state: ProcessState,
377}
378
379impl ProcessProcessor {
380 pub fn new(config: ProcessConfig) -> Self {
381 Self {
382 state: ProcessState::new(config),
383 }
384 }
385
386 pub fn state(&self) -> &ProcessState {
387 &self.state
388 }
389
390 pub fn state_mut(&mut self) -> &mut ProcessState {
391 &mut self.state
392 }
393}
394
395impl NDPluginProcess for ProcessProcessor {
396 fn process_array(&mut self, array: &NDArray, _pool: &NDArrayPool) -> ProcessResult {
397 let out = self.state.process(array);
398 ProcessResult::arrays(vec![Arc::new(out)])
399 }
400
401 fn plugin_type(&self) -> &str {
402 "NDPluginProcess"
403 }
404}
405
406#[cfg(test)]
407mod tests {
408 use super::*;
409 use ad_core::ndarray::{NDDimension, NDDataBuffer};
410
411 fn make_array(vals: &[u8]) -> NDArray {
412 let mut arr = NDArray::new(
413 vec![NDDimension::new(vals.len())],
414 NDDataType::UInt8,
415 );
416 if let NDDataBuffer::U8(ref mut v) = arr.data {
417 v.copy_from_slice(vals);
418 }
419 arr
420 }
421
422 fn make_f64_array(vals: &[f64]) -> NDArray {
423 let mut arr = NDArray::new(
424 vec![NDDimension::new(vals.len())],
425 NDDataType::Float64,
426 );
427 if let NDDataBuffer::F64(ref mut v) = arr.data {
428 v.copy_from_slice(vals);
429 }
430 arr
431 }
432
433 #[test]
434 fn test_background_subtraction() {
435 let bg_arr = make_array(&[10, 20, 30]);
436 let input = make_array(&[15, 25, 35]);
437
438 let mut state = ProcessState::new(ProcessConfig {
439 enable_background: true,
440 ..Default::default()
441 });
442 state.save_background(&bg_arr);
443
444 let result = state.process(&input);
445 if let NDDataBuffer::U8(ref v) = result.data {
446 assert_eq!(v[0], 5);
447 assert_eq!(v[1], 5);
448 assert_eq!(v[2], 5);
449 }
450 }
451
452 #[test]
453 fn test_flat_field() {
454 let ff_arr = make_array(&[100, 200, 50]);
455 let input = make_array(&[100, 200, 50]);
456
457 let mut state = ProcessState::new(ProcessConfig {
458 enable_flat_field: true,
459 ..Default::default()
460 });
461 state.save_flat_field(&ff_arr);
462
463 let result = state.process(&input);
464 if let NDDataBuffer::U8(ref v) = result.data {
466 assert!((v[0] as f64 - 116.67).abs() < 1.0);
468 assert!((v[1] as f64 - 116.67).abs() < 1.0);
469 assert!((v[2] as f64 - 116.67).abs() < 1.0);
470 }
471 }
472
473 #[test]
474 fn test_offset_scale() {
475 let input = make_array(&[10, 20, 30]);
476 let mut state = ProcessState::new(ProcessConfig {
477 enable_offset_scale: true,
478 scale: 2.0,
479 offset: 5.0,
480 ..Default::default()
481 });
482
483 let result = state.process(&input);
484 if let NDDataBuffer::U8(ref v) = result.data {
485 assert_eq!(v[0], 25); assert_eq!(v[1], 45); assert_eq!(v[2], 65); }
489 }
490
491 #[test]
492 fn test_clipping() {
493 let input = make_array(&[5, 50, 200]);
494 let mut state = ProcessState::new(ProcessConfig {
495 enable_low_clip: true,
496 low_clip: 10.0,
497 enable_high_clip: true,
498 high_clip: 100.0,
499 ..Default::default()
500 });
501
502 let result = state.process(&input);
503 if let NDDataBuffer::U8(ref v) = result.data {
504 assert_eq!(v[0], 10); assert_eq!(v[1], 50); assert_eq!(v[2], 100); }
508 }
509
510 #[test]
511 fn test_recursive_filter() {
512 let input1 = make_array(&[100, 100, 100]);
515 let input2 = make_array(&[0, 0, 0]);
516
517 let mut state = ProcessState::new(ProcessConfig {
518 enable_filter: true,
519 filter: FilterConfig {
520 fc: [0.5, 0.5, 0.0, 0.0],
521 oc: [1.0, 0.0, 0.0, 0.0],
522 ..Default::default()
523 },
524 ..Default::default()
525 });
526
527 let _ = state.process(&input1); let result = state.process(&input2); if let NDDataBuffer::U8(ref v) = result.data {
530 assert_eq!(v[0], 50);
531 assert_eq!(v[1], 50);
532 }
533 }
534
535 #[test]
536 fn test_output_type_conversion() {
537 let input = make_array(&[10, 20, 30]);
538 let mut state = ProcessState::new(ProcessConfig {
539 output_type: Some(NDDataType::Float64),
540 ..Default::default()
541 });
542
543 let result = state.process(&input);
544 assert_eq!(result.data.data_type(), NDDataType::Float64);
545 }
546
547 #[test]
550 fn test_process_processor() {
551 let mut proc = ProcessProcessor::new(ProcessConfig {
552 enable_offset_scale: true,
553 scale: 2.0,
554 offset: 1.0,
555 ..Default::default()
556 });
557 let pool = NDArrayPool::new(1_000_000);
558
559 let input = make_array(&[10, 20, 30]);
560 let result = proc.process_array(&input, &pool);
561 assert_eq!(result.output_arrays.len(), 1);
562 if let NDDataBuffer::U8(ref v) = result.output_arrays[0].data {
563 assert_eq!(v[0], 21); }
565 }
566
567 #[test]
570 fn test_4tap_filter_averaging() {
571 let mut state = ProcessState::new(ProcessConfig {
574 enable_filter: true,
575 filter: FilterConfig {
576 fc: [0.25, 0.75, 0.0, 0.0],
577 oc: [1.0, 0.0, 0.0, 0.0],
578 ..Default::default()
579 },
580 output_type: Some(NDDataType::Float64),
581 ..Default::default()
582 });
583
584 let r1 = state.process(&make_f64_array(&[100.0]));
586 let v1 = r1.data.get_as_f64(0).unwrap();
587 assert!((v1 - 100.0).abs() < 1e-9, "frame 1: got {v1}");
588
589 let r2 = state.process(&make_f64_array(&[100.0]));
591 let v2 = r2.data.get_as_f64(0).unwrap();
592 assert!((v2 - 100.0).abs() < 1e-9, "frame 2: got {v2}");
593
594 let r3 = state.process(&make_f64_array(&[0.0]));
596 let v3 = r3.data.get_as_f64(0).unwrap();
597 assert!((v3 - 75.0).abs() < 1e-9, "frame 3: got {v3}");
598
599 let r4 = state.process(&make_f64_array(&[0.0]));
601 let v4 = r4.data.get_as_f64(0).unwrap();
602 assert!((v4 - 56.25).abs() < 1e-9, "frame 4: got {v4}");
603 }
604
605 #[test]
606 fn test_4tap_filter_all_taps() {
607 let mut state = ProcessState::new(ProcessConfig {
609 enable_filter: true,
610 filter: FilterConfig {
611 fc: [0.5, 0.3, 0.1, 0.1],
612 oc: [0.7, 0.2, 0.05, 0.05],
613 rc: [1.0, 0.0],
614 f_offset: 0.0,
615 f_scale: 1.0,
616 o_offset: 0.0,
617 o_scale: 1.0,
618 ..Default::default()
619 },
620 output_type: Some(NDDataType::Float64),
621 ..Default::default()
622 });
623
624 let _ = state.process(&make_f64_array(&[10.0]));
626
627 let r1 = state.process(&make_f64_array(&[20.0]));
631 let v1 = r1.data.get_as_f64(0).unwrap();
632 assert!((v1 - 11.45).abs() < 1e-9, "frame 1: got {v1}");
633
634 let r2 = state.process(&make_f64_array(&[30.0]));
638 let v2 = r2.data.get_as_f64(0).unwrap();
639 assert!((v2 - 17.4525).abs() < 1e-9, "frame 2: got {v2}");
640 }
641
642 #[test]
643 fn test_save_background_one_shot() {
644 let mut state = ProcessState::new(ProcessConfig {
645 save_background: true,
646 ..Default::default()
647 });
648
649 assert!(!state.config.valid_background);
650 assert!(state.background.is_none());
651
652 let input = make_array(&[10, 20, 30]);
654 let _ = state.process(&input);
655
656 assert!(!state.config.save_background, "save_background should be cleared");
657 assert!(state.config.valid_background, "valid_background should be set");
658 assert!(state.background.is_some());
659
660 let bg = state.background.as_ref().unwrap();
661 assert_eq!(bg.len(), 3);
662 assert!((bg[0] - 10.0).abs() < 1e-9);
663 assert!((bg[1] - 20.0).abs() < 1e-9);
664 assert!((bg[2] - 30.0).abs() < 1e-9);
665
666 let input2 = make_array(&[40, 50, 60]);
668 let _ = state.process(&input2);
669
670 assert!(!state.config.save_background, "save_background stays cleared");
671 let bg2 = state.background.as_ref().unwrap();
673 assert!((bg2[0] - 10.0).abs() < 1e-9);
674 }
675
676 #[test]
677 fn test_save_flat_field_one_shot() {
678 let mut state = ProcessState::new(ProcessConfig {
679 save_flat_field: true,
680 ..Default::default()
681 });
682
683 assert!(!state.config.valid_flat_field);
684 assert!(state.flat_field.is_none());
685
686 let input = make_array(&[50, 100, 150]);
687 let _ = state.process(&input);
688
689 assert!(!state.config.save_flat_field, "save_flat_field should be cleared");
690 assert!(state.config.valid_flat_field, "valid_flat_field should be set");
691 assert!(state.flat_field.is_some());
692
693 let ff = state.flat_field.as_ref().unwrap();
694 assert_eq!(ff.len(), 3);
695 assert!((ff[0] - 50.0).abs() < 1e-9);
696 assert!((ff[1] - 100.0).abs() < 1e-9);
697 assert!((ff[2] - 150.0).abs() < 1e-9);
698 }
699
700 #[test]
701 fn test_auto_reset_when_num_filter_reached() {
702 let mut state = ProcessState::new(ProcessConfig {
703 enable_filter: true,
704 filter: FilterConfig {
705 num_filter: 3,
706 auto_reset: true,
707 fc: [0.5, 0.5, 0.0, 0.0],
708 oc: [1.0, 0.0, 0.0, 0.0],
709 ..Default::default()
710 },
711 output_type: Some(NDDataType::Float64),
712 ..Default::default()
713 });
714
715 let _ = state.process(&make_f64_array(&[100.0]));
717 assert_eq!(state.num_filtered, 1);
718
719 let _ = state.process(&make_f64_array(&[100.0]));
721 assert_eq!(state.num_filtered, 2);
722
723 let _ = state.process(&make_f64_array(&[100.0]));
725 assert_eq!(state.num_filtered, 0, "auto_reset should have fired");
726 assert!(state.filter_state.is_none(), "filter state should be cleared");
727 assert!(state.output_state.is_none(), "output state should be cleared");
728
729 let _ = state.process(&make_f64_array(&[200.0]));
731 assert_eq!(state.num_filtered, 1, "fresh start after reset");
732 }
733
734 #[test]
735 fn test_filter_with_offset_scale() {
736 let mut state = ProcessState::new(ProcessConfig {
738 enable_filter: true,
739 filter: FilterConfig {
740 fc: [1.0, 0.0, 0.0, 0.0],
741 oc: [1.0, 0.0, 0.0, 0.0],
742 f_offset: 10.0,
743 f_scale: 2.0,
744 o_offset: 5.0,
745 o_scale: 3.0,
746 ..Default::default()
747 },
748 output_type: Some(NDDataType::Float64),
749 ..Default::default()
750 });
751
752 let r0 = state.process(&make_f64_array(&[50.0]));
755 let v0 = r0.data.get_as_f64(0).unwrap();
756 assert!((v0 - 155.0).abs() < 1e-9, "frame 0: got {v0}");
757
758 let r1 = state.process(&make_f64_array(&[20.0]));
762 let v1 = r1.data.get_as_f64(0).unwrap();
763 assert!((v1 - 155.0).abs() < 1e-9, "frame 1: got {v1}");
764 }
765
766 #[test]
767 fn test_reset_filter_manual() {
768 let mut state = ProcessState::new(ProcessConfig {
769 enable_filter: true,
770 filter: FilterConfig {
771 fc: [0.5, 0.5, 0.0, 0.0],
772 oc: [1.0, 0.0, 0.0, 0.0],
773 ..Default::default()
774 },
775 output_type: Some(NDDataType::Float64),
776 ..Default::default()
777 });
778
779 let _ = state.process(&make_f64_array(&[100.0]));
781 let _ = state.process(&make_f64_array(&[100.0]));
782 assert!(state.filter_state.is_some());
783 assert_eq!(state.num_filtered, 2);
784
785 state.reset_filter();
787 assert!(state.filter_state.is_none());
788 assert!(state.output_state.is_none());
789 assert_eq!(state.num_filtered, 0);
790
791 let r = state.process(&make_f64_array(&[200.0]));
793 let v = r.data.get_as_f64(0).unwrap();
794 assert!((v - 200.0).abs() < 1e-9, "after reset, first frame: got {v}");
795 assert_eq!(state.num_filtered, 1);
796 }
797}