insta_fun/
input.rs

1use fundsp::hacker::AudioUnit;
2
3/// Input provided to the audio unit
4pub enum InputSource {
5    /// No input
6    None,
7    /// Input provided by a channel vec
8    ///
9    /// - First vec contains all **channels**
10    /// - Second vec contains **samples** per channel
11    VecByChannel(Vec<Vec<f32>>),
12    /// Input provided by a tick vec
13    ///
14    /// - First vec contains all **ticks**
15    /// - Second vec contains **samples** for all **channels** per tick
16    VecByTick(Vec<Vec<f32>>),
17    /// Input **repeated** on every tick
18    ///
19    /// - Vector contains **samples** for all **channels** for **one** tick
20    Flat(Vec<f32>),
21    /// Input provided by a generator function
22    ///
23    /// - First argument is the sample index
24    /// - Second argument is the channel index
25    Generator(Box<dyn Fn(usize, usize) -> f32>),
26    /// Input provided by an audio unit
27    ///
28    /// Number of outputs of the audio unit must match
29    /// the number of inputs to the test target
30    ///
31    /// * if you need to set sample rate on input unit do that upfront
32    Unit(Box<dyn AudioUnit>),
33}
34
35impl InputSource {
36    pub fn impulse() -> Self {
37        Self::Generator(Box::new(|i, _| if i == 0 { 1.0 } else { 0.0 }))
38    }
39    pub fn sine(freq: f32, sample_rate: f32) -> Self {
40        Self::Generator(Box::new(move |i, _| {
41            let phase = 2.0 * std::f32::consts::PI * freq * i as f32 / sample_rate;
42            phase.sin()
43        }))
44    }
45
46    pub fn into_data(self, num_inputs: usize, num_samples: usize) -> Vec<Vec<f32>> {
47        match self {
48            InputSource::None => vec![vec![0.0; num_samples]; num_inputs],
49            InputSource::VecByChannel(data) => {
50                assert_eq!(
51                    data.len(),
52                    num_inputs,
53                    "Input vec size mismatch. Expected {} channels, got {}",
54                    num_inputs,
55                    data.len()
56                );
57                assert!(
58                    data.iter().all(|v| v.len() == num_samples),
59                    "Input vec size mismatch. Expected {} samples per channel, got {}",
60                    num_samples,
61                    data.iter().map(|v| v.len()).max().unwrap_or(0)
62                );
63                data.to_vec()
64            }
65            InputSource::VecByTick(data) => {
66                assert!(
67                    data.iter().all(|v| v.len() == num_inputs),
68                    "Input vec size mismatch. Expected {} channels, got {}",
69                    num_inputs,
70                    data.iter().map(|v| v.len()).max().unwrap_or(0)
71                );
72                assert_eq!(
73                    data.len(),
74                    num_samples,
75                    "Input vec size mismatch. Expected {} samples, got {}",
76                    num_samples,
77                    data.len()
78                );
79                (0..num_inputs)
80                    .map(|ch| (0..num_samples).map(|i| data[i][ch]).collect())
81                    .collect()
82            }
83            InputSource::Flat(data) => {
84                assert_eq!(
85                    data.len(),
86                    num_inputs,
87                    "Input vec size mismatch. Expected {} channels, got {}",
88                    num_inputs,
89                    data.len()
90                );
91                (0..num_inputs)
92                    .map(|ch| (0..num_samples).map(|_| data[ch]).collect())
93                    .collect()
94            }
95            InputSource::Generator(generator_fn) => (0..num_inputs)
96                .map(|ch| (0..num_samples).map(|i| generator_fn(i, ch)).collect())
97                .collect(),
98            InputSource::Unit(mut unit) => {
99                let mut data = vec![vec![0.0; num_samples]; num_inputs];
100                for i in 0..num_samples {
101                    let mut outputs = vec![0.0; num_inputs];
102                    unit.tick(&[], &mut outputs);
103                    for slc in data.iter_mut() {
104                        slc[i] = outputs[i];
105                    }
106                }
107                data
108            }
109        }
110    }
111}