use crate::{
event::TrackEvent,
prelude::*,
primitive::{Format, Timing},
riff,
};
#[cfg(feature = "alloc")]
const BYTES_TO_EVENTS: f32 = 1.0 / 3.0;
#[cfg(feature = "alloc")]
const EVENTS_TO_BYTES: f32 = 3.4;
#[cfg(feature = "parallel")]
const PARALLEL_ENABLE_THRESHOLD: usize = 3 * 1024;
#[cfg(feature = "alloc")]
pub type Track<'a> = Vec<TrackEvent<'a>>;
#[cfg(feature = "alloc")]
#[derive(Clone, PartialEq, Eq, Debug, Hash)]
pub struct Smf<'a> {
pub header: Header,
pub tracks: Vec<Track<'a>>,
}
#[cfg(feature = "alloc")]
impl<'a> Smf<'a> {
#[inline]
pub fn new(header: Header) -> Smf<'a> {
Smf {
header,
tracks: vec![],
}
}
pub fn parse(raw: &[u8]) -> Result<Smf> {
let (header, tracks) = parse(raw)?;
let track_count_hint = tracks.track_count_hint;
let tracks = tracks.collect_tracks()?;
validate_smf(&header, track_count_hint, tracks.len())?;
Ok(Smf { header, tracks })
}
#[inline]
pub fn write<W: Write>(&self, out: &mut W) -> WriteResult<W> {
write(&self.header, &self.tracks, out)
}
#[cfg(feature = "std")]
#[inline]
pub fn write_std<W: io::Write>(&self, out: W) -> io::Result<()> {
write_std(&self.header, &self.tracks, out)
}
#[cfg(feature = "std")]
#[inline]
pub fn save<P: AsRef<Path>>(&self, path: P) -> io::Result<()> {
fn save_impl(smf: &Smf, path: &Path) -> io::Result<()> {
smf.write(&mut IoWrap(File::create(path)?))
}
save_impl(self, path.as_ref())
}
pub fn to_static(&self) -> Smf<'static> {
self.clone().make_static()
}
pub fn make_static(mut self) -> Smf<'static> {
for track in self.tracks.iter_mut() {
for ev in track.iter_mut() {
*ev = ev.to_static();
}
}
unsafe { mem::transmute::<Smf<'a>, Smf<'static>>(self) }
}
}
#[cfg(feature = "alloc")]
pub type BytemappedTrack<'a> = Vec<(&'a [u8], TrackEvent<'a>)>;
#[cfg(feature = "alloc")]
#[derive(Clone, PartialEq, Eq, Debug, Hash)]
pub struct SmfBytemap<'a> {
pub header: Header,
pub tracks: Vec<BytemappedTrack<'a>>,
}
#[cfg(feature = "alloc")]
impl<'a> SmfBytemap<'a> {
#[inline]
pub fn new(header: Header) -> SmfBytemap<'a> {
SmfBytemap {
header,
tracks: vec![],
}
}
pub fn parse(raw: &[u8]) -> Result<SmfBytemap> {
let (header, tracks) = parse(raw)?;
let track_count_hint = tracks.track_count_hint;
let tracks = tracks.collect_bytemapped()?;
validate_smf(&header, track_count_hint, tracks.len())?;
Ok(SmfBytemap { header, tracks })
}
#[inline]
pub fn write<W: Write>(&self, out: &mut W) -> WriteResult<W> {
write(
&self.header,
self.tracks
.iter()
.map(|bytemapped| bytemapped.iter().map(|(_b, ev)| ev)),
out,
)
}
#[cfg(feature = "std")]
#[inline]
pub fn write_std<W: io::Write>(&self, out: W) -> io::Result<()> {
write_std(
&self.header,
self.tracks
.iter()
.map(|bytemapped| bytemapped.iter().map(|(_b, ev)| ev)),
out,
)
}
#[cfg(feature = "std")]
#[inline]
pub fn save<P: AsRef<Path>>(&self, path: P) -> io::Result<()> {
fn save_impl(smf: &SmfBytemap, path: &Path) -> io::Result<()> {
smf.write(&mut IoWrap(File::create(path)?))
}
save_impl(self, path.as_ref())
}
}
#[cfg(feature = "alloc")]
fn validate_smf(header: &Header, track_count_hint: u16, track_count: usize) -> Result<()> {
if cfg!(feature = "strict") {
ensure!(
track_count_hint as usize == track_count,
err_malformed!("file has a different amount of tracks than declared")
);
ensure!(
header.format != Format::SingleTrack || track_count == 1,
err_malformed!("singletrack format file has multiple tracks")
);
}
Ok(())
}
pub fn parse(raw: &[u8]) -> Result<(Header, TrackIter)> {
let raw = match raw.get(..4) {
Some(b"RIFF") => riff::unwrap(raw)?,
Some(b"MThd") => raw,
_ => bail!(err_invalid!("not a midi file")),
};
let mut chunks = ChunkIter::new(raw);
let (header, track_count) = match chunks.next() {
Some(maybe_chunk) => match maybe_chunk.context(err_invalid!("invalid midi header"))? {
Chunk::Header(header, track_count) => Ok((header, track_count)),
Chunk::Track(_) => Err(err_invalid!("expected header, found track")),
},
None => Err(err_invalid!("no midi header chunk")),
}?;
let tracks = chunks.as_tracks(track_count);
Ok((header, tracks))
}
pub fn write<'a, T, E, W>(header: &Header, tracks: T, out: &mut W) -> WriteResult<W>
where
T: IntoIterator<Item = E>,
T::IntoIter: ExactSizeIterator + Clone + Send,
E: IntoIterator<Item = &'a TrackEvent<'a>>,
E::IntoIter: Clone + Send,
W: Write,
{
let tracks = tracks.into_iter().map(|events| events.into_iter());
Chunk::write_header(header, tracks.len(), out)?;
#[cfg(feature = "parallel")]
{
let event_count = tracks
.clone()
.map(|track| track.into_iter().size_hint().0)
.sum::<usize>();
if (event_count as f32 * EVENTS_TO_BYTES) > PARALLEL_ENABLE_THRESHOLD as f32 {
use rayon::prelude::*;
let mut track_chunks = Vec::new();
tracks
.collect::<Vec<_>>()
.into_par_iter()
.map(|track| {
let mut track_chunk = Vec::new();
Chunk::write_to_vec(track, &mut track_chunk)?;
Ok(track_chunk)
})
.collect_into_vec(&mut track_chunks);
for result in track_chunks {
let track_chunk = result.map_err(W::invalid_input)?;
out.write(&track_chunk)?;
}
return Ok(());
}
}
#[cfg(feature = "alloc")]
{
let mut buf = Vec::new();
for track in tracks {
Chunk::write_to_vec(track, &mut buf).map_err(|msg| W::invalid_input(msg))?;
out.write(&buf)?;
}
return Ok(());
}
#[allow(unreachable_code)]
{
if let Some(out) = out.make_seekable() {
for track in tracks {
Chunk::write_seek(track, out)?;
}
return Ok(());
}
for track in tracks {
Chunk::write_probe(track, out)?;
}
Ok(())
}
}
#[cfg(feature = "std")]
#[inline]
pub fn write_std<'a, T, E, W>(header: &Header, tracks: T, out: W) -> io::Result<()>
where
T: IntoIterator<Item = E>,
T::IntoIter: ExactSizeIterator + Clone + Send,
E: IntoIterator<Item = &'a TrackEvent<'a>>,
E::IntoIter: Clone + Send,
W: io::Write,
{
write(header, tracks, &mut IoWrap(out))
}
#[derive(Clone, Debug)]
struct ChunkIter<'a> {
raw: &'a [u8],
}
impl<'a> ChunkIter<'a> {
#[inline]
fn new(raw: &'a [u8]) -> ChunkIter {
ChunkIter { raw }
}
#[inline]
fn as_tracks(self, track_count_hint: u16) -> TrackIter<'a> {
TrackIter {
chunks: self,
track_count_hint,
}
}
}
impl<'a> Iterator for ChunkIter<'a> {
type Item = Result<Chunk<'a>>;
#[inline]
fn next(&mut self) -> Option<Result<Chunk<'a>>> {
match Chunk::read(&mut self.raw) {
Ok(Some(chunk)) => Some(Ok(chunk)),
Ok(None) => None,
Err(err) => {
self.raw = &[];
Some(Err(err))
}
}
}
}
#[derive(Copy, Clone, Debug)]
enum Chunk<'a> {
Header(Header, u16),
Track(&'a [u8]),
}
impl<'a> Chunk<'a> {
fn read(raw: &mut &'a [u8]) -> Result<Option<Chunk<'a>>> {
Ok(loop {
if raw.is_empty() {
break None;
}
let id = raw
.split_checked(4)
.ok_or(err_invalid!("failed to read chunkid"))?;
let len = u32::read(raw).context(err_invalid!("failed to read chunklen"))?;
let chunkdata = match raw.split_checked(len as usize) {
Some(chunkdata) => chunkdata,
None => {
if cfg!(feature = "strict") {
bail!(err_malformed!("reached eof before chunk ended"));
} else {
mem::replace(raw, &[])
}
}
};
match id {
b"MThd" => {
let (header, track_count) = Header::read(chunkdata)?;
break Some(Chunk::Header(header, track_count));
}
b"MTrk" => {
break Some(Chunk::Track(chunkdata));
}
_ => (),
}
})
}
fn write_header<W: Write>(header: &Header, track_count: usize, out: &mut W) -> WriteResult<W> {
let mut header_chunk = [0; 4 + 4 + 6];
let track_count = u16::try_from(track_count)
.map_err(|_| W::invalid_input("track count exceeds 16 bit range"))?;
let header = header.encode(track_count);
header_chunk[0..4].copy_from_slice(&b"MThd"[..]);
header_chunk[4..8].copy_from_slice(&(header.len() as u32).to_be_bytes()[..]);
header_chunk[8..].copy_from_slice(&header[..]);
out.write(&header_chunk[..])?;
Ok(())
}
fn write_probe<W: Write>(
track: impl Iterator<Item = &'a TrackEvent<'a>> + Clone,
out: &mut W,
) -> WriteResult<W> {
let mut counter = WriteCounter(0);
Self::write_raw(track.clone(), &mut counter).map_err(W::invalid_input)?;
let len = Self::check_len::<W, _>(counter.0)?;
let mut head = [b'M', b'T', b'r', b'k', 0, 0, 0, 0];
head[4..8].copy_from_slice(&len);
out.write(&head)?;
Self::write_raw(track, out)?;
Ok(())
}
fn write_seek<W: Write + Seek>(
track: impl Iterator<Item = &'a TrackEvent<'a>>,
out: &mut W,
) -> WriteResult<W> {
out.write(b"MTrk\0\0\0\0")?;
let start = out.tell()?;
Self::write_raw(track, out)?;
let len = Self::check_len::<W, _>(out.tell()? - start)?;
out.write_at(&len, start - 4)?;
Ok(())
}
#[cfg(feature = "alloc")]
fn write_to_vec(
track: impl Iterator<Item = &'a TrackEvent<'a>>,
out: &mut Vec<u8>,
) -> WriteResult<Vec<u8>> {
let cap = (track.size_hint().0 as f32 * EVENTS_TO_BYTES) as usize;
out.clear();
out.reserve(8 + cap);
out.extend_from_slice(b"MTrk\0\0\0\0");
Self::write_raw(track, out)?;
let len = Self::check_len::<Vec<u8>, _>(out.len() - 8)?;
out[4..8].copy_from_slice(&len);
Ok(())
}
fn write_raw<W: Write>(
track: impl Iterator<Item = &'a TrackEvent<'a>>,
out: &mut W,
) -> WriteResult<W> {
let mut running_status = None;
for ev in track {
ev.write(&mut running_status, out)?;
}
Ok(())
}
fn check_len<W, T>(len: T) -> StdResult<[u8; 4], W::Error>
where
u32: TryFrom<T>,
W: Write,
{
let len = u32::try_from(len)
.map_err(|_| W::invalid_input("midi chunk size exceeds 32 bit range"))?;
Ok(len.to_be_bytes())
}
}
#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)]
pub struct Header {
pub format: Format,
pub timing: Timing,
}
impl Header {
#[inline]
pub fn new(format: Format, timing: Timing) -> Header {
Header { format, timing }
}
fn read(mut raw: &[u8]) -> Result<(Header, u16)> {
let format = Format::read(&mut raw)?;
let track_count = u16::read(&mut raw)?;
let timing = Timing::read(&mut raw)?;
Ok((Header::new(format, timing), track_count))
}
fn encode(&self, track_count: u16) -> [u8; 6] {
let mut bytes = [0; 6];
bytes[0..2].copy_from_slice(&self.format.encode()[..]);
bytes[2..4].copy_from_slice(&track_count.to_be_bytes()[..]);
bytes[4..6].copy_from_slice(&self.timing.encode()[..]);
bytes
}
}
#[derive(Clone, Debug)]
pub struct TrackIter<'a> {
chunks: ChunkIter<'a>,
track_count_hint: u16,
}
impl<'a> TrackIter<'a> {
#[inline]
pub fn new(raw: &[u8]) -> TrackIter {
TrackIter {
chunks: ChunkIter::new(raw),
track_count_hint: 0,
}
}
#[inline]
pub fn unread(&self) -> &'a [u8] {
self.chunks.raw
}
#[cfg(feature = "alloc")]
pub fn collect_tracks(self) -> Result<Vec<Track<'a>>> {
self.generic_collect(EventIter::into_vec)
}
#[cfg(feature = "alloc")]
pub fn collect_bytemapped(self) -> Result<Vec<BytemappedTrack<'a>>> {
self.generic_collect(|events| events.bytemapped().into_vec())
}
#[cfg(feature = "alloc")]
#[inline]
fn generic_collect<T: Send + 'a>(
self,
collect: impl Fn(EventIter<'a>) -> Result<Vec<T>> + Send + Sync,
) -> Result<Vec<Vec<T>>> {
#[cfg(feature = "parallel")]
{
if self.unread().len() >= PARALLEL_ENABLE_THRESHOLD {
use rayon::prelude::*;
let chunk_vec = self.collect::<Result<Vec<_>>>()?;
return chunk_vec
.into_par_iter()
.map(collect)
.collect::<Result<Vec<Vec<T>>>>();
}
}
self.map(|r| r.and_then(&collect))
.collect::<Result<Vec<Vec<T>>>>()
}
}
impl<'a> Iterator for TrackIter<'a> {
type Item = Result<EventIter<'a>>;
#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
(
self.track_count_hint as usize,
Some(self.track_count_hint as usize),
)
}
#[inline]
fn next(&mut self) -> Option<Result<EventIter<'a>>> {
loop {
if let Some(chunk) = self.chunks.next() {
self.track_count_hint = self.track_count_hint.saturating_sub(1);
match chunk {
Ok(Chunk::Track(track)) => break Some(Ok(EventIter::new(track))),
Ok(Chunk::Header(..)) => {
if cfg!(feature = "strict") {
break Some(Err(err_malformed!("found duplicate header").into()));
} else {
}
}
Err(err) => {
if cfg!(feature = "strict") {
break Some(Err(err).context(err_malformed!("invalid chunk")));
} else {
}
}
}
} else {
break None;
}
}
}
}
trait EventKind<'a> {
type Event: 'a;
fn read_ev(raw: &mut &'a [u8], running_status: &mut Option<u8>) -> Result<Self::Event>;
}
#[derive(Clone, Debug)]
struct EventIterGeneric<'a, T> {
raw: &'a [u8],
running_status: Option<u8>,
_kind: PhantomData<T>,
}
impl<'a, T: EventKind<'a>> EventIterGeneric<'a, T> {
#[inline]
fn new(raw: &[u8]) -> EventIterGeneric<T> {
EventIterGeneric {
raw,
running_status: None,
_kind: PhantomData,
}
}
#[inline]
fn unread(&self) -> &'a [u8] {
self.raw
}
#[inline]
fn running_status(&self) -> Option<u8> {
self.running_status
}
#[inline]
fn running_status_mut(&mut self) -> &mut Option<u8> {
&mut self.running_status
}
#[cfg(feature = "alloc")]
#[inline]
fn estimate_events(&self) -> usize {
(self.raw.len() as f32 * BYTES_TO_EVENTS) as usize
}
#[cfg(feature = "alloc")]
#[inline]
fn into_vec(mut self) -> Result<Vec<T::Event>> {
let mut events = Vec::with_capacity(self.estimate_events());
while !self.raw.is_empty() {
match T::read_ev(&mut self.raw, &mut self.running_status) {
Ok(ev) => events.push(ev),
Err(err) => {
self.raw = &[];
if cfg!(feature = "strict") {
Err(err).context(err_malformed!("malformed event"))?;
} else {
break;
}
}
}
}
Ok(events)
}
}
impl<'a, T: EventKind<'a>> Iterator for EventIterGeneric<'a, T> {
type Item = Result<T::Event>;
#[inline]
fn next(&mut self) -> Option<Self::Item> {
if !self.raw.is_empty() {
match T::read_ev(&mut self.raw, &mut self.running_status) {
Ok(ev) => Some(Ok(ev)),
Err(err) => {
self.raw = &[];
if cfg!(feature = "strict") {
Some(Err(err).context(err_malformed!("malformed event")))
} else {
None
}
}
}
} else {
None
}
}
}
#[derive(Clone, Debug)]
pub struct EventIter<'a> {
inner: EventIterGeneric<'a, Self>,
}
impl<'a> EventKind<'a> for EventIter<'a> {
type Event = TrackEvent<'a>;
#[inline]
fn read_ev(raw: &mut &'a [u8], rs: &mut Option<u8>) -> Result<TrackEvent<'a>> {
TrackEvent::read(raw, rs)
}
}
impl<'a> EventIter<'a> {
#[inline]
pub fn new(raw: &[u8]) -> EventIter {
EventIter {
inner: EventIterGeneric::new(raw),
}
}
#[inline]
pub fn unread(&self) -> &'a [u8] {
self.inner.unread()
}
#[inline]
pub fn running_status(&self) -> Option<u8> {
self.inner.running_status()
}
#[inline]
pub fn running_status_mut(&mut self) -> &mut Option<u8> {
self.inner.running_status_mut()
}
#[inline]
pub fn bytemapped(self) -> EventBytemapIter<'a> {
EventBytemapIter {
inner: EventIterGeneric {
raw: self.inner.raw,
running_status: self.inner.running_status,
_kind: PhantomData,
},
}
}
#[cfg(feature = "alloc")]
#[inline]
pub fn into_vec(self) -> Result<Track<'a>> {
self.inner.into_vec()
}
}
impl<'a> Iterator for EventIter<'a> {
type Item = Result<TrackEvent<'a>>;
#[inline]
fn next(&mut self) -> Option<Self::Item> {
self.inner.next()
}
}
#[derive(Clone, Debug)]
pub struct EventBytemapIter<'a> {
inner: EventIterGeneric<'a, Self>,
}
impl<'a> EventKind<'a> for EventBytemapIter<'a> {
type Event = (&'a [u8], TrackEvent<'a>);
#[inline]
fn read_ev(raw: &mut &'a [u8], rs: &mut Option<u8>) -> Result<Self::Event> {
TrackEvent::read_bytemap(raw, rs)
}
}
impl<'a> EventBytemapIter<'a> {
#[inline]
pub fn new(raw: &[u8]) -> EventBytemapIter {
EventBytemapIter {
inner: EventIterGeneric::new(raw),
}
}
#[inline]
pub fn unread(&self) -> &'a [u8] {
self.inner.unread()
}
#[inline]
pub fn running_status(&self) -> Option<u8> {
self.inner.running_status()
}
#[inline]
pub fn running_status_mut(&mut self) -> &mut Option<u8> {
self.inner.running_status_mut()
}
#[inline]
pub fn not_bytemapped(self) -> EventIter<'a> {
EventIter {
inner: EventIterGeneric {
raw: self.inner.raw,
running_status: self.inner.running_status,
_kind: PhantomData,
},
}
}
#[cfg(feature = "alloc")]
#[inline]
pub fn into_vec(self) -> Result<Vec<(&'a [u8], TrackEvent<'a>)>> {
self.inner.into_vec()
}
}
impl<'a> Iterator for EventBytemapIter<'a> {
type Item = Result<(&'a [u8], TrackEvent<'a>)>;
#[inline]
fn next(&mut self) -> Option<Self::Item> {
self.inner.next()
}
}