Skip to main content

maolan_engine/plugins/vst3/
state.rs

1use serde::{Deserialize, Serialize};
2use std::cell::UnsafeCell;
3use vst3::Steinberg::{IBStreamTrait, kResultFalse, kResultOk};
4use vst3::{Class, ComWrapper};
5
6type TResult = i32;
7
8/// VST3 plugin state snapshot
9#[derive(Clone, Debug, Serialize, Deserialize)]
10pub struct Vst3PluginState {
11    pub plugin_id: String,
12    pub component_state: Vec<u8>,
13    pub controller_state: Vec<u8>,
14}
15
16/// Memory-based stream for VST3 state I/O
17/// Uses UnsafeCell for interior mutability as required by IBStreamTrait
18pub struct MemoryStream {
19    data: UnsafeCell<Vec<u8>>,
20    position: UnsafeCell<usize>,
21}
22
23impl Class for MemoryStream {
24    type Interfaces = (vst3::Steinberg::IBStream,);
25}
26
27impl Default for MemoryStream {
28    fn default() -> Self {
29        Self::new()
30    }
31}
32
33impl MemoryStream {
34    pub fn new() -> Self {
35        Self {
36            data: UnsafeCell::new(Vec::new()),
37            position: UnsafeCell::new(0),
38        }
39    }
40
41    pub fn from_bytes(data: &[u8]) -> Self {
42        Self {
43            data: UnsafeCell::new(data.to_vec()),
44            position: UnsafeCell::new(0),
45        }
46    }
47
48    pub fn bytes(&self) -> Vec<u8> {
49        unsafe { self.data_ref().clone() }
50    }
51
52    // Helper methods for safe access (used in unsafe blocks)
53    #[allow(clippy::mut_from_ref)]
54    unsafe fn data_mut(&self) -> &mut Vec<u8> {
55        unsafe { &mut *self.data.get() }
56    }
57
58    #[allow(clippy::mut_from_ref)]
59    unsafe fn position_mut(&self) -> &mut usize {
60        unsafe { &mut *self.position.get() }
61    }
62
63    unsafe fn data_ref(&self) -> &Vec<u8> {
64        unsafe { &*self.data.get() }
65    }
66
67    unsafe fn position_ref(&self) -> &usize {
68        unsafe { &*self.position.get() }
69    }
70}
71
72pub fn ibstream_ptr(stream: &ComWrapper<MemoryStream>) -> *mut vst3::Steinberg::IBStream {
73    stream
74        .as_com_ref::<vst3::Steinberg::IBStream>()
75        .map(|r| r.as_ptr())
76        .unwrap_or(std::ptr::null_mut())
77}
78
79impl IBStreamTrait for MemoryStream {
80    unsafe fn read(
81        &self,
82        buffer: *mut std::os::raw::c_void,
83        num_bytes: i32,
84        num_bytes_read: *mut i32,
85    ) -> TResult {
86        if buffer.is_null() || num_bytes < 0 {
87            return kResultFalse;
88        }
89
90        let bytes_to_read = num_bytes as usize;
91        let data = unsafe { self.data_ref() };
92        let position = unsafe { *self.position_ref() };
93        let available = data.len().saturating_sub(position);
94        let actual_read = bytes_to_read.min(available);
95
96        if actual_read == 0 {
97            if !num_bytes_read.is_null() {
98                unsafe {
99                    *num_bytes_read = 0;
100                }
101            }
102            return kResultFalse;
103        }
104
105        // Copy data from internal buffer to provided buffer
106        let src_slice = &data[position..position + actual_read];
107        let dst_slice = unsafe { std::slice::from_raw_parts_mut(buffer as *mut u8, actual_read) };
108        dst_slice.copy_from_slice(src_slice);
109
110        unsafe {
111            *self.position_mut() += actual_read;
112        }
113
114        if !num_bytes_read.is_null() {
115            unsafe {
116                *num_bytes_read = actual_read as i32;
117            }
118        }
119
120        kResultOk
121    }
122
123    unsafe fn write(
124        &self,
125        buffer: *mut std::os::raw::c_void,
126        num_bytes: i32,
127        num_bytes_written: *mut i32,
128    ) -> TResult {
129        if buffer.is_null() || num_bytes < 0 {
130            return kResultFalse;
131        }
132
133        let bytes_to_write = num_bytes as usize;
134        let src_slice = unsafe { std::slice::from_raw_parts(buffer as *mut u8, bytes_to_write) };
135
136        let data = unsafe { self.data_mut() };
137        let position = unsafe { *self.position_ref() };
138
139        // Ensure capacity
140        let required_len = position + bytes_to_write;
141        if required_len > data.len() {
142            data.resize(required_len, 0);
143        }
144
145        // Write data
146        data[position..position + bytes_to_write].copy_from_slice(src_slice);
147        unsafe {
148            *self.position_mut() += bytes_to_write;
149        }
150
151        if !num_bytes_written.is_null() {
152            unsafe {
153                *num_bytes_written = bytes_to_write as i32;
154            }
155        }
156
157        kResultOk
158    }
159
160    unsafe fn seek(&self, pos: i64, mode: i32, result: *mut i64) -> TResult {
161        let current_pos = unsafe { *self.position_ref() };
162        let data_len = unsafe { self.data_ref().len() };
163
164        let new_position = match mode {
165            0 => {
166                // kIBSeekSet - absolute position from start
167                if pos < 0 {
168                    return kResultFalse;
169                }
170                pos as usize
171            }
172            1 => {
173                // kIBSeekCur - relative to current position
174                if pos < 0 {
175                    current_pos.saturating_sub((-pos) as usize)
176                } else {
177                    current_pos.saturating_add(pos as usize)
178                }
179            }
180            2 => {
181                // kIBSeekEnd - relative to end
182                if pos > 0 {
183                    return kResultFalse;
184                }
185                data_len.saturating_sub((-pos) as usize)
186            }
187            _ => return kResultFalse,
188        };
189
190        unsafe {
191            *self.position_mut() = new_position;
192        }
193
194        if !result.is_null() {
195            unsafe {
196                *result = new_position as i64;
197            }
198        }
199
200        kResultOk
201    }
202
203    unsafe fn tell(&self, pos: *mut i64) -> TResult {
204        if pos.is_null() {
205            return kResultFalse;
206        }
207
208        unsafe {
209            *pos = *self.position_ref() as i64;
210        }
211        kResultOk
212    }
213}
214
215#[cfg(test)]
216mod tests {
217    use super::*;
218
219    #[test]
220    fn test_memory_stream_write_read() {
221        let stream = MemoryStream::new();
222        let test_data = b"Hello, VST3!";
223
224        unsafe {
225            let mut written = 0;
226            let result = stream.write(
227                test_data.as_ptr() as *mut std::os::raw::c_void,
228                test_data.len() as i32,
229                &mut written,
230            );
231            assert_eq!(result, kResultOk);
232            assert_eq!(written, test_data.len() as i32);
233        }
234
235        // Seek back to start
236        unsafe {
237            let mut new_pos = 0;
238            stream.seek(0, 0, &mut new_pos);
239            assert_eq!(new_pos, 0);
240        }
241
242        // Read back
243        let mut read_buffer = vec![0u8; test_data.len()];
244        unsafe {
245            let mut read_count = 0;
246            let result = stream.read(
247                read_buffer.as_mut_ptr() as *mut _,
248                test_data.len() as i32,
249                &mut read_count,
250            );
251            assert_eq!(result, kResultOk);
252            assert_eq!(read_count, test_data.len() as i32);
253        }
254
255        assert_eq!(&read_buffer, test_data);
256    }
257
258    #[test]
259    fn test_memory_stream_seek() {
260        let stream = MemoryStream::from_bytes(b"0123456789");
261
262        // Seek to position 5
263        unsafe {
264            let mut pos = 0;
265            stream.seek(5, 0, &mut pos);
266            assert_eq!(pos, 5);
267        }
268
269        // Tell should return 5
270        unsafe {
271            let mut pos = 0;
272            stream.tell(&mut pos);
273            assert_eq!(pos, 5);
274        }
275
276        // Seek relative forward
277        unsafe {
278            let mut pos = 0;
279            stream.seek(2, 1, &mut pos);
280            assert_eq!(pos, 7);
281        }
282
283        // Seek from end
284        unsafe {
285            let mut pos = 0;
286            stream.seek(-3, 2, &mut pos);
287            assert_eq!(pos, 7);
288        }
289    }
290
291    #[test]
292    fn test_plugin_state_serialization() {
293        let state = Vst3PluginState {
294            plugin_id: "com.example.plugin".to_string(),
295            component_state: vec![1, 2, 3, 4],
296            controller_state: vec![5, 6, 7, 8],
297        };
298
299        let json = serde_json::to_string(&state).unwrap();
300        let deserialized: Vst3PluginState = serde_json::from_str(&json).unwrap();
301
302        assert_eq!(state.plugin_id, deserialized.plugin_id);
303        assert_eq!(state.component_state, deserialized.component_state);
304        assert_eq!(state.controller_state, deserialized.controller_state);
305    }
306}