1#![deny(missing_docs)]
2use std::collections::HashMap;
5
6use kornia_image::{
7 allocator::{CpuAllocator, ImageAllocator},
8 Image, ImageSize,
9};
10use kornia_imgproc::resize::resize_fast_mono;
11
12use crate::{
13 decoder::{decode_tags, Detection, GrayModelPair},
14 errors::AprilTagError,
15 family::{TagFamily, TagFamilyKind},
16 quad::{fit_quads, FitQuadConfig},
17 segmentation::{find_connected_components, find_gradient_clusters, GradientInfo},
18 threshold::{adaptive_threshold, TileMinMax},
19 union_find::UnionFind,
20 utils::Pixel,
21};
22
23pub mod errors;
25
26pub mod utils;
28
29pub mod threshold;
31
32pub(crate) mod iter;
34
35pub mod segmentation;
37
38pub mod union_find;
40
41pub mod family;
43
44pub mod quad;
46
47pub mod decoder;
49
50#[derive(Debug, Clone, PartialEq)]
52pub struct DecodeTagsConfig {
53 pub tag_families: Vec<TagFamily>,
55 pub fit_quad_config: FitQuadConfig,
57 pub refine_edges_enabled: bool,
59 pub decode_sharpening: f32,
61 pub normal_border: bool,
63 pub reversed_border: bool,
65 pub min_tag_width: usize,
67 pub min_white_black_difference: u8,
69 pub downscale_factor: usize,
71}
72
73impl DecodeTagsConfig {
74 pub fn new(tag_family_kinds: Vec<TagFamilyKind>) -> Self {
76 const DEFAULT_DOWNSCALE_FACTOR: usize = 2;
77
78 let mut tag_families = Vec::with_capacity(tag_family_kinds.len());
79 let mut normal_border = false;
80 let mut reversed_border = false;
81 let mut min_tag_width = usize::MAX;
82
83 tag_family_kinds.iter().for_each(|family_kind| {
84 let family: TagFamily = family_kind.into();
85 if family.width_at_border < min_tag_width {
86 min_tag_width = family.width_at_border;
87 }
88 normal_border |= !family.reversed_border;
89 reversed_border |= family.reversed_border;
90
91 tag_families.push(family);
92 });
93
94 min_tag_width /= DEFAULT_DOWNSCALE_FACTOR;
95
96 if min_tag_width < 3 {
97 min_tag_width = 3;
98 }
99
100 Self {
101 tag_families,
102 fit_quad_config: Default::default(),
103 normal_border,
104 refine_edges_enabled: true,
105 decode_sharpening: 0.25,
106 reversed_border,
107 min_tag_width,
108 min_white_black_difference: 5,
109 downscale_factor: DEFAULT_DOWNSCALE_FACTOR,
110 }
111 }
112
113 pub fn all() -> Self {
115 Self::new(TagFamilyKind::all())
116 }
117
118 pub fn add(&mut self, family: TagFamily) {
120 if family.width_at_border < self.min_tag_width {
121 self.min_tag_width = family.width_at_border;
122 }
123 self.normal_border |= !family.reversed_border;
124 self.reversed_border |= family.reversed_border;
125
126 self.tag_families.push(family);
127 }
128}
129
130pub struct AprilTagDecoder {
132 config: DecodeTagsConfig,
133 downscale_img: Option<Image<u8, 1, CpuAllocator>>,
134 bin_img: Image<Pixel, 1, CpuAllocator>,
135 tile_min_max: TileMinMax,
136 uf: UnionFind,
137 clusters: HashMap<(usize, usize), Vec<GradientInfo>>,
138 gray_model_pair: GrayModelPair,
139}
140
141impl AprilTagDecoder {
142 #[inline]
144 pub fn config(&self) -> &DecodeTagsConfig {
145 &self.config
146 }
147
148 #[inline]
150 pub fn add(&mut self, family: TagFamily) {
151 self.config.add(family);
152 }
153
154 pub fn new(config: DecodeTagsConfig, img_size: ImageSize) -> Result<Self, AprilTagError> {
165 let (img_size, downscale_img) = if config.downscale_factor <= 1 {
166 (img_size, None)
167 } else {
168 let new_size = ImageSize {
169 width: img_size.width / config.downscale_factor,
170 height: img_size.height / config.downscale_factor,
171 };
172
173 (
174 new_size,
175 Some(Image::from_size_val(new_size, 0, CpuAllocator)?),
176 )
177 };
178
179 let bin_img = Image::from_size_val(img_size, Pixel::Skip, CpuAllocator)?;
180 let tile_min_max = TileMinMax::new(img_size, 4);
181 let uf = UnionFind::new(img_size.width * img_size.height);
182
183 Ok(Self {
184 config,
185 downscale_img,
186 bin_img,
187 tile_min_max,
188 uf,
189 clusters: HashMap::new(),
190 gray_model_pair: GrayModelPair::new(),
191 })
192 }
193
194 pub fn decode<A: ImageAllocator>(
209 &mut self,
210 src: &Image<u8, 1, A>,
211 ) -> Result<Vec<Detection>, AprilTagError> {
212 if let Some(downscale_img) = self.downscale_img.as_mut() {
213 resize_fast_mono(
214 src,
215 downscale_img,
216 kornia_imgproc::interpolation::InterpolationMode::Nearest,
217 )?;
218
219 adaptive_threshold(
221 downscale_img,
222 &mut self.bin_img,
223 &mut self.tile_min_max,
224 self.config.min_white_black_difference,
225 )?;
226 } else {
227 adaptive_threshold(
229 src,
230 &mut self.bin_img,
231 &mut self.tile_min_max,
232 self.config.min_white_black_difference,
233 )?;
234 }
235
236 find_connected_components(&self.bin_img, &mut self.uf)?;
238
239 find_gradient_clusters(&self.bin_img, &mut self.uf, &mut self.clusters);
241
242 let mut quads = fit_quads(&self.bin_img, &mut self.clusters, &self.config);
244
245 Ok(decode_tags(
247 src,
248 &mut quads,
249 &mut self.config,
250 &mut self.gray_model_pair,
251 ))
252 }
253
254 pub fn clear(&mut self) {
256 self.uf.reset();
257 self.clusters.clear();
258 self.gray_model_pair.reset();
259 }
260
261 pub fn tag_families(&self) -> &[TagFamily] {
263 &self.config.tag_families
264 }
265}
266
267#[cfg(all(test, not(target_arch = "aarch64")))]
269mod tests {
270 use kornia_io::png::read_image_png_mono8;
271
272 use crate::{family::TagFamilyKind, utils::Point2d, AprilTagDecoder, DecodeTagsConfig};
273
274 fn test_tags(
275 decoder: &mut AprilTagDecoder,
276 expected_tag: TagFamilyKind,
277 expected_quads: [Point2d<f32>; 4],
278 images_dir: &str,
279 file_name_starts_with: &str,
280 ) -> Result<(), Box<dyn std::error::Error>> {
281 let tag_images = std::fs::read_dir(images_dir)?;
282
283 for img in tag_images {
284 let img = img?;
285 let file_name = img.file_name();
286 let file_name = file_name
287 .to_str()
288 .ok_or("Failed to convert file name to str")?;
289
290 if file_name.starts_with(file_name_starts_with) {
291 let file_path = img.path();
292
293 let expected_id = file_name.strip_prefix(file_name_starts_with).unwrap();
294 let expected_id = expected_id.strip_suffix(".png").unwrap();
295 let Ok(expected_id) = expected_id.parse::<u16>() else {
296 continue;
299 };
300
301 if expected_id == u16::MAX {
302 continue;
303 }
304
305 let original_img = read_image_png_mono8(file_path)?;
306 let detection = decoder.decode(&original_img)?;
307
308 assert_eq!(detection.len(), 1, "Tag: {file_name}");
309 let detection = &detection[0];
310
311 assert_eq!(detection.id, expected_id);
312 assert_eq!(detection.tag_family_kind, expected_tag);
313
314 for (point, expected) in detection.quad.corners.iter().zip(expected_quads.iter()) {
315 assert!(
316 (point.y - expected.y).abs() <= 0.1,
317 "Tag: {}, Got y: {}, Expected: {}",
318 file_name,
319 point.y,
320 expected.y
321 );
322 assert!(
323 (point.x - expected.x).abs() <= 0.1,
324 "Tag: {}, Got x: {}, Expected: {}",
325 file_name,
326 point.x,
327 expected.x
328 );
329 }
330
331 decoder.clear();
332 }
333 }
334
335 Ok(())
336 }
337
338 #[test]
339 fn test_tag16_h5() -> Result<(), Box<dyn std::error::Error>> {
340 let config = DecodeTagsConfig::new(vec![TagFamilyKind::Tag16H5]);
341 let mut decoder = AprilTagDecoder::new(config, [50, 50].into())?;
342
343 let expected_quad = [
344 Point2d { x: 40.0, y: 10.0 },
345 Point2d { x: 40.0, y: 40.0 },
346 Point2d { x: 10.0, y: 40.0 },
347 Point2d { x: 10.0, y: 10.0 },
348 ];
349
350 test_tags(
351 &mut decoder,
352 TagFamilyKind::Tag16H5,
353 expected_quad,
354 "../../tests/data/apriltag-imgs/tag16h5/",
355 "tag16_05_",
356 )?;
357
358 Ok(())
359 }
360
361 #[test]
362 fn test_tag25_h9() -> Result<(), Box<dyn std::error::Error>> {
363 let config = DecodeTagsConfig::new(vec![TagFamilyKind::Tag25H9]);
364 let mut decoder = AprilTagDecoder::new(config, [55, 55].into())?;
365
366 let expected_quad = [
367 Point2d { x: 45.0, y: 10.0 },
368 Point2d { x: 45.0, y: 45.0 },
369 Point2d { x: 10.0, y: 45.0 },
370 Point2d { x: 10.0, y: 10.0 },
371 ];
372
373 test_tags(
374 &mut decoder,
375 TagFamilyKind::Tag25H9,
376 expected_quad,
377 "../../tests/data/apriltag-imgs/tag25h9/",
378 "tag25_09_",
379 )?;
380
381 Ok(())
382 }
383
384 #[test]
385 fn test_tag36_h11() -> Result<(), Box<dyn std::error::Error>> {
386 let config = DecodeTagsConfig::new(vec![TagFamilyKind::Tag36H11]);
387 let mut decoder = AprilTagDecoder::new(config, [60, 60].into())?;
388
389 let expected_quad = [
390 Point2d { x: 50.0, y: 10.0 },
391 Point2d { x: 50.0, y: 50.0 },
392 Point2d { x: 10.0, y: 50.0 },
393 Point2d { x: 10.0, y: 10.0 },
394 ];
395
396 test_tags(
397 &mut decoder,
398 TagFamilyKind::Tag36H11,
399 expected_quad,
400 "../../tests/data/apriltag-imgs/tag36h11/",
401 "tag36_11_",
402 )?;
403
404 Ok(())
405 }
406
407 #[test]
408 fn test_tagcircle21h7() -> Result<(), Box<dyn std::error::Error>> {
409 let config = DecodeTagsConfig::new(vec![TagFamilyKind::TagCircle21H7]);
410 let mut decoder = AprilTagDecoder::new(config, [55, 55].into())?;
411
412 let expected_quad = [
413 Point2d { x: 40.0, y: 15.0 },
414 Point2d { x: 40.0, y: 40.0 },
415 Point2d { x: 15.0, y: 40.0 },
416 Point2d { x: 15.0, y: 15.0 },
417 ];
418
419 test_tags(
420 &mut decoder,
421 TagFamilyKind::TagCircle21H7,
422 expected_quad,
423 "../../tests/data/apriltag-imgs/tagCircle21h7/",
424 "tag21_07_",
425 )?;
426
427 Ok(())
428 }
429
430 #[test]
431 fn test_tagcircle49h12() -> Result<(), Box<dyn std::error::Error>> {
432 let config = DecodeTagsConfig::new(vec![TagFamilyKind::TagCircle49H12]);
433 let mut decoder = AprilTagDecoder::new(config, [65, 65].into())?;
434
435 let expected_quad = [
436 Point2d { x: 45.0, y: 20.0 },
437 Point2d { x: 45.0, y: 45.0 },
438 Point2d { x: 20.0, y: 45.0 },
439 Point2d { x: 20.0, y: 20.0 },
440 ];
441
442 test_tags(
443 &mut decoder,
444 TagFamilyKind::TagCircle49H12,
445 expected_quad,
446 "../../tests/data/apriltag-imgs/tagCircle49h12/",
447 "tag49_12_",
448 )?;
449
450 Ok(())
451 }
452
453 #[test]
454 fn test_tagcustom48_h12() -> Result<(), Box<dyn std::error::Error>> {
455 let config = DecodeTagsConfig::new(vec![TagFamilyKind::TagCustom48H12]);
456 let mut decoder = AprilTagDecoder::new(config, [60, 60].into())?;
457
458 let expected_quad = [
459 Point2d { x: 45.0, y: 15.0 },
460 Point2d { x: 45.0, y: 45.0 },
461 Point2d { x: 15.0, y: 45.0 },
462 Point2d { x: 15.0, y: 15.0 },
463 ];
464
465 test_tags(
466 &mut decoder,
467 TagFamilyKind::TagCustom48H12,
468 expected_quad,
469 "../../tests/data/apriltag-imgs/tagCustom48h12/",
470 "tag48_12_",
471 )?;
472
473 Ok(())
474 }
475
476 #[test]
477 fn test_tagstandard41_h12() -> Result<(), Box<dyn std::error::Error>> {
478 let config = DecodeTagsConfig::new(vec![TagFamilyKind::TagStandard41H12]);
479 let mut decoder = AprilTagDecoder::new(config, [55, 55].into())?;
480
481 let expected_quad = [
482 Point2d { x: 40.0, y: 15.0 },
483 Point2d { x: 40.0, y: 40.0 },
484 Point2d { x: 15.0, y: 40.0 },
485 Point2d { x: 15.0, y: 15.0 },
486 ];
487
488 test_tags(
489 &mut decoder,
490 TagFamilyKind::TagStandard41H12,
491 expected_quad,
492 "../../tests/data/apriltag-imgs/tagStandard41h12/",
493 "tag41_12_",
494 )?;
495
496 Ok(())
497 }
498
499 #[test]
500 fn test_tagstandard52_h13() -> Result<(), Box<dyn std::error::Error>> {
501 let config = DecodeTagsConfig::new(vec![TagFamilyKind::TagStandard52H13]);
502 let mut decoder = AprilTagDecoder::new(config, [60, 60].into())?;
503
504 let expected_quad = [
505 Point2d { x: 45.0, y: 15.0 },
506 Point2d { x: 45.0, y: 45.0 },
507 Point2d { x: 15.0, y: 45.0 },
508 Point2d { x: 15.0, y: 15.0 },
509 ];
510
511 test_tags(
512 &mut decoder,
513 TagFamilyKind::TagStandard52H13,
514 expected_quad,
515 "../../tests/data/apriltag-imgs/tagStandard52h13/",
516 "tag52_13_",
517 )?;
518
519 Ok(())
520 }
521}