maolan_plugin_host/vst3/
state.rs1use serde::{Deserialize, Serialize};
2use std::cell::UnsafeCell;
3use vst3::Steinberg::{IBStreamTrait, kResultFalse, kResultOk};
4use vst3::{Class, ComWrapper};
5
6type TResult = i32;
7
8#[derive(Clone, Debug, Serialize, Deserialize)]
9pub struct Vst3PluginState {
10 pub plugin_id: String,
11 pub component_state: Vec<u8>,
12 pub controller_state: Vec<u8>,
13}
14
15pub struct MemoryStream {
16 data: UnsafeCell<Vec<u8>>,
17 position: UnsafeCell<usize>,
18}
19
20impl Class for MemoryStream {
21 type Interfaces = (vst3::Steinberg::IBStream,);
22}
23
24impl Default for MemoryStream {
25 fn default() -> Self {
26 Self::new()
27 }
28}
29
30impl MemoryStream {
31 pub fn new() -> Self {
32 Self {
33 data: UnsafeCell::new(Vec::new()),
34 position: UnsafeCell::new(0),
35 }
36 }
37
38 pub fn from_bytes(data: &[u8]) -> Self {
39 Self {
40 data: UnsafeCell::new(data.to_vec()),
41 position: UnsafeCell::new(0),
42 }
43 }
44
45 pub fn bytes(&self) -> Vec<u8> {
46 unsafe { self.data_ref().clone() }
47 }
48
49 #[allow(clippy::mut_from_ref)]
50 unsafe fn data_mut(&self) -> &mut Vec<u8> {
51 unsafe { &mut *self.data.get() }
52 }
53
54 #[allow(clippy::mut_from_ref)]
55 unsafe fn position_mut(&self) -> &mut usize {
56 unsafe { &mut *self.position.get() }
57 }
58
59 unsafe fn data_ref(&self) -> &Vec<u8> {
60 unsafe { &*self.data.get() }
61 }
62
63 unsafe fn position_ref(&self) -> &usize {
64 unsafe { &*self.position.get() }
65 }
66}
67
68pub fn ibstream_ptr(stream: &ComWrapper<MemoryStream>) -> *mut vst3::Steinberg::IBStream {
69 stream
70 .as_com_ref::<vst3::Steinberg::IBStream>()
71 .map(|r| r.as_ptr())
72 .unwrap_or(std::ptr::null_mut())
73}
74
75impl IBStreamTrait for MemoryStream {
76 unsafe fn read(
77 &self,
78 buffer: *mut std::os::raw::c_void,
79 num_bytes: i32,
80 num_bytes_read: *mut i32,
81 ) -> TResult {
82 if buffer.is_null() || num_bytes < 0 {
83 return kResultFalse;
84 }
85
86 let bytes_to_read = num_bytes as usize;
87 let data = unsafe { self.data_ref() };
88 let position = unsafe { *self.position_ref() };
89 let available = data.len().saturating_sub(position);
90 let actual_read = bytes_to_read.min(available);
91
92 if actual_read == 0 {
93 if !num_bytes_read.is_null() {
94 unsafe {
95 *num_bytes_read = 0;
96 }
97 }
98 return kResultFalse;
99 }
100
101 let src_slice = &data[position..position + actual_read];
102 let dst_slice = unsafe { std::slice::from_raw_parts_mut(buffer as *mut u8, actual_read) };
103 dst_slice.copy_from_slice(src_slice);
104
105 unsafe {
106 *self.position_mut() += actual_read;
107 }
108
109 if !num_bytes_read.is_null() {
110 unsafe {
111 *num_bytes_read = actual_read as i32;
112 }
113 }
114
115 kResultOk
116 }
117
118 unsafe fn write(
119 &self,
120 buffer: *mut std::os::raw::c_void,
121 num_bytes: i32,
122 num_bytes_written: *mut i32,
123 ) -> TResult {
124 if buffer.is_null() || num_bytes < 0 {
125 return kResultFalse;
126 }
127
128 let bytes_to_write = num_bytes as usize;
129 let src_slice = unsafe { std::slice::from_raw_parts(buffer as *mut u8, bytes_to_write) };
130
131 let data = unsafe { self.data_mut() };
132 let position = unsafe { *self.position_ref() };
133
134 let required_len = position + bytes_to_write;
135 if required_len > data.len() {
136 data.resize(required_len, 0);
137 }
138
139 data[position..position + bytes_to_write].copy_from_slice(src_slice);
140 unsafe {
141 *self.position_mut() += bytes_to_write;
142 }
143
144 if !num_bytes_written.is_null() {
145 unsafe {
146 *num_bytes_written = bytes_to_write as i32;
147 }
148 }
149
150 kResultOk
151 }
152
153 unsafe fn seek(&self, pos: i64, mode: i32, result: *mut i64) -> TResult {
154 let current_pos = unsafe { *self.position_ref() };
155 let data_len = unsafe { self.data_ref().len() };
156
157 let new_position = match mode {
158 0 => {
159 if pos < 0 {
160 return kResultFalse;
161 }
162 pos as usize
163 }
164 1 => {
165 if pos < 0 {
166 current_pos.saturating_sub((-pos) as usize)
167 } else {
168 current_pos.saturating_add(pos as usize)
169 }
170 }
171 2 => {
172 if pos > 0 {
173 return kResultFalse;
174 }
175 data_len.saturating_sub((-pos) as usize)
176 }
177 _ => return kResultFalse,
178 };
179
180 unsafe {
181 *self.position_mut() = new_position;
182 }
183
184 if !result.is_null() {
185 unsafe {
186 *result = new_position as i64;
187 }
188 }
189
190 kResultOk
191 }
192
193 unsafe fn tell(&self, pos: *mut i64) -> TResult {
194 if pos.is_null() {
195 return kResultFalse;
196 }
197
198 unsafe {
199 *pos = *self.position_ref() as i64;
200 }
201 kResultOk
202 }
203}
204
205#[cfg(test)]
206mod tests {
207 use super::*;
208
209 #[test]
210 fn test_memory_stream_write_read() {
211 let stream = MemoryStream::new();
212 let test_data = b"Hello, VST3!";
213
214 unsafe {
215 let mut written = 0;
216 let result = stream.write(
217 test_data.as_ptr() as *mut std::os::raw::c_void,
218 test_data.len() as i32,
219 &mut written,
220 );
221 assert_eq!(result, kResultOk);
222 assert_eq!(written, test_data.len() as i32);
223 }
224
225 unsafe {
226 let mut new_pos = 0;
227 stream.seek(0, 0, &mut new_pos);
228 assert_eq!(new_pos, 0);
229 }
230
231 let mut read_buffer = vec![0u8; test_data.len()];
232 unsafe {
233 let mut read_count = 0;
234 let result = stream.read(
235 read_buffer.as_mut_ptr() as *mut _,
236 test_data.len() as i32,
237 &mut read_count,
238 );
239 assert_eq!(result, kResultOk);
240 assert_eq!(read_count, test_data.len() as i32);
241 }
242
243 assert_eq!(&read_buffer, test_data);
244 }
245
246 #[test]
247 fn test_memory_stream_seek() {
248 let stream = MemoryStream::from_bytes(b"0123456789");
249
250 unsafe {
251 let mut pos = 0;
252 stream.seek(5, 0, &mut pos);
253 assert_eq!(pos, 5);
254 }
255
256 unsafe {
257 let mut pos = 0;
258 stream.tell(&mut pos);
259 assert_eq!(pos, 5);
260 }
261
262 unsafe {
263 let mut pos = 0;
264 stream.seek(2, 1, &mut pos);
265 assert_eq!(pos, 7);
266 }
267
268 unsafe {
269 let mut pos = 0;
270 stream.seek(-3, 2, &mut pos);
271 assert_eq!(pos, 7);
272 }
273 }
274
275 #[test]
276 fn test_plugin_state_serialization() {
277 let state = Vst3PluginState {
278 plugin_id: "com.example.plugin".to_string(),
279 component_state: vec![1, 2, 3, 4],
280 controller_state: vec![5, 6, 7, 8],
281 };
282
283 let json = serde_json::to_string(&state).unwrap();
284 let deserialized: Vst3PluginState = serde_json::from_str(&json).unwrap();
285
286 assert_eq!(state.plugin_id, deserialized.plugin_id);
287 assert_eq!(state.component_state, deserialized.component_state);
288 assert_eq!(state.controller_state, deserialized.controller_state);
289 }
290}