Skip to main content

ad_plugins/
transform.rs

1use std::sync::Arc;
2
3use ad_core::ndarray::{NDArray, NDDataBuffer, NDDimension};
4use ad_core::ndarray_pool::NDArrayPool;
5use ad_core::plugin::runtime::{NDPluginProcess, ProcessResult};
6
7/// Transform types matching C++ NDPluginTransform.
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9#[repr(u8)]
10pub enum TransformType {
11    None = 0,
12    Rot90CW = 1,
13    Rot180 = 2,
14    Rot90CCW = 3,
15    FlipHoriz = 4,
16    FlipVert = 5,
17    FlipDiag = 6,
18    FlipAntiDiag = 7,
19}
20
21impl TransformType {
22    pub fn from_u8(v: u8) -> Self {
23        match v {
24            1 => Self::Rot90CW,
25            2 => Self::Rot180,
26            3 => Self::Rot90CCW,
27            4 => Self::FlipHoriz,
28            5 => Self::FlipVert,
29            6 => Self::FlipDiag,
30            7 => Self::FlipAntiDiag,
31            _ => Self::None,
32        }
33    }
34
35    /// Whether this transform swaps x and y dimensions.
36    pub fn swaps_dims(&self) -> bool {
37        matches!(self, Self::Rot90CW | Self::Rot90CCW | Self::FlipDiag | Self::FlipAntiDiag)
38    }
39}
40
41/// Map source (x, y) to destination (x, y) for the given transform.
42fn map_coords(
43    sx: usize,
44    sy: usize,
45    src_w: usize,
46    src_h: usize,
47    transform: TransformType,
48) -> (usize, usize) {
49    match transform {
50        TransformType::None => (sx, sy),
51        TransformType::Rot90CW => (src_h - 1 - sy, sx),
52        TransformType::Rot180 => (src_w - 1 - sx, src_h - 1 - sy),
53        TransformType::Rot90CCW => (sy, src_w - 1 - sx),
54        TransformType::FlipHoriz => (src_w - 1 - sx, sy),
55        TransformType::FlipVert => (sx, src_h - 1 - sy),
56        TransformType::FlipDiag => (sy, sx),
57        TransformType::FlipAntiDiag => (src_h - 1 - sy, src_w - 1 - sx),
58    }
59}
60
61/// Apply a 2D transform to an NDArray.
62pub fn apply_transform(src: &NDArray, transform: TransformType) -> NDArray {
63    if transform == TransformType::None || src.dims.len() < 2 {
64        return src.clone();
65    }
66
67    let src_w = src.dims[0].size;
68    let src_h = src.dims[1].size;
69    let (dst_w, dst_h) = if transform.swaps_dims() {
70        (src_h, src_w)
71    } else {
72        (src_w, src_h)
73    };
74
75    macro_rules! transform_buf {
76        ($vec:expr, $T:ty, $zero:expr) => {{
77            let mut out = vec![$zero; dst_w * dst_h];
78            for sy in 0..src_h {
79                for sx in 0..src_w {
80                    let (dx, dy) = map_coords(sx, sy, src_w, src_h, transform);
81                    out[dy * dst_w + dx] = $vec[sy * src_w + sx];
82                }
83            }
84            out
85        }};
86    }
87
88    let out_data = match &src.data {
89        NDDataBuffer::U8(v) => NDDataBuffer::U8(transform_buf!(v, u8, 0)),
90        NDDataBuffer::U16(v) => NDDataBuffer::U16(transform_buf!(v, u16, 0)),
91        NDDataBuffer::I8(v) => NDDataBuffer::I8(transform_buf!(v, i8, 0)),
92        NDDataBuffer::I16(v) => NDDataBuffer::I16(transform_buf!(v, i16, 0)),
93        NDDataBuffer::I32(v) => NDDataBuffer::I32(transform_buf!(v, i32, 0)),
94        NDDataBuffer::U32(v) => NDDataBuffer::U32(transform_buf!(v, u32, 0)),
95        NDDataBuffer::I64(v) => NDDataBuffer::I64(transform_buf!(v, i64, 0)),
96        NDDataBuffer::U64(v) => NDDataBuffer::U64(transform_buf!(v, u64, 0)),
97        NDDataBuffer::F32(v) => NDDataBuffer::F32(transform_buf!(v, f32, 0.0)),
98        NDDataBuffer::F64(v) => NDDataBuffer::F64(transform_buf!(v, f64, 0.0)),
99    };
100
101    let dims = vec![NDDimension::new(dst_w), NDDimension::new(dst_h)];
102    let mut arr = NDArray::new(dims, src.data.data_type());
103    arr.data = out_data;
104    arr.unique_id = src.unique_id;
105    arr.timestamp = src.timestamp;
106    arr.attributes = src.attributes.clone();
107    arr
108}
109
110// --- New TransformProcessor (NDPluginProcess-based) ---
111
112/// Pure transform processing logic.
113pub struct TransformProcessor {
114    transform: TransformType,
115}
116
117impl TransformProcessor {
118    pub fn new(transform: TransformType) -> Self {
119        Self { transform }
120    }
121}
122
123impl NDPluginProcess for TransformProcessor {
124    fn process_array(&mut self, array: &NDArray, _pool: &NDArrayPool) -> ProcessResult {
125        let out = apply_transform(array, self.transform);
126        ProcessResult::arrays(vec![Arc::new(out)])
127    }
128
129    fn plugin_type(&self) -> &str {
130        "NDPluginTransform"
131    }
132}
133
134#[cfg(test)]
135mod tests {
136    use super::*;
137    use ad_core::ndarray::NDDataType;
138
139    /// Create a 3x2 array:
140    /// [1, 2, 3]
141    /// [4, 5, 6]
142    fn make_3x2() -> NDArray {
143        let mut arr = NDArray::new(
144            vec![NDDimension::new(3), NDDimension::new(2)],
145            NDDataType::UInt8,
146        );
147        if let NDDataBuffer::U8(ref mut v) = arr.data {
148            *v = vec![1, 2, 3, 4, 5, 6];
149        }
150        arr
151    }
152
153    fn get_u8(arr: &NDArray) -> &[u8] {
154        match &arr.data {
155            NDDataBuffer::U8(v) => v,
156            _ => panic!("not u8"),
157        }
158    }
159
160    #[test]
161    fn test_none() {
162        let arr = make_3x2();
163        let out = apply_transform(&arr, TransformType::None);
164        assert_eq!(get_u8(&out), &[1, 2, 3, 4, 5, 6]);
165    }
166
167    #[test]
168    fn test_rot90cw() {
169        let arr = make_3x2();
170        let out = apply_transform(&arr, TransformType::Rot90CW);
171        assert_eq!(out.dims[0].size, 2);
172        assert_eq!(out.dims[1].size, 3);
173        // Expected:
174        // [4, 1]
175        // [5, 2]
176        // [6, 3]
177        assert_eq!(get_u8(&out), &[4, 1, 5, 2, 6, 3]);
178    }
179
180    #[test]
181    fn test_rot180() {
182        let arr = make_3x2();
183        let out = apply_transform(&arr, TransformType::Rot180);
184        assert_eq!(out.dims[0].size, 3);
185        assert_eq!(out.dims[1].size, 2);
186        assert_eq!(get_u8(&out), &[6, 5, 4, 3, 2, 1]);
187    }
188
189    #[test]
190    fn test_rot90ccw() {
191        let arr = make_3x2();
192        let out = apply_transform(&arr, TransformType::Rot90CCW);
193        assert_eq!(out.dims[0].size, 2);
194        assert_eq!(out.dims[1].size, 3);
195        // Expected:
196        // [3, 6]
197        // [2, 5]
198        // [1, 4]
199        assert_eq!(get_u8(&out), &[3, 6, 2, 5, 1, 4]);
200    }
201
202    #[test]
203    fn test_flip_horiz() {
204        let arr = make_3x2();
205        let out = apply_transform(&arr, TransformType::FlipHoriz);
206        assert_eq!(get_u8(&out), &[3, 2, 1, 6, 5, 4]);
207    }
208
209    #[test]
210    fn test_flip_vert() {
211        let arr = make_3x2();
212        let out = apply_transform(&arr, TransformType::FlipVert);
213        assert_eq!(get_u8(&out), &[4, 5, 6, 1, 2, 3]);
214    }
215
216    #[test]
217    fn test_flip_diag() {
218        let arr = make_3x2();
219        let out = apply_transform(&arr, TransformType::FlipDiag);
220        assert_eq!(out.dims[0].size, 2);
221        assert_eq!(out.dims[1].size, 3);
222        // Transpose:
223        // [1, 4]
224        // [2, 5]
225        // [3, 6]
226        assert_eq!(get_u8(&out), &[1, 4, 2, 5, 3, 6]);
227    }
228
229    #[test]
230    fn test_flip_anti_diag() {
231        let arr = make_3x2();
232        let out = apply_transform(&arr, TransformType::FlipAntiDiag);
233        assert_eq!(out.dims[0].size, 2);
234        assert_eq!(out.dims[1].size, 3);
235        // Anti-transpose:
236        // [6, 3]
237        // [5, 2]
238        // [4, 1]
239        assert_eq!(get_u8(&out), &[6, 3, 5, 2, 4, 1]);
240    }
241
242    #[test]
243    fn test_rot90_roundtrip() {
244        let arr = make_3x2();
245        let r1 = apply_transform(&arr, TransformType::Rot90CW);
246        let r2 = apply_transform(&r1, TransformType::Rot90CW);
247        let r3 = apply_transform(&r2, TransformType::Rot90CW);
248        let r4 = apply_transform(&r3, TransformType::Rot90CW);
249        assert_eq!(get_u8(&r4), get_u8(&arr));
250        assert_eq!(r4.dims[0].size, arr.dims[0].size);
251        assert_eq!(r4.dims[1].size, arr.dims[1].size);
252    }
253
254    // --- New TransformProcessor tests ---
255
256    #[test]
257    fn test_transform_processor() {
258        let mut proc = TransformProcessor::new(TransformType::Rot90CW);
259        let pool = NDArrayPool::new(1_000_000);
260
261        let arr = make_3x2();
262        let result = proc.process_array(&arr, &pool);
263        assert_eq!(result.output_arrays.len(), 1);
264        assert_eq!(result.output_arrays[0].dims[0].size, 2); // swapped
265        assert_eq!(result.output_arrays[0].dims[1].size, 3);
266        assert_eq!(get_u8(&result.output_arrays[0]), &[4, 1, 5, 2, 6, 3]);
267    }
268}