Skip to main content

maolan_plugin_host/vst3/
state.rs

1use serde::{Deserialize, Serialize};
2use std::cell::UnsafeCell;
3use vst3::Steinberg::{IBStreamTrait, kResultFalse, kResultOk};
4use vst3::{Class, ComWrapper};
5
6type TResult = i32;
7
8#[derive(Clone, Debug, Serialize, Deserialize)]
9pub struct Vst3PluginState {
10    pub plugin_id: String,
11    pub component_state: Vec<u8>,
12    pub controller_state: Vec<u8>,
13}
14
15pub struct MemoryStream {
16    data: UnsafeCell<Vec<u8>>,
17    position: UnsafeCell<usize>,
18}
19
20impl Class for MemoryStream {
21    type Interfaces = (vst3::Steinberg::IBStream,);
22}
23
24impl Default for MemoryStream {
25    fn default() -> Self {
26        Self::new()
27    }
28}
29
30impl MemoryStream {
31    pub fn new() -> Self {
32        Self {
33            data: UnsafeCell::new(Vec::new()),
34            position: UnsafeCell::new(0),
35        }
36    }
37
38    pub fn from_bytes(data: &[u8]) -> Self {
39        Self {
40            data: UnsafeCell::new(data.to_vec()),
41            position: UnsafeCell::new(0),
42        }
43    }
44
45    pub fn bytes(&self) -> Vec<u8> {
46        unsafe { self.data_ref().clone() }
47    }
48
49    #[allow(clippy::mut_from_ref)]
50    unsafe fn data_mut(&self) -> &mut Vec<u8> {
51        unsafe { &mut *self.data.get() }
52    }
53
54    #[allow(clippy::mut_from_ref)]
55    unsafe fn position_mut(&self) -> &mut usize {
56        unsafe { &mut *self.position.get() }
57    }
58
59    unsafe fn data_ref(&self) -> &Vec<u8> {
60        unsafe { &*self.data.get() }
61    }
62
63    unsafe fn position_ref(&self) -> &usize {
64        unsafe { &*self.position.get() }
65    }
66}
67
68pub fn ibstream_ptr(stream: &ComWrapper<MemoryStream>) -> *mut vst3::Steinberg::IBStream {
69    stream
70        .as_com_ref::<vst3::Steinberg::IBStream>()
71        .map(|r| r.as_ptr())
72        .unwrap_or(std::ptr::null_mut())
73}
74
75impl IBStreamTrait for MemoryStream {
76    unsafe fn read(
77        &self,
78        buffer: *mut std::os::raw::c_void,
79        num_bytes: i32,
80        num_bytes_read: *mut i32,
81    ) -> TResult {
82        if buffer.is_null() || num_bytes < 0 {
83            return kResultFalse;
84        }
85
86        let bytes_to_read = num_bytes as usize;
87        let data = unsafe { self.data_ref() };
88        let position = unsafe { *self.position_ref() };
89        let available = data.len().saturating_sub(position);
90        let actual_read = bytes_to_read.min(available);
91
92        if actual_read == 0 {
93            if !num_bytes_read.is_null() {
94                unsafe {
95                    *num_bytes_read = 0;
96                }
97            }
98            return kResultFalse;
99        }
100
101        let src_slice = &data[position..position + actual_read];
102        let dst_slice = unsafe { std::slice::from_raw_parts_mut(buffer as *mut u8, actual_read) };
103        dst_slice.copy_from_slice(src_slice);
104
105        unsafe {
106            *self.position_mut() += actual_read;
107        }
108
109        if !num_bytes_read.is_null() {
110            unsafe {
111                *num_bytes_read = actual_read as i32;
112            }
113        }
114
115        kResultOk
116    }
117
118    unsafe fn write(
119        &self,
120        buffer: *mut std::os::raw::c_void,
121        num_bytes: i32,
122        num_bytes_written: *mut i32,
123    ) -> TResult {
124        if buffer.is_null() || num_bytes < 0 {
125            return kResultFalse;
126        }
127
128        let bytes_to_write = num_bytes as usize;
129        let src_slice = unsafe { std::slice::from_raw_parts(buffer as *mut u8, bytes_to_write) };
130
131        let data = unsafe { self.data_mut() };
132        let position = unsafe { *self.position_ref() };
133
134        let required_len = position + bytes_to_write;
135        if required_len > data.len() {
136            data.resize(required_len, 0);
137        }
138
139        data[position..position + bytes_to_write].copy_from_slice(src_slice);
140        unsafe {
141            *self.position_mut() += bytes_to_write;
142        }
143
144        if !num_bytes_written.is_null() {
145            unsafe {
146                *num_bytes_written = bytes_to_write as i32;
147            }
148        }
149
150        kResultOk
151    }
152
153    unsafe fn seek(&self, pos: i64, mode: i32, result: *mut i64) -> TResult {
154        let current_pos = unsafe { *self.position_ref() };
155        let data_len = unsafe { self.data_ref().len() };
156
157        let new_position = match mode {
158            0 => {
159                if pos < 0 {
160                    return kResultFalse;
161                }
162                pos as usize
163            }
164            1 => {
165                if pos < 0 {
166                    current_pos.saturating_sub((-pos) as usize)
167                } else {
168                    current_pos.saturating_add(pos as usize)
169                }
170            }
171            2 => {
172                if pos > 0 {
173                    return kResultFalse;
174                }
175                data_len.saturating_sub((-pos) as usize)
176            }
177            _ => return kResultFalse,
178        };
179
180        unsafe {
181            *self.position_mut() = new_position;
182        }
183
184        if !result.is_null() {
185            unsafe {
186                *result = new_position as i64;
187            }
188        }
189
190        kResultOk
191    }
192
193    unsafe fn tell(&self, pos: *mut i64) -> TResult {
194        if pos.is_null() {
195            return kResultFalse;
196        }
197
198        unsafe {
199            *pos = *self.position_ref() as i64;
200        }
201        kResultOk
202    }
203}
204
205#[cfg(test)]
206mod tests {
207    use super::*;
208
209    #[test]
210    fn test_memory_stream_write_read() {
211        let stream = MemoryStream::new();
212        let test_data = b"Hello, VST3!";
213
214        unsafe {
215            let mut written = 0;
216            let result = stream.write(
217                test_data.as_ptr() as *mut std::os::raw::c_void,
218                test_data.len() as i32,
219                &mut written,
220            );
221            assert_eq!(result, kResultOk);
222            assert_eq!(written, test_data.len() as i32);
223        }
224
225        unsafe {
226            let mut new_pos = 0;
227            stream.seek(0, 0, &mut new_pos);
228            assert_eq!(new_pos, 0);
229        }
230
231        let mut read_buffer = vec![0u8; test_data.len()];
232        unsafe {
233            let mut read_count = 0;
234            let result = stream.read(
235                read_buffer.as_mut_ptr() as *mut _,
236                test_data.len() as i32,
237                &mut read_count,
238            );
239            assert_eq!(result, kResultOk);
240            assert_eq!(read_count, test_data.len() as i32);
241        }
242
243        assert_eq!(&read_buffer, test_data);
244    }
245
246    #[test]
247    fn test_memory_stream_seek() {
248        let stream = MemoryStream::from_bytes(b"0123456789");
249
250        unsafe {
251            let mut pos = 0;
252            stream.seek(5, 0, &mut pos);
253            assert_eq!(pos, 5);
254        }
255
256        unsafe {
257            let mut pos = 0;
258            stream.tell(&mut pos);
259            assert_eq!(pos, 5);
260        }
261
262        unsafe {
263            let mut pos = 0;
264            stream.seek(2, 1, &mut pos);
265            assert_eq!(pos, 7);
266        }
267
268        unsafe {
269            let mut pos = 0;
270            stream.seek(-3, 2, &mut pos);
271            assert_eq!(pos, 7);
272        }
273    }
274
275    #[test]
276    fn test_plugin_state_serialization() {
277        let state = Vst3PluginState {
278            plugin_id: "com.example.plugin".to_string(),
279            component_state: vec![1, 2, 3, 4],
280            controller_state: vec![5, 6, 7, 8],
281        };
282
283        let json = serde_json::to_string(&state).unwrap();
284        let deserialized: Vst3PluginState = serde_json::from_str(&json).unwrap();
285
286        assert_eq!(state.plugin_id, deserialized.plugin_id);
287        assert_eq!(state.component_state, deserialized.component_state);
288        assert_eq!(state.controller_state, deserialized.controller_state);
289    }
290}