Skip to main content

maolan_engine/plugins/vst3/
midi.rs

1#![allow(clippy::unnecessary_cast)]
2
3use crate::midi::io::MidiEvent;
4use std::cell::UnsafeCell;
5use vst3::Steinberg::Vst::ControllerNumbers_::{kAfterTouch, kCtrlProgramChange, kPitchBend};
6use vst3::Steinberg::Vst::DataEvent_::DataTypes_;
7use vst3::Steinberg::Vst::Event_::EventTypes_;
8use vst3::Steinberg::Vst::{
9    CtrlNumber, DataEvent, Event, Event__type0, IEventList, IEventListTrait, IMidiMapping,
10    IMidiMappingTrait, IParamValueQueue, IParamValueQueueTrait, IParameterChanges,
11    IParameterChangesTrait, LegacyMIDICCOutEvent, NoteOffEvent, NoteOnEvent, ParamID,
12    PolyPressureEvent,
13};
14use vst3::Steinberg::{kInvalidArgument, kResultFalse, kResultOk};
15use vst3::{Class, ComPtr, ComWrapper};
16
17pub struct EventBuffer {
18    events: UnsafeCell<Vec<Event>>,
19    sysex_data: UnsafeCell<Vec<Vec<u8>>>,
20}
21
22impl Class for EventBuffer {
23    type Interfaces = (IEventList,);
24}
25
26impl Default for EventBuffer {
27    fn default() -> Self {
28        Self::new()
29    }
30}
31
32impl EventBuffer {
33    pub fn new() -> Self {
34        Self {
35            events: UnsafeCell::new(Vec::new()),
36            sysex_data: UnsafeCell::new(Vec::new()),
37        }
38    }
39
40    pub fn clear(&mut self) {
41        self.events_mut().clear();
42        self.sysex_data_mut().clear();
43    }
44
45    pub fn from_midi_events(midi_events: &[MidiEvent], bus_index: i32) -> Self {
46        let buffer = Self::new();
47        for midi_event in midi_events {
48            buffer.push_midi_event(midi_event, bus_index);
49        }
50        buffer
51    }
52
53    pub fn to_midi_events(&self) -> Vec<MidiEvent> {
54        self.events_ref()
55            .iter()
56            .filter_map(vst3_event_to_midi)
57            .collect()
58    }
59
60    pub fn event_count(&self) -> usize {
61        self.events_ref().len()
62    }
63
64    pub fn event_list_ptr(list: &ComWrapper<Self>) -> *mut IEventList {
65        list.as_com_ref::<IEventList>()
66            .map(|r| r.as_ptr())
67            .unwrap_or(std::ptr::null_mut())
68    }
69
70    #[allow(clippy::mut_from_ref)]
71    fn events_mut(&self) -> &mut Vec<Event> {
72        unsafe { &mut *self.events.get() }
73    }
74
75    fn events_ref(&self) -> &Vec<Event> {
76        unsafe { &*self.events.get() }
77    }
78
79    #[allow(clippy::mut_from_ref)]
80    fn sysex_data_mut(&self) -> &mut Vec<Vec<u8>> {
81        unsafe { &mut *self.sysex_data.get() }
82    }
83
84    fn push_midi_event(&self, midi_event: &MidiEvent, bus_index: i32) {
85        if let Some(event) = midi_to_vst3_event(midi_event, bus_index, self.sysex_data_mut()) {
86            self.events_mut().push(event);
87        }
88    }
89}
90
91impl IEventListTrait for EventBuffer {
92    unsafe fn getEventCount(&self) -> i32 {
93        self.event_count().min(i32::MAX as usize) as i32
94    }
95
96    unsafe fn getEvent(&self, index: i32, e: *mut Event) -> i32 {
97        if index < 0 || e.is_null() {
98            return kInvalidArgument;
99        }
100        let Some(event) = self.events_ref().get(index as usize).copied() else {
101            return kResultFalse;
102        };
103        unsafe {
104            *e = event;
105        }
106        kResultOk
107    }
108
109    #[allow(clippy::unnecessary_cast)]
110    unsafe fn addEvent(&self, e: *mut Event) -> i32 {
111        if e.is_null() {
112            return kInvalidArgument;
113        }
114        let event = unsafe { *e };
115        if event.r#type == EventTypes_::kDataEvent as u16
116            && let Some(bytes) = copy_sysex_event(&event)
117        {
118            self.sysex_data_mut().push(bytes);
119            if let Some(last) = self.sysex_data_mut().last() {
120                self.events_mut().push(Event {
121                    __field0: Event__type0 {
122                        data: DataEvent {
123                            size: last.len().min(u32::MAX as usize) as u32,
124                            r#type: DataTypes_::kMidiSysEx as u32,
125                            bytes: last.as_ptr(),
126                        },
127                    },
128                    ..event
129                });
130                return kResultOk;
131            }
132        }
133        self.events_mut().push(event);
134        kResultOk
135    }
136}
137
138pub struct ParameterChanges {
139    queues: UnsafeCell<Vec<ComWrapper<ParameterValueQueue>>>,
140}
141
142impl Class for ParameterChanges {
143    type Interfaces = (IParameterChanges,);
144}
145
146impl Default for ParameterChanges {
147    fn default() -> Self {
148        Self::new()
149    }
150}
151
152impl ParameterChanges {
153    pub fn new() -> Self {
154        Self {
155            queues: UnsafeCell::new(Vec::new()),
156        }
157    }
158
159    pub fn from_midi_events(
160        midi_events: &[MidiEvent],
161        mapping: &ComPtr<IMidiMapping>,
162        bus_index: i32,
163    ) -> Option<Self> {
164        let changes = Self::new();
165        for midi_event in midi_events {
166            let Some((channel, controller, value)) = midi_to_controller_change(midi_event) else {
167                continue;
168            };
169            let mut param_id: ParamID = 0;
170            let result = unsafe {
171                mapping.getMidiControllerAssignment(bus_index, channel, controller, &mut param_id)
172            };
173            if result != kResultOk {
174                continue;
175            }
176            changes.add_point(
177                param_id,
178                midi_event.frame.min(i32::MAX as u32) as i32,
179                value as f64,
180            );
181        }
182
183        (!changes.queues_ref().is_empty()).then_some(changes)
184    }
185
186    pub fn changes_ptr(changes: &ComWrapper<Self>) -> *mut IParameterChanges {
187        changes
188            .as_com_ref::<IParameterChanges>()
189            .map(|r| r.as_ptr())
190            .unwrap_or(std::ptr::null_mut())
191    }
192
193    fn add_point(&self, param_id: ParamID, sample_offset: i32, value: f64) {
194        for queue in self.queues_ref() {
195            if queue.parameter_id() == param_id {
196                queue.push_point(sample_offset, value);
197                return;
198            }
199        }
200
201        let queue = ComWrapper::new(ParameterValueQueue::new(param_id));
202        queue.push_point(sample_offset, value);
203        self.queues_mut().push(queue);
204    }
205
206    #[allow(clippy::mut_from_ref)]
207    fn queues_mut(&self) -> &mut Vec<ComWrapper<ParameterValueQueue>> {
208        unsafe { &mut *self.queues.get() }
209    }
210
211    fn queues_ref(&self) -> &Vec<ComWrapper<ParameterValueQueue>> {
212        unsafe { &*self.queues.get() }
213    }
214}
215
216impl IParameterChangesTrait for ParameterChanges {
217    unsafe fn getParameterCount(&self) -> i32 {
218        self.queues_ref().len().min(i32::MAX as usize) as i32
219    }
220
221    unsafe fn getParameterData(&self, index: i32) -> *mut IParamValueQueue {
222        self.queues_ref()
223            .get(index.max(0) as usize)
224            .and_then(|queue| queue.as_com_ref::<IParamValueQueue>())
225            .map(|queue| queue.as_ptr())
226            .unwrap_or(std::ptr::null_mut())
227    }
228
229    unsafe fn addParameterData(
230        &self,
231        id: *const ParamID,
232        index: *mut i32,
233    ) -> *mut IParamValueQueue {
234        if id.is_null() {
235            return std::ptr::null_mut();
236        }
237        let param_id = unsafe { *id };
238        if let Some(existing_idx) = self
239            .queues_ref()
240            .iter()
241            .position(|queue| queue.parameter_id() == param_id)
242        {
243            if !index.is_null() {
244                unsafe {
245                    *index = existing_idx as i32;
246                }
247            }
248            return self
249                .queues_ref()
250                .get(existing_idx)
251                .and_then(|queue| queue.as_com_ref::<IParamValueQueue>())
252                .map(|queue| queue.as_ptr())
253                .unwrap_or(std::ptr::null_mut());
254        }
255
256        let queue = ComWrapper::new(ParameterValueQueue::new(param_id));
257        self.queues_mut().push(queue);
258        let idx = self.queues_ref().len().saturating_sub(1);
259        if !index.is_null() {
260            unsafe {
261                *index = idx as i32;
262            }
263        }
264        self.queues_ref()[idx]
265            .as_com_ref::<IParamValueQueue>()
266            .map(|queue| queue.as_ptr())
267            .unwrap_or(std::ptr::null_mut())
268    }
269}
270
271pub struct ParameterValueQueue {
272    param_id: ParamID,
273    points: UnsafeCell<Vec<(i32, f64)>>,
274}
275
276impl Class for ParameterValueQueue {
277    type Interfaces = (IParamValueQueue,);
278}
279
280impl ParameterValueQueue {
281    fn new(param_id: ParamID) -> Self {
282        Self {
283            param_id,
284            points: UnsafeCell::new(Vec::new()),
285        }
286    }
287
288    fn parameter_id(&self) -> ParamID {
289        self.param_id
290    }
291
292    fn push_point(&self, sample_offset: i32, value: f64) {
293        self.points_mut()
294            .push((sample_offset, value.clamp(0.0, 1.0)));
295    }
296
297    #[allow(clippy::mut_from_ref)]
298    fn points_mut(&self) -> &mut Vec<(i32, f64)> {
299        unsafe { &mut *self.points.get() }
300    }
301
302    fn points_ref(&self) -> &Vec<(i32, f64)> {
303        unsafe { &*self.points.get() }
304    }
305}
306
307impl IParamValueQueueTrait for ParameterValueQueue {
308    unsafe fn getParameterId(&self) -> ParamID {
309        self.param_id
310    }
311
312    unsafe fn getPointCount(&self) -> i32 {
313        self.points_ref().len().min(i32::MAX as usize) as i32
314    }
315
316    unsafe fn getPoint(&self, index: i32, sample_offset: *mut i32, value: *mut f64) -> i32 {
317        let Some((offset, point_value)) = self.points_ref().get(index.max(0) as usize).copied()
318        else {
319            return kResultFalse;
320        };
321        if !sample_offset.is_null() {
322            unsafe {
323                *sample_offset = offset;
324            }
325        }
326        if !value.is_null() {
327            unsafe {
328                *value = point_value;
329            }
330        }
331        kResultOk
332    }
333
334    unsafe fn addPoint(&self, sample_offset: i32, value: f64, index: *mut i32) -> i32 {
335        self.push_point(sample_offset, value);
336        if !index.is_null() {
337            unsafe {
338                *index = self.points_ref().len().saturating_sub(1) as i32;
339            }
340        }
341        kResultOk
342    }
343}
344
345#[allow(clippy::unnecessary_cast)]
346fn midi_to_vst3_event(
347    midi_event: &MidiEvent,
348    bus_index: i32,
349    sysex_storage: &mut Vec<Vec<u8>>,
350) -> Option<Event> {
351    let status = *midi_event.data.first()?;
352    let channel = (status & 0x0f) as i16;
353    let kind = status & 0xf0;
354    let sample_offset = midi_event.frame.min(i32::MAX as u32) as i32;
355
356    match kind {
357        0x80 => {
358            let pitch = *midi_event.data.get(1)? as i16;
359            let velocity = midi_velocity(midi_event.data.get(2).copied().unwrap_or(0));
360            Some(Event {
361                busIndex: bus_index,
362                sampleOffset: sample_offset,
363                ppqPosition: 0.0,
364                flags: 0,
365                r#type: EventTypes_::kNoteOffEvent as u16,
366                __field0: Event__type0 {
367                    noteOff: NoteOffEvent {
368                        channel,
369                        pitch,
370                        velocity,
371                        noteId: -1,
372                        tuning: 0.0,
373                    },
374                },
375            })
376        }
377        0x90 => {
378            let pitch = *midi_event.data.get(1)? as i16;
379            let velocity_byte = midi_event.data.get(2).copied().unwrap_or(0);
380            if velocity_byte == 0 {
381                return midi_to_vst3_event(
382                    &MidiEvent::new(midi_event.frame, vec![0x80 | channel as u8, pitch as u8, 0]),
383                    bus_index,
384                    sysex_storage,
385                );
386            }
387            Some(Event {
388                busIndex: bus_index,
389                sampleOffset: sample_offset,
390                ppqPosition: 0.0,
391                flags: 0,
392                r#type: EventTypes_::kNoteOnEvent as u16,
393                __field0: Event__type0 {
394                    noteOn: NoteOnEvent {
395                        channel,
396                        pitch,
397                        tuning: 0.0,
398                        velocity: midi_velocity(velocity_byte),
399                        length: 0,
400                        noteId: -1,
401                    },
402                },
403            })
404        }
405        0xA0 => {
406            let pitch = *midi_event.data.get(1)? as i16;
407            let pressure = midi_velocity(midi_event.data.get(2).copied().unwrap_or(0));
408            Some(Event {
409                busIndex: bus_index,
410                sampleOffset: sample_offset,
411                ppqPosition: 0.0,
412                flags: 0,
413                r#type: EventTypes_::kPolyPressureEvent as u16,
414                __field0: Event__type0 {
415                    polyPressure: PolyPressureEvent {
416                        channel,
417                        pitch,
418                        pressure,
419                        noteId: -1,
420                    },
421                },
422            })
423        }
424        0xF0 if midi_event.data.first().copied() == Some(0xF0) => {
425            sysex_storage.push(midi_event.data.clone());
426            let bytes = sysex_storage.last()?;
427            Some(Event {
428                busIndex: bus_index,
429                sampleOffset: sample_offset,
430                ppqPosition: 0.0,
431                flags: 0,
432                r#type: EventTypes_::kDataEvent as u16,
433                __field0: Event__type0 {
434                    data: DataEvent {
435                        size: bytes.len().min(u32::MAX as usize) as u32,
436                        r#type: DataTypes_::kMidiSysEx as u32,
437                        bytes: bytes.as_ptr(),
438                    },
439                },
440            })
441        }
442        _ => None,
443    }
444}
445
446#[allow(clippy::unnecessary_cast)]
447fn vst3_event_to_midi(event: &Event) -> Option<MidiEvent> {
448    let frame = event.sampleOffset.max(0) as u32;
449    let event_type = event.r#type as u32;
450    if event_type == EventTypes_::kNoteOnEvent as u32 {
451        let note = unsafe { event.__field0.noteOn };
452        Some(MidiEvent::new(
453            frame,
454            vec![
455                0x90 | (note.channel as u8 & 0x0f),
456                note.pitch.clamp(0, 127) as u8,
457                midi_byte(note.velocity),
458            ],
459        ))
460    } else if event_type == EventTypes_::kNoteOffEvent as u32 {
461        let note = unsafe { event.__field0.noteOff };
462        Some(MidiEvent::new(
463            frame,
464            vec![
465                0x80 | (note.channel as u8 & 0x0f),
466                note.pitch.clamp(0, 127) as u8,
467                midi_byte(note.velocity),
468            ],
469        ))
470    } else if event_type == EventTypes_::kPolyPressureEvent as u32 {
471        let pressure = unsafe { event.__field0.polyPressure };
472        Some(MidiEvent::new(
473            frame,
474            vec![
475                0xA0 | (pressure.channel as u8 & 0x0f),
476                pressure.pitch.clamp(0, 127) as u8,
477                midi_byte(pressure.pressure),
478            ],
479        ))
480    } else if event_type == EventTypes_::kDataEvent as u32 {
481        let data = unsafe { event.__field0.data };
482        (data.r#type == DataTypes_::kMidiSysEx as u32 && !data.bytes.is_null()).then(|| {
483            let bytes = unsafe {
484                std::slice::from_raw_parts(data.bytes, data.size.min(usize::MAX as u32) as usize)
485            };
486            MidiEvent::new(frame, bytes.to_vec())
487        })
488    } else if event_type == EventTypes_::kLegacyMIDICCOutEvent as u32 {
489        let cc = unsafe { event.__field0.midiCCOut };
490        legacy_cc_to_midi(frame, cc)
491    } else {
492        None
493    }
494}
495
496fn midi_to_controller_change(midi_event: &MidiEvent) -> Option<(i16, CtrlNumber, f32)> {
497    let status = *midi_event.data.first()?;
498    let channel = (status & 0x0f) as i16;
499    let kind = status & 0xf0;
500    match kind {
501        0xB0 => Some((
502            channel,
503            midi_event.data.get(1).copied()? as CtrlNumber,
504            midi_velocity(midi_event.data.get(2).copied().unwrap_or(0)),
505        )),
506        0xC0 => Some((
507            channel,
508            kCtrlProgramChange as CtrlNumber,
509            midi_velocity(midi_event.data.get(1).copied().unwrap_or(0)),
510        )),
511        0xD0 => Some((
512            channel,
513            kAfterTouch as CtrlNumber,
514            midi_velocity(midi_event.data.get(1).copied().unwrap_or(0)),
515        )),
516        0xE0 => {
517            let lsb = midi_event.data.get(1).copied().unwrap_or(0) as u16;
518            let msb = midi_event.data.get(2).copied().unwrap_or(0) as u16;
519            let value = ((msb << 7) | lsb).min(16383) as f32 / 16383.0;
520            Some((channel, kPitchBend as CtrlNumber, value))
521        }
522        _ => None,
523    }
524}
525
526fn legacy_cc_to_midi(frame: u32, cc: LegacyMIDICCOutEvent) -> Option<MidiEvent> {
527    let channel = (cc.channel as u8) & 0x0f;
528    let control = cc.controlNumber as u16;
529    let value = (cc.value as i16).clamp(0, 127) as u8;
530    let value2 = (cc.value2 as i16).clamp(0, 127) as u8;
531
532    let data = match control {
533        x if x == kPitchBend as u16 => vec![0xE0 | channel, value, value2],
534        x if x == kAfterTouch as u16 => vec![0xD0 | channel, value],
535        x if x == kCtrlProgramChange as u16 => vec![0xC0 | channel, value],
536        x if x <= 127 => vec![0xB0 | channel, x as u8, value],
537        _ => return None,
538    };
539    Some(MidiEvent::new(frame, data))
540}
541
542#[allow(clippy::unnecessary_cast)]
543fn copy_sysex_event(event: &Event) -> Option<Vec<u8>> {
544    let data = unsafe { event.__field0.data };
545    (data.r#type == DataTypes_::kMidiSysEx as u32 && !data.bytes.is_null()).then(|| unsafe {
546        std::slice::from_raw_parts(data.bytes, data.size.min(usize::MAX as u32) as usize).to_vec()
547    })
548}
549
550fn midi_velocity(value: u8) -> f32 {
551    (value.min(127) as f32) / 127.0
552}
553
554fn midi_byte(value: f32) -> u8 {
555    (value.clamp(0.0, 1.0) * 127.0).round() as u8
556}
557
558#[cfg(test)]
559mod tests {
560    use super::*;
561
562    #[test]
563    fn test_event_buffer_creation() {
564        let buffer = EventBuffer::new();
565        assert_eq!(buffer.event_count(), 0);
566    }
567
568    #[test]
569    fn test_midi_events_conversion() {
570        let midi = vec![
571            MidiEvent::new(0, vec![0x90, 60, 100]),
572            MidiEvent::new(100, vec![0x80, 60, 64]),
573        ];
574
575        let buffer = EventBuffer::from_midi_events(&midi, 0);
576        assert_eq!(buffer.event_count(), 2);
577
578        let output = buffer.to_midi_events();
579        assert_eq!(output.len(), 2);
580        assert_eq!(output[0].frame, 0);
581        assert_eq!(output[0].data, vec![0x90, 60, 100]);
582    }
583
584    #[test]
585    fn test_unsupported_events_are_ignored() {
586        let midi = vec![MidiEvent::new(0, vec![0xB0, 74, 100])];
587        let buffer = EventBuffer::from_midi_events(&midi, 0);
588        assert_eq!(buffer.event_count(), 0);
589    }
590
591    #[test]
592    fn test_sysex_roundtrip() {
593        let midi = vec![MidiEvent::new(12, vec![0xF0, 0x7D, 0x10, 0xF7])];
594        let buffer = EventBuffer::from_midi_events(&midi, 0);
595        let output = buffer.to_midi_events();
596        assert_eq!(output, midi);
597    }
598
599    #[test]
600    fn test_controller_changes_map_to_parameter_changes() {
601        struct TestMidiMapping;
602        impl Class for TestMidiMapping {
603            type Interfaces = (IMidiMapping,);
604        }
605        impl IMidiMappingTrait for TestMidiMapping {
606            unsafe fn getMidiControllerAssignment(
607                &self,
608                _bus_index: i32,
609                _channel: i16,
610                midi_controller_number: CtrlNumber,
611                id: *mut ParamID,
612            ) -> i32 {
613                if midi_controller_number == 74 {
614                    unsafe {
615                        *id = 1234;
616                    }
617                    kResultOk
618                } else {
619                    kResultFalse
620                }
621            }
622        }
623
624        let mapping = ComWrapper::new(TestMidiMapping)
625            .to_com_ptr::<IMidiMapping>()
626            .unwrap();
627        let changes = ParameterChanges::from_midi_events(
628            &[MidiEvent::new(64, vec![0xB0, 74, 100])],
629            &mapping,
630            0,
631        )
632        .unwrap();
633        assert_eq!(unsafe { changes.getParameterCount() }, 1);
634        let queue_ptr = unsafe { changes.getParameterData(0) };
635        let queue = unsafe { vst3::ComRef::from_raw(queue_ptr) }.unwrap();
636        assert_eq!(unsafe { queue.getParameterId() }, 1234);
637        assert_eq!(unsafe { queue.getPointCount() }, 1);
638    }
639
640    #[test]
641    fn test_empty_buffer() {
642        let buffer = EventBuffer::from_midi_events(&[], 0);
643        assert_eq!(buffer.event_count(), 0);
644        assert_eq!(buffer.to_midi_events().len(), 0);
645    }
646}