1use std::{
16 collections::HashMap,
17 ffi::{c_char, c_void, CString},
18 pin::Pin,
19 sync::{Arc, LazyLock},
20 task::{Context, Poll},
21 time::Duration,
22};
23
24use futures_util::Stream;
25use libloading::{Library, Symbol};
26use libwebrtc::{audio_stream::native::NativeAudioStream, prelude::AudioFrame};
27use parking_lot::RwLock;
28use serde::Serialize;
29use serde_json::json;
30
31#[derive(Debug, thiserror::Error)]
32pub enum PluginError {
33 #[error("dylib error: {0}")]
34 Library(#[from] libloading::Error),
35 #[error("dylib error: {0}")]
36 NotImplemented(String),
37 #[error("on_load failed with error: {0}")]
38 OnLoad(i32),
39}
40
41type OnLoadFn = unsafe extern "C" fn(options: *const c_char) -> i32;
42type CreateFn = unsafe extern "C" fn(
43 sampling_rate: u32,
44 options: *const c_char,
45 stream_info: *const c_char,
46) -> *mut c_void;
47type DestroyFn = unsafe extern "C" fn(*const c_void);
48type ProcessI16Fn = unsafe extern "C" fn(*const c_void, usize, *const i16, *mut i16);
49type ProcessF32Fn = unsafe extern "C" fn(*const c_void, usize, *const f32, *mut f32);
50type UpdateStreamInfoFn = unsafe extern "C" fn(*const c_void, *const c_char);
51type UpdateRefreshedTokenFn = unsafe extern "C" fn(*const c_char, *const c_char);
52
53static REGISTERED_PLUGINS: LazyLock<RwLock<HashMap<String, Arc<AudioFilterPlugin>>>> =
54 LazyLock::new(|| RwLock::new(HashMap::new()));
55
56pub fn register_audio_filter_plugin(id: String, plugin: Arc<AudioFilterPlugin>) {
57 REGISTERED_PLUGINS.write().insert(id, plugin);
58}
59
60pub fn registered_audio_filter_plugin(id: &str) -> Option<Arc<AudioFilterPlugin>> {
61 REGISTERED_PLUGINS.read().get(id).cloned()
62}
63
64pub fn registered_audio_filter_plugins() -> Vec<Arc<AudioFilterPlugin>> {
65 REGISTERED_PLUGINS.read().values().map(|v| v.clone()).collect()
66}
67
68pub struct AudioFilterPlugin {
69 lib: Library,
70 dependencies: Vec<Library>,
71 on_load_fn_ptr: *const c_void,
72 create_fn_ptr: *const c_void,
73 destroy_fn_ptr: *const c_void,
74 process_i16_fn_ptr: *const c_void,
75 process_f32_fn_ptr: *const c_void,
76 update_stream_info_fn_ptr: *const c_void,
77 update_token_fn_ptr: *const c_void,
78}
79
80impl AudioFilterPlugin {
81 pub fn new<P: AsRef<str>>(path: P) -> Result<Arc<Self>, PluginError> {
82 Ok(Arc::new(Self::_new(path)?))
83 }
84
85 pub fn new_with_dependencies<P: AsRef<str>>(
86 path: P,
87 dependencies: Vec<P>,
88 ) -> Result<Arc<Self>, PluginError> {
89 let mut libs = vec![];
90 for path in dependencies {
91 let lib = unsafe { Library::new(path.as_ref()) }?;
92 libs.push(lib);
93 }
94 let mut this = Self::_new(path)?;
95 this.dependencies = libs;
96 Ok(Arc::new(this))
97 }
98
99 fn _new<P: AsRef<str>>(path: P) -> Result<Self, PluginError> {
100 let lib = unsafe { Library::new(path.as_ref()) }?;
101
102 let on_load_fn_ptr = unsafe {
103 lib.get::<Symbol<OnLoadFn>>(b"audio_filter_on_load")?.try_as_raw_ptr().unwrap()
104 };
105
106 let create_fn_ptr = unsafe {
107 lib.get::<Symbol<CreateFn>>(b"audio_filter_create")?.try_as_raw_ptr().unwrap()
108 };
109 if create_fn_ptr.is_null() {
110 return Err(PluginError::NotImplemented(
111 "audio_filter_create is not implemented".into(),
112 ));
113 }
114 let destroy_fn_ptr = unsafe {
115 lib.get::<Symbol<DestroyFn>>(b"audio_filter_destroy")?.try_as_raw_ptr().unwrap()
116 };
117 if destroy_fn_ptr.is_null() {
118 return Err(PluginError::NotImplemented(
119 "audio_filter_destroy is not implemented".into(),
120 ));
121 }
122 let process_i16_fn_ptr = unsafe {
123 lib.get::<Symbol<ProcessI16Fn>>(b"audio_filter_process_int16")?
124 .try_as_raw_ptr()
125 .unwrap()
126 };
127 if process_i16_fn_ptr.is_null() {
128 return Err(PluginError::NotImplemented(
129 "audio_filter_process_int16 is not implemented".into(),
130 ));
131 }
132 let process_f32_fn_ptr = unsafe {
133 lib.get::<Symbol<ProcessF32Fn>>(b"audio_filter_process_float")?
134 .try_as_raw_ptr()
135 .unwrap()
136 };
137 let update_stream_info_fn_ptr = unsafe {
138 lib.get::<Symbol<UpdateStreamInfoFn>>(b"audio_filter_update_stream_info")?
139 .try_as_raw_ptr()
140 .unwrap()
141 };
142 let update_token_fn_ptr = unsafe {
143 match lib.get::<Symbol<UpdateRefreshedTokenFn>>(b"audio_filter_update_token") {
145 Ok(sym) => sym.try_as_raw_ptr().unwrap(),
146 Err(_) => std::ptr::null(),
147 }
148 };
149
150 Ok(Self {
151 lib,
152 dependencies: Default::default(),
153 on_load_fn_ptr,
154 create_fn_ptr,
155 destroy_fn_ptr,
156 process_i16_fn_ptr,
157 process_f32_fn_ptr,
158 update_stream_info_fn_ptr,
159 update_token_fn_ptr,
160 })
161 }
162
163 pub fn on_load<S: AsRef<str>>(&self, url: S, token: S) -> Result<(), PluginError> {
164 if self.on_load_fn_ptr.is_null() {
165 return Ok(());
167 }
168
169 let options_json = json!({
170 "url": url.as_ref().to_string(),
171 "token": token.as_ref().to_string(),
172 });
173 let options = serde_json::to_string(&options_json).map_err(|e| {
174 eprintln!("failed to serialize option: {}", e);
175 PluginError::OnLoad(-1)
176 })?;
177
178 let options = CString::new(options).unwrap_or(CString::new("").unwrap());
179 let on_load_fn: OnLoadFn = unsafe { std::mem::transmute(self.on_load_fn_ptr) };
180
181 let res = unsafe { on_load_fn(options.as_ptr()) };
182 if res == 0 {
183 Ok(())
184 } else {
185 Err(PluginError::OnLoad(res))
186 }
187 }
188
189 pub fn update_token(&self, url: String, token: String) {
190 if self.update_token_fn_ptr.is_null() {
191 return;
192 }
193 let update_token_fn: UpdateRefreshedTokenFn =
194 unsafe { std::mem::transmute(self.update_token_fn_ptr) };
195 let url = CString::new(url).unwrap();
196 let token = CString::new(token).unwrap();
197 unsafe { update_token_fn(url.as_ptr(), token.as_ptr()) }
198 }
199
200 pub fn new_session<S: AsRef<str>>(
201 self: Arc<Self>,
202 sampling_rate: u32,
203 options: S,
204 stream_info: AudioFilterStreamInfo,
205 ) -> Option<AudioFilterSession> {
206 let create_fn: CreateFn = unsafe { std::mem::transmute(self.create_fn_ptr) };
207
208 let options = CString::new(options.as_ref()).unwrap_or(CString::new("").unwrap());
209
210 let stream_info = serde_json::to_string(&stream_info).unwrap();
211 let stream_info = CString::new(stream_info).unwrap_or(CString::new("").unwrap());
212
213 let ptr = unsafe { create_fn(sampling_rate, options.as_ptr(), stream_info.as_ptr()) };
214 if ptr.is_null() {
215 return None;
216 }
217
218 Some(AudioFilterSession { plugin: self.clone(), ptr })
219 }
220}
221
222pub struct AudioFilterSession {
223 plugin: Arc<AudioFilterPlugin>,
224 ptr: *const c_void,
225}
226
227impl AudioFilterSession {
228 pub fn destroy(&self) {
229 let destroy: DestroyFn = unsafe { std::mem::transmute(self.plugin.destroy_fn_ptr) };
230 unsafe { destroy(self.ptr) };
231 }
232
233 pub fn process_i16(&self, num_samples: usize, input: &[i16], output: &mut [i16]) {
234 let process: ProcessI16Fn = unsafe { std::mem::transmute(self.plugin.process_i16_fn_ptr) };
235 unsafe { process(self.ptr, num_samples, input.as_ptr(), output.as_mut_ptr()) };
236 }
237
238 pub fn process_f32(&self, num_samples: usize, input: &[f32], output: &mut [f32]) {
239 let process: ProcessF32Fn = unsafe { std::mem::transmute(self.plugin.process_f32_fn_ptr) };
240 unsafe { process(self.ptr, num_samples, input.as_ptr(), output.as_mut_ptr()) };
241 }
242
243 pub fn update_stream_info(&self, info: AudioFilterStreamInfo) {
244 if self.plugin.update_stream_info_fn_ptr.is_null() {
245 return;
246 }
247 let update_stream_info_fn: UpdateStreamInfoFn =
248 unsafe { std::mem::transmute(self.plugin.update_stream_info_fn_ptr) };
249 let info_json = serde_json::to_string(&info).unwrap();
250 let info_json = CString::new(info_json).unwrap_or(CString::new("").unwrap());
251 unsafe { update_stream_info_fn(self.ptr, info_json.as_ptr()) }
252 }
253}
254
255impl Drop for AudioFilterSession {
256 fn drop(&mut self) {
257 if !self.ptr.is_null() {
258 self.destroy();
259 }
260 }
261}
262
263pub struct AudioFilterAudioStream {
264 inner: NativeAudioStream,
265 session: AudioFilterSession,
266 buffer: Vec<i16>,
267 sample_rate: u32,
268 num_channels: u32,
269 frame_size: usize,
270}
271
272impl AudioFilterAudioStream {
273 pub fn new(
274 inner: NativeAudioStream,
275 session: AudioFilterSession,
276 duration: Duration,
277 sample_rate: u32,
278 num_channels: u32,
279 ) -> Self {
280 let frame_size =
281 ((sample_rate as f64) * duration.as_secs_f64() * num_channels as f64) as usize;
282 Self {
283 inner,
284 session,
285 buffer: Vec::with_capacity(frame_size),
286 sample_rate,
287 num_channels,
288 frame_size,
289 }
290 }
291
292 pub fn update_stream_info(&mut self, info: AudioFilterStreamInfo) {
293 self.session.update_stream_info(info);
294 }
295}
296
297impl Stream for AudioFilterAudioStream {
298 type Item = AudioFrame<'static>;
299
300 fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
301 let this = self.get_mut();
302
303 while let Poll::Ready(frame) = Pin::new(&mut this.inner).poll_next(cx) {
304 let Some(frame) = frame else {
305 return Poll::Ready(None);
306 };
307 this.buffer.extend_from_slice(&frame.data);
308
309 if this.buffer.len() >= this.frame_size {
310 let data = this.buffer.drain(..this.frame_size).collect::<Vec<_>>();
311 let mut out: Vec<i16> = vec![0; this.frame_size];
312
313 this.session.process_i16(this.frame_size, &data, &mut out);
314
315 return Poll::Ready(Some(AudioFrame {
316 data: out.into(),
317 sample_rate: this.sample_rate,
318 num_channels: this.num_channels,
319 samples_per_channel: (this.frame_size / this.num_channels as usize) as u32,
320 }));
321 }
322 }
323
324 Poll::Pending
325 }
326}
327
328#[derive(Debug, Serialize, Default, Clone)]
329#[serde(rename_all = "camelCase")]
330pub struct AudioFilterStreamInfo {
331 pub url: String,
332 pub room_id: String,
333 pub room_name: String,
334 pub participant_identity: String,
335 pub participant_id: String,
336 pub track_id: String,
337}
338
339unsafe impl Send for AudioFilterPlugin {}
342unsafe impl Sync for AudioFilterPlugin {}
343unsafe impl Send for AudioFilterSession {}
344unsafe impl Sync for AudioFilterSession {}