augmented_midi/
parser.rs

1// Augmented Audio: Audio libraries and applications
2// Copyright (c) 2022 Pedro Tacla Yamada
3//
4// The MIT License (MIT)
5//
6// Permission is hereby granted, free of charge, to any person obtaining a copy
7// of this software and associated documentation files (the "Software"), to deal
8// in the Software without restriction, including without limitation the rights
9// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10// copies of the Software, and to permit persons to whom the Software is
11// furnished to do so, subject to the following conditions:
12//
13// The above copyright notice and this permission notice shall be included in
14// all copies or substantial portions of the Software.
15//
16// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
22// THE SOFTWARE.
23use std::borrow::Borrow;
24
25use nom::bytes::complete::tag;
26use nom::{
27    branch::alt,
28    bytes::complete::take,
29    bytes::complete::take_till,
30    error::FromExternalError,
31    error::{Error, ErrorKind},
32    multi::many0,
33    number::complete::*,
34    Err, IResult,
35};
36
37pub use crate::types::*;
38
39/// These are bit-masks
40
41// The first 4 bits of the status byte indicate message type. This bitmask extracts that
42// section to match against the masks below.
43/// bit-mask to match the status byte
44pub const STATUS_BYTE_MASK: u8 = 0b1111_0000;
45
46// Bit-masks for each of the statuses, the 2nd 4 bits indicate the MIDI channel
47pub const CONTROL_CHANGE_MASK: u8 = 0b1011_0000;
48pub const NOTE_OFF_MASK: u8 = 0b1000_0000;
49pub const NOTE_ON_MASK: u8 = 0b1001_0000;
50pub const POLYPHONIC_KEY_PRESSURE_MASK: u8 = 0b1010_0000;
51pub const PROGRAM_CHANGE_MASK: u8 = 0b1100_0000;
52pub const CHANNEL_PRESSURE_MASK: u8 = 0b1101_0000;
53pub const PITCH_WHEEL_CHANGE_MASK: u8 = 0b1110_0000;
54
55// All these messages start with 0b1111, the 2nd 4 bits are part of the status
56pub const SONG_POSITION_POINTER_MASK: u8 = 0b1111_0010;
57pub const SONG_SELECT_MASK: u8 = 0b1111_0011;
58pub const TIMING_CLOCK_MASK: u8 = 0b1111_1000;
59pub const START_MASK: u8 = 0b1111_1010;
60pub const CONTINUE_MASK: u8 = 0b1111_1011;
61pub const STOP_MASK: u8 = 0b1111_1100;
62pub const ACTIVE_SENSING_MASK: u8 = 0b1111_1110;
63pub const RESET_MASK: u8 = 0b1111_1111;
64pub const TUNE_REQUEST_MASK: u8 = 0b1111_0110;
65
66pub const SYSEX_MESSAGE_MASK: u8 = 0b1111_0000;
67pub const SYSEX_MESSAGE_END_MASK: u8 = 0b11110111;
68
69pub type MIDIParseResult<'a, Output> = IResult<Input<'a>, Output>;
70
71/// Parses 3 16bit words. In order:
72///
73/// * File format
74/// * Number of tracks
75/// * Division
76pub fn parse_header_body(input: Input) -> MIDIParseResult<MIDIFileHeader> {
77    let (input, format) = parse_file_format(input)?;
78    let (input, num_tracks) = be_u16(input)?;
79    let (input, division_word) = be_u16(input)?;
80
81    let division_type = division_word >> 15;
82    let (input, division) = match division_type {
83        0 => {
84            let ticks_per_quarter_note = (division_word << 1) >> 1;
85            Ok((
86                input,
87                MIDIFileDivision::TicksPerQuarterNote {
88                    ticks_per_quarter_note,
89                },
90            ))
91        }
92        1 => {
93            let format = ((division_word << 1) >> 9) as u8;
94            let ticks_per_frame = ((division_word << 8) >> 8) as u8;
95            Ok((
96                input,
97                MIDIFileDivision::SMPTE {
98                    format,
99                    ticks_per_frame,
100                },
101            ))
102        }
103        _ => Err(Err::Error(Error::new(input, ErrorKind::Fail))),
104    }?;
105
106    Ok((
107        input,
108        MIDIFileHeader {
109            format,
110            num_tracks,
111            division,
112        },
113    ))
114}
115
116fn parse_file_format(input: Input) -> MIDIParseResult<MIDIFileFormat> {
117    let (input, format) = be_u16(input)?;
118    let format = match format {
119        0 => Ok(MIDIFileFormat::Single),
120        1 => Ok(MIDIFileFormat::Simultaneous),
121        2 => Ok(MIDIFileFormat::Sequential),
122        _ => Ok(MIDIFileFormat::Unknown),
123    }?;
124    Ok((input, format))
125}
126
127// https://en.wikipedia.org/wiki/Variable-length_quantity
128pub fn parse_variable_length_num(input: Input) -> MIDIParseResult<u32> {
129    use nom::bytes::complete::*;
130
131    let mut result: u32 = 0;
132
133    let (input, parts) = take_till(|b| b & 0b10000000 == 0)(input)?;
134    let (input, extra_part) = take(1u8)(input)?;
135
136    let mut i = parts.len() + 1;
137    for part in parts.iter().chain(extra_part.iter()) {
138        i -= 1;
139        let part = (part << 1) >> 1;
140        let part32 = part as u32;
141        result += part32 << (i * 7);
142    }
143
144    Ok((input, result))
145}
146
147pub fn parse_midi_event<'a, Buffer: Borrow<[u8]> + From<Input<'a>>>(
148    input: Input<'a>,
149    state: &mut ParserState,
150) -> MIDIParseResult<'a, MIDIMessage<Buffer>> {
151    let (tmp_input, tmp_status) = be_u8(input)?;
152    // Handle rolling status, this is look-ahead into the status byte and check
153    // if it's valid, otherwise try using the previous status.
154    let (input, status) = if tmp_status >= 0x7F {
155        state.last_status = Some(tmp_status);
156        Ok((tmp_input, tmp_status))
157    } else if let Some(status) = state.last_status {
158        Ok((input, status))
159    } else {
160        Err(Err::Error(Error::new(input, ErrorKind::Fail)))
161    }?;
162
163    let status_start = status & STATUS_BYTE_MASK;
164    let (input, message) = if status_start == NOTE_OFF_MASK {
165        let channel = parse_channel(status);
166        let (input, note) = be_u8(input)?;
167        let (input, velocity) = be_u8(input)?;
168        Ok((
169            input,
170            MIDIMessage::NoteOff(MIDIMessageNote {
171                channel,
172                note,
173                velocity,
174            }),
175        ))
176    } else if status_start == NOTE_ON_MASK {
177        let channel = parse_channel(status);
178        let (input, note) = be_u8(input)?;
179        let (input, velocity) = be_u8(input)?;
180        let note = MIDIMessageNote {
181            channel,
182            note,
183            velocity,
184        };
185        Ok((input, MIDIMessage::NoteOn(note)))
186    } else if status_start == POLYPHONIC_KEY_PRESSURE_MASK {
187        let channel = parse_channel(status);
188        let (input, note) = be_u8(input)?;
189        let (input, pressure) = be_u8(input)?;
190        Ok((
191            input,
192            MIDIMessage::PolyphonicKeyPressure {
193                channel,
194                note,
195                pressure,
196            },
197        ))
198    } else if status_start == CONTROL_CHANGE_MASK {
199        // Could potentially detect channel mode change here, but message is the same, the
200        // applications can handle this.
201        let channel = parse_channel(status);
202        let (input, controller_number) = be_u8(input)?;
203        let (input, value) = be_u8(input)?;
204        Ok((
205            input,
206            MIDIMessage::ControlChange {
207                channel,
208                controller_number,
209                value,
210            },
211        ))
212    } else if status_start == PROGRAM_CHANGE_MASK {
213        let channel = parse_channel(status);
214        let (input, program_number) = be_u8(input)?;
215        Ok((
216            input,
217            MIDIMessage::ProgramChange {
218                channel,
219                program_number,
220            },
221        ))
222    } else if status_start == CHANNEL_PRESSURE_MASK {
223        let channel = parse_channel(status);
224        let (input, pressure) = be_u8(input)?;
225        Ok((input, MIDIMessage::ChannelPressure { channel, pressure }))
226    } else if status_start == PITCH_WHEEL_CHANGE_MASK {
227        let channel = parse_channel(status);
228        let (input, value) = parse_14bit_midi_number(input)?;
229        Ok((input, MIDIMessage::PitchWheelChange { channel, value }))
230    } else if status == SYSEX_MESSAGE_MASK {
231        let (input, sysex_message) = take_till(|b| b == SYSEX_MESSAGE_END_MASK)(input)?;
232        let (input, _extra) = take(1u8)(input)?;
233        // assert!(extra.is_empty() && extra[0] == SYSEX_MESSAGE_END_MASK);
234        let sysex_message = MIDISysExEvent {
235            message: sysex_message.into(),
236        };
237        Ok((input, MIDIMessage::SysExMessage(sysex_message)))
238    } else if status == SONG_POSITION_POINTER_MASK {
239        let (input, value) = parse_14bit_midi_number(input)?;
240        Ok((input, MIDIMessage::SongPositionPointer { beats: value }))
241    } else if status == SONG_SELECT_MASK {
242        let (input, song) = be_u8(input)?;
243        Ok((input, MIDIMessage::SongSelect { song }))
244    } else if status == TIMING_CLOCK_MASK {
245        Ok((input, MIDIMessage::TimingClock))
246    } else if status == START_MASK {
247        Ok((input, MIDIMessage::Start))
248    } else if status == CONTINUE_MASK {
249        Ok((input, MIDIMessage::Continue))
250    } else if status == STOP_MASK {
251        Ok((input, MIDIMessage::Stop))
252    } else if status == ACTIVE_SENSING_MASK {
253        Ok((input, MIDIMessage::ActiveSensing))
254    } else if status == RESET_MASK {
255        Ok((input, MIDIMessage::Reset))
256    } else if status == TUNE_REQUEST_MASK {
257        Ok((input, MIDIMessage::TuneRequest))
258    } else {
259        Ok((input, MIDIMessage::Other { status }))
260    }?;
261
262    Ok((input, message))
263}
264
265fn parse_channel(status: u8) -> u8 {
266    status & 0b0000_1111
267}
268
269/// Input is a 14-bit number
270/// 0b0lllllll - 1st 7 bits are the least significant bits
271/// 0b0mmmmmmm - 2nd 7 bits are the most significant bits
272pub(crate) fn parse_14bit_midi_number(input: Input) -> MIDIParseResult<u16> {
273    let (input, value1) = be_u8(input)?;
274    let (input, value2) = be_u8(input)?;
275    let value1 = (value1 & !0b1000_0000) as u16;
276    let value2 = ((value2 & !0b1000_0000) as u16) << 7;
277    let value = value1 + value2;
278    Ok((input, value))
279}
280
281pub fn parse_meta_event<'a, Buffer: Borrow<[u8]> + From<Input<'a>>>(
282    input: Input<'a>,
283) -> MIDIParseResult<'a, MIDIMetaEvent<Buffer>> {
284    let (input, _) = tag([0xFF])(input)?;
285    let (input, meta_type) = be_u8(input)?;
286    let (input, length) = parse_variable_length_num(input)?;
287    let (input, bytes) = take(length)(input)?;
288
289    Ok((
290        input,
291        MIDIMetaEvent {
292            meta_type,
293            length,
294            bytes: bytes.into(),
295        },
296    ))
297}
298
299pub fn parse_track_event<'a, Buffer: Borrow<[u8]> + From<Input<'a>>>(
300    input: Input<'a>,
301    state: &mut ParserState,
302) -> MIDIParseResult<'a, MIDITrackEvent<Buffer>> {
303    let (input, delta_time) = parse_variable_length_num(input)?;
304    let (input, event) = alt((
305        |input| parse_meta_event(input).map(|(input, event)| (input, MIDITrackInner::Meta(event))),
306        |input| {
307            parse_midi_event(input, state)
308                .map(|(input, event)| (input, MIDITrackInner::Message(event)))
309        },
310    ))(input)?;
311
312    match event {
313        MIDITrackInner::Meta(_) => {
314            state.last_status = None;
315        }
316        MIDITrackInner::Message(MIDIMessage::SysExMessage(_)) => {
317            state.last_status = None;
318        }
319        _ => {}
320    }
321
322    Ok((
323        input,
324        MIDITrackEvent {
325            delta_time,
326            inner: event,
327        },
328    ))
329}
330
331#[derive(Default)]
332pub struct ParserState {
333    last_status: Option<u8>,
334}
335
336pub fn parse_chunk<
337    'a,
338    StringRepr: Borrow<str> + From<&'a str>,
339    Buffer: Borrow<[u8]> + From<Input<'a>>,
340>(
341    input: Input<'a>,
342) -> MIDIParseResult<'a, MIDIFileChunk<StringRepr, Buffer>> {
343    let (input, chunk_name) = take(4u32)(input)?;
344    let chunk_name: &str = std::str::from_utf8(chunk_name)
345        .map_err(|err| Err::Failure(Error::from_external_error(input, ErrorKind::Fail, err)))?;
346
347    let (input, chunk_length) = parse_chunk_length(input)?;
348    let (input, chunk_body) = take(chunk_length)(input)?;
349
350    let (_, chunk) = match chunk_name {
351        "MThd" => {
352            // assert_eq!(chunk_length, 6);
353            parse_header_body(chunk_body)
354                .map(|(rest, header)| (rest, MIDIFileChunk::Header(header)))
355        }
356        "MTrk" => {
357            let mut state = ParserState::default();
358            let mut parse = |input| parse_track_event(input, &mut state);
359            let mut events = Vec::with_capacity((chunk_length / 3) as usize);
360            let mut chunk_body = chunk_body;
361            loop {
362                let (new_chunk_body, event) = parse(chunk_body)?;
363                events.push(event);
364                chunk_body = new_chunk_body;
365
366                if chunk_body.is_empty() {
367                    break;
368                }
369            }
370            Ok((chunk_body, MIDIFileChunk::Track { events }))
371        }
372        _ => Ok((
373            chunk_body,
374            MIDIFileChunk::Unknown {
375                name: chunk_name.into(),
376                body: chunk_body.into(),
377            },
378        )),
379    }?;
380
381    Ok((input, chunk))
382}
383
384fn parse_chunk_length(input: Input) -> MIDIParseResult<u32> {
385    u32(nom::number::Endianness::Big)(input)
386}
387
388pub fn parse_midi_file<
389    'a,
390    StringRepr: Borrow<str> + From<&'a str>,
391    Buffer: Borrow<[u8]> + From<&'a [u8]>,
392>(
393    input: Input<'a>,
394) -> MIDIParseResult<'a, MIDIFile<StringRepr, Buffer>> {
395    let mut chunks = Vec::with_capacity(input.len() / 10);
396    let mut input = input;
397    loop {
398        let (new_input, chunk) = parse_chunk(input)?;
399        chunks.push(chunk);
400        input = new_input;
401
402        if input.is_empty() {
403            break;
404        }
405    }
406    Ok((input, MIDIFile { chunks }))
407}
408
409pub fn parse_midi<'a, Buffer: Borrow<[u8]> + From<Input<'a>>>(
410    input: Input<'a>,
411) -> MIDIParseResult<'a, Vec<MIDIMessage<Buffer>>> {
412    many0(|input| parse_midi_event(input, &mut ParserState::default()))(input)
413}
414
415#[cfg(test)]
416mod test {
417    use super::*;
418
419    #[test]
420    fn test_parse_file_format_single() {
421        let format = [0, 0];
422        let (_rest, result) = parse_file_format(&format).unwrap();
423        assert_eq!(result, MIDIFileFormat::Single);
424    }
425
426    #[test]
427    fn test_parse_file_format_simultaneous() {
428        let format = [0, 1];
429        let (_rest, result) = parse_file_format(&format).unwrap();
430        assert_eq!(result, MIDIFileFormat::Simultaneous);
431    }
432
433    #[test]
434    fn test_parse_file_format_sequential() {
435        let format = [0, 2];
436        let (_rest, result) = parse_file_format(&format).unwrap();
437        assert_eq!(result, MIDIFileFormat::Sequential);
438    }
439
440    #[test]
441    fn test_parse_file_format_unknown() {
442        let format = [0, 8];
443        let (_rest, result) = parse_file_format(&format).unwrap();
444        assert_eq!(result, MIDIFileFormat::Unknown);
445    }
446
447    #[test]
448    fn test_parse_header_body_tick_based() {
449        let input = [
450            // Single
451            0,
452            0,
453            // 1 track
454            0,
455            1,
456            // 2 ticks
457            0b0000_0000,
458            2,
459        ];
460        let (_rest, result) = parse_header_body(&input).unwrap();
461        assert_eq!(
462            result,
463            MIDIFileHeader {
464                format: MIDIFileFormat::Single,
465                num_tracks: 1,
466                division: MIDIFileDivision::TicksPerQuarterNote {
467                    ticks_per_quarter_note: 2
468                }
469            }
470        );
471    }
472
473    #[test]
474    fn test_parse_header_body_smpte_time_based() {
475        let input = [
476            // Single
477            0b0_u8,
478            0,
479            // 3 track
480            0,
481            3,
482            // SMPTE, format is 1 (which isn't valid, but for simplicity)
483            // ticks is 2
484            0b1000_0001,
485            2,
486        ];
487        let (_rest, result) = parse_header_body(&input).unwrap();
488        assert_eq!(
489            result,
490            MIDIFileHeader {
491                format: MIDIFileFormat::Single,
492                num_tracks: 3,
493                division: MIDIFileDivision::SMPTE {
494                    format: 1,
495                    ticks_per_frame: 2
496                }
497            }
498        );
499    }
500
501    #[test]
502    fn test_parse_variable_length_quantity_length_1() {
503        assert_eq!(127, parse_variable_length_num(&[0x7F]).unwrap().1);
504    }
505
506    #[test]
507    fn test_parse_variable_length_quantity_length_more_than_2() {
508        assert_eq!(128, parse_variable_length_num(&[0x81, 0x00]).unwrap().1);
509        assert_eq!(
510            16384,
511            parse_variable_length_num(&[0x81, 0x80, 0x00]).unwrap().1
512        );
513    }
514
515    #[test]
516    fn test_parse_14bit_midi_number() {
517        // Example pitch change on channel 3
518        // let pitch_wheel_message = [0xE3, 0x54, 0x39];
519        let (_, result) = parse_14bit_midi_number(&[0x54, 0x39]).unwrap();
520        assert_eq!(result, 7380);
521    }
522
523    #[test]
524    fn test_parse_pitch_wheel_event() {
525        // Example pitch change on channel 3
526        let pitch_wheel_message = [0xE3, 0x54, 0x39];
527        let (_, result) =
528            parse_midi_event::<Vec<u8>>(&pitch_wheel_message, &mut ParserState::default()).unwrap();
529        assert_eq!(
530            result,
531            MIDIMessage::PitchWheelChange {
532                channel: 3,
533                value: 7380
534            }
535        );
536    }
537
538    #[test]
539    fn test_parse_midi_file_smoke_test() {
540        let input_path = format!("{}/bach_846.mid", env!("CARGO_MANIFEST_DIR"));
541        let file_contents = std::fs::read(input_path).unwrap();
542        // let file_contents: Vec<u8> = file_contents.into_iter().take(8000).collect();
543        let (_rest, _midi_stream) = assert_no_alloc::assert_no_alloc(|| {
544            parse_midi_file::<String, Vec<u8>>(&file_contents).unwrap()
545        });
546    }
547
548    #[test]
549    fn test_parse_midi_file_smoke_test_no_alloc() {
550        let input_path = format!("{}/bach_846.mid", env!("CARGO_MANIFEST_DIR"));
551        let file_contents = std::fs::read(input_path).unwrap();
552        // let file_contents: Vec<u8> = file_contents.into_iter().take(8000).collect();
553        let (_rest, _midi_stream) = assert_no_alloc::assert_no_alloc(|| {
554            parse_midi_file::<&str, &[u8]>(&file_contents).unwrap()
555        });
556        // println!("{:?}", midi_stream);
557    }
558
559    #[test]
560    fn test() {
561        let input_path = format!(
562            "{}/test-files/c1_4over4_1bar.mid",
563            env!("CARGO_MANIFEST_DIR")
564        );
565        let file_contents = std::fs::read(input_path).unwrap();
566        let (_rest, midi_file) = parse_midi_file::<String, Vec<u8>>(&file_contents).unwrap();
567        assert_eq!(midi_file.ticks_per_quarter_note(), 96);
568        let quarter_length = midi_file.ticks_per_quarter_note() as u32;
569        let sixteenth_length = quarter_length / 4;
570
571        let header = midi_file.header().unwrap();
572        assert_eq!(header.format, MIDIFileFormat::Single);
573        assert_eq!(header.num_tracks, 1);
574
575        let events: Vec<MIDITrackEvent<Vec<u8>>> = midi_file.track_chunks().cloned().collect();
576        let note_on_events: Vec<(u32, MIDIMessageNote)> = events
577            .iter()
578            .filter_map(|event| match event {
579                MIDITrackEvent {
580                    delta_time,
581                    inner: MIDITrackInner::Message(MIDIMessage::NoteOn(note)),
582                } => Some((*delta_time, note.clone())),
583                _ => None,
584            })
585            .collect();
586        assert_eq!(note_on_events.len(), 4);
587        assert_eq!(note_on_events[0].0, 0);
588        assert_eq!(note_on_events[1].0, quarter_length - sixteenth_length);
589        assert_eq!(note_on_events[2].0, quarter_length - sixteenth_length);
590        assert_eq!(note_on_events[3].0, quarter_length - sixteenth_length);
591        for (_, evt) in &note_on_events {
592            assert_eq!(evt.velocity, 100);
593            assert_eq!(evt.note, 36);
594        }
595
596        let note_off_events: Vec<(u32, MIDIMessageNote)> = events
597            .iter()
598            .filter_map(|event| match event {
599                MIDITrackEvent {
600                    delta_time,
601                    inner: MIDITrackInner::Message(MIDIMessage::NoteOff(note)),
602                } => Some((*delta_time, note.clone())),
603                _ => None,
604            })
605            .collect();
606
607        assert_eq!(note_off_events.len(), 4);
608        assert_eq!(note_off_events[0].0, sixteenth_length);
609        assert_eq!(note_off_events[1].0, sixteenth_length);
610        assert_eq!(note_off_events[2].0, sixteenth_length);
611        assert_eq!(note_off_events[3].0, sixteenth_length);
612    }
613}