pm_remez/extrema.rs
1use crate::barycentric::{compute_error, compute_extrema_candidate};
2use crate::compute_cheby_coefficients;
3use crate::eigenvalues::EigenvalueBackend;
4use crate::error::{Error, Result};
5use crate::types::Band;
6use ndarray::Array2;
7use num_traits::{Float, FloatConst};
8
9#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
10pub struct Interval<T> {
11 pub begin: T,
12 pub end: T,
13}
14
15#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
16pub struct ExtremaCandidate<T> {
17 pub x: T,
18 pub error: T,
19 pub desired: T,
20 pub weight: T,
21}
22
23// Initial guess for extremal frequencies: evenly spaced over bands
24pub fn initial_extremal_freqs<T: Float>(bands: &[Band<T>], num_functions: usize) -> Vec<T> {
25 let mut total_band_length = T::zero();
26 for band_length in bands.iter().map(|b| b.len()) {
27 total_band_length = total_band_length + band_length;
28 }
29 let spacing = total_band_length / T::from(num_functions).unwrap();
30 let mut consumed_length = T::zero();
31 let num_bands = bands.len();
32 let mut current_band = bands.iter().enumerate().peekable();
33 (0..(num_functions + 1))
34 .map(|j| {
35 let s = T::from(j).unwrap() * spacing;
36 debug_assert!(s >= consumed_length);
37 let mut u = s - consumed_length;
38 loop {
39 let cband = current_band.peek().unwrap();
40 let band_length = cband.1.len();
41 // the second condition is to avoid going past the last band due to numerical rounding
42 if u <= band_length || cband.0 == num_bands - 1 {
43 break;
44 }
45 current_band.next();
46 consumed_length = consumed_length + band_length;
47 u = s - consumed_length;
48 }
49 let cband = current_band.peek().unwrap();
50 (cband.1.begin() + u).min(cband.1.end())
51 })
52 .collect()
53}
54
55// Compute subintervals containing the extremal points and band edges (in the
56// [-1, 1] domain).
57pub fn subdivide<T: Float>(x: &[T], bands_x: &[Interval<T>]) -> Vec<Interval<T>> {
58 // reserve capacity for the worst case
59 let mut subintervals = Vec::with_capacity(x.len() + bands_x.len());
60 let mut xs = x.iter().rev().peekable();
61 for band in bands_x {
62 let mut begin = band.begin;
63 loop {
64 match xs.peek() {
65 Some(&&a) => {
66 match a.partial_cmp(&band.end).unwrap() {
67 std::cmp::Ordering::Greater => {
68 // new point to the right of the band end: end
69 // interval at the right band edge, do not consume
70 // point.
71 subintervals.push(Interval {
72 begin,
73 end: band.end,
74 });
75 break;
76 }
77 std::cmp::Ordering::Equal => {
78 // new point exactly at the band end: end interval
79 // at the right band edge, consume point.
80 subintervals.push(Interval {
81 begin,
82 end: band.end,
83 });
84 xs.next();
85 break;
86 }
87 std::cmp::Ordering::Less => {
88 // new point inside the band: end interval at this
89 // point, consume point, the point is the begin of
90 // the next interval.
91 if begin != a {
92 subintervals.push(Interval { begin, end: a });
93 begin = a;
94 }
95 xs.next();
96 }
97 }
98 }
99 None => {
100 // no more points: end interval at the right band edge.
101 subintervals.push(Interval {
102 begin,
103 end: band.end,
104 });
105 break;
106 }
107 }
108 }
109 }
110 // check that we have consumed all the points
111 debug_assert!(xs.next().is_none());
112 subintervals
113}
114
115// Find local extrema of error function in subinterval using Chebyshev proxy method
116#[allow(clippy::too_many_arguments)]
117pub fn find_extrema_in_subinterval<'a, T, D, W, B: EigenvalueBackend<T>>(
118 interval: &Interval<T>,
119 cheby_nodes: &[T],
120 x: &'a [T],
121 wk: &'a [T],
122 yk: &'a [T],
123 desired: D,
124 weights: W,
125 eigenvalue_backend: &B,
126) -> Result<impl Iterator<Item = ExtremaCandidate<T>>>
127where
128 T: Float + FloatConst,
129 D: Fn(T) -> T + 'a,
130 W: Fn(T) -> T + 'a,
131{
132 // Compute Chebyshev proxy for error function in interval
133 //
134 // Scale Chebyshev nodes to interval and compute error function
135 let mut cheby_nodes_errors: Vec<T> = {
136 let scale = T::from(0.5).unwrap() * (interval.end - interval.begin);
137 cheby_nodes
138 .iter()
139 .map(|&x0| {
140 let cheby_node_scaled = (x0 + T::one()) * scale + interval.begin;
141 compute_error(cheby_node_scaled, x, wk, yk, &desired, &weights)
142 })
143 .collect()
144 };
145 // Compute coefficients of first-order Chebyshev polynomial expansion
146 let ak = compute_cheby_coefficients(&mut cheby_nodes_errors);
147
148 // Compute derivative of Chebyshev proxy
149 //
150 // Compute coefficients of second-order Chevyshev polynomial expansion of
151 // the derivative of the proxy.
152 let mut ck: Vec<T> = ak
153 .iter()
154 .enumerate()
155 .skip(1)
156 .map(|(k, &a)| T::from(k).unwrap() * a)
157 .collect();
158
159 // Remove high-order coefficients ck which are zero. The colleague matrix
160 // definition needs the leading coefficient to be nonzero.
161 let zero = T::zero();
162 while *ck.last().unwrap() == zero {
163 ck.pop();
164 if ck.is_empty() {
165 return Err(Error::ProxyDerivativeZero);
166 }
167 }
168
169 // Compute colleague matrix of ck. Its eigenvalues are the zeros of the
170 // derivative of the Chebyshev proxy.
171 let s = ck.len() - 1;
172 let mut colleague = Array2::<T>::zeros((s, s));
173 let half = T::from(0.5).unwrap();
174 for j in 0..s - 1 {
175 colleague[(j, j + 1)] = half;
176 }
177 for j in 2..s {
178 colleague[(j, j - 1)] = half;
179 }
180 let scale = T::from(-0.5).unwrap() / *ck.last().unwrap();
181 for j in 0..s {
182 let c = ck[s - 1 - j] * scale;
183 colleague[(j, 0)] = if j == 1 { c + half } else { c };
184 }
185 // Balance matrix for better numerical conditioning
186 balance_matrix(&mut colleague);
187
188 // Compute eigenvalues of colleague matrix. These are the roots of the
189 // derivative of the proxy.
190 let eig = eigenvalue_backend.eigenvalues(colleague)?;
191
192 // Filter only the roots that are real and inside [-1, 1]. Map them to
193 // the original interval.
194 //
195 // The threshold scalar is real, but the type system doesn't know, so we
196 // need to call re.
197 let threshold = T::from(1e-20).unwrap();
198 let limits = -T::one()..=T::one();
199 let scale = T::from(0.5).unwrap() * (interval.end - interval.begin);
200 let begin = interval.begin;
201 Ok(eig.into_iter().filter_map(move |z| {
202 if Float::abs(z.im) < threshold {
203 let x0 = z.re;
204 if limits.contains(&x0) {
205 // map root to interval
206 let y = (x0 + T::one()) * scale + begin;
207 // evaluate error function
208 Some(compute_extrema_candidate(y, x, wk, yk, &desired, &weights))
209 } else {
210 None
211 }
212 } else {
213 None
214 }
215 }))
216}
217
218// Prune extrema candidates to leave only n of them. It assumes that the candidates are sorted.
219pub fn prune_extrema_candidates<T: Float>(
220 candidates: &[ExtremaCandidate<T>],
221 n: usize,
222) -> Result<Vec<ExtremaCandidate<T>>> {
223 assert!(!candidates.is_empty());
224 let mut pruned = Vec::with_capacity(candidates.len());
225 let zero = T::zero();
226
227 // From groups of adjacent extrema with the same sign, leave only the largest
228 let mut b = candidates[0];
229 let mut b_sign = b.error < zero;
230 let mut b_abs = b.error.abs();
231 for &a in candidates.iter().skip(1) {
232 let a_sign = a.error < zero;
233 let a_abs = a.error.abs();
234 if a_sign != b_sign {
235 pruned.push(b);
236 }
237 if a_sign != b_sign || a_abs > b_abs {
238 b = a;
239 b_sign = a_sign;
240 b_abs = a_abs;
241 }
242 }
243 pruned.push(b);
244
245 if pruned.len() == n {
246 return Ok(pruned);
247 }
248 if pruned.len() < n {
249 return Err(Error::NotEnoughExtrema);
250 }
251
252 let to_remove = pruned.len() - n;
253 // FIXME: This technique gives convergence problems in some cases, such as a
254 // lowpass FIR with 1/f stopband response when the stopband weigth is set
255 // large enough. The last extrema is always removed and the problem starts
256 // ignoring the error at the end of the stopband.
257 //
258 // if to_remove == 1 {
259 // // Only one extrema needs to be removed. Consider the cases of removing
260 // // the first extrema and the last extrema, and compute delta for each of
261 // // them. Choose the option that gives a larger delta.
262 // let delta_keep_first = compute_delta_from_candidates(&pruned[..n]);
263 // let delta_keep_last = compute_delta_from_candidates(&pruned[1..]);
264 // if delta_keep_first >= delta_keep_last {
265 // // remove last candidate
266 // pruned.pop();
267 // } else {
268 // // remove first candidate
269 // pruned.remove(0);
270 // }
271 // return Ok(pruned);
272 // }
273 if to_remove % 2 == 1 {
274 // An odd number of extrema need to be removed. Reduce this to the case
275 // of an even number of extrema for removal by removing either the first
276 // or last extrema, whichever has smaller error.
277 if pruned[0].error.abs() >= pruned[pruned.len() - 1].error.abs() {
278 pruned.pop();
279 } else {
280 pruned.remove(0);
281 }
282 }
283 while pruned.len() > n {
284 // An even number of extrema need to be removed. Find the pair of
285 // elements that has smaller minimum absolute value among the two
286 // elements of the pair and remove that pair.
287 let idx = pruned
288 .iter()
289 .zip(pruned.iter().skip(1))
290 .enumerate()
291 .map(|(k, (a, b))| (k, a.error.abs().min(b.error.abs())))
292 // unwrap will fail if there are NaN's
293 .min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
294 .unwrap()
295 .0;
296 pruned.drain(idx..=idx + 1);
297 }
298 assert!(pruned.len() == n);
299 Ok(pruned)
300}
301
302// Balance a matrix for eigenvalue calculation, as indicated in [5].
303fn balance_matrix<T: Float>(a: &mut Array2<T>) {
304 // Some constants to be used below
305 let gamma = T::from(0.95).unwrap();
306 let two = T::from(2.0).unwrap();
307 let four = T::from(4.0).unwrap();
308 let half = T::from(0.5).unwrap();
309 let one = T::one();
310 let zero = T::zero();
311
312 // The algorithm in [5] has a preliminary step where rows and columns that
313 // isolate an eignevalue (those that zero except on the diagonal element)
314 // are pushed to the left or bottom of the matrix respectively. However, the
315 // colleague matrix does not have any such rows or columns, so we don't need
316 // this step.
317
318 let n = a.nrows();
319 let mut converged = false;
320 while !converged {
321 converged = true;
322 for j in 0..n {
323 let mut row_norm = zero;
324 let mut col_norm = zero;
325 for k in 0..n {
326 // ignore the diagonal term, because the algorithm only works with
327 // the off-diagonal matrix
328 if k != j {
329 row_norm = row_norm + a[(j, k)].abs();
330 col_norm = col_norm + a[(k, j)].abs();
331 }
332 }
333 if row_norm == zero || col_norm == zero {
334 continue;
335 }
336 // Sum of original row norm and column norm. To be used in the
337 // condition below.
338 let norm_sum = row_norm + col_norm;
339 // Implicitly finds the integer sigma such that
340 // 2^{2*sigma - 1} < row_norm / col_norm <= 2^{2*sigma + 1}
341 // and sets f = 2^sigma.
342 let mut f = one;
343 let row_norm_half = row_norm * half;
344 // The is_normal serves to stop iteration if we run into numerical
345 // trouble instead of looping forever.
346 while col_norm.is_normal() && col_norm <= row_norm_half {
347 f = f * two;
348 col_norm = col_norm * four;
349 }
350 let row_norm_twice = row_norm * two;
351 while col_norm.is_normal() && col_norm > row_norm_twice {
352 f = f / two;
353 col_norm = col_norm / four;
354 }
355 // By the end of these two loops col_norm has been replaced with
356 // col_norm * f^2.
357
358 // If we have run into trouble we just return
359 if !col_norm.is_normal() {
360 return;
361 }
362
363 // Check if
364 // col_norm * f + row_norm / f < gamma * (col_norm + row_norm)
365 // Since at this point col_norm contains col_norm * f^2, we multiply
366 // both sides of the equation by f.
367 if row_norm + col_norm < gamma * norm_sum * f {
368 converged = false;
369 let f_recip = f.recip();
370 // Let D be a diagonal matrix that contains ones in all the
371 // diagonal elements excepth the j-th, where it contains
372 // f. Replace the matrix A by D^{-1}AD.
373 for k in 0..n {
374 if k != j {
375 a[(j, k)] = a[(j, k)] * f_recip;
376 a[(k, j)] = a[(k, j)] * f;
377 }
378 }
379 }
380 }
381 }
382}