use std::f64::consts::PI;
use crate::spectral::SpectralAnalyzer;
use crate::{AnalysisConfig, Result};
#[derive(Debug, Clone)]
pub struct SpliceProbability {
pub frame_idx: usize,
pub confidence: f64,
}
pub struct PhaseDiscontinuityDetector {
pub window_size: usize,
pub hop_size: usize,
pub threshold: f64,
}
impl Default for PhaseDiscontinuityDetector {
fn default() -> Self {
Self {
window_size: 2048,
hop_size: 512,
threshold: 0.1,
}
}
}
impl PhaseDiscontinuityDetector {
#[must_use]
pub fn new(window_size: usize, hop_size: usize, threshold: f64) -> Self {
Self {
window_size,
hop_size,
threshold,
}
}
#[must_use]
pub fn detect_phase_discontinuities(&self, samples: &[f32]) -> Vec<SpliceProbability> {
let n = self.window_size;
if samples.len() < n * 3 {
return Vec::new();
}
let phase_frames = self.compute_phase_frames(samples);
if phase_frames.len() < 3 {
return Vec::new();
}
let num_bins = n / 2 + 1;
let global_peak: f64 = phase_frames
.iter()
.flat_map(|(_, mags)| mags.iter().copied())
.fold(0.0_f64, f64::max);
let mag_floor = (global_peak * 0.1).max(1e-6);
let mut results = Vec::new();
for i in 1..phase_frames.len().saturating_sub(1) {
let (ph_prev, _) = &phase_frames[i - 1];
let (ph_curr, mag_curr) = &phase_frames[i];
let (ph_next, _) = &phase_frames[i + 1];
let mut weighted_d2 = 0.0_f64;
let mut total_weight = 0.0_f64;
for k in 1..num_bins {
let mag = mag_curr[k];
if mag < mag_floor {
continue;
}
let d1_prev = wrap_angle(ph_curr[k] - ph_prev[k]);
let d1_curr = wrap_angle(ph_next[k] - ph_curr[k]);
let d2 = wrap_angle(d1_curr - d1_prev).abs();
weighted_d2 += d2 * mag;
total_weight += mag;
}
if total_weight > 0.0 {
let mean_d2 = weighted_d2 / total_weight;
if mean_d2 > self.threshold {
let confidence =
((mean_d2 - self.threshold) / (PI - self.threshold)).clamp(0.0, 1.0);
results.push(SpliceProbability {
frame_idx: i,
confidence,
});
}
}
}
results
}
fn compute_phase_frames(&self, samples: &[f32]) -> Vec<(Vec<f64>, Vec<f64>)> {
let n = self.window_size;
let hop = self.hop_size;
let num_bins = n / 2 + 1;
let mut frames: Vec<(Vec<f64>, Vec<f64>)> = Vec::new();
let mut pos = 0_usize;
while pos + n <= samples.len() {
let frame = &samples[pos..pos + n];
let (phases, magnitudes) = compute_stft_phases_and_magnitudes(frame, n, num_bins);
frames.push((phases, magnitudes));
pos += hop;
}
frames
}
}
fn compute_stft_phases_and_magnitudes(
frame: &[f32],
n: usize,
num_bins: usize,
) -> (Vec<f64>, Vec<f64>) {
let mut phases = vec![0.0_f64; num_bins];
let mut magnitudes = vec![0.0_f64; num_bins];
for k in 0..num_bins {
let mut re = 0.0_f64;
let mut im = 0.0_f64;
for (j, &s) in frame.iter().enumerate() {
let w = 0.5 * (1.0 - (2.0 * PI * j as f64 / (n - 1) as f64).cos());
let angle = -2.0 * PI * k as f64 * j as f64 / n as f64;
let sv = f64::from(s) * w;
re += sv * angle.cos();
im += sv * angle.sin();
}
magnitudes[k] = (re * re + im * im).sqrt();
phases[k] = im.atan2(re);
}
(phases, magnitudes)
}
#[inline]
fn wrap_angle(angle: f64) -> f64 {
let mut a = angle;
while a > PI {
a -= 2.0 * PI;
}
while a <= -PI {
a += 2.0 * PI;
}
a
}
pub struct EditDetector {
spectral_analyzer: SpectralAnalyzer,
hop_size: usize,
phase_detector: PhaseDiscontinuityDetector,
}
impl EditDetector {
#[must_use]
pub fn new(config: AnalysisConfig) -> Self {
let hop_size = config.hop_size;
let phase_detector = PhaseDiscontinuityDetector::new(
config.fft_size,
hop_size,
0.1, );
Self {
spectral_analyzer: SpectralAnalyzer::new(config),
hop_size,
phase_detector,
}
}
pub fn detect(&self, samples: &[f32], sample_rate: f32) -> Result<EditResult> {
let edit_times = self.detect_discontinuities(samples, sample_rate)?;
Ok(EditResult {
num_edits: edit_times.len(),
edit_times,
})
}
fn detect_discontinuities(&self, samples: &[f32], sample_rate: f32) -> Result<Vec<f32>> {
let window_size = 2048;
let mut edits = Vec::new();
if samples.len() < window_size * 2 {
return Ok(edits);
}
let num_frames = (samples.len() - window_size) / self.hop_size;
let mut spectral_centroids = Vec::new();
let mut energies = Vec::new();
for frame_idx in 0..num_frames {
let start = frame_idx * self.hop_size;
let end = (start + window_size).min(samples.len());
if end - start < window_size {
break;
}
let frame = &samples[start..end];
let energy: f32 = frame.iter().map(|&x| x * x).sum();
energies.push(energy);
let features = self.spectral_analyzer.analyze_frame(frame, sample_rate)?;
spectral_centroids.push(features.centroid);
}
let threshold = 3.0;
for i in 1..energies.len() {
let diff = (energies[i] - energies[i - 1]).abs();
let mean = (energies[i] + energies[i - 1]) / 2.0;
if mean > 0.0 && diff / mean > threshold {
let time = (i * self.hop_size) as f32 / sample_rate;
edits.push(time);
}
}
for i in 1..spectral_centroids.len() {
let diff = (spectral_centroids[i] - spectral_centroids[i - 1]).abs();
if diff > 500.0 {
let time = (i * self.hop_size) as f32 / sample_rate;
if !edits.contains(&time) {
edits.push(time);
}
}
}
let phase_candidates = self.phase_detector.detect_phase_discontinuities(samples);
for candidate in &phase_candidates {
let time = (candidate.frame_idx * self.hop_size) as f32 / sample_rate;
if candidate.confidence >= 0.3 && !edits.contains(&time) {
edits.push(time);
}
}
edits.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
Ok(edits)
}
}
#[derive(Debug, Clone)]
pub struct EditResult {
pub num_edits: usize,
pub edit_times: Vec<f32>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_edit_detector() {
let config = AnalysisConfig::default();
let detector = EditDetector::new(config);
let sample_rate = 44100.0;
let mut samples = Vec::new();
for i in 0..22050 {
samples.push((2.0 * std::f32::consts::PI * 440.0 * i as f32 / sample_rate).sin() * 0.5);
}
for i in 0..22050 {
samples.push((2.0 * std::f32::consts::PI * 880.0 * i as f32 / sample_rate).sin() * 0.5);
}
let result = detector.detect(&samples, sample_rate);
assert!(result.is_ok());
}
#[test]
fn test_phase_continuity_clean_signal() {
let detector = PhaseDiscontinuityDetector::new(512, 128, 0.1);
let sample_rate = 8000.0_f32;
let freq = 440.0_f32;
let num_samples = 8000_usize;
let samples: Vec<f32> = (0..num_samples)
.map(|i| (2.0 * std::f32::consts::PI * freq * i as f32 / sample_rate).sin() * 0.8)
.collect();
let detections = detector.detect_phase_discontinuities(&samples);
assert!(
detections.is_empty(),
"Clean signal should produce 0 phase detections, got {}",
detections.len()
);
}
#[test]
fn test_phase_discontinuity_detected_at_splice() {
let window_size = 512;
let hop_size = 128;
let detector = PhaseDiscontinuityDetector::new(window_size, hop_size, 0.1);
let sample_rate = 44100.0_f32;
let freq = 440.0_f32;
let half = 4096_usize;
let mut samples: Vec<f32> = (0..half)
.map(|i| (2.0 * std::f32::consts::PI * freq * i as f32 / sample_rate).sin() * 0.8)
.collect();
let phase_offset = std::f32::consts::PI;
let second: Vec<f32> = (0..half)
.map(|i| {
(2.0 * std::f32::consts::PI * freq * i as f32 / sample_rate + phase_offset).sin()
* 0.8
})
.collect();
samples.extend_from_slice(&second);
let detections = detector.detect_phase_discontinuities(&samples);
assert!(
!detections.is_empty(),
"Expected at least one phase discontinuity at the splice, got none"
);
let splice_frame = half / hop_size;
let found_near_splice = detections
.iter()
.any(|d| d.frame_idx.abs_diff(splice_frame) <= 4);
assert!(
found_near_splice,
"Expected a detection near frame {} (splice), detections at frames: {:?}",
splice_frame,
detections.iter().map(|d| d.frame_idx).collect::<Vec<_>>()
);
}
}