pico_detect/localize/localizer/
mod.rs1use std::fmt::Debug;
2use std::io::{Error, Read};
3
4use image::Luma;
5use nalgebra::{Point2, Translation2, Vector2};
6use pixelutil_image::ExtendedImageView;
7
8use crate::geometry::Target;
9use crate::nodes::ComparisonNode;
10
11type Tree = Vec<ComparisonNode>;
12type Predictions = Vec<Vector2<f32>>;
13type Stage = Vec<(Tree, Predictions)>;
14
15#[derive(Clone)]
19pub struct Localizer {
20 depth: usize,
21 dsize: usize,
22 scale: f32,
23 stages: Vec<Stage>,
24}
25
26impl Debug for Localizer {
27 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28 f.debug_struct(stringify!(Localizer))
29 .field("depth", &self.depth)
30 .field("dsize", &self.dsize)
31 .field("scale", &self.scale)
32 .field("stages", &self.stages.len())
33 .finish()
34 }
35}
36
37impl Localizer {
38 #[inline]
46 pub fn localize<I>(&self, image: &I, roi: Target) -> Point2<f32>
47 where
48 I: ExtendedImageView<Pixel = Luma<u8>>,
49 {
50 let Target {
51 mut point,
52 mut size,
53 } = roi;
54
55 for stage in self.stages.iter() {
56 let mut translation = Translation2::identity();
57 let p = unsafe { point.coords.try_cast::<i32>().unwrap_unchecked() }.into();
58 let s = size as u32;
59
60 for (codes, preds) in stage.iter() {
61 let idx = (0..self.depth).fold(0, |idx, _| {
62 2 * idx + 1 + codes[idx].bintest(image, p, s) as usize
63 });
64 let lutidx = (idx + 1) - self.dsize;
65
66 translation.vector += preds[lutidx];
67 }
68
69 translation.vector.scale_mut(size);
70 *point = *translation.transform_point(&point);
71 size *= self.scale;
72 }
73
74 point
75 }
76
77 #[inline]
79 pub fn load(mut readable: impl Read) -> Result<Self, Error> {
80 let mut buffer: [u8; 4] = [0u8; 4];
81 readable.read_exact(&mut buffer)?;
82 let nstages = i32::from_le_bytes(buffer) as usize;
83
84 readable.read_exact(&mut buffer)?;
85 let scale = f32::from_le_bytes(buffer);
86
87 readable.read_exact(&mut buffer)?;
88 let ntrees = i32::from_le_bytes(buffer) as usize;
89
90 readable.read_exact(&mut buffer)?;
91 let depth = i32::from_le_bytes(buffer) as usize;
92 let pred_size: usize = match 2usize.checked_pow(depth as u32) {
93 Some(value) => value,
94 None => return Err(Error::other("depth overflow")),
95 };
96 let code_size = pred_size - 1;
97
98 let mut stages = Vec::with_capacity(nstages);
99
100 for _ in 0..nstages {
101 let mut stage: Stage = Vec::with_capacity(ntrees);
102
103 for _ in 0..ntrees {
104 let mut tree: Tree = Vec::with_capacity(code_size);
105 let mut predictions: Predictions = Vec::with_capacity(pred_size);
106
107 for _ in 0..code_size {
108 readable.read_exact(&mut buffer)?;
109 let node = ComparisonNode::from(buffer);
110 tree.push(node);
111 }
112
113 for _ in 0..pred_size {
114 readable.read_exact(&mut buffer)?;
115 let y = f32::from_le_bytes(buffer);
116
117 readable.read_exact(&mut buffer)?;
118 let x = f32::from_le_bytes(buffer);
119
120 predictions.push(Vector2::new(x, y));
121 }
122
123 stage.push((tree, predictions));
124 }
125
126 stages.push(stage);
127 }
128
129 Ok(Self {
130 depth,
131 dsize: pred_size,
132 scale,
133 stages,
134 })
135 }
136}
137
138#[cfg(test)]
139mod tests {
140 use super::*;
141
142 #[test]
143 fn test_pupil_localizer_model_loading() {
144 let puploc = dbg!(Localizer::load(
145 include_bytes!("../../../models/pupil.localizer.bin")
146 .to_vec()
147 .as_slice(),
148 )
149 .expect("parsing failed"));
150
151 let stages = &puploc.stages;
152 let trees = stages[0].len();
153
154 assert_eq!(5, stages.len());
155 assert_eq!(20, trees);
156 assert_eq!(10, puploc.depth);
157 assert_eq!(80, (puploc.scale * 100.0) as u32);
158
159 let dsize = 2usize.pow(puploc.depth as u32);
160
161 let first_node = ComparisonNode::from([30i8, -16i8, 125i8, 14i8]);
162 let last_node = ComparisonNode::from([-125i8, 26i8, 15i8, 98i8]);
163 assert_eq!(first_node, stages[0][0].0[0]);
164 assert_eq!(
165 last_node,
166 stages[stages.len() - 1][trees - 1].0[dsize - 1 - 1]
167 );
168
169 let first_pred_test = Vector2::new(-0.08540829f32, 0.04436668f32);
170 let last_pred_test = Vector2::new(0.05820565f32, 0.02249731f32);
171 let first_pred = stages[0][0].1[0];
172 let last_pred = stages[stages.len() - 1][trees - 1].1[dsize - 1];
173 assert_abs_diff_eq!(first_pred_test, first_pred);
174 assert_abs_diff_eq!(last_pred_test, last_pred);
175 }
176}