pm_remez/
extrema.rs

1use crate::barycentric::{compute_error, compute_extrema_candidate};
2use crate::compute_cheby_coefficients;
3use crate::eigenvalues::EigenvalueBackend;
4use crate::error::{Error, Result};
5use crate::types::Band;
6use ndarray::Array2;
7use num_traits::{Float, FloatConst};
8
9#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
10pub struct Interval<T> {
11    pub begin: T,
12    pub end: T,
13}
14
15#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
16pub struct ExtremaCandidate<T> {
17    pub x: T,
18    pub error: T,
19    pub desired: T,
20    pub weight: T,
21}
22
23// Initial guess for extremal frequencies: evenly spaced over bands
24pub fn initial_extremal_freqs<T: Float>(bands: &[Band<T>], num_functions: usize) -> Vec<T> {
25    let mut total_band_length = T::zero();
26    for band_length in bands.iter().map(|b| b.len()) {
27        total_band_length = total_band_length + band_length;
28    }
29    let spacing = total_band_length / T::from(num_functions).unwrap();
30    let mut consumed_length = T::zero();
31    let num_bands = bands.len();
32    let mut current_band = bands.iter().enumerate().peekable();
33    (0..(num_functions + 1))
34        .map(|j| {
35            let s = T::from(j).unwrap() * spacing;
36            debug_assert!(s >= consumed_length);
37            let mut u = s - consumed_length;
38            loop {
39                let cband = current_band.peek().unwrap();
40                let band_length = cband.1.len();
41                // the second condition is to avoid going past the last band due to numerical rounding
42                if u <= band_length || cband.0 == num_bands - 1 {
43                    break;
44                }
45                current_band.next();
46                consumed_length = consumed_length + band_length;
47                u = s - consumed_length;
48            }
49            let cband = current_band.peek().unwrap();
50            (cband.1.begin() + u).min(cband.1.end())
51        })
52        .collect()
53}
54
55// Compute subintervals containing the extremal points and band edges (in the
56// [-1, 1] domain).
57pub fn subdivide<T: Float>(x: &[T], bands_x: &[Interval<T>]) -> Vec<Interval<T>> {
58    // reserve capacity for the worst case
59    let mut subintervals = Vec::with_capacity(x.len() + bands_x.len());
60    let mut xs = x.iter().rev().peekable();
61    for band in bands_x {
62        let mut begin = band.begin;
63        loop {
64            match xs.peek() {
65                Some(&&a) => {
66                    match a.partial_cmp(&band.end).unwrap() {
67                        std::cmp::Ordering::Greater => {
68                            // new point to the right of the band end: end
69                            // interval at the right band edge, do not consume
70                            // point.
71                            subintervals.push(Interval {
72                                begin,
73                                end: band.end,
74                            });
75                            break;
76                        }
77                        std::cmp::Ordering::Equal => {
78                            // new point exactly at the band end: end interval
79                            // at the right band edge, consume point.
80                            subintervals.push(Interval {
81                                begin,
82                                end: band.end,
83                            });
84                            xs.next();
85                            break;
86                        }
87                        std::cmp::Ordering::Less => {
88                            // new point inside the band: end interval at this
89                            // point, consume point, the point is the begin of
90                            // the next interval.
91                            if begin != a {
92                                subintervals.push(Interval { begin, end: a });
93                                begin = a;
94                            }
95                            xs.next();
96                        }
97                    }
98                }
99                None => {
100                    // no more points: end interval at the right band edge.
101                    subintervals.push(Interval {
102                        begin,
103                        end: band.end,
104                    });
105                    break;
106                }
107            }
108        }
109    }
110    // check that we have consumed all the points
111    debug_assert!(xs.next().is_none());
112    subintervals
113}
114
115// Find local extrema of error function in subinterval using Chebyshev proxy method
116#[allow(clippy::too_many_arguments)]
117pub fn find_extrema_in_subinterval<'a, T, D, W, B: EigenvalueBackend<T>>(
118    interval: &Interval<T>,
119    cheby_nodes: &[T],
120    x: &'a [T],
121    wk: &'a [T],
122    yk: &'a [T],
123    desired: D,
124    weights: W,
125    eigenvalue_backend: &B,
126) -> Result<impl Iterator<Item = ExtremaCandidate<T>>>
127where
128    T: Float + FloatConst,
129    D: Fn(T) -> T + 'a,
130    W: Fn(T) -> T + 'a,
131{
132    // Compute Chebyshev proxy for error function in interval
133    //
134    // Scale Chebyshev nodes to interval and compute error function
135    let mut cheby_nodes_errors: Vec<T> = {
136        let scale = T::from(0.5).unwrap() * (interval.end - interval.begin);
137        cheby_nodes
138            .iter()
139            .map(|&x0| {
140                let cheby_node_scaled = (x0 + T::one()) * scale + interval.begin;
141                compute_error(cheby_node_scaled, x, wk, yk, &desired, &weights)
142            })
143            .collect()
144    };
145    // Compute coefficients of first-order Chebyshev polynomial expansion
146    let ak = compute_cheby_coefficients(&mut cheby_nodes_errors);
147
148    // Compute derivative of Chebyshev proxy
149    //
150    // Compute coefficients of second-order Chevyshev polynomial expansion of
151    // the derivative of the proxy.
152    let mut ck: Vec<T> = ak
153        .iter()
154        .enumerate()
155        .skip(1)
156        .map(|(k, &a)| T::from(k).unwrap() * a)
157        .collect();
158
159    // Remove high-order coefficients ck which are zero. The colleague matrix
160    // definition needs the leading coefficient to be nonzero.
161    let zero = T::zero();
162    while *ck.last().unwrap() == zero {
163        ck.pop();
164        if ck.is_empty() {
165            return Err(Error::ProxyDerivativeZero);
166        }
167    }
168
169    // Compute colleague matrix of ck. Its eigenvalues are the zeros of the
170    // derivative of the Chebyshev proxy.
171    let s = ck.len() - 1;
172    let mut colleague = Array2::<T>::zeros((s, s));
173    let half = T::from(0.5).unwrap();
174    for j in 0..s - 1 {
175        colleague[(j, j + 1)] = half;
176    }
177    for j in 2..s {
178        colleague[(j, j - 1)] = half;
179    }
180    let scale = T::from(-0.5).unwrap() / *ck.last().unwrap();
181    for j in 0..s {
182        let c = ck[s - 1 - j] * scale;
183        colleague[(j, 0)] = if j == 1 { c + half } else { c };
184    }
185    // Balance matrix for better numerical conditioning
186    balance_matrix(&mut colleague);
187
188    // Compute eigenvalues of colleague matrix. These are the roots of the
189    // derivative of the proxy.
190    let eig = eigenvalue_backend.eigenvalues(colleague)?;
191
192    // Filter only the roots that are real and inside [-1, 1]. Map them to
193    // the original interval.
194    //
195    // The threshold scalar is real, but the type system doesn't know, so we
196    // need to call re.
197    let threshold = T::from(1e-20).unwrap();
198    let limits = -T::one()..=T::one();
199    let scale = T::from(0.5).unwrap() * (interval.end - interval.begin);
200    let begin = interval.begin;
201    Ok(eig.into_iter().filter_map(move |z| {
202        if Float::abs(z.im) < threshold {
203            let x0 = z.re;
204            if limits.contains(&x0) {
205                // map root to interval
206                let y = (x0 + T::one()) * scale + begin;
207                // evaluate error function
208                Some(compute_extrema_candidate(y, x, wk, yk, &desired, &weights))
209            } else {
210                None
211            }
212        } else {
213            None
214        }
215    }))
216}
217
218// Prune extrema candidates to leave only n of them. It assumes that the candidates are sorted.
219pub fn prune_extrema_candidates<T: Float>(
220    candidates: &[ExtremaCandidate<T>],
221    n: usize,
222) -> Result<Vec<ExtremaCandidate<T>>> {
223    assert!(!candidates.is_empty());
224    let mut pruned = Vec::with_capacity(candidates.len());
225    let zero = T::zero();
226
227    // From groups of adjacent extrema with the same sign, leave only the largest
228    let mut b = candidates[0];
229    let mut b_sign = b.error < zero;
230    let mut b_abs = b.error.abs();
231    for &a in candidates.iter().skip(1) {
232        let a_sign = a.error < zero;
233        let a_abs = a.error.abs();
234        if a_sign != b_sign {
235            pruned.push(b);
236        }
237        if a_sign != b_sign || a_abs > b_abs {
238            b = a;
239            b_sign = a_sign;
240            b_abs = a_abs;
241        }
242    }
243    pruned.push(b);
244
245    if pruned.len() == n {
246        return Ok(pruned);
247    }
248    if pruned.len() < n {
249        return Err(Error::NotEnoughExtrema);
250    }
251
252    let to_remove = pruned.len() - n;
253    // FIXME: This technique gives convergence problems in some cases, such as a
254    // lowpass FIR with 1/f stopband response when the stopband weigth is set
255    // large enough. The last extrema is always removed and the problem starts
256    // ignoring the error at the end of the stopband.
257    //
258    // if to_remove == 1 {
259    //     // Only one extrema needs to be removed. Consider the cases of removing
260    //     // the first extrema and the last extrema, and compute delta for each of
261    //     // them. Choose the option that gives a larger delta.
262    //     let delta_keep_first = compute_delta_from_candidates(&pruned[..n]);
263    //     let delta_keep_last = compute_delta_from_candidates(&pruned[1..]);
264    //     if delta_keep_first >= delta_keep_last {
265    //         // remove last candidate
266    //         pruned.pop();
267    //     } else {
268    //         // remove first candidate
269    //         pruned.remove(0);
270    //     }
271    //     return Ok(pruned);
272    // }
273    if to_remove % 2 == 1 {
274        // An odd number of extrema need to be removed. Reduce this to the case
275        // of an even number of extrema for removal by removing either the first
276        // or last extrema, whichever has smaller error.
277        if pruned[0].error.abs() >= pruned[pruned.len() - 1].error.abs() {
278            pruned.pop();
279        } else {
280            pruned.remove(0);
281        }
282    }
283    while pruned.len() > n {
284        // An even number of extrema need to be removed. Find the pair of
285        // elements that has smaller minimum absolute value among the two
286        // elements of the pair and remove that pair.
287        let idx = pruned
288            .iter()
289            .zip(pruned.iter().skip(1))
290            .enumerate()
291            .map(|(k, (a, b))| (k, a.error.abs().min(b.error.abs())))
292            // unwrap will fail if there are NaN's
293            .min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
294            .unwrap()
295            .0;
296        pruned.drain(idx..=idx + 1);
297    }
298    assert!(pruned.len() == n);
299    Ok(pruned)
300}
301
302// Balance a matrix for eigenvalue calculation, as indicated in [5].
303fn balance_matrix<T: Float>(a: &mut Array2<T>) {
304    // Some constants to be used below
305    let gamma = T::from(0.95).unwrap();
306    let two = T::from(2.0).unwrap();
307    let four = T::from(4.0).unwrap();
308    let half = T::from(0.5).unwrap();
309    let one = T::one();
310    let zero = T::zero();
311
312    // The algorithm in [5] has a preliminary step where rows and columns that
313    // isolate an eignevalue (those that zero except on the diagonal element)
314    // are pushed to the left or bottom of the matrix respectively. However, the
315    // colleague matrix does not have any such rows or columns, so we don't need
316    // this step.
317
318    let n = a.nrows();
319    let mut converged = false;
320    while !converged {
321        converged = true;
322        for j in 0..n {
323            let mut row_norm = zero;
324            let mut col_norm = zero;
325            for k in 0..n {
326                // ignore the diagonal term, because the algorithm only works with
327                // the off-diagonal matrix
328                if k != j {
329                    row_norm = row_norm + a[(j, k)].abs();
330                    col_norm = col_norm + a[(k, j)].abs();
331                }
332            }
333            if row_norm == zero || col_norm == zero {
334                continue;
335            }
336            // Sum of original row norm and column norm. To be used in the
337            // condition below.
338            let norm_sum = row_norm + col_norm;
339            // Implicitly finds the integer sigma such that
340            // 2^{2*sigma - 1} < row_norm / col_norm <= 2^{2*sigma + 1}
341            // and sets f = 2^sigma.
342            let mut f = one;
343            let row_norm_half = row_norm * half;
344            // The is_normal serves to stop iteration if we run into numerical
345            // trouble instead of looping forever.
346            while col_norm.is_normal() && col_norm <= row_norm_half {
347                f = f * two;
348                col_norm = col_norm * four;
349            }
350            let row_norm_twice = row_norm * two;
351            while col_norm.is_normal() && col_norm > row_norm_twice {
352                f = f / two;
353                col_norm = col_norm / four;
354            }
355            // By the end of these two loops col_norm has been replaced with
356            // col_norm * f^2.
357
358            // If we have run into trouble we just return
359            if !col_norm.is_normal() {
360                return;
361            }
362
363            // Check if
364            // col_norm * f + row_norm / f < gamma * (col_norm + row_norm)
365            // Since at this point col_norm contains col_norm * f^2, we multiply
366            // both sides of the equation by f.
367            if row_norm + col_norm < gamma * norm_sum * f {
368                converged = false;
369                let f_recip = f.recip();
370                // Let D be a diagonal matrix that contains ones in all the
371                // diagonal elements excepth the j-th, where it contains
372                // f. Replace the matrix A by D^{-1}AD.
373                for k in 0..n {
374                    if k != j {
375                        a[(j, k)] = a[(j, k)] * f_recip;
376                        a[(k, j)] = a[(k, j)] * f;
377                    }
378                }
379            }
380        }
381    }
382}