aicheck 0.2.0

Detect AI-generated content via provenance signals (C2PA, XMP/IPTC, EXIF)
Documentation
use anyhow::Result;
use rustfft::{num_complex::Complex, FftPlanner};
use std::fs;
use std::path::Path;

use super::wav_metadata;
use super::{Confidence, Signal, SignalBuilder, SignalSource};

const FFT_SIZE: usize = 2048;
const MAX_FRAMES: usize = 64;
const BANDWIDTH_THRESHOLD: f64 = 0.7;
const CUTOFF_ENERGY_RATIO: f64 = 0.02;

fn decode_pcm_16le(data: &[u8], channels: u16) -> Vec<f64> {
    let bytes_per_sample = 2usize;
    let block_align = bytes_per_sample * channels as usize;
    let num_blocks = data.len() / block_align;
    let mut samples = Vec::with_capacity(num_blocks);
    for i in 0..num_blocks {
        let offset = i * block_align;
        if offset + 2 > data.len() {
            break;
        }
        let raw = i16::from_le_bytes([data[offset], data[offset + 1]]);
        samples.push(raw as f64 / 32768.0);
    }
    samples
}

fn compute_avg_spectrum(samples: &[f64], fft_size: usize) -> Vec<f64> {
    if samples.len() < fft_size {
        return vec![];
    }
    let mut planner = FftPlanner::<f64>::new();
    let fft = planner.plan_fft_forward(fft_size);
    let mid = samples.len() / 2;
    let half_window = (MAX_FRAMES * fft_size) / 2;
    let start = mid.saturating_sub(half_window);
    let available = &samples[start..];
    let num_bins = fft_size / 2;
    let mut avg_power = vec![0.0f64; num_bins];
    let mut frame_count = 0usize;
    let hop = fft_size / 2;
    let mut pos = 0;
    while pos + fft_size <= available.len() && frame_count < MAX_FRAMES {
        let mut buffer: Vec<Complex<f64>> = available[pos..pos + fft_size]
            .iter()
            .enumerate()
            .map(|(i, &s)| {
                let w = 0.5
                    * (1.0 - (2.0 * std::f64::consts::PI * i as f64 / (fft_size - 1) as f64).cos());
                Complex::new(s * w, 0.0)
            })
            .collect();
        fft.process(&mut buffer);
        for (bin, power) in avg_power.iter_mut().enumerate() {
            *power += buffer[bin].norm_sqr();
        }
        frame_count += 1;
        pos += hop;
    }
    if frame_count == 0 {
        return vec![];
    }
    for power in avg_power.iter_mut() {
        *power /= frame_count as f64;
    }
    avg_power
}

fn find_bandwidth_cutoff(spectrum: &[f64], sample_rate: u32) -> Option<(f64, f64)> {
    if spectrum.is_empty() {
        return None;
    }
    let num_bins = spectrum.len();
    let nyquist = sample_rate as f64 / 2.0;
    let bin_hz = nyquist / num_bins as f64;
    let total_energy: f64 = spectrum.iter().sum();
    if total_energy == 0.0 {
        return None;
    }
    let mut cumulative = 0.0;
    let mut cutoff_bin = num_bins;
    for (i, &power) in spectrum.iter().enumerate() {
        cumulative += power;
        if cumulative >= total_energy * 0.99 {
            cutoff_bin = i + 1;
            break;
        }
    }
    let cutoff_freq = cutoff_bin as f64 * bin_hz;
    let bandwidth_ratio = cutoff_freq / nyquist;
    if bandwidth_ratio < BANDWIDTH_THRESHOLD {
        let below_energy: f64 = spectrum[..cutoff_bin].iter().sum();
        let above_energy: f64 = spectrum[cutoff_bin..].iter().sum();
        let ratio = if below_energy > 0.0 {
            above_energy / below_energy
        } else {
            0.0
        };
        if ratio < CUTOFF_ENERGY_RATIO {
            return Some((cutoff_freq, bandwidth_ratio));
        }
    }
    None
}

fn spectral_flatness(spectrum: &[f64]) -> f64 {
    let n = spectrum.len() as f64;
    if n == 0.0 {
        return 0.0;
    }
    let filtered: Vec<f64> = spectrum.iter().copied().filter(|&x| x > 1e-20).collect();
    if filtered.is_empty() {
        return 0.0;
    }
    let n = filtered.len() as f64;
    let log_mean = filtered.iter().map(|x| x.ln()).sum::<f64>() / n;
    let geometric_mean = log_mean.exp();
    let arithmetic_mean = filtered.iter().sum::<f64>() / n;
    if arithmetic_mean > 0.0 {
        geometric_mean / arithmetic_mean
    } else {
        0.0
    }
}

