pico_detect/localize/localizer/
mod.rs

1use std::fmt::Debug;
2use std::io::{Error, Read};
3
4use image::Luma;
5use nalgebra::{Point2, Translation2, Vector2};
6use pixelutil_image::ExtendedImageView;
7
8use crate::geometry::Target;
9use crate::nodes::ComparisonNode;
10
11type Tree = Vec<ComparisonNode>;
12type Predictions = Vec<Vector2<f32>>;
13type Stage = Vec<(Tree, Predictions)>;
14
15/// Implements object localization using decision trees.
16///
17/// Details available [here](https://tehnokv.com/posts/puploc-with-trees/).
18#[derive(Clone)]
19pub struct Localizer {
20    depth: usize,
21    dsize: usize,
22    scale: f32,
23    stages: Vec<Stage>,
24}
25
26impl Debug for Localizer {
27    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28        f.debug_struct(stringify!(Localizer))
29            .field("depth", &self.depth)
30            .field("dsize", &self.dsize)
31            .field("scale", &self.scale)
32            .field("stages", &self.stages.len())
33            .finish()
34    }
35}
36
37impl Localizer {
38    // TODO:
39    /// Estimate object location on the image
40    ///
41    /// ### Arguments
42    ///
43    /// * `image` - Target image.
44    /// * `roi` - Region of interest, which is the initial guess of the object location.
45    #[inline]
46    pub fn localize<I>(&self, image: &I, roi: Target) -> Point2<f32>
47    where
48        I: ExtendedImageView<Pixel = Luma<u8>>,
49    {
50        let Target {
51            mut point,
52            mut size,
53        } = roi;
54
55        for stage in self.stages.iter() {
56            let mut translation = Translation2::identity();
57            let p = unsafe { point.coords.try_cast::<i32>().unwrap_unchecked() }.into();
58            let s = size as u32;
59
60            for (codes, preds) in stage.iter() {
61                let idx = (0..self.depth).fold(0, |idx, _| {
62                    2 * idx + 1 + codes[idx].bintest(image, p, s) as usize
63                });
64                let lutidx = (idx + 1) - self.dsize;
65
66                translation.vector += preds[lutidx];
67            }
68
69            translation.vector.scale_mut(size);
70            *point = *translation.transform_point(&point);
71            size *= self.scale;
72        }
73
74        point
75    }
76
77    /// Load localizer from a readable source.
78    #[inline]
79    pub fn load(mut readable: impl Read) -> Result<Self, Error> {
80        let mut buffer: [u8; 4] = [0u8; 4];
81        readable.read_exact(&mut buffer)?;
82        let nstages = i32::from_le_bytes(buffer) as usize;
83
84        readable.read_exact(&mut buffer)?;
85        let scale = f32::from_le_bytes(buffer);
86
87        readable.read_exact(&mut buffer)?;
88        let ntrees = i32::from_le_bytes(buffer) as usize;
89
90        readable.read_exact(&mut buffer)?;
91        let depth = i32::from_le_bytes(buffer) as usize;
92        let pred_size: usize = match 2usize.checked_pow(depth as u32) {
93            Some(value) => value,
94            None => return Err(Error::other("depth overflow")),
95        };
96        let code_size = pred_size - 1;
97
98        let mut stages = Vec::with_capacity(nstages);
99
100        for _ in 0..nstages {
101            let mut stage: Stage = Vec::with_capacity(ntrees);
102
103            for _ in 0..ntrees {
104                let mut tree: Tree = Vec::with_capacity(code_size);
105                let mut predictions: Predictions = Vec::with_capacity(pred_size);
106
107                for _ in 0..code_size {
108                    readable.read_exact(&mut buffer)?;
109                    let node = ComparisonNode::from(buffer);
110                    tree.push(node);
111                }
112
113                for _ in 0..pred_size {
114                    readable.read_exact(&mut buffer)?;
115                    let y = f32::from_le_bytes(buffer);
116
117                    readable.read_exact(&mut buffer)?;
118                    let x = f32::from_le_bytes(buffer);
119
120                    predictions.push(Vector2::new(x, y));
121                }
122
123                stage.push((tree, predictions));
124            }
125
126            stages.push(stage);
127        }
128
129        Ok(Self {
130            depth,
131            dsize: pred_size,
132            scale,
133            stages,
134        })
135    }
136}
137
138#[cfg(test)]
139mod tests {
140    use super::*;
141
142    #[test]
143    fn test_pupil_localizer_model_loading() {
144        let puploc = dbg!(Localizer::load(
145            include_bytes!("../../../models/pupil.localizer.bin")
146                .to_vec()
147                .as_slice(),
148        )
149        .expect("parsing failed"));
150
151        let stages = &puploc.stages;
152        let trees = stages[0].len();
153
154        assert_eq!(5, stages.len());
155        assert_eq!(20, trees);
156        assert_eq!(10, puploc.depth);
157        assert_eq!(80, (puploc.scale * 100.0) as u32);
158
159        let dsize = 2usize.pow(puploc.depth as u32);
160
161        let first_node = ComparisonNode::from([30i8, -16i8, 125i8, 14i8]);
162        let last_node = ComparisonNode::from([-125i8, 26i8, 15i8, 98i8]);
163        assert_eq!(first_node, stages[0][0].0[0]);
164        assert_eq!(
165            last_node,
166            stages[stages.len() - 1][trees - 1].0[dsize - 1 - 1]
167        );
168
169        let first_pred_test = Vector2::new(-0.08540829f32, 0.04436668f32);
170        let last_pred_test = Vector2::new(0.05820565f32, 0.02249731f32);
171        let first_pred = stages[0][0].1[0];
172        let last_pred = stages[stages.len() - 1][trees - 1].1[dsize - 1];
173        assert_abs_diff_eq!(first_pred_test, first_pred);
174        assert_abs_diff_eq!(last_pred_test, last_pred);
175    }
176}