cute_dsp/
phase_rotation.rs

1//! Phase Rotation and Hilbert Transform
2//!
3//! This module provides utilities for phase manipulation, including:
4//! - Hilbert transform (90-degree phase shift)
5//! - Analytic signal generation
6//! - Phase rotation of signals
7//! - Phase unwrapping and analysis
8
9#[cfg(feature = "std")]
10use std::vec::Vec;
11
12#[cfg(all(not(feature = "std"), feature = "alloc"))]
13use alloc::vec::Vec;
14
15use num_traits::{Float, FromPrimitive};
16use num_complex::Complex;
17use crate::fft::SimpleFFT;
18
19/// Hilbert Transform processor
20///
21/// Produces a 90-degree phase-shifted version of a signal.
22/// Can be used to generate analytic signals.
23pub struct HilbertTransform<T: Float + FromPrimitive> {
24    fft: SimpleFFT<T>,
25    fft_size: usize,
26}
27
28impl<T: Float + FromPrimitive> HilbertTransform<T> {
29    /// Create a new Hilbert transform processor with the specified FFT size
30    pub fn new(fft_size: usize) -> Self {
31        let fft = SimpleFFT::new(fft_size);
32        Self { fft, fft_size }
33    }
34
35    /// Compute the Hilbert transform of a signal
36    ///
37    /// Returns the 90-degree phase-shifted version of the input signal.
38    pub fn transform(&mut self, signal: &[T]) -> Vec<T> {
39        if signal.is_empty() {
40            return Vec::new();
41        }
42
43        // Pad to FFT size
44        let mut padded = vec![Complex::new(T::zero(), T::zero()); self.fft_size];
45        for (i, &val) in signal.iter().enumerate().take(self.fft_size) {
46            padded[i] = Complex::new(val, T::zero());
47        }
48
49        // Forward FFT
50        let mut freq = vec![Complex::new(T::zero(), T::zero()); self.fft_size];
51        self.fft.fft(&padded, &mut freq);
52
53        // Apply Hilbert filter: multiply positive frequencies by -j, negative by j
54        let mid = self.fft_size / 2;
55        
56        // DC and Nyquist components stay zero
57        freq[0] = Complex::new(T::zero(), T::zero());
58        if mid < self.fft_size {
59            freq[mid] = Complex::new(T::zero(), T::zero());
60        }
61
62        // Multiply positive frequencies by -j (rotate by -90 degrees)
63        for k in 1..mid {
64            let j_mult = Complex::new(T::zero(), -T::one());
65            freq[k] = freq[k] * j_mult;
66        }
67
68        // Multiply negative frequencies by j (rotate by +90 degrees)
69        for k in (mid + 1)..self.fft_size {
70            let j_mult = Complex::new(T::zero(), T::one());
71            freq[k] = freq[k] * j_mult;
72        }
73
74        // Inverse FFT
75        let mut hilbert = vec![Complex::new(T::zero(), T::zero()); self.fft_size];
76        self.fft.ifft(&freq, &mut hilbert);
77
78        // Extract real part and scale
79        let scale = T::from_f64(2.0).unwrap();
80        hilbert
81            .iter()
82            .take(signal.len())
83            .map(|c| c.re * scale)
84            .collect()
85    }
86
87    /// Create an analytic signal from the input
88    ///
89    /// Returns a vector of complex numbers where:
90    /// - Real part = original signal
91    /// - Imaginary part = Hilbert transform (90° phase-shifted version)
92    pub fn analytic_signal(&mut self, signal: &[T]) -> Vec<Complex<T>> {
93        if signal.is_empty() {
94            return Vec::new();
95        }
96
97        // Pad to FFT size
98        let mut padded = vec![Complex::new(T::zero(), T::zero()); self.fft_size];
99        for (i, &val) in signal.iter().enumerate().take(self.fft_size) {
100            padded[i] = Complex::new(val, T::zero());
101        }
102
103        // Forward FFT
104        let mut freq = vec![Complex::new(T::zero(), T::zero()); self.fft_size];
105        self.fft.fft(&padded, &mut freq);
106
107        // Zero out negative frequencies and double positive ones
108        let mid = self.fft_size / 2;
109        for k in (mid + 1)..self.fft_size {
110            freq[k] = Complex::new(T::zero(), T::zero());
111        }
112        
113        // Double positive frequencies (except DC and Nyquist)
114        let two = T::from_f64(2.0).unwrap();
115        for k in 1..mid {
116            freq[k] = freq[k] * two;
117        }
118
119        // Inverse FFT
120        let mut analytic = vec![Complex::new(T::zero(), T::zero()); self.fft_size];
121        self.fft.ifft(&freq, &mut analytic);
122
123        analytic.iter().take(signal.len()).copied().collect()
124    }
125}
126
127/// Phase Rotator for applying phase shifts to signals
128pub struct PhaseRotator<T: Float> {
129    /// Current phase accumulator
130    phase: T,
131    /// Phase increment per sample
132    phase_increment: T,
133    /// Two pi constant
134    two_pi: T,
135}
136
137impl<T: Float + FromPrimitive> PhaseRotator<T> {
138    /// Create a new phase rotator with specified frequency
139    ///
140    /// # Arguments
141    /// * `frequency` - The frequency of the oscillation in Hz
142    /// * `sample_rate` - The sample rate in Hz
143    pub fn new(frequency: T, sample_rate: T) -> Self {
144        let two_pi = T::from_f64(std::f64::consts::PI * 2.0).unwrap();
145        let phase_increment = (two_pi * frequency) / sample_rate;
146        
147        Self {
148            phase: T::zero(),
149            phase_increment,
150            two_pi,
151        }
152    }
153
154    /// Process a sample and apply phase rotation
155    /// Returns the rotated sample
156    pub fn process(&mut self, sample: T) -> T {
157        let output = sample * self.phase.cos();
158        self.advance_phase();
159        output
160    }
161
162    /// Process a sample with quadrature output (real and imaginary)
163    /// Returns (in_phase, quadrature)
164    pub fn process_quadrature(&mut self, sample: T) -> (T, T) {
165        let in_phase = sample * self.phase.cos();
166        let quadrature = sample * self.phase.sin();
167        self.advance_phase();
168        (in_phase, quadrature)
169    }
170
171    /// Process a block of samples
172    pub fn process_block(&mut self, input: &[T]) -> Vec<T> {
173        input.iter().map(|&s| self.process(s)).collect()
174    }
175
176    /// Rotate all samples in a vector by a fixed phase angle
177    pub fn rotate_by_angle(signal: &[T], angle: T) -> Vec<T> {
178        let cos_angle = angle.cos();
179        signal.iter().map(|&s| s * cos_angle).collect()
180    }
181
182    /// Advance the phase by one sample
183    fn advance_phase(&mut self) {
184        self.phase = self.phase + self.phase_increment;
185        
186        // Wrap phase to [0, 2π)
187        let two_pi = self.two_pi;
188        if self.phase >= two_pi {
189            let cycles = (self.phase / two_pi).floor();
190            self.phase = self.phase - (cycles * two_pi);
191        }
192    }
193
194    /// Reset phase to zero
195    pub fn reset(&mut self) {
196        self.phase = T::zero();
197    }
198
199    /// Get current phase
200    pub fn get_phase(&self) -> T {
201        self.phase
202    }
203
204    /// Set current phase
205    pub fn set_phase(&mut self, phase: T) {
206        self.phase = phase;
207    }
208
209    /// Set frequency
210    pub fn set_frequency(&mut self, frequency: T, sample_rate: T) {
211        self.phase_increment = (self.two_pi * frequency) / sample_rate;
212    }
213}
214
215/// Compute instantaneous phase of a signal using analytic signal
216pub fn instantaneous_phase(analytic: &[Complex<f32>]) -> Vec<f32> {
217    analytic
218        .iter()
219        .map(|c| c.im.atan2(c.re))
220        .collect()
221}
222
223/// Compute instantaneous magnitude (amplitude) of a signal
224pub fn instantaneous_magnitude(analytic: &[Complex<f32>]) -> Vec<f32> {
225    analytic
226        .iter()
227        .map(|c| (c.re * c.re + c.im * c.im).sqrt())
228        .collect()
229}
230
231/// Compute instantaneous frequency using phase derivative
232pub fn instantaneous_frequency(
233    phase: &[f32],
234    sample_rate: f32,
235) -> Vec<f32> {
236    if phase.len() < 2 {
237        return Vec::new();
238    }
239
240    let mut freq = Vec::with_capacity(phase.len());
241    let two_pi = std::f32::consts::PI * 2.0;
242
243    for i in 0..phase.len() - 1 {
244        let mut phase_diff = phase[i + 1] - phase[i];
245        
246        // Unwrap phase if needed
247        if phase_diff > std::f32::consts::PI {
248            phase_diff -= two_pi;
249        } else if phase_diff < -std::f32::consts::PI {
250            phase_diff += two_pi;
251        }
252
253        let inst_freq = (phase_diff * sample_rate) / two_pi;
254        freq.push(inst_freq);
255    }
256
257    // Replicate last value
258    if let Some(&last) = freq.last() {
259        freq.push(last);
260    }
261
262    freq
263}
264
265#[cfg(test)]
266mod tests {
267    use super::*;
268
269    #[test]
270    fn test_hilbert_transform_basic() {
271        let mut hilbert = HilbertTransform::new(64);
272        let signal = vec![1.0; 10];
273        let result = hilbert.transform(&signal);
274        assert_eq!(result.len(), 10);
275    }
276
277    #[test]
278    fn test_analytic_signal() {
279        let mut hilbert = HilbertTransform::new(64);
280        let signal = vec![1.0, 0.5, 0.25, 0.125];
281        let analytic = hilbert.analytic_signal(&signal);
282        assert_eq!(analytic.len(), 4);
283        // Just verify that the analytic signal was computed
284        assert!(analytic.iter().any(|c| c.re != 0.0 || c.im != 0.0));
285    }
286
287    #[test]
288    fn test_phase_rotator() {
289        let mut rotator = PhaseRotator::new(1.0, 10.0);
290        let sample = 1.0;
291        let output = rotator.process(sample);
292        assert!(output.is_finite());
293    }
294
295    #[test]
296    fn test_phase_rotator_quadrature() {
297        let mut rotator = PhaseRotator::new(1.0, 10.0);
298        let sample = 1.0;
299        let (i, q) = rotator.process_quadrature(sample);
300        assert!(i.is_finite() && q.is_finite());
301    }
302
303    #[test]
304    fn test_instantaneous_phase() {
305        let analytic = vec![
306            Complex::new(1.0, 0.0),
307            Complex::new(0.707, 0.707),
308            Complex::new(0.0, 1.0),
309        ];
310        let phase = instantaneous_phase(&analytic);
311        assert_eq!(phase.len(), 3);
312    }
313
314    #[test]
315    fn test_phase_rotator_reset() {
316        let mut rotator = PhaseRotator::new(1.0, 10.0);
317        rotator.process(1.0);
318        assert!(rotator.get_phase() > 0.0);
319        rotator.reset();
320        assert_eq!(rotator.get_phase(), 0.0);
321    }
322}