use crate::types::{AudioChannel, BinauraAudio, Position3D};
use crate::{Error, Result};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::any::Any;
use std::collections::HashMap;
use std::fmt::Debug;
use std::sync::{Arc, RwLock};
#[async_trait]
pub trait SpatialPlugin: Send + Sync + Debug {
fn name(&self) -> &str;
fn version(&self) -> &str;
fn description(&self) -> &str;
fn author(&self) -> &str;
fn capabilities(&self) -> PluginCapabilities;
async fn initialize(&mut self, config: PluginConfig) -> Result<()>;
async fn process_audio(
&self,
audio: &[f32],
listener_position: Position3D,
source_position: Position3D,
context: &ProcessingContext,
) -> Result<Vec<f32>>;
async fn process_binaural(
&self,
audio: &BinauraAudio,
context: &ProcessingContext,
) -> Result<BinauraAudio> {
Ok(audio.clone())
}
async fn update_parameters(&mut self, parameters: PluginParameters) -> Result<()>;
fn get_state(&self) -> PluginState;
async fn cleanup(&mut self) -> Result<()>;
fn as_any(&self) -> &dyn Any;
fn as_any_mut(&mut self) -> &mut dyn Any;
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct PluginCapabilities {
pub supports_mono: bool,
pub supports_stereo: bool,
pub supports_multichannel: bool,
pub supports_binaural: bool,
pub supports_realtime: bool,
pub supports_batch: bool,
pub has_parameters: bool,
pub supports_serialization: bool,
pub requires_gpu: bool,
pub supports_3d_positioning: bool,
pub supports_hrtf: bool,
pub supports_room_simulation: bool,
}
impl Default for PluginCapabilities {
fn default() -> Self {
Self {
supports_mono: true,
supports_stereo: true,
supports_multichannel: false,
supports_binaural: false,
supports_realtime: true,
supports_batch: true,
has_parameters: false,
supports_serialization: false,
requires_gpu: false,
supports_3d_positioning: false,
supports_hrtf: false,
supports_room_simulation: false,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PluginConfig {
pub parameters: HashMap<String, PluginParameter>,
pub sample_rate: f32,
pub buffer_size: usize,
pub channels: usize,
pub use_gpu: bool,
pub realtime_mode: bool,
pub quality_level: f32,
}
impl Default for PluginConfig {
fn default() -> Self {
Self {
parameters: HashMap::new(),
sample_rate: 44100.0,
buffer_size: 1024,
channels: 2,
use_gpu: false,
realtime_mode: true,
quality_level: 0.8,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum PluginParameter {
Bool(bool),
Int(i32),
Float(f32),
String(String),
FloatArray(Vec<f32>),
Object(HashMap<String, PluginParameter>),
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct PluginParameters {
pub parameters: HashMap<String, PluginParameter>,
}
impl PluginParameters {
pub fn new() -> Self {
Self::default()
}
pub fn set_bool(&mut self, key: &str, value: bool) {
self.parameters
.insert(key.to_string(), PluginParameter::Bool(value));
}
pub fn set_int(&mut self, key: &str, value: i32) {
self.parameters
.insert(key.to_string(), PluginParameter::Int(value));
}
pub fn set_float(&mut self, key: &str, value: f32) {
self.parameters
.insert(key.to_string(), PluginParameter::Float(value));
}
pub fn set_string(&mut self, key: &str, value: String) {
self.parameters
.insert(key.to_string(), PluginParameter::String(value));
}
pub fn get_bool(&self, key: &str) -> Option<bool> {
match self.parameters.get(key)? {
PluginParameter::Bool(value) => Some(*value),
_ => None,
}
}
pub fn get_int(&self, key: &str) -> Option<i32> {
match self.parameters.get(key)? {
PluginParameter::Int(value) => Some(*value),
_ => None,
}
}
pub fn get_float(&self, key: &str) -> Option<f32> {
match self.parameters.get(key)? {
PluginParameter::Float(value) => Some(*value),
_ => None,
}
}
pub fn get_string(&self, key: &str) -> Option<&str> {
match self.parameters.get(key)? {
PluginParameter::String(value) => Some(value),
_ => None,
}
}
}
#[derive(Debug, Clone)]
pub struct ProcessingContext {
pub sample_rate: f32,
pub buffer_size: usize,
pub channels: usize,
pub timestamp: std::time::Instant,
pub quality_level: f32,
pub realtime_mode: bool,
pub context_data: HashMap<String, PluginParameter>,
}
impl Default for ProcessingContext {
fn default() -> Self {
Self {
sample_rate: 44100.0,
buffer_size: 1024,
channels: 2,
timestamp: std::time::Instant::now(),
quality_level: 0.8,
realtime_mode: true,
context_data: HashMap::new(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum PluginState {
Uninitialized,
Ready,
Processing,
Paused,
Error(String),
Cleanup,
}
#[derive(Debug)]
pub struct PluginManager {
plugins: Arc<RwLock<HashMap<String, Box<dyn SpatialPlugin>>>>,
configs: Arc<RwLock<HashMap<String, PluginConfig>>>,
chains: Arc<RwLock<HashMap<String, ProcessingChain>>>,
}
impl PluginManager {
pub fn new() -> Self {
Self {
plugins: Arc::new(RwLock::new(HashMap::new())),
configs: Arc::new(RwLock::new(HashMap::new())),
chains: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn register_plugin(
&self,
plugin: Box<dyn SpatialPlugin>,
config: PluginConfig,
) -> Result<()> {
let name = plugin.name().to_string();
{
let mut configs = self
.configs
.write()
.map_err(|_| Error::LegacyAudio("Plugin config lock poisoned".to_string()))?;
configs.insert(name.clone(), config);
}
{
let mut plugins = self
.plugins
.write()
.map_err(|_| Error::LegacyAudio("Plugin lock poisoned".to_string()))?;
plugins.insert(name, plugin);
}
Ok(())
}
pub async fn unregister_plugin(&self, name: &str) -> Result<()> {
let plugin_to_cleanup = {
let mut plugins = self
.plugins
.write()
.map_err(|_| Error::LegacyAudio("Plugin lock poisoned".to_string()))?;
plugins.remove(name)
};
if let Some(mut plugin) = plugin_to_cleanup {
plugin.cleanup().await?;
}
{
let mut configs = self
.configs
.write()
.map_err(|_| Error::LegacyAudio("Plugin config lock poisoned".to_string()))?;
configs.remove(name);
}
Ok(())
}
pub fn get_plugin_names(&self) -> Vec<String> {
match self.plugins.read() {
Ok(plugins) => plugins.keys().cloned().collect(),
Err(_) => {
tracing::warn!("Plugin lock poisoned, returning empty list");
Vec::new()
}
}
}
pub fn has_plugin(&self, name: &str) -> bool {
match self.plugins.read() {
Ok(plugins) => plugins.contains_key(name),
Err(_) => {
tracing::warn!("Plugin lock poisoned, assuming plugin doesn't exist");
false
}
}
}
#[allow(clippy::await_holding_lock)]
pub async fn process_with_plugin(
&self,
plugin_name: &str,
audio: &[f32],
listener_position: Position3D,
source_position: Position3D,
context: &ProcessingContext,
) -> Result<Vec<f32>> {
let plugins = self
.plugins
.read()
.map_err(|_| Error::LegacyAudio("Plugin lock poisoned".to_string()))?;
let plugin = plugins
.get(plugin_name)
.ok_or_else(|| Error::LegacyAudio(format!("Plugin {plugin_name} not found")))?;
plugin
.process_audio(audio, listener_position, source_position, context)
.await
}
#[allow(clippy::await_holding_lock)]
pub async fn process_with_chain(
&self,
chain_name: &str,
audio: &[f32],
listener_position: Position3D,
source_position: Position3D,
context: &ProcessingContext,
) -> Result<Vec<f32>> {
let chains = self
.chains
.read()
.map_err(|_| Error::LegacyAudio("Chain lock poisoned".to_string()))?;
let chain = chains.get(chain_name).ok_or_else(|| {
Error::LegacyAudio(format!("Processing chain {chain_name} not found"))
})?;
self.process_chain(chain, audio, listener_position, source_position, context)
.await
}
pub async fn create_chain(&self, name: &str, plugin_names: Vec<String>) -> Result<()> {
let chain = ProcessingChain {
name: name.to_string(),
plugins: plugin_names,
enabled: true,
};
let mut chains = self
.chains
.write()
.map_err(|_| Error::LegacyAudio("Chain lock poisoned".to_string()))?;
chains.insert(name.to_string(), chain);
Ok(())
}
pub async fn remove_chain(&self, name: &str) -> Result<()> {
let mut chains = self
.chains
.write()
.map_err(|_| Error::LegacyAudio("Chain lock poisoned".to_string()))?;
chains.remove(name);
Ok(())
}
async fn process_chain(
&self,
chain: &ProcessingChain,
mut audio: &[f32],
listener_position: Position3D,
source_position: Position3D,
context: &ProcessingContext,
) -> Result<Vec<f32>> {
if !chain.enabled {
return Ok(audio.to_vec());
}
let mut result = audio.to_vec();
for plugin_name in &chain.plugins {
result = self
.process_with_plugin(
plugin_name,
&result,
listener_position,
source_position,
context,
)
.await?;
}
Ok(result)
}
pub fn get_plugin_capabilities(&self, name: &str) -> Option<PluginCapabilities> {
match self.plugins.read() {
Ok(plugins) => plugins.get(name).map(|plugin| plugin.capabilities()),
Err(_) => {
tracing::warn!("Plugin lock poisoned, returning None for capabilities");
None
}
}
}
#[allow(clippy::await_holding_lock)]
pub async fn update_plugin_parameters(
&self,
plugin_name: &str,
parameters: PluginParameters,
) -> Result<()> {
let mut plugins = self
.plugins
.write()
.map_err(|_| Error::LegacyAudio("Plugin lock poisoned".to_string()))?;
let plugin = plugins
.get_mut(plugin_name)
.ok_or_else(|| Error::LegacyAudio(format!("Plugin {plugin_name} not found")))?;
plugin.update_parameters(parameters).await
}
}
impl Default for PluginManager {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProcessingChain {
pub name: String,
pub plugins: Vec<String>,
pub enabled: bool,
}
#[derive(Debug)]
pub struct ReverbPlugin {
name: String,
version: String,
room_size: f32,
damping: f32,
wet_level: f32,
dry_level: f32,
state: PluginState,
}
impl Default for ReverbPlugin {
fn default() -> Self {
Self::new()
}
}
impl ReverbPlugin {
pub fn new() -> Self {
Self {
name: "Spatial Reverb".to_string(),
version: "1.0.0".to_string(),
room_size: 0.5,
damping: 0.5,
wet_level: 0.3,
dry_level: 0.7,
state: PluginState::Uninitialized,
}
}
}
#[async_trait]
impl SpatialPlugin for ReverbPlugin {
fn name(&self) -> &str {
&self.name
}
fn version(&self) -> &str {
&self.version
}
fn description(&self) -> &str {
"Spatial reverb effect for room simulation"
}
fn author(&self) -> &str {
"VoiRS Team"
}
fn capabilities(&self) -> PluginCapabilities {
PluginCapabilities {
supports_mono: true,
supports_stereo: true,
supports_multichannel: true,
supports_binaural: true,
supports_realtime: true,
supports_batch: true,
has_parameters: true,
supports_serialization: true,
requires_gpu: false,
supports_3d_positioning: true,
supports_hrtf: false,
supports_room_simulation: true,
}
}
async fn initialize(&mut self, config: PluginConfig) -> Result<()> {
if let Some(PluginParameter::Float(size)) = config.parameters.get("room_size") {
self.room_size = *size;
}
if let Some(PluginParameter::Float(damping)) = config.parameters.get("damping") {
self.damping = *damping;
}
if let Some(PluginParameter::Float(wet)) = config.parameters.get("wet_level") {
self.wet_level = *wet;
}
if let Some(PluginParameter::Float(dry)) = config.parameters.get("dry_level") {
self.dry_level = *dry;
}
self.state = PluginState::Ready;
Ok(())
}
async fn process_audio(
&self,
audio: &[f32],
listener_position: Position3D,
source_position: Position3D,
context: &ProcessingContext,
) -> Result<Vec<f32>> {
if matches!(self.state, PluginState::Error(_)) {
return Err(Error::LegacyAudio("Plugin is in error state".to_string()));
}
let distance = listener_position.distance_to(&source_position);
let reverb_scale = (distance / 10.0).min(1.0);
let mut output = Vec::with_capacity(audio.len());
for (i, &sample) in audio.iter().enumerate() {
let delayed_sample = if i >= context.buffer_size / 4 {
audio[i - context.buffer_size / 4] * self.room_size * reverb_scale
} else {
0.0
};
let wet = delayed_sample * self.wet_level * reverb_scale;
let dry = sample * self.dry_level;
output.push(dry + wet);
}
Ok(output)
}
async fn update_parameters(&mut self, parameters: PluginParameters) -> Result<()> {
if let Some(size) = parameters.get_float("room_size") {
self.room_size = size.clamp(0.0, 1.0);
}
if let Some(damping) = parameters.get_float("damping") {
self.damping = damping.clamp(0.0, 1.0);
}
if let Some(wet) = parameters.get_float("wet_level") {
self.wet_level = wet.clamp(0.0, 1.0);
}
if let Some(dry) = parameters.get_float("dry_level") {
self.dry_level = dry.clamp(0.0, 1.0);
}
Ok(())
}
fn get_state(&self) -> PluginState {
self.state.clone()
}
async fn cleanup(&mut self) -> Result<()> {
self.state = PluginState::Cleanup;
Ok(())
}
fn as_any(&self) -> &dyn Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn Any {
self
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::test;
#[test]
async fn test_plugin_manager_creation() {
let manager = PluginManager::new();
assert_eq!(manager.get_plugin_names().len(), 0);
}
#[test]
async fn test_plugin_registration() {
let manager = PluginManager::new();
let plugin = Box::new(ReverbPlugin::new());
let config = PluginConfig::default();
manager.register_plugin(plugin, config).await.unwrap();
assert_eq!(manager.get_plugin_names().len(), 1);
assert!(manager.has_plugin("Spatial Reverb"));
}
#[test]
async fn test_plugin_capabilities() {
let plugin = ReverbPlugin::new();
let caps = plugin.capabilities();
assert!(caps.supports_mono);
assert!(caps.supports_stereo);
assert!(caps.supports_room_simulation);
assert!(caps.has_parameters);
}
#[test]
async fn test_plugin_parameters() {
let mut params = PluginParameters::new();
params.set_float("room_size", 0.8);
params.set_bool("enabled", true);
params.set_string("preset", "Hall".to_string());
assert_eq!(params.get_float("room_size"), Some(0.8));
assert_eq!(params.get_bool("enabled"), Some(true));
assert_eq!(params.get_string("preset"), Some("Hall"));
}
#[test]
async fn test_plugin_audio_processing() {
let mut plugin = ReverbPlugin::new();
let config = PluginConfig::default();
plugin.initialize(config).await.unwrap();
let audio = vec![0.5; 1024];
let listener_pos = Position3D::new(0.0, 0.0, 0.0);
let source_pos = Position3D::new(1.0, 0.0, 0.0);
let context = ProcessingContext::default();
let result = plugin
.process_audio(&audio, listener_pos, source_pos, &context)
.await
.unwrap();
assert_eq!(result.len(), audio.len());
}
#[test]
async fn test_processing_chain() {
let manager = PluginManager::new();
let plugin = Box::new(ReverbPlugin::new());
let config = PluginConfig::default();
manager.register_plugin(plugin, config).await.unwrap();
manager
.create_chain("test_chain", vec!["Spatial Reverb".to_string()])
.await
.unwrap();
let audio = vec![0.5; 1024];
let listener_pos = Position3D::new(0.0, 0.0, 0.0);
let source_pos = Position3D::new(1.0, 0.0, 0.0);
let context = ProcessingContext::default();
let result = manager
.process_with_chain("test_chain", &audio, listener_pos, source_pos, &context)
.await
.unwrap();
assert_eq!(result.len(), audio.len());
}
#[test]
async fn test_plugin_cleanup() {
let manager = PluginManager::new();
let plugin = Box::new(ReverbPlugin::new());
let config = PluginConfig::default();
manager.register_plugin(plugin, config).await.unwrap();
assert!(manager.has_plugin("Spatial Reverb"));
manager.unregister_plugin("Spatial Reverb").await.unwrap();
assert!(!manager.has_plugin("Spatial Reverb"));
}
}