use crate::{AudexError, FileType, Result, StreamInfo, Tags};
use std::fs::File;
use std::io::{BufReader, Read, Seek};
use std::path::Path;
use std::time::Duration;
#[cfg(feature = "async")]
use crate::util::loadfile_read_async;
#[cfg(feature = "async")]
use std::io::SeekFrom;
#[cfg(feature = "async")]
use tokio::fs::File as TokioFile;
#[cfg(feature = "async")]
use tokio::io::{AsyncReadExt, AsyncSeekExt};
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
enum EventType {
Tempo,
Midi,
}
#[derive(Debug, Clone)]
struct MidiEvent {
deltasum: u64,
event_type: EventType,
data: u32,
}
fn var_int(data: &[u8], mut offset: usize) -> Result<(u32, usize)> {
let mut val: u32 = 0;
let max_bytes = 4;
let mut bytes_read = 0;
loop {
if offset >= data.len() {
return Err(AudexError::ParseError(
"Not enough data for VLQ".to_string(),
));
}
bytes_read += 1;
if bytes_read > max_bytes {
return Err(AudexError::ParseError(
"VLQ exceeds 4-byte MIDI limit".to_string(),
));
}
let x = data[offset];
offset += 1;
val = (val << 7) | ((x & 0x7F) as u32);
if (x & 0x80) == 0 {
return Ok((val, offset));
}
}
}
fn read_track(chunk: &[u8]) -> Result<Vec<MidiEvent>> {
const MAX_TRACK_EVENTS: usize = 1_000_000;
let mut events = Vec::new();
let mut deltasum: u64 = 0;
let mut status: u8 = 0;
let mut off: usize = 0;
while off < chunk.len() {
if events.len() >= MAX_TRACK_EVENTS {
return Err(AudexError::ParseError(format!(
"Track exceeds maximum event count ({})",
MAX_TRACK_EVENTS
)));
}
let delta;
(delta, off) = var_int(chunk, off)?;
deltasum = deltasum.saturating_add(delta as u64);
if off >= chunk.len() {
break;
}
let mut event_type = chunk[off];
off += 1;
if event_type == 0xFF {
if off >= chunk.len() {
return Err(AudexError::ParseError("Truncated meta event".to_string()));
}
let meta_type = chunk[off];
off += 1;
let num;
(num, off) = var_int(chunk, off)?;
if meta_type == 0x51 {
if off + (num as usize) > chunk.len() {
return Err(AudexError::ParseError("Truncated tempo data".to_string()));
}
let data = &chunk[off..off + (num as usize)];
if data.len() != 3 {
return Err(AudexError::ParseError(
"Invalid tempo data length".to_string(),
));
}
let tempo = ((data[0] as u32) << 16) | ((data[1] as u32) << 8) | (data[2] as u32);
events.push(MidiEvent {
deltasum,
event_type: EventType::Tempo,
data: tempo,
});
}
let len = num as usize;
if off + len > chunk.len() {
return Err(AudexError::ParseError(format!(
"Meta event length {} exceeds remaining track data ({})",
len,
chunk.len() - off
)));
}
off += len;
} else if event_type == 0xF0 || event_type == 0xF7 {
let val;
(val, off) = var_int(chunk, off)?;
let len = val as usize;
if off + len > chunk.len() {
return Err(AudexError::ParseError(format!(
"SysEx event length {} exceeds remaining track data ({})",
len,
chunk.len() - off
)));
}
off += len;
} else if (0xF1..=0xF6).contains(&event_type) {
match event_type {
0xF1 | 0xF3 => {
if off + 1 > chunk.len() {
return Err(AudexError::ParseError(format!(
"Truncated System Common message 0x{:02X} at offset {}",
event_type, off
)));
}
off += 1;
}
0xF2 => {
if off + 2 > chunk.len() {
return Err(AudexError::ParseError(format!(
"Truncated Song Position Pointer at offset {}",
off
)));
}
off += 2;
}
0xF6 => {} _ => {} }
} else if event_type >= 0xF8 {
} else {
let is_running_status = event_type < 0x80;
if is_running_status {
if status == 0 {
return Err(AudexError::ParseError(
"MIDI running status used before any valid status byte".to_string(),
));
}
event_type = status;
} else if event_type < 0xF0 {
status = event_type;
} else {
return Err(AudexError::ParseError("Invalid event type".to_string()));
}
let data_bytes: usize = match event_type >> 4 {
0xC | 0xD => 1,
_ => 2,
};
if is_running_status {
off += data_bytes - 1;
} else {
off += data_bytes;
}
if off > chunk.len() {
return Err(AudexError::ParseError(
"MIDI event data exceeds track bounds".to_string(),
));
}
events.push(MidiEvent {
deltasum,
event_type: EventType::Midi,
data: delta,
});
}
}
Ok(events)
}
fn read_chunk<R: Read>(reader: &mut R) -> Result<([u8; 4], Vec<u8>)> {
let mut info = [0u8; 8];
reader
.read_exact(&mut info)
.map_err(|_| AudexError::ParseError("Truncated chunk header".to_string()))?;
let identifier = [info[0], info[1], info[2], info[3]];
let chunklen = u32::from_be_bytes([info[4], info[5], info[6], info[7]]) as usize;
const MAX_CHUNK_SIZE: usize = 64 * 1024 * 1024; if chunklen > MAX_CHUNK_SIZE {
return Err(AudexError::ParseError(format!(
"Chunk size too large: {} bytes (max {})",
chunklen, MAX_CHUNK_SIZE
)));
}
let mut data = vec![0u8; chunklen];
reader
.read_exact(&mut data)
.map_err(|_| AudexError::ParseError("Truncated chunk data".to_string()))?;
Ok((identifier, data))
}
fn read_midi_length<R: Read>(reader: &mut R) -> Result<f64> {
let (identifier, chunk) = read_chunk(reader)?;
if &identifier != b"MThd" {
return Err(AudexError::ParseError("Not a MIDI file".to_string()));
}
if chunk.len() != 6 {
return Err(AudexError::ParseError("Invalid MIDI header".to_string()));
}
let format = u16::from_be_bytes([chunk[0], chunk[1]]);
let ntracks = u16::from_be_bytes([chunk[2], chunk[3]]);
let tickdiv = u16::from_be_bytes([chunk[4], chunk[5]]);
if format > 1 {
return Err(AudexError::ParseError(format!(
"Unsupported MIDI format {}",
format
)));
}
if (tickdiv >> 15) != 0 {
return Err(AudexError::ParseError(
"SMPTE timing not supported".to_string(),
));
}
if tickdiv == 0 {
return Err(AudexError::ParseError(
"Invalid tick division: 0".to_string(),
));
}
const MAX_CUMULATIVE_TRACK_BYTES: usize = 128 * 1024 * 1024; let mut cumulative_bytes: usize = 0;
let mut tracks: Vec<Vec<MidiEvent>> = Vec::new();
let mut first_tempos: Option<Vec<MidiEvent>> = None;
for _ in 0..ntracks {
let (identifier, chunk) = read_chunk(reader)?;
cumulative_bytes = cumulative_bytes.saturating_add(chunk.len());
if cumulative_bytes > MAX_CUMULATIVE_TRACK_BYTES {
return Err(AudexError::ParseError(format!(
"Cumulative track data ({} bytes) exceeds {} byte limit",
cumulative_bytes, MAX_CUMULATIVE_TRACK_BYTES
)));
}
if &identifier != b"MTrk" {
continue;
}
let mut events = read_track(&chunk)?;
let tempos: Vec<MidiEvent> = events
.iter()
.filter(|e| e.event_type == EventType::Tempo)
.cloned()
.collect();
if first_tempos.is_none() {
first_tempos = Some(tempos.clone());
}
if format == 1 {
if let Some(ref ft) = first_tempos {
events.retain(|e| e.event_type != EventType::Tempo);
events.extend(ft.clone());
}
}
events.sort_by_key(|e| (e.deltasum, e.event_type));
tracks.push(events);
}
let mut durations = Vec::new();
for events in tracks {
let mut tempo: u32 = 500000; let mut parts: Vec<(u64, u32)> = Vec::new();
let mut last_tempo_pos: u64 = 0;
let mut last_event_pos: u64 = 0;
for event in events {
match event.event_type {
EventType::Tempo => {
parts.push((event.deltasum.saturating_sub(last_tempo_pos), tempo));
tempo = event.data;
last_tempo_pos = event.deltasum;
}
EventType::Midi => {
last_event_pos = event.deltasum;
}
}
}
parts.push((last_event_pos.saturating_sub(last_tempo_pos), tempo));
let mut duration: f64 = 0.0;
for (deltasum, tempo) in parts {
let quarter = deltasum as f64 / tickdiv as f64;
let tpq = tempo as f64;
duration += quarter * tpq;
}
duration /= 1_000_000.0;
durations.push(duration);
}
durations
.into_iter()
.max_by(|a, b| a.total_cmp(b))
.ok_or_else(|| AudexError::ParseError("No valid tracks found".to_string()))
}
#[derive(Debug, Clone)]
pub struct SMFInfo {
length: Option<Duration>,
}
impl SMFInfo {
fn from_reader<R: Read>(reader: &mut R) -> Result<Self> {
let length_secs = read_midi_length(reader)?;
let length = if length_secs.is_finite() && length_secs >= 0.0 {
Some(Duration::from_secs_f64(length_secs))
} else {
None
};
Ok(SMFInfo { length })
}
}
impl Default for SMFInfo {
fn default() -> Self {
Self { length: None }
}
}
impl StreamInfo for SMFInfo {
fn length(&self) -> Option<Duration> {
self.length
}
fn bitrate(&self) -> Option<u32> {
None
}
fn sample_rate(&self) -> Option<u32> {
None
}
fn channels(&self) -> Option<u16> {
None
}
fn bits_per_sample(&self) -> Option<u16> {
None
}
fn pprint(&self) -> String {
if let Some(length) = self.length {
format!("SMF, {:.2} seconds", length.as_secs_f64())
} else {
"SMF, unknown length".to_string()
}
}
}
impl std::fmt::Display for SMFInfo {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.pprint())
}
}
#[derive(Debug, Clone, Default)]
pub struct SMFTags;
impl Tags for SMFTags {
fn get(&self, _key: &str) -> Option<&[String]> {
None
}
fn set(&mut self, _key: &str, _values: Vec<String>) {
}
fn remove(&mut self, _key: &str) {
}
fn keys(&self) -> Vec<String> {
Vec::new()
}
fn pprint(&self) -> String {
String::new()
}
}
#[derive(Debug)]
pub struct SMF {
info: SMFInfo,
_path: Option<String>,
}
impl SMF {
pub fn new() -> Self {
Self {
info: SMFInfo::default(),
_path: None,
}
}
pub fn load_from_reader<R: Read + Seek>(reader: &mut R) -> Result<Self> {
let info = SMFInfo::from_reader(reader)
.map_err(|e| AudexError::ParseError(format!("Failed to parse MIDI file: {}", e)))?;
Ok(SMF { info, _path: None })
}
#[cfg(feature = "async")]
pub async fn load_async<P: AsRef<Path>>(path: P) -> Result<Self> {
let mut file = loadfile_read_async(&path).await?;
let mut smf = SMF::new();
smf._path = Some(path.as_ref().to_string_lossy().to_string());
smf.info = Self::parse_info_async(&mut file).await?;
Ok(smf)
}
#[cfg(feature = "async")]
async fn parse_info_async(file: &mut TokioFile) -> Result<SMFInfo> {
const MAX_MIDI_READ: u64 = 10 * 1024 * 1024;
let file_size = file.seek(SeekFrom::End(0)).await?;
let read_size = std::cmp::min(file_size, MAX_MIDI_READ) as usize;
file.seek(SeekFrom::Start(0)).await?;
let mut data = vec![0u8; read_size];
file.read_exact(&mut data).await?;
let mut cursor = std::io::Cursor::new(&data[..]);
SMFInfo::from_reader(&mut cursor)
}
}
impl Default for SMF {
fn default() -> Self {
Self::new()
}
}
impl FileType for SMF {
type Tags = SMFTags;
type Info = SMFInfo;
fn format_id() -> &'static str {
"SMF"
}
fn load<P: AsRef<Path>>(path: P) -> Result<Self> {
debug_event!("parsing SMF stream info");
let file = File::open(path.as_ref())?;
let mut reader = BufReader::new(file);
let info = SMFInfo::from_reader(&mut reader)
.map_err(|e| AudexError::ParseError(format!("Failed to parse MIDI file: {}", e)))?;
Ok(SMF {
info,
_path: Some(path.as_ref().to_string_lossy().to_string()),
})
}
fn load_from_reader(reader: &mut dyn crate::ReadSeek) -> Result<Self> {
debug_event!("parsing SMF stream info from reader");
let mut reader = reader;
let info = SMFInfo::from_reader(&mut reader)
.map_err(|e| AudexError::ParseError(format!("Failed to parse MIDI file: {}", e)))?;
Ok(SMF { info, _path: None })
}
fn save(&mut self) -> Result<()> {
Err(AudexError::Unsupported(
"MIDI files don't support tags".to_string(),
))
}
fn clear(&mut self) -> Result<()> {
Err(AudexError::Unsupported(
"MIDI files don't support tags".to_string(),
))
}
fn add_tags(&mut self) -> Result<()> {
Err(AudexError::Unsupported(
"MIDI files don't support tags".to_string(),
))
}
fn tags(&self) -> Option<&Self::Tags> {
None
}
fn tags_mut(&mut self) -> Option<&mut Self::Tags> {
None
}
fn info(&self) -> &Self::Info {
&self.info
}
fn score(filename: &str, header: &[u8]) -> i32 {
let filename_lower = filename.to_lowercase();
let has_extension = filename_lower.ends_with(".mid")
|| filename_lower.ends_with(".midi")
|| filename_lower.ends_with(".kar");
let has_header = header.len() >= 4 && &header[0..4] == b"MThd";
if has_header && has_extension {
100
} else if has_header {
50
} else if has_extension {
10
} else {
0
}
}
fn mime_types() -> &'static [&'static str] {
&["audio/midi", "audio/x-midi"]
}
}