1use crate::{ComputationStatus, StatefulUnaryKernel, TapsAccessor};
2
3extern crate alloc;
4use alloc::vec::Vec;
5
6pub struct IirKernel<InputType, OutputType, TapsType: TapsAccessor> {
28 a_taps: TapsType,
29 b_taps: TapsType,
30 memory: Vec<InputType>,
31 _input_type: core::marker::PhantomData<InputType>,
32 _output_type: core::marker::PhantomData<OutputType>,
33}
34
35impl<InputType, OutputType, TapType, TapsType: TapsAccessor<TapType = TapType>>
36 IirKernel<InputType, OutputType, TapsType>
37{
38 pub fn new(a_taps: TapsType, b_taps: TapsType) -> Self {
39 Self {
40 a_taps,
41 b_taps,
42 memory: Vec::new(),
43 _input_type: core::marker::PhantomData,
44 _output_type: core::marker::PhantomData,
45 }
46 }
47}
48
49impl<TapsType: TapsAccessor<TapType = f32>> StatefulUnaryKernel<f32, f32>
50 for IirKernel<f32, f32, TapsType>
51{
52 fn work(&mut self, i: &[f32], o: &mut [f32]) -> (usize, usize, ComputationStatus) {
53 if i.is_empty() {
54 return (
55 0,
56 0,
57 if o.is_empty() {
58 ComputationStatus::BothSufficient
59 } else {
60 ComputationStatus::InsufficientInput
61 },
62 );
63 }
64
65 let mut num_filled = 0;
67 while self.memory.len() < self.a_taps.num_taps() {
68 if i.len() <= self.memory.len() {
69 return (
70 0,
71 0,
72 if o.is_empty() {
73 ComputationStatus::BothSufficient
74 } else {
75 ComputationStatus::InsufficientInput
76 },
77 );
78 }
79 self.memory.push(i[self.memory.len()]);
80 num_filled += 1;
81 }
82 if num_filled == i.len() {
83 return (
84 0,
85 0,
86 if o.is_empty() {
87 ComputationStatus::BothSufficient
88 } else {
89 ComputationStatus::InsufficientInput
90 },
91 );
92 }
93
94 assert_eq!(self.a_taps.num_taps(), self.memory.len());
95 assert!(self.b_taps.num_taps() > 0);
96
97 let mut n_consumed = 0;
98 let mut n_produced = 0;
99 while n_consumed + self.b_taps.num_taps() - 1 < i.len() && n_produced < o.len() {
100 let o: &mut f32 = &mut o[n_produced];
101
102 *o = 0.0;
103
104 for b_tap in 0..self.b_taps.num_taps() {
106 *o += unsafe { self.b_taps.get(b_tap) }
108 * i[n_consumed + self.b_taps.num_taps() - b_tap - 1];
109 }
110
111 for a_tap in 0..self.a_taps.num_taps() {
113 *o += unsafe { self.a_taps.get(a_tap) } * self.memory[a_tap];
115 }
116
117 for idx in 1..self.memory.len() {
119 self.memory[idx] = self.memory[idx - 1];
120 }
121 if !self.memory.is_empty() {
122 self.memory[0] = *o;
123 }
124
125 n_produced += 1;
126 n_consumed += 1;
127 }
128
129 (
130 n_consumed,
131 n_produced,
132 if n_consumed == i.len() && n_produced == o.len() {
133 ComputationStatus::BothSufficient
134 } else if n_consumed < i.len() {
135 ComputationStatus::InsufficientOutput
136 } else {
137 assert!(n_produced < o.len());
138 ComputationStatus::InsufficientInput
139 },
140 )
141 }
142}
143
144#[cfg(test)]
145mod test {
146 use super::*;
147
148 use alloc::vec;
149
150 struct Feeder {
151 filter: IirKernel<f32, f32, Vec<f32>>,
152 input: Vec<f32>,
153 }
154
155 impl Feeder {
156 fn feed(&mut self, input: f32) -> Option<f32> {
157 self.input.push(input);
158
159 let mut out = [0.0];
160 let (n_consumed, n_produced, _status) = self.filter.work(&self.input[..], &mut out);
161 assert_eq!(n_consumed, n_produced); if n_consumed > 0 {
163 self.input.drain(0..n_consumed);
164 }
165 if n_produced > 0 {
166 Some(out[0])
167 } else {
168 None
169 }
170 }
171 }
172
173 fn make_filter(a_taps: Vec<f32>, b_taps: Vec<f32>) -> Feeder {
174 Feeder {
175 filter: IirKernel {
176 a_taps,
177 b_taps,
178 memory: vec![],
179 _input_type: core::marker::PhantomData,
180 _output_type: core::marker::PhantomData,
181 },
182 input: vec![],
183 }
184 }
185
186 #[test]
187 fn test_iir_b_taps_algorithm() {
188 let mut iir = make_filter(vec![], vec![1.0, 2.0, 3.0]);
189
190 assert_eq!(iir.feed(10.0), None);
191 assert_eq!(iir.feed(20.0), None);
192 assert_eq!(iir.feed(30.0), Some(30.0 + 40.0 + 30.0));
193 assert_eq!(iir.feed(40.0), Some(40.0 + 60.0 + 60.0));
194 }
195
196 #[test]
197 fn test_iir_single_a_tap_algorithm() {
198 let mut iir = make_filter(vec![0.5], vec![1.0]);
199
200 assert_eq!(iir.feed(10.0), None);
201 assert_eq!(iir.feed(10.0), Some(15.0));
202 assert_eq!(iir.feed(10.0), Some(17.5));
203 assert_eq!(iir.feed(10.0), Some(18.75));
204 }
205}