maolan_engine/plugins/vst3/
state.rs1use serde::{Deserialize, Serialize};
2use std::cell::UnsafeCell;
3use vst3::Steinberg::{IBStreamTrait, kResultFalse, kResultOk};
4use vst3::{Class, ComWrapper};
5
6type TResult = i32;
7
8#[derive(Clone, Debug, Serialize, Deserialize)]
10pub struct Vst3PluginState {
11 pub plugin_id: String,
12 pub component_state: Vec<u8>,
13 pub controller_state: Vec<u8>,
14}
15
16pub struct MemoryStream {
19 data: UnsafeCell<Vec<u8>>,
20 position: UnsafeCell<usize>,
21}
22
23impl Class for MemoryStream {
24 type Interfaces = (vst3::Steinberg::IBStream,);
25}
26
27impl Default for MemoryStream {
28 fn default() -> Self {
29 Self::new()
30 }
31}
32
33impl MemoryStream {
34 pub fn new() -> Self {
35 Self {
36 data: UnsafeCell::new(Vec::new()),
37 position: UnsafeCell::new(0),
38 }
39 }
40
41 pub fn from_bytes(data: &[u8]) -> Self {
42 Self {
43 data: UnsafeCell::new(data.to_vec()),
44 position: UnsafeCell::new(0),
45 }
46 }
47
48 pub fn bytes(&self) -> Vec<u8> {
49 unsafe { self.data_ref().clone() }
50 }
51
52 #[allow(clippy::mut_from_ref)]
53 unsafe fn data_mut(&self) -> &mut Vec<u8> {
54 unsafe { &mut *self.data.get() }
55 }
56
57 #[allow(clippy::mut_from_ref)]
58 unsafe fn position_mut(&self) -> &mut usize {
59 unsafe { &mut *self.position.get() }
60 }
61
62 unsafe fn data_ref(&self) -> &Vec<u8> {
63 unsafe { &*self.data.get() }
64 }
65
66 unsafe fn position_ref(&self) -> &usize {
67 unsafe { &*self.position.get() }
68 }
69}
70
71pub fn ibstream_ptr(stream: &ComWrapper<MemoryStream>) -> *mut vst3::Steinberg::IBStream {
72 stream
73 .as_com_ref::<vst3::Steinberg::IBStream>()
74 .map(|r| r.as_ptr())
75 .unwrap_or(std::ptr::null_mut())
76}
77
78impl IBStreamTrait for MemoryStream {
79 unsafe fn read(
80 &self,
81 buffer: *mut std::os::raw::c_void,
82 num_bytes: i32,
83 num_bytes_read: *mut i32,
84 ) -> TResult {
85 if buffer.is_null() || num_bytes < 0 {
86 return kResultFalse;
87 }
88
89 let bytes_to_read = num_bytes as usize;
90 let data = unsafe { self.data_ref() };
91 let position = unsafe { *self.position_ref() };
92 let available = data.len().saturating_sub(position);
93 let actual_read = bytes_to_read.min(available);
94
95 if actual_read == 0 {
96 if !num_bytes_read.is_null() {
97 unsafe {
98 *num_bytes_read = 0;
99 }
100 }
101 return kResultFalse;
102 }
103
104 let src_slice = &data[position..position + actual_read];
105 let dst_slice = unsafe { std::slice::from_raw_parts_mut(buffer as *mut u8, actual_read) };
106 dst_slice.copy_from_slice(src_slice);
107
108 unsafe {
109 *self.position_mut() += actual_read;
110 }
111
112 if !num_bytes_read.is_null() {
113 unsafe {
114 *num_bytes_read = actual_read as i32;
115 }
116 }
117
118 kResultOk
119 }
120
121 unsafe fn write(
122 &self,
123 buffer: *mut std::os::raw::c_void,
124 num_bytes: i32,
125 num_bytes_written: *mut i32,
126 ) -> TResult {
127 if buffer.is_null() || num_bytes < 0 {
128 return kResultFalse;
129 }
130
131 let bytes_to_write = num_bytes as usize;
132 let src_slice = unsafe { std::slice::from_raw_parts(buffer as *mut u8, bytes_to_write) };
133
134 let data = unsafe { self.data_mut() };
135 let position = unsafe { *self.position_ref() };
136
137 let required_len = position + bytes_to_write;
138 if required_len > data.len() {
139 data.resize(required_len, 0);
140 }
141
142 data[position..position + bytes_to_write].copy_from_slice(src_slice);
143 unsafe {
144 *self.position_mut() += bytes_to_write;
145 }
146
147 if !num_bytes_written.is_null() {
148 unsafe {
149 *num_bytes_written = bytes_to_write as i32;
150 }
151 }
152
153 kResultOk
154 }
155
156 unsafe fn seek(&self, pos: i64, mode: i32, result: *mut i64) -> TResult {
157 let current_pos = unsafe { *self.position_ref() };
158 let data_len = unsafe { self.data_ref().len() };
159
160 let new_position = match mode {
161 0 => {
162 if pos < 0 {
163 return kResultFalse;
164 }
165 pos as usize
166 }
167 1 => {
168 if pos < 0 {
169 current_pos.saturating_sub((-pos) as usize)
170 } else {
171 current_pos.saturating_add(pos as usize)
172 }
173 }
174 2 => {
175 if pos > 0 {
176 return kResultFalse;
177 }
178 data_len.saturating_sub((-pos) as usize)
179 }
180 _ => return kResultFalse,
181 };
182
183 unsafe {
184 *self.position_mut() = new_position;
185 }
186
187 if !result.is_null() {
188 unsafe {
189 *result = new_position as i64;
190 }
191 }
192
193 kResultOk
194 }
195
196 unsafe fn tell(&self, pos: *mut i64) -> TResult {
197 if pos.is_null() {
198 return kResultFalse;
199 }
200
201 unsafe {
202 *pos = *self.position_ref() as i64;
203 }
204 kResultOk
205 }
206}
207
208#[cfg(test)]
209mod tests {
210 use super::*;
211
212 #[test]
213 fn test_memory_stream_write_read() {
214 let stream = MemoryStream::new();
215 let test_data = b"Hello, VST3!";
216
217 unsafe {
218 let mut written = 0;
219 let result = stream.write(
220 test_data.as_ptr() as *mut std::os::raw::c_void,
221 test_data.len() as i32,
222 &mut written,
223 );
224 assert_eq!(result, kResultOk);
225 assert_eq!(written, test_data.len() as i32);
226 }
227
228 unsafe {
229 let mut new_pos = 0;
230 stream.seek(0, 0, &mut new_pos);
231 assert_eq!(new_pos, 0);
232 }
233
234 let mut read_buffer = vec![0u8; test_data.len()];
235 unsafe {
236 let mut read_count = 0;
237 let result = stream.read(
238 read_buffer.as_mut_ptr() as *mut _,
239 test_data.len() as i32,
240 &mut read_count,
241 );
242 assert_eq!(result, kResultOk);
243 assert_eq!(read_count, test_data.len() as i32);
244 }
245
246 assert_eq!(&read_buffer, test_data);
247 }
248
249 #[test]
250 fn test_memory_stream_seek() {
251 let stream = MemoryStream::from_bytes(b"0123456789");
252
253 unsafe {
254 let mut pos = 0;
255 stream.seek(5, 0, &mut pos);
256 assert_eq!(pos, 5);
257 }
258
259 unsafe {
260 let mut pos = 0;
261 stream.tell(&mut pos);
262 assert_eq!(pos, 5);
263 }
264
265 unsafe {
266 let mut pos = 0;
267 stream.seek(2, 1, &mut pos);
268 assert_eq!(pos, 7);
269 }
270
271 unsafe {
272 let mut pos = 0;
273 stream.seek(-3, 2, &mut pos);
274 assert_eq!(pos, 7);
275 }
276 }
277
278 #[test]
279 fn test_plugin_state_serialization() {
280 let state = Vst3PluginState {
281 plugin_id: "com.example.plugin".to_string(),
282 component_state: vec![1, 2, 3, 4],
283 controller_state: vec![5, 6, 7, 8],
284 };
285
286 let json = serde_json::to_string(&state).unwrap();
287 let deserialized: Vst3PluginState = serde_json::from_str(&json).unwrap();
288
289 assert_eq!(state.plugin_id, deserialized.plugin_id);
290 assert_eq!(state.component_state, deserialized.component_state);
291 assert_eq!(state.controller_state, deserialized.controller_state);
292 }
293}