1use std::sync::Arc;
2
3use ad_core::ndarray::{NDArray, NDDataBuffer, NDDimension};
4use ad_core::ndarray_pool::NDArrayPool;
5use ad_core::plugin::runtime::{NDPluginProcess, ProcessResult};
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9#[repr(u8)]
10pub enum TransformType {
11 None = 0,
12 Rot90CW = 1,
13 Rot180 = 2,
14 Rot90CCW = 3,
15 FlipHoriz = 4,
16 FlipVert = 5,
17 FlipDiag = 6,
18 FlipAntiDiag = 7,
19}
20
21impl TransformType {
22 pub fn from_u8(v: u8) -> Self {
23 match v {
24 1 => Self::Rot90CW,
25 2 => Self::Rot180,
26 3 => Self::Rot90CCW,
27 4 => Self::FlipHoriz,
28 5 => Self::FlipVert,
29 6 => Self::FlipDiag,
30 7 => Self::FlipAntiDiag,
31 _ => Self::None,
32 }
33 }
34
35 pub fn swaps_dims(&self) -> bool {
37 matches!(self, Self::Rot90CW | Self::Rot90CCW | Self::FlipDiag | Self::FlipAntiDiag)
38 }
39}
40
41fn map_coords(
43 sx: usize,
44 sy: usize,
45 src_w: usize,
46 src_h: usize,
47 transform: TransformType,
48) -> (usize, usize) {
49 match transform {
50 TransformType::None => (sx, sy),
51 TransformType::Rot90CW => (src_h - 1 - sy, sx),
52 TransformType::Rot180 => (src_w - 1 - sx, src_h - 1 - sy),
53 TransformType::Rot90CCW => (sy, src_w - 1 - sx),
54 TransformType::FlipHoriz => (src_w - 1 - sx, sy),
55 TransformType::FlipVert => (sx, src_h - 1 - sy),
56 TransformType::FlipDiag => (sy, sx),
57 TransformType::FlipAntiDiag => (src_h - 1 - sy, src_w - 1 - sx),
58 }
59}
60
61pub fn apply_transform(src: &NDArray, transform: TransformType) -> NDArray {
63 if transform == TransformType::None || src.dims.len() < 2 {
64 return src.clone();
65 }
66
67 let src_w = src.dims[0].size;
68 let src_h = src.dims[1].size;
69 let (dst_w, dst_h) = if transform.swaps_dims() {
70 (src_h, src_w)
71 } else {
72 (src_w, src_h)
73 };
74
75 macro_rules! transform_buf {
76 ($vec:expr, $T:ty, $zero:expr) => {{
77 let mut out = vec![$zero; dst_w * dst_h];
78 for sy in 0..src_h {
79 for sx in 0..src_w {
80 let (dx, dy) = map_coords(sx, sy, src_w, src_h, transform);
81 out[dy * dst_w + dx] = $vec[sy * src_w + sx];
82 }
83 }
84 out
85 }};
86 }
87
88 let out_data = match &src.data {
89 NDDataBuffer::U8(v) => NDDataBuffer::U8(transform_buf!(v, u8, 0)),
90 NDDataBuffer::U16(v) => NDDataBuffer::U16(transform_buf!(v, u16, 0)),
91 NDDataBuffer::I8(v) => NDDataBuffer::I8(transform_buf!(v, i8, 0)),
92 NDDataBuffer::I16(v) => NDDataBuffer::I16(transform_buf!(v, i16, 0)),
93 NDDataBuffer::I32(v) => NDDataBuffer::I32(transform_buf!(v, i32, 0)),
94 NDDataBuffer::U32(v) => NDDataBuffer::U32(transform_buf!(v, u32, 0)),
95 NDDataBuffer::I64(v) => NDDataBuffer::I64(transform_buf!(v, i64, 0)),
96 NDDataBuffer::U64(v) => NDDataBuffer::U64(transform_buf!(v, u64, 0)),
97 NDDataBuffer::F32(v) => NDDataBuffer::F32(transform_buf!(v, f32, 0.0)),
98 NDDataBuffer::F64(v) => NDDataBuffer::F64(transform_buf!(v, f64, 0.0)),
99 };
100
101 let dims = vec![NDDimension::new(dst_w), NDDimension::new(dst_h)];
102 let mut arr = NDArray::new(dims, src.data.data_type());
103 arr.data = out_data;
104 arr.unique_id = src.unique_id;
105 arr.timestamp = src.timestamp;
106 arr.attributes = src.attributes.clone();
107 arr
108}
109
110pub struct TransformProcessor {
114 transform: TransformType,
115}
116
117impl TransformProcessor {
118 pub fn new(transform: TransformType) -> Self {
119 Self { transform }
120 }
121}
122
123impl NDPluginProcess for TransformProcessor {
124 fn process_array(&mut self, array: &NDArray, _pool: &NDArrayPool) -> ProcessResult {
125 let out = apply_transform(array, self.transform);
126 ProcessResult::arrays(vec![Arc::new(out)])
127 }
128
129 fn plugin_type(&self) -> &str {
130 "NDPluginTransform"
131 }
132}
133
134#[cfg(test)]
135mod tests {
136 use super::*;
137 use ad_core::ndarray::NDDataType;
138
139 fn make_3x2() -> NDArray {
143 let mut arr = NDArray::new(
144 vec![NDDimension::new(3), NDDimension::new(2)],
145 NDDataType::UInt8,
146 );
147 if let NDDataBuffer::U8(ref mut v) = arr.data {
148 *v = vec![1, 2, 3, 4, 5, 6];
149 }
150 arr
151 }
152
153 fn get_u8(arr: &NDArray) -> &[u8] {
154 match &arr.data {
155 NDDataBuffer::U8(v) => v,
156 _ => panic!("not u8"),
157 }
158 }
159
160 #[test]
161 fn test_none() {
162 let arr = make_3x2();
163 let out = apply_transform(&arr, TransformType::None);
164 assert_eq!(get_u8(&out), &[1, 2, 3, 4, 5, 6]);
165 }
166
167 #[test]
168 fn test_rot90cw() {
169 let arr = make_3x2();
170 let out = apply_transform(&arr, TransformType::Rot90CW);
171 assert_eq!(out.dims[0].size, 2);
172 assert_eq!(out.dims[1].size, 3);
173 assert_eq!(get_u8(&out), &[4, 1, 5, 2, 6, 3]);
178 }
179
180 #[test]
181 fn test_rot180() {
182 let arr = make_3x2();
183 let out = apply_transform(&arr, TransformType::Rot180);
184 assert_eq!(out.dims[0].size, 3);
185 assert_eq!(out.dims[1].size, 2);
186 assert_eq!(get_u8(&out), &[6, 5, 4, 3, 2, 1]);
187 }
188
189 #[test]
190 fn test_rot90ccw() {
191 let arr = make_3x2();
192 let out = apply_transform(&arr, TransformType::Rot90CCW);
193 assert_eq!(out.dims[0].size, 2);
194 assert_eq!(out.dims[1].size, 3);
195 assert_eq!(get_u8(&out), &[3, 6, 2, 5, 1, 4]);
200 }
201
202 #[test]
203 fn test_flip_horiz() {
204 let arr = make_3x2();
205 let out = apply_transform(&arr, TransformType::FlipHoriz);
206 assert_eq!(get_u8(&out), &[3, 2, 1, 6, 5, 4]);
207 }
208
209 #[test]
210 fn test_flip_vert() {
211 let arr = make_3x2();
212 let out = apply_transform(&arr, TransformType::FlipVert);
213 assert_eq!(get_u8(&out), &[4, 5, 6, 1, 2, 3]);
214 }
215
216 #[test]
217 fn test_flip_diag() {
218 let arr = make_3x2();
219 let out = apply_transform(&arr, TransformType::FlipDiag);
220 assert_eq!(out.dims[0].size, 2);
221 assert_eq!(out.dims[1].size, 3);
222 assert_eq!(get_u8(&out), &[1, 4, 2, 5, 3, 6]);
227 }
228
229 #[test]
230 fn test_flip_anti_diag() {
231 let arr = make_3x2();
232 let out = apply_transform(&arr, TransformType::FlipAntiDiag);
233 assert_eq!(out.dims[0].size, 2);
234 assert_eq!(out.dims[1].size, 3);
235 assert_eq!(get_u8(&out), &[6, 3, 5, 2, 4, 1]);
240 }
241
242 #[test]
243 fn test_rot90_roundtrip() {
244 let arr = make_3x2();
245 let r1 = apply_transform(&arr, TransformType::Rot90CW);
246 let r2 = apply_transform(&r1, TransformType::Rot90CW);
247 let r3 = apply_transform(&r2, TransformType::Rot90CW);
248 let r4 = apply_transform(&r3, TransformType::Rot90CW);
249 assert_eq!(get_u8(&r4), get_u8(&arr));
250 assert_eq!(r4.dims[0].size, arr.dims[0].size);
251 assert_eq!(r4.dims[1].size, arr.dims[1].size);
252 }
253
254 #[test]
257 fn test_transform_processor() {
258 let mut proc = TransformProcessor::new(TransformType::Rot90CW);
259 let pool = NDArrayPool::new(1_000_000);
260
261 let arr = make_3x2();
262 let result = proc.process_array(&arr, &pool);
263 assert_eq!(result.output_arrays.len(), 1);
264 assert_eq!(result.output_arrays[0].dims[0].size, 2); assert_eq!(result.output_arrays[0].dims[1].size, 3);
266 assert_eq!(get_u8(&result.output_arrays[0]), &[4, 1, 5, 2, 6, 3]);
267 }
268}