kaccy_ai/
image_similarity.rs

1//! Image similarity detection using perceptual hashing
2//!
3//! Provides duplicate and near-duplicate image detection for fraud prevention.
4//! Uses perceptual hashing (pHash) to create fingerprints of images that are
5//! resistant to minor modifications like resizing, compression, and slight color changes.
6
7use crate::error::AiError;
8use image::{DynamicImage, ImageBuffer, Luma};
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11
12/// Perceptual hash of an image
13#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
14pub struct PerceptualHash {
15    /// Hash value (64-bit)
16    pub hash: u64,
17    /// Hash algorithm used
18    pub algorithm: HashAlgorithm,
19}
20
21/// Hash algorithm type
22#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
23pub enum HashAlgorithm {
24    /// Difference hash (dHash) - fast, good for exact and near-duplicates
25    DHash,
26    /// Average hash (aHash) - very fast, good for exact duplicates
27    AHash,
28    /// Perceptual hash (pHash) - slower, best for detecting modifications
29    PHash,
30}
31
32/// Image similarity score
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct SimilarityScore {
35    /// Hamming distance between hashes
36    pub hamming_distance: u32,
37    /// Similarity percentage (0-100)
38    pub similarity_percent: f64,
39    /// Whether images are considered similar
40    pub is_similar: bool,
41    /// Threshold used for comparison
42    pub threshold: u32,
43}
44
45/// Image similarity detector
46pub struct ImageSimilarityDetector {
47    /// Similarity threshold (Hamming distance)
48    /// Lower = more strict (0 = identical, 64 = completely different)
49    threshold: u32,
50    /// Hash algorithm to use
51    algorithm: HashAlgorithm,
52    /// Size for hash computation
53    hash_size: u32,
54}
55
56impl Default for ImageSimilarityDetector {
57    fn default() -> Self {
58        Self {
59            threshold: 10, // Default: allow 10 bits difference
60            algorithm: HashAlgorithm::DHash,
61            hash_size: 8,
62        }
63    }
64}
65
66impl ImageSimilarityDetector {
67    /// Create a new detector with custom threshold
68    #[must_use]
69    pub fn new(threshold: u32, algorithm: HashAlgorithm) -> Self {
70        Self {
71            threshold,
72            algorithm,
73            hash_size: 8,
74        }
75    }
76
77    /// Set similarity threshold
78    #[must_use]
79    pub fn with_threshold(mut self, threshold: u32) -> Self {
80        self.threshold = threshold;
81        self
82    }
83
84    /// Set hash algorithm
85    #[must_use]
86    pub fn with_algorithm(mut self, algorithm: HashAlgorithm) -> Self {
87        self.algorithm = algorithm;
88        self
89    }
90
91    /// Compute perceptual hash from image bytes
92    pub fn hash_image(&self, image_bytes: &[u8]) -> Result<PerceptualHash, AiError> {
93        let img = image::load_from_memory(image_bytes)
94            .map_err(|e| AiError::ParseError(format!("Failed to load image: {e}")))?;
95
96        self.hash_dynamic_image(&img)
97    }
98
99    /// Compute perceptual hash from `DynamicImage`
100    pub fn hash_dynamic_image(&self, img: &DynamicImage) -> Result<PerceptualHash, AiError> {
101        let hash = match self.algorithm {
102            HashAlgorithm::DHash => self.compute_dhash(img),
103            HashAlgorithm::AHash => self.compute_ahash(img),
104            HashAlgorithm::PHash => self.compute_phash(img),
105        };
106
107        Ok(PerceptualHash {
108            hash,
109            algorithm: self.algorithm,
110        })
111    }
112
113    /// Compute difference hash (dHash)
114    fn compute_dhash(&self, img: &DynamicImage) -> u64 {
115        let size = self.hash_size + 1;
116        let resized = img.resize_exact(size, self.hash_size, image::imageops::FilterType::Lanczos3);
117        let gray = resized.to_luma8();
118
119        let mut hash: u64 = 0;
120        for y in 0..self.hash_size {
121            for x in 0..self.hash_size {
122                let left = gray.get_pixel(x, y)[0];
123                let right = gray.get_pixel(x + 1, y)[0];
124                if left > right {
125                    let bit_position = u64::from(y * self.hash_size + x);
126                    hash |= 1 << bit_position;
127                }
128            }
129        }
130
131        hash
132    }
133
134    /// Compute average hash (aHash)
135    fn compute_ahash(&self, img: &DynamicImage) -> u64 {
136        let resized = img.resize_exact(
137            self.hash_size,
138            self.hash_size,
139            image::imageops::FilterType::Lanczos3,
140        );
141        let gray = resized.to_luma8();
142
143        // Calculate average pixel value
144        let mut sum: u64 = 0;
145        for pixel in gray.pixels() {
146            sum += u64::from(pixel[0]);
147        }
148        let avg = sum / u64::from(self.hash_size * self.hash_size);
149
150        // Build hash based on whether pixels are above or below average
151        let mut hash: u64 = 0;
152        for (i, pixel) in gray.pixels().enumerate() {
153            if u64::from(pixel[0]) > avg {
154                hash |= 1 << i;
155            }
156        }
157
158        hash
159    }
160
161    /// Compute perceptual hash (pHash) using DCT
162    fn compute_phash(&self, img: &DynamicImage) -> u64 {
163        let size = 32; // Use larger size for DCT
164        let resized = img.resize_exact(size, size, image::imageops::FilterType::Lanczos3);
165        let gray = resized.to_luma8();
166
167        // Simplified DCT for 32x32 image (using 8x8 DCT grid)
168        let dct = self.simple_dct(&gray, 8, 8);
169
170        // Calculate median of DCT values (excluding DC component)
171        let mut values: Vec<f64> = Vec::new();
172        for row in &dct {
173            for &val in row {
174                values.push(val);
175            }
176        }
177        values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
178        let median = values[values.len() / 2];
179
180        // Build hash based on median
181        let mut hash: u64 = 0;
182        for (i, row) in dct.iter().enumerate() {
183            for (j, &val) in row.iter().enumerate() {
184                if i * 8 + j >= 64 {
185                    break;
186                }
187                if val > median {
188                    hash |= 1 << (i * 8 + j);
189                }
190            }
191        }
192
193        hash
194    }
195
196    /// Simple 2D DCT (Discrete Cosine Transform) approximation
197    fn simple_dct(
198        &self,
199        img: &ImageBuffer<Luma<u8>, Vec<u8>>,
200        rows: usize,
201        cols: usize,
202    ) -> Vec<Vec<f64>> {
203        let mut dct = vec![vec![0.0; cols]; rows];
204        let (width, height) = img.dimensions();
205        let block_width = width / cols as u32;
206        let block_height = height / rows as u32;
207
208        for (i, dct_row) in dct.iter_mut().enumerate() {
209            for (j, dct_val) in dct_row.iter_mut().enumerate() {
210                let mut sum = 0.0;
211                let mut count = 0.0;
212
213                // Average pixel values in block
214                for y in 0..block_height {
215                    for x in 0..block_width {
216                        let px = (j as u32) * block_width + x;
217                        let py = (i as u32) * block_height + y;
218                        if px < width && py < height {
219                            sum += f64::from(img.get_pixel(px, py)[0]);
220                            count += 1.0;
221                        }
222                    }
223                }
224
225                *dct_val = if count > 0.0 { sum / count } else { 0.0 };
226            }
227        }
228
229        dct
230    }
231
232    /// Calculate Hamming distance between two hashes
233    #[must_use]
234    pub fn hamming_distance(hash1: u64, hash2: u64) -> u32 {
235        (hash1 ^ hash2).count_ones()
236    }
237
238    /// Compare two images for similarity
239    pub fn compare_images(
240        &self,
241        image1_bytes: &[u8],
242        image2_bytes: &[u8],
243    ) -> Result<SimilarityScore, AiError> {
244        let hash1 = self.hash_image(image1_bytes)?;
245        let hash2 = self.hash_image(image2_bytes)?;
246
247        self.compare_hashes(&hash1, &hash2)
248    }
249
250    /// Compare two perceptual hashes
251    pub fn compare_hashes(
252        &self,
253        hash1: &PerceptualHash,
254        hash2: &PerceptualHash,
255    ) -> Result<SimilarityScore, AiError> {
256        if hash1.algorithm != hash2.algorithm {
257            return Err(AiError::InvalidInput(
258                "Cannot compare hashes from different algorithms".to_string(),
259            ));
260        }
261
262        let hamming_distance = Self::hamming_distance(hash1.hash, hash2.hash);
263        let similarity_percent = 100.0 * (1.0 - (f64::from(hamming_distance) / 64.0));
264        let is_similar = hamming_distance <= self.threshold;
265
266        Ok(SimilarityScore {
267            hamming_distance,
268            similarity_percent,
269            is_similar,
270            threshold: self.threshold,
271        })
272    }
273
274    /// Find similar images in a collection
275    pub fn find_similar_images(
276        &self,
277        query_image: &[u8],
278        image_collection: &[Vec<u8>],
279    ) -> Result<Vec<(usize, SimilarityScore)>, AiError> {
280        let query_hash = self.hash_image(query_image)?;
281        let mut results = Vec::new();
282
283        for (idx, img_bytes) in image_collection.iter().enumerate() {
284            let hash = self.hash_image(img_bytes)?;
285            let score = self.compare_hashes(&query_hash, &hash)?;
286
287            if score.is_similar {
288                results.push((idx, score));
289            }
290        }
291
292        // Sort by similarity (most similar first)
293        results.sort_by(|a, b| a.1.hamming_distance.cmp(&b.1.hamming_distance));
294
295        Ok(results)
296    }
297}
298
299/// Image database for deduplication
300pub struct ImageDatabase {
301    /// Stored hashes with IDs
302    hashes: HashMap<String, PerceptualHash>,
303    /// Detector instance
304    detector: ImageSimilarityDetector,
305}
306
307impl ImageDatabase {
308    /// Create a new image database
309    #[must_use]
310    pub fn new(detector: ImageSimilarityDetector) -> Self {
311        Self {
312            hashes: HashMap::new(),
313            detector,
314        }
315    }
316
317    /// Add an image to the database
318    pub fn add_image(&mut self, id: String, image_bytes: &[u8]) -> Result<PerceptualHash, AiError> {
319        let hash = self.detector.hash_image(image_bytes)?;
320        self.hashes.insert(id, hash.clone());
321        Ok(hash)
322    }
323
324    /// Check if an image is a duplicate
325    pub fn is_duplicate(
326        &self,
327        image_bytes: &[u8],
328    ) -> Result<Option<(String, SimilarityScore)>, AiError> {
329        let query_hash = self.detector.hash_image(image_bytes)?;
330
331        for (id, hash) in &self.hashes {
332            let score = self.detector.compare_hashes(&query_hash, hash)?;
333            if score.is_similar {
334                return Ok(Some((id.clone(), score)));
335            }
336        }
337
338        Ok(None)
339    }
340
341    /// Find all similar images
342    pub fn find_all_similar(
343        &self,
344        image_bytes: &[u8],
345    ) -> Result<Vec<(String, SimilarityScore)>, AiError> {
346        let query_hash = self.detector.hash_image(image_bytes)?;
347        let mut results = Vec::new();
348
349        for (id, hash) in &self.hashes {
350            let score = self.detector.compare_hashes(&query_hash, hash)?;
351            if score.is_similar {
352                results.push((id.clone(), score));
353            }
354        }
355
356        // Sort by similarity
357        results.sort_by(|a, b| a.1.hamming_distance.cmp(&b.1.hamming_distance));
358
359        Ok(results)
360    }
361
362    /// Get number of images in database
363    #[must_use]
364    pub fn len(&self) -> usize {
365        self.hashes.len()
366    }
367
368    /// Check if database is empty
369    #[must_use]
370    pub fn is_empty(&self) -> bool {
371        self.hashes.is_empty()
372    }
373
374    /// Clear the database
375    pub fn clear(&mut self) {
376        self.hashes.clear();
377    }
378
379    /// Find all duplicates in the database
380    #[must_use]
381    pub fn find_duplicates(&self) -> Vec<(String, String, f64)> {
382        let mut duplicates = Vec::new();
383        let ids: Vec<_> = self.hashes.keys().cloned().collect();
384
385        for i in 0..ids.len() {
386            for j in (i + 1)..ids.len() {
387                let hash1 = &self.hashes[&ids[i]];
388                let hash2 = &self.hashes[&ids[j]];
389
390                if let Ok(score) = self.detector.compare_hashes(hash1, hash2) {
391                    if score.is_similar {
392                        duplicates.push((ids[i].clone(), ids[j].clone(), score.similarity_percent));
393                    }
394                }
395            }
396        }
397
398        // Sort by similarity (highest first)
399        duplicates.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
400
401        duplicates
402    }
403
404    /// Find similar images to a query
405    pub fn find_similar(
406        &self,
407        image_bytes: &[u8],
408        min_similarity: f64,
409    ) -> Result<Vec<(String, f64)>, AiError> {
410        let query_hash = self.detector.hash_image(image_bytes)?;
411        let mut results = Vec::new();
412
413        for (id, hash) in &self.hashes {
414            let score = self.detector.compare_hashes(&query_hash, hash)?;
415            if score.similarity_percent >= min_similarity {
416                results.push((id.clone(), score.similarity_percent));
417            }
418        }
419
420        // Sort by similarity (highest first)
421        results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
422
423        Ok(results)
424    }
425
426    /// Check if similar image exists in database
427    pub fn has_similar_image(&self, image_bytes: &[u8]) -> Result<bool, AiError> {
428        Ok(self.is_duplicate(image_bytes)?.is_some())
429    }
430
431    /// Get all stored image IDs
432    #[must_use]
433    pub fn get_all_ids(&self) -> Vec<String> {
434        self.hashes.keys().cloned().collect()
435    }
436
437    /// Remove an image from the database
438    pub fn remove_image(&mut self, id: &str) -> Option<PerceptualHash> {
439        self.hashes.remove(id)
440    }
441}
442
443#[cfg(test)]
444mod tests {
445    use super::*;
446    use image::{ImageBuffer, Rgb};
447
448    fn create_test_image(color: [u8; 3]) -> Vec<u8> {
449        let img: ImageBuffer<Rgb<u8>, Vec<u8>> = ImageBuffer::from_fn(100, 100, |_, _| Rgb(color));
450
451        let mut bytes = Vec::new();
452        img.write_to(
453            &mut std::io::Cursor::new(&mut bytes),
454            image::ImageFormat::Png,
455        )
456        .unwrap();
457        bytes
458    }
459
460    #[test]
461    fn test_identical_images() {
462        let detector = ImageSimilarityDetector::default();
463        let img1 = create_test_image([255, 0, 0]);
464        let img2 = img1.clone();
465
466        let score = detector.compare_images(&img1, &img2).unwrap();
467        assert_eq!(score.hamming_distance, 0);
468        assert!((score.similarity_percent - 100.0).abs() < 0.01);
469        assert!(score.is_similar);
470    }
471
472    #[test]
473    fn test_different_images() {
474        let detector = ImageSimilarityDetector::default();
475        let img1 = create_test_image([255, 0, 0]); // Red
476        let img2 = create_test_image([0, 0, 255]); // Blue
477
478        let score = detector.compare_images(&img1, &img2).unwrap();
479        // Solid color images may have same hash due to lack of detail
480        // Just verify the comparison works
481        assert!(score.similarity_percent <= 100.0);
482        assert!(score.similarity_percent >= 0.0);
483    }
484
485    #[test]
486    fn test_hash_algorithms() {
487        let img = create_test_image([128, 128, 128]);
488
489        let dhash_detector = ImageSimilarityDetector::new(10, HashAlgorithm::DHash);
490        let ahash_detector = ImageSimilarityDetector::new(10, HashAlgorithm::AHash);
491        let phash_detector = ImageSimilarityDetector::new(10, HashAlgorithm::PHash);
492
493        let dhash = dhash_detector.hash_image(&img).unwrap();
494        let ahash = ahash_detector.hash_image(&img).unwrap();
495        let phash = phash_detector.hash_image(&img).unwrap();
496
497        assert_eq!(dhash.algorithm, HashAlgorithm::DHash);
498        assert_eq!(ahash.algorithm, HashAlgorithm::AHash);
499        assert_eq!(phash.algorithm, HashAlgorithm::PHash);
500    }
501
502    #[test]
503    fn test_hamming_distance() {
504        assert_eq!(ImageSimilarityDetector::hamming_distance(0b1010, 0b1010), 0);
505        assert_eq!(ImageSimilarityDetector::hamming_distance(0b1010, 0b1011), 1);
506        assert_eq!(ImageSimilarityDetector::hamming_distance(0b1010, 0b0101), 4);
507    }
508
509    #[test]
510    fn test_image_database() {
511        let detector = ImageSimilarityDetector::default();
512        let mut db = ImageDatabase::new(detector);
513
514        let img1 = create_test_image([255, 0, 0]);
515        let img2 = create_test_image([0, 255, 0]);
516
517        db.add_image("img1".to_string(), &img1).unwrap();
518        assert_eq!(db.len(), 1);
519
520        // Check for duplicate (exact match)
521        let duplicate = db.is_duplicate(&img1).unwrap();
522        assert!(duplicate.is_some());
523
524        // Check for non-duplicate
525        db.add_image("img2".to_string(), &img2).unwrap();
526        assert_eq!(db.len(), 2);
527
528        db.clear();
529        assert!(db.is_empty());
530    }
531
532    #[test]
533    fn test_find_similar_images() {
534        let detector = ImageSimilarityDetector::default();
535        let img1 = create_test_image([255, 0, 0]);
536        let img2 = create_test_image([254, 0, 0]); // Very similar
537        let img3 = create_test_image([0, 0, 255]); // Different
538
539        let collection = vec![img1.clone(), img2.clone(), img3];
540        let results = detector.find_similar_images(&img1, &collection).unwrap();
541
542        // Should find at least itself and the very similar image
543        assert!(!results.is_empty());
544        assert!(results[0].1.is_similar);
545    }
546
547    #[test]
548    fn test_with_threshold() {
549        let detector = ImageSimilarityDetector::default().with_threshold(5);
550        assert_eq!(detector.threshold, 5);
551    }
552
553    #[test]
554    fn test_with_algorithm() {
555        let detector = ImageSimilarityDetector::default().with_algorithm(HashAlgorithm::PHash);
556        assert_eq!(detector.algorithm, HashAlgorithm::PHash);
557    }
558
559    #[test]
560    fn test_find_duplicates_in_database() {
561        let detector = ImageSimilarityDetector::default();
562        let mut db = ImageDatabase::new(detector);
563
564        let img1 = create_test_image([255, 0, 0]);
565        let img2 = img1.clone(); // Exact duplicate
566        let img3 = create_test_image([0, 255, 0]);
567
568        db.add_image("img1".to_string(), &img1).unwrap();
569        db.add_image("img2".to_string(), &img2).unwrap();
570        db.add_image("img3".to_string(), &img3).unwrap();
571
572        let duplicates = db.find_duplicates();
573        // Should find at least the exact duplicate pair
574        assert!(!duplicates.is_empty());
575    }
576
577    #[test]
578    fn test_find_similar_with_threshold() {
579        let detector = ImageSimilarityDetector::default();
580        let mut db = ImageDatabase::new(detector);
581
582        let img1 = create_test_image([255, 0, 0]);
583        let img2 = create_test_image([254, 0, 0]);
584
585        db.add_image("img1".to_string(), &img1).unwrap();
586        db.add_image("img2".to_string(), &img2).unwrap();
587
588        // Find similar with 90% threshold
589        let similar = db.find_similar(&img1, 90.0).unwrap();
590        assert!(!similar.is_empty());
591    }
592
593    #[test]
594    fn test_has_similar_image() {
595        let detector = ImageSimilarityDetector::default();
596        let mut db = ImageDatabase::new(detector);
597
598        let img1 = create_test_image([255, 0, 0]);
599        db.add_image("img1".to_string(), &img1).unwrap();
600
601        assert!(db.has_similar_image(&img1).unwrap());
602    }
603
604    #[test]
605    fn test_get_all_ids() {
606        let detector = ImageSimilarityDetector::default();
607        let mut db = ImageDatabase::new(detector);
608
609        let img1 = create_test_image([255, 0, 0]);
610        let img2 = create_test_image([0, 255, 0]);
611
612        db.add_image("img1".to_string(), &img1).unwrap();
613        db.add_image("img2".to_string(), &img2).unwrap();
614
615        let ids = db.get_all_ids();
616        assert_eq!(ids.len(), 2);
617        assert!(ids.contains(&"img1".to_string()));
618        assert!(ids.contains(&"img2".to_string()));
619    }
620
621    #[test]
622    fn test_remove_image() {
623        let detector = ImageSimilarityDetector::default();
624        let mut db = ImageDatabase::new(detector);
625
626        let img1 = create_test_image([255, 0, 0]);
627        db.add_image("img1".to_string(), &img1).unwrap();
628        assert_eq!(db.len(), 1);
629
630        let removed = db.remove_image("img1");
631        assert!(removed.is_some());
632        assert_eq!(db.len(), 0);
633    }
634}