Skip to main content

maolan_engine/plugins/vst3/
midi.rs

1// The vst3 crate uses platform-dependent types for enum constants, so explicit
2// casts are required for cross-platform compilation.
3#![allow(clippy::unnecessary_cast)]
4
5use crate::midi::io::MidiEvent;
6use std::cell::UnsafeCell;
7use vst3::Steinberg::Vst::ControllerNumbers_::{kAfterTouch, kCtrlProgramChange, kPitchBend};
8use vst3::Steinberg::Vst::DataEvent_::DataTypes_;
9use vst3::Steinberg::Vst::Event_::EventTypes_;
10use vst3::Steinberg::Vst::{
11    CtrlNumber, DataEvent, Event, Event__type0, IEventList, IEventListTrait, IMidiMapping,
12    IMidiMappingTrait, IParamValueQueue, IParamValueQueueTrait, IParameterChanges,
13    IParameterChangesTrait, LegacyMIDICCOutEvent, NoteOffEvent, NoteOnEvent, ParamID,
14    PolyPressureEvent,
15};
16use vst3::Steinberg::{kInvalidArgument, kResultFalse, kResultOk};
17use vst3::{Class, ComPtr, ComWrapper};
18
19pub struct EventBuffer {
20    events: UnsafeCell<Vec<Event>>,
21    sysex_data: UnsafeCell<Vec<Vec<u8>>>,
22}
23
24impl Class for EventBuffer {
25    type Interfaces = (IEventList,);
26}
27
28impl Default for EventBuffer {
29    fn default() -> Self {
30        Self::new()
31    }
32}
33
34impl EventBuffer {
35    pub fn new() -> Self {
36        Self {
37            events: UnsafeCell::new(Vec::new()),
38            sysex_data: UnsafeCell::new(Vec::new()),
39        }
40    }
41
42    pub fn clear(&mut self) {
43        self.events_mut().clear();
44        self.sysex_data_mut().clear();
45    }
46
47    pub fn from_midi_events(midi_events: &[MidiEvent], bus_index: i32) -> Self {
48        let buffer = Self::new();
49        for midi_event in midi_events {
50            buffer.push_midi_event(midi_event, bus_index);
51        }
52        buffer
53    }
54
55    pub fn to_midi_events(&self) -> Vec<MidiEvent> {
56        self.events_ref()
57            .iter()
58            .filter_map(vst3_event_to_midi)
59            .collect()
60    }
61
62    pub fn event_count(&self) -> usize {
63        self.events_ref().len()
64    }
65
66    pub fn event_list_ptr(list: &ComWrapper<Self>) -> *mut IEventList {
67        list.as_com_ref::<IEventList>()
68            .map(|r| r.as_ptr())
69            .unwrap_or(std::ptr::null_mut())
70    }
71
72    #[allow(clippy::mut_from_ref)]
73    fn events_mut(&self) -> &mut Vec<Event> {
74        unsafe { &mut *self.events.get() }
75    }
76
77    fn events_ref(&self) -> &Vec<Event> {
78        unsafe { &*self.events.get() }
79    }
80
81    #[allow(clippy::mut_from_ref)]
82    fn sysex_data_mut(&self) -> &mut Vec<Vec<u8>> {
83        unsafe { &mut *self.sysex_data.get() }
84    }
85
86    fn push_midi_event(&self, midi_event: &MidiEvent, bus_index: i32) {
87        if let Some(event) = midi_to_vst3_event(midi_event, bus_index, self.sysex_data_mut()) {
88            self.events_mut().push(event);
89        }
90    }
91}
92
93impl IEventListTrait for EventBuffer {
94    unsafe fn getEventCount(&self) -> i32 {
95        self.event_count().min(i32::MAX as usize) as i32
96    }
97
98    unsafe fn getEvent(&self, index: i32, e: *mut Event) -> i32 {
99        if index < 0 || e.is_null() {
100            return kInvalidArgument;
101        }
102        let Some(event) = self.events_ref().get(index as usize).copied() else {
103            return kResultFalse;
104        };
105        unsafe {
106            *e = event;
107        }
108        kResultOk
109    }
110
111    #[allow(clippy::unnecessary_cast)]
112    unsafe fn addEvent(&self, e: *mut Event) -> i32 {
113        if e.is_null() {
114            return kInvalidArgument;
115        }
116        let event = unsafe { *e };
117        if event.r#type == EventTypes_::kDataEvent as u16
118            && let Some(bytes) = copy_sysex_event(&event)
119        {
120            self.sysex_data_mut().push(bytes);
121            if let Some(last) = self.sysex_data_mut().last() {
122                self.events_mut().push(Event {
123                    __field0: Event__type0 {
124                        data: DataEvent {
125                            size: last.len().min(u32::MAX as usize) as u32,
126                            r#type: DataTypes_::kMidiSysEx as u32,
127                            bytes: last.as_ptr(),
128                        },
129                    },
130                    ..event
131                });
132                return kResultOk;
133            }
134        }
135        self.events_mut().push(event);
136        kResultOk
137    }
138}
139
140pub struct ParameterChanges {
141    queues: UnsafeCell<Vec<ComWrapper<ParameterValueQueue>>>,
142}
143
144impl Class for ParameterChanges {
145    type Interfaces = (IParameterChanges,);
146}
147
148impl Default for ParameterChanges {
149    fn default() -> Self {
150        Self::new()
151    }
152}
153
154impl ParameterChanges {
155    pub fn new() -> Self {
156        Self {
157            queues: UnsafeCell::new(Vec::new()),
158        }
159    }
160
161    pub fn from_midi_events(
162        midi_events: &[MidiEvent],
163        mapping: &ComPtr<IMidiMapping>,
164        bus_index: i32,
165    ) -> Option<Self> {
166        let changes = Self::new();
167        for midi_event in midi_events {
168            let Some((channel, controller, value)) = midi_to_controller_change(midi_event) else {
169                continue;
170            };
171            let mut param_id: ParamID = 0;
172            let result = unsafe {
173                mapping.getMidiControllerAssignment(bus_index, channel, controller, &mut param_id)
174            };
175            if result != kResultOk {
176                continue;
177            }
178            changes.add_point(
179                param_id,
180                midi_event.frame.min(i32::MAX as u32) as i32,
181                value as f64,
182            );
183        }
184
185        (!changes.queues_ref().is_empty()).then_some(changes)
186    }
187
188    pub fn changes_ptr(changes: &ComWrapper<Self>) -> *mut IParameterChanges {
189        changes
190            .as_com_ref::<IParameterChanges>()
191            .map(|r| r.as_ptr())
192            .unwrap_or(std::ptr::null_mut())
193    }
194
195    fn add_point(&self, param_id: ParamID, sample_offset: i32, value: f64) {
196        for queue in self.queues_ref() {
197            if queue.parameter_id() == param_id {
198                queue.push_point(sample_offset, value);
199                return;
200            }
201        }
202
203        let queue = ComWrapper::new(ParameterValueQueue::new(param_id));
204        queue.push_point(sample_offset, value);
205        self.queues_mut().push(queue);
206    }
207
208    #[allow(clippy::mut_from_ref)]
209    fn queues_mut(&self) -> &mut Vec<ComWrapper<ParameterValueQueue>> {
210        unsafe { &mut *self.queues.get() }
211    }
212
213    fn queues_ref(&self) -> &Vec<ComWrapper<ParameterValueQueue>> {
214        unsafe { &*self.queues.get() }
215    }
216}
217
218impl IParameterChangesTrait for ParameterChanges {
219    unsafe fn getParameterCount(&self) -> i32 {
220        self.queues_ref().len().min(i32::MAX as usize) as i32
221    }
222
223    unsafe fn getParameterData(&self, index: i32) -> *mut IParamValueQueue {
224        self.queues_ref()
225            .get(index.max(0) as usize)
226            .and_then(|queue| queue.as_com_ref::<IParamValueQueue>())
227            .map(|queue| queue.as_ptr())
228            .unwrap_or(std::ptr::null_mut())
229    }
230
231    unsafe fn addParameterData(
232        &self,
233        id: *const ParamID,
234        index: *mut i32,
235    ) -> *mut IParamValueQueue {
236        if id.is_null() {
237            return std::ptr::null_mut();
238        }
239        let param_id = unsafe { *id };
240        if let Some(existing_idx) = self
241            .queues_ref()
242            .iter()
243            .position(|queue| queue.parameter_id() == param_id)
244        {
245            if !index.is_null() {
246                unsafe {
247                    *index = existing_idx as i32;
248                }
249            }
250            return self
251                .queues_ref()
252                .get(existing_idx)
253                .and_then(|queue| queue.as_com_ref::<IParamValueQueue>())
254                .map(|queue| queue.as_ptr())
255                .unwrap_or(std::ptr::null_mut());
256        }
257
258        let queue = ComWrapper::new(ParameterValueQueue::new(param_id));
259        self.queues_mut().push(queue);
260        let idx = self.queues_ref().len().saturating_sub(1);
261        if !index.is_null() {
262            unsafe {
263                *index = idx as i32;
264            }
265        }
266        self.queues_ref()[idx]
267            .as_com_ref::<IParamValueQueue>()
268            .map(|queue| queue.as_ptr())
269            .unwrap_or(std::ptr::null_mut())
270    }
271}
272
273pub struct ParameterValueQueue {
274    param_id: ParamID,
275    points: UnsafeCell<Vec<(i32, f64)>>,
276}
277
278impl Class for ParameterValueQueue {
279    type Interfaces = (IParamValueQueue,);
280}
281
282impl ParameterValueQueue {
283    fn new(param_id: ParamID) -> Self {
284        Self {
285            param_id,
286            points: UnsafeCell::new(Vec::new()),
287        }
288    }
289
290    fn parameter_id(&self) -> ParamID {
291        self.param_id
292    }
293
294    fn push_point(&self, sample_offset: i32, value: f64) {
295        self.points_mut()
296            .push((sample_offset, value.clamp(0.0, 1.0)));
297    }
298
299    #[allow(clippy::mut_from_ref)]
300    fn points_mut(&self) -> &mut Vec<(i32, f64)> {
301        unsafe { &mut *self.points.get() }
302    }
303
304    fn points_ref(&self) -> &Vec<(i32, f64)> {
305        unsafe { &*self.points.get() }
306    }
307}
308
309impl IParamValueQueueTrait for ParameterValueQueue {
310    unsafe fn getParameterId(&self) -> ParamID {
311        self.param_id
312    }
313
314    unsafe fn getPointCount(&self) -> i32 {
315        self.points_ref().len().min(i32::MAX as usize) as i32
316    }
317
318    unsafe fn getPoint(&self, index: i32, sample_offset: *mut i32, value: *mut f64) -> i32 {
319        let Some((offset, point_value)) = self.points_ref().get(index.max(0) as usize).copied()
320        else {
321            return kResultFalse;
322        };
323        if !sample_offset.is_null() {
324            unsafe {
325                *sample_offset = offset;
326            }
327        }
328        if !value.is_null() {
329            unsafe {
330                *value = point_value;
331            }
332        }
333        kResultOk
334    }
335
336    unsafe fn addPoint(&self, sample_offset: i32, value: f64, index: *mut i32) -> i32 {
337        self.push_point(sample_offset, value);
338        if !index.is_null() {
339            unsafe {
340                *index = self.points_ref().len().saturating_sub(1) as i32;
341            }
342        }
343        kResultOk
344    }
345}
346
347#[allow(clippy::unnecessary_cast)]
348fn midi_to_vst3_event(
349    midi_event: &MidiEvent,
350    bus_index: i32,
351    sysex_storage: &mut Vec<Vec<u8>>,
352) -> Option<Event> {
353    let status = *midi_event.data.first()?;
354    let channel = (status & 0x0f) as i16;
355    let kind = status & 0xf0;
356    let sample_offset = midi_event.frame.min(i32::MAX as u32) as i32;
357
358    match kind {
359        0x80 => {
360            let pitch = *midi_event.data.get(1)? as i16;
361            let velocity = midi_velocity(midi_event.data.get(2).copied().unwrap_or(0));
362            Some(Event {
363                busIndex: bus_index,
364                sampleOffset: sample_offset,
365                ppqPosition: 0.0,
366                flags: 0,
367                r#type: EventTypes_::kNoteOffEvent as u16,
368                __field0: Event__type0 {
369                    noteOff: NoteOffEvent {
370                        channel,
371                        pitch,
372                        velocity,
373                        noteId: -1,
374                        tuning: 0.0,
375                    },
376                },
377            })
378        }
379        0x90 => {
380            let pitch = *midi_event.data.get(1)? as i16;
381            let velocity_byte = midi_event.data.get(2).copied().unwrap_or(0);
382            if velocity_byte == 0 {
383                return midi_to_vst3_event(
384                    &MidiEvent::new(midi_event.frame, vec![0x80 | channel as u8, pitch as u8, 0]),
385                    bus_index,
386                    sysex_storage,
387                );
388            }
389            Some(Event {
390                busIndex: bus_index,
391                sampleOffset: sample_offset,
392                ppqPosition: 0.0,
393                flags: 0,
394                r#type: EventTypes_::kNoteOnEvent as u16,
395                __field0: Event__type0 {
396                    noteOn: NoteOnEvent {
397                        channel,
398                        pitch,
399                        tuning: 0.0,
400                        velocity: midi_velocity(velocity_byte),
401                        length: 0,
402                        noteId: -1,
403                    },
404                },
405            })
406        }
407        0xA0 => {
408            let pitch = *midi_event.data.get(1)? as i16;
409            let pressure = midi_velocity(midi_event.data.get(2).copied().unwrap_or(0));
410            Some(Event {
411                busIndex: bus_index,
412                sampleOffset: sample_offset,
413                ppqPosition: 0.0,
414                flags: 0,
415                r#type: EventTypes_::kPolyPressureEvent as u16,
416                __field0: Event__type0 {
417                    polyPressure: PolyPressureEvent {
418                        channel,
419                        pitch,
420                        pressure,
421                        noteId: -1,
422                    },
423                },
424            })
425        }
426        0xF0 if midi_event.data.first().copied() == Some(0xF0) => {
427            sysex_storage.push(midi_event.data.clone());
428            let bytes = sysex_storage.last()?;
429            Some(Event {
430                busIndex: bus_index,
431                sampleOffset: sample_offset,
432                ppqPosition: 0.0,
433                flags: 0,
434                r#type: EventTypes_::kDataEvent as u16,
435                __field0: Event__type0 {
436                    data: DataEvent {
437                        size: bytes.len().min(u32::MAX as usize) as u32,
438                        r#type: DataTypes_::kMidiSysEx as u32,
439                        bytes: bytes.as_ptr(),
440                    },
441                },
442            })
443        }
444        _ => None,
445    }
446}
447
448#[allow(clippy::unnecessary_cast)]
449fn vst3_event_to_midi(event: &Event) -> Option<MidiEvent> {
450    let frame = event.sampleOffset.max(0) as u32;
451    let event_type = event.r#type as u32;
452    if event_type == EventTypes_::kNoteOnEvent as u32 {
453        let note = unsafe { event.__field0.noteOn };
454        Some(MidiEvent::new(
455            frame,
456            vec![
457                0x90 | (note.channel as u8 & 0x0f),
458                note.pitch.clamp(0, 127) as u8,
459                midi_byte(note.velocity),
460            ],
461        ))
462    } else if event_type == EventTypes_::kNoteOffEvent as u32 {
463        let note = unsafe { event.__field0.noteOff };
464        Some(MidiEvent::new(
465            frame,
466            vec![
467                0x80 | (note.channel as u8 & 0x0f),
468                note.pitch.clamp(0, 127) as u8,
469                midi_byte(note.velocity),
470            ],
471        ))
472    } else if event_type == EventTypes_::kPolyPressureEvent as u32 {
473        let pressure = unsafe { event.__field0.polyPressure };
474        Some(MidiEvent::new(
475            frame,
476            vec![
477                0xA0 | (pressure.channel as u8 & 0x0f),
478                pressure.pitch.clamp(0, 127) as u8,
479                midi_byte(pressure.pressure),
480            ],
481        ))
482    } else if event_type == EventTypes_::kDataEvent as u32 {
483        let data = unsafe { event.__field0.data };
484        (data.r#type == DataTypes_::kMidiSysEx as u32 && !data.bytes.is_null()).then(|| {
485            let bytes = unsafe {
486                std::slice::from_raw_parts(data.bytes, data.size.min(usize::MAX as u32) as usize)
487            };
488            MidiEvent::new(frame, bytes.to_vec())
489        })
490    } else if event_type == EventTypes_::kLegacyMIDICCOutEvent as u32 {
491        let cc = unsafe { event.__field0.midiCCOut };
492        legacy_cc_to_midi(frame, cc)
493    } else {
494        None
495    }
496}
497
498fn midi_to_controller_change(midi_event: &MidiEvent) -> Option<(i16, CtrlNumber, f32)> {
499    let status = *midi_event.data.first()?;
500    let channel = (status & 0x0f) as i16;
501    let kind = status & 0xf0;
502    match kind {
503        0xB0 => Some((
504            channel,
505            midi_event.data.get(1).copied()? as CtrlNumber,
506            midi_velocity(midi_event.data.get(2).copied().unwrap_or(0)),
507        )),
508        0xC0 => Some((
509            channel,
510            kCtrlProgramChange as CtrlNumber,
511            midi_velocity(midi_event.data.get(1).copied().unwrap_or(0)),
512        )),
513        0xD0 => Some((
514            channel,
515            kAfterTouch as CtrlNumber,
516            midi_velocity(midi_event.data.get(1).copied().unwrap_or(0)),
517        )),
518        0xE0 => {
519            let lsb = midi_event.data.get(1).copied().unwrap_or(0) as u16;
520            let msb = midi_event.data.get(2).copied().unwrap_or(0) as u16;
521            let value = ((msb << 7) | lsb).min(16383) as f32 / 16383.0;
522            Some((channel, kPitchBend as CtrlNumber, value))
523        }
524        _ => None,
525    }
526}
527
528fn legacy_cc_to_midi(frame: u32, cc: LegacyMIDICCOutEvent) -> Option<MidiEvent> {
529    let channel = (cc.channel as u8) & 0x0f;
530    let control = cc.controlNumber as u16;
531    let value = (cc.value as i16).clamp(0, 127) as u8;
532    let value2 = (cc.value2 as i16).clamp(0, 127) as u8;
533
534    let data = match control {
535        x if x == kPitchBend as u16 => vec![0xE0 | channel, value, value2],
536        x if x == kAfterTouch as u16 => vec![0xD0 | channel, value],
537        x if x == kCtrlProgramChange as u16 => vec![0xC0 | channel, value],
538        x if x <= 127 => vec![0xB0 | channel, x as u8, value],
539        _ => return None,
540    };
541    Some(MidiEvent::new(frame, data))
542}
543
544#[allow(clippy::unnecessary_cast)]
545fn copy_sysex_event(event: &Event) -> Option<Vec<u8>> {
546    let data = unsafe { event.__field0.data };
547    (data.r#type == DataTypes_::kMidiSysEx as u32 && !data.bytes.is_null()).then(|| unsafe {
548        std::slice::from_raw_parts(data.bytes, data.size.min(usize::MAX as u32) as usize).to_vec()
549    })
550}
551
552fn midi_velocity(value: u8) -> f32 {
553    (value.min(127) as f32) / 127.0
554}
555
556fn midi_byte(value: f32) -> u8 {
557    (value.clamp(0.0, 1.0) * 127.0).round() as u8
558}
559
560#[cfg(test)]
561mod tests {
562    use super::*;
563
564    #[test]
565    fn test_event_buffer_creation() {
566        let buffer = EventBuffer::new();
567        assert_eq!(buffer.event_count(), 0);
568    }
569
570    #[test]
571    fn test_midi_events_conversion() {
572        let midi = vec![
573            MidiEvent::new(0, vec![0x90, 60, 100]),  // Note On C4
574            MidiEvent::new(100, vec![0x80, 60, 64]), // Note Off C4
575        ];
576
577        let buffer = EventBuffer::from_midi_events(&midi, 0);
578        assert_eq!(buffer.event_count(), 2);
579
580        let output = buffer.to_midi_events();
581        assert_eq!(output.len(), 2);
582        assert_eq!(output[0].frame, 0);
583        assert_eq!(output[0].data, vec![0x90, 60, 100]);
584    }
585
586    #[test]
587    fn test_unsupported_events_are_ignored() {
588        let midi = vec![MidiEvent::new(0, vec![0xB0, 74, 100])];
589        let buffer = EventBuffer::from_midi_events(&midi, 0);
590        assert_eq!(buffer.event_count(), 0);
591    }
592
593    #[test]
594    fn test_sysex_roundtrip() {
595        let midi = vec![MidiEvent::new(12, vec![0xF0, 0x7D, 0x10, 0xF7])];
596        let buffer = EventBuffer::from_midi_events(&midi, 0);
597        let output = buffer.to_midi_events();
598        assert_eq!(output, midi);
599    }
600
601    #[test]
602    fn test_controller_changes_map_to_parameter_changes() {
603        struct TestMidiMapping;
604        impl Class for TestMidiMapping {
605            type Interfaces = (IMidiMapping,);
606        }
607        impl IMidiMappingTrait for TestMidiMapping {
608            unsafe fn getMidiControllerAssignment(
609                &self,
610                _bus_index: i32,
611                _channel: i16,
612                midi_controller_number: CtrlNumber,
613                id: *mut ParamID,
614            ) -> i32 {
615                if midi_controller_number == 74 {
616                    unsafe {
617                        *id = 1234;
618                    }
619                    kResultOk
620                } else {
621                    kResultFalse
622                }
623            }
624        }
625
626        let mapping = ComWrapper::new(TestMidiMapping)
627            .to_com_ptr::<IMidiMapping>()
628            .unwrap();
629        let changes = ParameterChanges::from_midi_events(
630            &[MidiEvent::new(64, vec![0xB0, 74, 100])],
631            &mapping,
632            0,
633        )
634        .unwrap();
635        assert_eq!(unsafe { changes.getParameterCount() }, 1);
636        let queue_ptr = unsafe { changes.getParameterData(0) };
637        let queue = unsafe { vst3::ComRef::from_raw(queue_ptr) }.unwrap();
638        assert_eq!(unsafe { queue.getParameterId() }, 1234);
639        assert_eq!(unsafe { queue.getPointCount() }, 1);
640    }
641
642    #[test]
643    fn test_empty_buffer() {
644        let buffer = EventBuffer::from_midi_events(&[], 0);
645        assert_eq!(buffer.event_count(), 0);
646        assert_eq!(buffer.to_midi_events().len(), 0);
647    }
648}