1use crate::error::AiError;
8use image::{DynamicImage, ImageBuffer, Luma};
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11
12#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
14pub struct PerceptualHash {
15 pub hash: u64,
17 pub algorithm: HashAlgorithm,
19}
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
23pub enum HashAlgorithm {
24 DHash,
26 AHash,
28 PHash,
30}
31
32#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct SimilarityScore {
35 pub hamming_distance: u32,
37 pub similarity_percent: f64,
39 pub is_similar: bool,
41 pub threshold: u32,
43}
44
45pub struct ImageSimilarityDetector {
47 threshold: u32,
50 algorithm: HashAlgorithm,
52 hash_size: u32,
54}
55
56impl Default for ImageSimilarityDetector {
57 fn default() -> Self {
58 Self {
59 threshold: 10, algorithm: HashAlgorithm::DHash,
61 hash_size: 8,
62 }
63 }
64}
65
66impl ImageSimilarityDetector {
67 #[must_use]
69 pub fn new(threshold: u32, algorithm: HashAlgorithm) -> Self {
70 Self {
71 threshold,
72 algorithm,
73 hash_size: 8,
74 }
75 }
76
77 #[must_use]
79 pub fn with_threshold(mut self, threshold: u32) -> Self {
80 self.threshold = threshold;
81 self
82 }
83
84 #[must_use]
86 pub fn with_algorithm(mut self, algorithm: HashAlgorithm) -> Self {
87 self.algorithm = algorithm;
88 self
89 }
90
91 pub fn hash_image(&self, image_bytes: &[u8]) -> Result<PerceptualHash, AiError> {
93 let img = image::load_from_memory(image_bytes)
94 .map_err(|e| AiError::ParseError(format!("Failed to load image: {e}")))?;
95
96 self.hash_dynamic_image(&img)
97 }
98
99 pub fn hash_dynamic_image(&self, img: &DynamicImage) -> Result<PerceptualHash, AiError> {
101 let hash = match self.algorithm {
102 HashAlgorithm::DHash => self.compute_dhash(img),
103 HashAlgorithm::AHash => self.compute_ahash(img),
104 HashAlgorithm::PHash => self.compute_phash(img),
105 };
106
107 Ok(PerceptualHash {
108 hash,
109 algorithm: self.algorithm,
110 })
111 }
112
113 fn compute_dhash(&self, img: &DynamicImage) -> u64 {
115 let size = self.hash_size + 1;
116 let resized = img.resize_exact(size, self.hash_size, image::imageops::FilterType::Lanczos3);
117 let gray = resized.to_luma8();
118
119 let mut hash: u64 = 0;
120 for y in 0..self.hash_size {
121 for x in 0..self.hash_size {
122 let left = gray.get_pixel(x, y)[0];
123 let right = gray.get_pixel(x + 1, y)[0];
124 if left > right {
125 let bit_position = u64::from(y * self.hash_size + x);
126 hash |= 1 << bit_position;
127 }
128 }
129 }
130
131 hash
132 }
133
134 fn compute_ahash(&self, img: &DynamicImage) -> u64 {
136 let resized = img.resize_exact(
137 self.hash_size,
138 self.hash_size,
139 image::imageops::FilterType::Lanczos3,
140 );
141 let gray = resized.to_luma8();
142
143 let mut sum: u64 = 0;
145 for pixel in gray.pixels() {
146 sum += u64::from(pixel[0]);
147 }
148 let avg = sum / u64::from(self.hash_size * self.hash_size);
149
150 let mut hash: u64 = 0;
152 for (i, pixel) in gray.pixels().enumerate() {
153 if u64::from(pixel[0]) > avg {
154 hash |= 1 << i;
155 }
156 }
157
158 hash
159 }
160
161 fn compute_phash(&self, img: &DynamicImage) -> u64 {
163 let size = 32; let resized = img.resize_exact(size, size, image::imageops::FilterType::Lanczos3);
165 let gray = resized.to_luma8();
166
167 let dct = self.simple_dct(&gray, 8, 8);
169
170 let mut values: Vec<f64> = Vec::new();
172 for row in &dct {
173 for &val in row {
174 values.push(val);
175 }
176 }
177 values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
178 let median = values[values.len() / 2];
179
180 let mut hash: u64 = 0;
182 for (i, row) in dct.iter().enumerate() {
183 for (j, &val) in row.iter().enumerate() {
184 if i * 8 + j >= 64 {
185 break;
186 }
187 if val > median {
188 hash |= 1 << (i * 8 + j);
189 }
190 }
191 }
192
193 hash
194 }
195
196 fn simple_dct(
198 &self,
199 img: &ImageBuffer<Luma<u8>, Vec<u8>>,
200 rows: usize,
201 cols: usize,
202 ) -> Vec<Vec<f64>> {
203 let mut dct = vec![vec![0.0; cols]; rows];
204 let (width, height) = img.dimensions();
205 let block_width = width / cols as u32;
206 let block_height = height / rows as u32;
207
208 for (i, dct_row) in dct.iter_mut().enumerate() {
209 for (j, dct_val) in dct_row.iter_mut().enumerate() {
210 let mut sum = 0.0;
211 let mut count = 0.0;
212
213 for y in 0..block_height {
215 for x in 0..block_width {
216 let px = (j as u32) * block_width + x;
217 let py = (i as u32) * block_height + y;
218 if px < width && py < height {
219 sum += f64::from(img.get_pixel(px, py)[0]);
220 count += 1.0;
221 }
222 }
223 }
224
225 *dct_val = if count > 0.0 { sum / count } else { 0.0 };
226 }
227 }
228
229 dct
230 }
231
232 #[must_use]
234 pub fn hamming_distance(hash1: u64, hash2: u64) -> u32 {
235 (hash1 ^ hash2).count_ones()
236 }
237
238 pub fn compare_images(
240 &self,
241 image1_bytes: &[u8],
242 image2_bytes: &[u8],
243 ) -> Result<SimilarityScore, AiError> {
244 let hash1 = self.hash_image(image1_bytes)?;
245 let hash2 = self.hash_image(image2_bytes)?;
246
247 self.compare_hashes(&hash1, &hash2)
248 }
249
250 pub fn compare_hashes(
252 &self,
253 hash1: &PerceptualHash,
254 hash2: &PerceptualHash,
255 ) -> Result<SimilarityScore, AiError> {
256 if hash1.algorithm != hash2.algorithm {
257 return Err(AiError::InvalidInput(
258 "Cannot compare hashes from different algorithms".to_string(),
259 ));
260 }
261
262 let hamming_distance = Self::hamming_distance(hash1.hash, hash2.hash);
263 let similarity_percent = 100.0 * (1.0 - (f64::from(hamming_distance) / 64.0));
264 let is_similar = hamming_distance <= self.threshold;
265
266 Ok(SimilarityScore {
267 hamming_distance,
268 similarity_percent,
269 is_similar,
270 threshold: self.threshold,
271 })
272 }
273
274 pub fn find_similar_images(
276 &self,
277 query_image: &[u8],
278 image_collection: &[Vec<u8>],
279 ) -> Result<Vec<(usize, SimilarityScore)>, AiError> {
280 let query_hash = self.hash_image(query_image)?;
281 let mut results = Vec::new();
282
283 for (idx, img_bytes) in image_collection.iter().enumerate() {
284 let hash = self.hash_image(img_bytes)?;
285 let score = self.compare_hashes(&query_hash, &hash)?;
286
287 if score.is_similar {
288 results.push((idx, score));
289 }
290 }
291
292 results.sort_by(|a, b| a.1.hamming_distance.cmp(&b.1.hamming_distance));
294
295 Ok(results)
296 }
297}
298
299pub struct ImageDatabase {
301 hashes: HashMap<String, PerceptualHash>,
303 detector: ImageSimilarityDetector,
305}
306
307impl ImageDatabase {
308 #[must_use]
310 pub fn new(detector: ImageSimilarityDetector) -> Self {
311 Self {
312 hashes: HashMap::new(),
313 detector,
314 }
315 }
316
317 pub fn add_image(&mut self, id: String, image_bytes: &[u8]) -> Result<PerceptualHash, AiError> {
319 let hash = self.detector.hash_image(image_bytes)?;
320 self.hashes.insert(id, hash.clone());
321 Ok(hash)
322 }
323
324 pub fn is_duplicate(
326 &self,
327 image_bytes: &[u8],
328 ) -> Result<Option<(String, SimilarityScore)>, AiError> {
329 let query_hash = self.detector.hash_image(image_bytes)?;
330
331 for (id, hash) in &self.hashes {
332 let score = self.detector.compare_hashes(&query_hash, hash)?;
333 if score.is_similar {
334 return Ok(Some((id.clone(), score)));
335 }
336 }
337
338 Ok(None)
339 }
340
341 pub fn find_all_similar(
343 &self,
344 image_bytes: &[u8],
345 ) -> Result<Vec<(String, SimilarityScore)>, AiError> {
346 let query_hash = self.detector.hash_image(image_bytes)?;
347 let mut results = Vec::new();
348
349 for (id, hash) in &self.hashes {
350 let score = self.detector.compare_hashes(&query_hash, hash)?;
351 if score.is_similar {
352 results.push((id.clone(), score));
353 }
354 }
355
356 results.sort_by(|a, b| a.1.hamming_distance.cmp(&b.1.hamming_distance));
358
359 Ok(results)
360 }
361
362 #[must_use]
364 pub fn len(&self) -> usize {
365 self.hashes.len()
366 }
367
368 #[must_use]
370 pub fn is_empty(&self) -> bool {
371 self.hashes.is_empty()
372 }
373
374 pub fn clear(&mut self) {
376 self.hashes.clear();
377 }
378
379 #[must_use]
381 pub fn find_duplicates(&self) -> Vec<(String, String, f64)> {
382 let mut duplicates = Vec::new();
383 let ids: Vec<_> = self.hashes.keys().cloned().collect();
384
385 for i in 0..ids.len() {
386 for j in (i + 1)..ids.len() {
387 let hash1 = &self.hashes[&ids[i]];
388 let hash2 = &self.hashes[&ids[j]];
389
390 if let Ok(score) = self.detector.compare_hashes(hash1, hash2) {
391 if score.is_similar {
392 duplicates.push((ids[i].clone(), ids[j].clone(), score.similarity_percent));
393 }
394 }
395 }
396 }
397
398 duplicates.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
400
401 duplicates
402 }
403
404 pub fn find_similar(
406 &self,
407 image_bytes: &[u8],
408 min_similarity: f64,
409 ) -> Result<Vec<(String, f64)>, AiError> {
410 let query_hash = self.detector.hash_image(image_bytes)?;
411 let mut results = Vec::new();
412
413 for (id, hash) in &self.hashes {
414 let score = self.detector.compare_hashes(&query_hash, hash)?;
415 if score.similarity_percent >= min_similarity {
416 results.push((id.clone(), score.similarity_percent));
417 }
418 }
419
420 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
422
423 Ok(results)
424 }
425
426 pub fn has_similar_image(&self, image_bytes: &[u8]) -> Result<bool, AiError> {
428 Ok(self.is_duplicate(image_bytes)?.is_some())
429 }
430
431 #[must_use]
433 pub fn get_all_ids(&self) -> Vec<String> {
434 self.hashes.keys().cloned().collect()
435 }
436
437 pub fn remove_image(&mut self, id: &str) -> Option<PerceptualHash> {
439 self.hashes.remove(id)
440 }
441}
442
443#[cfg(test)]
444mod tests {
445 use super::*;
446 use image::{ImageBuffer, Rgb};
447
448 fn create_test_image(color: [u8; 3]) -> Vec<u8> {
449 let img: ImageBuffer<Rgb<u8>, Vec<u8>> = ImageBuffer::from_fn(100, 100, |_, _| Rgb(color));
450
451 let mut bytes = Vec::new();
452 img.write_to(
453 &mut std::io::Cursor::new(&mut bytes),
454 image::ImageFormat::Png,
455 )
456 .unwrap();
457 bytes
458 }
459
460 #[test]
461 fn test_identical_images() {
462 let detector = ImageSimilarityDetector::default();
463 let img1 = create_test_image([255, 0, 0]);
464 let img2 = img1.clone();
465
466 let score = detector.compare_images(&img1, &img2).unwrap();
467 assert_eq!(score.hamming_distance, 0);
468 assert!((score.similarity_percent - 100.0).abs() < 0.01);
469 assert!(score.is_similar);
470 }
471
472 #[test]
473 fn test_different_images() {
474 let detector = ImageSimilarityDetector::default();
475 let img1 = create_test_image([255, 0, 0]); let img2 = create_test_image([0, 0, 255]); let score = detector.compare_images(&img1, &img2).unwrap();
479 assert!(score.similarity_percent <= 100.0);
482 assert!(score.similarity_percent >= 0.0);
483 }
484
485 #[test]
486 fn test_hash_algorithms() {
487 let img = create_test_image([128, 128, 128]);
488
489 let dhash_detector = ImageSimilarityDetector::new(10, HashAlgorithm::DHash);
490 let ahash_detector = ImageSimilarityDetector::new(10, HashAlgorithm::AHash);
491 let phash_detector = ImageSimilarityDetector::new(10, HashAlgorithm::PHash);
492
493 let dhash = dhash_detector.hash_image(&img).unwrap();
494 let ahash = ahash_detector.hash_image(&img).unwrap();
495 let phash = phash_detector.hash_image(&img).unwrap();
496
497 assert_eq!(dhash.algorithm, HashAlgorithm::DHash);
498 assert_eq!(ahash.algorithm, HashAlgorithm::AHash);
499 assert_eq!(phash.algorithm, HashAlgorithm::PHash);
500 }
501
502 #[test]
503 fn test_hamming_distance() {
504 assert_eq!(ImageSimilarityDetector::hamming_distance(0b1010, 0b1010), 0);
505 assert_eq!(ImageSimilarityDetector::hamming_distance(0b1010, 0b1011), 1);
506 assert_eq!(ImageSimilarityDetector::hamming_distance(0b1010, 0b0101), 4);
507 }
508
509 #[test]
510 fn test_image_database() {
511 let detector = ImageSimilarityDetector::default();
512 let mut db = ImageDatabase::new(detector);
513
514 let img1 = create_test_image([255, 0, 0]);
515 let img2 = create_test_image([0, 255, 0]);
516
517 db.add_image("img1".to_string(), &img1).unwrap();
518 assert_eq!(db.len(), 1);
519
520 let duplicate = db.is_duplicate(&img1).unwrap();
522 assert!(duplicate.is_some());
523
524 db.add_image("img2".to_string(), &img2).unwrap();
526 assert_eq!(db.len(), 2);
527
528 db.clear();
529 assert!(db.is_empty());
530 }
531
532 #[test]
533 fn test_find_similar_images() {
534 let detector = ImageSimilarityDetector::default();
535 let img1 = create_test_image([255, 0, 0]);
536 let img2 = create_test_image([254, 0, 0]); let img3 = create_test_image([0, 0, 255]); let collection = vec![img1.clone(), img2.clone(), img3];
540 let results = detector.find_similar_images(&img1, &collection).unwrap();
541
542 assert!(!results.is_empty());
544 assert!(results[0].1.is_similar);
545 }
546
547 #[test]
548 fn test_with_threshold() {
549 let detector = ImageSimilarityDetector::default().with_threshold(5);
550 assert_eq!(detector.threshold, 5);
551 }
552
553 #[test]
554 fn test_with_algorithm() {
555 let detector = ImageSimilarityDetector::default().with_algorithm(HashAlgorithm::PHash);
556 assert_eq!(detector.algorithm, HashAlgorithm::PHash);
557 }
558
559 #[test]
560 fn test_find_duplicates_in_database() {
561 let detector = ImageSimilarityDetector::default();
562 let mut db = ImageDatabase::new(detector);
563
564 let img1 = create_test_image([255, 0, 0]);
565 let img2 = img1.clone(); let img3 = create_test_image([0, 255, 0]);
567
568 db.add_image("img1".to_string(), &img1).unwrap();
569 db.add_image("img2".to_string(), &img2).unwrap();
570 db.add_image("img3".to_string(), &img3).unwrap();
571
572 let duplicates = db.find_duplicates();
573 assert!(!duplicates.is_empty());
575 }
576
577 #[test]
578 fn test_find_similar_with_threshold() {
579 let detector = ImageSimilarityDetector::default();
580 let mut db = ImageDatabase::new(detector);
581
582 let img1 = create_test_image([255, 0, 0]);
583 let img2 = create_test_image([254, 0, 0]);
584
585 db.add_image("img1".to_string(), &img1).unwrap();
586 db.add_image("img2".to_string(), &img2).unwrap();
587
588 let similar = db.find_similar(&img1, 90.0).unwrap();
590 assert!(!similar.is_empty());
591 }
592
593 #[test]
594 fn test_has_similar_image() {
595 let detector = ImageSimilarityDetector::default();
596 let mut db = ImageDatabase::new(detector);
597
598 let img1 = create_test_image([255, 0, 0]);
599 db.add_image("img1".to_string(), &img1).unwrap();
600
601 assert!(db.has_similar_image(&img1).unwrap());
602 }
603
604 #[test]
605 fn test_get_all_ids() {
606 let detector = ImageSimilarityDetector::default();
607 let mut db = ImageDatabase::new(detector);
608
609 let img1 = create_test_image([255, 0, 0]);
610 let img2 = create_test_image([0, 255, 0]);
611
612 db.add_image("img1".to_string(), &img1).unwrap();
613 db.add_image("img2".to_string(), &img2).unwrap();
614
615 let ids = db.get_all_ids();
616 assert_eq!(ids.len(), 2);
617 assert!(ids.contains(&"img1".to_string()));
618 assert!(ids.contains(&"img2".to_string()));
619 }
620
621 #[test]
622 fn test_remove_image() {
623 let detector = ImageSimilarityDetector::default();
624 let mut db = ImageDatabase::new(detector);
625
626 let img1 = create_test_image([255, 0, 0]);
627 db.add_image("img1".to_string(), &img1).unwrap();
628 assert_eq!(db.len(), 1);
629
630 let removed = db.remove_image("img1");
631 assert!(removed.is_some());
632 assert_eq!(db.len(), 0);
633 }
634}