Skip to main content

maolan_engine/plugins/vst3/
midi.rs

1use crate::midi::io::MidiEvent;
2use std::cell::UnsafeCell;
3use vst3::Steinberg::Vst::ControllerNumbers_::{kAfterTouch, kCtrlProgramChange, kPitchBend};
4use vst3::Steinberg::Vst::DataEvent_::DataTypes_;
5use vst3::Steinberg::Vst::Event_::EventTypes_;
6use vst3::Steinberg::Vst::{
7    CtrlNumber, DataEvent, Event, Event__type0, IEventList, IEventListTrait, IMidiMapping,
8    IMidiMappingTrait, IParamValueQueue, IParamValueQueueTrait, IParameterChanges,
9    IParameterChangesTrait, LegacyMIDICCOutEvent, NoteOffEvent, NoteOnEvent, ParamID,
10    PolyPressureEvent,
11};
12use vst3::Steinberg::{kInvalidArgument, kResultFalse, kResultOk};
13use vst3::{Class, ComPtr, ComWrapper};
14
15pub struct EventBuffer {
16    events: UnsafeCell<Vec<Event>>,
17    sysex_data: UnsafeCell<Vec<Vec<u8>>>,
18}
19
20impl Class for EventBuffer {
21    type Interfaces = (IEventList,);
22}
23
24impl Default for EventBuffer {
25    fn default() -> Self {
26        Self::new()
27    }
28}
29
30impl EventBuffer {
31    pub fn new() -> Self {
32        Self {
33            events: UnsafeCell::new(Vec::new()),
34            sysex_data: UnsafeCell::new(Vec::new()),
35        }
36    }
37
38    pub fn clear(&mut self) {
39        self.events_mut().clear();
40        self.sysex_data_mut().clear();
41    }
42
43    pub fn from_midi_events(midi_events: &[MidiEvent], bus_index: i32) -> Self {
44        let buffer = Self::new();
45        for midi_event in midi_events {
46            buffer.push_midi_event(midi_event, bus_index);
47        }
48        buffer
49    }
50
51    pub fn to_midi_events(&self) -> Vec<MidiEvent> {
52        self.events_ref()
53            .iter()
54            .filter_map(vst3_event_to_midi)
55            .collect()
56    }
57
58    pub fn event_count(&self) -> usize {
59        self.events_ref().len()
60    }
61
62    pub fn event_list_ptr(list: &ComWrapper<Self>) -> *mut IEventList {
63        list.as_com_ref::<IEventList>()
64            .map(|r| r.as_ptr())
65            .unwrap_or(std::ptr::null_mut())
66    }
67
68    #[allow(clippy::mut_from_ref)]
69    fn events_mut(&self) -> &mut Vec<Event> {
70        unsafe { &mut *self.events.get() }
71    }
72
73    fn events_ref(&self) -> &Vec<Event> {
74        unsafe { &*self.events.get() }
75    }
76
77    #[allow(clippy::mut_from_ref)]
78    fn sysex_data_mut(&self) -> &mut Vec<Vec<u8>> {
79        unsafe { &mut *self.sysex_data.get() }
80    }
81
82    fn push_midi_event(&self, midi_event: &MidiEvent, bus_index: i32) {
83        if let Some(event) = midi_to_vst3_event(midi_event, bus_index, self.sysex_data_mut()) {
84            self.events_mut().push(event);
85        }
86    }
87}
88
89impl IEventListTrait for EventBuffer {
90    unsafe fn getEventCount(&self) -> i32 {
91        self.event_count().min(i32::MAX as usize) as i32
92    }
93
94    unsafe fn getEvent(&self, index: i32, e: *mut Event) -> i32 {
95        if index < 0 || e.is_null() {
96            return kInvalidArgument;
97        }
98        let Some(event) = self.events_ref().get(index as usize).copied() else {
99            return kResultFalse;
100        };
101        unsafe {
102            *e = event;
103        }
104        kResultOk
105    }
106
107    unsafe fn addEvent(&self, e: *mut Event) -> i32 {
108        if e.is_null() {
109            return kInvalidArgument;
110        }
111        let event = unsafe { *e };
112        if event.r#type as u32 == EventTypes_::kDataEvent
113            && let Some(bytes) = copy_sysex_event(&event)
114        {
115            self.sysex_data_mut().push(bytes);
116            if let Some(last) = self.sysex_data_mut().last() {
117                self.events_mut().push(Event {
118                    __field0: Event__type0 {
119                        data: DataEvent {
120                            size: last.len().min(u32::MAX as usize) as u32,
121                            r#type: DataTypes_::kMidiSysEx,
122                            bytes: last.as_ptr(),
123                        },
124                    },
125                    ..event
126                });
127                return kResultOk;
128            }
129        }
130        self.events_mut().push(event);
131        kResultOk
132    }
133}
134
135pub struct ParameterChanges {
136    queues: UnsafeCell<Vec<ComWrapper<ParameterValueQueue>>>,
137}
138
139impl Class for ParameterChanges {
140    type Interfaces = (IParameterChanges,);
141}
142
143impl Default for ParameterChanges {
144    fn default() -> Self {
145        Self::new()
146    }
147}
148
149impl ParameterChanges {
150    pub fn new() -> Self {
151        Self {
152            queues: UnsafeCell::new(Vec::new()),
153        }
154    }
155
156    pub fn from_midi_events(
157        midi_events: &[MidiEvent],
158        mapping: &ComPtr<IMidiMapping>,
159        bus_index: i32,
160    ) -> Option<Self> {
161        let changes = Self::new();
162        for midi_event in midi_events {
163            let Some((channel, controller, value)) = midi_to_controller_change(midi_event) else {
164                continue;
165            };
166            let mut param_id: ParamID = 0;
167            let result = unsafe {
168                mapping.getMidiControllerAssignment(bus_index, channel, controller, &mut param_id)
169            };
170            if result != kResultOk {
171                continue;
172            }
173            changes.add_point(
174                param_id,
175                midi_event.frame.min(i32::MAX as u32) as i32,
176                value as f64,
177            );
178        }
179
180        (!changes.queues_ref().is_empty()).then_some(changes)
181    }
182
183    pub fn changes_ptr(changes: &ComWrapper<Self>) -> *mut IParameterChanges {
184        changes
185            .as_com_ref::<IParameterChanges>()
186            .map(|r| r.as_ptr())
187            .unwrap_or(std::ptr::null_mut())
188    }
189
190    fn add_point(&self, param_id: ParamID, sample_offset: i32, value: f64) {
191        for queue in self.queues_ref() {
192            if queue.parameter_id() == param_id {
193                queue.push_point(sample_offset, value);
194                return;
195            }
196        }
197
198        let queue = ComWrapper::new(ParameterValueQueue::new(param_id));
199        queue.push_point(sample_offset, value);
200        self.queues_mut().push(queue);
201    }
202
203    #[allow(clippy::mut_from_ref)]
204    fn queues_mut(&self) -> &mut Vec<ComWrapper<ParameterValueQueue>> {
205        unsafe { &mut *self.queues.get() }
206    }
207
208    fn queues_ref(&self) -> &Vec<ComWrapper<ParameterValueQueue>> {
209        unsafe { &*self.queues.get() }
210    }
211}
212
213impl IParameterChangesTrait for ParameterChanges {
214    unsafe fn getParameterCount(&self) -> i32 {
215        self.queues_ref().len().min(i32::MAX as usize) as i32
216    }
217
218    unsafe fn getParameterData(&self, index: i32) -> *mut IParamValueQueue {
219        self.queues_ref()
220            .get(index.max(0) as usize)
221            .and_then(|queue| queue.as_com_ref::<IParamValueQueue>())
222            .map(|queue| queue.as_ptr())
223            .unwrap_or(std::ptr::null_mut())
224    }
225
226    unsafe fn addParameterData(
227        &self,
228        id: *const ParamID,
229        index: *mut i32,
230    ) -> *mut IParamValueQueue {
231        if id.is_null() {
232            return std::ptr::null_mut();
233        }
234        let param_id = unsafe { *id };
235        if let Some(existing_idx) = self
236            .queues_ref()
237            .iter()
238            .position(|queue| queue.parameter_id() == param_id)
239        {
240            if !index.is_null() {
241                unsafe {
242                    *index = existing_idx as i32;
243                }
244            }
245            return self
246                .queues_ref()
247                .get(existing_idx)
248                .and_then(|queue| queue.as_com_ref::<IParamValueQueue>())
249                .map(|queue| queue.as_ptr())
250                .unwrap_or(std::ptr::null_mut());
251        }
252
253        let queue = ComWrapper::new(ParameterValueQueue::new(param_id));
254        self.queues_mut().push(queue);
255        let idx = self.queues_ref().len().saturating_sub(1);
256        if !index.is_null() {
257            unsafe {
258                *index = idx as i32;
259            }
260        }
261        self.queues_ref()[idx]
262            .as_com_ref::<IParamValueQueue>()
263            .map(|queue| queue.as_ptr())
264            .unwrap_or(std::ptr::null_mut())
265    }
266}
267
268pub struct ParameterValueQueue {
269    param_id: ParamID,
270    points: UnsafeCell<Vec<(i32, f64)>>,
271}
272
273impl Class for ParameterValueQueue {
274    type Interfaces = (IParamValueQueue,);
275}
276
277impl ParameterValueQueue {
278    fn new(param_id: ParamID) -> Self {
279        Self {
280            param_id,
281            points: UnsafeCell::new(Vec::new()),
282        }
283    }
284
285    fn parameter_id(&self) -> ParamID {
286        self.param_id
287    }
288
289    fn push_point(&self, sample_offset: i32, value: f64) {
290        self.points_mut()
291            .push((sample_offset, value.clamp(0.0, 1.0)));
292    }
293
294    #[allow(clippy::mut_from_ref)]
295    fn points_mut(&self) -> &mut Vec<(i32, f64)> {
296        unsafe { &mut *self.points.get() }
297    }
298
299    fn points_ref(&self) -> &Vec<(i32, f64)> {
300        unsafe { &*self.points.get() }
301    }
302}
303
304impl IParamValueQueueTrait for ParameterValueQueue {
305    unsafe fn getParameterId(&self) -> ParamID {
306        self.param_id
307    }
308
309    unsafe fn getPointCount(&self) -> i32 {
310        self.points_ref().len().min(i32::MAX as usize) as i32
311    }
312
313    unsafe fn getPoint(&self, index: i32, sample_offset: *mut i32, value: *mut f64) -> i32 {
314        let Some((offset, point_value)) = self.points_ref().get(index.max(0) as usize).copied()
315        else {
316            return kResultFalse;
317        };
318        if !sample_offset.is_null() {
319            unsafe {
320                *sample_offset = offset;
321            }
322        }
323        if !value.is_null() {
324            unsafe {
325                *value = point_value;
326            }
327        }
328        kResultOk
329    }
330
331    unsafe fn addPoint(&self, sample_offset: i32, value: f64, index: *mut i32) -> i32 {
332        self.push_point(sample_offset, value);
333        if !index.is_null() {
334            unsafe {
335                *index = self.points_ref().len().saturating_sub(1) as i32;
336            }
337        }
338        kResultOk
339    }
340}
341
342fn midi_to_vst3_event(
343    midi_event: &MidiEvent,
344    bus_index: i32,
345    sysex_storage: &mut Vec<Vec<u8>>,
346) -> Option<Event> {
347    let status = *midi_event.data.first()?;
348    let channel = (status & 0x0f) as i16;
349    let kind = status & 0xf0;
350    let sample_offset = midi_event.frame.min(i32::MAX as u32) as i32;
351
352    match kind {
353        0x80 => {
354            let pitch = *midi_event.data.get(1)? as i16;
355            let velocity = midi_velocity(midi_event.data.get(2).copied().unwrap_or(0));
356            Some(Event {
357                busIndex: bus_index,
358                sampleOffset: sample_offset,
359                ppqPosition: 0.0,
360                flags: 0,
361                r#type: EventTypes_::kNoteOffEvent as u16,
362                __field0: Event__type0 {
363                    noteOff: NoteOffEvent {
364                        channel,
365                        pitch,
366                        velocity,
367                        noteId: -1,
368                        tuning: 0.0,
369                    },
370                },
371            })
372        }
373        0x90 => {
374            let pitch = *midi_event.data.get(1)? as i16;
375            let velocity_byte = midi_event.data.get(2).copied().unwrap_or(0);
376            if velocity_byte == 0 {
377                return midi_to_vst3_event(
378                    &MidiEvent::new(midi_event.frame, vec![0x80 | channel as u8, pitch as u8, 0]),
379                    bus_index,
380                    sysex_storage,
381                );
382            }
383            Some(Event {
384                busIndex: bus_index,
385                sampleOffset: sample_offset,
386                ppqPosition: 0.0,
387                flags: 0,
388                r#type: EventTypes_::kNoteOnEvent as u16,
389                __field0: Event__type0 {
390                    noteOn: NoteOnEvent {
391                        channel,
392                        pitch,
393                        tuning: 0.0,
394                        velocity: midi_velocity(velocity_byte),
395                        length: 0,
396                        noteId: -1,
397                    },
398                },
399            })
400        }
401        0xA0 => {
402            let pitch = *midi_event.data.get(1)? as i16;
403            let pressure = midi_velocity(midi_event.data.get(2).copied().unwrap_or(0));
404            Some(Event {
405                busIndex: bus_index,
406                sampleOffset: sample_offset,
407                ppqPosition: 0.0,
408                flags: 0,
409                r#type: EventTypes_::kPolyPressureEvent as u16,
410                __field0: Event__type0 {
411                    polyPressure: PolyPressureEvent {
412                        channel,
413                        pitch,
414                        pressure,
415                        noteId: -1,
416                    },
417                },
418            })
419        }
420        0xF0 if midi_event.data.first().copied() == Some(0xF0) => {
421            sysex_storage.push(midi_event.data.clone());
422            let bytes = sysex_storage.last()?;
423            Some(Event {
424                busIndex: bus_index,
425                sampleOffset: sample_offset,
426                ppqPosition: 0.0,
427                flags: 0,
428                r#type: EventTypes_::kDataEvent as u16,
429                __field0: Event__type0 {
430                    data: DataEvent {
431                        size: bytes.len().min(u32::MAX as usize) as u32,
432                        r#type: DataTypes_::kMidiSysEx,
433                        bytes: bytes.as_ptr(),
434                    },
435                },
436            })
437        }
438        _ => None,
439    }
440}
441
442fn vst3_event_to_midi(event: &Event) -> Option<MidiEvent> {
443    let frame = event.sampleOffset.max(0) as u32;
444    match event.r#type as u32 {
445        EventTypes_::kNoteOnEvent => {
446            let note = unsafe { event.__field0.noteOn };
447            Some(MidiEvent::new(
448                frame,
449                vec![
450                    0x90 | (note.channel as u8 & 0x0f),
451                    note.pitch.clamp(0, 127) as u8,
452                    midi_byte(note.velocity),
453                ],
454            ))
455        }
456        EventTypes_::kNoteOffEvent => {
457            let note = unsafe { event.__field0.noteOff };
458            Some(MidiEvent::new(
459                frame,
460                vec![
461                    0x80 | (note.channel as u8 & 0x0f),
462                    note.pitch.clamp(0, 127) as u8,
463                    midi_byte(note.velocity),
464                ],
465            ))
466        }
467        EventTypes_::kPolyPressureEvent => {
468            let pressure = unsafe { event.__field0.polyPressure };
469            Some(MidiEvent::new(
470                frame,
471                vec![
472                    0xA0 | (pressure.channel as u8 & 0x0f),
473                    pressure.pitch.clamp(0, 127) as u8,
474                    midi_byte(pressure.pressure),
475                ],
476            ))
477        }
478        EventTypes_::kDataEvent => {
479            let data = unsafe { event.__field0.data };
480            (data.r#type == DataTypes_::kMidiSysEx && !data.bytes.is_null()).then(|| {
481                let bytes = unsafe {
482                    std::slice::from_raw_parts(
483                        data.bytes,
484                        data.size.min(usize::MAX as u32) as usize,
485                    )
486                };
487                MidiEvent::new(frame, bytes.to_vec())
488            })
489        }
490        EventTypes_::kLegacyMIDICCOutEvent => {
491            let cc = unsafe { event.__field0.midiCCOut };
492            legacy_cc_to_midi(frame, cc)
493        }
494        _ => None,
495    }
496}
497
498fn midi_to_controller_change(midi_event: &MidiEvent) -> Option<(i16, CtrlNumber, f32)> {
499    let status = *midi_event.data.first()?;
500    let channel = (status & 0x0f) as i16;
501    let kind = status & 0xf0;
502    match kind {
503        0xB0 => Some((
504            channel,
505            midi_event.data.get(1).copied()? as CtrlNumber,
506            midi_velocity(midi_event.data.get(2).copied().unwrap_or(0)),
507        )),
508        0xC0 => Some((
509            channel,
510            kCtrlProgramChange as CtrlNumber,
511            midi_velocity(midi_event.data.get(1).copied().unwrap_or(0)),
512        )),
513        0xD0 => Some((
514            channel,
515            kAfterTouch as CtrlNumber,
516            midi_velocity(midi_event.data.get(1).copied().unwrap_or(0)),
517        )),
518        0xE0 => {
519            let lsb = midi_event.data.get(1).copied().unwrap_or(0) as u16;
520            let msb = midi_event.data.get(2).copied().unwrap_or(0) as u16;
521            let value = ((msb << 7) | lsb).min(16383) as f32 / 16383.0;
522            Some((channel, kPitchBend as CtrlNumber, value))
523        }
524        _ => None,
525    }
526}
527
528fn legacy_cc_to_midi(frame: u32, cc: LegacyMIDICCOutEvent) -> Option<MidiEvent> {
529    let channel = (cc.channel as u8) & 0x0f;
530    let control = cc.controlNumber as u16;
531    let value = (cc.value as i16).clamp(0, 127) as u8;
532    let value2 = (cc.value2 as i16).clamp(0, 127) as u8;
533
534    let data = match control {
535        x if x == kPitchBend as u16 => vec![0xE0 | channel, value, value2],
536        x if x == kAfterTouch as u16 => vec![0xD0 | channel, value],
537        x if x == kCtrlProgramChange as u16 => vec![0xC0 | channel, value],
538        x if x <= 127 => vec![0xB0 | channel, x as u8, value],
539        _ => return None,
540    };
541    Some(MidiEvent::new(frame, data))
542}
543
544fn copy_sysex_event(event: &Event) -> Option<Vec<u8>> {
545    let data = unsafe { event.__field0.data };
546    (data.r#type == DataTypes_::kMidiSysEx && !data.bytes.is_null()).then(|| unsafe {
547        std::slice::from_raw_parts(data.bytes, data.size.min(usize::MAX as u32) as usize).to_vec()
548    })
549}
550
551fn midi_velocity(value: u8) -> f32 {
552    (value.min(127) as f32) / 127.0
553}
554
555fn midi_byte(value: f32) -> u8 {
556    (value.clamp(0.0, 1.0) * 127.0).round() as u8
557}
558
559#[cfg(test)]
560mod tests {
561    use super::*;
562
563    #[test]
564    fn test_event_buffer_creation() {
565        let buffer = EventBuffer::new();
566        assert_eq!(buffer.event_count(), 0);
567    }
568
569    #[test]
570    fn test_midi_events_conversion() {
571        let midi = vec![
572            MidiEvent::new(0, vec![0x90, 60, 100]),  // Note On C4
573            MidiEvent::new(100, vec![0x80, 60, 64]), // Note Off C4
574        ];
575
576        let buffer = EventBuffer::from_midi_events(&midi, 0);
577        assert_eq!(buffer.event_count(), 2);
578
579        let output = buffer.to_midi_events();
580        assert_eq!(output.len(), 2);
581        assert_eq!(output[0].frame, 0);
582        assert_eq!(output[0].data, vec![0x90, 60, 100]);
583    }
584
585    #[test]
586    fn test_unsupported_events_are_ignored() {
587        let midi = vec![MidiEvent::new(0, vec![0xB0, 74, 100])];
588        let buffer = EventBuffer::from_midi_events(&midi, 0);
589        assert_eq!(buffer.event_count(), 0);
590    }
591
592    #[test]
593    fn test_sysex_roundtrip() {
594        let midi = vec![MidiEvent::new(12, vec![0xF0, 0x7D, 0x10, 0xF7])];
595        let buffer = EventBuffer::from_midi_events(&midi, 0);
596        let output = buffer.to_midi_events();
597        assert_eq!(output, midi);
598    }
599
600    #[test]
601    fn test_controller_changes_map_to_parameter_changes() {
602        struct TestMidiMapping;
603        impl Class for TestMidiMapping {
604            type Interfaces = (IMidiMapping,);
605        }
606        impl IMidiMappingTrait for TestMidiMapping {
607            unsafe fn getMidiControllerAssignment(
608                &self,
609                _bus_index: i32,
610                _channel: i16,
611                midi_controller_number: CtrlNumber,
612                id: *mut ParamID,
613            ) -> i32 {
614                if midi_controller_number == 74 {
615                    unsafe {
616                        *id = 1234;
617                    }
618                    kResultOk
619                } else {
620                    kResultFalse
621                }
622            }
623        }
624
625        let mapping = ComWrapper::new(TestMidiMapping)
626            .to_com_ptr::<IMidiMapping>()
627            .unwrap();
628        let changes = ParameterChanges::from_midi_events(
629            &[MidiEvent::new(64, vec![0xB0, 74, 100])],
630            &mapping,
631            0,
632        )
633        .unwrap();
634        assert_eq!(unsafe { changes.getParameterCount() }, 1);
635        let queue_ptr = unsafe { changes.getParameterData(0) };
636        let queue = unsafe { vst3::ComRef::from_raw(queue_ptr) }.unwrap();
637        assert_eq!(unsafe { queue.getParameterId() }, 1234);
638        assert_eq!(unsafe { queue.getPointCount() }, 1);
639    }
640
641    #[test]
642    fn test_empty_buffer() {
643        let buffer = EventBuffer::from_midi_events(&[], 0);
644        assert_eq!(buffer.event_count(), 0);
645        assert_eq!(buffer.to_midi_events().len(), 0);
646    }
647}