Skip to main content

ad_plugins_rs/
transform.rs

1use std::sync::Arc;
2
3use ad_core_rs::ndarray::{NDArray, NDDataBuffer, NDDimension};
4use ad_core_rs::ndarray_pool::NDArrayPool;
5use ad_core_rs::plugin::runtime::{NDPluginProcess, ProcessResult};
6
7/// Transform types matching C++ NDPluginTransform.
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9#[repr(u8)]
10pub enum TransformType {
11    None = 0,
12    Rot90CW = 1,
13    Rot180 = 2,
14    Rot90CCW = 3,
15    FlipHoriz = 4,
16    FlipVert = 5,
17    FlipDiag = 6,
18    FlipAntiDiag = 7,
19}
20
21impl TransformType {
22    pub fn from_u8(v: u8) -> Self {
23        match v {
24            1 => Self::Rot90CW,
25            2 => Self::Rot180,
26            3 => Self::Rot90CCW,
27            4 => Self::FlipHoriz,
28            5 => Self::FlipVert,
29            6 => Self::FlipDiag,
30            7 => Self::FlipAntiDiag,
31            _ => Self::None,
32        }
33    }
34
35    /// Whether this transform swaps x and y dimensions.
36    pub fn swaps_dims(&self) -> bool {
37        matches!(
38            self,
39            Self::Rot90CW | Self::Rot90CCW | Self::FlipDiag | Self::FlipAntiDiag
40        )
41    }
42}
43
44/// Map source (x, y) to destination (x, y) for the given transform.
45fn map_coords(
46    sx: usize,
47    sy: usize,
48    src_w: usize,
49    src_h: usize,
50    transform: TransformType,
51) -> (usize, usize) {
52    match transform {
53        TransformType::None => (sx, sy),
54        TransformType::Rot90CW => (src_h - 1 - sy, sx),
55        TransformType::Rot180 => (src_w - 1 - sx, src_h - 1 - sy),
56        TransformType::Rot90CCW => (sy, src_w - 1 - sx),
57        TransformType::FlipHoriz => (src_w - 1 - sx, sy),
58        TransformType::FlipVert => (sx, src_h - 1 - sy),
59        TransformType::FlipDiag => (sy, sx),
60        TransformType::FlipAntiDiag => (src_h - 1 - sy, src_w - 1 - sx),
61    }
62}
63
64/// Apply a 2D transform to an NDArray.
65pub fn apply_transform(src: &NDArray, transform: TransformType) -> NDArray {
66    if transform == TransformType::None || src.dims.len() < 2 {
67        return src.clone();
68    }
69
70    let src_w = src.dims[0].size;
71    let src_h = src.dims[1].size;
72    let (dst_w, dst_h) = if transform.swaps_dims() {
73        (src_h, src_w)
74    } else {
75        (src_w, src_h)
76    };
77
78    macro_rules! transform_buf {
79        ($vec:expr, $T:ty, $zero:expr) => {{
80            let mut out = vec![$zero; dst_w * dst_h];
81            for sy in 0..src_h {
82                for sx in 0..src_w {
83                    let (dx, dy) = map_coords(sx, sy, src_w, src_h, transform);
84                    out[dy * dst_w + dx] = $vec[sy * src_w + sx];
85                }
86            }
87            out
88        }};
89    }
90
91    let out_data = match &src.data {
92        NDDataBuffer::U8(v) => NDDataBuffer::U8(transform_buf!(v, u8, 0)),
93        NDDataBuffer::U16(v) => NDDataBuffer::U16(transform_buf!(v, u16, 0)),
94        NDDataBuffer::I8(v) => NDDataBuffer::I8(transform_buf!(v, i8, 0)),
95        NDDataBuffer::I16(v) => NDDataBuffer::I16(transform_buf!(v, i16, 0)),
96        NDDataBuffer::I32(v) => NDDataBuffer::I32(transform_buf!(v, i32, 0)),
97        NDDataBuffer::U32(v) => NDDataBuffer::U32(transform_buf!(v, u32, 0)),
98        NDDataBuffer::I64(v) => NDDataBuffer::I64(transform_buf!(v, i64, 0)),
99        NDDataBuffer::U64(v) => NDDataBuffer::U64(transform_buf!(v, u64, 0)),
100        NDDataBuffer::F32(v) => NDDataBuffer::F32(transform_buf!(v, f32, 0.0)),
101        NDDataBuffer::F64(v) => NDDataBuffer::F64(transform_buf!(v, f64, 0.0)),
102    };
103
104    let dims = vec![NDDimension::new(dst_w), NDDimension::new(dst_h)];
105    let mut arr = NDArray::new(dims, src.data.data_type());
106    arr.data = out_data;
107    arr.unique_id = src.unique_id;
108    arr.timestamp = src.timestamp;
109    arr.attributes = src.attributes.clone();
110    arr
111}
112
113// --- New TransformProcessor (NDPluginProcess-based) ---
114
115/// Pure transform processing logic.
116pub struct TransformProcessor {
117    transform: TransformType,
118    transform_type_idx: Option<usize>,
119}
120
121impl TransformProcessor {
122    pub fn new(transform: TransformType) -> Self {
123        Self {
124            transform,
125            transform_type_idx: None,
126        }
127    }
128}
129
130impl NDPluginProcess for TransformProcessor {
131    fn process_array(&mut self, array: &NDArray, _pool: &NDArrayPool) -> ProcessResult {
132        let out = apply_transform(array, self.transform);
133        ProcessResult::arrays(vec![Arc::new(out)])
134    }
135
136    fn plugin_type(&self) -> &str {
137        "NDPluginTransform"
138    }
139
140    fn register_params(
141        &mut self,
142        base: &mut asyn_rs::port::PortDriverBase,
143    ) -> asyn_rs::error::AsynResult<()> {
144        use asyn_rs::param::ParamType;
145        base.create_param("TRANSFORM_TYPE", ParamType::Int32)?;
146        self.transform_type_idx = base.find_param("TRANSFORM_TYPE");
147        Ok(())
148    }
149
150    fn on_param_change(
151        &mut self,
152        reason: usize,
153        params: &ad_core_rs::plugin::runtime::PluginParamSnapshot,
154    ) -> ad_core_rs::plugin::runtime::ParamChangeResult {
155        if Some(reason) == self.transform_type_idx {
156            self.transform = TransformType::from_u8(params.value.as_i32() as u8);
157        }
158        ad_core_rs::plugin::runtime::ParamChangeResult::updates(vec![])
159    }
160}
161
162#[cfg(test)]
163mod tests {
164    use super::*;
165    use ad_core_rs::ndarray::NDDataType;
166
167    /// Create a 3x2 array:
168    /// [1, 2, 3]
169    /// [4, 5, 6]
170    fn make_3x2() -> NDArray {
171        let mut arr = NDArray::new(
172            vec![NDDimension::new(3), NDDimension::new(2)],
173            NDDataType::UInt8,
174        );
175        if let NDDataBuffer::U8(ref mut v) = arr.data {
176            *v = vec![1, 2, 3, 4, 5, 6];
177        }
178        arr
179    }
180
181    fn get_u8(arr: &NDArray) -> &[u8] {
182        match &arr.data {
183            NDDataBuffer::U8(v) => v,
184            _ => panic!("not u8"),
185        }
186    }
187
188    #[test]
189    fn test_none() {
190        let arr = make_3x2();
191        let out = apply_transform(&arr, TransformType::None);
192        assert_eq!(get_u8(&out), &[1, 2, 3, 4, 5, 6]);
193    }
194
195    #[test]
196    fn test_rot90cw() {
197        let arr = make_3x2();
198        let out = apply_transform(&arr, TransformType::Rot90CW);
199        assert_eq!(out.dims[0].size, 2);
200        assert_eq!(out.dims[1].size, 3);
201        // Expected:
202        // [4, 1]
203        // [5, 2]
204        // [6, 3]
205        assert_eq!(get_u8(&out), &[4, 1, 5, 2, 6, 3]);
206    }
207
208    #[test]
209    fn test_rot180() {
210        let arr = make_3x2();
211        let out = apply_transform(&arr, TransformType::Rot180);
212        assert_eq!(out.dims[0].size, 3);
213        assert_eq!(out.dims[1].size, 2);
214        assert_eq!(get_u8(&out), &[6, 5, 4, 3, 2, 1]);
215    }
216
217    #[test]
218    fn test_rot90ccw() {
219        let arr = make_3x2();
220        let out = apply_transform(&arr, TransformType::Rot90CCW);
221        assert_eq!(out.dims[0].size, 2);
222        assert_eq!(out.dims[1].size, 3);
223        // Expected:
224        // [3, 6]
225        // [2, 5]
226        // [1, 4]
227        assert_eq!(get_u8(&out), &[3, 6, 2, 5, 1, 4]);
228    }
229
230    #[test]
231    fn test_flip_horiz() {
232        let arr = make_3x2();
233        let out = apply_transform(&arr, TransformType::FlipHoriz);
234        assert_eq!(get_u8(&out), &[3, 2, 1, 6, 5, 4]);
235    }
236
237    #[test]
238    fn test_flip_vert() {
239        let arr = make_3x2();
240        let out = apply_transform(&arr, TransformType::FlipVert);
241        assert_eq!(get_u8(&out), &[4, 5, 6, 1, 2, 3]);
242    }
243
244    #[test]
245    fn test_flip_diag() {
246        let arr = make_3x2();
247        let out = apply_transform(&arr, TransformType::FlipDiag);
248        assert_eq!(out.dims[0].size, 2);
249        assert_eq!(out.dims[1].size, 3);
250        // Transpose:
251        // [1, 4]
252        // [2, 5]
253        // [3, 6]
254        assert_eq!(get_u8(&out), &[1, 4, 2, 5, 3, 6]);
255    }
256
257    #[test]
258    fn test_flip_anti_diag() {
259        let arr = make_3x2();
260        let out = apply_transform(&arr, TransformType::FlipAntiDiag);
261        assert_eq!(out.dims[0].size, 2);
262        assert_eq!(out.dims[1].size, 3);
263        // Anti-transpose:
264        // [6, 3]
265        // [5, 2]
266        // [4, 1]
267        assert_eq!(get_u8(&out), &[6, 3, 5, 2, 4, 1]);
268    }
269
270    #[test]
271    fn test_rot90_roundtrip() {
272        let arr = make_3x2();
273        let r1 = apply_transform(&arr, TransformType::Rot90CW);
274        let r2 = apply_transform(&r1, TransformType::Rot90CW);
275        let r3 = apply_transform(&r2, TransformType::Rot90CW);
276        let r4 = apply_transform(&r3, TransformType::Rot90CW);
277        assert_eq!(get_u8(&r4), get_u8(&arr));
278        assert_eq!(r4.dims[0].size, arr.dims[0].size);
279        assert_eq!(r4.dims[1].size, arr.dims[1].size);
280    }
281
282    // --- New TransformProcessor tests ---
283
284    #[test]
285    fn test_transform_processor() {
286        let mut proc = TransformProcessor::new(TransformType::Rot90CW);
287        let pool = NDArrayPool::new(1_000_000);
288
289        let arr = make_3x2();
290        let result = proc.process_array(&arr, &pool);
291        assert_eq!(result.output_arrays.len(), 1);
292        assert_eq!(result.output_arrays[0].dims[0].size, 2); // swapped
293        assert_eq!(result.output_arrays[0].dims[1].size, 3);
294        assert_eq!(get_u8(&result.output_arrays[0]), &[4, 1, 5, 2, 6, 3]);
295    }
296}