#![allow(missing_docs)]
extern crate alloc;
use windows::Win32::Media::Audio::Apo::{
IAudioMediaType, IAudioMediaType_Impl, UNCOMPRESSEDAUDIOFORMAT,
};
use windows::Win32::Media::Audio::{WAVEFORMATEX, WAVEFORMATEXTENSIBLE};
use windows_core::{implement, ComObject, Ref, BOOL, HRESULT};
use crate::error::HResult;
use crate::format::{Format, FormatNegotiation};
use crate::instance::AnyApoInstance;
#[implement(IAudioMediaType)]
pub struct FormatMediaType {
wfx: WAVEFORMATEXTENSIBLE,
}
impl FormatMediaType {
#[must_use]
pub fn new(format: &Format) -> Self {
let mut wfx: WAVEFORMATEXTENSIBLE = unsafe { core::mem::zeroed() };
if format.is_extensible() {
wfx = format.to_waveformatextensible();
} else {
wfx.Format = format.to_waveformatex();
}
Self { wfx }
}
}
impl IAudioMediaType_Impl for FormatMediaType_Impl {
fn IsCompressedFormat(&self) -> windows_core::Result<BOOL> {
Ok(BOOL::from(false))
}
fn IsEqual(&self, _piaudiotype: Ref<IAudioMediaType>) -> windows_core::Result<u32> {
Err(windows_core::Error::new(
HRESULT::from(HResult::E_NOTIMPL),
"FormatMediaType::IsEqual is not part of the bridge surface",
))
}
fn GetAudioFormat(&self) -> *mut WAVEFORMATEX {
core::ptr::addr_of!(self.wfx.Format) as *mut WAVEFORMATEX
}
fn GetUncompressedAudioFormat(
&self,
_puncompressedaudioformat: *mut UNCOMPRESSEDAUDIOFORMAT,
) -> windows_core::Result<()> {
Err(windows_core::Error::new(
HRESULT::from(HResult::E_NOTIMPL),
"FormatMediaType::GetUncompressedAudioFormat is not part of the bridge surface",
))
}
}
#[must_use]
pub fn media_type_from_format(format: &Format) -> IAudioMediaType {
ComObject::new(FormatMediaType::new(format)).into_interface()
}
pub fn format_from_media_type(media: Ref<'_, IAudioMediaType>) -> windows_core::Result<Format> {
let Some(mt) = media.as_ref() else {
return Err(windows_core::Error::new(
HRESULT::from(HResult::E_POINTER),
"IAudioMediaType reference was null",
));
};
let wf_ptr = unsafe { mt.GetAudioFormat() };
if wf_ptr.is_null() {
return Err(windows_core::Error::new(
HRESULT::from(HResult::APOERR_FORMAT_NOT_SUPPORTED),
"IAudioMediaType::GetAudioFormat returned null",
));
}
Ok(unsafe { Format::from_waveformatex_ptr(wf_ptr) })
}
#[derive(Copy, Clone, Debug)]
pub enum NegotiationDirection {
Input,
Output,
}
pub fn negotiate_format(
instance: &dyn AnyApoInstance,
requested: Ref<'_, IAudioMediaType>,
direction: NegotiationDirection,
) -> windows_core::Result<IAudioMediaType> {
let requested_format = format_from_media_type(requested)?;
let decision = match direction {
NegotiationDirection::Input => instance.is_input_format_supported(&requested_format),
NegotiationDirection::Output => instance.is_output_format_supported(&requested_format),
};
match decision {
FormatNegotiation::Accept => Ok(media_type_from_format(&requested_format)),
FormatNegotiation::Suggest(alt) => Ok(media_type_from_format(&alt)),
FormatNegotiation::Reject => Err(windows_core::Error::new(
HRESULT::from(HResult::APOERR_FORMAT_NOT_SUPPORTED),
"ProcessingObject rejected the requested format with no alternative",
)),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::apo::{ApoCategory, ProcessInput, ProcessingObject};
use crate::buffer::BufferFlags;
use crate::clsid::Clsid;
use crate::format::WAVE_FORMAT_IEEE_FLOAT;
use crate::instance::ApoInstance;
use crate::realtime::RealtimeContext;
use alloc::sync::Arc;
struct AcceptFloat32;
impl ProcessingObject for AcceptFloat32 {
const CLSID: Clsid = Clsid::from_u128(0x11111111_2222_3333_4444_555555555555);
const NAME: &'static str = "accept-float32";
const COPYRIGHT: &'static str = "test";
const CATEGORY: ApoCategory = ApoCategory::Sfx;
fn new() -> Self {
Self
}
fn process(
&mut self,
_rt: &RealtimeContext,
input: ProcessInput<'_>,
output: &mut [f32],
) -> BufferFlags {
output.copy_from_slice(input.samples());
input.flags()
}
}
struct RejectEverything;
impl ProcessingObject for RejectEverything {
const CLSID: Clsid = Clsid::from_u128(0x66666666_7777_8888_9999_AAAAAAAAAAAA);
const NAME: &'static str = "reject-everything";
const COPYRIGHT: &'static str = "test";
const CATEGORY: ApoCategory = ApoCategory::Sfx;
fn new() -> Self {
Self
}
fn is_input_format_supported(&self, _format: &Format) -> FormatNegotiation {
FormatNegotiation::Reject
}
fn is_output_format_supported(&self, _format: &Format) -> FormatNegotiation {
FormatNegotiation::Reject
}
fn process(
&mut self,
_rt: &RealtimeContext,
_input: ProcessInput<'_>,
output: &mut [f32],
) -> BufferFlags {
output.fill(0.0);
BufferFlags::SILENT
}
}
fn read_format(media: &IAudioMediaType) -> Format {
unsafe {
let wf = media.GetAudioFormat();
assert!(!wf.is_null());
Format::from_waveformatex(&*wf)
}
}
#[test]
fn format_media_type_round_trips_via_get_audio_format() {
let f = Format::pcm_float32(48_000, 1);
let media = media_type_from_format(&f);
let echoed = read_format(&media);
assert_eq!(echoed, f);
}
#[test]
fn format_media_type_reports_uncompressed() {
let f = Format::pcm_int16(44_100, 2);
let media = media_type_from_format(&f);
let compressed = unsafe { media.IsCompressedFormat() }.unwrap();
assert!(!compressed.as_bool());
}
#[test]
fn format_from_media_type_reads_back_the_requested_format() {
let requested = media_type_from_format(&Format::pcm_int16(48_000, 1));
let r = Ref::from(&requested);
let parsed = format_from_media_type(r).unwrap();
assert_eq!(parsed, Format::pcm_int16(48_000, 1));
}
#[test]
fn format_from_media_type_rejects_null_reference() {
let r: Ref<'_, IAudioMediaType> = Ref::default();
let err = format_from_media_type(r).unwrap_err();
assert_eq!(err.code(), HRESULT::from(HResult::E_POINTER));
}
#[test]
fn negotiate_format_accept_echoes_requested_for_float32() {
let inst: Arc<dyn AnyApoInstance> = Arc::new(ApoInstance::<AcceptFloat32>::new());
let requested = media_type_from_format(&Format::pcm_float32(48_000, 1));
let r = Ref::from(&requested);
let answer = negotiate_format(inst.as_ref(), r, NegotiationDirection::Input).unwrap();
assert_eq!(read_format(&answer), Format::pcm_float32(48_000, 1));
}
#[test]
fn negotiate_format_suggest_returns_float32_alternative_for_int16() {
let inst: Arc<dyn AnyApoInstance> = Arc::new(ApoInstance::<AcceptFloat32>::new());
let requested = media_type_from_format(&Format::pcm_int16(48_000, 1));
let r = Ref::from(&requested);
let answer = negotiate_format(inst.as_ref(), r, NegotiationDirection::Input).unwrap();
let suggested = read_format(&answer);
assert_eq!(suggested.format_tag(), WAVE_FORMAT_IEEE_FLOAT);
assert_eq!(suggested.bits_per_sample(), 32);
assert_eq!(suggested.sample_rate(), 48_000);
assert_eq!(suggested.channels(), 1);
}
#[test]
fn negotiate_format_output_direction_routes_through_is_output() {
let inst: Arc<dyn AnyApoInstance> = Arc::new(ApoInstance::<AcceptFloat32>::new());
let requested = media_type_from_format(&Format::pcm_float32(44_100, 2));
let r = Ref::from(&requested);
let answer = negotiate_format(inst.as_ref(), r, NegotiationDirection::Output).unwrap();
assert_eq!(read_format(&answer), Format::pcm_float32(44_100, 2));
}
#[test]
fn negotiate_format_reject_surfaces_apoerr_format_not_supported() {
let inst: Arc<dyn AnyApoInstance> = Arc::new(ApoInstance::<RejectEverything>::new());
let requested = media_type_from_format(&Format::pcm_float32(48_000, 1));
let r = Ref::from(&requested);
let err = negotiate_format(inst.as_ref(), r, NegotiationDirection::Input).unwrap_err();
assert_eq!(
err.code(),
HRESULT::from(HResult::APOERR_FORMAT_NOT_SUPPORTED)
);
}
#[test]
fn negotiate_format_propagates_null_requested_as_e_pointer() {
let inst: Arc<dyn AnyApoInstance> = Arc::new(ApoInstance::<AcceptFloat32>::new());
let r: Ref<'_, IAudioMediaType> = Ref::default();
let err = negotiate_format(inst.as_ref(), r, NegotiationDirection::Input).unwrap_err();
assert_eq!(err.code(), HRESULT::from(HResult::E_POINTER));
}
}