use std::collections::HashMap;
use anyhow::Result;
use ring::digest::{Context, SHA256};
use tracing::{debug, info};
use crate::fft::FrequencyAnalyzer;
use crate::types::*;
#[derive(Debug, Clone)]
pub struct FingerprintConfig {
pub fft_size: usize,
pub hop_size: usize,
pub num_bands: usize,
pub fan_out: usize,
pub target_zone_frames: usize,
pub peak_threshold: f32,
}
impl Default for FingerprintConfig {
fn default() -> Self {
Self {
fft_size: 4096,
hop_size: 2048,
num_bands: 6,
fan_out: 5,
target_zone_frames: 50,
peak_threshold: 0.1,
}
}
}
pub struct Fingerprinter {
config: FingerprintConfig,
analyzer: FrequencyAnalyzer,
}
impl Fingerprinter {
pub fn new() -> Self {
Self::with_config(FingerprintConfig::default())
}
pub fn with_config(config: FingerprintConfig) -> Self {
let analyzer = FrequencyAnalyzer::new(config.fft_size, config.hop_size);
Self { config, analyzer }
}
pub fn fingerprint(&self, audio: &AudioData) -> Result<AudioFingerprint> {
info!("Generating fingerprint for {} samples", audio.samples.len());
let spectrogram = self.analyzer.compute_spectrogram(&audio.samples)?;
debug!("Computed spectrogram with {} frames", spectrogram.len());
let peaks = self.find_peaks(&spectrogram)?;
debug!("Found {} spectral peaks", peaks.len());
let points = self.create_constellation(&peaks);
debug!("Created {} constellation points", points.len());
let hash_pairs = self.generate_hash_pairs(&points);
debug!("Generated {} hash pairs", hash_pairs.len());
let hash = self.compute_hash(&hash_pairs);
let duration_secs = audio.samples.len() as f64 / audio.sample_rate as f64;
Ok(AudioFingerprint {
hash,
version: 1,
points,
duration_secs,
})
}
fn find_peaks(&self, spectrogram: &[Vec<f32>]) -> Result<Vec<SpectralPeak>> {
let spectrum_size = spectrogram.first()
.map(|f| f.len())
.ok_or_else(|| anyhow::anyhow!("Empty spectrogram"))?;
let band_edges: Vec<usize> = (0..=self.config.num_bands)
.map(|i| {
let t = i as f32 / self.config.num_bands as f32;
(spectrum_size as f32 * t.powf(2.0)) as usize
})
.collect();
let mut peaks = Vec::new();
for (time_idx, frame) in spectrogram.iter().enumerate() {
for band_idx in 0..self.config.num_bands {
let start = band_edges[band_idx];
let end = band_edges[band_idx + 1].min(frame.len());
if start >= end {
continue;
}
let (local_max_idx, &max_val) = frame[start..end]
.iter()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
.unwrap_or((0, &0.0));
if max_val > self.config.peak_threshold {
peaks.push(SpectralPeak {
time_frame: time_idx as u32,
freq_bin: (start + local_max_idx) as u32,
magnitude: max_val,
});
}
}
}
Ok(peaks)
}
fn create_constellation(&self, peaks: &[SpectralPeak]) -> Vec<FingerprintPoint> {
peaks.iter()
.map(|peak| FingerprintPoint {
time_offset: peak.time_frame,
freq_bin: peak.freq_bin,
amplitude: (peak.magnitude * 255.0).min(255.0) as u8,
})
.collect()
}
fn generate_hash_pairs(&self, points: &[FingerprintPoint]) -> Vec<HashPair> {
let mut pairs = Vec::new();
for (i, anchor) in points.iter().enumerate() {
let mut targets_found = 0;
for target in points.iter().skip(i + 1) {
let time_delta = target.time_offset.saturating_sub(anchor.time_offset);
if time_delta > 0 && time_delta <= self.config.target_zone_frames as u32 {
pairs.push(HashPair {
anchor_freq: anchor.freq_bin,
target_freq: target.freq_bin,
time_delta,
anchor_time: anchor.time_offset,
});
targets_found += 1;
if targets_found >= self.config.fan_out {
break;
}
}
}
}
pairs
}
fn compute_hash(&self, pairs: &[HashPair]) -> String {
let mut context = Context::new(&SHA256);
context.update(&1u32.to_le_bytes());
for pair in pairs {
context.update(&pair.anchor_freq.to_le_bytes());
context.update(&pair.target_freq.to_le_bytes());
context.update(&pair.time_delta.to_le_bytes());
}
let digest = context.finish();
hex::encode(digest.as_ref())
}
pub fn match_fingerprints(&self, fp1: &AudioFingerprint, fp2: &AudioFingerprint) -> MatchResult {
let pairs1 = self.generate_hash_pairs(&fp1.points);
let pairs2 = self.generate_hash_pairs(&fp2.points);
let mut fp1_hashes: HashMap<(u32, u32, u32), Vec<u32>> = HashMap::new();
for pair in &pairs1 {
let key = (pair.anchor_freq, pair.target_freq, pair.time_delta);
fp1_hashes.entry(key).or_default().push(pair.anchor_time);
}
let mut _match_count = 0;
let mut time_offsets: HashMap<i64, u32> = HashMap::new();
for pair in &pairs2 {
let key = (pair.anchor_freq, pair.target_freq, pair.time_delta);
if let Some(fp1_times) = fp1_hashes.get(&key) {
_match_count += 1;
for &t1 in fp1_times {
let offset = pair.anchor_time as i64 - t1 as i64;
*time_offsets.entry(offset).or_default() += 1;
}
}
}
let best_offset = time_offsets.iter()
.max_by_key(|(_, &count)| count)
.map(|(&offset, _)| offset)
.unwrap_or(0);
let aligned_matches = time_offsets.get(&best_offset).copied().unwrap_or(0);
let total_pairs = pairs1.len().max(pairs2.len()) as f32;
let similarity = if total_pairs > 0.0 {
aligned_matches as f32 / total_pairs
} else {
0.0
};
MatchResult {
is_match: similarity > 0.1,
similarity,
time_offset_frames: best_offset as i32,
matching_pairs: aligned_matches,
total_pairs_checked: pairs2.len() as u32,
}
}
pub fn verify(&self, audio: &AudioData, expected_hash: &str) -> Result<VerificationResult> {
let fingerprint = self.fingerprint(audio)?;
let matches = fingerprint.hash == expected_hash;
Ok(VerificationResult {
verified: matches,
computed_hash: fingerprint.hash,
expected_hash: expected_hash.to_string(),
})
}
}
impl Default for Fingerprinter {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
struct SpectralPeak {
time_frame: u32,
freq_bin: u32,
magnitude: f32,
}
#[derive(Debug, Clone)]
struct HashPair {
anchor_freq: u32,
target_freq: u32,
time_delta: u32,
anchor_time: u32,
}
#[derive(Debug, Clone)]
pub struct MatchResult {
pub is_match: bool,
pub similarity: f32,
pub time_offset_frames: i32,
pub matching_pairs: u32,
pub total_pairs_checked: u32,
}
#[derive(Debug, Clone)]
pub struct VerificationResult {
pub verified: bool,
pub computed_hash: String,
pub expected_hash: String,
}
pub struct FingerprintDatabase {
index: HashMap<(u32, u32, u32), Vec<(String, u32)>>,
}
impl FingerprintDatabase {
pub fn new() -> Self {
Self {
index: HashMap::new(),
}
}
pub fn add(&mut self, content_id: &str, fingerprint: &AudioFingerprint) {
let fingerprinter = Fingerprinter::new();
let pairs = fingerprinter.generate_hash_pairs(&fingerprint.points);
for pair in pairs {
let key = (pair.anchor_freq, pair.target_freq, pair.time_delta);
self.index.entry(key)
.or_default()
.push((content_id.to_string(), pair.anchor_time));
}
}
pub fn query(&self, fingerprint: &AudioFingerprint, threshold: f32) -> Vec<DatabaseMatch> {
let fingerprinter = Fingerprinter::new();
let pairs = fingerprinter.generate_hash_pairs(&fingerprint.points);
let mut content_matches: HashMap<String, HashMap<i64, u32>> = HashMap::new();
for pair in &pairs {
let key = (pair.anchor_freq, pair.target_freq, pair.time_delta);
if let Some(entries) = self.index.get(&key) {
for (content_id, db_time) in entries {
let offset = pair.anchor_time as i64 - *db_time as i64;
*content_matches
.entry(content_id.clone())
.or_default()
.entry(offset)
.or_default() += 1;
}
}
}
let mut results: Vec<DatabaseMatch> = content_matches.iter()
.filter_map(|(content_id, offsets)| {
let best_count = offsets.values().max().copied().unwrap_or(0);
let similarity = best_count as f32 / pairs.len() as f32;
if similarity >= threshold {
Some(DatabaseMatch {
content_id: content_id.clone(),
similarity,
matching_pairs: best_count,
})
} else {
None
}
})
.collect();
results.sort_by(|a, b| b.similarity.partial_cmp(&a.similarity).unwrap_or(std::cmp::Ordering::Equal));
results
}
}
impl Default for FingerprintDatabase {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct DatabaseMatch {
pub content_id: String,
pub similarity: f32,
pub matching_pairs: u32,
}
#[cfg(test)]
mod tests {
use super::*;
fn generate_test_audio(freq: f32, duration_secs: f32) -> AudioData {
let sample_rate = 44100;
let num_samples = (sample_rate as f32 * duration_secs) as usize;
let samples: Vec<f32> = (0..num_samples)
.map(|i| {
let t = i as f32 / sample_rate as f32;
(2.0 * std::f32::consts::PI * freq * t).sin()
})
.collect();
AudioData::new(samples, sample_rate)
}
#[test]
fn test_fingerprint_generation() {
let audio = generate_test_audio(440.0, 5.0);
let fingerprinter = Fingerprinter::new();
let fp = fingerprinter.fingerprint(&audio).unwrap();
assert!(!fp.hash.is_empty());
assert!(fp.points.len() > 0);
assert_eq!(fp.version, 1);
}
#[test]
fn test_fingerprint_consistency() {
let audio = generate_test_audio(440.0, 5.0);
let fingerprinter = Fingerprinter::new();
let fp1 = fingerprinter.fingerprint(&audio).unwrap();
let fp2 = fingerprinter.fingerprint(&audio).unwrap();
assert_eq!(fp1.hash, fp2.hash);
}
#[test]
fn test_fingerprint_matching() {
let audio1 = generate_test_audio(440.0, 5.0);
let audio2 = generate_test_audio(440.0, 5.0);
let audio3 = generate_test_audio(880.0, 5.0);
let fingerprinter = Fingerprinter::new();
let fp1 = fingerprinter.fingerprint(&audio1).unwrap();
let fp2 = fingerprinter.fingerprint(&audio2).unwrap();
let fp3 = fingerprinter.fingerprint(&audio3).unwrap();
let match_same = fingerprinter.match_fingerprints(&fp1, &fp2);
assert!(match_same.is_match);
let match_diff = fingerprinter.match_fingerprints(&fp1, &fp3);
assert!(match_same.similarity > match_diff.similarity);
}
#[test]
fn test_verification() {
let audio = generate_test_audio(440.0, 5.0);
let fingerprinter = Fingerprinter::new();
let fp = fingerprinter.fingerprint(&audio).unwrap();
let result = fingerprinter.verify(&audio, &fp.hash).unwrap();
assert!(result.verified);
assert_eq!(result.computed_hash, result.expected_hash);
}
#[test]
fn test_database_query() {
let audio1 = generate_test_audio(440.0, 5.0);
let audio2 = generate_test_audio(880.0, 5.0);
let query_audio = generate_test_audio(440.0, 5.0);
let fingerprinter = Fingerprinter::new();
let fp1 = fingerprinter.fingerprint(&audio1).unwrap();
let fp2 = fingerprinter.fingerprint(&audio2).unwrap();
let query_fp = fingerprinter.fingerprint(&query_audio).unwrap();
let mut db = FingerprintDatabase::new();
db.add("content_1", &fp1);
db.add("content_2", &fp2);
let results = db.query(&query_fp, 0.1);
assert!(!results.is_empty());
assert_eq!(results[0].content_id, "content_1");
}
}
mod hex {
pub fn encode(bytes: &[u8]) -> String {
bytes.iter().map(|b| format!("{:02x}", b)).collect()
}
}