yscv_recognize/
recognizer.rs1use std::fs;
2use std::path::Path;
3
4use yscv_tensor::Tensor;
5
6use super::RecognizeError;
7use super::similarity::cosine_similarity_prevalidated;
8use super::snapshot::{IdentitySnapshot, RecognizerSnapshot};
9use super::types::{IdentityEmbedding, Recognition};
10use super::validate::{validate_embedding, validate_embedding_slice, validate_threshold};
11use super::vp_tree::VpTree;
12
13#[derive(Debug, Clone)]
14pub struct Recognizer {
15 threshold: f32,
16 entries: Vec<IdentityEmbedding>,
17 embedding_dim: Option<usize>,
18 index: Option<VpTree>,
19}
20
21impl Recognizer {
22 pub fn new(threshold: f32) -> Result<Self, RecognizeError> {
23 validate_threshold(threshold)?;
24 Ok(Self {
25 threshold,
26 entries: Vec::new(),
27 embedding_dim: None,
28 index: None,
29 })
30 }
31
32 pub fn threshold(&self) -> f32 {
33 self.threshold
34 }
35
36 pub fn set_threshold(&mut self, threshold: f32) -> Result<(), RecognizeError> {
37 validate_threshold(threshold)?;
38 self.threshold = threshold;
39 Ok(())
40 }
41
42 pub fn enroll(
43 &mut self,
44 id: impl Into<String>,
45 embedding: Tensor,
46 ) -> Result<(), RecognizeError> {
47 validate_embedding(&embedding)?;
48 let id = id.into();
49 if self.entries.iter().any(|entry| entry.id == id) {
50 return Err(RecognizeError::DuplicateIdentity { id });
51 }
52 self.enforce_dim(embedding.len())?;
53 self.entries.push(IdentityEmbedding { id, embedding });
54 Ok(())
55 }
56
57 pub fn enroll_or_replace(
58 &mut self,
59 id: impl Into<String>,
60 embedding: Tensor,
61 ) -> Result<(), RecognizeError> {
62 validate_embedding(&embedding)?;
63 self.enforce_dim(embedding.len())?;
64 let id = id.into();
65 if let Some(existing) = self.entries.iter_mut().find(|entry| entry.id == id) {
66 existing.embedding = embedding;
67 return Ok(());
68 }
69 self.entries.push(IdentityEmbedding { id, embedding });
70 Ok(())
71 }
72
73 pub fn remove(&mut self, id: &str) -> bool {
74 if let Some(position) = self.entries.iter().position(|entry| entry.id == id) {
75 self.entries.remove(position);
76 if self.entries.is_empty() {
77 self.embedding_dim = None;
78 }
79 true
80 } else {
81 false
82 }
83 }
84
85 pub fn identities(&self) -> &[IdentityEmbedding] {
86 &self.entries
87 }
88
89 pub fn clear(&mut self) {
90 self.entries.clear();
91 self.embedding_dim = None;
92 }
93
94 pub fn recognize(&self, embedding: &Tensor) -> Result<Recognition, RecognizeError> {
95 validate_embedding(embedding)?;
96 self.recognize_prevalidated(embedding.data())
97 }
98
99 pub fn recognize_slice(&self, embedding: &[f32]) -> Result<Recognition, RecognizeError> {
100 validate_embedding_slice(embedding)?;
101 self.recognize_prevalidated(embedding)
102 }
103
104 fn recognize_prevalidated(&self, embedding: &[f32]) -> Result<Recognition, RecognizeError> {
105 if let Some(expected_dim) = self.embedding_dim {
106 if expected_dim != embedding.len() {
107 return Err(RecognizeError::EmbeddingDimMismatch {
108 expected: expected_dim,
109 got: embedding.len(),
110 });
111 }
112 } else {
113 return Ok(Recognition {
114 identity: None,
115 score: 0.0,
116 });
117 }
118
119 let mut best_index = None::<usize>;
120 let mut best_score = -1.0f32;
121 for (index, entry) in self.entries.iter().enumerate() {
122 let score = cosine_similarity_prevalidated(embedding, entry.embedding.data())?;
123 if score > best_score {
124 best_score = score;
125 best_index = Some(index);
126 }
127 }
128
129 if best_score >= self.threshold {
130 Ok(Recognition {
131 identity: best_index.map(|index| self.entries[index].id.clone()),
132 score: best_score,
133 })
134 } else {
135 Ok(Recognition {
136 identity: None,
137 score: best_score,
138 })
139 }
140 }
141
142 pub fn to_snapshot(&self) -> RecognizerSnapshot {
143 let mut identities = Vec::with_capacity(self.entries.len());
144 for entry in &self.entries {
145 identities.push(IdentitySnapshot {
146 id: entry.id.clone(),
147 embedding: entry.embedding.data().to_vec(),
148 });
149 }
150
151 RecognizerSnapshot {
152 threshold: self.threshold,
153 identities,
154 }
155 }
156
157 pub fn from_snapshot(snapshot: RecognizerSnapshot) -> Result<Self, RecognizeError> {
158 let mut recognizer = Self::new(snapshot.threshold)?;
159 for entry in snapshot.identities {
160 let embedding = Tensor::from_vec(vec![entry.embedding.len()], entry.embedding)
161 .map_err(|err| RecognizeError::Serialization {
162 message: err.to_string(),
163 })?;
164 recognizer.enroll(entry.id, embedding)?;
165 }
166 Ok(recognizer)
167 }
168
169 pub fn to_json_pretty(&self) -> Result<String, RecognizeError> {
170 serde_json::to_string_pretty(&self.to_snapshot()).map_err(|err| {
171 RecognizeError::Serialization {
172 message: err.to_string(),
173 }
174 })
175 }
176
177 pub fn from_json(json: &str) -> Result<Self, RecognizeError> {
178 let snapshot: RecognizerSnapshot =
179 serde_json::from_str(json).map_err(|err| RecognizeError::Serialization {
180 message: err.to_string(),
181 })?;
182 Self::from_snapshot(snapshot)
183 }
184
185 pub fn save_json_file(&self, path: impl AsRef<Path>) -> Result<(), RecognizeError> {
186 let json = self.to_json_pretty()?;
187 fs::write(path, json).map_err(|err| RecognizeError::Io {
188 message: err.to_string(),
189 })
190 }
191
192 pub fn load_json_file(path: impl AsRef<Path>) -> Result<Self, RecognizeError> {
193 let json = fs::read_to_string(path).map_err(|err| RecognizeError::Io {
194 message: err.to_string(),
195 })?;
196 Self::from_json(&json)
197 }
198
199 pub fn build_index(&mut self) {
201 let entries: Vec<(String, Vec<f32>)> = self
202 .entries
203 .iter()
204 .map(|e| (e.id.clone(), e.embedding.data().to_vec()))
205 .collect();
206 self.index = Some(VpTree::build(entries));
207 }
208
209 pub fn search_indexed(
213 &self,
214 embedding: &Tensor,
215 k: usize,
216 ) -> Result<Vec<Recognition>, RecognizeError> {
217 validate_embedding(embedding)?;
218
219 if let Some(expected_dim) = self.embedding_dim {
220 if expected_dim != embedding.len() {
221 return Err(RecognizeError::EmbeddingDimMismatch {
222 expected: expected_dim,
223 got: embedding.len(),
224 });
225 }
226 } else {
227 return Ok(Vec::new());
228 }
229
230 if let Some(ref index) = self.index {
231 let results = index.query(embedding.data(), k);
232 Ok(results
233 .into_iter()
234 .filter_map(|r| {
235 let score = 1.0 - r.distance;
236 if score >= self.threshold {
237 Some(Recognition {
238 identity: Some(r.id),
239 score,
240 })
241 } else {
242 None
243 }
244 })
245 .collect())
246 } else {
247 let mut scored: Vec<(usize, f32)> = Vec::with_capacity(self.entries.len());
249 for (i, entry) in self.entries.iter().enumerate() {
250 let score =
251 cosine_similarity_prevalidated(embedding.data(), entry.embedding.data())?;
252 scored.push((i, score));
253 }
254 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
255 scored.truncate(k);
256 Ok(scored
257 .into_iter()
258 .filter_map(|(i, score)| {
259 if score >= self.threshold {
260 Some(Recognition {
261 identity: Some(self.entries[i].id.clone()),
262 score,
263 })
264 } else {
265 None
266 }
267 })
268 .collect())
269 }
270 }
271
272 fn enforce_dim(&mut self, dim: usize) -> Result<(), RecognizeError> {
273 if let Some(expected_dim) = self.embedding_dim {
274 if expected_dim != dim {
275 return Err(RecognizeError::EmbeddingDimMismatch {
276 expected: expected_dim,
277 got: dim,
278 });
279 }
280 } else {
281 self.embedding_dim = Some(dim);
282 }
283 Ok(())
284 }
285}