use std::path::Path;
use midly::{Format, MetaMessage, Smf, TrackEventKind};
use super::click_analysis::BeatGrid;
use super::tempo_guess::{GuessedTempo, GuessedTempoChange};
struct TempoEvent {
tick: u64,
micros_per_beat: u32,
}
struct TimeSigEvent {
tick: u64,
numerator: u8,
denominator_power: u8, }
pub fn extract_tempo_from_midi(
midi_path: &Path,
beat_grid: Option<&BeatGrid>,
) -> Option<GuessedTempo> {
let data = std::fs::read(midi_path).ok()?;
let smf = Smf::parse(&data).ok()?;
let tpb = match smf.header.timing {
midly::Timing::Metrical(tpb) => tpb.as_int() as u64,
midly::Timing::Timecode(_, _) => return None, };
let (tempo_events, time_sig_events) = extract_events(&smf);
if tempo_events.is_empty() {
return None;
}
let start_offset_seconds = match beat_grid {
Some(grid) if grid.beats.len() >= 4 => {
find_best_offset(&tempo_events, &time_sig_events, tpb, grid)
}
Some(grid) => grid.beats.first().copied().unwrap_or(0.0),
None => 0.0,
};
let offset_tick = seconds_to_tick(start_offset_seconds, &tempo_events, tpb);
let base_micros_per_beat = tempo_events
.iter()
.rev()
.find(|e| e.tick <= offset_tick)
.map(|e| e.micros_per_beat)
.unwrap_or(tempo_events[0].micros_per_beat);
let base_bpm = (60_000_000.0 / base_micros_per_beat as f64).round() as u32;
let base_time_sig = time_sig_events
.iter()
.rev()
.find(|e| e.tick <= offset_tick)
.map(|e| [e.numerator as u32, 1u32 << e.denominator_power as u32])
.unwrap_or([4, 4]);
let mut changes = Vec::new();
let mut current_bpm = base_bpm;
let mut current_ts = base_time_sig;
let mut all_events: Vec<(u64, EventKind)> = Vec::new();
for te in &tempo_events {
if te.tick > offset_tick {
all_events.push((te.tick, EventKind::Tempo(te.micros_per_beat)));
}
}
for ts in &time_sig_events {
if ts.tick > offset_tick {
all_events.push((
ts.tick,
EventKind::TimeSig(ts.numerator, ts.denominator_power),
));
}
}
all_events.sort_by_key(|(tick, _)| *tick);
let mut tick_cursor: u64 = offset_tick;
let mut measure: u32 = 1;
let mut beat_in_measure: f64 = 0.0;
let mut beats_per_measure = base_time_sig[0];
let ticks_per_beat = tpb;
for (tick, event) in &all_events {
let delta_ticks = tick - tick_cursor;
let delta_beats = delta_ticks as f64 / ticks_per_beat as f64;
beat_in_measure += delta_beats;
while beat_in_measure >= beats_per_measure as f64 {
beat_in_measure -= beats_per_measure as f64;
measure += 1;
}
tick_cursor = *tick;
let beat_number = beat_in_measure.floor() as u32 + 1;
match event {
EventKind::Tempo(micros_per_beat) => {
let bpm = (60_000_000.0 / *micros_per_beat as f64).round() as u32;
if bpm != current_bpm {
changes.push(GuessedTempoChange {
measure,
beat: beat_number,
bpm,
time_signature: [beats_per_measure, current_ts[1]],
transition_beats: None,
});
current_bpm = bpm;
}
}
EventKind::TimeSig(numerator, denom_power) => {
let new_ts = [*numerator as u32, 1u32 << *denom_power as u32];
if new_ts != current_ts {
changes.push(GuessedTempoChange {
measure,
beat: beat_number,
bpm: current_bpm,
time_signature: new_ts,
transition_beats: None,
});
beats_per_measure = new_ts[0];
current_ts = new_ts;
}
}
}
}
dedup_changes(&mut changes);
collapse_ramps(&mut changes);
let alignment_rms_ms = beat_grid.and_then(|grid| {
if grid.beats.len() < 4 {
return None;
}
let max_time = grid.beats.last().copied().unwrap_or(0.0);
let midi_beats = midi_beat_times(
&tempo_events,
&time_sig_events,
tpb,
start_offset_seconds,
max_time,
);
let score = alignment_score(&midi_beats, &grid.beats);
if score == f64::NEG_INFINITY {
return None;
}
Some((-score).sqrt() * 1000.0)
});
Some(GuessedTempo {
start_seconds: start_offset_seconds,
bpm: base_bpm,
time_signature: base_time_sig,
changes,
alignment_rms_ms,
})
}
fn seconds_to_tick(target_seconds: f64, tempo_events: &[TempoEvent], tpb: u64) -> u64 {
if target_seconds <= 0.0 {
return 0;
}
let mut elapsed_seconds = 0.0;
let mut current_tick: u64 = 0;
let mut current_micros_per_beat = tempo_events[0].micros_per_beat;
for te in tempo_events.iter().skip(1) {
let delta_ticks = te.tick - current_tick;
let seconds_per_tick = current_micros_per_beat as f64 / 1_000_000.0 / tpb as f64;
let delta_seconds = delta_ticks as f64 * seconds_per_tick;
if elapsed_seconds + delta_seconds >= target_seconds {
let remaining = target_seconds - elapsed_seconds;
return current_tick + (remaining / seconds_per_tick).round() as u64;
}
elapsed_seconds += delta_seconds;
current_tick = te.tick;
current_micros_per_beat = te.micros_per_beat;
}
let remaining = target_seconds - elapsed_seconds;
let seconds_per_tick = current_micros_per_beat as f64 / 1_000_000.0 / tpb as f64;
current_tick + (remaining / seconds_per_tick).round() as u64
}
fn tick_to_seconds(target_tick: u64, tempo_events: &[TempoEvent], tpb: u64) -> f64 {
if target_tick == 0 {
return 0.0;
}
let mut elapsed_seconds = 0.0;
let mut current_tick: u64 = 0;
let mut current_micros_per_beat = tempo_events[0].micros_per_beat;
for te in tempo_events.iter().skip(1) {
if te.tick >= target_tick {
break;
}
let delta_ticks = te.tick - current_tick;
let seconds_per_tick = current_micros_per_beat as f64 / 1_000_000.0 / tpb as f64;
elapsed_seconds += delta_ticks as f64 * seconds_per_tick;
current_tick = te.tick;
current_micros_per_beat = te.micros_per_beat;
}
let remaining_ticks = target_tick - current_tick;
let seconds_per_tick = current_micros_per_beat as f64 / 1_000_000.0 / tpb as f64;
elapsed_seconds + remaining_ticks as f64 * seconds_per_tick
}
fn beat_step_ticks(numerator: u8, denominator_power: u8, tpb: u64) -> u64 {
let denominator = 1u32 << denominator_power;
match denominator {
2 => 2 * tpb,
4 => tpb,
8 if numerator.is_multiple_of(3) => 3 * tpb / 2, 8 => tpb / 2, 16 => tpb / 4,
_ => tpb,
}
}
fn midi_beat_times(
tempo_events: &[TempoEvent],
time_sig_events: &[TimeSigEvent],
tpb: u64,
start_seconds: f64,
max_seconds: f64,
) -> Vec<f64> {
let start_tick = seconds_to_tick(start_seconds, tempo_events, tpb);
let mut beats = Vec::new();
let mut tick = start_tick;
loop {
let time = tick_to_seconds(tick, tempo_events, tpb);
if time > max_seconds {
break;
}
beats.push(time);
let (num, denom_pow) = time_sig_events
.iter()
.rev()
.find(|e| e.tick <= tick)
.map(|e| (e.numerator, e.denominator_power))
.unwrap_or((4, 2));
tick += beat_step_ticks(num, denom_pow, tpb);
}
beats
}
fn alignment_score(midi_beats: &[f64], grid_beats: &[f64]) -> f64 {
if midi_beats.is_empty() || grid_beats.is_empty() {
return f64::NEG_INFINITY;
}
let midi_min = midi_beats[0];
let midi_max = midi_beats[midi_beats.len() - 1];
let mut sum_sq = 0.0;
let mut count = 0u32;
for &gb in grid_beats {
if gb < midi_min || gb > midi_max {
continue;
}
let idx = midi_beats.partition_point(|&mb| mb < gb);
let mut best_dist = f64::MAX;
if idx < midi_beats.len() {
best_dist = best_dist.min((midi_beats[idx] - gb).abs());
}
if idx > 0 {
best_dist = best_dist.min((midi_beats[idx - 1] - gb).abs());
}
sum_sq += best_dist * best_dist;
count += 1;
}
if count == 0 {
return f64::NEG_INFINITY;
}
-(sum_sq / count as f64)
}
const MAX_LEADIN_SECONDS: f64 = 10.0;
fn find_best_offset(
tempo_events: &[TempoEvent],
time_sig_events: &[TimeSigEvent],
tpb: u64,
grid: &BeatGrid,
) -> f64 {
let max_time = grid.beats.last().copied().unwrap_or(0.0);
let mut candidates: Vec<f64> = Vec::new();
for &beat_time in &grid.beats {
if beat_time > MAX_LEADIN_SECONDS {
break;
}
candidates.push(beat_time);
}
if candidates.first().is_none_or(|&t| t > 0.001) {
candidates.insert(0, 0.0);
}
let grid_first = grid.beats[0];
let score_epsilon = 5e-5_f64;
let mut best_score = f64::NEG_INFINITY;
let mut best_start = grid_first;
for &candidate in &candidates {
let midi_beats = midi_beat_times(tempo_events, time_sig_events, tpb, candidate, max_time);
let score = alignment_score(&midi_beats, &grid.beats);
let closer = (candidate - grid_first).abs() < (best_start - grid_first).abs();
let meaningfully_better = score > best_score + score_epsilon;
let tied_and_closer = score >= best_score - score_epsilon && closer;
if meaningfully_better || tied_and_closer {
best_score = score;
best_start = candidate;
}
}
let beat_duration = tempo_events[0].micros_per_beat as f64 / 1_000_000.0;
let fine_start = (best_start - beat_duration).max(0.0);
let fine_end = best_start + beat_duration;
let step = 0.001;
let mut t = fine_start;
while t <= fine_end {
let midi_beats = midi_beat_times(tempo_events, time_sig_events, tpb, t, max_time);
let score = alignment_score(&midi_beats, &grid.beats);
let closer = (t - grid_first).abs() < (best_start - grid_first).abs();
let meaningfully_better = score > best_score + score_epsilon;
let tied_and_closer = score >= best_score - score_epsilon && closer;
if meaningfully_better || tied_and_closer {
best_score = score;
best_start = t;
}
t += step;
}
best_start
}
enum EventKind {
Tempo(u32),
TimeSig(u8, u8),
}
fn extract_events(smf: &Smf) -> (Vec<TempoEvent>, Vec<TimeSigEvent>) {
let mut tempo_events = Vec::new();
let mut time_sig_events = Vec::new();
let tracks_to_scan: Vec<&[midly::TrackEvent]> = match smf.header.format {
Format::SingleTrack => smf.tracks.iter().map(|t| t.as_slice()).collect(),
Format::Parallel => {
smf.tracks.iter().map(|t| t.as_slice()).collect()
}
Format::Sequential => smf.tracks.iter().map(|t| t.as_slice()).collect(),
};
for track in tracks_to_scan {
let mut tick: u64 = 0;
for event in track {
tick += event.delta.as_int() as u64;
match event.kind {
TrackEventKind::Meta(MetaMessage::Tempo(tempo)) => {
if !tempo_events.iter().any(|e: &TempoEvent| e.tick == tick) {
tempo_events.push(TempoEvent {
tick,
micros_per_beat: tempo.as_int(),
});
}
}
TrackEventKind::Meta(MetaMessage::TimeSignature(
numerator,
denominator,
_clocks_per_click,
_thirty_seconds_per_quarter,
)) => {
if !time_sig_events
.iter()
.any(|e: &TimeSigEvent| e.tick == tick)
{
time_sig_events.push(TimeSigEvent {
tick,
numerator,
denominator_power: denominator,
});
}
}
_ => {}
}
}
}
tempo_events.sort_by_key(|e| e.tick);
time_sig_events.sort_by_key(|e| e.tick);
(tempo_events, time_sig_events)
}
fn collapse_ramps(changes: &mut Vec<GuessedTempoChange>) {
if changes.len() < 2 {
return;
}
const MAX_GAP_MEASURES: u32 = 4;
let mut i = 0;
while i + 1 < changes.len() {
let bpm_i = changes[i].bpm as i32;
let bpm_next = changes[i + 1].bpm as i32;
let going_down = bpm_next < bpm_i;
let going_up = bpm_next > bpm_i;
if !going_down && !going_up {
i += 1;
continue;
}
if changes[i].time_signature != changes[i + 1].time_signature {
i += 1;
continue;
}
let first_gap = changes[i + 1].measure.saturating_sub(changes[i].measure);
if first_gap > MAX_GAP_MEASURES {
i += 1;
continue;
}
let mut run_end = i + 1;
while run_end + 1 < changes.len() {
let prev_bpm = changes[run_end].bpm as i32;
let next_bpm = changes[run_end + 1].bpm as i32;
let same_dir = if going_down {
next_bpm < prev_bpm
} else {
next_bpm > prev_bpm
};
let same_ts = changes[run_end + 1].time_signature == changes[i].time_signature;
let measure_gap = changes[run_end + 1]
.measure
.saturating_sub(changes[run_end].measure);
let close = measure_gap <= MAX_GAP_MEASURES;
if same_dir && same_ts && close {
run_end += 1;
} else {
break;
}
}
let run_len = run_end - i + 1;
if run_len < 2 {
i += 1;
continue;
}
let first = &changes[i];
let last = &changes[run_end];
let bpm_ts = first.time_signature[0];
let first_beat_abs = (first.measure - 1) * bpm_ts + (first.beat - 1);
let last_beat_abs = (last.measure - 1) * bpm_ts + (last.beat - 1);
let transition_beats = last_beat_abs.saturating_sub(first_beat_abs);
let final_bpm = changes[run_end].bpm;
changes[i].bpm = final_bpm;
changes[i].transition_beats = if transition_beats > 0 {
Some(transition_beats)
} else {
None
};
changes.drain((i + 1)..=run_end);
i += 1;
}
}
fn dedup_changes(changes: &mut Vec<GuessedTempoChange>) {
let mut i = 0;
while i + 1 < changes.len() {
if changes[i].measure == changes[i + 1].measure && changes[i].beat == changes[i + 1].beat {
changes[i].bpm = changes[i + 1].bpm;
changes[i].time_signature = changes[i + 1].time_signature;
changes.remove(i + 1);
} else {
i += 1;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
fn build_midi(events: &[(u32, MidiMetaEvent)]) -> Vec<u8> {
let tpb: u16 = 480;
let mut track_data = Vec::new();
let mut last_tick: u32 = 0;
for (tick, event) in events {
let delta = tick - last_tick;
write_vlq(&mut track_data, delta);
match event {
MidiMetaEvent::Tempo(micros) => {
track_data.extend_from_slice(&[0xFF, 0x51, 0x03]);
track_data.push((micros >> 16) as u8);
track_data.push((micros >> 8) as u8);
track_data.push(*micros as u8);
}
MidiMetaEvent::TimeSig(num, denom_pow) => {
track_data.extend_from_slice(&[0xFF, 0x58, 0x04]);
track_data.push(*num);
track_data.push(*denom_pow);
track_data.push(24); track_data.push(8); }
}
last_tick = *tick;
}
write_vlq(&mut track_data, 0);
track_data.extend_from_slice(&[0xFF, 0x2F, 0x00]);
let track_len = track_data.len() as u32;
let mut midi = Vec::new();
midi.extend_from_slice(b"MThd");
midi.extend_from_slice(&6u32.to_be_bytes()); midi.extend_from_slice(&0u16.to_be_bytes()); midi.extend_from_slice(&1u16.to_be_bytes()); midi.extend_from_slice(&tpb.to_be_bytes());
midi.extend_from_slice(b"MTrk");
midi.extend_from_slice(&track_len.to_be_bytes());
midi.extend_from_slice(&track_data);
midi
}
enum MidiMetaEvent {
Tempo(u32), TimeSig(u8, u8), }
fn bpm_to_micros(bpm: u32) -> u32 {
60_000_000 / bpm
}
fn write_vlq(buf: &mut Vec<u8>, mut value: u32) {
if value == 0 {
buf.push(0);
return;
}
let mut bytes = Vec::new();
while value > 0 {
bytes.push((value & 0x7F) as u8);
value >>= 7;
}
bytes.reverse();
for (i, b) in bytes.iter().enumerate() {
if i < bytes.len() - 1 {
buf.push(b | 0x80);
} else {
buf.push(*b);
}
}
}
fn write_test_midi(events: &[(u32, MidiMetaEvent)]) -> tempfile::NamedTempFile {
let data = build_midi(events);
let mut f = tempfile::NamedTempFile::new().unwrap();
f.write_all(&data).unwrap();
f.flush().unwrap();
f
}
fn make_beat_grid(
bpm: f64,
start: f64,
num_beats: usize,
beats_per_measure: usize,
) -> BeatGrid {
let interval = 60.0 / bpm;
let beats: Vec<f64> = (0..num_beats)
.map(|i| start + i as f64 * interval)
.collect();
let measure_starts: Vec<usize> = (0..num_beats).step_by(beats_per_measure).collect();
BeatGrid {
beats,
measure_starts,
}
}
#[test]
fn alignment_rms_ms_none_without_beat_grid() {
let f = write_test_midi(&[(0, MidiMetaEvent::Tempo(bpm_to_micros(120)))]);
let result = extract_tempo_from_midi(f.path(), None).unwrap();
assert!(
result.alignment_rms_ms.is_none(),
"alignment_rms_ms should be None when no beat grid provided"
);
}
#[test]
fn alignment_rms_ms_low_for_well_matched_grid() {
let f = write_test_midi(&[(0, MidiMetaEvent::Tempo(bpm_to_micros(120)))]);
let grid = make_beat_grid(120.0, 0.0, 64, 4);
let result = extract_tempo_from_midi(f.path(), Some(&grid)).unwrap();
let rms = result
.alignment_rms_ms
.expect("alignment_rms_ms should be Some");
assert!(
rms < 2.0,
"RMS should be near-zero for a perfectly matched grid, got {rms:.3}ms"
);
}
#[test]
fn alignment_rms_ms_high_for_mismatched_grid() {
let f = write_test_midi(&[(0, MidiMetaEvent::Tempo(bpm_to_micros(120)))]);
let grid = make_beat_grid(90.0, 0.0, 48, 4);
let result = extract_tempo_from_midi(f.path(), Some(&grid)).unwrap();
let rms = result
.alignment_rms_ms
.expect("alignment_rms_ms should be Some");
assert!(
rms > 15.0,
"RMS should exceed warning threshold for mismatched BPM, got {rms:.3}ms"
);
}
#[test]
fn no_beat_grid_uses_zero_offset() {
let f = write_test_midi(&[(0, MidiMetaEvent::Tempo(bpm_to_micros(120)))]);
let result = extract_tempo_from_midi(f.path(), None).unwrap();
assert_eq!(result.bpm, 120);
assert!(result.changes.is_empty());
}
#[test]
fn beat_grid_auto_detects_leadin_offset() {
let offset_ticks = 1600;
let f = write_test_midi(&[
(0, MidiMetaEvent::Tempo(bpm_to_micros(100))),
(offset_ticks, MidiMetaEvent::Tempo(bpm_to_micros(140))),
]);
let grid = make_beat_grid(140.0, 2.0, 32, 4);
let result = extract_tempo_from_midi(f.path(), Some(&grid)).unwrap();
assert_eq!(result.bpm, 140);
assert!(result.changes.is_empty(), "No changes after beat 1");
assert!(
(result.start_seconds - 2.0).abs() < 0.01,
"Offset should be ~2.0s, got {}",
result.start_seconds
);
}
#[test]
fn beat_grid_detects_no_leadin() {
let f = write_test_midi(&[(0, MidiMetaEvent::Tempo(bpm_to_micros(120)))]);
let grid = make_beat_grid(120.0, 0.0, 32, 4);
let result = extract_tempo_from_midi(f.path(), Some(&grid)).unwrap();
assert_eq!(result.bpm, 120);
assert!(result.start_seconds.abs() < 0.01, "Offset should be ~0.0s");
}
#[test]
fn beat_grid_with_tempo_change_after_start() {
let offset_ticks = 960; let change_ticks = offset_ticks + 16 * 480; let f = write_test_midi(&[
(0, MidiMetaEvent::Tempo(bpm_to_micros(120))),
(change_ticks, MidiMetaEvent::Tempo(bpm_to_micros(150))),
]);
let mut beats = Vec::new();
let mut measure_starts = Vec::new();
let start = 1.0;
for i in 0..16 {
if i % 4 == 0 {
measure_starts.push(beats.len());
}
beats.push(start + i as f64 * 0.5); }
for i in 0..16 {
if i % 4 == 0 {
measure_starts.push(beats.len());
}
beats.push(start + 16.0 * 0.5 + i as f64 * 0.4); }
let grid = BeatGrid {
beats,
measure_starts,
};
let result = extract_tempo_from_midi(f.path(), Some(&grid)).unwrap();
assert_eq!(result.bpm, 120);
assert!((result.start_seconds - 1.0).abs() < 0.01);
assert_eq!(result.changes.len(), 1);
assert_eq!(result.changes[0].measure, 5);
assert_eq!(result.changes[0].bpm, 150);
}
#[test]
fn beat_grid_with_time_sig_at_start() {
let f = write_test_midi(&[
(0, MidiMetaEvent::Tempo(bpm_to_micros(120))),
(0, MidiMetaEvent::TimeSig(4, 2)), (1920, MidiMetaEvent::TimeSig(3, 2)), ]);
let grid = make_beat_grid(120.0, 2.0, 24, 3);
let result = extract_tempo_from_midi(f.path(), Some(&grid)).unwrap();
assert_eq!(result.bpm, 120);
assert_eq!(result.time_signature, [3, 4]);
assert!((result.start_seconds - 2.0).abs() < 0.01);
}
#[test]
fn seconds_to_tick_zero_offset() {
let events = vec![TempoEvent {
tick: 0,
micros_per_beat: bpm_to_micros(120),
}];
assert_eq!(seconds_to_tick(0.0, &events, 480), 0);
}
#[test]
fn seconds_to_tick_simple() {
let events = vec![TempoEvent {
tick: 0,
micros_per_beat: bpm_to_micros(120),
}];
assert_eq!(seconds_to_tick(2.0, &events, 480), 1920);
}
#[test]
fn seconds_to_tick_with_tempo_change() {
let events = vec![
TempoEvent {
tick: 0,
micros_per_beat: bpm_to_micros(120),
},
TempoEvent {
tick: 960,
micros_per_beat: bpm_to_micros(60),
},
];
assert_eq!(seconds_to_tick(1.5, &events, 480), 1200);
}
#[test]
fn saxon_shore_midi() {
let midi_path = std::path::Path::new(&std::env::var("HOME").unwrap_or_default())
.join("src/backing-tracks/Isenmor/Saxon Shore/Saxon Shore.mid");
if !midi_path.exists() {
eprintln!("Skipping: MIDI not found");
return;
}
let result = extract_tempo_from_midi(&midi_path, None).unwrap();
eprintln!("Base: {} BPM, {:?}", result.bpm, result.time_signature);
for c in &result.changes {
eprintln!(
" m{}/{} {} BPM ts={}/{}",
c.measure, c.beat, c.bpm, c.time_signature[0], c.time_signature[1]
);
}
assert_eq!(result.bpm, 150, "Base should be 150");
}
fn eval_midi_vs_click(
song_dir: &str,
midi_filename: &str,
click_filename: &str,
click_channel: u16,
) {
let base = std::path::Path::new(&std::env::var("HOME").unwrap_or_default())
.join("src/backing-tracks")
.join(song_dir);
let midi_path = base.join(midi_filename);
let click_path = base.join(click_filename);
if !midi_path.exists() || !click_path.exists() {
eprintln!("Skipping {song_dir}: files not found");
return;
}
let grid =
crate::audio::click_analysis::analyze_click_track_default(&click_path, click_channel)
.expect("click analysis failed");
eprintln!("\n=== {song_dir} ===");
eprintln!(
"Click grid: {} beats, {} measures",
grid.beat_count(),
grid.measure_count()
);
if let (Some(&first), Some(&last)) = (grid.beats.first(), grid.beats.last()) {
eprintln!(
" Span: {first:.3}s – {last:.3}s (duration {:.1}s)",
last - first
);
}
let result =
extract_tempo_from_midi(&midi_path, Some(&grid)).expect("MIDI extraction failed");
eprintln!(
"MIDI result: {} BPM, ts={}/{}, start_offset={:.4}s",
result.bpm, result.time_signature[0], result.time_signature[1], result.start_seconds
);
for c in &result.changes {
eprintln!(
" m{}/{} {} BPM ts={}/{} transition={:?}",
c.measure,
c.beat,
c.bpm,
c.time_signature[0],
c.time_signature[1],
c.transition_beats
);
}
let data = std::fs::read(&midi_path).unwrap();
let smf = midly::Smf::parse(&data).unwrap();
let tpb = match smf.header.timing {
midly::Timing::Metrical(tpb) => tpb.as_int() as u64,
_ => panic!("SMPTE timing not supported"),
};
let (tempo_events, time_sig_events) = extract_events(&smf);
let max_time = grid.beats.last().copied().unwrap_or(0.0) + 1.0;
let midi_beats = midi_beat_times(
&tempo_events,
&time_sig_events,
tpb,
result.start_seconds,
max_time,
);
eprintln!(
"MIDI predicted {} beats vs {} click beats",
midi_beats.len(),
grid.beats.len()
);
let click_min = grid.beats.first().copied().unwrap_or(0.0);
let click_max = grid.beats.last().copied().unwrap_or(0.0);
let mut errors_ms: Vec<f64> = Vec::new();
let mut drift_series: Vec<(f64, f64)> = Vec::new();
for &mb in &midi_beats {
if mb < click_min || mb > click_max {
continue;
}
let idx = grid.beats.partition_point(|&cb| cb < mb);
let mut best_dist = f64::MAX;
let mut best_signed = 0.0f64;
if idx < grid.beats.len() {
let d = (mb - grid.beats[idx]).abs();
if d < best_dist {
best_dist = d;
best_signed = mb - grid.beats[idx];
}
}
if idx > 0 {
let d = (mb - grid.beats[idx - 1]).abs();
if d < best_dist {
best_dist = d;
best_signed = mb - grid.beats[idx - 1];
}
}
errors_ms.push(best_dist * 1000.0);
drift_series.push((mb, best_signed * 1000.0));
}
if errors_ms.is_empty() {
eprintln!("No overlap between MIDI beats and click grid — cannot evaluate.");
return;
}
let mean_err = errors_ms.iter().sum::<f64>() / errors_ms.len() as f64;
let rmse = (errors_ms.iter().map(|e| e * e).sum::<f64>() / errors_ms.len() as f64).sqrt();
let max_err = errors_ms.iter().cloned().fold(0.0f64, f64::max);
let median = {
let mut s = errors_ms.clone();
s.sort_by(|a, b| a.partial_cmp(b).unwrap());
s[s.len() / 2]
};
eprintln!("\nAlignment errors ({} beats evaluated):", errors_ms.len());
eprintln!(" Mean: {mean_err:.2}ms");
eprintln!(" Median: {median:.2}ms");
eprintln!(" RMSE: {rmse:.2}ms");
eprintln!(" Max: {max_err:.2}ms");
let buckets = [1.0, 2.0, 5.0, 10.0, 20.0, f64::INFINITY];
let labels = ["<1ms", "<2ms", "<5ms", "<10ms", "<20ms", "≥20ms"];
let total = errors_ms.len() as f64;
eprintln!("\n Distribution:");
let mut prev = 0.0f64;
for (threshold, label) in buckets.iter().zip(labels.iter()) {
let count = errors_ms
.iter()
.filter(|&&e| e >= prev && e < *threshold)
.count();
let pct = count as f64 / total * 100.0;
eprintln!(" {label:>6}: {count:4} beats ({pct:5.1}%)");
prev = *threshold;
}
eprintln!("\n Signed drift by 30s window (positive = MIDI ahead of click):");
eprintln!(
" {:>8} {:>8} {:>8} {:>8} {:>8}",
"window", "beats", "mean", "min", "max"
);
let window = 30.0f64;
let song_end = drift_series.last().map(|(t, _)| *t).unwrap_or(0.0);
let mut t = drift_series.first().map(|(t, _)| *t).unwrap_or(0.0);
while t < song_end {
let slice: Vec<f64> = drift_series
.iter()
.filter(|(mt, _)| *mt >= t && *mt < t + window)
.map(|(_, d)| *d)
.collect();
if !slice.is_empty() {
let wmean = slice.iter().sum::<f64>() / slice.len() as f64;
let wmin = slice.iter().cloned().fold(f64::INFINITY, f64::min);
let wmax = slice.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
eprintln!(
" {:>7.0}s {:>8} {:>7.1}ms {:>7.1}ms {:>7.1}ms",
t,
slice.len(),
wmean,
wmin,
wmax
);
}
t += window;
}
}
#[test]
fn saxon_shore_beat_grid_alignment() {
eval_midi_vs_click("Isenmor/Saxon Shore", "Saxon Shore.mid", "Click.flac", 0);
}
#[test]
fn saxon_shore_leadin_analysis() {
let base = std::path::Path::new(&std::env::var("HOME").unwrap_or_default())
.join("src/backing-tracks/Isenmor/Saxon Shore");
let midi_path = base.join("Saxon Shore.mid");
let click_path = base.join("Click.flac");
if !midi_path.exists() || !click_path.exists() {
eprintln!("Skipping: files not found");
return;
}
let grid = crate::audio::click_analysis::analyze_click_track_default(&click_path, 0)
.expect("click analysis failed");
let data = std::fs::read(&midi_path).unwrap();
let smf = midly::Smf::parse(&data).unwrap();
let tpb = match smf.header.timing {
midly::Timing::Metrical(tpb) => tpb.as_int() as u64,
_ => panic!("SMPTE not supported"),
};
let (tempo_events, time_sig_events) = extract_events(&smf);
let found_offset = find_best_offset(&tempo_events, &time_sig_events, tpb, &grid);
let max_time = grid.beats.last().copied().unwrap_or(0.0);
let score_found = {
let mb = midi_beat_times(&tempo_events, &time_sig_events, tpb, found_offset, max_time);
alignment_score(&mb, &grid.beats)
};
let score_zero = {
let mb = midi_beat_times(&tempo_events, &time_sig_events, tpb, 0.0, max_time);
alignment_score(&mb, &grid.beats)
};
eprintln!("Found offset: {found_offset:.4}s score={score_found:.6}");
eprintln!("Zero offset: 0.0000s score={score_zero:.6}");
eprintln!(
"Score delta (found - zero): {:.6}",
score_found - score_zero
);
let pre_beats: Vec<f64> = grid
.beats
.iter()
.copied()
.filter(|&t| t < found_offset)
.collect();
eprintln!(
"\nClick beats before found offset ({found_offset:.3}s): {} beats",
pre_beats.len()
);
if pre_beats.len() <= 30 {
for (i, &t) in pre_beats.iter().enumerate() {
let spacing = if i > 0 {
format!(" (+{:.3}s)", t - pre_beats[i - 1])
} else {
String::new()
};
eprintln!(" beat {i:>2}: {t:.4}s{spacing}");
}
}
let spacings_leadin: Vec<f64> = pre_beats.windows(2).map(|w| w[1] - w[0]).collect();
let post_beats: Vec<f64> = grid
.beats
.iter()
.copied()
.filter(|&t| t >= found_offset)
.take(20)
.collect();
let spacings_post: Vec<f64> = post_beats.windows(2).map(|w| w[1] - w[0]).collect();
if !spacings_leadin.is_empty() {
let mean_li = spacings_leadin.iter().sum::<f64>() / spacings_leadin.len() as f64;
eprintln!(
"\nLead-in beat spacing: mean={:.4}s ({:.1} BPM)",
mean_li,
60.0 / mean_li
);
}
if !spacings_post.is_empty() {
let mean_post = spacings_post.iter().sum::<f64>() / spacings_post.len() as f64;
eprintln!(
"Post-offset beat spacing (first 20): mean={:.4}s ({:.1} BPM)",
mean_post,
60.0 / mean_post
);
}
eprintln!("\nAlignment scores for candidates in first 10s:");
eprintln!(" {:>8} {:>12} note", "offset", "score");
let mut candidates: Vec<f64> = vec![0.0];
candidates.extend(grid.beats.iter().copied().filter(|&t| t <= 10.0));
candidates.dedup_by(|a, b| (*a - *b).abs() < 0.001);
for c in &candidates {
let mb = midi_beat_times(&tempo_events, &time_sig_events, tpb, *c, max_time);
let score = alignment_score(&mb, &grid.beats);
let marker = if (*c - found_offset).abs() < 0.002 {
" ← chosen"
} else {
""
};
eprintln!(" {:>8.4}s {:>12.6}{marker}", c, score);
}
}
#[test]
fn sigurds_song_beat_grid_alignment() {
eval_midi_vs_click("Isenmor/Sigurd's Song", "Midi.mid", "Click.flac", 0);
}
#[test]
fn operation_orcinianus_copia_beat_grid_alignment() {
eval_midi_vs_click(
"Recently Vacated Graves/Operation Orcinianus Copia",
"automation.mid",
"Click.flac",
0,
);
}
#[test]
fn alignment_rms_values() {
let base = std::path::Path::new(&std::env::var("HOME").unwrap_or_default())
.join("src/backing-tracks");
let songs: &[(&str, &str, &str)] = &[
("Isenmor/Beornulf", "Beornulf.mid", "Click.flac"),
("Isenmor/Battle Scarred", "Battle Scarred.mid", "Click.flac"),
("Isenmor/Jotunheim", "Jotunheim.mid", "Click.flac"),
("Isenmor/Afar", "Afar.mid", "Click.flac"),
("Isenmor/The Pursuit of Vikings", "Midi.mid", "Click.flac"),
("Isenmor/Saxon Shore", "Saxon Shore.mid", "Click.flac"),
("Isenmor/Sigurd's Song", "Midi.mid", "Click.flac"),
(
"Recently Vacated Graves/Operation Orcinianus Copia",
"automation.mid",
"Click.flac",
),
];
eprintln!("\n{:<50} {:>14} note", "Song", "alignment_rms_ms");
eprintln!("{}", "-".repeat(72));
for (dir, midi, click) in songs {
let midi_path = base.join(dir).join(midi);
let click_path = base.join(dir).join(click);
if !midi_path.exists() || !click_path.exists() {
eprintln!("{:<50} (skipped — files not found)", dir);
continue;
}
let grid = crate::audio::click_analysis::analyze_click_track_default(&click_path, 0);
let result = extract_tempo_from_midi(&midi_path, grid.as_ref());
match result {
Some(t) => {
let rms = t
.alignment_rms_ms
.map(|v| format!("{v:>10.2}ms"))
.unwrap_or_else(|| " n/a".into());
let flag = t
.alignment_rms_ms
.map(|v| if v > 15.0 { " ⚠ poor" } else { "" })
.unwrap_or("");
eprintln!("{:<50} {rms}{flag}", dir);
}
None => eprintln!("{:<50} (extraction failed)", dir),
}
}
}
#[test]
fn batch_beat_grid_alignment() {
let songs: &[(&str, &str, &str)] = &[
("Isenmor/Afar", "Afar.mid", "Click.flac"),
("Isenmor/Beornulf", "Beornulf.mid", "Click.flac"),
("Isenmor/Battle Scarred", "Battle Scarred.mid", "Click.flac"),
(
"Isenmor/Death is a Fine Companion",
"Death is a Fine Companion.mid",
"Click.flac",
),
("Isenmor/Jotunheim", "Jotunheim.mid", "Click.flac"),
("Isenmor/The Pursuit of Vikings", "Midi.mid", "Click.flac"),
("Isenmor/Throneless", "Throneless.mid", "Click.flac"),
("Isenmor/Wanderlust", "Wanderlust.mid", "Click.flac"),
(
"Recently Vacated Graves/Bored to Undeath",
"automation.mid",
"Click.flac",
),
(
"Recently Vacated Graves/Hurricane Zombie",
"automation.mid",
"Click.flac",
),
(
"Recently Vacated Graves/Send More Cops",
"automation.mid",
"Click.flac",
),
(
"Recently Vacated Graves/Zombie Ritual",
"automation.mid",
"Click.flac",
),
(
"Recently Vacated Graves/Devoured in Decay",
"automation.mid",
"Click.flac",
),
(
"Recently Vacated Graves/You Die",
"automation.mid",
"Click.flac",
),
];
for (dir, midi, click) in songs {
eval_midi_vs_click(dir, midi, click, 0);
}
}
#[test]
fn sigurds_song_midi() {
let midi_path = std::path::Path::new(&std::env::var("HOME").unwrap_or_default())
.join("src/backing-tracks/Isenmor/Sigurd's Song/Sigurd's Song.mid");
if !midi_path.exists() {
let alt = std::path::Path::new(&std::env::var("HOME").unwrap_or_default())
.join("src/backing-tracks/Isenmor/Sigurd's Song/Midi.mid");
if !alt.exists() {
eprintln!("Skipping: MIDI not found");
return;
}
let result = extract_tempo_from_midi(&alt, None).unwrap();
eprintln!("Base: {} BPM, {:?}", result.bpm, result.time_signature);
for c in &result.changes {
eprintln!(
" m{}/{} {} BPM ts={}/{}",
c.measure, c.beat, c.bpm, c.time_signature[0], c.time_signature[1]
);
}
assert_eq!(result.bpm, 120, "Base should be 120");
return;
}
let result = extract_tempo_from_midi(&midi_path, None).unwrap();
eprintln!("Base: {} BPM, {:?}", result.bpm, result.time_signature);
for c in &result.changes {
eprintln!(
" m{}/{} {} BPM ts={}/{}",
c.measure, c.beat, c.bpm, c.time_signature[0], c.time_signature[1]
);
}
assert_eq!(result.bpm, 120, "Base should be 120");
}
#[test]
fn operation_orcinianus_copia_midi() {
let midi_path = std::path::Path::new(&std::env::var("HOME").unwrap_or_default()).join(
"src/backing-tracks/Recently Vacated Graves/Operation Orcinianus Copia/automation.mid",
);
if !midi_path.exists() {
eprintln!("Skipping: MIDI not found");
return;
}
let result = extract_tempo_from_midi(&midi_path, None).unwrap();
eprintln!("Base: {} BPM, {:?}", result.bpm, result.time_signature);
for c in &result.changes {
eprintln!(
" m{}/{} {} BPM ts={}/{} transition={:?}",
c.measure,
c.beat,
c.bpm,
c.time_signature[0],
c.time_signature[1],
c.transition_beats
);
}
let rit = result
.changes
.iter()
.find(|c| c.measure >= 24 && c.measure <= 26 && c.transition_beats.is_some());
assert!(
rit.is_some(),
"Expected a rit transition near measure 25, got: {:?}",
result.changes
);
}
#[test]
fn tick_to_seconds_zero() {
let events = vec![TempoEvent {
tick: 0,
micros_per_beat: bpm_to_micros(120),
}];
assert_eq!(tick_to_seconds(0, &events, 480), 0.0);
}
#[test]
fn tick_to_seconds_simple() {
let events = vec![TempoEvent {
tick: 0,
micros_per_beat: bpm_to_micros(120),
}];
let result = tick_to_seconds(1920, &events, 480);
assert!((result - 2.0).abs() < 1e-9);
}
#[test]
fn tick_to_seconds_with_tempo_change() {
let events = vec![
TempoEvent {
tick: 0,
micros_per_beat: bpm_to_micros(120),
},
TempoEvent {
tick: 960,
micros_per_beat: bpm_to_micros(60),
},
];
let result = tick_to_seconds(1200, &events, 480);
assert!((result - 1.5).abs() < 1e-9);
}
#[test]
fn tick_to_seconds_roundtrip() {
let events = vec![
TempoEvent {
tick: 0,
micros_per_beat: bpm_to_micros(120),
},
TempoEvent {
tick: 960,
micros_per_beat: bpm_to_micros(90),
},
];
for target_secs in [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 5.0] {
let tick = seconds_to_tick(target_secs, &events, 480);
let back = tick_to_seconds(tick, &events, 480);
assert!(
(back - target_secs).abs() < 0.002,
"Roundtrip failed for {}s: got {}s (via tick {})",
target_secs,
back,
tick,
);
}
}
#[test]
fn alignment_score_perfect() {
let beats = vec![0.0, 0.5, 1.0, 1.5, 2.0];
let score = alignment_score(&beats, &beats);
assert!(
(score - 0.0).abs() < 1e-12,
"Perfect alignment should score 0.0"
);
}
#[test]
fn alignment_score_decreases_with_drift() {
let midi_beats: Vec<f64> = (0..20).map(|i| i as f64 * 0.5).collect();
let good_grid: Vec<f64> = (0..20).map(|i| i as f64 * 0.5).collect();
let bad_grid: Vec<f64> = (0..20).map(|i| i as f64 * 0.5 + 0.05).collect();
let good_score = alignment_score(&midi_beats, &good_grid);
let bad_score = alignment_score(&midi_beats, &bad_grid);
assert!(good_score > bad_score, "Drifted grid should score worse");
}
#[test]
fn find_offset_no_leadin() {
let events = vec![TempoEvent {
tick: 0,
micros_per_beat: bpm_to_micros(120),
}];
let grid = make_beat_grid(120.0, 0.0, 32, 4);
let offset = find_best_offset(&events, &[], 480, &grid);
assert!(offset.abs() < 0.01, "Expected ~0.0, got {}", offset);
}
#[test]
fn find_offset_with_leadin() {
let events = vec![TempoEvent {
tick: 0,
micros_per_beat: bpm_to_micros(120),
}];
let grid = make_beat_grid(120.0, 2.0, 32, 4);
let offset = find_best_offset(&events, &[], 480, &grid);
assert!((offset - 2.0).abs() < 0.01, "Expected ~2.0, got {}", offset);
}
#[test]
fn find_offset_with_tempo_change_in_leadin() {
let events = vec![
TempoEvent {
tick: 0,
micros_per_beat: bpm_to_micros(100),
},
TempoEvent {
tick: 1600,
micros_per_beat: bpm_to_micros(140),
},
];
let grid = make_beat_grid(140.0, 2.0, 32, 4);
let offset = find_best_offset(&events, &[], 480, &grid);
assert!((offset - 2.0).abs() < 0.02, "Expected ~2.0, got {}", offset);
}
#[test]
fn beat_step_ticks_simple_meters() {
let tpb = 480_u64;
assert_eq!(
beat_step_ticks(2, 1, tpb),
2 * tpb,
"2/2 should step by half notes"
);
assert_eq!(
beat_step_ticks(4, 2, tpb),
tpb,
"4/4 should step by quarter notes"
);
assert_eq!(
beat_step_ticks(3, 2, tpb),
tpb,
"3/4 should step by quarter notes"
);
assert_eq!(
beat_step_ticks(7, 3, tpb),
tpb / 2,
"7/8 should step by eighth notes"
);
assert_eq!(
beat_step_ticks(4, 4, tpb),
tpb / 4,
"4/16 should step by sixteenth notes"
);
}
#[test]
fn beat_step_ticks_compound_meters() {
let tpb = 480_u64;
assert_eq!(
beat_step_ticks(6, 3, tpb),
3 * tpb / 2,
"6/8 should step by dotted quarters"
);
assert_eq!(
beat_step_ticks(9, 3, tpb),
3 * tpb / 2,
"9/8 should step by dotted quarters"
);
assert_eq!(
beat_step_ticks(12, 3, tpb),
3 * tpb / 2,
"12/8 should step by dotted quarters"
);
assert_eq!(
beat_step_ticks(3, 3, tpb),
3 * tpb / 2,
"3/8 treated as compound dotted-quarter"
);
}
#[test]
fn find_offset_prefers_earlier_when_scores_are_tied() {
let tpb = 480_u64;
let events = vec![TempoEvent {
tick: 0,
micros_per_beat: bpm_to_micros(120),
}];
let grid = make_beat_grid(120.0, 0.0, 64, 4);
let offset = find_best_offset(&events, &[], tpb, &grid);
assert!(
offset.abs() < 0.01,
"Expected offset ~0.0 (no lead-in), got {offset:.4}s — epsilon tie-breaking may have failed",
);
}
#[test]
fn find_offset_uses_later_offset_when_meaningfully_better() {
let tpb = 480_u64;
let events = vec![TempoEvent {
tick: 0,
micros_per_beat: bpm_to_micros(120),
}];
let grid = make_beat_grid(120.0, 0.5, 32, 4);
let offset = find_best_offset(&events, &[], tpb, &grid);
assert!(
(offset - 0.5).abs() < 0.01,
"Expected offset ~0.5s (1-beat lead-in), got {offset:.4}s",
);
}
}