pico_detect/shape/
mod.rs

1mod delta;
2mod forest;
3mod tree;
4mod utils;
5
6use std::{
7    fmt::Debug,
8    io::{Error, ErrorKind, Read},
9};
10
11use image::Luma;
12use imageproc::rect::Rect;
13use nalgebra::{Affine2, DimName, Matrix3, Point2, SimilarityMatrix2, U2};
14
15use forest::ShaperForest;
16use pixelutil_image::ExtendedImageView;
17
18/// Implements object alignment using an ensemble of regression trees.
19#[derive(Clone)]
20pub struct Shaper {
21    depth: usize,
22    dsize: usize,
23    shape: Vec<Point2<f32>>,
24    forests: Vec<ShaperForest>,
25}
26
27impl Debug for Shaper {
28    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29        f.debug_struct(stringify!(Shaper))
30            .field("depth", &self.depth)
31            .field("dsize", &self.dsize)
32            .field("shape", &self.shape.len())
33            .field("forests", &self.forests.len())
34            .finish()
35    }
36}
37
38impl Shaper {
39    #[inline]
40    pub fn size(&self) -> usize {
41        self.shape.len()
42    }
43
44    #[inline]
45    pub fn init_points(&self) -> &[Point2<f32>] {
46        self.shape.as_ref()
47    }
48
49    /// Create a shaper object from a readable source.
50    #[inline]
51    pub fn load<R: Read>(mut reader: R) -> Result<Self, Error> {
52        let mut buf = [0u8; 4];
53        reader.read_exact(&mut buf[0..1])?;
54        let version = buf[0];
55        if version != 1 {
56            return Err(Error::new(ErrorKind::InvalidData, "wrong version"));
57        }
58
59        reader.read_exact(&mut buf)?;
60        let nrows = u32::from_be_bytes(buf) as usize;
61
62        reader.read_exact(&mut buf)?;
63        let ncols = u32::from_be_bytes(buf) as usize;
64
65        let shape_size = nrows * ncols / U2::DIM;
66
67        reader.read_exact(&mut buf)?;
68        let nforests = u32::from_be_bytes(buf) as usize;
69
70        reader.read_exact(&mut buf)?;
71        let forest_size = u32::from_be_bytes(buf) as usize;
72
73        reader.read_exact(&mut buf)?;
74        let tree_depth = u32::from_be_bytes(buf);
75
76        reader.read_exact(&mut buf)?;
77        let nfeatures = u32::from_be_bytes(buf) as usize;
78
79        let shifts_count = 2u32.pow(tree_depth) as usize;
80        let nodes_count = shifts_count - 1;
81
82        // dbg!(nrows, ncols, nforests, forest_size, tree_depth, nfeatures);
83        let shape: Vec<Point2<f32>> = utils::read_shape(reader.by_ref(), shape_size)?
84            .column_iter()
85            .map(|col| Point2::new(col.x, col.y))
86            .collect();
87
88        let mut forests = Vec::with_capacity(nforests);
89        for _ in 0..nforests {
90            forests.push(ShaperForest::load(
91                reader.by_ref(),
92                forest_size,
93                nodes_count,
94                shifts_count,
95                shape_size,
96                nfeatures,
97            )?);
98        }
99
100        Ok(Self {
101            depth: tree_depth as usize,
102            dsize: nodes_count,
103            shape,
104            forests,
105        })
106    }
107
108    // TODO:
109    /// Estimate object shape on the image
110    ///
111    /// ### Arguments
112    ///
113    /// * `image` - Target image.
114    ///
115    /// ### Returns
116    ///
117    /// A collection of points each one corresponds to landmark location.
118    /// Points count is defined by a loaded shaper model.
119    #[inline]
120    pub fn shape<I>(&self, image: &I, rect: Rect) -> Vec<Point2<f32>>
121    where
122        I: ExtendedImageView<Pixel = Luma<u8>>,
123    {
124        let mut shape = self.shape.clone();
125
126        let transform_to_image = find_transform_to_image(rect);
127
128        for forest in self.forests.iter() {
129            let transform_to_shape = Self::find_transform(self, shape.as_slice());
130
131            let features =
132                forest.extract_features(image, &transform_to_shape, &transform_to_image, &shape);
133
134            for tree in forest.trees_slice().iter() {
135                let idx = (0..self.depth).fold(0, |idx, _| {
136                    2 * idx
137                        + 1
138                        + unsafe { tree.node_unchecked(idx) }.bintest(features.as_slice()) as usize
139                }) - self.dsize;
140
141                shape
142                    .iter_mut()
143                    .zip(unsafe { tree.shift_unchecked(idx) }.iter())
144                    .for_each(|(shape_point, shift_vector)| {
145                        *shape_point += shift_vector;
146                    });
147            }
148        }
149
150        shape
151            .iter_mut()
152            .for_each(|point| *point = transform_to_image * *point);
153
154        shape
155    }
156
157    #[inline]
158    fn find_transform(&self, shape: &[Point2<f32>]) -> SimilarityMatrix2<f32> {
159        unsafe {
160            similarity_least_squares::from_point_slices(
161                self.shape.as_slice(),
162                shape,
163                f32::EPSILON,
164                0,
165            )
166            .unwrap_unchecked()
167        }
168    }
169}
170
171#[inline]
172fn find_transform_to_image(rect: Rect) -> Affine2<f32> {
173    Affine2::from_matrix_unchecked(Matrix3::new(
174        rect.width() as f32,
175        0.0,
176        rect.left() as f32,
177        0.0,
178        rect.height() as f32,
179        rect.top() as f32,
180        0.0,
181        0.0,
182        1.0,
183    ))
184}
185
186#[cfg(test)]
187mod tests {
188    use super::*;
189
190    #[test]
191    fn test_face_landmarks_model_loading() {
192        let shaper = dbg!(Shaper::load(
193            include_bytes!("../../models/face-5.shaper.bin")
194                .to_vec()
195                .as_slice(),
196        )
197        .expect("parsing failed"));
198
199        assert_eq!(shaper.forests.len(), 15);
200        assert_eq!(shaper.forests[0].trees(), 500);
201
202        assert_eq!(shaper.forests[0].tree(0).nodes(), 15);
203        assert_eq!(shaper.forests[0].tree(0).shifts(), 16);
204
205        assert_eq!(shaper.forests[0].deltas(), 800);
206    }
207}