1use std::{
16 collections::HashMap,
17 ffi::{c_char, c_void, CString},
18 pin::Pin,
19 sync::{Arc, LazyLock},
20 task::{Context, Poll},
21 time::Duration,
22};
23
24use futures_util::Stream;
25use libloading::{Library, Symbol};
26use libwebrtc::{audio_stream::native::NativeAudioStream, prelude::AudioFrame};
27use parking_lot::RwLock;
28use serde::Serialize;
29use serde_json::json;
30
31#[derive(Debug, thiserror::Error)]
32pub enum PluginError {
33 #[error("dylib error: {0}")]
34 Library(#[from] libloading::Error),
35 #[error("dylib error: {0}")]
36 NotImplemented(String),
37 #[error("on_load failed with error: {0}")]
38 OnLoad(i32),
39}
40
41type OnLoadFn = unsafe extern "C" fn(options: *const c_char) -> i32;
42type CreateFn = unsafe extern "C" fn(
43 sampling_rate: u32,
44 options: *const c_char,
45 stream_info: *const c_char,
46) -> *mut c_void;
47type CreateV2Fn = unsafe extern "C" fn(
48 input_sample_rate: u32,
49 output_sample_rate: u32,
50 options: *const c_char,
51 stream_info: *const c_char,
52) -> *mut c_void;
53type DestroyFn = unsafe extern "C" fn(*const c_void);
54type ProcessI16Fn = unsafe extern "C" fn(*const c_void, usize, *const i16, *mut i16);
55type ProcessI16V2Fn = unsafe extern "C" fn(*const c_void, usize, *const i16, usize, *mut i16);
56type ProcessF32Fn = unsafe extern "C" fn(*const c_void, usize, *const f32, *mut f32);
57type ProcessF32V2Fn = unsafe extern "C" fn(*const c_void, usize, *const f32, usize, *mut f32);
58type UpdateStreamInfoFn = unsafe extern "C" fn(*const c_void, *const c_char);
59type UpdateRefreshedTokenFn = unsafe extern "C" fn(*const c_char, *const c_char);
60
61static REGISTERED_PLUGINS: LazyLock<RwLock<HashMap<String, Arc<AudioFilterPlugin>>>> =
62 LazyLock::new(|| RwLock::new(HashMap::new()));
63
64pub fn register_audio_filter_plugin(id: String, plugin: Arc<AudioFilterPlugin>) {
65 REGISTERED_PLUGINS.write().insert(id, plugin);
66}
67
68pub fn registered_audio_filter_plugin(id: &str) -> Option<Arc<AudioFilterPlugin>> {
69 REGISTERED_PLUGINS.read().get(id).cloned()
70}
71
72pub fn registered_audio_filter_plugins() -> Vec<Arc<AudioFilterPlugin>> {
73 REGISTERED_PLUGINS.read().values().map(|v| v.clone()).collect()
74}
75
76pub struct AudioFilterPlugin {
77 lib: Library,
78 dependencies: Vec<Library>,
79 on_load_fn_ptr: *const c_void,
80 create_fn_ptr: *const c_void,
81 create_v2_fn_ptr: *const c_void,
82 destroy_fn_ptr: *const c_void,
83 process_i16_fn_ptr: *const c_void,
84 process_i16_v2_fn_ptr: *const c_void,
85 process_f32_fn_ptr: *const c_void,
86 process_f32_v2_fn_ptr: *const c_void,
87 update_stream_info_fn_ptr: *const c_void,
88 update_token_fn_ptr: *const c_void,
89}
90
91impl AudioFilterPlugin {
92 pub fn new<P: AsRef<str>>(path: P) -> Result<Arc<Self>, PluginError> {
93 Ok(Arc::new(Self::_new(path)?))
94 }
95
96 pub fn new_with_dependencies<P: AsRef<str>>(
97 path: P,
98 dependencies: Vec<P>,
99 ) -> Result<Arc<Self>, PluginError> {
100 let mut libs = vec![];
101 for path in dependencies {
102 let lib = unsafe { Library::new(path.as_ref()) }?;
103 libs.push(lib);
104 }
105 let mut this = Self::_new(path)?;
106 this.dependencies = libs;
107 Ok(Arc::new(this))
108 }
109
110 fn _new<P: AsRef<str>>(path: P) -> Result<Self, PluginError> {
111 let lib = unsafe { Library::new(path.as_ref()) }?;
112
113 let on_load_fn_ptr = unsafe {
114 lib.get::<Symbol<OnLoadFn>>(b"audio_filter_on_load")?.try_as_raw_ptr().unwrap()
115 };
116
117 let create_fn_ptr = unsafe {
118 lib.get::<Symbol<CreateFn>>(b"audio_filter_create")?.try_as_raw_ptr().unwrap()
119 };
120 let create_v2_fn_ptr = unsafe {
121 match lib.get::<Symbol<CreateV2Fn>>(b"audio_filter_create_v2") {
122 Ok(sym) => sym.try_as_raw_ptr().unwrap(),
123 Err(_) => std::ptr::null(),
124 }
125 };
126 if create_fn_ptr.is_null() && create_v2_fn_ptr.is_null() {
127 return Err(PluginError::NotImplemented(
128 "audio_filter_create is not implemented".into(),
129 ));
130 }
131 let destroy_fn_ptr = unsafe {
132 lib.get::<Symbol<DestroyFn>>(b"audio_filter_destroy")?.try_as_raw_ptr().unwrap()
133 };
134 if destroy_fn_ptr.is_null() {
135 return Err(PluginError::NotImplemented(
136 "audio_filter_destroy is not implemented".into(),
137 ));
138 }
139 let process_i16_fn_ptr = unsafe {
140 lib.get::<Symbol<ProcessI16Fn>>(b"audio_filter_process_int16")?
141 .try_as_raw_ptr()
142 .unwrap()
143 };
144 let process_i16_v2_fn_ptr = unsafe {
145 match lib.get::<Symbol<ProcessI16V2Fn>>(b"audio_filter_process_int16_v2") {
146 Ok(sym) => sym.try_as_raw_ptr().unwrap(),
147 Err(_) => std::ptr::null(),
148 }
149 };
150 if process_i16_fn_ptr.is_null() && process_i16_v2_fn_ptr.is_null() {
151 return Err(PluginError::NotImplemented(
152 "audio_filter_process_int16 is not implemented".into(),
153 ));
154 }
155 let process_f32_fn_ptr = unsafe {
156 lib.get::<Symbol<ProcessF32Fn>>(b"audio_filter_process_float")?
157 .try_as_raw_ptr()
158 .unwrap()
159 };
160 let process_f32_v2_fn_ptr = unsafe {
161 match lib.get::<Symbol<ProcessF32V2Fn>>(b"audio_filter_process_float_v2") {
162 Ok(sym) => sym.try_as_raw_ptr().unwrap(),
163 Err(_) => std::ptr::null(),
164 }
165 };
166 let update_stream_info_fn_ptr = unsafe {
167 lib.get::<Symbol<UpdateStreamInfoFn>>(b"audio_filter_update_stream_info")?
168 .try_as_raw_ptr()
169 .unwrap()
170 };
171 let update_token_fn_ptr = unsafe {
172 match lib.get::<Symbol<UpdateRefreshedTokenFn>>(b"audio_filter_update_token") {
174 Ok(sym) => sym.try_as_raw_ptr().unwrap(),
175 Err(_) => std::ptr::null(),
176 }
177 };
178
179 Ok(Self {
180 lib,
181 dependencies: Default::default(),
182 on_load_fn_ptr,
183 create_fn_ptr,
184 create_v2_fn_ptr,
185 destroy_fn_ptr,
186 process_i16_fn_ptr,
187 process_i16_v2_fn_ptr,
188 process_f32_fn_ptr,
189 process_f32_v2_fn_ptr,
190 update_stream_info_fn_ptr,
191 update_token_fn_ptr,
192 })
193 }
194
195 pub fn on_load<S: AsRef<str>>(&self, url: S, token: S) -> Result<(), PluginError> {
196 if self.on_load_fn_ptr.is_null() {
197 return Ok(());
199 }
200
201 let options_json = json!({
202 "url": url.as_ref().to_string(),
203 "token": token.as_ref().to_string(),
204 });
205 let options = serde_json::to_string(&options_json).map_err(|e| {
206 eprintln!("failed to serialize option: {}", e);
207 PluginError::OnLoad(-1)
208 })?;
209
210 let options = CString::new(options).unwrap_or(CString::new("").unwrap());
211 let on_load_fn: OnLoadFn = unsafe { std::mem::transmute(self.on_load_fn_ptr) };
212
213 let res = unsafe { on_load_fn(options.as_ptr()) };
214 if res == 0 {
215 Ok(())
216 } else {
217 Err(PluginError::OnLoad(res))
218 }
219 }
220
221 pub fn update_token(&self, url: String, token: String) {
222 if self.update_token_fn_ptr.is_null() {
223 return;
224 }
225 let update_token_fn: UpdateRefreshedTokenFn =
226 unsafe { std::mem::transmute(self.update_token_fn_ptr) };
227 let url = CString::new(url).unwrap();
228 let token = CString::new(token).unwrap();
229 unsafe { update_token_fn(url.as_ptr(), token.as_ptr()) }
230 }
231
232 pub fn supports_separate_rates(&self) -> bool {
233 !self.create_v2_fn_ptr.is_null() && !self.process_i16_v2_fn_ptr.is_null()
234 }
235
236 pub fn new_session<S: AsRef<str>>(
237 self: Arc<Self>,
238 input_sample_rate: u32,
239 output_sample_rate: u32,
240 options: S,
241 stream_info: AudioFilterStreamInfo,
242 ) -> Option<AudioFilterSession> {
243 let options = CString::new(options.as_ref()).unwrap_or(CString::new("").unwrap());
244
245 let stream_info = serde_json::to_string(&stream_info).unwrap();
246 let stream_info = CString::new(stream_info).unwrap_or(CString::new("").unwrap());
247
248 let ptr = if !self.create_v2_fn_ptr.is_null() {
249 let create_fn: CreateV2Fn = unsafe { std::mem::transmute(self.create_v2_fn_ptr) };
250 unsafe {
251 create_fn(
252 input_sample_rate,
253 output_sample_rate,
254 options.as_ptr(),
255 stream_info.as_ptr(),
256 )
257 }
258 } else {
259 let create_fn: CreateFn = unsafe { std::mem::transmute(self.create_fn_ptr) };
260 unsafe { create_fn(input_sample_rate, options.as_ptr(), stream_info.as_ptr()) }
261 };
262 if ptr.is_null() {
263 return None;
264 }
265
266 Some(AudioFilterSession { plugin: self.clone(), ptr })
267 }
268}
269
270pub struct AudioFilterSession {
271 plugin: Arc<AudioFilterPlugin>,
272 ptr: *const c_void,
273}
274
275impl AudioFilterSession {
276 pub fn destroy(&self) {
277 let destroy: DestroyFn = unsafe { std::mem::transmute(self.plugin.destroy_fn_ptr) };
278 unsafe { destroy(self.ptr) };
279 }
280
281 pub fn process_i16(
282 &self,
283 in_num_samples: usize,
284 input: &[i16],
285 out_num_samples: usize,
286 output: &mut [i16],
287 ) {
288 if !self.plugin.process_i16_v2_fn_ptr.is_null() {
289 let process: ProcessI16V2Fn =
290 unsafe { std::mem::transmute(self.plugin.process_i16_v2_fn_ptr) };
291 unsafe {
292 process(
293 self.ptr,
294 in_num_samples,
295 input.as_ptr(),
296 out_num_samples,
297 output.as_mut_ptr(),
298 )
299 };
300 } else {
301 let process: ProcessI16Fn =
302 unsafe { std::mem::transmute(self.plugin.process_i16_fn_ptr) };
303 unsafe { process(self.ptr, in_num_samples, input.as_ptr(), output.as_mut_ptr()) };
304 }
305 }
306
307 pub fn process_f32(
308 &self,
309 in_num_samples: usize,
310 input: &[f32],
311 out_num_samples: usize,
312 output: &mut [f32],
313 ) {
314 if !self.plugin.process_f32_v2_fn_ptr.is_null() {
315 let process: ProcessF32V2Fn =
316 unsafe { std::mem::transmute(self.plugin.process_f32_v2_fn_ptr) };
317 unsafe {
318 process(
319 self.ptr,
320 in_num_samples,
321 input.as_ptr(),
322 out_num_samples,
323 output.as_mut_ptr(),
324 )
325 };
326 } else {
327 let process: ProcessF32Fn =
328 unsafe { std::mem::transmute(self.plugin.process_f32_fn_ptr) };
329 unsafe { process(self.ptr, in_num_samples, input.as_ptr(), output.as_mut_ptr()) };
330 }
331 }
332
333 pub fn update_stream_info(&self, info: AudioFilterStreamInfo) {
334 if self.plugin.update_stream_info_fn_ptr.is_null() {
335 return;
336 }
337 let update_stream_info_fn: UpdateStreamInfoFn =
338 unsafe { std::mem::transmute(self.plugin.update_stream_info_fn_ptr) };
339 let info_json = serde_json::to_string(&info).unwrap();
340 let info_json = CString::new(info_json).unwrap_or(CString::new("").unwrap());
341 unsafe { update_stream_info_fn(self.ptr, info_json.as_ptr()) }
342 }
343}
344
345impl Drop for AudioFilterSession {
346 fn drop(&mut self) {
347 if !self.ptr.is_null() {
348 self.destroy();
349 }
350 }
351}
352
353pub struct AudioFilterAudioStream {
354 inner: NativeAudioStream,
355 session: AudioFilterSession,
356 buffer: Vec<i16>,
357 output_sample_rate: u32,
358 num_channels: u32,
359 input_frame_size: usize,
360 output_frame_size: usize,
361}
362
363impl AudioFilterAudioStream {
364 pub fn new(
365 inner: NativeAudioStream,
366 session: AudioFilterSession,
367 duration: Duration,
368 input_sample_rate: u32,
369 output_sample_rate: u32,
370 num_channels: u32,
371 ) -> Self {
372 let input_frame_size =
373 ((input_sample_rate as f64) * duration.as_secs_f64() * num_channels as f64) as usize;
374 let output_frame_size =
375 ((output_sample_rate as f64) * duration.as_secs_f64() * num_channels as f64) as usize;
376 Self {
377 inner,
378 session,
379 buffer: Vec::with_capacity(input_frame_size),
380 output_sample_rate,
381 num_channels,
382 input_frame_size,
383 output_frame_size,
384 }
385 }
386
387 pub fn update_stream_info(&mut self, info: AudioFilterStreamInfo) {
388 self.session.update_stream_info(info);
389 }
390}
391
392impl Stream for AudioFilterAudioStream {
393 type Item = AudioFrame<'static>;
394
395 fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
396 let this = self.get_mut();
397
398 while let Poll::Ready(frame) = Pin::new(&mut this.inner).poll_next(cx) {
399 let Some(frame) = frame else {
400 return Poll::Ready(None);
401 };
402 this.buffer.extend_from_slice(&frame.data);
403
404 if this.buffer.len() >= this.input_frame_size {
405 let data = this.buffer.drain(..this.input_frame_size).collect::<Vec<_>>();
406 let mut out: Vec<i16> = vec![0; this.output_frame_size];
407
408 this.session.process_i16(
409 this.input_frame_size,
410 &data,
411 this.output_frame_size,
412 &mut out,
413 );
414
415 return Poll::Ready(Some(AudioFrame {
416 data: out.into(),
417 sample_rate: this.output_sample_rate,
418 num_channels: this.num_channels,
419 samples_per_channel: (this.output_frame_size / this.num_channels as usize)
420 as u32,
421 }));
422 }
423 }
424
425 Poll::Pending
426 }
427}
428
429#[derive(Debug, Serialize, Default, Clone)]
430#[serde(rename_all = "camelCase")]
431pub struct AudioFilterStreamInfo {
432 pub url: String,
433 pub room_id: String,
434 pub room_name: String,
435 pub participant_identity: String,
436 pub participant_id: String,
437 pub track_id: String,
438}
439
440unsafe impl Send for AudioFilterPlugin {}
443unsafe impl Sync for AudioFilterPlugin {}
444unsafe impl Send for AudioFilterSession {}
445unsafe impl Sync for AudioFilterSession {}