Skip to main content

maolan_engine/plugins/vst3/
state.rs

1use serde::{Deserialize, Serialize};
2use std::cell::UnsafeCell;
3use vst3::Steinberg::{IBStreamTrait, kResultFalse, kResultOk};
4use vst3::{Class, ComWrapper};
5
6type TResult = i32;
7
8/// VST3 plugin state snapshot
9#[derive(Clone, Debug, Serialize, Deserialize)]
10pub struct Vst3PluginState {
11    pub plugin_id: String,
12    pub component_state: Vec<u8>,
13    pub controller_state: Vec<u8>,
14}
15
16/// Memory-based stream for VST3 state I/O
17/// Uses UnsafeCell for interior mutability as required by IBStreamTrait
18pub struct MemoryStream {
19    data: UnsafeCell<Vec<u8>>,
20    position: UnsafeCell<usize>,
21}
22
23impl Class for MemoryStream {
24    type Interfaces = (vst3::Steinberg::IBStream,);
25}
26
27impl Default for MemoryStream {
28    fn default() -> Self {
29        Self::new()
30    }
31}
32
33impl MemoryStream {
34    pub fn new() -> Self {
35        Self {
36            data: UnsafeCell::new(Vec::new()),
37            position: UnsafeCell::new(0),
38        }
39    }
40
41    pub fn from_bytes(data: &[u8]) -> Self {
42        Self {
43            data: UnsafeCell::new(data.to_vec()),
44            position: UnsafeCell::new(0),
45        }
46    }
47
48    pub fn bytes(&self) -> Vec<u8> {
49        unsafe { self.data_ref().clone() }
50    }
51
52    #[allow(clippy::mut_from_ref)]
53    unsafe fn data_mut(&self) -> &mut Vec<u8> {
54        unsafe { &mut *self.data.get() }
55    }
56
57    #[allow(clippy::mut_from_ref)]
58    unsafe fn position_mut(&self) -> &mut usize {
59        unsafe { &mut *self.position.get() }
60    }
61
62    unsafe fn data_ref(&self) -> &Vec<u8> {
63        unsafe { &*self.data.get() }
64    }
65
66    unsafe fn position_ref(&self) -> &usize {
67        unsafe { &*self.position.get() }
68    }
69}
70
71pub fn ibstream_ptr(stream: &ComWrapper<MemoryStream>) -> *mut vst3::Steinberg::IBStream {
72    stream
73        .as_com_ref::<vst3::Steinberg::IBStream>()
74        .map(|r| r.as_ptr())
75        .unwrap_or(std::ptr::null_mut())
76}
77
78impl IBStreamTrait for MemoryStream {
79    unsafe fn read(
80        &self,
81        buffer: *mut std::os::raw::c_void,
82        num_bytes: i32,
83        num_bytes_read: *mut i32,
84    ) -> TResult {
85        if buffer.is_null() || num_bytes < 0 {
86            return kResultFalse;
87        }
88
89        let bytes_to_read = num_bytes as usize;
90        let data = unsafe { self.data_ref() };
91        let position = unsafe { *self.position_ref() };
92        let available = data.len().saturating_sub(position);
93        let actual_read = bytes_to_read.min(available);
94
95        if actual_read == 0 {
96            if !num_bytes_read.is_null() {
97                unsafe {
98                    *num_bytes_read = 0;
99                }
100            }
101            return kResultFalse;
102        }
103
104        let src_slice = &data[position..position + actual_read];
105        let dst_slice = unsafe { std::slice::from_raw_parts_mut(buffer as *mut u8, actual_read) };
106        dst_slice.copy_from_slice(src_slice);
107
108        unsafe {
109            *self.position_mut() += actual_read;
110        }
111
112        if !num_bytes_read.is_null() {
113            unsafe {
114                *num_bytes_read = actual_read as i32;
115            }
116        }
117
118        kResultOk
119    }
120
121    unsafe fn write(
122        &self,
123        buffer: *mut std::os::raw::c_void,
124        num_bytes: i32,
125        num_bytes_written: *mut i32,
126    ) -> TResult {
127        if buffer.is_null() || num_bytes < 0 {
128            return kResultFalse;
129        }
130
131        let bytes_to_write = num_bytes as usize;
132        let src_slice = unsafe { std::slice::from_raw_parts(buffer as *mut u8, bytes_to_write) };
133
134        let data = unsafe { self.data_mut() };
135        let position = unsafe { *self.position_ref() };
136
137        let required_len = position + bytes_to_write;
138        if required_len > data.len() {
139            data.resize(required_len, 0);
140        }
141
142        data[position..position + bytes_to_write].copy_from_slice(src_slice);
143        unsafe {
144            *self.position_mut() += bytes_to_write;
145        }
146
147        if !num_bytes_written.is_null() {
148            unsafe {
149                *num_bytes_written = bytes_to_write as i32;
150            }
151        }
152
153        kResultOk
154    }
155
156    unsafe fn seek(&self, pos: i64, mode: i32, result: *mut i64) -> TResult {
157        let current_pos = unsafe { *self.position_ref() };
158        let data_len = unsafe { self.data_ref().len() };
159
160        let new_position = match mode {
161            0 => {
162                if pos < 0 {
163                    return kResultFalse;
164                }
165                pos as usize
166            }
167            1 => {
168                if pos < 0 {
169                    current_pos.saturating_sub((-pos) as usize)
170                } else {
171                    current_pos.saturating_add(pos as usize)
172                }
173            }
174            2 => {
175                if pos > 0 {
176                    return kResultFalse;
177                }
178                data_len.saturating_sub((-pos) as usize)
179            }
180            _ => return kResultFalse,
181        };
182
183        unsafe {
184            *self.position_mut() = new_position;
185        }
186
187        if !result.is_null() {
188            unsafe {
189                *result = new_position as i64;
190            }
191        }
192
193        kResultOk
194    }
195
196    unsafe fn tell(&self, pos: *mut i64) -> TResult {
197        if pos.is_null() {
198            return kResultFalse;
199        }
200
201        unsafe {
202            *pos = *self.position_ref() as i64;
203        }
204        kResultOk
205    }
206}
207
208#[cfg(test)]
209mod tests {
210    use super::*;
211
212    #[test]
213    fn test_memory_stream_write_read() {
214        let stream = MemoryStream::new();
215        let test_data = b"Hello, VST3!";
216
217        unsafe {
218            let mut written = 0;
219            let result = stream.write(
220                test_data.as_ptr() as *mut std::os::raw::c_void,
221                test_data.len() as i32,
222                &mut written,
223            );
224            assert_eq!(result, kResultOk);
225            assert_eq!(written, test_data.len() as i32);
226        }
227
228        unsafe {
229            let mut new_pos = 0;
230            stream.seek(0, 0, &mut new_pos);
231            assert_eq!(new_pos, 0);
232        }
233
234        let mut read_buffer = vec![0u8; test_data.len()];
235        unsafe {
236            let mut read_count = 0;
237            let result = stream.read(
238                read_buffer.as_mut_ptr() as *mut _,
239                test_data.len() as i32,
240                &mut read_count,
241            );
242            assert_eq!(result, kResultOk);
243            assert_eq!(read_count, test_data.len() as i32);
244        }
245
246        assert_eq!(&read_buffer, test_data);
247    }
248
249    #[test]
250    fn test_memory_stream_seek() {
251        let stream = MemoryStream::from_bytes(b"0123456789");
252
253        unsafe {
254            let mut pos = 0;
255            stream.seek(5, 0, &mut pos);
256            assert_eq!(pos, 5);
257        }
258
259        unsafe {
260            let mut pos = 0;
261            stream.tell(&mut pos);
262            assert_eq!(pos, 5);
263        }
264
265        unsafe {
266            let mut pos = 0;
267            stream.seek(2, 1, &mut pos);
268            assert_eq!(pos, 7);
269        }
270
271        unsafe {
272            let mut pos = 0;
273            stream.seek(-3, 2, &mut pos);
274            assert_eq!(pos, 7);
275        }
276    }
277
278    #[test]
279    fn test_plugin_state_serialization() {
280        let state = Vst3PluginState {
281            plugin_id: "com.example.plugin".to_string(),
282            component_state: vec![1, 2, 3, 4],
283            controller_state: vec![5, 6, 7, 8],
284        };
285
286        let json = serde_json::to_string(&state).unwrap();
287        let deserialized: Vst3PluginState = serde_json::from_str(&json).unwrap();
288
289        assert_eq!(state.plugin_id, deserialized.plugin_id);
290        assert_eq!(state.component_state, deserialized.component_state);
291        assert_eq!(state.controller_state, deserialized.controller_state);
292    }
293}