1use std::sync::Arc;
2
3use ad_core_rs::ndarray::{NDArray, NDDataBuffer, NDDimension};
4use ad_core_rs::ndarray_pool::NDArrayPool;
5use ad_core_rs::plugin::runtime::{NDPluginProcess, ProcessResult};
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9#[repr(u8)]
10pub enum TransformType {
11 None = 0,
12 Rot90CW = 1,
13 Rot180 = 2,
14 Rot90CCW = 3,
15 FlipHoriz = 4,
16 FlipVert = 5,
17 FlipDiag = 6,
18 FlipAntiDiag = 7,
19}
20
21impl TransformType {
22 pub fn from_u8(v: u8) -> Self {
23 match v {
24 1 => Self::Rot90CW,
25 2 => Self::Rot180,
26 3 => Self::Rot90CCW,
27 4 => Self::FlipHoriz,
28 5 => Self::FlipVert,
29 6 => Self::FlipDiag,
30 7 => Self::FlipAntiDiag,
31 _ => Self::None,
32 }
33 }
34
35 pub fn swaps_dims(&self) -> bool {
37 matches!(
38 self,
39 Self::Rot90CW | Self::Rot90CCW | Self::FlipDiag | Self::FlipAntiDiag
40 )
41 }
42}
43
44fn map_coords(
46 sx: usize,
47 sy: usize,
48 src_w: usize,
49 src_h: usize,
50 transform: TransformType,
51) -> (usize, usize) {
52 match transform {
53 TransformType::None => (sx, sy),
54 TransformType::Rot90CW => (src_h - 1 - sy, sx),
55 TransformType::Rot180 => (src_w - 1 - sx, src_h - 1 - sy),
56 TransformType::Rot90CCW => (sy, src_w - 1 - sx),
57 TransformType::FlipHoriz => (src_w - 1 - sx, sy),
58 TransformType::FlipVert => (sx, src_h - 1 - sy),
59 TransformType::FlipDiag => (sy, sx),
60 TransformType::FlipAntiDiag => (src_h - 1 - sy, src_w - 1 - sx),
61 }
62}
63
64pub fn apply_transform(src: &NDArray, transform: TransformType) -> NDArray {
66 if transform == TransformType::None || src.dims.len() < 2 {
67 return src.clone();
68 }
69
70 let src_w = src.dims[0].size;
71 let src_h = src.dims[1].size;
72 let (dst_w, dst_h) = if transform.swaps_dims() {
73 (src_h, src_w)
74 } else {
75 (src_w, src_h)
76 };
77
78 macro_rules! transform_buf {
79 ($vec:expr, $T:ty, $zero:expr) => {{
80 let mut out = vec![$zero; dst_w * dst_h];
81 for sy in 0..src_h {
82 for sx in 0..src_w {
83 let (dx, dy) = map_coords(sx, sy, src_w, src_h, transform);
84 out[dy * dst_w + dx] = $vec[sy * src_w + sx];
85 }
86 }
87 out
88 }};
89 }
90
91 let out_data = match &src.data {
92 NDDataBuffer::U8(v) => NDDataBuffer::U8(transform_buf!(v, u8, 0)),
93 NDDataBuffer::U16(v) => NDDataBuffer::U16(transform_buf!(v, u16, 0)),
94 NDDataBuffer::I8(v) => NDDataBuffer::I8(transform_buf!(v, i8, 0)),
95 NDDataBuffer::I16(v) => NDDataBuffer::I16(transform_buf!(v, i16, 0)),
96 NDDataBuffer::I32(v) => NDDataBuffer::I32(transform_buf!(v, i32, 0)),
97 NDDataBuffer::U32(v) => NDDataBuffer::U32(transform_buf!(v, u32, 0)),
98 NDDataBuffer::I64(v) => NDDataBuffer::I64(transform_buf!(v, i64, 0)),
99 NDDataBuffer::U64(v) => NDDataBuffer::U64(transform_buf!(v, u64, 0)),
100 NDDataBuffer::F32(v) => NDDataBuffer::F32(transform_buf!(v, f32, 0.0)),
101 NDDataBuffer::F64(v) => NDDataBuffer::F64(transform_buf!(v, f64, 0.0)),
102 };
103
104 let dims = vec![NDDimension::new(dst_w), NDDimension::new(dst_h)];
105 let mut arr = NDArray::new(dims, src.data.data_type());
106 arr.data = out_data;
107 arr.unique_id = src.unique_id;
108 arr.timestamp = src.timestamp;
109 arr.attributes = src.attributes.clone();
110 arr
111}
112
113pub struct TransformProcessor {
117 transform: TransformType,
118 transform_type_idx: Option<usize>,
119}
120
121impl TransformProcessor {
122 pub fn new(transform: TransformType) -> Self {
123 Self {
124 transform,
125 transform_type_idx: None,
126 }
127 }
128}
129
130impl NDPluginProcess for TransformProcessor {
131 fn process_array(&mut self, array: &NDArray, _pool: &NDArrayPool) -> ProcessResult {
132 let out = apply_transform(array, self.transform);
133 ProcessResult::arrays(vec![Arc::new(out)])
134 }
135
136 fn plugin_type(&self) -> &str {
137 "NDPluginTransform"
138 }
139
140 fn register_params(
141 &mut self,
142 base: &mut asyn_rs::port::PortDriverBase,
143 ) -> asyn_rs::error::AsynResult<()> {
144 use asyn_rs::param::ParamType;
145 base.create_param("TRANSFORM_TYPE", ParamType::Int32)?;
146 self.transform_type_idx = base.find_param("TRANSFORM_TYPE");
147 Ok(())
148 }
149
150 fn on_param_change(
151 &mut self,
152 reason: usize,
153 params: &ad_core_rs::plugin::runtime::PluginParamSnapshot,
154 ) -> ad_core_rs::plugin::runtime::ParamChangeResult {
155 if Some(reason) == self.transform_type_idx {
156 self.transform = TransformType::from_u8(params.value.as_i32() as u8);
157 }
158 ad_core_rs::plugin::runtime::ParamChangeResult::updates(vec![])
159 }
160}
161
162#[cfg(test)]
163mod tests {
164 use super::*;
165 use ad_core_rs::ndarray::NDDataType;
166
167 fn make_3x2() -> NDArray {
171 let mut arr = NDArray::new(
172 vec![NDDimension::new(3), NDDimension::new(2)],
173 NDDataType::UInt8,
174 );
175 if let NDDataBuffer::U8(ref mut v) = arr.data {
176 *v = vec![1, 2, 3, 4, 5, 6];
177 }
178 arr
179 }
180
181 fn get_u8(arr: &NDArray) -> &[u8] {
182 match &arr.data {
183 NDDataBuffer::U8(v) => v,
184 _ => panic!("not u8"),
185 }
186 }
187
188 #[test]
189 fn test_none() {
190 let arr = make_3x2();
191 let out = apply_transform(&arr, TransformType::None);
192 assert_eq!(get_u8(&out), &[1, 2, 3, 4, 5, 6]);
193 }
194
195 #[test]
196 fn test_rot90cw() {
197 let arr = make_3x2();
198 let out = apply_transform(&arr, TransformType::Rot90CW);
199 assert_eq!(out.dims[0].size, 2);
200 assert_eq!(out.dims[1].size, 3);
201 assert_eq!(get_u8(&out), &[4, 1, 5, 2, 6, 3]);
206 }
207
208 #[test]
209 fn test_rot180() {
210 let arr = make_3x2();
211 let out = apply_transform(&arr, TransformType::Rot180);
212 assert_eq!(out.dims[0].size, 3);
213 assert_eq!(out.dims[1].size, 2);
214 assert_eq!(get_u8(&out), &[6, 5, 4, 3, 2, 1]);
215 }
216
217 #[test]
218 fn test_rot90ccw() {
219 let arr = make_3x2();
220 let out = apply_transform(&arr, TransformType::Rot90CCW);
221 assert_eq!(out.dims[0].size, 2);
222 assert_eq!(out.dims[1].size, 3);
223 assert_eq!(get_u8(&out), &[3, 6, 2, 5, 1, 4]);
228 }
229
230 #[test]
231 fn test_flip_horiz() {
232 let arr = make_3x2();
233 let out = apply_transform(&arr, TransformType::FlipHoriz);
234 assert_eq!(get_u8(&out), &[3, 2, 1, 6, 5, 4]);
235 }
236
237 #[test]
238 fn test_flip_vert() {
239 let arr = make_3x2();
240 let out = apply_transform(&arr, TransformType::FlipVert);
241 assert_eq!(get_u8(&out), &[4, 5, 6, 1, 2, 3]);
242 }
243
244 #[test]
245 fn test_flip_diag() {
246 let arr = make_3x2();
247 let out = apply_transform(&arr, TransformType::FlipDiag);
248 assert_eq!(out.dims[0].size, 2);
249 assert_eq!(out.dims[1].size, 3);
250 assert_eq!(get_u8(&out), &[1, 4, 2, 5, 3, 6]);
255 }
256
257 #[test]
258 fn test_flip_anti_diag() {
259 let arr = make_3x2();
260 let out = apply_transform(&arr, TransformType::FlipAntiDiag);
261 assert_eq!(out.dims[0].size, 2);
262 assert_eq!(out.dims[1].size, 3);
263 assert_eq!(get_u8(&out), &[6, 3, 5, 2, 4, 1]);
268 }
269
270 #[test]
271 fn test_rot90_roundtrip() {
272 let arr = make_3x2();
273 let r1 = apply_transform(&arr, TransformType::Rot90CW);
274 let r2 = apply_transform(&r1, TransformType::Rot90CW);
275 let r3 = apply_transform(&r2, TransformType::Rot90CW);
276 let r4 = apply_transform(&r3, TransformType::Rot90CW);
277 assert_eq!(get_u8(&r4), get_u8(&arr));
278 assert_eq!(r4.dims[0].size, arr.dims[0].size);
279 assert_eq!(r4.dims[1].size, arr.dims[1].size);
280 }
281
282 #[test]
285 fn test_transform_processor() {
286 let mut proc = TransformProcessor::new(TransformType::Rot90CW);
287 let pool = NDArrayPool::new(1_000_000);
288
289 let arr = make_3x2();
290 let result = proc.process_array(&arr, &pool);
291 assert_eq!(result.output_arrays.len(), 1);
292 assert_eq!(result.output_arrays[0].dims[0].size, 2); assert_eq!(result.output_arrays[0].dims[1].size, 3);
294 assert_eq!(get_u8(&result.output_arrays[0]), &[4, 1, 5, 2, 6, 3]);
295 }
296}