use core::cell::UnsafeCell;
extern crate alloc;
use alloc::vec::Vec;
use crate::apo::{ApoCategory, ProcessInput, ProcessingObject, SystemEffect, SystemEffectState};
use crate::buffer::BufferFlags;
use crate::clsid::Clsid;
use crate::error::HResult;
use crate::format::{Format, FormatNegotiation};
use crate::realtime::{RealtimeContext, Refcount, State, StateCell};
#[derive(Copy, Clone, PartialEq, Eq, Debug)]
pub struct LockedFormats {
pub input: Format,
pub output: Format,
}
pub trait AnyApoInstance: Send + Sync {
fn add_ref(&self) -> u32;
fn release(&self) -> u32;
fn refcount(&self) -> u32;
fn state(&self) -> State;
fn initialize(&self) -> Result<(), HResult>;
fn lock_for_process(&self, input: &Format, output: &Format) -> Result<(), HResult>;
fn unlock_for_process(&self) -> Result<(), HResult>;
fn is_input_format_supported(&self, format: &Format) -> FormatNegotiation;
fn is_output_format_supported(&self, format: &Format) -> FormatNegotiation;
fn process(
&self,
rt: &RealtimeContext,
input: ProcessInput<'_>,
output: &mut [f32],
) -> Result<BufferFlags, HResult>;
fn locked_formats(&self) -> Option<LockedFormats>;
fn clsid(&self) -> Clsid;
fn name(&self) -> &'static str;
fn copyright(&self) -> &'static str;
fn category(&self) -> ApoCategory;
fn system_effects(&self) -> Vec<SystemEffect>;
fn set_system_effect_state(&self, id: &Clsid, state: SystemEffectState);
}
pub struct ApoInstance<T: ProcessingObject> {
inner: UnsafeCell<T>,
state: StateCell,
refcount: Refcount,
locked_formats: UnsafeCell<Option<LockedFormats>>,
}
unsafe impl<T: ProcessingObject> Sync for ApoInstance<T> {}
impl<T: ProcessingObject> ApoInstance<T> {
#[must_use]
pub fn new() -> Self {
Self {
inner: UnsafeCell::new(T::new()),
state: StateCell::new(),
refcount: Refcount::new(),
locked_formats: UnsafeCell::new(None),
}
}
#[inline]
#[must_use]
pub fn state(&self) -> State {
self.state.load()
}
#[inline]
#[allow(dead_code)] pub(crate) fn inner_cell(&self) -> &UnsafeCell<T> {
&self.inner
}
#[inline]
#[must_use]
pub fn refcount(&self) -> u32 {
self.refcount.count()
}
#[inline]
pub fn add_ref(&self) -> u32 {
self.refcount.add_ref()
}
#[inline]
pub fn release(&self) -> u32 {
self.refcount.release()
}
pub fn initialize(&self) -> Result<(), HResult> {
self.state
.initialize()
.map_err(|_| HResult::APOERR_ALREADY_LOCKED)
}
pub fn is_input_format_supported(&self, format: &Format) -> FormatNegotiation {
let inner = unsafe { &*self.inner.get() };
inner.is_input_format_supported(format)
}
pub fn is_output_format_supported(&self, format: &Format) -> FormatNegotiation {
let inner = unsafe { &*self.inner.get() };
inner.is_output_format_supported(format)
}
pub fn lock_for_process(&self, input: &Format, output: &Format) -> Result<(), HResult> {
self.state.lock().map_err(|err| match err.actual {
State::Uninitialized => HResult::APOERR_NOT_LOCKED,
State::Initialized => HResult::E_FAIL, State::Locked => HResult::APOERR_ALREADY_LOCKED,
})?;
let inner = unsafe { &mut *self.inner.get() };
match inner.lock_for_process(input, output) {
Ok(()) => {
unsafe {
*self.locked_formats.get() = Some(LockedFormats {
input: *input,
output: *output,
});
}
Ok(())
}
Err(e) => {
let _ = self.state.unlock();
Err(e)
}
}
}
pub fn unlock_for_process(&self) -> Result<(), HResult> {
if self.state.load() != State::Locked {
return Err(HResult::APOERR_NOT_LOCKED);
}
let inner = unsafe { &mut *self.inner.get() };
inner.unlock_for_process();
unsafe {
*self.locked_formats.get() = None;
}
self.state
.unlock()
.map_err(|_| HResult::APOERR_NOT_LOCKED)?;
Ok(())
}
#[inline]
#[must_use]
pub fn locked_formats(&self) -> Option<LockedFormats> {
if self.state.load() != State::Locked {
return None;
}
unsafe { *self.locked_formats.get() }
}
pub fn process(
&self,
rt: &RealtimeContext,
input: ProcessInput<'_>,
output: &mut [f32],
) -> Result<BufferFlags, HResult> {
if !self.state.is_locked() {
return Err(HResult::APOERR_NOT_LOCKED);
}
let inner = unsafe { &mut *self.inner.get() };
Ok(inner.process(rt, input, output))
}
}
impl<T: ProcessingObject> Default for ApoInstance<T> {
fn default() -> Self {
Self::new()
}
}
impl<T: ProcessingObject> AnyApoInstance for ApoInstance<T> {
#[inline]
fn add_ref(&self) -> u32 {
Self::add_ref(self)
}
#[inline]
fn release(&self) -> u32 {
Self::release(self)
}
#[inline]
fn refcount(&self) -> u32 {
Self::refcount(self)
}
#[inline]
fn state(&self) -> State {
Self::state(self)
}
#[inline]
fn initialize(&self) -> Result<(), HResult> {
Self::initialize(self)
}
#[inline]
fn is_input_format_supported(&self, format: &Format) -> FormatNegotiation {
Self::is_input_format_supported(self, format)
}
#[inline]
fn is_output_format_supported(&self, format: &Format) -> FormatNegotiation {
Self::is_output_format_supported(self, format)
}
#[inline]
fn lock_for_process(&self, input: &Format, output: &Format) -> Result<(), HResult> {
Self::lock_for_process(self, input, output)
}
#[inline]
fn unlock_for_process(&self) -> Result<(), HResult> {
Self::unlock_for_process(self)
}
#[inline]
fn process(
&self,
rt: &RealtimeContext,
input: ProcessInput<'_>,
output: &mut [f32],
) -> Result<BufferFlags, HResult> {
Self::process(self, rt, input, output)
}
#[inline]
fn locked_formats(&self) -> Option<LockedFormats> {
Self::locked_formats(self)
}
#[inline]
fn clsid(&self) -> Clsid {
T::CLSID
}
#[inline]
fn name(&self) -> &'static str {
T::NAME
}
#[inline]
fn copyright(&self) -> &'static str {
T::COPYRIGHT
}
#[inline]
fn category(&self) -> ApoCategory {
T::CATEGORY
}
#[inline]
fn system_effects(&self) -> Vec<SystemEffect> {
let inner = unsafe { &*self.inner.get() };
inner.system_effects().to_vec()
}
#[inline]
fn set_system_effect_state(&self, id: &Clsid, state: SystemEffectState) {
let inner = unsafe { &mut *self.inner.get() };
inner.set_system_effect_state(id, state);
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::apo::{ApoCategory, ProcessInput, ProcessingObject};
use crate::buffer::BufferFlags;
use crate::clsid::Clsid;
use crate::error::HResult;
use crate::format::Format;
use crate::realtime::{RealtimeContext, State};
use core::cell::Cell;
use static_assertions::assert_impl_all;
struct Trace {
lock_seen: Cell<Option<(u32, u32)>>, unlock_seen: Cell<u32>,
process_seen: Cell<u32>,
lock_should_fail: Cell<bool>,
}
impl ProcessingObject for Trace {
const CLSID: Clsid = Clsid::from_u128(0x01234567_89AB_CDEF_0123_456789ABCDEF);
const NAME: &'static str = "tympan-apo trace";
const COPYRIGHT: &'static str = "test fixture";
const CATEGORY: ApoCategory = ApoCategory::Sfx;
fn new() -> Self {
Self {
lock_seen: Cell::new(None),
unlock_seen: Cell::new(0),
process_seen: Cell::new(0),
lock_should_fail: Cell::new(false),
}
}
fn lock_for_process(&mut self, input: &Format, output: &Format) -> Result<(), HResult> {
if self.lock_should_fail.get() {
return Err(HResult::APOERR_FORMAT_NOT_SUPPORTED);
}
self.lock_seen
.set(Some((input.sample_rate(), output.sample_rate())));
Ok(())
}
fn unlock_for_process(&mut self) {
self.unlock_seen.set(self.unlock_seen.get() + 1);
}
fn process(
&mut self,
_rt: &RealtimeContext,
input: ProcessInput<'_>,
output: &mut [f32],
) -> BufferFlags {
self.process_seen.set(self.process_seen.get() + 1);
output.copy_from_slice(input.samples());
input.flags()
}
}
assert_impl_all!(ApoInstance<Trace>: Sync);
fn rt() -> RealtimeContext {
unsafe { RealtimeContext::new_unchecked() }
}
#[test]
fn new_starts_uninitialized_with_zero_refcount() {
let apo = ApoInstance::<Trace>::new();
assert_eq!(apo.state(), State::Uninitialized);
assert_eq!(apo.refcount(), 0);
}
#[test]
fn default_matches_new() {
let apo: ApoInstance<Trace> = ApoInstance::default();
assert_eq!(apo.state(), State::Uninitialized);
assert_eq!(apo.refcount(), 0);
}
#[test]
fn add_ref_release_delegate_to_refcount() {
let apo = ApoInstance::<Trace>::new();
assert_eq!(apo.add_ref(), 1);
assert_eq!(apo.add_ref(), 2);
assert_eq!(apo.refcount(), 2);
assert_eq!(apo.release(), 1);
assert_eq!(apo.release(), 0);
}
#[test]
fn initialize_transitions_to_initialized() {
let apo = ApoInstance::<Trace>::new();
assert!(apo.initialize().is_ok());
assert_eq!(apo.state(), State::Initialized);
}
#[test]
fn double_initialize_returns_apoerr_already_locked() {
let apo = ApoInstance::<Trace>::new();
apo.initialize().unwrap();
assert_eq!(apo.initialize(), Err(HResult::APOERR_ALREADY_LOCKED));
}
#[test]
fn lock_requires_initialized() {
let apo = ApoInstance::<Trace>::new();
let f = Format::pcm_float32(48_000, 1);
assert_eq!(
apo.lock_for_process(&f, &f),
Err(HResult::APOERR_NOT_LOCKED)
);
assert_eq!(apo.state(), State::Uninitialized);
}
#[test]
fn lock_for_process_transitions_and_forwards_to_user() {
let apo = ApoInstance::<Trace>::new();
apo.initialize().unwrap();
let input = Format::pcm_float32(48_000, 1);
let output = Format::pcm_float32(44_100, 2);
apo.lock_for_process(&input, &output).unwrap();
assert_eq!(apo.state(), State::Locked);
let trace = unsafe { &*apo.inner.get() };
assert_eq!(trace.lock_seen.get(), Some((48_000, 44_100)));
}
#[test]
fn lock_failure_rolls_state_back_to_initialized() {
let apo = ApoInstance::<Trace>::new();
apo.initialize().unwrap();
unsafe { &*apo.inner.get() }.lock_should_fail.set(true);
let f = Format::pcm_float32(48_000, 1);
assert_eq!(
apo.lock_for_process(&f, &f),
Err(HResult::APOERR_FORMAT_NOT_SUPPORTED)
);
assert_eq!(apo.state(), State::Initialized);
}
#[test]
fn unlock_for_process_returns_to_initialized() {
let apo = ApoInstance::<Trace>::new();
apo.initialize().unwrap();
let f = Format::pcm_float32(48_000, 1);
apo.lock_for_process(&f, &f).unwrap();
apo.unlock_for_process().unwrap();
assert_eq!(apo.state(), State::Initialized);
let trace = unsafe { &*apo.inner.get() };
assert_eq!(trace.unlock_seen.get(), 1);
}
#[test]
fn unlock_without_lock_fails() {
let apo = ApoInstance::<Trace>::new();
assert_eq!(apo.unlock_for_process(), Err(HResult::APOERR_NOT_LOCKED));
}
#[test]
fn process_requires_locked_state() {
let apo = ApoInstance::<Trace>::new();
let samples = [0.0_f32; 4];
let mut output = [0.0_f32; 4];
let rt = rt();
let result = apo.process(
&rt,
ProcessInput::new(&samples, BufferFlags::VALID),
&mut output,
);
assert_eq!(result, Err(HResult::APOERR_NOT_LOCKED));
}
#[test]
fn process_after_lock_returns_user_flags_and_copies_samples() {
let apo = ApoInstance::<Trace>::new();
apo.initialize().unwrap();
let f = Format::pcm_float32(48_000, 1);
apo.lock_for_process(&f, &f).unwrap();
let samples = [0.1_f32, -0.2, 0.3, -0.4];
let mut output = [0.0_f32; 4];
let rt = rt();
let out = apo
.process(
&rt,
ProcessInput::new(&samples, BufferFlags::SILENT),
&mut output,
)
.unwrap();
assert_eq!(out, BufferFlags::SILENT);
assert_eq!(output, samples);
let trace = unsafe { &*apo.inner.get() };
assert_eq!(trace.process_seen.get(), 1);
}
#[test]
fn full_lifecycle_round_trip() {
let apo = ApoInstance::<Trace>::new();
apo.initialize().unwrap();
let f = Format::pcm_float32(48_000, 1);
apo.lock_for_process(&f, &f).unwrap();
let samples = [0.5_f32; 4];
let mut output = [0.0_f32; 4];
let rt = rt();
for _ in 0..3 {
apo.process(
&rt,
ProcessInput::new(&samples, BufferFlags::VALID),
&mut output,
)
.unwrap();
}
apo.unlock_for_process().unwrap();
assert_eq!(apo.state(), State::Initialized);
apo.lock_for_process(&f, &f).unwrap();
apo.unlock_for_process().unwrap();
assert_eq!(apo.state(), State::Initialized);
let trace = unsafe { &*apo.inner.get() };
assert_eq!(trace.process_seen.get(), 3);
assert_eq!(trace.unlock_seen.get(), 2);
}
#[test]
fn is_input_format_supported_uses_user_default() {
let apo = ApoInstance::<Trace>::new();
let f = Format::pcm_float32(48_000, 1);
assert_eq!(
apo.is_input_format_supported(&f),
crate::format::FormatNegotiation::Accept
);
let f = Format::pcm_int16(48_000, 1);
match apo.is_input_format_supported(&f) {
crate::format::FormatNegotiation::Suggest(s) => {
assert!(s.is_float());
assert_eq!(s.bits_per_sample(), 32);
}
other => panic!("expected Suggest, got {other:?}"),
}
}
#[test]
fn locked_formats_cached_during_lock_for_process() {
let apo = ApoInstance::<Trace>::new();
assert_eq!(apo.locked_formats(), None);
apo.initialize().unwrap();
assert_eq!(apo.locked_formats(), None);
let input = Format::pcm_float32(48_000, 1);
let output = Format::pcm_float32(44_100, 2);
apo.lock_for_process(&input, &output).unwrap();
let fmts = apo.locked_formats().unwrap();
assert_eq!(fmts.input, input);
assert_eq!(fmts.output, output);
apo.unlock_for_process().unwrap();
assert_eq!(apo.locked_formats(), None);
}
#[test]
fn type_erased_dispatch_drives_full_lifecycle() {
use std::sync::Arc;
let inst: Arc<dyn AnyApoInstance> = Arc::new(ApoInstance::<Trace>::new());
assert_eq!(inst.state(), State::Uninitialized);
assert_eq!(inst.refcount(), 0);
assert_eq!(inst.add_ref(), 1);
inst.initialize().unwrap();
let f = Format::pcm_float32(48_000, 1);
inst.lock_for_process(&f, &f).unwrap();
assert_eq!(inst.state(), State::Locked);
let samples = [0.25_f32, -0.5, 0.75, -1.0];
let mut output = [0.0_f32; 4];
let rt = rt();
let out_flags = inst
.process(
&rt,
ProcessInput::new(&samples, BufferFlags::VALID),
&mut output,
)
.unwrap();
assert_eq!(out_flags, BufferFlags::VALID);
assert_eq!(output, samples);
inst.unlock_for_process().unwrap();
assert_eq!(inst.state(), State::Initialized);
assert_eq!(inst.release(), 0);
}
}