pico_detect/detect/detector/
mod.rs1mod tree;
2
3use std::fmt::Debug;
4use std::io::{Error, Read};
5
6use image::Luma;
7use pixelutil_image::ExtendedImageView;
8
9use crate::geometry::Square;
10use crate::traits::Region;
11
12use super::Detection;
13
14use tree::DetectorTree;
15
16#[derive(Clone)]
18pub struct Detector {
19 depth: usize,
20 dsize: usize,
21 threshold: f32,
22 forest: Vec<DetectorTree>,
23}
24
25impl Debug for Detector {
26 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27 f.debug_struct(stringify!(Detector))
28 .field("depth", &self.depth)
29 .field("dsize", &self.dsize)
30 .field("threshold", &self.threshold)
31 .field("trees", &self.forest.len())
32 .finish()
33 }
34}
35
36impl Detector {
37 #[inline]
49 pub fn classify<I>(&self, image: &I, region: Square) -> Option<f32>
50 where
51 I: ExtendedImageView<Pixel = Luma<u8>>,
52 {
53 let mut result = 0.0f32;
54 let point = region.center();
55
56 for tree in self.forest.iter() {
57 let idx = (0..self.depth).fold(1, |idx, _| {
58 2 * idx + !tree.nodes[idx].bintest(image, point, region.size()) as usize
59 });
60 let lutidx = idx - self.dsize;
61 result += tree.predictions[lutidx];
62
63 if result < tree.threshold {
64 return None;
65 }
66 }
67 Some(result - self.threshold)
68 }
69
70 #[inline]
72 pub fn detect<I>(&self, image: &I, region: Square) -> Option<Detection<Square>>
73 where
74 I: ExtendedImageView<Pixel = Luma<u8>>,
75 {
76 self.classify(image, region)
77 .map(|score| Detection { region, score })
78 }
79
80 #[inline]
82 pub fn load(mut readable: impl Read) -> Result<Self, Error> {
83 let mut buffer: [u8; 4] = [0u8; 4];
84 readable.read_exact(&mut [0; 8])?;
86
87 readable.read_exact(&mut buffer)?;
88 let depth = i32::from_le_bytes(buffer) as usize;
89
90 let tree_size: usize = match 2usize.checked_pow(depth as u32) {
91 Some(value) => value,
92 None => return Err(Error::other("depth overflow")),
93 };
94
95 readable.read_exact(&mut buffer)?;
96 let ntrees = i32::from_le_bytes(buffer) as usize;
97
98 let mut trees: Vec<DetectorTree> = Vec::with_capacity(ntrees);
99
100 for _ in 0..ntrees {
101 trees.push(DetectorTree::load(&mut readable, tree_size)?);
102 }
103
104 let threshold = trees.last().ok_or(Error::other("No trees"))?.threshold;
105
106 Ok(Self {
107 depth,
108 dsize: tree_size,
109 forest: trees,
110 threshold,
111 })
112 }
113}
114
115#[cfg(test)]
116mod tests {
117 use crate::nodes::ComparisonNode;
118
119 use super::*;
120
121 #[test]
122 fn test_face_detector_model_loading() {
123 let facefinder = dbg!(Detector::load(
124 include_bytes!("../../../models/face.detector.bin")
125 .to_vec()
126 .as_slice(),
127 )
128 .expect("parsing failed"));
129
130 assert_eq!(6, facefinder.depth);
135 assert_eq!(468, facefinder.forest.len());
136
137 let second_node = ComparisonNode::from([-17i8, 36i8, -55i8, 7i8]);
138 let last_node = ComparisonNode::from([-26i8, -84i8, -48i8, 0i8]);
139 assert_eq!(second_node, facefinder.forest[0].nodes[1]);
140 assert_eq!(
141 last_node,
142 *facefinder.forest.last().unwrap().nodes.last().unwrap(),
143 );
144
145 assert_abs_diff_eq!(facefinder.forest[0].threshold, -0.7550662159919739f32);
146 assert_abs_diff_eq!(
147 facefinder.forest.last().unwrap().threshold,
148 -1.9176125526428223f32
149 );
150
151 assert_abs_diff_eq!(facefinder.forest[0].predictions[0], -0.7820115089416504f32);
152 assert_abs_diff_eq!(
153 *facefinder
154 .forest
155 .last()
156 .unwrap()
157 .predictions
158 .last()
159 .unwrap(),
160 0.07058460265398026f32
161 );
162 }
163}