use crate::{generate_window, AnalysisConfig, AnalysisError, Result, WindowType};
use oxifft::Complex;
const PITCH_CLASS_NAMES: [&str; 12] = [
"C", "C#", "D", "D#", "E", "F", "F#", "G", "G#", "A", "A#", "B",
];
const MAJOR_PROFILE: [f32; 12] = [
6.35, 2.23, 3.48, 2.33, 4.38, 4.09, 2.52, 5.19, 2.39, 3.66, 2.29, 2.88,
];
const MINOR_PROFILE: [f32; 12] = [
6.33, 2.68, 3.52, 5.38, 2.60, 3.53, 2.54, 4.75, 3.98, 2.69, 3.34, 3.17,
];
struct ChordTemplate {
name: &'static str,
intervals: &'static [usize],
}
const CHORD_TEMPLATES: &[ChordTemplate] = &[
ChordTemplate {
name: "maj",
intervals: &[0, 4, 7],
},
ChordTemplate {
name: "min",
intervals: &[0, 3, 7],
},
ChordTemplate {
name: "dim",
intervals: &[0, 3, 6],
},
ChordTemplate {
name: "aug",
intervals: &[0, 4, 8],
},
ChordTemplate {
name: "7",
intervals: &[0, 4, 7, 10],
},
ChordTemplate {
name: "maj7",
intervals: &[0, 4, 7, 11],
},
ChordTemplate {
name: "min7",
intervals: &[0, 3, 7, 10],
},
ChordTemplate {
name: "sus4",
intervals: &[0, 5, 7],
},
ChordTemplate {
name: "sus2",
intervals: &[0, 2, 7],
},
];
pub struct HarmonyAnalyzer {
config: AnalysisConfig,
}
impl HarmonyAnalyzer {
#[must_use]
pub fn new(config: AnalysisConfig) -> Self {
Self { config }
}
pub fn analyze(&self, samples: &[f32], sample_rate: f32) -> Result<HarmonyResult> {
if samples.len() < self.config.fft_size {
return Err(AnalysisError::InsufficientSamples {
needed: self.config.fft_size,
got: samples.len(),
});
}
let chroma_frames = self.extract_chroma_frames(samples, sample_rate)?;
if chroma_frames.is_empty() {
return Ok(HarmonyResult {
key: "Unknown".to_string(),
key_confidence: 0.0,
key_pitch_class: 0,
key_is_major: true,
chords: vec![],
chord_confidences: vec![],
harmonic_complexity: 0.0,
mean_chroma: [0.0; 12],
});
}
let mean_chroma = compute_mean_chroma(&chroma_frames);
let (key_pc, is_major, key_confidence) = detect_key(&mean_chroma);
let key_name = format!(
"{} {}",
PITCH_CLASS_NAMES[key_pc],
if is_major { "major" } else { "minor" }
);
let (chords, chord_confidences) = self.detect_chords(&chroma_frames);
let harmonic_complexity = compute_harmonic_complexity(&chroma_frames, &chords);
Ok(HarmonyResult {
key: key_name,
key_confidence,
key_pitch_class: key_pc,
key_is_major: is_major,
chords,
chord_confidences,
harmonic_complexity,
mean_chroma,
})
}
fn extract_chroma_frames(&self, samples: &[f32], sample_rate: f32) -> Result<Vec<[f32; 12]>> {
let fft_size = self.config.fft_size;
let hop_size = self.config.hop_size;
let window = generate_window(WindowType::Hann, fft_size);
let num_bins = fft_size / 2 + 1;
let num_frames = if samples.len() >= fft_size {
(samples.len() - fft_size) / hop_size + 1
} else {
0
};
let mut chroma_frames = Vec::with_capacity(num_frames);
for frame_idx in 0..num_frames {
let start = frame_idx * hop_size;
let end = start + fft_size;
if end > samples.len() {
break;
}
let complex_input: Vec<Complex<f64>> = samples[start..end]
.iter()
.zip(&window)
.map(|(&s, &w)| Complex::new(f64::from(s * w), 0.0))
.collect();
let fft_output = oxifft::fft(&complex_input);
let magnitude: Vec<f32> = fft_output[..num_bins]
.iter()
.map(|c| c.norm() as f32)
.collect();
let chroma = fold_to_chroma(&magnitude, sample_rate, fft_size);
chroma_frames.push(chroma);
}
Ok(chroma_frames)
}
#[allow(clippy::unused_self)]
fn detect_chords(&self, chroma_frames: &[[f32; 12]]) -> (Vec<String>, Vec<f32>) {
let mut chords = Vec::with_capacity(chroma_frames.len());
let mut confidences = Vec::with_capacity(chroma_frames.len());
for chroma in chroma_frames {
let (chord, confidence) = match_chord(chroma);
chords.push(chord);
confidences.push(confidence);
}
(chords, confidences)
}
}
fn fold_to_chroma(magnitude: &[f32], sample_rate: f32, fft_size: usize) -> [f32; 12] {
let mut chroma = [0.0_f32; 12];
let min_freq = 27.5_f32; let max_freq = 4186.0_f32; let a4_hz = 440.0_f32;
for (bin, &mag) in magnitude.iter().enumerate() {
if mag <= 0.0 {
continue;
}
let freq = bin as f32 * sample_rate / fft_size as f32;
if freq < min_freq || freq > max_freq {
continue;
}
let semitones = 12.0 * (freq / a4_hz).log2();
let rounded = semitones.round() as i32;
let pc = ((rounded + 9).rem_euclid(12)) as usize;
chroma[pc] += mag;
}
let sum: f32 = chroma.iter().sum();
if sum > f32::EPSILON {
for v in &mut chroma {
*v /= sum;
}
}
chroma
}
fn compute_mean_chroma(frames: &[[f32; 12]]) -> [f32; 12] {
let mut mean = [0.0_f32; 12];
if frames.is_empty() {
return mean;
}
for frame in frames {
for (m, &v) in mean.iter_mut().zip(frame.iter()) {
*m += v;
}
}
let n = frames.len() as f32;
for m in &mut mean {
*m /= n;
}
mean
}
pub fn detect_key(chroma: &[f32; 12]) -> (usize, bool, f32) {
let chroma_mean: f32 = chroma.iter().sum::<f32>() / 12.0;
let chroma_var: f32 = chroma
.iter()
.map(|&v| (v - chroma_mean).powi(2))
.sum::<f32>()
/ 12.0;
let chroma_std = chroma_var.sqrt();
if chroma_std < f32::EPSILON {
return (0, true, 0.0);
}
let profile_stats = |p: &[f32; 12]| -> (f32, f32) {
let mean = p.iter().sum::<f32>() / 12.0;
let var: f32 = p.iter().map(|&v| (v - mean).powi(2)).sum::<f32>() / 12.0;
(mean, var.sqrt())
};
let (major_mean, major_std) = profile_stats(&MAJOR_PROFILE);
let (minor_mean, minor_std) = profile_stats(&MINOR_PROFILE);
let mut best_key = 0usize;
let mut best_is_major = true;
let mut best_corr = f32::NEG_INFINITY;
for root in 0..12 {
if major_std > f32::EPSILON {
let rotated = rotate_profile(&MAJOR_PROFILE, root);
let corr = pearson_correlation(
chroma,
&rotated,
chroma_mean,
chroma_std,
major_mean,
major_std,
);
if corr > best_corr {
best_corr = corr;
best_key = root;
best_is_major = true;
}
}
if minor_std > f32::EPSILON {
let rotated = rotate_profile(&MINOR_PROFILE, root);
let corr = pearson_correlation(
chroma,
&rotated,
chroma_mean,
chroma_std,
minor_mean,
minor_std,
);
if corr > best_corr {
best_corr = corr;
best_key = root;
best_is_major = false;
}
}
}
let confidence = ((best_corr + 1.0) / 2.0).clamp(0.0, 1.0);
(best_key, best_is_major, confidence)
}
fn rotate_profile(profile: &[f32; 12], shift: usize) -> [f32; 12] {
let mut out = [0.0_f32; 12];
for i in 0..12 {
out[i] = profile[(i + 12 - shift) % 12];
}
out
}
fn pearson_correlation(
chroma: &[f32; 12],
profile: &[f32; 12],
chroma_mean: f32,
chroma_std: f32,
profile_mean: f32,
profile_std: f32,
) -> f32 {
if chroma_std < f32::EPSILON || profile_std < f32::EPSILON {
return 0.0;
}
let cov: f32 = chroma
.iter()
.zip(profile.iter())
.map(|(&c, &p)| (c - chroma_mean) * (p - profile_mean))
.sum::<f32>()
/ 12.0;
(cov / (chroma_std * profile_std)).clamp(-1.0, 1.0)
}
fn match_chord(chroma: &[f32; 12]) -> (String, f32) {
let mut best_name = String::from("N"); let mut best_score = 0.0_f32;
for root in 0..12 {
for template in CHORD_TEMPLATES {
let mut tmpl = [0.0_f32; 12];
for &interval in template.intervals {
tmpl[(root + interval) % 12] = 1.0;
}
let dot: f32 = chroma.iter().zip(tmpl.iter()).map(|(&a, &b)| a * b).sum();
let norm_a: f32 = chroma.iter().map(|&v| v * v).sum::<f32>().sqrt();
let norm_b: f32 = tmpl.iter().map(|&v| v * v).sum::<f32>().sqrt();
let score = if norm_a > f32::EPSILON && norm_b > f32::EPSILON {
dot / (norm_a * norm_b)
} else {
0.0
};
if score > best_score {
best_score = score;
best_name = format!("{}{}", PITCH_CLASS_NAMES[root], template.name);
}
}
}
(best_name, best_score.clamp(0.0, 1.0))
}
fn compute_harmonic_complexity(chroma_frames: &[[f32; 12]], chords: &[String]) -> f32 {
if chords.is_empty() || chroma_frames.is_empty() {
return 0.0;
}
let mut unique_chords: Vec<&String> = chords.iter().collect();
unique_chords.sort();
unique_chords.dedup();
let chord_diversity =
(unique_chords.len() as f32 - 1.0).max(0.0) / (chords.len() as f32).max(1.0);
let mean_chroma = compute_mean_chroma(chroma_frames);
let sum: f32 = mean_chroma.iter().sum();
let entropy = if sum > f32::EPSILON {
let mut h = 0.0_f32;
for &v in &mean_chroma {
let p = v / sum;
if p > f32::EPSILON {
h -= p * p.log2();
}
}
h / 12.0_f32.log2()
} else {
0.0
};
let mut changes = 0;
for i in 1..chords.len() {
if chords[i] != chords[i - 1] {
changes += 1;
}
}
let change_rate = changes as f32 / (chords.len().saturating_sub(1).max(1)) as f32;
let complexity = 0.3 * chord_diversity + 0.4 * entropy + 0.3 * change_rate;
complexity.clamp(0.0, 1.0)
}
pub fn detect_key_from_audio(samples: &[f32], sample_rate: f32) -> Result<(String, f32)> {
let config = AnalysisConfig::default();
let analyzer = HarmonyAnalyzer::new(config);
let result = analyzer.analyze(samples, sample_rate)?;
Ok((result.key, result.key_confidence))
}
#[derive(Debug, Clone)]
pub struct HarmonyResult {
pub key: String,
pub key_confidence: f32,
pub key_pitch_class: usize,
pub key_is_major: bool,
pub chords: Vec<String>,
pub chord_confidences: Vec<f32>,
pub harmonic_complexity: f32,
pub mean_chroma: [f32; 12],
}
#[cfg(test)]
mod tests {
use super::*;
use std::f32::consts::PI;
fn sine_wave(freq: f32, sample_rate: f32, duration: f32) -> Vec<f32> {
let n = (sample_rate * duration) as usize;
(0..n)
.map(|i| {
let t = i as f32 / sample_rate;
(2.0 * PI * freq * t).sin()
})
.collect()
}
fn chord_signal(freqs: &[f32], sample_rate: f32, duration: f32) -> Vec<f32> {
let n = (sample_rate * duration) as usize;
let amplitude = 1.0 / freqs.len() as f32;
(0..n)
.map(|i| {
let t = i as f32 / sample_rate;
freqs
.iter()
.map(|&f| amplitude * (2.0 * PI * f * t).sin())
.sum::<f32>()
})
.collect()
}
#[test]
fn test_harmony_analyzer_basic() {
let config = AnalysisConfig::default();
let analyzer = HarmonyAnalyzer::new(config);
let samples = sine_wave(440.0, 44100.0, 0.5);
let result = analyzer.analyze(&samples, 44100.0);
assert!(result.is_ok());
let result = result.expect("should succeed");
assert!(!result.key.is_empty());
assert!(result.key_confidence >= 0.0 && result.key_confidence <= 1.0);
}
#[test]
fn test_harmony_analyzer_insufficient_samples() {
let config = AnalysisConfig::default();
let analyzer = HarmonyAnalyzer::new(config);
let samples = vec![0.1; 100]; let result = analyzer.analyze(&samples, 44100.0);
assert!(result.is_err());
}
#[test]
fn test_detect_key_c_major_chord() {
let samples = chord_signal(&[261.63, 329.63, 392.00], 44100.0, 1.0);
let config = AnalysisConfig::default();
let analyzer = HarmonyAnalyzer::new(config);
let result = analyzer.analyze(&samples, 44100.0).expect("should succeed");
assert!(
result.key_confidence > 0.3,
"Key confidence should be reasonable: {}",
result.key_confidence
);
assert_eq!(
result.key_pitch_class, 0,
"Expected C major key from C major chord, got {}",
result.key
);
}
#[test]
fn test_detect_key_a_minor_chord() {
let samples = chord_signal(&[220.0, 261.63, 329.63], 44100.0, 1.0);
let config = AnalysisConfig::default();
let analyzer = HarmonyAnalyzer::new(config);
let result = analyzer.analyze(&samples, 44100.0).expect("should succeed");
assert!(
result.key_confidence > 0.3,
"Key detection confidence: {}",
result.key_confidence
);
}
#[test]
fn test_detect_key_function() {
let mut chroma = [0.0_f32; 12];
chroma[0] = 1.0; chroma[4] = 0.7; chroma[7] = 0.8; let (pc, is_major, conf) = detect_key(&chroma);
assert_eq!(pc, 0, "Expected C, got pitch class {pc}");
assert!(is_major, "Expected major key");
assert!(conf > 0.5, "Confidence should be high: {conf}");
}
#[test]
fn test_detect_key_g_major() {
let mut chroma = [0.0_f32; 12];
chroma[7] = 1.0; chroma[11] = 0.7; chroma[2] = 0.8; let (pc, is_major, _) = detect_key(&chroma);
assert_eq!(pc, 7, "Expected G (7), got {pc}");
assert!(is_major, "Expected major key");
}
#[test]
fn test_detect_key_flat_chroma() {
let chroma = [1.0_f32 / 12.0; 12];
let (_, _, conf) = detect_key(&chroma);
assert!(
conf < 0.7,
"Flat chroma should have low confidence, got {conf}"
);
}
#[test]
fn test_match_chord_c_major() {
let mut chroma = [0.0_f32; 12];
chroma[0] = 1.0; chroma[4] = 0.8; chroma[7] = 0.9; let (name, conf) = match_chord(&chroma);
assert!(
name.starts_with('C'),
"Expected chord starting with C, got {name}"
);
assert!(conf > 0.5, "Chord confidence should be high: {conf}");
}
#[test]
fn test_match_chord_a_minor() {
let mut chroma = [0.0_f32; 12];
chroma[9] = 1.0; chroma[0] = 0.8; chroma[4] = 0.9; let (name, conf) = match_chord(&chroma);
assert!(
name.contains('A') || name.contains('C'),
"Expected A minor or related, got {name}"
);
assert!(conf > 0.3, "Chord confidence: {conf}");
}
#[test]
fn test_harmonic_complexity_single_chord() {
let chroma = [[0.5, 0.0, 0.0, 0.0, 0.3, 0.0, 0.0, 0.2, 0.0, 0.0, 0.0, 0.0]; 10];
let chords: Vec<String> = vec!["Cmaj".to_string(); 10];
let complexity = compute_harmonic_complexity(&chroma, &chords);
assert!(
complexity < 0.5,
"Single repeated chord should have low complexity: {complexity}"
);
}
#[test]
fn test_harmonic_complexity_many_chords() {
let chroma_frames: Vec<[f32; 12]> = (0..12)
.map(|i| {
let mut c = [0.0_f32; 12];
c[i] = 1.0;
c[(i + 4) % 12] = 0.7;
c[(i + 7) % 12] = 0.8;
c
})
.collect();
let chords: Vec<String> = (0..12)
.map(|i| format!("{}maj", PITCH_CLASS_NAMES[i]))
.collect();
let complexity = compute_harmonic_complexity(&chroma_frames, &chords);
assert!(
complexity > 0.3,
"Many different chords should have higher complexity: {complexity}"
);
}
#[test]
fn test_detect_key_from_audio() {
let samples = chord_signal(&[261.63, 329.63, 392.00], 44100.0, 0.5);
let result = detect_key_from_audio(&samples, 44100.0);
assert!(result.is_ok());
let (key, conf) = result.expect("should succeed");
assert!(!key.is_empty());
assert!(conf >= 0.0 && conf <= 1.0);
}
#[test]
fn test_chord_detection_has_results() {
let samples = chord_signal(&[261.63, 329.63, 392.00], 44100.0, 1.0);
let config = AnalysisConfig::default();
let analyzer = HarmonyAnalyzer::new(config);
let result = analyzer.analyze(&samples, 44100.0).expect("should succeed");
assert!(
!result.chords.is_empty(),
"Should detect at least one chord"
);
assert_eq!(result.chords.len(), result.chord_confidences.len());
}
#[test]
fn test_rotate_profile_identity() {
let rotated = rotate_profile(&MAJOR_PROFILE, 0);
for i in 0..12 {
assert!((rotated[i] - MAJOR_PROFILE[i]).abs() < f32::EPSILON);
}
}
#[test]
fn test_rotate_profile_by_one() {
let rotated = rotate_profile(&MAJOR_PROFILE, 1);
assert!((rotated[0] - MAJOR_PROFILE[11]).abs() < f32::EPSILON);
assert!((rotated[1] - MAJOR_PROFILE[0]).abs() < f32::EPSILON);
}
#[test]
fn test_mean_chroma_computation() {
let frames = vec![
[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
];
let mean = compute_mean_chroma(&frames);
assert!((mean[0] - 0.5).abs() < 1e-5);
assert!((mean[4] - 0.5).abs() < 1e-5);
}
}