kornia_apriltag/
lib.rs

1#![deny(missing_docs)]
2//! # Kornia AprilTag
3
4use std::collections::HashMap;
5
6use kornia_image::{
7    allocator::{CpuAllocator, ImageAllocator},
8    Image, ImageSize,
9};
10use kornia_imgproc::resize::resize_fast_mono;
11
12use crate::{
13    decoder::{decode_tags, Detection, GrayModelPair},
14    errors::AprilTagError,
15    family::{TagFamily, TagFamilyKind},
16    quad::{fit_quads, FitQuadConfig},
17    segmentation::{find_connected_components, find_gradient_clusters, GradientInfo},
18    threshold::{adaptive_threshold, TileMinMax},
19    union_find::UnionFind,
20    utils::Pixel,
21};
22
23/// Error types for AprilTag detection.
24pub mod errors;
25
26/// Utility functions for AprilTag detection.
27pub mod utils;
28
29/// Thresholding utilities for AprilTag detection.
30pub mod threshold;
31
32/// image iteration utilities module.
33pub(crate) mod iter;
34
35/// Segmentation utilities for AprilTag detection.
36pub mod segmentation;
37
38/// Union-find utilities for AprilTag detection.
39pub mod union_find;
40
41/// AprilTag family definitions and utilities.
42pub mod family;
43
44/// Quad detection utilities for AprilTag detection.
45pub mod quad;
46
47/// Decoding utilities for AprilTag detection.
48pub mod decoder;
49
50/// Configuration for decoding AprilTags.
51#[derive(Debug, Clone, PartialEq)]
52pub struct DecodeTagsConfig {
53    /// List of tag families to detect.
54    pub tag_families: Vec<TagFamily>,
55    /// Configuration for quad fitting.
56    pub fit_quad_config: FitQuadConfig,
57    /// Whether to enable edge refinement before decoding.
58    pub refine_edges_enabled: bool,
59    /// Sharpening factor applied during decoding.
60    pub decode_sharpening: f32,
61    /// Whether normal border tags are present.
62    pub normal_border: bool,
63    /// Whether reversed border tags are present.
64    pub reversed_border: bool,
65    /// Minimum tag width at border among all families.
66    pub min_tag_width: usize,
67    /// Minimum difference between white and black pixels for thresholding.
68    pub min_white_black_difference: u8,
69    /// Downscale factor for input images.
70    pub downscale_factor: usize,
71}
72
73impl DecodeTagsConfig {
74    /// Creates a new `DecodeTagsConfig` with the given tag family kinds.
75    pub fn new(tag_family_kinds: Vec<TagFamilyKind>) -> Self {
76        const DEFAULT_DOWNSCALE_FACTOR: usize = 2;
77
78        let mut tag_families = Vec::with_capacity(tag_family_kinds.len());
79        let mut normal_border = false;
80        let mut reversed_border = false;
81        let mut min_tag_width = usize::MAX;
82
83        tag_family_kinds.iter().for_each(|family_kind| {
84            let family: TagFamily = family_kind.into();
85            if family.width_at_border < min_tag_width {
86                min_tag_width = family.width_at_border;
87            }
88            normal_border |= !family.reversed_border;
89            reversed_border |= family.reversed_border;
90
91            tag_families.push(family);
92        });
93
94        min_tag_width /= DEFAULT_DOWNSCALE_FACTOR;
95
96        if min_tag_width < 3 {
97            min_tag_width = 3;
98        }
99
100        Self {
101            tag_families,
102            fit_quad_config: Default::default(),
103            normal_border,
104            refine_edges_enabled: true,
105            decode_sharpening: 0.25,
106            reversed_border,
107            min_tag_width,
108            min_white_black_difference: 5,
109            downscale_factor: DEFAULT_DOWNSCALE_FACTOR,
110        }
111    }
112
113    /// Creates a `DecodeTagsConfig` with all supported tag families.
114    pub fn all() -> Self {
115        Self::new(TagFamilyKind::all())
116    }
117
118    /// Adds a tag family to the configuration.
119    pub fn add(&mut self, family: TagFamily) {
120        if family.width_at_border < self.min_tag_width {
121            self.min_tag_width = family.width_at_border;
122        }
123        self.normal_border |= !family.reversed_border;
124        self.reversed_border |= family.reversed_border;
125
126        self.tag_families.push(family);
127    }
128}
129
130/// Decoder for AprilTag detection and decoding.
131pub struct AprilTagDecoder {
132    config: DecodeTagsConfig,
133    downscale_img: Option<Image<u8, 1, CpuAllocator>>,
134    bin_img: Image<Pixel, 1, CpuAllocator>,
135    tile_min_max: TileMinMax,
136    uf: UnionFind,
137    clusters: HashMap<(usize, usize), Vec<GradientInfo>>,
138    gray_model_pair: GrayModelPair,
139}
140
141impl AprilTagDecoder {
142    /// Returns a reference to the decoder configuration.
143    #[inline]
144    pub fn config(&self) -> &DecodeTagsConfig {
145        &self.config
146    }
147
148    /// Adds a tag family to the decoder configuration.
149    #[inline]
150    pub fn add(&mut self, family: TagFamily) {
151        self.config.add(family);
152    }
153
154    /// Creates a new `AprilTagDecoder` with the given configuration and image size.
155    ///
156    /// # Arguments
157    ///
158    /// * `config` - The configuration for decoding AprilTags.
159    /// * `img_size` - The size of the image to be processed.
160    ///
161    /// # Returns
162    ///
163    /// Returns a `Result` containing the new `AprilTagDecoder` or an `AprilTagError`.
164    pub fn new(config: DecodeTagsConfig, img_size: ImageSize) -> Result<Self, AprilTagError> {
165        let (img_size, downscale_img) = if config.downscale_factor <= 1 {
166            (img_size, None)
167        } else {
168            let new_size = ImageSize {
169                width: img_size.width / config.downscale_factor,
170                height: img_size.height / config.downscale_factor,
171            };
172
173            (
174                new_size,
175                Some(Image::from_size_val(new_size, 0, CpuAllocator)?),
176            )
177        };
178
179        let bin_img = Image::from_size_val(img_size, Pixel::Skip, CpuAllocator)?;
180        let tile_min_max = TileMinMax::new(img_size, 4);
181        let uf = UnionFind::new(img_size.width * img_size.height);
182
183        Ok(Self {
184            config,
185            downscale_img,
186            bin_img,
187            tile_min_max,
188            uf,
189            clusters: HashMap::new(),
190            gray_model_pair: GrayModelPair::new(),
191        })
192    }
193
194    /// Decodes AprilTags from the provided grayscale image.
195    ///
196    /// # Arguments
197    ///
198    /// * `src` - The source grayscale image to decode tags from.
199    ///
200    /// # Returns
201    ///
202    /// Returns a `Result` containing a vector of `Detection` or an `AprilTagError`.
203    ///
204    /// # Note
205    ///
206    /// If you are running this method multiple times on the same decoder instance,
207    /// you should call [`AprilTagDecoder::clear`] between runs to reset internal state.
208    pub fn decode<A: ImageAllocator>(
209        &mut self,
210        src: &Image<u8, 1, A>,
211    ) -> Result<Vec<Detection>, AprilTagError> {
212        if let Some(downscale_img) = self.downscale_img.as_mut() {
213            resize_fast_mono(
214                src,
215                downscale_img,
216                kornia_imgproc::interpolation::InterpolationMode::Nearest,
217            )?;
218
219            // Step 1: Adaptive Threshold
220            adaptive_threshold(
221                downscale_img,
222                &mut self.bin_img,
223                &mut self.tile_min_max,
224                self.config.min_white_black_difference,
225            )?;
226        } else {
227            // Step 1: Adaptive Threshold
228            adaptive_threshold(
229                src,
230                &mut self.bin_img,
231                &mut self.tile_min_max,
232                self.config.min_white_black_difference,
233            )?;
234        }
235
236        // Step 2(a): Find Connected Components
237        find_connected_components(&self.bin_img, &mut self.uf)?;
238
239        // Step 2(b): Find Clusters
240        find_gradient_clusters(&self.bin_img, &mut self.uf, &mut self.clusters);
241
242        // Step 3: Quad Fitting
243        let mut quads = fit_quads(&self.bin_img, &mut self.clusters, &self.config);
244
245        // Step 4: Tag Decoding
246        Ok(decode_tags(
247            src,
248            &mut quads,
249            &mut self.config,
250            &mut self.gray_model_pair,
251        ))
252    }
253
254    /// Clears the internal state of the decoder for reuse.
255    pub fn clear(&mut self) {
256        self.uf.reset();
257        self.clusters.clear();
258        self.gray_model_pair.reset();
259    }
260
261    /// Returns a slice of tag families configured for detection.
262    pub fn tag_families(&self) -> &[TagFamily] {
263        &self.config.tag_families
264    }
265}
266
267/// Running the test on aarch64 crashes the CI
268#[cfg(all(test, not(target_arch = "aarch64")))]
269mod tests {
270    use kornia_io::png::read_image_png_mono8;
271
272    use crate::{family::TagFamilyKind, utils::Point2d, AprilTagDecoder, DecodeTagsConfig};
273
274    fn test_tags(
275        decoder: &mut AprilTagDecoder,
276        expected_tag: TagFamilyKind,
277        expected_quads: [Point2d<f32>; 4],
278        images_dir: &str,
279        file_name_starts_with: &str,
280    ) -> Result<(), Box<dyn std::error::Error>> {
281        let tag_images = std::fs::read_dir(images_dir)?;
282
283        for img in tag_images {
284            let img = img?;
285            let file_name = img.file_name();
286            let file_name = file_name
287                .to_str()
288                .ok_or("Failed to convert file name to str")?;
289
290            if file_name.starts_with(file_name_starts_with) {
291                let file_path = img.path();
292
293                let expected_id = file_name.strip_prefix(file_name_starts_with).unwrap();
294                let expected_id = expected_id.strip_suffix(".png").unwrap();
295                let Ok(expected_id) = expected_id.parse::<u16>() else {
296                    // Currently we only support decoding id upto 65535 (u16::MAX) while some tag families
297                    // like `TagCircle49H12` can support more than that.
298                    continue;
299                };
300
301                if expected_id == u16::MAX {
302                    continue;
303                }
304
305                let original_img = read_image_png_mono8(file_path)?;
306                let detection = decoder.decode(&original_img)?;
307
308                assert_eq!(detection.len(), 1, "Tag: {file_name}");
309                let detection = &detection[0];
310
311                assert_eq!(detection.id, expected_id);
312                assert_eq!(detection.tag_family_kind, expected_tag);
313
314                for (point, expected) in detection.quad.corners.iter().zip(expected_quads.iter()) {
315                    assert!(
316                        (point.y - expected.y).abs() <= 0.1,
317                        "Tag: {}, Got y: {}, Expected: {}",
318                        file_name,
319                        point.y,
320                        expected.y
321                    );
322                    assert!(
323                        (point.x - expected.x).abs() <= 0.1,
324                        "Tag: {}, Got x: {}, Expected: {}",
325                        file_name,
326                        point.x,
327                        expected.x
328                    );
329                }
330
331                decoder.clear();
332            }
333        }
334
335        Ok(())
336    }
337
338    #[test]
339    fn test_tag16_h5() -> Result<(), Box<dyn std::error::Error>> {
340        let config = DecodeTagsConfig::new(vec![TagFamilyKind::Tag16H5]);
341        let mut decoder = AprilTagDecoder::new(config, [50, 50].into())?;
342
343        let expected_quad = [
344            Point2d { x: 40.0, y: 10.0 },
345            Point2d { x: 40.0, y: 40.0 },
346            Point2d { x: 10.0, y: 40.0 },
347            Point2d { x: 10.0, y: 10.0 },
348        ];
349
350        test_tags(
351            &mut decoder,
352            TagFamilyKind::Tag16H5,
353            expected_quad,
354            "../../tests/data/apriltag-imgs/tag16h5/",
355            "tag16_05_",
356        )?;
357
358        Ok(())
359    }
360
361    #[test]
362    fn test_tag25_h9() -> Result<(), Box<dyn std::error::Error>> {
363        let config = DecodeTagsConfig::new(vec![TagFamilyKind::Tag25H9]);
364        let mut decoder = AprilTagDecoder::new(config, [55, 55].into())?;
365
366        let expected_quad = [
367            Point2d { x: 45.0, y: 10.0 },
368            Point2d { x: 45.0, y: 45.0 },
369            Point2d { x: 10.0, y: 45.0 },
370            Point2d { x: 10.0, y: 10.0 },
371        ];
372
373        test_tags(
374            &mut decoder,
375            TagFamilyKind::Tag25H9,
376            expected_quad,
377            "../../tests/data/apriltag-imgs/tag25h9/",
378            "tag25_09_",
379        )?;
380
381        Ok(())
382    }
383
384    #[test]
385    fn test_tag36_h11() -> Result<(), Box<dyn std::error::Error>> {
386        let config = DecodeTagsConfig::new(vec![TagFamilyKind::Tag36H11]);
387        let mut decoder = AprilTagDecoder::new(config, [60, 60].into())?;
388
389        let expected_quad = [
390            Point2d { x: 50.0, y: 10.0 },
391            Point2d { x: 50.0, y: 50.0 },
392            Point2d { x: 10.0, y: 50.0 },
393            Point2d { x: 10.0, y: 10.0 },
394        ];
395
396        test_tags(
397            &mut decoder,
398            TagFamilyKind::Tag36H11,
399            expected_quad,
400            "../../tests/data/apriltag-imgs/tag36h11/",
401            "tag36_11_",
402        )?;
403
404        Ok(())
405    }
406
407    #[test]
408    fn test_tagcircle21h7() -> Result<(), Box<dyn std::error::Error>> {
409        let config = DecodeTagsConfig::new(vec![TagFamilyKind::TagCircle21H7]);
410        let mut decoder = AprilTagDecoder::new(config, [55, 55].into())?;
411
412        let expected_quad = [
413            Point2d { x: 40.0, y: 15.0 },
414            Point2d { x: 40.0, y: 40.0 },
415            Point2d { x: 15.0, y: 40.0 },
416            Point2d { x: 15.0, y: 15.0 },
417        ];
418
419        test_tags(
420            &mut decoder,
421            TagFamilyKind::TagCircle21H7,
422            expected_quad,
423            "../../tests/data/apriltag-imgs/tagCircle21h7/",
424            "tag21_07_",
425        )?;
426
427        Ok(())
428    }
429
430    #[test]
431    fn test_tagcircle49h12() -> Result<(), Box<dyn std::error::Error>> {
432        let config = DecodeTagsConfig::new(vec![TagFamilyKind::TagCircle49H12]);
433        let mut decoder = AprilTagDecoder::new(config, [65, 65].into())?;
434
435        let expected_quad = [
436            Point2d { x: 45.0, y: 20.0 },
437            Point2d { x: 45.0, y: 45.0 },
438            Point2d { x: 20.0, y: 45.0 },
439            Point2d { x: 20.0, y: 20.0 },
440        ];
441
442        test_tags(
443            &mut decoder,
444            TagFamilyKind::TagCircle49H12,
445            expected_quad,
446            "../../tests/data/apriltag-imgs/tagCircle49h12/",
447            "tag49_12_",
448        )?;
449
450        Ok(())
451    }
452
453    #[test]
454    fn test_tagcustom48_h12() -> Result<(), Box<dyn std::error::Error>> {
455        let config = DecodeTagsConfig::new(vec![TagFamilyKind::TagCustom48H12]);
456        let mut decoder = AprilTagDecoder::new(config, [60, 60].into())?;
457
458        let expected_quad = [
459            Point2d { x: 45.0, y: 15.0 },
460            Point2d { x: 45.0, y: 45.0 },
461            Point2d { x: 15.0, y: 45.0 },
462            Point2d { x: 15.0, y: 15.0 },
463        ];
464
465        test_tags(
466            &mut decoder,
467            TagFamilyKind::TagCustom48H12,
468            expected_quad,
469            "../../tests/data/apriltag-imgs/tagCustom48h12/",
470            "tag48_12_",
471        )?;
472
473        Ok(())
474    }
475
476    #[test]
477    fn test_tagstandard41_h12() -> Result<(), Box<dyn std::error::Error>> {
478        let config = DecodeTagsConfig::new(vec![TagFamilyKind::TagStandard41H12]);
479        let mut decoder = AprilTagDecoder::new(config, [55, 55].into())?;
480
481        let expected_quad = [
482            Point2d { x: 40.0, y: 15.0 },
483            Point2d { x: 40.0, y: 40.0 },
484            Point2d { x: 15.0, y: 40.0 },
485            Point2d { x: 15.0, y: 15.0 },
486        ];
487
488        test_tags(
489            &mut decoder,
490            TagFamilyKind::TagStandard41H12,
491            expected_quad,
492            "../../tests/data/apriltag-imgs/tagStandard41h12/",
493            "tag41_12_",
494        )?;
495
496        Ok(())
497    }
498
499    #[test]
500    fn test_tagstandard52_h13() -> Result<(), Box<dyn std::error::Error>> {
501        let config = DecodeTagsConfig::new(vec![TagFamilyKind::TagStandard52H13]);
502        let mut decoder = AprilTagDecoder::new(config, [60, 60].into())?;
503
504        let expected_quad = [
505            Point2d { x: 45.0, y: 15.0 },
506            Point2d { x: 45.0, y: 45.0 },
507            Point2d { x: 15.0, y: 45.0 },
508            Point2d { x: 15.0, y: 15.0 },
509        ];
510
511        test_tags(
512            &mut decoder,
513            TagFamilyKind::TagStandard52H13,
514            expected_quad,
515            "../../tests/data/apriltag-imgs/tagStandard52h13/",
516            "tag52_13_",
517        )?;
518
519        Ok(())
520    }
521}