Skip to main content

ad_plugins_rs/
transform.rs

1use std::sync::Arc;
2
3use ad_core_rs::color::NDColorMode;
4use ad_core_rs::ndarray::{NDArray, NDDataBuffer, NDDimension};
5use ad_core_rs::ndarray_pool::NDArrayPool;
6use ad_core_rs::plugin::runtime::{NDPluginProcess, ProcessResult};
7
8/// Transform types matching C++ `NDPluginTransformType_t`.
9///
10/// The numeric ordering is the C++ enum order:
11/// `None=0, Rotate90=1, Rotate180=2, Rotate270=3, Mirror=4,
12/// Rotate90Mirror=5, Rotate180Mirror=6, Rotate270Mirror=7`.
13///
14/// - `Mirror` is a horizontal flip.
15/// - `Rotate90Mirror` is the transpose (main-diagonal flip).
16/// - `Rotate180Mirror` is a vertical flip.
17/// - `Rotate270Mirror` is the anti-diagonal flip.
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19#[repr(u8)]
20pub enum TransformType {
21    None = 0,
22    Rot90CW = 1,
23    Rot180 = 2,
24    Rot90CCW = 3,
25    FlipHoriz = 4,
26    /// C++ `Rotate90Mirror`: transpose / main-diagonal flip.
27    FlipDiag = 5,
28    /// C++ `Rotate180Mirror`: vertical flip.
29    FlipVert = 6,
30    /// C++ `Rotate270Mirror`: anti-diagonal flip.
31    FlipAntiDiag = 7,
32}
33
34impl TransformType {
35    pub fn from_u8(v: u8) -> Self {
36        match v {
37            1 => Self::Rot90CW,
38            2 => Self::Rot180,
39            3 => Self::Rot90CCW,
40            4 => Self::FlipHoriz,
41            // C++ TransformRotate90Mirror == transpose.
42            5 => Self::FlipDiag,
43            // C++ TransformRotate180Mirror == vertical flip.
44            6 => Self::FlipVert,
45            7 => Self::FlipAntiDiag,
46            _ => Self::None,
47        }
48    }
49
50    /// Whether this transform swaps x and y dimensions.
51    pub fn swaps_dims(&self) -> bool {
52        matches!(
53            self,
54            Self::Rot90CW | Self::Rot90CCW | Self::FlipDiag | Self::FlipAntiDiag
55        )
56    }
57}
58
59/// Map source (x, y) to destination (x, y) for the given transform.
60fn map_coords(
61    sx: usize,
62    sy: usize,
63    src_w: usize,
64    src_h: usize,
65    transform: TransformType,
66) -> (usize, usize) {
67    match transform {
68        TransformType::None => (sx, sy),
69        TransformType::Rot90CW => (src_h - 1 - sy, sx),
70        TransformType::Rot180 => (src_w - 1 - sx, src_h - 1 - sy),
71        TransformType::Rot90CCW => (sy, src_w - 1 - sx),
72        TransformType::FlipHoriz => (src_w - 1 - sx, sy),
73        TransformType::FlipVert => (sx, src_h - 1 - sy),
74        TransformType::FlipDiag => (sy, sx),
75        TransformType::FlipAntiDiag => (src_h - 1 - sy, src_w - 1 - sx),
76    }
77}
78
79/// Per-color-mode element strides for a 2-D or 3-D image of the given
80/// X/Y/color sizes. Mirrors C++ `NDArray::getInfo` stride layout: returns
81/// `(x_stride, y_stride, color_stride)` and the destination dimension order.
82fn strides_for(color_mode: NDColorMode, xs: usize, ys: usize, cs: usize) -> (usize, usize, usize) {
83    match color_mode {
84        NDColorMode::RGB1 => (cs, xs * cs, 1),
85        NDColorMode::RGB2 => (1, xs * cs, xs),
86        // RGB3 / Mono / others: planar X-fastest layout.
87        _ => (1, xs, xs * ys),
88    }
89}
90
91/// Build the destination dimension vector for `color_mode` with the given
92/// X/Y/color sizes, matching the C++ dimension order per color mode.
93fn dims_for(
94    color_mode: NDColorMode,
95    xs: usize,
96    ys: usize,
97    cs: usize,
98    ndims: usize,
99) -> Vec<NDDimension> {
100    if ndims < 3 {
101        return vec![NDDimension::new(xs), NDDimension::new(ys)];
102    }
103    match color_mode {
104        NDColorMode::RGB1 => vec![
105            NDDimension::new(cs),
106            NDDimension::new(xs),
107            NDDimension::new(ys),
108        ],
109        NDColorMode::RGB2 => vec![
110            NDDimension::new(xs),
111            NDDimension::new(cs),
112            NDDimension::new(ys),
113        ],
114        _ => vec![
115            NDDimension::new(xs),
116            NDDimension::new(ys),
117            NDDimension::new(cs),
118        ],
119    }
120}
121
122/// Apply a transform to an NDArray.
123///
124/// Handles 2-D mono images and 3-D RGB1/RGB2/RGB3 color images. The per-color
125/// reindexing mirrors C++ `transformNDArray`: source `(x, y)` is geometrically
126/// mapped to destination `(x, y)` and every color component is copied with the
127/// destination strides recomputed for the (possibly swapped) X/Y sizes.
128pub fn apply_transform(src: &NDArray, transform: TransformType) -> NDArray {
129    if transform == TransformType::None || src.dims.len() < 2 {
130        return src.clone();
131    }
132
133    let info = src.info();
134    let src_w = info.x_size;
135    let src_h = info.y_size;
136    let color = info.color_size.max(1);
137    if src_w == 0 || src_h == 0 {
138        return src.clone();
139    }
140
141    let (dst_w, dst_h) = if transform.swaps_dims() {
142        (src_h, src_w)
143    } else {
144        (src_w, src_h)
145    };
146
147    let (sxs, sys, scs) = (
148        info.x_stride,
149        info.y_stride.max(1),
150        info.color_stride.max(1),
151    );
152    let (dxs, dys, dcs) = strides_for(info.color_mode, dst_w, dst_h, color);
153    let total = dst_w * dst_h * color;
154
155    macro_rules! transform_buf {
156        ($vec:expr, $zero:expr) => {{
157            let mut out = vec![$zero; total];
158            for sy in 0..src_h {
159                for sx in 0..src_w {
160                    let (dx, dy) = map_coords(sx, sy, src_w, src_h, transform);
161                    let s_base = sy * sys + sx * sxs;
162                    let d_base = dy * dys + dx * dxs;
163                    for c in 0..color {
164                        out[d_base + c * dcs] = $vec[s_base + c * scs];
165                    }
166                }
167            }
168            out
169        }};
170    }
171
172    let out_data = match &src.data {
173        NDDataBuffer::U8(v) => NDDataBuffer::U8(transform_buf!(v, 0)),
174        NDDataBuffer::U16(v) => NDDataBuffer::U16(transform_buf!(v, 0)),
175        NDDataBuffer::I8(v) => NDDataBuffer::I8(transform_buf!(v, 0)),
176        NDDataBuffer::I16(v) => NDDataBuffer::I16(transform_buf!(v, 0)),
177        NDDataBuffer::I32(v) => NDDataBuffer::I32(transform_buf!(v, 0)),
178        NDDataBuffer::U32(v) => NDDataBuffer::U32(transform_buf!(v, 0)),
179        NDDataBuffer::I64(v) => NDDataBuffer::I64(transform_buf!(v, 0)),
180        NDDataBuffer::U64(v) => NDDataBuffer::U64(transform_buf!(v, 0)),
181        NDDataBuffer::F32(v) => NDDataBuffer::F32(transform_buf!(v, 0.0)),
182        NDDataBuffer::F64(v) => NDDataBuffer::F64(transform_buf!(v, 0.0)),
183    };
184
185    let dims = dims_for(info.color_mode, dst_w, dst_h, color, src.dims.len());
186    let mut arr = NDArray::new(dims, src.data.data_type());
187    arr.data = out_data;
188    arr.unique_id = src.unique_id;
189    arr.timestamp = src.timestamp;
190    arr.time_stamp = src.time_stamp;
191    arr.attributes = src.attributes.clone();
192    arr
193}
194
195// --- New TransformProcessor (NDPluginProcess-based) ---
196
197/// Pure transform processing logic.
198pub struct TransformProcessor {
199    transform: TransformType,
200    transform_type_idx: Option<usize>,
201}
202
203impl TransformProcessor {
204    pub fn new(transform: TransformType) -> Self {
205        Self {
206            transform,
207            transform_type_idx: None,
208        }
209    }
210}
211
212impl NDPluginProcess for TransformProcessor {
213    fn process_array(&mut self, array: &NDArray, _pool: &NDArrayPool) -> ProcessResult {
214        let out = apply_transform(array, self.transform);
215        ProcessResult::arrays(vec![Arc::new(out)])
216    }
217
218    fn plugin_type(&self) -> &str {
219        "NDPluginTransform"
220    }
221
222    fn register_params(
223        &mut self,
224        base: &mut asyn_rs::port::PortDriverBase,
225    ) -> asyn_rs::error::AsynResult<()> {
226        use asyn_rs::param::ParamType;
227        base.create_param("TRANSFORM_TYPE", ParamType::Int32)?;
228        self.transform_type_idx = base.find_param("TRANSFORM_TYPE");
229        Ok(())
230    }
231
232    fn on_param_change(
233        &mut self,
234        reason: usize,
235        params: &ad_core_rs::plugin::runtime::PluginParamSnapshot,
236    ) -> ad_core_rs::plugin::runtime::ParamChangeResult {
237        if Some(reason) == self.transform_type_idx {
238            self.transform = TransformType::from_u8(params.value.as_i32() as u8);
239        }
240        ad_core_rs::plugin::runtime::ParamChangeResult::updates(vec![])
241    }
242}
243
244#[cfg(test)]
245mod tests {
246    use super::*;
247    use ad_core_rs::ndarray::NDDataType;
248
249    /// Create a 3x2 array:
250    /// [1, 2, 3]
251    /// [4, 5, 6]
252    fn make_3x2() -> NDArray {
253        let mut arr = NDArray::new(
254            vec![NDDimension::new(3), NDDimension::new(2)],
255            NDDataType::UInt8,
256        );
257        if let NDDataBuffer::U8(ref mut v) = arr.data {
258            *v = vec![1, 2, 3, 4, 5, 6];
259        }
260        arr
261    }
262
263    fn get_u8(arr: &NDArray) -> &[u8] {
264        match &arr.data {
265            NDDataBuffer::U8(v) => v,
266            _ => panic!("not u8"),
267        }
268    }
269
270    #[test]
271    fn test_none() {
272        let arr = make_3x2();
273        let out = apply_transform(&arr, TransformType::None);
274        assert_eq!(get_u8(&out), &[1, 2, 3, 4, 5, 6]);
275    }
276
277    #[test]
278    fn test_rot90cw() {
279        let arr = make_3x2();
280        let out = apply_transform(&arr, TransformType::Rot90CW);
281        assert_eq!(out.dims[0].size, 2);
282        assert_eq!(out.dims[1].size, 3);
283        // Expected:
284        // [4, 1]
285        // [5, 2]
286        // [6, 3]
287        assert_eq!(get_u8(&out), &[4, 1, 5, 2, 6, 3]);
288    }
289
290    #[test]
291    fn test_rot180() {
292        let arr = make_3x2();
293        let out = apply_transform(&arr, TransformType::Rot180);
294        assert_eq!(out.dims[0].size, 3);
295        assert_eq!(out.dims[1].size, 2);
296        assert_eq!(get_u8(&out), &[6, 5, 4, 3, 2, 1]);
297    }
298
299    #[test]
300    fn test_rot90ccw() {
301        let arr = make_3x2();
302        let out = apply_transform(&arr, TransformType::Rot90CCW);
303        assert_eq!(out.dims[0].size, 2);
304        assert_eq!(out.dims[1].size, 3);
305        // Expected:
306        // [3, 6]
307        // [2, 5]
308        // [1, 4]
309        assert_eq!(get_u8(&out), &[3, 6, 2, 5, 1, 4]);
310    }
311
312    #[test]
313    fn test_flip_horiz() {
314        let arr = make_3x2();
315        let out = apply_transform(&arr, TransformType::FlipHoriz);
316        assert_eq!(get_u8(&out), &[3, 2, 1, 6, 5, 4]);
317    }
318
319    #[test]
320    fn test_flip_vert() {
321        let arr = make_3x2();
322        let out = apply_transform(&arr, TransformType::FlipVert);
323        assert_eq!(get_u8(&out), &[4, 5, 6, 1, 2, 3]);
324    }
325
326    #[test]
327    fn test_flip_diag() {
328        let arr = make_3x2();
329        let out = apply_transform(&arr, TransformType::FlipDiag);
330        assert_eq!(out.dims[0].size, 2);
331        assert_eq!(out.dims[1].size, 3);
332        // Transpose:
333        // [1, 4]
334        // [2, 5]
335        // [3, 6]
336        assert_eq!(get_u8(&out), &[1, 4, 2, 5, 3, 6]);
337    }
338
339    #[test]
340    fn test_flip_anti_diag() {
341        let arr = make_3x2();
342        let out = apply_transform(&arr, TransformType::FlipAntiDiag);
343        assert_eq!(out.dims[0].size, 2);
344        assert_eq!(out.dims[1].size, 3);
345        // Anti-transpose:
346        // [6, 3]
347        // [5, 2]
348        // [4, 1]
349        assert_eq!(get_u8(&out), &[6, 3, 5, 2, 4, 1]);
350    }
351
352    #[test]
353    fn test_rot90_roundtrip() {
354        let arr = make_3x2();
355        let r1 = apply_transform(&arr, TransformType::Rot90CW);
356        let r2 = apply_transform(&r1, TransformType::Rot90CW);
357        let r3 = apply_transform(&r2, TransformType::Rot90CW);
358        let r4 = apply_transform(&r3, TransformType::Rot90CW);
359        assert_eq!(get_u8(&r4), get_u8(&arr));
360        assert_eq!(r4.dims[0].size, arr.dims[0].size);
361        assert_eq!(r4.dims[1].size, arr.dims[1].size);
362    }
363
364    #[test]
365    fn test_from_u8_cpp_enum_order() {
366        // C++ NDPluginTransformType_t order: value 5 is Rotate90Mirror
367        // (transpose), value 6 is Rotate180Mirror (vertical flip).
368        assert_eq!(TransformType::from_u8(0), TransformType::None);
369        assert_eq!(TransformType::from_u8(1), TransformType::Rot90CW);
370        assert_eq!(TransformType::from_u8(2), TransformType::Rot180);
371        assert_eq!(TransformType::from_u8(3), TransformType::Rot90CCW);
372        assert_eq!(TransformType::from_u8(4), TransformType::FlipHoriz);
373        assert_eq!(TransformType::from_u8(5), TransformType::FlipDiag);
374        assert_eq!(TransformType::from_u8(6), TransformType::FlipVert);
375        assert_eq!(TransformType::from_u8(7), TransformType::FlipAntiDiag);
376    }
377
378    #[test]
379    fn test_transform_5_is_transpose() {
380        // Selecting transform 5 from EPICS must produce a transpose.
381        let arr = make_3x2();
382        let out = apply_transform(&arr, TransformType::from_u8(5));
383        assert_eq!(out.dims[0].size, 2);
384        assert_eq!(out.dims[1].size, 3);
385        assert_eq!(get_u8(&out), &[1, 4, 2, 5, 3, 6]); // transpose
386    }
387
388    #[test]
389    fn test_transform_6_is_vertical_flip() {
390        // Selecting transform 6 from EPICS must produce a vertical flip.
391        let arr = make_3x2();
392        let out = apply_transform(&arr, TransformType::from_u8(6));
393        assert_eq!(out.dims[0].size, 3);
394        assert_eq!(out.dims[1].size, 2);
395        assert_eq!(get_u8(&out), &[4, 5, 6, 1, 2, 3]); // vertical flip
396    }
397
398    /// Build a 2x2 RGB1 image (color-interleaved): pixel (x,y) channel c.
399    /// dims = [color=3, x=2, y=2]. Pixel value encodes 100*y + 10*x + c.
400    fn make_rgb1_2x2() -> NDArray {
401        use ad_core_rs::attributes::{NDAttrSource, NDAttrValue, NDAttribute};
402        let mut arr = NDArray::new(
403            vec![
404                NDDimension::new(3),
405                NDDimension::new(2),
406                NDDimension::new(2),
407            ],
408            NDDataType::UInt8,
409        );
410        arr.attributes.add(NDAttribute::new_static(
411            "ColorMode",
412            "",
413            NDAttrSource::Driver,
414            NDAttrValue::Int32(NDColorMode::RGB1 as i32),
415        ));
416        if let NDDataBuffer::U8(ref mut v) = arr.data {
417            // layout: index = y*(x*c) + x*c + c, with x_stride=3, y_stride=6
418            for y in 0..2 {
419                for x in 0..2 {
420                    for c in 0..3 {
421                        v[y * 6 + x * 3 + c] = (100 * y + 10 * x + c) as u8;
422                    }
423                }
424            }
425        }
426        arr
427    }
428
429    #[test]
430    fn test_rgb1_flip_horiz_keeps_color_grouping() {
431        // Horizontal flip of an RGB1 image: each pixel's 3 channels stay
432        // together; only the x coordinate is mirrored.
433        let arr = make_rgb1_2x2();
434        let out = apply_transform(&arr, TransformType::FlipHoriz);
435        // dims unchanged for a non-swapping transform
436        assert_eq!(out.dims[0].size, 3);
437        assert_eq!(out.dims[1].size, 2);
438        assert_eq!(out.dims[2].size, 2);
439        if let NDDataBuffer::U8(v) = &out.data {
440            // pixel (x=0,y=0) should now hold source (x=1,y=0): 10,11,12
441            assert_eq!(&v[0..3], &[10, 11, 12]);
442            // pixel (x=1,y=0) holds source (x=0,y=0): 0,1,2
443            assert_eq!(&v[3..6], &[0, 1, 2]);
444            // pixel (x=0,y=1) holds source (x=1,y=1): 110,111,112
445            assert_eq!(&v[6..9], &[110, 111, 112]);
446        } else {
447            panic!("not u8");
448        }
449    }
450
451    #[test]
452    fn test_rgb1_rot90cw_swaps_dims_and_keeps_color() {
453        let arr = make_rgb1_2x2();
454        let out = apply_transform(&arr, TransformType::Rot90CW);
455        // x/y swapped (both 2 here), color dim preserved
456        assert_eq!(out.dims[0].size, 3);
457        assert_eq!(out.dims[1].size, 2);
458        assert_eq!(out.dims[2].size, 2);
459        if let NDDataBuffer::U8(v) = &out.data {
460            // Rot90CW maps src (sx,sy) -> (src_h-1-sy, sx).
461            // dest (0,0) <- src (sx,sy) with src_h-1-sy=0, sx=0 => sy=1,sx=0
462            // src (0,1) = 100,101,102
463            assert_eq!(&v[0..3], &[100, 101, 102]);
464        } else {
465            panic!("not u8");
466        }
467    }
468
469    // --- New TransformProcessor tests ---
470
471    #[test]
472    fn test_transform_processor() {
473        let mut proc = TransformProcessor::new(TransformType::Rot90CW);
474        let pool = NDArrayPool::new(1_000_000);
475
476        let arr = make_3x2();
477        let result = proc.process_array(&arr, &pool);
478        assert_eq!(result.output_arrays.len(), 1);
479        assert_eq!(result.output_arrays[0].dims[0].size, 2); // swapped
480        assert_eq!(result.output_arrays[0].dims[1].size, 3);
481        assert_eq!(get_u8(&result.output_arrays[0]), &[4, 1, 5, 2, 6, 3]);
482    }
483}