Skip to main content

livekit/
plugin.rs

1// Copyright 2025 LiveKit, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::{
16    collections::HashMap,
17    ffi::{c_char, c_void, CString},
18    pin::Pin,
19    sync::{Arc, LazyLock},
20    task::{Context, Poll},
21    time::Duration,
22};
23
24use futures_util::Stream;
25use libloading::{Library, Symbol};
26use libwebrtc::{audio_stream::native::NativeAudioStream, prelude::AudioFrame};
27use parking_lot::RwLock;
28use serde::Serialize;
29use serde_json::json;
30
31#[derive(Debug, thiserror::Error)]
32pub enum PluginError {
33    #[error("dylib error: {0}")]
34    Library(#[from] libloading::Error),
35    #[error("dylib error: {0}")]
36    NotImplemented(String),
37    #[error("on_load failed with error: {0}")]
38    OnLoad(i32),
39}
40
41type OnLoadFn = unsafe extern "C" fn(options: *const c_char) -> i32;
42type CreateFn = unsafe extern "C" fn(
43    sampling_rate: u32,
44    options: *const c_char,
45    stream_info: *const c_char,
46) -> *mut c_void;
47type CreateV2Fn = unsafe extern "C" fn(
48    input_sample_rate: u32,
49    output_sample_rate: u32,
50    options: *const c_char,
51    stream_info: *const c_char,
52) -> *mut c_void;
53type DestroyFn = unsafe extern "C" fn(*const c_void);
54type ProcessI16Fn = unsafe extern "C" fn(*const c_void, usize, *const i16, *mut i16);
55type ProcessI16V2Fn = unsafe extern "C" fn(*const c_void, usize, *const i16, usize, *mut i16);
56type ProcessF32Fn = unsafe extern "C" fn(*const c_void, usize, *const f32, *mut f32);
57type ProcessF32V2Fn = unsafe extern "C" fn(*const c_void, usize, *const f32, usize, *mut f32);
58type UpdateStreamInfoFn = unsafe extern "C" fn(*const c_void, *const c_char);
59type UpdateRefreshedTokenFn = unsafe extern "C" fn(*const c_char, *const c_char);
60
61static REGISTERED_PLUGINS: LazyLock<RwLock<HashMap<String, Arc<AudioFilterPlugin>>>> =
62    LazyLock::new(|| RwLock::new(HashMap::new()));
63
64pub fn register_audio_filter_plugin(id: String, plugin: Arc<AudioFilterPlugin>) {
65    REGISTERED_PLUGINS.write().insert(id, plugin);
66}
67
68pub fn registered_audio_filter_plugin(id: &str) -> Option<Arc<AudioFilterPlugin>> {
69    REGISTERED_PLUGINS.read().get(id).cloned()
70}
71
72pub fn registered_audio_filter_plugins() -> Vec<Arc<AudioFilterPlugin>> {
73    REGISTERED_PLUGINS.read().values().map(|v| v.clone()).collect()
74}
75
76pub struct AudioFilterPlugin {
77    lib: Library,
78    dependencies: Vec<Library>,
79    on_load_fn_ptr: *const c_void,
80    create_fn_ptr: *const c_void,
81    create_v2_fn_ptr: *const c_void,
82    destroy_fn_ptr: *const c_void,
83    process_i16_fn_ptr: *const c_void,
84    process_i16_v2_fn_ptr: *const c_void,
85    process_f32_fn_ptr: *const c_void,
86    process_f32_v2_fn_ptr: *const c_void,
87    update_stream_info_fn_ptr: *const c_void,
88    update_token_fn_ptr: *const c_void,
89}
90
91impl AudioFilterPlugin {
92    pub fn new<P: AsRef<str>>(path: P) -> Result<Arc<Self>, PluginError> {
93        Ok(Arc::new(Self::_new(path)?))
94    }
95
96    pub fn new_with_dependencies<P: AsRef<str>>(
97        path: P,
98        dependencies: Vec<P>,
99    ) -> Result<Arc<Self>, PluginError> {
100        let mut libs = vec![];
101        for path in dependencies {
102            let lib = unsafe { Library::new(path.as_ref()) }?;
103            libs.push(lib);
104        }
105        let mut this = Self::_new(path)?;
106        this.dependencies = libs;
107        Ok(Arc::new(this))
108    }
109
110    fn _new<P: AsRef<str>>(path: P) -> Result<Self, PluginError> {
111        let lib = unsafe { Library::new(path.as_ref()) }?;
112
113        let on_load_fn_ptr = unsafe {
114            lib.get::<Symbol<OnLoadFn>>(b"audio_filter_on_load")?.try_as_raw_ptr().unwrap()
115        };
116
117        let create_fn_ptr = unsafe {
118            lib.get::<Symbol<CreateFn>>(b"audio_filter_create")?.try_as_raw_ptr().unwrap()
119        };
120        let create_v2_fn_ptr = unsafe {
121            match lib.get::<Symbol<CreateV2Fn>>(b"audio_filter_create_v2") {
122                Ok(sym) => sym.try_as_raw_ptr().unwrap(),
123                Err(_) => std::ptr::null(),
124            }
125        };
126        if create_fn_ptr.is_null() && create_v2_fn_ptr.is_null() {
127            return Err(PluginError::NotImplemented(
128                "audio_filter_create is not implemented".into(),
129            ));
130        }
131        let destroy_fn_ptr = unsafe {
132            lib.get::<Symbol<DestroyFn>>(b"audio_filter_destroy")?.try_as_raw_ptr().unwrap()
133        };
134        if destroy_fn_ptr.is_null() {
135            return Err(PluginError::NotImplemented(
136                "audio_filter_destroy is not implemented".into(),
137            ));
138        }
139        let process_i16_fn_ptr = unsafe {
140            lib.get::<Symbol<ProcessI16Fn>>(b"audio_filter_process_int16")?
141                .try_as_raw_ptr()
142                .unwrap()
143        };
144        let process_i16_v2_fn_ptr = unsafe {
145            match lib.get::<Symbol<ProcessI16V2Fn>>(b"audio_filter_process_int16_v2") {
146                Ok(sym) => sym.try_as_raw_ptr().unwrap(),
147                Err(_) => std::ptr::null(),
148            }
149        };
150        if process_i16_fn_ptr.is_null() && process_i16_v2_fn_ptr.is_null() {
151            return Err(PluginError::NotImplemented(
152                "audio_filter_process_int16 is not implemented".into(),
153            ));
154        }
155        let process_f32_fn_ptr = unsafe {
156            lib.get::<Symbol<ProcessF32Fn>>(b"audio_filter_process_float")?
157                .try_as_raw_ptr()
158                .unwrap()
159        };
160        let process_f32_v2_fn_ptr = unsafe {
161            match lib.get::<Symbol<ProcessF32V2Fn>>(b"audio_filter_process_float_v2") {
162                Ok(sym) => sym.try_as_raw_ptr().unwrap(),
163                Err(_) => std::ptr::null(),
164            }
165        };
166        let update_stream_info_fn_ptr = unsafe {
167            lib.get::<Symbol<UpdateStreamInfoFn>>(b"audio_filter_update_stream_info")?
168                .try_as_raw_ptr()
169                .unwrap()
170        };
171        let update_token_fn_ptr = unsafe {
172            // treat as optional function for now
173            match lib.get::<Symbol<UpdateRefreshedTokenFn>>(b"audio_filter_update_token") {
174                Ok(sym) => sym.try_as_raw_ptr().unwrap(),
175                Err(_) => std::ptr::null(),
176            }
177        };
178
179        Ok(Self {
180            lib,
181            dependencies: Default::default(),
182            on_load_fn_ptr,
183            create_fn_ptr,
184            create_v2_fn_ptr,
185            destroy_fn_ptr,
186            process_i16_fn_ptr,
187            process_i16_v2_fn_ptr,
188            process_f32_fn_ptr,
189            process_f32_v2_fn_ptr,
190            update_stream_info_fn_ptr,
191            update_token_fn_ptr,
192        })
193    }
194
195    pub fn on_load<S: AsRef<str>>(&self, url: S, token: S) -> Result<(), PluginError> {
196        if self.on_load_fn_ptr.is_null() {
197            // on_load is optional function
198            return Ok(());
199        }
200
201        let options_json = json!({
202            "url": url.as_ref().to_string(),
203            "token": token.as_ref().to_string(),
204        });
205        let options = serde_json::to_string(&options_json).map_err(|e| {
206            eprintln!("failed to serialize option: {}", e);
207            PluginError::OnLoad(-1)
208        })?;
209
210        let options = CString::new(options).unwrap_or(CString::new("").unwrap());
211        let on_load_fn: OnLoadFn = unsafe { std::mem::transmute(self.on_load_fn_ptr) };
212
213        let res = unsafe { on_load_fn(options.as_ptr()) };
214        if res == 0 {
215            Ok(())
216        } else {
217            Err(PluginError::OnLoad(res))
218        }
219    }
220
221    pub fn update_token(&self, url: String, token: String) {
222        if self.update_token_fn_ptr.is_null() {
223            return;
224        }
225        let update_token_fn: UpdateRefreshedTokenFn =
226            unsafe { std::mem::transmute(self.update_token_fn_ptr) };
227        let url = CString::new(url).unwrap();
228        let token = CString::new(token).unwrap();
229        unsafe { update_token_fn(url.as_ptr(), token.as_ptr()) }
230    }
231
232    pub fn supports_separate_rates(&self) -> bool {
233        !self.create_v2_fn_ptr.is_null() && !self.process_i16_v2_fn_ptr.is_null()
234    }
235
236    pub fn new_session<S: AsRef<str>>(
237        self: Arc<Self>,
238        input_sample_rate: u32,
239        output_sample_rate: u32,
240        options: S,
241        stream_info: AudioFilterStreamInfo,
242    ) -> Option<AudioFilterSession> {
243        let options = CString::new(options.as_ref()).unwrap_or(CString::new("").unwrap());
244
245        let stream_info = serde_json::to_string(&stream_info).unwrap();
246        let stream_info = CString::new(stream_info).unwrap_or(CString::new("").unwrap());
247
248        let ptr = if !self.create_v2_fn_ptr.is_null() {
249            let create_fn: CreateV2Fn = unsafe { std::mem::transmute(self.create_v2_fn_ptr) };
250            unsafe {
251                create_fn(
252                    input_sample_rate,
253                    output_sample_rate,
254                    options.as_ptr(),
255                    stream_info.as_ptr(),
256                )
257            }
258        } else {
259            let create_fn: CreateFn = unsafe { std::mem::transmute(self.create_fn_ptr) };
260            unsafe { create_fn(input_sample_rate, options.as_ptr(), stream_info.as_ptr()) }
261        };
262        if ptr.is_null() {
263            return None;
264        }
265
266        Some(AudioFilterSession { plugin: self.clone(), ptr })
267    }
268}
269
270pub struct AudioFilterSession {
271    plugin: Arc<AudioFilterPlugin>,
272    ptr: *const c_void,
273}
274
275impl AudioFilterSession {
276    pub fn destroy(&self) {
277        let destroy: DestroyFn = unsafe { std::mem::transmute(self.plugin.destroy_fn_ptr) };
278        unsafe { destroy(self.ptr) };
279    }
280
281    pub fn process_i16(
282        &self,
283        in_num_samples: usize,
284        input: &[i16],
285        out_num_samples: usize,
286        output: &mut [i16],
287    ) {
288        if !self.plugin.process_i16_v2_fn_ptr.is_null() {
289            let process: ProcessI16V2Fn =
290                unsafe { std::mem::transmute(self.plugin.process_i16_v2_fn_ptr) };
291            unsafe {
292                process(
293                    self.ptr,
294                    in_num_samples,
295                    input.as_ptr(),
296                    out_num_samples,
297                    output.as_mut_ptr(),
298                )
299            };
300        } else {
301            let process: ProcessI16Fn =
302                unsafe { std::mem::transmute(self.plugin.process_i16_fn_ptr) };
303            unsafe { process(self.ptr, in_num_samples, input.as_ptr(), output.as_mut_ptr()) };
304        }
305    }
306
307    pub fn process_f32(
308        &self,
309        in_num_samples: usize,
310        input: &[f32],
311        out_num_samples: usize,
312        output: &mut [f32],
313    ) {
314        if !self.plugin.process_f32_v2_fn_ptr.is_null() {
315            let process: ProcessF32V2Fn =
316                unsafe { std::mem::transmute(self.plugin.process_f32_v2_fn_ptr) };
317            unsafe {
318                process(
319                    self.ptr,
320                    in_num_samples,
321                    input.as_ptr(),
322                    out_num_samples,
323                    output.as_mut_ptr(),
324                )
325            };
326        } else {
327            let process: ProcessF32Fn =
328                unsafe { std::mem::transmute(self.plugin.process_f32_fn_ptr) };
329            unsafe { process(self.ptr, in_num_samples, input.as_ptr(), output.as_mut_ptr()) };
330        }
331    }
332
333    pub fn update_stream_info(&self, info: AudioFilterStreamInfo) {
334        if self.plugin.update_stream_info_fn_ptr.is_null() {
335            return;
336        }
337        let update_stream_info_fn: UpdateStreamInfoFn =
338            unsafe { std::mem::transmute(self.plugin.update_stream_info_fn_ptr) };
339        let info_json = serde_json::to_string(&info).unwrap();
340        let info_json = CString::new(info_json).unwrap_or(CString::new("").unwrap());
341        unsafe { update_stream_info_fn(self.ptr, info_json.as_ptr()) }
342    }
343}
344
345impl Drop for AudioFilterSession {
346    fn drop(&mut self) {
347        if !self.ptr.is_null() {
348            self.destroy();
349        }
350    }
351}
352
353pub struct AudioFilterAudioStream {
354    inner: NativeAudioStream,
355    session: AudioFilterSession,
356    buffer: Vec<i16>,
357    output_sample_rate: u32,
358    num_channels: u32,
359    input_frame_size: usize,
360    output_frame_size: usize,
361}
362
363impl AudioFilterAudioStream {
364    pub fn new(
365        inner: NativeAudioStream,
366        session: AudioFilterSession,
367        duration: Duration,
368        input_sample_rate: u32,
369        output_sample_rate: u32,
370        num_channels: u32,
371    ) -> Self {
372        let input_frame_size =
373            ((input_sample_rate as f64) * duration.as_secs_f64() * num_channels as f64) as usize;
374        let output_frame_size =
375            ((output_sample_rate as f64) * duration.as_secs_f64() * num_channels as f64) as usize;
376        Self {
377            inner,
378            session,
379            buffer: Vec::with_capacity(input_frame_size),
380            output_sample_rate,
381            num_channels,
382            input_frame_size,
383            output_frame_size,
384        }
385    }
386
387    pub fn update_stream_info(&mut self, info: AudioFilterStreamInfo) {
388        self.session.update_stream_info(info);
389    }
390}
391
392impl Stream for AudioFilterAudioStream {
393    type Item = AudioFrame<'static>;
394
395    fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
396        let this = self.get_mut();
397
398        while let Poll::Ready(frame) = Pin::new(&mut this.inner).poll_next(cx) {
399            let Some(frame) = frame else {
400                return Poll::Ready(None);
401            };
402            this.buffer.extend_from_slice(&frame.data);
403
404            if this.buffer.len() >= this.input_frame_size {
405                let data = this.buffer.drain(..this.input_frame_size).collect::<Vec<_>>();
406                let mut out: Vec<i16> = vec![0; this.output_frame_size];
407
408                this.session.process_i16(
409                    this.input_frame_size,
410                    &data,
411                    this.output_frame_size,
412                    &mut out,
413                );
414
415                return Poll::Ready(Some(AudioFrame {
416                    data: out.into(),
417                    sample_rate: this.output_sample_rate,
418                    num_channels: this.num_channels,
419                    samples_per_channel: (this.output_frame_size / this.num_channels as usize)
420                        as u32,
421                }));
422            }
423        }
424
425        Poll::Pending
426    }
427}
428
429#[derive(Debug, Serialize, Default, Clone)]
430#[serde(rename_all = "camelCase")]
431pub struct AudioFilterStreamInfo {
432    pub url: String,
433    pub room_id: String,
434    pub room_name: String,
435    pub participant_identity: String,
436    pub participant_id: String,
437    pub track_id: String,
438}
439
440// The function pointers in this struct are initialized only once during construction
441// and remain read-only throughout the lifetime of the struct, ensuring thread safety.
442unsafe impl Send for AudioFilterPlugin {}
443unsafe impl Sync for AudioFilterPlugin {}
444unsafe impl Send for AudioFilterSession {}
445unsafe impl Sync for AudioFilterSession {}