use crate::core::similarity::{hash_similarity, Similarity};
use crate::hash::algorithms::HashAlgorithm;
pub const DEFAULT_AHASH_WEIGHT: f32 = 0.10;
pub const DEFAULT_PHASH_WEIGHT: f32 = 0.60;
pub const DEFAULT_DHASH_WEIGHT: f32 = 0.30;
pub const DEFAULT_GLOBAL_WEIGHT: f32 = 0.40;
pub const DEFAULT_BLOCK_WEIGHT: f32 = 0.60;
pub const DEFAULT_BLOCK_DISTANCE_THRESHOLD: u32 = 32;
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(feature = "serde", serde(deny_unknown_fields))]
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct MultiHashConfig {
pub ahash_weight: f32,
pub phash_weight: f32,
pub dhash_weight: f32,
pub global_weight: f32,
pub block_weight: f32,
pub block_distance_threshold: u32,
}
impl Default for MultiHashConfig {
fn default() -> Self {
Self {
ahash_weight: DEFAULT_AHASH_WEIGHT,
phash_weight: DEFAULT_PHASH_WEIGHT,
dhash_weight: DEFAULT_DHASH_WEIGHT,
global_weight: DEFAULT_GLOBAL_WEIGHT,
block_weight: DEFAULT_BLOCK_WEIGHT,
block_distance_threshold: DEFAULT_BLOCK_DISTANCE_THRESHOLD,
}
}
}
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(feature = "serde", serde(deny_unknown_fields))]
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct ImageFingerprint {
pub(crate) exact: [u8; 32],
pub(crate) global_hash: u64,
pub(crate) block_hashes: [u64; 16],
}
impl ImageFingerprint {
#[inline]
pub(crate) fn new(exact: [u8; 32], global_hash: u64, block_hashes: [u64; 16]) -> Self {
Self {
exact,
global_hash,
block_hashes,
}
}
#[inline]
#[must_use]
pub fn exact_hash(&self) -> &[u8; 32] {
&self.exact
}
#[inline]
#[must_use]
pub fn global_hash(&self) -> u64 {
self.global_hash
}
#[inline]
#[must_use]
pub fn block_hashes(&self) -> &[u64; 16] {
&self.block_hashes
}
#[inline]
#[must_use]
pub fn distance(&self, other: &ImageFingerprint) -> u32 {
(self.global_hash ^ other.global_hash).count_ones()
}
#[doc(alias = "compare")]
#[doc(alias = "match")]
#[must_use]
pub fn is_similar(&self, other: &ImageFingerprint, threshold: f32) -> bool {
debug_assert!(
(0.0..=1.0).contains(&threshold),
"threshold must be in range [0.0, 1.0], got {threshold}"
);
if self.exact == other.exact {
return true;
}
let clamped_threshold = threshold.clamp(0.0, 1.0);
let dist = self.distance(other);
let similarity = hash_similarity(dist);
similarity >= clamped_threshold
}
}
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(feature = "serde", serde(deny_unknown_fields))]
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct MultiHashFingerprint {
pub(crate) exact: [u8; 32],
pub(crate) ahash: ImageFingerprint,
pub(crate) phash: ImageFingerprint,
pub(crate) dhash: ImageFingerprint,
}
impl MultiHashFingerprint {
pub(crate) fn new(
exact: [u8; 32],
ahash: ImageFingerprint,
phash: ImageFingerprint,
dhash: ImageFingerprint,
) -> Self {
Self {
exact,
ahash,
phash,
dhash,
}
}
#[inline]
#[must_use]
pub fn exact_hash(&self) -> &[u8; 32] {
&self.exact
}
#[inline]
#[must_use]
pub fn ahash(&self) -> &ImageFingerprint {
&self.ahash
}
#[inline]
#[must_use]
pub fn phash(&self) -> &ImageFingerprint {
&self.phash
}
#[inline]
#[must_use]
pub fn dhash(&self) -> &ImageFingerprint {
&self.dhash
}
#[must_use]
pub fn get(&self, algorithm: HashAlgorithm) -> &ImageFingerprint {
match algorithm {
HashAlgorithm::AHash => &self.ahash,
HashAlgorithm::PHash => &self.phash,
HashAlgorithm::DHash => &self.dhash,
}
}
#[must_use]
pub fn compare(&self, other: &MultiHashFingerprint) -> Similarity {
self.compare_with_threshold(other, 32)
}
#[must_use]
pub fn compare_with_threshold(
&self,
other: &MultiHashFingerprint,
block_threshold: u32,
) -> Similarity {
let cfg = MultiHashConfig {
block_distance_threshold: block_threshold,
..MultiHashConfig::default()
};
self.compare_with_config(other, &cfg)
}
#[must_use]
pub fn compare_with_config(
&self,
other: &MultiHashFingerprint,
config: &MultiHashConfig,
) -> Similarity {
use crate::core::similarity::{compute_similarity_with_weights, hamming_distance};
use subtle::ConstantTimeEq;
let exact_match = self.exact.ct_eq(&other.exact).into();
if exact_match {
return Similarity {
score: 1.0,
exact_match: true,
perceptual_distance: 0,
};
}
let ahash_sim = compute_similarity_with_weights(
&self.ahash,
&other.ahash,
config.global_weight,
config.block_weight,
config.block_distance_threshold,
)
.score;
let phash_sim = compute_similarity_with_weights(
&self.phash,
&other.phash,
config.global_weight,
config.block_weight,
config.block_distance_threshold,
)
.score;
let dhash_sim = compute_similarity_with_weights(
&self.dhash,
&other.dhash,
config.global_weight,
config.block_weight,
config.block_distance_threshold,
)
.score;
let weighted_score = ahash_sim * config.ahash_weight
+ phash_sim * config.phash_weight
+ dhash_sim * config.dhash_weight;
let ahash_dist = hamming_distance(self.ahash.global_hash, other.ahash.global_hash);
let phash_dist = hamming_distance(self.phash.global_hash, other.phash.global_hash);
let dhash_dist = hamming_distance(self.dhash.global_hash, other.dhash.global_hash);
#[allow(
clippy::cast_precision_loss,
clippy::cast_possible_truncation,
clippy::cast_sign_loss
)]
let avg_distance = ((ahash_dist as f32 * config.ahash_weight)
+ (phash_dist as f32 * config.phash_weight)
+ (dhash_dist as f32 * config.dhash_weight)) as u32;
Similarity {
score: weighted_score.clamp(0.0, 1.0),
exact_match: false,
perceptual_distance: avg_distance,
}
}
#[must_use]
pub fn is_similar(&self, other: &MultiHashFingerprint, threshold: f32) -> bool {
debug_assert!(
(0.0..=1.0).contains(&threshold),
"threshold must be in range [0.0, 1.0], got {}",
threshold
);
let clamped_threshold = threshold.clamp(0.0, 1.0);
self.compare(other).score >= clamped_threshold
}
}
#[cfg(test)]
mod tests {
use super::*;
fn fp(global: u64, blocks_word: u64) -> ImageFingerprint {
ImageFingerprint::new([0u8; 32], global, [blocks_word; 16])
}
fn multi(exact: [u8; 32], a_global: u64, p_global: u64, d_global: u64) -> MultiHashFingerprint {
MultiHashFingerprint::new(
exact,
ImageFingerprint::new(exact, a_global, [a_global; 16]),
ImageFingerprint::new(exact, p_global, [p_global; 16]),
ImageFingerprint::new(exact, d_global, [d_global; 16]),
)
}
#[test]
fn multi_hash_config_default_matches_compare() {
let a = multi([1u8; 32], 0xAAAA, 0xBBBB, 0xCCCC);
let b = multi([2u8; 32], 0xAAAA, 0xBBB0, 0xCCC0);
let default_score = a.compare(&b).score;
let cfg_score = a.compare_with_config(&b, &MultiHashConfig::default()).score;
assert!(
(default_score - cfg_score).abs() < 1e-6,
"{default_score} vs {cfg_score}"
);
}
#[test]
fn multi_hash_config_phash_only_ignores_other_algorithms() {
let a = multi([1u8; 32], 0x0000_0000, 0x1234_5678, 0x0000_0000);
let b = multi([2u8; 32], u64::MAX, 0x1234_5678, u64::MAX);
let default_score = a.compare(&b).score;
let phash_only = MultiHashConfig {
ahash_weight: 0.0,
phash_weight: 1.0,
dhash_weight: 0.0,
..MultiHashConfig::default()
};
let phash_score = a.compare_with_config(&b, &phash_only).score;
assert!((phash_score - 1.0).abs() < 1e-6, "got {phash_score}");
assert!(
default_score < phash_score,
"{default_score} >= {phash_score}"
);
}
#[test]
fn multi_hash_config_exact_match_is_always_one() {
let a = multi([7u8; 32], 0xAAAA, 0xBBBB, 0xCCCC);
let weird = MultiHashConfig {
ahash_weight: 0.0,
phash_weight: 0.0,
dhash_weight: 0.0,
global_weight: 0.0,
block_weight: 0.0,
block_distance_threshold: 0,
};
let s = a.compare_with_config(&a, &weird);
assert!(s.exact_match);
assert_eq!(s.score, 1.0);
}
#[test]
fn multi_hash_config_score_clamped_to_unit_interval() {
let a = multi([1u8; 32], 0, 0, 0);
let b = multi([2u8; 32], 0, 0, 0);
let cfg = MultiHashConfig {
ahash_weight: 5.0,
phash_weight: 5.0,
dhash_weight: 5.0,
global_weight: 10.0,
block_weight: 10.0,
block_distance_threshold: 32,
};
let s = a.compare_with_config(&b, &cfg);
assert!(s.score <= 1.0 && s.score >= 0.0, "got {}", s.score);
}
#[test]
fn fingerprint_unused_helper_compiles() {
let _ = fp(0x1234, 0xABCD);
}
}