Skip to main content

bids_filter/
lib.rs

1#![deny(unsafe_code)]
2//! Butterworth IIR filter with zero-phase filtfilt.
3//!
4//! Implements `scipy.signal.butter` + `scipy.signal.filtfilt` for anti-alias
5//! filtering before downsampling, as used by PyBIDS.
6
7use std::f64::consts::PI;
8
9/// Butterworth low-pass filter coefficients.
10///
11/// Returns `(b, a)` — numerator and denominator of the IIR transfer function,
12/// matching `scipy.signal.butter(order, cutoff, btype='low', output='ba')`.
13///
14/// `cutoff` is normalized frequency: cutoff_hz / (fs/2), must be in (0, 1).
15#[must_use]
16pub fn butter_lowpass(order: usize, cutoff: f64) -> (Vec<f64>, Vec<f64>) {
17    assert!(
18        cutoff > 0.0 && cutoff < 1.0,
19        "cutoff must be in (0,1), got {cutoff}"
20    );
21
22    // Step 1: Analog Butterworth poles on the unit circle (left half-plane)
23    let mut poles_s: Vec<(f64, f64)> = Vec::with_capacity(order);
24    for k in 0..order {
25        let theta = PI * (2 * k + order + 1) as f64 / (2 * order) as f64;
26        poles_s.push((theta.cos(), theta.sin()));
27    }
28
29    // Step 2: Pre-warp cutoff for bilinear transform
30    let fs = 2.0; // normalized sampling rate
31    let warped = 2.0 * fs * (PI * cutoff / fs).tan();
32
33    // Step 3: Scale analog poles by warped cutoff
34    let poles_a: Vec<(f64, f64)> = poles_s
35        .iter()
36        .map(|(re, im)| (re * warped, im * warped))
37        .collect();
38
39    // Step 4: Bilinear transform: s -> (2*fs*(z-1))/(z+1)
40    // z-domain pole = (1 + s/(2*fs)) / (1 - s/(2*fs))
41    let mut poles_z: Vec<(f64, f64)> = Vec::new();
42    let c = 2.0 * fs;
43    for &(re, im) in &poles_a {
44        let denom_re = 1.0 - re / c;
45        let denom_im = -im / c;
46        let num_re = 1.0 + re / c;
47        let num_im = im / c;
48        let d2 = denom_re * denom_re + denom_im * denom_im;
49        poles_z.push((
50            (num_re * denom_re + num_im * denom_im) / d2,
51            (num_im * denom_re - num_re * denom_im) / d2,
52        ));
53    }
54
55    // Step 5: All zeros at z = -1 for low-pass
56    let zeros_z: Vec<(f64, f64)> = vec![(-1.0, 0.0); order];
57
58    // Step 6: Convert poles/zeros to polynomial coefficients
59    let a = poly_from_roots(&poles_z);
60    let b_unnorm = poly_from_roots(&zeros_z);
61
62    // Step 7: Normalize gain at DC (z=1)
63    let gain_a: f64 = a.iter().sum();
64    let gain_b: f64 = b_unnorm.iter().sum();
65    let gain = gain_a / gain_b;
66
67    let b: Vec<f64> = b_unnorm.iter().map(|&x| x * gain).collect();
68
69    (b, a)
70}
71
72/// Direct-form II transposed IIR filter (single-pass).
73#[must_use]
74pub fn lfilter(b: &[f64], a: &[f64], x: &[f64]) -> Vec<f64> {
75    let n = x.len();
76    let nb = b.len();
77    let na = a.len();
78    let nfilt = nb.max(na);
79    let mut y = vec![0.0; n];
80    let mut d = vec![0.0; nfilt]; // delay line
81
82    let a0 = a[0];
83    for i in 0..n {
84        let mut out = b.first().copied().unwrap_or(0.0) * x[i] + d[0];
85        out /= a0;
86        y[i] = out;
87        // Shift delay line
88        for j in 0..nfilt - 1 {
89            d[j] = b.get(j + 1).copied().unwrap_or(0.0) * x[i]
90                - a.get(j + 1).copied().unwrap_or(0.0) * out
91                + d.get(j + 1).copied().unwrap_or(0.0);
92        }
93        if nfilt > 0 {
94            d[nfilt - 1] = 0.0;
95        }
96    }
97    y
98}
99
100/// Zero-phase filtering: apply filter forward, then backward.
101///
102/// Matches `scipy.signal.filtfilt(b, a, x)`.
103#[must_use]
104pub fn filtfilt(b: &[f64], a: &[f64], x: &[f64]) -> Vec<f64> {
105    if x.is_empty() {
106        return vec![];
107    }
108    // Pad signal to reduce edge effects (3 * max(len(a), len(b)) samples)
109    let npad = 3 * b.len().max(a.len());
110    let mut padded = Vec::with_capacity(x.len() + 2 * npad);
111    // Reflect-pad start
112    for i in (1..=npad.min(x.len() - 1)).rev() {
113        padded.push(2.0 * x[0] - x[i]);
114    }
115    padded.extend_from_slice(x);
116    // Reflect-pad end
117    let last = x[x.len() - 1];
118    for i in 1..=npad.min(x.len() - 1) {
119        padded.push(2.0 * last - x[x.len() - 1 - i]);
120    }
121
122    // Forward pass
123    let fwd = lfilter(b, a, &padded);
124    // Reverse
125    let rev_input: Vec<f64> = fwd.into_iter().rev().collect();
126    // Backward pass
127    let bwd = lfilter(b, a, &rev_input);
128    // Reverse again and extract original-length portion
129    let result: Vec<f64> = bwd.into_iter().rev().collect();
130    let start = npad.min(result.len());
131    let end = (start + x.len()).min(result.len());
132    result[start..end].to_vec()
133}
134
135/// Butterworth high-pass filter coefficients.
136///
137/// Returns `(b, a)` matching `scipy.signal.butter(order, cutoff, btype='high')`.
138/// `cutoff` is normalized frequency in (0, 1).
139#[must_use]
140pub fn butter_highpass(order: usize, cutoff: f64) -> (Vec<f64>, Vec<f64>) {
141    assert!(
142        cutoff > 0.0 && cutoff < 1.0,
143        "cutoff must be in (0,1), got {cutoff}"
144    );
145    // High-pass = transform low-pass: s → 1/s before bilinear.
146    // Equivalently: zeros at z=+1 (not -1), and gain at Nyquist=1.
147    let (b_lp, a_lp) = butter_lowpass(order, cutoff);
148    // Spectral inversion: negate every other coefficient of b, flipping the response
149    let b_hp: Vec<f64> = b_lp
150        .iter()
151        .enumerate()
152        .map(|(i, &v)| if i % 2 == 0 { v } else { -v })
153        .collect();
154    let a_hp: Vec<f64> = a_lp
155        .iter()
156        .enumerate()
157        .map(|(i, &v)| if i % 2 == 0 { v } else { -v })
158        .collect();
159    // Re-normalize: gain at Nyquist (z = -1) should be 1
160    let gain_a: f64 = a_hp
161        .iter()
162        .enumerate()
163        .map(|(i, &v)| v * (-1.0f64).powi(i as i32))
164        .sum();
165    let gain_b: f64 = b_hp
166        .iter()
167        .enumerate()
168        .map(|(i, &v)| v * (-1.0f64).powi(i as i32))
169        .sum();
170    let gain = gain_a / gain_b;
171    let b: Vec<f64> = b_hp.iter().map(|&v| v * gain).collect();
172    (b, a_hp)
173}
174
175/// Butterworth band-pass filter coefficients.
176///
177/// Returns `(b, a)` matching `scipy.signal.butter(order, [low, high], btype='band')`.
178/// `low` and `high` are normalized frequencies in (0, 1).
179#[must_use]
180pub fn butter_bandpass(order: usize, low: f64, high: f64) -> (Vec<f64>, Vec<f64>) {
181    assert!(
182        low > 0.0 && low < 1.0 && high > 0.0 && high < 1.0 && low < high,
183        "frequencies must be 0 < low < high < 1, got low={low}, high={high}"
184    );
185    // Cascade: highpass at `low` then lowpass at `high`
186    let (b_hp, a_hp) = butter_highpass(order, low);
187    let (b_lp, a_lp) = butter_lowpass(order, high);
188    // Convolve the transfer functions: B(z) = B_hp(z) * B_lp(z), A(z) = A_hp(z) * A_lp(z)
189    let b = convolve(&b_hp, &b_lp);
190    let a = convolve(&a_hp, &a_lp);
191    (b, a)
192}
193
194/// Notch (band-stop) filter using second-order IIR sections.
195///
196/// Removes a narrow frequency band around `freq_hz` (e.g., 50 or 60 Hz power line noise).
197/// `quality` controls the notch width (default ~30). Higher = narrower.
198///
199/// Like MNE's `raw.notch_filter()` or `scipy.signal.iirnotch`.
200#[must_use]
201pub fn notch_filter(x: &[f64], freq_hz: f64, fs: f64, quality: f64) -> Vec<f64> {
202    let w0 = 2.0 * PI * freq_hz / fs;
203    let bw = w0 / quality;
204    let r = 1.0 - bw / 2.0; // pole radius
205
206    // Second-order IIR notch: zeros on unit circle at ±w0, poles just inside
207    let b: &[f64] = &[1.0, -2.0 * w0.cos(), 1.0];
208    let a: &[f64] = &[1.0, -2.0 * r * w0.cos(), r * r];
209
210    // Normalize gain at DC to 1
211    let dc_b: f64 = b.iter().sum();
212    let dc_a: f64 = a.iter().sum();
213    let gain = dc_a / dc_b;
214    let b_norm: Vec<f64> = b.iter().map(|&v| v * gain).collect();
215
216    filtfilt(&b_norm, a, x)
217}
218
219/// Resample a signal from `fs_old` to `fs_new` using linear interpolation.
220///
221/// For anti-aliasing when downsampling, applies a lowpass filter at the new
222/// Nyquist frequency before decimation (like MNE's `raw.resample()`).
223#[must_use]
224pub fn resample(x: &[f64], fs_old: f64, fs_new: f64) -> Vec<f64> {
225    if x.is_empty() || fs_old <= 0.0 || fs_new <= 0.0 {
226        return vec![];
227    }
228    if (fs_old - fs_new).abs() < 1e-10 {
229        return x.to_vec();
230    }
231
232    let ratio = fs_new / fs_old;
233    let n_out = (x.len() as f64 * ratio).round() as usize;
234    if n_out == 0 {
235        return vec![];
236    }
237
238    // Anti-alias filter when downsampling
239    let src = if fs_new < fs_old {
240        let cutoff = fs_new / fs_old; // normalized Nyquist of new rate
241        let cutoff = cutoff.clamp(0.01, 0.99);
242        let (b, a) = butter_lowpass(8, cutoff);
243        filtfilt(&b, &a, x)
244    } else {
245        x.to_vec()
246    };
247
248    // Linear interpolation
249    let mut out = Vec::with_capacity(n_out);
250    for i in 0..n_out {
251        let t = i as f64 / ratio;
252        let idx = t.floor() as usize;
253        let frac = t - idx as f64;
254        if idx + 1 < src.len() {
255            out.push(src[idx] * (1.0 - frac) + src[idx + 1] * frac);
256        } else if idx < src.len() {
257            out.push(src[idx]);
258        }
259    }
260    out
261}
262
263/// Convolve two polynomial coefficient vectors.
264fn convolve(a: &[f64], b: &[f64]) -> Vec<f64> {
265    let n = a.len() + b.len() - 1;
266    let mut out = vec![0.0; n];
267    for (i, &av) in a.iter().enumerate() {
268        for (j, &bv) in b.iter().enumerate() {
269            out[i + j] += av * bv;
270        }
271    }
272    out
273}
274
275/// Compute polynomial coefficients from complex roots.
276/// Returns real coefficients [1, c1, c2, ...] for (z-r1)(z-r2)...
277fn poly_from_roots(roots: &[(f64, f64)]) -> Vec<f64> {
278    let mut coeffs: Vec<(f64, f64)> = vec![(1.0, 0.0)];
279    for &(rr, ri) in roots {
280        let mut new_coeffs = vec![(0.0, 0.0); coeffs.len() + 1];
281        for (i, &(cr, ci)) in coeffs.iter().enumerate() {
282            // Multiply by (z - root): shift + subtract root*current
283            new_coeffs[i].0 += cr;
284            new_coeffs[i].1 += ci;
285            new_coeffs[i + 1].0 -= cr * rr - ci * ri;
286            new_coeffs[i + 1].1 -= cr * ri + ci * rr;
287        }
288        coeffs = new_coeffs;
289    }
290    // Extract real parts (imaginary should be ~0 for conjugate pairs)
291    coeffs.iter().map(|(r, _)| *r).collect()
292}
293
294#[cfg(test)]
295mod tests {
296    use super::*;
297
298    #[test]
299    fn test_butter_lowpass_order1() {
300        let (b, a) = butter_lowpass(1, 0.5);
301        // For order 1 at Nyquist/2, known result
302        assert_eq!(b.len(), 2);
303        assert_eq!(a.len(), 2);
304        assert!((a[0] - 1.0).abs() < 1e-10);
305        // DC gain should be 1.0
306        let dc_gain: f64 = b.iter().sum::<f64>() / a.iter().sum::<f64>();
307        assert!((dc_gain - 1.0).abs() < 1e-10, "DC gain = {}", dc_gain);
308    }
309
310    #[test]
311    fn test_butter_lowpass_order5() {
312        let (b, a) = butter_lowpass(5, 0.25);
313        assert_eq!(b.len(), 6);
314        assert_eq!(a.len(), 6);
315        let dc_gain: f64 = b.iter().sum::<f64>() / a.iter().sum::<f64>();
316        assert!((dc_gain - 1.0).abs() < 1e-10, "DC gain = {}", dc_gain);
317    }
318
319    #[test]
320    fn test_lfilter_passthrough() {
321        // With b=[1], a=[1], output = input
322        let x: Vec<f64> = (0..10).map(|i| i as f64).collect();
323        let y = lfilter(&[1.0], &[1.0], &x);
324        assert_eq!(y, x);
325    }
326
327    #[test]
328    fn test_butter_highpass() {
329        let (b, a) = butter_highpass(3, 0.1);
330        assert_eq!(b.len(), 4);
331        assert_eq!(a.len(), 4);
332        // DC gain should be ~0 (high-pass blocks DC)
333        let dc_gain: f64 = b.iter().sum::<f64>() / a.iter().sum::<f64>();
334        assert!(
335            dc_gain.abs() < 0.01,
336            "HP DC gain should be ~0, got {}",
337            dc_gain
338        );
339        // Nyquist gain should be ~1
340        let ny_gain_b: f64 = b
341            .iter()
342            .enumerate()
343            .map(|(i, &v)| v * (-1.0f64).powi(i as i32))
344            .sum();
345        let ny_gain_a: f64 = a
346            .iter()
347            .enumerate()
348            .map(|(i, &v)| v * (-1.0f64).powi(i as i32))
349            .sum();
350        let ny_gain = ny_gain_b / ny_gain_a;
351        assert!(
352            (ny_gain - 1.0).abs() < 0.01,
353            "HP Nyquist gain should be ~1, got {}",
354            ny_gain
355        );
356    }
357
358    #[test]
359    fn test_butter_bandpass() {
360        let (b, a) = butter_bandpass(2, 0.1, 0.4);
361        // DC gain should be ~0
362        let dc_b: f64 = b.iter().sum();
363        let dc_a: f64 = a.iter().sum();
364        assert!((dc_b / dc_a).abs() < 0.01, "BP DC gain should be ~0");
365        // Nyquist gain should be ~0
366        let ny_b: f64 = b
367            .iter()
368            .enumerate()
369            .map(|(i, &v)| v * (-1.0f64).powi(i as i32))
370            .sum();
371        let ny_a: f64 = a
372            .iter()
373            .enumerate()
374            .map(|(i, &v)| v * (-1.0f64).powi(i as i32))
375            .sum();
376        assert!((ny_b / ny_a).abs() < 0.01, "BP Nyquist gain should be ~0");
377    }
378
379    #[test]
380    fn test_notch_filter() {
381        // Signal with 50 Hz hum + 10 Hz signal
382        let fs = 500.0;
383        let n = 1000;
384        let signal: Vec<f64> = (0..n)
385            .map(|i| {
386                let t = i as f64 / fs;
387                (2.0 * PI * 10.0 * t).sin() + 0.5 * (2.0 * PI * 50.0 * t).sin()
388            })
389            .collect();
390
391        let filtered = notch_filter(&signal, 50.0, fs, 30.0);
392        assert_eq!(filtered.len(), n);
393
394        // 50 Hz energy should be reduced; 10 Hz preserved
395        // Compare energy in second half (after transient)
396        let half = n / 2;
397        let orig_energy: f64 = signal[half..].iter().map(|v| v * v).sum::<f64>();
398        let filt_energy: f64 = filtered[half..].iter().map(|v| v * v).sum::<f64>();
399        // Original has ~1.25 (1.0 + 0.25), filtered should have ~1.0
400        assert!(
401            filt_energy < orig_energy * 0.9,
402            "Notch should reduce energy: orig={:.3}, filt={:.3}",
403            orig_energy,
404            filt_energy
405        );
406    }
407
408    #[test]
409    fn test_resample_downsample() {
410        let fs = 1000.0;
411        let n = 1000;
412        let signal: Vec<f64> = (0..n)
413            .map(|i| {
414                let t = i as f64 / fs;
415                (2.0 * PI * 10.0 * t).sin()
416            })
417            .collect();
418
419        let resampled = resample(&signal, fs, 250.0);
420        assert_eq!(resampled.len(), 250); // 1000 * 250/1000
421
422        // The 10 Hz signal should be preserved at 250 Hz sampling
423        // Check approximate peak
424        let max = resampled.iter().copied().fold(f64::NEG_INFINITY, f64::max);
425        assert!(max > 0.8, "10 Hz should be preserved, peak={:.3}", max);
426    }
427
428    #[test]
429    fn test_resample_upsample() {
430        let signal = vec![0.0, 1.0, 0.0, -1.0, 0.0];
431        let up = resample(&signal, 100.0, 200.0);
432        assert_eq!(up.len(), 10);
433        // First and last should be close to original
434        assert!((up[0] - 0.0).abs() < 0.1);
435        assert!((up[2] - 1.0).abs() < 0.1);
436    }
437
438    #[test]
439    fn test_resample_identity() {
440        let signal: Vec<f64> = (0..100).map(|i| i as f64).collect();
441        let out = resample(&signal, 256.0, 256.0);
442        assert_eq!(out.len(), signal.len());
443        for (a, b) in signal.iter().zip(out.iter()) {
444            assert!((a - b).abs() < 1e-10);
445        }
446    }
447
448    #[test]
449    fn test_filtfilt_removes_high_freq() {
450        // Generate signal: low freq + high freq
451        let n = 200;
452        let fs = 100.0;
453        let signal: Vec<f64> = (0..n)
454            .map(|i| {
455                let t = i as f64 / fs;
456                (2.0 * PI * 5.0 * t).sin() + (2.0 * PI * 40.0 * t).sin()
457            })
458            .collect();
459
460        // Low-pass at 10 Hz (cutoff = 10 / (100/2) = 0.2)
461        let (b, a) = butter_lowpass(5, 0.2);
462        let filtered = filtfilt(&b, &a, &signal);
463
464        assert_eq!(filtered.len(), signal.len());
465        // The 40Hz component should be attenuated significantly
466        // Check that the filtered signal has lower energy than the original
467        let orig_energy: f64 = signal.iter().map(|v| v * v).sum::<f64>() / n as f64;
468        let filt_energy: f64 = filtered.iter().map(|v| v * v).sum::<f64>() / n as f64;
469        assert!(
470            filt_energy < orig_energy * 0.7,
471            "Filtered energy {} should be much less than original {}",
472            filt_energy,
473            orig_energy
474        );
475    }
476}