pub fn detect(path: &Path) -> Result<Vec<Signal>> {
    let data = fs::read(path)?;
    let wav = match wav_metadata::parse_wav_full(&data) {
        Some(w) => w,
        None => return Ok(vec![]),
    };
    if wav.fmt.bits_per_sample != 16 || wav.pcm_start >= wav.pcm_end {
        return Ok(vec![]);
    }
    let pcm_data = &data[wav.pcm_start..wav.pcm_end];
    let samples = decode_pcm_16le(pcm_data, wav.fmt.channels);
    if samples.len() < FFT_SIZE {
        return Ok(vec![]);
    }
    let spectrum = compute_avg_spectrum(&samples, FFT_SIZE);
    if spectrum.is_empty() {
        return Ok(vec![]);
    }

    let mut signals = Vec::new();

    if let Some((cutoff_freq, bandwidth_ratio)) =
        find_bandwidth_cutoff(&spectrum, wav.fmt.sample_rate)
    {
        let nyquist = wav.fmt.sample_rate as f64 / 2.0;
        signals.push(
            SignalBuilder::new(
                SignalSource::AudioSpectral,
                Confidence::Low,
                "signal_audio_cutoff",
            )
            .param("freq", format!("{:.0}", cutoff_freq))
            .param("pct", format!("{:.0}", bandwidth_ratio * 100.0))
            .param("nyquist", format!("{:.0}", nyquist))
            .detail("cutoff_frequency", format!("{:.0}Hz", cutoff_freq))
            .detail("nyquist", format!("{:.0}Hz", nyquist))
            .detail("bandwidth_used", format!("{:.1}%", bandwidth_ratio * 100.0))
            .build(),
        );
    }

    let flatness = spectral_flatness(&spectrum);
    let nyquist = wav.fmt.sample_rate as f64 / 2.0;
    if nyquist <= 12000.0 && wav.fmt.channels == 1 && flatness < 0.05 {
        signals.push(
            SignalBuilder::new(
                SignalSource::AudioSpectral,
                Confidence::Low,
                "signal_audio_flatness",
            )
            .param("value", format!("{:.4}", flatness))
            .detail("spectral_flatness", format!("{:.4}", flatness))
            .build(),
        );
    }

    Ok(signals)
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_decode_pcm_16le() {
        let data = vec![0u8; 200];
        let samples = decode_pcm_16le(&data, 1);
        assert_eq!(samples.len(), 100);
        assert!(samples.iter().all(|&s| s == 0.0));
    }

    #[test]
    fn test_decode_pcm_16le_stereo() {
        let data = vec![0u8; 400];
        let samples = decode_pcm_16le(&data, 2);
        assert_eq!(samples.len(), 100);
    }

    #[test]
    fn test_spectral_flatness_pure_tone() {
        let mut spectrum = vec![0.0; 1024];
        spectrum[100] = 1.0;
        let flatness = spectral_flatness(&spectrum);
        assert!(flatness <= 1.0);
    }

    #[test]
    fn test_spectral_flatness_white_noise() {
        let spectrum = vec![1.0; 1024];
        let flatness = spectral_flatness(&spectrum);
        assert!((flatness - 1.0).abs() < 0.01);
    }

    #[test]
    fn test_find_bandwidth_cutoff_full() {
        let spectrum = vec![1.0; 1024];
        let result = find_bandwidth_cutoff(&spectrum, 48000);
        assert!(result.is_none());
    }

    #[test]
    fn test_find_bandwidth_cutoff_half() {
        let mut spectrum = vec![0.0; 1024];
        for i in 0..300 {
            spectrum[i] = 1.0;
        }
        let result = find_bandwidth_cutoff(&spectrum, 48000);
        assert!(result.is_some());
        let (freq, ratio) = result.unwrap();
        assert!(freq < 12000.0);
        assert!(ratio < BANDWIDTH_THRESHOLD);
    }

    #[test]
    fn test_compute_avg_spectrum_silence() {
        let samples = vec![0.0; FFT_SIZE * 4];
        let spectrum = compute_avg_spectrum(&samples, FFT_SIZE);
        assert!(!spectrum.is_empty());
        assert!(spectrum.iter().all(|&x| x < 1e-10));
    }

    #[test]
    fn test_compute_avg_spectrum_too_short() {
        let samples = vec![0.0; 100];
        let spectrum = compute_avg_spectrum(&samples, FFT_SIZE);
        assert!(spectrum.is_empty());
    }

    #[test]
    fn test_detect_non_wav() {
        let tmp = tempfile::NamedTempFile::new().unwrap();
        std::fs::write(tmp.path(), b"not a wav").unwrap();
        let signals = detect(tmp.path()).unwrap();
        assert!(signals.is_empty());
    }
}