livekit/
plugin.rs

1// Copyright 2025 LiveKit, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::{
16    collections::HashMap,
17    ffi::{c_char, c_void, CString},
18    pin::Pin,
19    sync::{Arc, LazyLock},
20    task::{Context, Poll},
21    time::Duration,
22};
23
24use futures_util::Stream;
25use libloading::{Library, Symbol};
26use libwebrtc::{audio_stream::native::NativeAudioStream, prelude::AudioFrame};
27use parking_lot::RwLock;
28use serde::Serialize;
29use serde_json::json;
30
31#[derive(Debug, thiserror::Error)]
32pub enum PluginError {
33    #[error("dylib error: {0}")]
34    Library(#[from] libloading::Error),
35    #[error("dylib error: {0}")]
36    NotImplemented(String),
37    #[error("on_load failed with error: {0}")]
38    OnLoad(i32),
39}
40
41type OnLoadFn = unsafe extern "C" fn(options: *const c_char) -> i32;
42type CreateFn = unsafe extern "C" fn(
43    sampling_rate: u32,
44    options: *const c_char,
45    stream_info: *const c_char,
46) -> *mut c_void;
47type DestroyFn = unsafe extern "C" fn(*const c_void);
48type ProcessI16Fn = unsafe extern "C" fn(*const c_void, usize, *const i16, *mut i16);
49type ProcessF32Fn = unsafe extern "C" fn(*const c_void, usize, *const f32, *mut f32);
50type UpdateStreamInfoFn = unsafe extern "C" fn(*const c_void, *const c_char);
51type UpdateRefreshedTokenFn = unsafe extern "C" fn(*const c_char, *const c_char);
52
53static REGISTERED_PLUGINS: LazyLock<RwLock<HashMap<String, Arc<AudioFilterPlugin>>>> =
54    LazyLock::new(|| RwLock::new(HashMap::new()));
55
56pub fn register_audio_filter_plugin(id: String, plugin: Arc<AudioFilterPlugin>) {
57    REGISTERED_PLUGINS.write().insert(id, plugin);
58}
59
60pub fn registered_audio_filter_plugin(id: &str) -> Option<Arc<AudioFilterPlugin>> {
61    REGISTERED_PLUGINS.read().get(id).cloned()
62}
63
64pub fn registered_audio_filter_plugins() -> Vec<Arc<AudioFilterPlugin>> {
65    REGISTERED_PLUGINS.read().values().map(|v| v.clone()).collect()
66}
67
68pub struct AudioFilterPlugin {
69    lib: Library,
70    dependencies: Vec<Library>,
71    on_load_fn_ptr: *const c_void,
72    create_fn_ptr: *const c_void,
73    destroy_fn_ptr: *const c_void,
74    process_i16_fn_ptr: *const c_void,
75    process_f32_fn_ptr: *const c_void,
76    update_stream_info_fn_ptr: *const c_void,
77    update_token_fn_ptr: *const c_void,
78}
79
80impl AudioFilterPlugin {
81    pub fn new<P: AsRef<str>>(path: P) -> Result<Arc<Self>, PluginError> {
82        Ok(Arc::new(Self::_new(path)?))
83    }
84
85    pub fn new_with_dependencies<P: AsRef<str>>(
86        path: P,
87        dependencies: Vec<P>,
88    ) -> Result<Arc<Self>, PluginError> {
89        let mut libs = vec![];
90        for path in dependencies {
91            let lib = unsafe { Library::new(path.as_ref()) }?;
92            libs.push(lib);
93        }
94        let mut this = Self::_new(path)?;
95        this.dependencies = libs;
96        Ok(Arc::new(this))
97    }
98
99    fn _new<P: AsRef<str>>(path: P) -> Result<Self, PluginError> {
100        let lib = unsafe { Library::new(path.as_ref()) }?;
101
102        let on_load_fn_ptr = unsafe {
103            lib.get::<Symbol<OnLoadFn>>(b"audio_filter_on_load")?.try_as_raw_ptr().unwrap()
104        };
105
106        let create_fn_ptr = unsafe {
107            lib.get::<Symbol<CreateFn>>(b"audio_filter_create")?.try_as_raw_ptr().unwrap()
108        };
109        if create_fn_ptr.is_null() {
110            return Err(PluginError::NotImplemented(
111                "audio_filter_create is not implemented".into(),
112            ));
113        }
114        let destroy_fn_ptr = unsafe {
115            lib.get::<Symbol<DestroyFn>>(b"audio_filter_destroy")?.try_as_raw_ptr().unwrap()
116        };
117        if destroy_fn_ptr.is_null() {
118            return Err(PluginError::NotImplemented(
119                "audio_filter_destroy is not implemented".into(),
120            ));
121        }
122        let process_i16_fn_ptr = unsafe {
123            lib.get::<Symbol<ProcessI16Fn>>(b"audio_filter_process_int16")?
124                .try_as_raw_ptr()
125                .unwrap()
126        };
127        if process_i16_fn_ptr.is_null() {
128            return Err(PluginError::NotImplemented(
129                "audio_filter_process_int16 is not implemented".into(),
130            ));
131        }
132        let process_f32_fn_ptr = unsafe {
133            lib.get::<Symbol<ProcessF32Fn>>(b"audio_filter_process_float")?
134                .try_as_raw_ptr()
135                .unwrap()
136        };
137        let update_stream_info_fn_ptr = unsafe {
138            lib.get::<Symbol<UpdateStreamInfoFn>>(b"audio_filter_update_stream_info")?
139                .try_as_raw_ptr()
140                .unwrap()
141        };
142        let update_token_fn_ptr = unsafe {
143            // treat as optional function for now
144            match lib.get::<Symbol<UpdateRefreshedTokenFn>>(b"audio_filter_update_token") {
145                Ok(sym) => sym.try_as_raw_ptr().unwrap(),
146                Err(_) => std::ptr::null(),
147            }
148        };
149
150        Ok(Self {
151            lib,
152            dependencies: Default::default(),
153            on_load_fn_ptr,
154            create_fn_ptr,
155            destroy_fn_ptr,
156            process_i16_fn_ptr,
157            process_f32_fn_ptr,
158            update_stream_info_fn_ptr,
159            update_token_fn_ptr,
160        })
161    }
162
163    pub fn on_load<S: AsRef<str>>(&self, url: S, token: S) -> Result<(), PluginError> {
164        if self.on_load_fn_ptr.is_null() {
165            // on_load is optional function
166            return Ok(());
167        }
168
169        let options_json = json!({
170            "url": url.as_ref().to_string(),
171            "token": token.as_ref().to_string(),
172        });
173        let options = serde_json::to_string(&options_json).map_err(|e| {
174            eprintln!("failed to serialize option: {}", e);
175            PluginError::OnLoad(-1)
176        })?;
177
178        let options = CString::new(options).unwrap_or(CString::new("").unwrap());
179        let on_load_fn: OnLoadFn = unsafe { std::mem::transmute(self.on_load_fn_ptr) };
180
181        let res = unsafe { on_load_fn(options.as_ptr()) };
182        if res == 0 {
183            Ok(())
184        } else {
185            Err(PluginError::OnLoad(res))
186        }
187    }
188
189    pub fn update_token(&self, url: String, token: String) {
190        if self.update_token_fn_ptr.is_null() {
191            return;
192        }
193        let update_token_fn: UpdateRefreshedTokenFn =
194            unsafe { std::mem::transmute(self.update_token_fn_ptr) };
195        let url = CString::new(url).unwrap();
196        let token = CString::new(token).unwrap();
197        unsafe { update_token_fn(url.as_ptr(), token.as_ptr()) }
198    }
199
200    pub fn new_session<S: AsRef<str>>(
201        self: Arc<Self>,
202        sampling_rate: u32,
203        options: S,
204        stream_info: AudioFilterStreamInfo,
205    ) -> Option<AudioFilterSession> {
206        let create_fn: CreateFn = unsafe { std::mem::transmute(self.create_fn_ptr) };
207
208        let options = CString::new(options.as_ref()).unwrap_or(CString::new("").unwrap());
209
210        let stream_info = serde_json::to_string(&stream_info).unwrap();
211        let stream_info = CString::new(stream_info).unwrap_or(CString::new("").unwrap());
212
213        let ptr = unsafe { create_fn(sampling_rate, options.as_ptr(), stream_info.as_ptr()) };
214        if ptr.is_null() {
215            return None;
216        }
217
218        Some(AudioFilterSession { plugin: self.clone(), ptr })
219    }
220}
221
222pub struct AudioFilterSession {
223    plugin: Arc<AudioFilterPlugin>,
224    ptr: *const c_void,
225}
226
227impl AudioFilterSession {
228    pub fn destroy(&self) {
229        let destroy: DestroyFn = unsafe { std::mem::transmute(self.plugin.destroy_fn_ptr) };
230        unsafe { destroy(self.ptr) };
231    }
232
233    pub fn process_i16(&self, num_samples: usize, input: &[i16], output: &mut [i16]) {
234        let process: ProcessI16Fn = unsafe { std::mem::transmute(self.plugin.process_i16_fn_ptr) };
235        unsafe { process(self.ptr, num_samples, input.as_ptr(), output.as_mut_ptr()) };
236    }
237
238    pub fn process_f32(&self, num_samples: usize, input: &[f32], output: &mut [f32]) {
239        let process: ProcessF32Fn = unsafe { std::mem::transmute(self.plugin.process_f32_fn_ptr) };
240        unsafe { process(self.ptr, num_samples, input.as_ptr(), output.as_mut_ptr()) };
241    }
242
243    pub fn update_stream_info(&self, info: AudioFilterStreamInfo) {
244        if self.plugin.update_stream_info_fn_ptr.is_null() {
245            return;
246        }
247        let update_stream_info_fn: UpdateStreamInfoFn =
248            unsafe { std::mem::transmute(self.plugin.update_stream_info_fn_ptr) };
249        let info_json = serde_json::to_string(&info).unwrap();
250        let info_json = CString::new(info_json).unwrap_or(CString::new("").unwrap());
251        unsafe { update_stream_info_fn(self.ptr, info_json.as_ptr()) }
252    }
253}
254
255impl Drop for AudioFilterSession {
256    fn drop(&mut self) {
257        if !self.ptr.is_null() {
258            self.destroy();
259        }
260    }
261}
262
263pub struct AudioFilterAudioStream {
264    inner: NativeAudioStream,
265    session: AudioFilterSession,
266    buffer: Vec<i16>,
267    sample_rate: u32,
268    num_channels: u32,
269    frame_size: usize,
270}
271
272impl AudioFilterAudioStream {
273    pub fn new(
274        inner: NativeAudioStream,
275        session: AudioFilterSession,
276        duration: Duration,
277        sample_rate: u32,
278        num_channels: u32,
279    ) -> Self {
280        let frame_size =
281            ((sample_rate as f64) * duration.as_secs_f64() * num_channels as f64) as usize;
282        Self {
283            inner,
284            session,
285            buffer: Vec::with_capacity(frame_size),
286            sample_rate,
287            num_channels,
288            frame_size,
289        }
290    }
291
292    pub fn update_stream_info(&mut self, info: AudioFilterStreamInfo) {
293        self.session.update_stream_info(info);
294    }
295}
296
297impl Stream for AudioFilterAudioStream {
298    type Item = AudioFrame<'static>;
299
300    fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
301        let this = self.get_mut();
302
303        while let Poll::Ready(frame) = Pin::new(&mut this.inner).poll_next(cx) {
304            let Some(frame) = frame else {
305                return Poll::Ready(None);
306            };
307            this.buffer.extend_from_slice(&frame.data);
308
309            if this.buffer.len() >= this.frame_size {
310                let data = this.buffer.drain(..this.frame_size).collect::<Vec<_>>();
311                let mut out: Vec<i16> = vec![0; this.frame_size];
312
313                this.session.process_i16(this.frame_size, &data, &mut out);
314
315                return Poll::Ready(Some(AudioFrame {
316                    data: out.into(),
317                    sample_rate: this.sample_rate,
318                    num_channels: this.num_channels,
319                    samples_per_channel: (this.frame_size / this.num_channels as usize) as u32,
320                }));
321            }
322        }
323
324        Poll::Pending
325    }
326}
327
328#[derive(Debug, Serialize, Default, Clone)]
329#[serde(rename_all = "camelCase")]
330pub struct AudioFilterStreamInfo {
331    pub url: String,
332    pub room_id: String,
333    pub room_name: String,
334    pub participant_identity: String,
335    pub participant_id: String,
336    pub track_id: String,
337}
338
339// The function pointers in this struct are initialized only once during construction
340// and remain read-only throughout the lifetime of the struct, ensuring thread safety.
341unsafe impl Send for AudioFilterPlugin {}
342unsafe impl Sync for AudioFilterPlugin {}
343unsafe impl Send for AudioFilterSession {}
344unsafe impl Sync for AudioFilterSession {}