pico_detect/detect/detector/
mod.rs

1mod tree;
2
3use std::fmt::Debug;
4use std::io::{Error, Read};
5
6use image::Luma;
7use pixelutil_image::ExtendedImageView;
8
9use crate::geometry::Square;
10use crate::traits::Region;
11
12use super::Detection;
13
14use tree::DetectorTree;
15
16/// Implements object detection using a cascade of decision tree classifiers.
17#[derive(Clone)]
18pub struct Detector {
19    depth: usize,
20    dsize: usize,
21    threshold: f32,
22    forest: Vec<DetectorTree>,
23}
24
25impl Debug for Detector {
26    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27        f.debug_struct(stringify!(Detector))
28            .field("depth", &self.depth)
29            .field("dsize", &self.dsize)
30            .field("threshold", &self.threshold)
31            .field("trees", &self.forest.len())
32            .finish()
33    }
34}
35
36impl Detector {
37    /// Estimate detection score for the rectangular region.
38    ///
39    /// ### Arguments
40    ///
41    /// * `image` -- target image;
42    /// * `region` -- rectangular region to classify.
43    ///
44    /// ### Returns
45    ///
46    /// * `Some(f32)` passed region is an object with score;
47    /// * `None` -- if passed region is not an object.
48    #[inline]
49    pub fn classify<I>(&self, image: &I, region: Square) -> Option<f32>
50    where
51        I: ExtendedImageView<Pixel = Luma<u8>>,
52    {
53        let mut result = 0.0f32;
54        let point = region.center();
55
56        for tree in self.forest.iter() {
57            let idx = (0..self.depth).fold(1, |idx, _| {
58                2 * idx + !tree.nodes[idx].bintest(image, point, region.size()) as usize
59            });
60            let lutidx = idx - self.dsize;
61            result += tree.predictions[lutidx];
62
63            if result < tree.threshold {
64                return None;
65            }
66        }
67        Some(result - self.threshold)
68    }
69
70    /// Detect an object in the rectangular region.
71    #[inline]
72    pub fn detect<I>(&self, image: &I, region: Square) -> Option<Detection<Square>>
73    where
74        I: ExtendedImageView<Pixel = Luma<u8>>,
75    {
76        self.classify(image, region)
77            .map(|score| Detection { region, score })
78    }
79
80    /// Create a detector object from a readable source.
81    #[inline]
82    pub fn load(mut readable: impl Read) -> Result<Self, Error> {
83        let mut buffer: [u8; 4] = [0u8; 4];
84        // skip first 8 bytes;
85        readable.read_exact(&mut [0; 8])?;
86
87        readable.read_exact(&mut buffer)?;
88        let depth = i32::from_le_bytes(buffer) as usize;
89
90        let tree_size: usize = match 2usize.checked_pow(depth as u32) {
91            Some(value) => value,
92            None => return Err(Error::other("depth overflow")),
93        };
94
95        readable.read_exact(&mut buffer)?;
96        let ntrees = i32::from_le_bytes(buffer) as usize;
97
98        let mut trees: Vec<DetectorTree> = Vec::with_capacity(ntrees);
99
100        for _ in 0..ntrees {
101            trees.push(DetectorTree::load(&mut readable, tree_size)?);
102        }
103
104        let threshold = trees.last().ok_or(Error::other("No trees"))?.threshold;
105
106        Ok(Self {
107            depth,
108            dsize: tree_size,
109            forest: trees,
110            threshold,
111        })
112    }
113}
114
115#[cfg(test)]
116mod tests {
117    use crate::nodes::ComparisonNode;
118
119    use super::*;
120
121    #[test]
122    fn test_face_detector_model_loading() {
123        let facefinder = dbg!(Detector::load(
124            include_bytes!("../../../models/face.detector.bin")
125                .to_vec()
126                .as_slice(),
127        )
128        .expect("parsing failed"));
129
130        // for tree in facefinder.forest.iter() {
131        //     println!("{:?}", tree);
132        // }
133
134        assert_eq!(6, facefinder.depth);
135        assert_eq!(468, facefinder.forest.len());
136
137        let second_node = ComparisonNode::from([-17i8, 36i8, -55i8, 7i8]);
138        let last_node = ComparisonNode::from([-26i8, -84i8, -48i8, 0i8]);
139        assert_eq!(second_node, facefinder.forest[0].nodes[1]);
140        assert_eq!(
141            last_node,
142            *facefinder.forest.last().unwrap().nodes.last().unwrap(),
143        );
144
145        assert_abs_diff_eq!(facefinder.forest[0].threshold, -0.7550662159919739f32);
146        assert_abs_diff_eq!(
147            facefinder.forest.last().unwrap().threshold,
148            -1.9176125526428223f32
149        );
150
151        assert_abs_diff_eq!(facefinder.forest[0].predictions[0], -0.7820115089416504f32);
152        assert_abs_diff_eq!(
153            *facefinder
154                .forest
155                .last()
156                .unwrap()
157                .predictions
158                .last()
159                .unwrap(),
160            0.07058460265398026f32
161        );
162    }
163}