maolan_engine/plugins/vst3/
state.rs1use serde::{Deserialize, Serialize};
2use std::cell::UnsafeCell;
3use vst3::Steinberg::{IBStreamTrait, kResultFalse, kResultOk};
4use vst3::{Class, ComWrapper};
5
6type TResult = i32;
7
8#[derive(Clone, Debug, Serialize, Deserialize)]
10pub struct Vst3PluginState {
11 pub plugin_id: String,
12 pub component_state: Vec<u8>,
13 pub controller_state: Vec<u8>,
14}
15
16pub struct MemoryStream {
19 data: UnsafeCell<Vec<u8>>,
20 position: UnsafeCell<usize>,
21}
22
23impl Class for MemoryStream {
24 type Interfaces = (vst3::Steinberg::IBStream,);
25}
26
27impl Default for MemoryStream {
28 fn default() -> Self {
29 Self::new()
30 }
31}
32
33impl MemoryStream {
34 pub fn new() -> Self {
35 Self {
36 data: UnsafeCell::new(Vec::new()),
37 position: UnsafeCell::new(0),
38 }
39 }
40
41 pub fn from_bytes(data: &[u8]) -> Self {
42 Self {
43 data: UnsafeCell::new(data.to_vec()),
44 position: UnsafeCell::new(0),
45 }
46 }
47
48 pub fn bytes(&self) -> Vec<u8> {
49 unsafe { self.data_ref().clone() }
50 }
51
52 #[allow(clippy::mut_from_ref)]
54 unsafe fn data_mut(&self) -> &mut Vec<u8> {
55 unsafe { &mut *self.data.get() }
56 }
57
58 #[allow(clippy::mut_from_ref)]
59 unsafe fn position_mut(&self) -> &mut usize {
60 unsafe { &mut *self.position.get() }
61 }
62
63 unsafe fn data_ref(&self) -> &Vec<u8> {
64 unsafe { &*self.data.get() }
65 }
66
67 unsafe fn position_ref(&self) -> &usize {
68 unsafe { &*self.position.get() }
69 }
70}
71
72pub fn ibstream_ptr(stream: &ComWrapper<MemoryStream>) -> *mut vst3::Steinberg::IBStream {
73 stream
74 .as_com_ref::<vst3::Steinberg::IBStream>()
75 .map(|r| r.as_ptr())
76 .unwrap_or(std::ptr::null_mut())
77}
78
79impl IBStreamTrait for MemoryStream {
80 unsafe fn read(
81 &self,
82 buffer: *mut std::os::raw::c_void,
83 num_bytes: i32,
84 num_bytes_read: *mut i32,
85 ) -> TResult {
86 if buffer.is_null() || num_bytes < 0 {
87 return kResultFalse;
88 }
89
90 let bytes_to_read = num_bytes as usize;
91 let data = unsafe { self.data_ref() };
92 let position = unsafe { *self.position_ref() };
93 let available = data.len().saturating_sub(position);
94 let actual_read = bytes_to_read.min(available);
95
96 if actual_read == 0 {
97 if !num_bytes_read.is_null() {
98 unsafe {
99 *num_bytes_read = 0;
100 }
101 }
102 return kResultFalse;
103 }
104
105 let src_slice = &data[position..position + actual_read];
107 let dst_slice = unsafe { std::slice::from_raw_parts_mut(buffer as *mut u8, actual_read) };
108 dst_slice.copy_from_slice(src_slice);
109
110 unsafe {
111 *self.position_mut() += actual_read;
112 }
113
114 if !num_bytes_read.is_null() {
115 unsafe {
116 *num_bytes_read = actual_read as i32;
117 }
118 }
119
120 kResultOk
121 }
122
123 unsafe fn write(
124 &self,
125 buffer: *mut std::os::raw::c_void,
126 num_bytes: i32,
127 num_bytes_written: *mut i32,
128 ) -> TResult {
129 if buffer.is_null() || num_bytes < 0 {
130 return kResultFalse;
131 }
132
133 let bytes_to_write = num_bytes as usize;
134 let src_slice = unsafe { std::slice::from_raw_parts(buffer as *mut u8, bytes_to_write) };
135
136 let data = unsafe { self.data_mut() };
137 let position = unsafe { *self.position_ref() };
138
139 let required_len = position + bytes_to_write;
141 if required_len > data.len() {
142 data.resize(required_len, 0);
143 }
144
145 data[position..position + bytes_to_write].copy_from_slice(src_slice);
147 unsafe {
148 *self.position_mut() += bytes_to_write;
149 }
150
151 if !num_bytes_written.is_null() {
152 unsafe {
153 *num_bytes_written = bytes_to_write as i32;
154 }
155 }
156
157 kResultOk
158 }
159
160 unsafe fn seek(&self, pos: i64, mode: i32, result: *mut i64) -> TResult {
161 let current_pos = unsafe { *self.position_ref() };
162 let data_len = unsafe { self.data_ref().len() };
163
164 let new_position = match mode {
165 0 => {
166 if pos < 0 {
168 return kResultFalse;
169 }
170 pos as usize
171 }
172 1 => {
173 if pos < 0 {
175 current_pos.saturating_sub((-pos) as usize)
176 } else {
177 current_pos.saturating_add(pos as usize)
178 }
179 }
180 2 => {
181 if pos > 0 {
183 return kResultFalse;
184 }
185 data_len.saturating_sub((-pos) as usize)
186 }
187 _ => return kResultFalse,
188 };
189
190 unsafe {
191 *self.position_mut() = new_position;
192 }
193
194 if !result.is_null() {
195 unsafe {
196 *result = new_position as i64;
197 }
198 }
199
200 kResultOk
201 }
202
203 unsafe fn tell(&self, pos: *mut i64) -> TResult {
204 if pos.is_null() {
205 return kResultFalse;
206 }
207
208 unsafe {
209 *pos = *self.position_ref() as i64;
210 }
211 kResultOk
212 }
213}
214
215#[cfg(test)]
216mod tests {
217 use super::*;
218
219 #[test]
220 fn test_memory_stream_write_read() {
221 let stream = MemoryStream::new();
222 let test_data = b"Hello, VST3!";
223
224 unsafe {
225 let mut written = 0;
226 let result = stream.write(
227 test_data.as_ptr() as *mut std::os::raw::c_void,
228 test_data.len() as i32,
229 &mut written,
230 );
231 assert_eq!(result, kResultOk);
232 assert_eq!(written, test_data.len() as i32);
233 }
234
235 unsafe {
237 let mut new_pos = 0;
238 stream.seek(0, 0, &mut new_pos);
239 assert_eq!(new_pos, 0);
240 }
241
242 let mut read_buffer = vec![0u8; test_data.len()];
244 unsafe {
245 let mut read_count = 0;
246 let result = stream.read(
247 read_buffer.as_mut_ptr() as *mut _,
248 test_data.len() as i32,
249 &mut read_count,
250 );
251 assert_eq!(result, kResultOk);
252 assert_eq!(read_count, test_data.len() as i32);
253 }
254
255 assert_eq!(&read_buffer, test_data);
256 }
257
258 #[test]
259 fn test_memory_stream_seek() {
260 let stream = MemoryStream::from_bytes(b"0123456789");
261
262 unsafe {
264 let mut pos = 0;
265 stream.seek(5, 0, &mut pos);
266 assert_eq!(pos, 5);
267 }
268
269 unsafe {
271 let mut pos = 0;
272 stream.tell(&mut pos);
273 assert_eq!(pos, 5);
274 }
275
276 unsafe {
278 let mut pos = 0;
279 stream.seek(2, 1, &mut pos);
280 assert_eq!(pos, 7);
281 }
282
283 unsafe {
285 let mut pos = 0;
286 stream.seek(-3, 2, &mut pos);
287 assert_eq!(pos, 7);
288 }
289 }
290
291 #[test]
292 fn test_plugin_state_serialization() {
293 let state = Vst3PluginState {
294 plugin_id: "com.example.plugin".to_string(),
295 component_state: vec![1, 2, 3, 4],
296 controller_state: vec![5, 6, 7, 8],
297 };
298
299 let json = serde_json::to_string(&state).unwrap();
300 let deserialized: Vst3PluginState = serde_json::from_str(&json).unwrap();
301
302 assert_eq!(state.plugin_id, deserialized.plugin_id);
303 assert_eq!(state.component_state, deserialized.component_state);
304 assert_eq!(state.controller_state, deserialized.controller_state);
305 }
306}