1use std::sync::Arc;
2
3use ad_core_rs::color::NDColorMode;
4use ad_core_rs::ndarray::{NDArray, NDDataBuffer, NDDimension};
5use ad_core_rs::ndarray_pool::NDArrayPool;
6use ad_core_rs::plugin::runtime::{NDPluginProcess, ProcessResult};
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19#[repr(u8)]
20pub enum TransformType {
21 None = 0,
22 Rot90CW = 1,
23 Rot180 = 2,
24 Rot90CCW = 3,
25 FlipHoriz = 4,
26 FlipDiag = 5,
28 FlipVert = 6,
30 FlipAntiDiag = 7,
32}
33
34impl TransformType {
35 pub fn from_u8(v: u8) -> Self {
36 match v {
37 1 => Self::Rot90CW,
38 2 => Self::Rot180,
39 3 => Self::Rot90CCW,
40 4 => Self::FlipHoriz,
41 5 => Self::FlipDiag,
43 6 => Self::FlipVert,
45 7 => Self::FlipAntiDiag,
46 _ => Self::None,
47 }
48 }
49
50 pub fn swaps_dims(&self) -> bool {
52 matches!(
53 self,
54 Self::Rot90CW | Self::Rot90CCW | Self::FlipDiag | Self::FlipAntiDiag
55 )
56 }
57}
58
59fn map_coords(
61 sx: usize,
62 sy: usize,
63 src_w: usize,
64 src_h: usize,
65 transform: TransformType,
66) -> (usize, usize) {
67 match transform {
68 TransformType::None => (sx, sy),
69 TransformType::Rot90CW => (src_h - 1 - sy, sx),
70 TransformType::Rot180 => (src_w - 1 - sx, src_h - 1 - sy),
71 TransformType::Rot90CCW => (sy, src_w - 1 - sx),
72 TransformType::FlipHoriz => (src_w - 1 - sx, sy),
73 TransformType::FlipVert => (sx, src_h - 1 - sy),
74 TransformType::FlipDiag => (sy, sx),
75 TransformType::FlipAntiDiag => (src_h - 1 - sy, src_w - 1 - sx),
76 }
77}
78
79fn strides_for(color_mode: NDColorMode, xs: usize, ys: usize, cs: usize) -> (usize, usize, usize) {
83 match color_mode {
84 NDColorMode::RGB1 => (cs, xs * cs, 1),
85 NDColorMode::RGB2 => (1, xs * cs, xs),
86 _ => (1, xs, xs * ys),
88 }
89}
90
91fn dims_for(
94 color_mode: NDColorMode,
95 xs: usize,
96 ys: usize,
97 cs: usize,
98 ndims: usize,
99) -> Vec<NDDimension> {
100 if ndims < 3 {
101 return vec![NDDimension::new(xs), NDDimension::new(ys)];
102 }
103 match color_mode {
104 NDColorMode::RGB1 => vec![
105 NDDimension::new(cs),
106 NDDimension::new(xs),
107 NDDimension::new(ys),
108 ],
109 NDColorMode::RGB2 => vec![
110 NDDimension::new(xs),
111 NDDimension::new(cs),
112 NDDimension::new(ys),
113 ],
114 _ => vec![
115 NDDimension::new(xs),
116 NDDimension::new(ys),
117 NDDimension::new(cs),
118 ],
119 }
120}
121
122pub fn apply_transform(src: &NDArray, transform: TransformType) -> NDArray {
129 if transform == TransformType::None || src.dims.len() < 2 {
130 return src.clone();
131 }
132
133 let info = src.info();
134 let src_w = info.x_size;
135 let src_h = info.y_size;
136 let color = info.color_size.max(1);
137 if src_w == 0 || src_h == 0 {
138 return src.clone();
139 }
140
141 let (dst_w, dst_h) = if transform.swaps_dims() {
142 (src_h, src_w)
143 } else {
144 (src_w, src_h)
145 };
146
147 let (sxs, sys, scs) = (
148 info.x_stride,
149 info.y_stride.max(1),
150 info.color_stride.max(1),
151 );
152 let (dxs, dys, dcs) = strides_for(info.color_mode, dst_w, dst_h, color);
153 let total = dst_w * dst_h * color;
154
155 macro_rules! transform_buf {
156 ($vec:expr, $zero:expr) => {{
157 let mut out = vec![$zero; total];
158 for sy in 0..src_h {
159 for sx in 0..src_w {
160 let (dx, dy) = map_coords(sx, sy, src_w, src_h, transform);
161 let s_base = sy * sys + sx * sxs;
162 let d_base = dy * dys + dx * dxs;
163 for c in 0..color {
164 out[d_base + c * dcs] = $vec[s_base + c * scs];
165 }
166 }
167 }
168 out
169 }};
170 }
171
172 let out_data = match &src.data {
173 NDDataBuffer::U8(v) => NDDataBuffer::U8(transform_buf!(v, 0)),
174 NDDataBuffer::U16(v) => NDDataBuffer::U16(transform_buf!(v, 0)),
175 NDDataBuffer::I8(v) => NDDataBuffer::I8(transform_buf!(v, 0)),
176 NDDataBuffer::I16(v) => NDDataBuffer::I16(transform_buf!(v, 0)),
177 NDDataBuffer::I32(v) => NDDataBuffer::I32(transform_buf!(v, 0)),
178 NDDataBuffer::U32(v) => NDDataBuffer::U32(transform_buf!(v, 0)),
179 NDDataBuffer::I64(v) => NDDataBuffer::I64(transform_buf!(v, 0)),
180 NDDataBuffer::U64(v) => NDDataBuffer::U64(transform_buf!(v, 0)),
181 NDDataBuffer::F32(v) => NDDataBuffer::F32(transform_buf!(v, 0.0)),
182 NDDataBuffer::F64(v) => NDDataBuffer::F64(transform_buf!(v, 0.0)),
183 };
184
185 let dims = dims_for(info.color_mode, dst_w, dst_h, color, src.dims.len());
186 let mut arr = NDArray::new(dims, src.data.data_type());
187 arr.data = out_data;
188 arr.unique_id = src.unique_id;
189 arr.timestamp = src.timestamp;
190 arr.time_stamp = src.time_stamp;
191 arr.attributes = src.attributes.clone();
192 arr
193}
194
195pub struct TransformProcessor {
199 transform: TransformType,
200 transform_type_idx: Option<usize>,
201}
202
203impl TransformProcessor {
204 pub fn new(transform: TransformType) -> Self {
205 Self {
206 transform,
207 transform_type_idx: None,
208 }
209 }
210}
211
212impl NDPluginProcess for TransformProcessor {
213 fn process_array(&mut self, array: &NDArray, _pool: &NDArrayPool) -> ProcessResult {
214 let out = apply_transform(array, self.transform);
215 ProcessResult::arrays(vec![Arc::new(out)])
216 }
217
218 fn plugin_type(&self) -> &str {
219 "NDPluginTransform"
220 }
221
222 fn register_params(
223 &mut self,
224 base: &mut asyn_rs::port::PortDriverBase,
225 ) -> asyn_rs::error::AsynResult<()> {
226 use asyn_rs::param::ParamType;
227 base.create_param("TRANSFORM_TYPE", ParamType::Int32)?;
228 self.transform_type_idx = base.find_param("TRANSFORM_TYPE");
229 Ok(())
230 }
231
232 fn on_param_change(
233 &mut self,
234 reason: usize,
235 params: &ad_core_rs::plugin::runtime::PluginParamSnapshot,
236 ) -> ad_core_rs::plugin::runtime::ParamChangeResult {
237 if Some(reason) == self.transform_type_idx {
238 self.transform = TransformType::from_u8(params.value.as_i32() as u8);
239 }
240 ad_core_rs::plugin::runtime::ParamChangeResult::updates(vec![])
241 }
242}
243
244#[cfg(test)]
245mod tests {
246 use super::*;
247 use ad_core_rs::ndarray::NDDataType;
248
249 fn make_3x2() -> NDArray {
253 let mut arr = NDArray::new(
254 vec![NDDimension::new(3), NDDimension::new(2)],
255 NDDataType::UInt8,
256 );
257 if let NDDataBuffer::U8(ref mut v) = arr.data {
258 *v = vec![1, 2, 3, 4, 5, 6];
259 }
260 arr
261 }
262
263 fn get_u8(arr: &NDArray) -> &[u8] {
264 match &arr.data {
265 NDDataBuffer::U8(v) => v,
266 _ => panic!("not u8"),
267 }
268 }
269
270 #[test]
271 fn test_none() {
272 let arr = make_3x2();
273 let out = apply_transform(&arr, TransformType::None);
274 assert_eq!(get_u8(&out), &[1, 2, 3, 4, 5, 6]);
275 }
276
277 #[test]
278 fn test_rot90cw() {
279 let arr = make_3x2();
280 let out = apply_transform(&arr, TransformType::Rot90CW);
281 assert_eq!(out.dims[0].size, 2);
282 assert_eq!(out.dims[1].size, 3);
283 assert_eq!(get_u8(&out), &[4, 1, 5, 2, 6, 3]);
288 }
289
290 #[test]
291 fn test_rot180() {
292 let arr = make_3x2();
293 let out = apply_transform(&arr, TransformType::Rot180);
294 assert_eq!(out.dims[0].size, 3);
295 assert_eq!(out.dims[1].size, 2);
296 assert_eq!(get_u8(&out), &[6, 5, 4, 3, 2, 1]);
297 }
298
299 #[test]
300 fn test_rot90ccw() {
301 let arr = make_3x2();
302 let out = apply_transform(&arr, TransformType::Rot90CCW);
303 assert_eq!(out.dims[0].size, 2);
304 assert_eq!(out.dims[1].size, 3);
305 assert_eq!(get_u8(&out), &[3, 6, 2, 5, 1, 4]);
310 }
311
312 #[test]
313 fn test_flip_horiz() {
314 let arr = make_3x2();
315 let out = apply_transform(&arr, TransformType::FlipHoriz);
316 assert_eq!(get_u8(&out), &[3, 2, 1, 6, 5, 4]);
317 }
318
319 #[test]
320 fn test_flip_vert() {
321 let arr = make_3x2();
322 let out = apply_transform(&arr, TransformType::FlipVert);
323 assert_eq!(get_u8(&out), &[4, 5, 6, 1, 2, 3]);
324 }
325
326 #[test]
327 fn test_flip_diag() {
328 let arr = make_3x2();
329 let out = apply_transform(&arr, TransformType::FlipDiag);
330 assert_eq!(out.dims[0].size, 2);
331 assert_eq!(out.dims[1].size, 3);
332 assert_eq!(get_u8(&out), &[1, 4, 2, 5, 3, 6]);
337 }
338
339 #[test]
340 fn test_flip_anti_diag() {
341 let arr = make_3x2();
342 let out = apply_transform(&arr, TransformType::FlipAntiDiag);
343 assert_eq!(out.dims[0].size, 2);
344 assert_eq!(out.dims[1].size, 3);
345 assert_eq!(get_u8(&out), &[6, 3, 5, 2, 4, 1]);
350 }
351
352 #[test]
353 fn test_rot90_roundtrip() {
354 let arr = make_3x2();
355 let r1 = apply_transform(&arr, TransformType::Rot90CW);
356 let r2 = apply_transform(&r1, TransformType::Rot90CW);
357 let r3 = apply_transform(&r2, TransformType::Rot90CW);
358 let r4 = apply_transform(&r3, TransformType::Rot90CW);
359 assert_eq!(get_u8(&r4), get_u8(&arr));
360 assert_eq!(r4.dims[0].size, arr.dims[0].size);
361 assert_eq!(r4.dims[1].size, arr.dims[1].size);
362 }
363
364 #[test]
365 fn test_from_u8_cpp_enum_order() {
366 assert_eq!(TransformType::from_u8(0), TransformType::None);
369 assert_eq!(TransformType::from_u8(1), TransformType::Rot90CW);
370 assert_eq!(TransformType::from_u8(2), TransformType::Rot180);
371 assert_eq!(TransformType::from_u8(3), TransformType::Rot90CCW);
372 assert_eq!(TransformType::from_u8(4), TransformType::FlipHoriz);
373 assert_eq!(TransformType::from_u8(5), TransformType::FlipDiag);
374 assert_eq!(TransformType::from_u8(6), TransformType::FlipVert);
375 assert_eq!(TransformType::from_u8(7), TransformType::FlipAntiDiag);
376 }
377
378 #[test]
379 fn test_transform_5_is_transpose() {
380 let arr = make_3x2();
382 let out = apply_transform(&arr, TransformType::from_u8(5));
383 assert_eq!(out.dims[0].size, 2);
384 assert_eq!(out.dims[1].size, 3);
385 assert_eq!(get_u8(&out), &[1, 4, 2, 5, 3, 6]); }
387
388 #[test]
389 fn test_transform_6_is_vertical_flip() {
390 let arr = make_3x2();
392 let out = apply_transform(&arr, TransformType::from_u8(6));
393 assert_eq!(out.dims[0].size, 3);
394 assert_eq!(out.dims[1].size, 2);
395 assert_eq!(get_u8(&out), &[4, 5, 6, 1, 2, 3]); }
397
398 fn make_rgb1_2x2() -> NDArray {
401 use ad_core_rs::attributes::{NDAttrSource, NDAttrValue, NDAttribute};
402 let mut arr = NDArray::new(
403 vec![
404 NDDimension::new(3),
405 NDDimension::new(2),
406 NDDimension::new(2),
407 ],
408 NDDataType::UInt8,
409 );
410 arr.attributes.add(NDAttribute::new_static(
411 "ColorMode",
412 "",
413 NDAttrSource::Driver,
414 NDAttrValue::Int32(NDColorMode::RGB1 as i32),
415 ));
416 if let NDDataBuffer::U8(ref mut v) = arr.data {
417 for y in 0..2 {
419 for x in 0..2 {
420 for c in 0..3 {
421 v[y * 6 + x * 3 + c] = (100 * y + 10 * x + c) as u8;
422 }
423 }
424 }
425 }
426 arr
427 }
428
429 #[test]
430 fn test_rgb1_flip_horiz_keeps_color_grouping() {
431 let arr = make_rgb1_2x2();
434 let out = apply_transform(&arr, TransformType::FlipHoriz);
435 assert_eq!(out.dims[0].size, 3);
437 assert_eq!(out.dims[1].size, 2);
438 assert_eq!(out.dims[2].size, 2);
439 if let NDDataBuffer::U8(v) = &out.data {
440 assert_eq!(&v[0..3], &[10, 11, 12]);
442 assert_eq!(&v[3..6], &[0, 1, 2]);
444 assert_eq!(&v[6..9], &[110, 111, 112]);
446 } else {
447 panic!("not u8");
448 }
449 }
450
451 #[test]
452 fn test_rgb1_rot90cw_swaps_dims_and_keeps_color() {
453 let arr = make_rgb1_2x2();
454 let out = apply_transform(&arr, TransformType::Rot90CW);
455 assert_eq!(out.dims[0].size, 3);
457 assert_eq!(out.dims[1].size, 2);
458 assert_eq!(out.dims[2].size, 2);
459 if let NDDataBuffer::U8(v) = &out.data {
460 assert_eq!(&v[0..3], &[100, 101, 102]);
464 } else {
465 panic!("not u8");
466 }
467 }
468
469 #[test]
472 fn test_transform_processor() {
473 let mut proc = TransformProcessor::new(TransformType::Rot90CW);
474 let pool = NDArrayPool::new(1_000_000);
475
476 let arr = make_3x2();
477 let result = proc.process_array(&arr, &pool);
478 assert_eq!(result.output_arrays.len(), 1);
479 assert_eq!(result.output_arrays[0].dims[0].size, 2); assert_eq!(result.output_arrays[0].dims[1].size, 3);
481 assert_eq!(get_u8(&result.output_arrays[0]), &[4, 1, 5, 2, 6, 3]);
482 }
483}