1use objc2_audio_toolbox::{
2 kAudioOutputUnitProperty_SetInputCallback, kAudioUnitProperty_SetRenderCallback,
3 kAudioUnitProperty_StreamFormat, AURenderCallbackStruct, AudioUnitRender,
4 AudioUnitRenderActionFlags,
5};
6use objc2_core_audio_types::{AudioBuffer, AudioBufferList, AudioTimeStamp};
7
8use super::audio_format::LinearPcmFlags;
9use super::{AudioUnit, Element, Scope};
10use crate::error::{self, Error};
11use crate::OSStatus;
12use std::mem;
13use std::os::raw::c_void;
14use std::ptr::NonNull;
15use std::slice;
16
17pub use self::action_flags::ActionFlags;
18pub use self::data::Data;
19
20pub type InputProcFn = dyn FnMut(
26 NonNull<AudioUnitRenderActionFlags>,
27 NonNull<AudioTimeStamp>,
28 u32,
29 u32,
30 *mut AudioBufferList,
31) -> OSStatus;
32
33pub struct InputProcFnWrapper {
35 callback: Box<InputProcFn>,
36}
37
38#[derive(Debug)]
40pub struct Args<D> {
41 pub data: D,
43 pub time_stamp: AudioTimeStamp,
45 pub bus_number: u32,
47 pub num_frames: usize,
49 pub flags: action_flags::Handle,
56}
57
58pub mod data {
60 use objc2_core_audio_types::AudioBuffer;
61 use objc2_core_audio_types::AudioBufferList;
62
63 use super::super::Sample;
64 use super::super::StreamFormat;
65 use crate::audio_unit::audio_format::LinearPcmFlags;
66 use std::marker::PhantomData;
67 use std::slice;
68
69 pub trait Data {
71 fn does_stream_format_match(stream_format: &StreamFormat) -> bool;
73 unsafe fn from_input_proc_args(num_frames: u32, io_data: *mut AudioBufferList) -> Self;
77 }
78
79 #[derive(Debug)]
81 pub struct Raw {
82 pub data: *mut AudioBufferList,
83 }
84
85 impl Data for Raw {
86 fn does_stream_format_match(_: &StreamFormat) -> bool {
87 true
88 }
89 unsafe fn from_input_proc_args(_num_frames: u32, io_data: *mut AudioBufferList) -> Self {
90 Raw { data: io_data }
91 }
92 }
93
94 pub struct Interleaved<S: 'static> {
96 pub buffer: &'static mut [S],
98 pub channels: usize,
99 sample_format: PhantomData<S>,
100 }
101
102 pub struct InterleavedBytes<S: 'static> {
104 pub buffer: &'static mut [u8],
106 pub channels: usize,
107 sample_format: PhantomData<S>,
108 }
109
110 pub struct NonInterleaved<S> {
112 buffers: &'static mut [AudioBuffer],
114 frames: usize,
116 sample_format: PhantomData<S>,
117 }
118
119 pub struct Channels<'a, S: 'a> {
121 buffers: slice::Iter<'a, AudioBuffer>,
122 frames: usize,
123 sample_format: PhantomData<S>,
124 }
125
126 pub struct ChannelsMut<'a, S: 'a> {
128 buffers: slice::IterMut<'a, AudioBuffer>,
129 frames: usize,
130 sample_format: PhantomData<S>,
131 }
132
133 unsafe impl<S> Send for NonInterleaved<S> where S: Send {}
134
135 impl<'a, S> Iterator for Channels<'a, S> {
136 type Item = &'a [S];
137 #[allow(non_snake_case)]
138 fn next(&mut self) -> Option<Self::Item> {
139 self.buffers.next().map(
140 |&AudioBuffer {
141 mNumberChannels,
142 mData,
143 ..
144 }| {
145 let len = mNumberChannels as usize * self.frames;
146 let ptr = mData as *mut S;
147 unsafe { slice::from_raw_parts(ptr, len) }
148 },
149 )
150 }
151 }
152
153 impl<'a, S> Iterator for ChannelsMut<'a, S> {
154 type Item = &'a mut [S];
155 #[allow(non_snake_case)]
156 fn next(&mut self) -> Option<Self::Item> {
157 self.buffers.next().map(
158 |&mut AudioBuffer {
159 mNumberChannels,
160 mData,
161 ..
162 }| {
163 let len = mNumberChannels as usize * self.frames;
164 let ptr = mData as *mut S;
165 unsafe { slice::from_raw_parts_mut(ptr, len) }
166 },
167 )
168 }
169 }
170
171 impl<S> NonInterleaved<S> {
172 pub fn channels(&self) -> Channels<'_, S> {
174 Channels {
175 buffers: self.buffers.iter(),
176 frames: self.frames,
177 sample_format: PhantomData,
178 }
179 }
180
181 pub fn channels_mut(&mut self) -> ChannelsMut<'_, S> {
183 ChannelsMut {
184 buffers: self.buffers.iter_mut(),
185 frames: self.frames,
186 sample_format: PhantomData,
187 }
188 }
189 }
190
191 impl<S> Data for NonInterleaved<S>
193 where
194 S: Sample,
195 {
196 fn does_stream_format_match(stream_format: &StreamFormat) -> bool {
197 stream_format
198 .flags
199 .contains(LinearPcmFlags::IS_NON_INTERLEAVED)
200 && S::sample_format().does_match_flags(stream_format.flags)
201 }
202
203 #[allow(non_snake_case)]
204 unsafe fn from_input_proc_args(frames: u32, io_data: *mut AudioBufferList) -> Self {
205 let ptr = (*io_data).mBuffers.as_ptr() as *mut AudioBuffer;
206 let len = (*io_data).mNumberBuffers as usize;
207 let buffers = slice::from_raw_parts_mut(ptr, len);
208 NonInterleaved {
209 buffers,
210 frames: frames as usize,
211 sample_format: PhantomData,
212 }
213 }
214 }
215
216 impl<S> Data for Interleaved<S>
218 where
219 S: Sample,
220 {
221 fn does_stream_format_match(stream_format: &StreamFormat) -> bool {
222 !stream_format
223 .flags
224 .contains(LinearPcmFlags::IS_NON_INTERLEAVED)
225 && S::sample_format().does_match_flags(stream_format.flags)
226 }
227
228 #[allow(non_snake_case)]
229 unsafe fn from_input_proc_args(frames: u32, io_data: *mut AudioBufferList) -> Self {
230 let AudioBuffer {
232 mNumberChannels,
233 mDataByteSize,
234 mData,
235 } = (*io_data).mBuffers[0];
236 let buffer_len = frames as usize * mNumberChannels as usize;
241 let expected_size = ::std::mem::size_of::<S>() * buffer_len;
242 assert!(mDataByteSize as usize == expected_size);
243
244 let buffer: &mut [S] = {
245 let buffer_ptr = mData as *mut S;
246 slice::from_raw_parts_mut(buffer_ptr, buffer_len)
247 };
248
249 Interleaved {
250 buffer,
251 channels: mNumberChannels as usize,
252 sample_format: PhantomData,
253 }
254 }
255 }
256
257 impl<S> Data for InterleavedBytes<S>
259 where
260 S: Sample,
261 {
262 fn does_stream_format_match(stream_format: &StreamFormat) -> bool {
263 !stream_format
264 .flags
265 .contains(LinearPcmFlags::IS_NON_INTERLEAVED)
266 && S::sample_format().does_match_flags(stream_format.flags)
267 }
268
269 #[allow(non_snake_case)]
270 unsafe fn from_input_proc_args(frames: u32, io_data: *mut AudioBufferList) -> Self {
271 let AudioBuffer {
273 mNumberChannels,
274 mDataByteSize,
275 mData,
276 } = (*io_data).mBuffers[0];
277 let buffer_len = frames as usize * mNumberChannels as usize;
282 let expected_size = ::std::mem::size_of::<S>() * buffer_len;
283 assert!(mDataByteSize as usize == expected_size);
284
285 let buffer: &mut [u8] = {
286 let buffer_ptr = mData as *mut u8;
287 slice::from_raw_parts_mut(buffer_ptr, mDataByteSize as usize)
288 };
289
290 InterleavedBytes {
291 buffer,
292 channels: mNumberChannels as usize,
293 sample_format: PhantomData,
294 }
295 }
296 }
297}
298
299pub mod action_flags {
300 use objc2_audio_toolbox::AudioUnitRenderActionFlags;
301
302 use std::fmt;
303
304 bitflags! {
305 pub struct ActionFlags: u32 {
306 const PRE_RENDER = AudioUnitRenderActionFlags::UnitRenderAction_PreRender.0;
312 const POST_RENDER = AudioUnitRenderActionFlags::UnitRenderAction_PostRender.0;
318 const OUTPUT_IS_SILENCE = AudioUnitRenderActionFlags::UnitRenderAction_OutputIsSilence.0;
325 const OFFLINE_PREFLIGHT = AudioUnitRenderActionFlags::OfflineUnitRenderAction_Preflight.0;
334 const OFFLINE_RENDER = AudioUnitRenderActionFlags::OfflineUnitRenderAction_Render.0;
340 const OFFLINE_COMPLETE = AudioUnitRenderActionFlags::OfflineUnitRenderAction_Complete.0;
345 const POST_RENDER_ERROR = AudioUnitRenderActionFlags::UnitRenderAction_PostRenderError.0;
352 const DO_NOT_CHECK_RENDER_ARGS = AudioUnitRenderActionFlags::UnitRenderAction_DoNotCheckRenderArgs.0;
359 }
360 }
361
362 pub struct Handle {
370 ptr: *mut AudioUnitRenderActionFlags,
371 }
372
373 impl fmt::Debug for Handle {
374 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
375 if self.ptr.is_null() {
376 write!(f, "{:?}", self.ptr)
377 } else {
378 unsafe { write!(f, "{:?}", *self.ptr) }
379 }
380 }
381 }
382
383 impl Handle {
384 pub fn get(&self) -> ActionFlags {
386 ActionFlags::from_bits_truncate(unsafe { *self.ptr }.0)
387 }
388
389 fn set(&mut self, flags: ActionFlags) {
390 unsafe { (*self.ptr).0 = flags.bits() }
391 }
392
393 pub fn bits(&self) -> u32 {
395 self.get().bits()
396 }
397
398 pub fn is_empty(&self) -> bool {
400 self.get().is_empty()
401 }
402
403 pub fn is_all(&self) -> bool {
405 self.get().is_all()
406 }
407
408 pub fn intersects(&self, other: ActionFlags) -> bool {
410 self.get().intersects(other)
411 }
412
413 pub fn contains(&self, other: ActionFlags) -> bool {
415 self.get().contains(other)
416 }
417
418 pub fn insert(&mut self, other: ActionFlags) {
420 let mut flags = self.get();
421 flags.insert(other);
422 self.set(flags);
423 }
424
425 pub fn remove(&mut self, other: ActionFlags) {
427 let mut flags = self.get();
428 flags.remove(other);
429 self.set(flags);
430 }
431
432 pub fn toggle(&mut self, other: ActionFlags) {
434 let mut flags = self.get();
435 flags.toggle(other);
436 self.set(flags);
437 }
438
439 pub fn from_ptr(ptr: *mut AudioUnitRenderActionFlags) -> Self {
441 Handle { ptr }
442 }
443 }
444
445 unsafe impl Send for Handle {}
446
447 impl ::std::fmt::Display for ActionFlags {
448 fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
449 write!(
450 f,
451 "{:?}",
452 match AudioUnitRenderActionFlags(self.bits()) {
453 AudioUnitRenderActionFlags::UnitRenderAction_PreRender => "PRE_RENDER",
454 AudioUnitRenderActionFlags::UnitRenderAction_PostRender => "POST_RENDER",
455 AudioUnitRenderActionFlags::UnitRenderAction_OutputIsSilence =>
456 "OUTPUT_IS_SILENCE",
457 AudioUnitRenderActionFlags::OfflineUnitRenderAction_Preflight =>
458 "OFFLINE_PREFLIGHT",
459 AudioUnitRenderActionFlags::OfflineUnitRenderAction_Render => "OFFLINE_RENDER",
460 AudioUnitRenderActionFlags::OfflineUnitRenderAction_Complete =>
461 "OFFLINE_COMPLETE",
462 AudioUnitRenderActionFlags::UnitRenderAction_PostRenderError =>
463 "POST_RENDER_ERROR",
464 AudioUnitRenderActionFlags::UnitRenderAction_DoNotCheckRenderArgs =>
465 "DO_NOT_CHECK_RENDER_ARGS",
466 _ => "<Unknown ActionFlags>",
467 }
468 )
469 }
470 }
471}
472
473impl AudioUnit {
474 pub fn set_render_callback<F, D>(&mut self, mut f: F) -> Result<(), Error>
476 where
477 F: FnMut(Args<D>) -> Result<(), ()> + 'static,
478 D: Data,
479 {
480 let stream_format = self.output_stream_format()?;
483
484 if !D::does_stream_format_match(&stream_format) {
486 return Err(Error::RenderCallbackBufferFormatDoesNotMatchAudioUnitStreamFormat);
487 }
488
489 let input_proc_fn = move |io_action_flags: NonNull<AudioUnitRenderActionFlags>,
495 in_time_stamp: NonNull<AudioTimeStamp>,
496 in_bus_number: u32,
497 in_number_frames: u32,
498 io_data: *mut AudioBufferList|
499 -> OSStatus {
500 let args = unsafe {
501 let data = D::from_input_proc_args(in_number_frames, io_data);
502 let flags = action_flags::Handle::from_ptr(io_action_flags.as_ptr());
503 Args {
504 data,
505 time_stamp: in_time_stamp.read(),
506 flags,
507 bus_number: in_bus_number,
508 num_frames: in_number_frames as usize,
509 }
510 };
511
512 match f(args) {
513 Ok(()) => 0,
514 Err(()) => error::Error::Unspecified.as_os_status(),
515 }
516 };
517
518 let input_proc_fn_wrapper = Box::new(InputProcFnWrapper {
519 callback: Box::new(input_proc_fn),
520 });
521
522 let input_proc_fn_wrapper_ptr = Box::into_raw(input_proc_fn_wrapper) as *mut c_void;
527
528 let render_callback = AURenderCallbackStruct {
529 inputProc: Some(input_proc),
530 inputProcRefCon: input_proc_fn_wrapper_ptr,
531 };
532
533 self.set_property(
534 kAudioUnitProperty_SetRenderCallback,
535 Scope::Input,
536 Element::Output,
537 Some(&render_callback),
538 )?;
539
540 self.free_render_callback();
541 self.maybe_render_callback = Some(input_proc_fn_wrapper_ptr as *mut InputProcFnWrapper);
542 Ok(())
543 }
544
545 pub fn set_input_callback<F, D>(&mut self, mut f: F) -> Result<(), Error>
547 where
548 F: FnMut(Args<D>) -> Result<(), ()> + 'static,
549 D: Data,
550 {
551 let stream_format = self.input_stream_format()?;
554
555 if !D::does_stream_format_match(&stream_format) {
557 return Err(Error::RenderCallbackBufferFormatDoesNotMatchAudioUnitStreamFormat);
558 }
559
560 let non_interleaved = stream_format
562 .flags
563 .contains(LinearPcmFlags::IS_NON_INTERLEAVED);
564
565 #[cfg(target_os = "macos")]
569 let mut buffer_frame_size: u32 = {
570 let id = objc2_core_audio::kAudioDevicePropertyBufferFrameSize;
571 let buffer_frame_size: u32 = self.get_property(id, Scope::Global, Element::Output)?;
572 buffer_frame_size
573 };
574 #[cfg(any(target_os = "ios", target_os = "tvos", target_os = "visionos"))]
575 let mut buffer_frame_size: u32 = {
576 let id = objc2_audio_toolbox::kAudioSessionProperty_CurrentHardwareIOBufferDuration;
577 let seconds: f32 = super::audio_session_get_property(id)?;
578 let id = objc2_audio_toolbox::kAudioSessionProperty_CurrentHardwareSampleRate;
579 let sample_rate: f64 = super::audio_session_get_property(id)?;
580 (sample_rate * seconds as f64).round() as u32
581 };
582 let sample_bytes = stream_format.sample_format.size_in_bytes();
583 let n_channels = stream_format.channels;
584 if non_interleaved && n_channels > 1 {
585 return Err(Error::NonInterleavedInputOnlySupportsMono);
586 }
587
588 let data_byte_size = buffer_frame_size * sample_bytes as u32 * n_channels;
589 let mut data = vec![0u8; data_byte_size as usize];
590 let mut buffer_capacity = data_byte_size as usize;
591 let audio_buffer = AudioBuffer {
592 mDataByteSize: data_byte_size,
593 mNumberChannels: n_channels,
594 mData: data.as_mut_ptr() as *mut _,
595 };
596 mem::forget(data);
599
600 let audio_buffer_list = Box::new(AudioBufferList {
601 mNumberBuffers: 1,
602 mBuffers: [audio_buffer],
603 });
604
605 let audio_buffer_list_ptr = Box::into_raw(audio_buffer_list);
608
609 let audio_unit = self.instance;
615 let input_proc_fn = move |io_action_flags: NonNull<AudioUnitRenderActionFlags>,
616 in_time_stamp: NonNull<AudioTimeStamp>,
617 in_bus_number: u32,
618 in_number_frames: u32,
619 _io_data: *mut AudioBufferList|
620 -> OSStatus {
621 if buffer_frame_size != in_number_frames {
623 unsafe {
624 let id = kAudioUnitProperty_StreamFormat;
626 let asbd =
627 match super::get_property(audio_unit, id, Scope::Output, Element::Input) {
628 Err(err) => return err.as_os_status(),
629 Ok(asbd) => asbd,
630 };
631 let stream_format = match super::StreamFormat::from_asbd(asbd) {
632 Err(err) => return err.as_os_status(),
633 Ok(fmt) => fmt,
634 };
635 let sample_bytes = stream_format.sample_format.size_in_bytes();
636 let n_channels = stream_format.channels;
637 let data_byte_size =
638 in_number_frames as usize * sample_bytes * n_channels as usize;
639 let ptr = (*audio_buffer_list_ptr).mBuffers.as_ptr() as *mut AudioBuffer;
640 let len = (*audio_buffer_list_ptr).mNumberBuffers as usize;
641
642 let buffers: &mut [AudioBuffer] = slice::from_raw_parts_mut(ptr, len);
643 let old_capacity = buffer_capacity;
644 for buffer in buffers {
645 let current_len = buffer.mDataByteSize as usize;
646 let audio_buffer_ptr = buffer.mData as *mut u8;
647 let mut vec: Vec<u8> =
648 Vec::from_raw_parts(audio_buffer_ptr, current_len, old_capacity);
649 vec.resize(data_byte_size, 0u8);
650
651 buffer_capacity = vec.capacity();
652 buffer.mData = vec.as_mut_ptr() as *mut _;
653 buffer.mDataByteSize = data_byte_size as u32;
654 mem::forget(vec);
655 }
656 }
657 buffer_frame_size = in_number_frames;
658 }
659
660 unsafe {
661 let status = AudioUnitRender(
662 audio_unit,
663 io_action_flags.as_ptr(),
664 in_time_stamp,
665 in_bus_number,
666 in_number_frames,
667 NonNull::new(audio_buffer_list_ptr).unwrap(),
668 );
669 if status != 0 {
670 return status;
671 }
672 }
673
674 let args = unsafe {
675 let data = D::from_input_proc_args(in_number_frames, audio_buffer_list_ptr);
676 let flags = action_flags::Handle::from_ptr(io_action_flags.as_ptr());
677 Args {
678 data,
679 time_stamp: in_time_stamp.read(),
680 flags,
681 bus_number: in_bus_number,
682 num_frames: in_number_frames as usize,
683 }
684 };
685
686 match f(args) {
687 Ok(()) => 0,
688 Err(()) => error::Error::Unspecified.as_os_status(),
689 }
690 };
691
692 let input_proc_fn_wrapper = Box::new(InputProcFnWrapper {
693 callback: Box::new(input_proc_fn),
694 });
695
696 let input_proc_fn_wrapper_ptr = Box::into_raw(input_proc_fn_wrapper) as *mut c_void;
701
702 let render_callback = AURenderCallbackStruct {
703 inputProc: Some(input_proc),
704 inputProcRefCon: input_proc_fn_wrapper_ptr,
705 };
706
707 self.set_property(
708 kAudioOutputUnitProperty_SetInputCallback,
709 Scope::Global,
710 Element::Output,
711 Some(&render_callback),
712 )?;
713
714 let input_callback = super::InputCallback {
715 buffer_list: audio_buffer_list_ptr,
716 callback: input_proc_fn_wrapper_ptr as *mut InputProcFnWrapper,
717 };
718 self.free_input_callback();
719 self.maybe_input_callback = Some(input_callback);
720 Ok(())
721 }
722
723 pub fn free_render_callback(&mut self) -> Option<Box<InputProcFnWrapper>> {
726 if let Some(callback) = self.maybe_render_callback.take() {
727 let callback: Box<InputProcFnWrapper> = unsafe { Box::from_raw(callback) };
730 return Some(callback);
731 }
732 None
733 }
734
735 pub fn free_input_callback(&mut self) -> Option<Box<InputProcFnWrapper>> {
738 if let Some(input_callback) = self.maybe_input_callback.take() {
739 let super::InputCallback {
740 buffer_list,
741 callback,
742 } = input_callback;
743 unsafe {
744 let buffer_list: Box<AudioBufferList> = Box::from_raw(buffer_list);
746 let ptr = buffer_list.mBuffers.as_ptr();
748 let len = buffer_list.mNumberBuffers as usize;
749 let buffers: &[AudioBuffer] = slice::from_raw_parts(ptr, len);
750 for &buffer in buffers {
751 let ptr = buffer.mData as *mut u8;
752 let len = buffer.mDataByteSize as usize;
753 let cap = len;
754 let _ = Vec::from_raw_parts(ptr, len, cap);
755 }
756 let callback: Box<InputProcFnWrapper> = Box::from_raw(callback);
758 return Some(callback);
759 }
760 }
761 None
762 }
763}
764
765extern "C-unwind" fn input_proc(
767 in_ref_con: NonNull<c_void>,
768 io_action_flags: NonNull<AudioUnitRenderActionFlags>,
769 in_time_stamp: NonNull<AudioTimeStamp>,
770 in_bus_number: u32,
771 in_number_frames: u32,
772 io_data: *mut AudioBufferList,
773) -> OSStatus {
774 let wrapper = unsafe { in_ref_con.cast::<InputProcFnWrapper>().as_mut() };
775 (wrapper.callback)(
776 io_action_flags,
777 in_time_stamp,
778 in_bus_number,
779 in_number_frames,
780 io_data,
781 )
782}