cu_apriltag/
lib.rs

1#[cfg(unix)]
2use std::mem::ManuallyDrop;
3
4#[cfg(unix)]
5use apriltag::{Detector, DetectorBuilder, Family, Image, TagParams};
6
7#[cfg(unix)]
8use apriltag_sys::image_u8_t;
9
10use bincode::de::Decoder;
11use bincode::error::DecodeError;
12use cu29::bincode::{Decode, Encode};
13use cu29::prelude::*;
14use cu_sensor_payloads::CuImage;
15use cu_spatial_payloads::Pose as CuPose;
16use serde::ser::SerializeTuple;
17use serde::{Deserialize, Deserializer, Serialize};
18
19// the maximum number of detections that can be returned by the detector
20const MAX_DETECTIONS: usize = 16;
21
22// Defaults
23#[cfg(not(windows))]
24const TAG_SIZE: f64 = 0.14;
25#[cfg(not(windows))]
26const FX: f64 = 2600.0;
27#[cfg(not(windows))]
28const FY: f64 = 2600.0;
29#[cfg(not(windows))]
30const CX: f64 = 900.0;
31#[cfg(not(windows))]
32const CY: f64 = 520.0;
33#[cfg(not(windows))]
34const FAMILY: &str = "tag16h5";
35
36#[derive(Default, Debug, Clone, Encode)]
37pub struct AprilTagDetections {
38    pub ids: CuArrayVec<usize, MAX_DETECTIONS>,
39    pub poses: CuArrayVec<CuPose<f32>, MAX_DETECTIONS>,
40    pub decision_margins: CuArrayVec<f32, MAX_DETECTIONS>,
41}
42
43impl Decode<()> for AprilTagDetections {
44    fn decode<D: Decoder<Context = ()>>(decoder: &mut D) -> Result<Self, DecodeError> {
45        let ids = CuArrayVec::<usize, MAX_DETECTIONS>::decode(decoder)?;
46        let poses = CuArrayVec::<CuPose<f32>, MAX_DETECTIONS>::decode(decoder)?;
47        let decision_margins = CuArrayVec::<f32, MAX_DETECTIONS>::decode(decoder)?;
48        Ok(AprilTagDetections {
49            ids,
50            poses,
51            decision_margins,
52        })
53    }
54}
55
56// implement serde support for AprilTagDetections
57// This is so it can be logged with debug!.
58impl Serialize for AprilTagDetections {
59    fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
60        let CuArrayVec(ids) = &self.ids;
61        let CuArrayVec(poses) = &self.poses;
62        let CuArrayVec(decision_margins) = &self.decision_margins;
63        let mut tup = serializer.serialize_tuple(ids.len())?;
64
65        ids.iter()
66            .zip(poses.iter())
67            .zip(decision_margins.iter())
68            .map(|((id, pose), margin)| (id, pose, margin))
69            .for_each(|(id, pose, margin)| {
70                tup.serialize_element(&(id, pose, margin)).unwrap();
71            });
72
73        tup.end()
74    }
75}
76
77impl<'de> Deserialize<'de> for AprilTagDetections {
78    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
79    where
80        D: Deserializer<'de>,
81    {
82        struct AprilTagDetectionsVisitor;
83
84        impl<'de> serde::de::Visitor<'de> for AprilTagDetectionsVisitor {
85            type Value = AprilTagDetections;
86
87            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
88                formatter.write_str("a tuple of (id, pose, decision_margin)")
89            }
90
91            fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
92            where
93                A: serde::de::SeqAccess<'de>,
94            {
95                let mut detections = AprilTagDetections::new();
96                while let Some((id, pose, decision_margin)) = seq.next_element()? {
97                    let CuArrayVec(ids) = &mut detections.ids;
98                    ids.push(id);
99                    let CuArrayVec(poses) = &mut detections.poses;
100                    poses.push(pose);
101                    let CuArrayVec(decision_margins) = &mut detections.decision_margins;
102                    decision_margins.push(decision_margin);
103                }
104                Ok(detections)
105            }
106        }
107
108        deserializer.deserialize_tuple(MAX_DETECTIONS, AprilTagDetectionsVisitor)
109    }
110}
111
112impl AprilTagDetections {
113    fn new() -> Self {
114        Self::default()
115    }
116    pub fn filtered_by_decision_margin(
117        &self,
118        threshold: f32,
119    ) -> impl Iterator<Item = (usize, &CuPose<f32>, f32)> {
120        let CuArrayVec(ids) = &self.ids;
121        let CuArrayVec(poses) = &self.poses;
122        let CuArrayVec(decision_margins) = &self.decision_margins;
123
124        ids.iter()
125            .zip(poses.iter())
126            .zip(decision_margins.iter())
127            .filter_map(move |((id, pose), margin)| {
128                (*margin > threshold).then_some((*id, pose, *margin))
129            })
130    }
131}
132
133#[cfg(unix)]
134pub struct AprilTags {
135    detector: Detector,
136    tag_params: TagParams,
137}
138
139#[cfg(not(unix))]
140pub struct AprilTags {}
141
142#[cfg(not(windows))]
143fn image_from_cuimage<A>(cu_image: &CuImage<A>) -> ManuallyDrop<Image>
144where
145    A: ArrayLike<Element = u8>,
146{
147    unsafe {
148        // Try to emulate what the C code is doing on the heap to avoid double free
149        let buffer_ptr = cu_image.buffer_handle.with_inner(|inner| inner.as_ptr());
150        let low_level_img = Box::new(image_u8_t {
151            buf: buffer_ptr as *mut u8,
152            width: cu_image.format.width as i32,
153            height: cu_image.format.height as i32,
154            stride: cu_image.format.stride as i32,
155        });
156        let ptr = Box::into_raw(low_level_img);
157        ManuallyDrop::new(Image::from_raw(ptr))
158    }
159}
160
161impl Freezable for AprilTags {}
162
163#[cfg(windows)]
164impl CuTask for AprilTags {
165    type Input<'m> = input_msg!(CuImage<Vec<u8>>);
166    type Output<'m> = output_msg!(AprilTagDetections);
167
168    fn new(_config: Option<&ComponentConfig>) -> CuResult<Self>
169    where
170        Self: Sized,
171    {
172        Ok(Self {})
173    }
174
175    fn process(
176        &mut self,
177        _clock: &RobotClock,
178        _input: &Self::Input<'_>,
179        _output: &mut Self::Output<'_>,
180    ) -> CuResult<()> {
181        Ok(())
182    }
183}
184
185#[cfg(not(windows))]
186impl CuTask for AprilTags {
187    type Input<'m> = input_msg!(CuImage<Vec<u8>>);
188    type Output<'m> = output_msg!(AprilTagDetections);
189
190    fn new(_config: Option<&ComponentConfig>) -> CuResult<Self>
191    where
192        Self: Sized,
193    {
194        if let Some(config) = _config {
195            let family_cfg: String = config.get("family").unwrap_or(FAMILY.to_string());
196            let family: Family = family_cfg.parse().unwrap();
197            let bits_corrected: u32 = config.get("bits_corrected").unwrap_or(1);
198            let tagsize = config.get("tag_size").unwrap_or(TAG_SIZE);
199            let fx = config.get("fx").unwrap_or(FX);
200            let fy = config.get("fy").unwrap_or(FY);
201            let cx = config.get("cx").unwrap_or(CX);
202            let cy = config.get("cy").unwrap_or(CY);
203            let tag_params = TagParams {
204                fx,
205                fy,
206                cx,
207                cy,
208                tagsize,
209            };
210
211            let detector = DetectorBuilder::default()
212                .add_family_bits(family, bits_corrected as usize)
213                .build()
214                .unwrap();
215            return Ok(Self {
216                detector,
217                tag_params,
218            });
219        }
220        Ok(Self {
221            detector: DetectorBuilder::default()
222                .add_family_bits(FAMILY.parse::<Family>().unwrap(), 1)
223                .build()
224                .unwrap(),
225            tag_params: TagParams {
226                fx: FX,
227                fy: FY,
228                cx: CX,
229                cy: CY,
230                tagsize: TAG_SIZE,
231            },
232        })
233    }
234
235    fn process(
236        &mut self,
237        _clock: &RobotClock,
238        input: &Self::Input<'_>,
239        output: &mut Self::Output<'_>,
240    ) -> CuResult<()> {
241        let mut result = AprilTagDetections::new();
242        if let Some(payload) = input.payload() {
243            let image = image_from_cuimage(payload);
244            let detections = self.detector.detect(&image);
245            for detection in detections {
246                if let Some(aprilpose) = detection.estimate_tag_pose(&self.tag_params) {
247                    let translation = aprilpose.translation();
248                    let rotation = aprilpose.rotation();
249                    let mut mat: [[f32; 4]; 4] = [[0.0, 0.0, 0.0, 0.0]; 4];
250                    mat[0][3] = translation.data()[0] as f32;
251                    mat[1][3] = translation.data()[1] as f32;
252                    mat[2][3] = translation.data()[2] as f32;
253                    mat[0][0] = rotation.data()[0] as f32;
254                    mat[0][1] = rotation.data()[3] as f32;
255                    mat[0][2] = rotation.data()[2 * 3] as f32;
256                    mat[1][0] = rotation.data()[1] as f32;
257                    mat[1][1] = rotation.data()[1 + 3] as f32;
258                    mat[1][2] = rotation.data()[1 + 2 * 3] as f32;
259                    mat[2][0] = rotation.data()[2] as f32;
260                    mat[2][1] = rotation.data()[2 + 3] as f32;
261                    mat[2][2] = rotation.data()[2 + 2 * 3] as f32;
262
263                    let pose = CuPose::<f32>::from_matrix(mat);
264                    let CuArrayVec(detections) = &mut result.poses;
265                    detections.push(pose);
266                    let CuArrayVec(decision_margin) = &mut result.decision_margins;
267                    decision_margin.push(detection.decision_margin());
268                    let CuArrayVec(ids) = &mut result.ids;
269                    ids.push(detection.id());
270                }
271            }
272        };
273        output.tov = input.tov;
274        output.set_payload(result);
275        Ok(())
276    }
277}
278
279#[cfg(test)]
280mod tests {
281    #[allow(unused_imports)]
282    use super::*;
283    use anyhow::Context;
284    use anyhow::Result;
285    use image::{imageops::crop, imageops::resize, imageops::FilterType, Luma};
286    use image::{ImageBuffer, ImageReader};
287
288    #[cfg(not(windows))]
289    use cu_sensor_payloads::CuImageBufferFormat;
290
291    #[allow(dead_code)]
292    fn process_image(path: &str) -> Result<ImageBuffer<Luma<u8>, Vec<u8>>> {
293        let reader = ImageReader::open(path).with_context(|| "Failed to open image")?;
294        let mut img = reader
295            .decode()
296            .context("Failed to decode image")?
297            .into_luma8();
298        let (orig_w, orig_h) = img.dimensions();
299
300        let new_h = (orig_w as f32 * 9.0 / 16.0) as u32;
301        let crop_y = (orig_h - new_h) / 2; // Center crop
302
303        let cropped = crop(&mut img, 0, crop_y, orig_w, new_h).to_image();
304        Ok(resize(&cropped, 1920, 1080, FilterType::Lanczos3))
305    }
306
307    #[test]
308    #[cfg(not(windows))]
309    fn test_end2end_apriltag() -> Result<()> {
310        let img = process_image("tests/data/simple.jpg")?;
311        let format = CuImageBufferFormat {
312            width: img.width(),
313            height: img.height(),
314            stride: img.width(),
315            pixel_format: "GRAY".as_bytes().try_into()?,
316        };
317        let buffer_handle = CuHandle::new_detached(img.into_raw());
318        let cuimage = CuImage::new(format, buffer_handle);
319
320        let mut config = ComponentConfig::default();
321        config.set("tag_size", 0.14);
322        config.set("fx", 2600.0);
323        config.set("fy", 2600.0);
324        config.set("cx", 900.0);
325        config.set("cy", 520.0);
326        config.set("family", "tag16h5".to_string());
327
328        let mut task = AprilTags::new(Some(&config))?;
329        let input = CuMsg::<CuImage<Vec<u8>>>::new(Some(cuimage));
330        let mut output = CuMsg::<AprilTagDetections>::default();
331
332        let clock = RobotClock::new();
333        let result = task.process(&clock, &input, &mut output);
334        assert!(result.is_ok());
335
336        if let Some(detections) = output.payload() {
337            let detections = detections
338                .filtered_by_decision_margin(150.0)
339                .collect::<Vec<_>>();
340
341            assert_eq!(detections.len(), 1);
342            assert_eq!(detections[0].0, 4);
343            return Ok(());
344        }
345        Err(anyhow::anyhow!("No output"))
346    }
347}