scirs2-signal 0.1.0-rc.2

Signal processing module for SciRS2 (scirs2-signal)
Documentation
// Example of adaptive filtering using LMS, NLMS, and RLS algorithms
//
// This example demonstrates adaptive filter applications including:
// - System identification
// - Noise cancellation
// - Echo cancellation
// - Equalization

use scirs2_signal::adaptive::{LmsFilter, NlmsFilter, RlsFilter};
use std::f64::consts::PI;

#[allow(dead_code)]
fn main() {
    println!("Adaptive Filter Example");
    println!("======================\n");

    // Example 1: System Identification with LMS
    println!("1. System Identification with LMS Filter");
    println!("---------------------------------------");

    // Define an unknown system to identify: h = [0.5, -0.3, 0.2, 0.1]
    let unknown_system = vec![0.5, -0.3, 0.2, 0.1];
    println!("Unknown system coefficients: {:?}", unknown_system);

    // Create LMS filter
    let mut lms = LmsFilter::new(4, 0.01, 0.0).unwrap();
    println!("LMS filter: 4 taps, step size = 0.01");

    // Generate input signal (white noise-like)
    let num_samples = 500;
    let mut inputs = Vec::with_capacity(num_samples);
    let mut desired = Vec::with_capacity(num_samples);

    for i in 0..num_samples {
        // Input signal: mixture of sinusoids
        let input = 0.5 * (2.0 * PI * i as f64 * 0.01).sin()
            + 0.3 * (2.0 * PI * i as f64 * 0.02).cos()
            + 0.2 * (2.0 * PI * i as f64 * 0.03).sin();
        inputs.push(input);

        // Generate desired output by filtering input through unknown system
        let mut output = 0.0;
        for (j, &coeff) in unknown_system.iter().enumerate() {
            if i >= j {
                output += coeff * inputs[i - j];
            }
        }
        desired.push(output);
    }

    // Run adaptation
    let (_outputs, errors_mse) = lms.adapt_batch(&inputs, &desired).unwrap();

    println!("Final LMS weights: {:?}", lms.weights());
    println!("Target weights:    {:?}", unknown_system);

    // Calculate identification error
    let final_error: f64 = lms
        .weights()
        .iter()
        .zip(unknown_system.iter())
        .map(|(&w, &t)| (w - t).powi(2))
        .sum::<f64>()
        .sqrt();
    println!("Weight estimation error (RMSE): {:.6}", final_error);

    // Calculate final MSE
    let final_mse: f64 = errors.iter().rev().take(50).map(|&e| e * e).sum::<f64>() / 50.0;
    println!("Final MSE (last 50 samples): {:.6}", final_mse);

    // Example 2: Noise Cancellation with NLMS
    println!("\n\n2. Noise Cancellation with NLMS Filter");
    println!("-------------------------------------");

    // Create desired signal (speech-like)
    let signal_length = 300;
    let mut clean_signal = Vec::with_capacity(signal_length);
    let mut noise_signal = Vec::with_capacity(signal_length);
    let mut noisy_signal = Vec::with_capacity(signal_length);

    for i in 0..signal_length {
        // Clean signal: mixture of speech-like frequencies
        let clean = 0.8 * (2.0 * PI * i as f64 * 0.05).sin()
            + 0.4 * (2.0 * PI * i as f64 * 0.12).sin()
            + 0.2 * (2.0 * PI * i as f64 * 0.08).cos();
        clean_signal.push(clean);

        // Noise: correlated noise that we want to cancel
        let noise =
            0.5 * (2.0 * PI * i as f64 * 0.25).sin() + 0.3 * (2.0 * PI * i as f64 * 0.18).cos();
        noise_signal.push(noise);

        // Noisy signal = clean signal + noise
        noisy_signal.push(clean + noise);
    }

    // Create NLMS filter for noise cancellation
    let mut nlms = NlmsFilter::new(8, 0.5, 1e-6).unwrap();
    println!("NLMS filter: 8 taps, step size = 0.5");

    // Adaptive noise cancellation
    // Input: noise reference, Desired: noisy signal, Output: estimate of noise
    let mut clean_estimate = Vec::with_capacity(signal_length);
    let mut noise_estimate = Vec::with_capacity(signal_length);

    for i in 0..signal_length {
        let (noise_est_error_mse) = nlms.adapt(noise_signal[i], noisy_signal[i]).unwrap();
        noise_estimate.push(noise_est);

        // Clean estimate = noisy signal - noise estimate
        clean_estimate.push(noisy_signal[i] - noise_est);
    }

    // Calculate noise reduction performance
    let original_snr = calculate_snr(&clean_signal, &noise_signal);
    let enhanced_snr = calculate_snr(
        &clean_signal,
        &get_noise_from_clean(&clean_estimate, &clean_signal),
    );

    println!("Original SNR: {:.2} dB", original_snr);
    println!("Enhanced SNR: {:.2} dB", enhanced_snr);
    println!("SNR improvement: {:.2} dB", enhanced_snr - original_snr);

    // Example 3: RLS vs LMS Convergence Comparison
    println!("\n\n3. RLS vs LMS Convergence Comparison");
    println!("-----------------------------------");

    // Target system for comparison
    let target = vec![0.8, -0.6, 0.4];
    println!("Target system: {:?}", target);

    // Create both filters
    let mut lms_comp = LmsFilter::new(3, 0.05, 0.0).unwrap();
    let mut rls_comp = RlsFilter::new(3, 0.99, 100.0).unwrap();

    let convergence_samples = 100;
    let mut lms_weights_history = Vec::new();
    let mut rls_weights_history = Vec::new();
    let mut lms_errors = Vec::new();
    let mut rls_errors = Vec::new();

    for i in 0..convergence_samples {
        // Input signal
        let input = (i as f64 * 0.1).sin() + 0.5 * (i as f64 * 0.05).cos();

        // Desired output from target system
        let mut desired = 0.0;
        if i >= 2 {
            desired = target[0] * input
                + target[1] * ((i - 1) as f64 * 0.1).sin()
                + target[2] * ((i - 2) as f64 * 0.1).sin();
        }

        // Adapt both filters
        let (_out_lms, err_lms_) = lms_comp.adapt(input, desired).unwrap();
        let (_out_rls, err_rls_) = rls_comp.adapt(input, desired).unwrap();

        lms_weights_history.push(lms_comp.weights().to_vec());
        rls_weights_history.push(rls_comp.weights().to_vec());
        lms_errors.push(err_lms.abs());
        rls_errors.push(err_rls.abs());
    }

    println!("Final LMS weights: {:?}", lms_comp.weights());
    println!("Final RLS weights: {:?}", rls_comp.weights());

    // Calculate final convergence error for both
    let lms_final_error: f64 = lms_comp
        .weights()
        .iter()
        .zip(target.iter())
        .map(|(&w, &t)| (w - t).powi(2))
        .sum::<f64>()
        .sqrt();

    let rls_final_error: f64 = rls_comp
        .weights()
        .iter()
        .zip(target.iter())
        .map(|(&w, &t)| (w - t).powi(2))
        .sum::<f64>()
        .sqrt();

    println!("LMS final weight error: {:.6}", lms_final_error);
    println!("RLS final weight error: {:.6}", rls_final_error);

    // Show convergence speed (samples to reach 90% of final performance)
    let lms_90_percent = find_convergence_point(&lms_errors, 0.1);
    let rls_90_percent = find_convergence_point(&rls_errors, 0.1);

    println!("LMS convergence time (90%): {} samples", lms_90_percent);
    println!("RLS convergence time (90%): {} samples", rls_90_percent);

    // Example 4: Echo Cancellation Simulation
    println!("\n\n4. Echo Cancellation Simulation");
    println!("------------------------------");

    // Simulate echo path: h_echo = [0.0, 0.0, 0.3, -0.1, 0.05]
    let echo_path = vec![0.0, 0.0, 0.3, -0.1, 0.05];
    let _echo_delay = 2; // 2-sample delay
    println!("Echo path coefficients: {:?}", echo_path);

    // Create adaptive echo canceler
    let mut echo_canceler = LmsFilter::new(8, 0.01, 0.0).unwrap();

    let num_echo_samples = 200;
    let mut far_end_signal = Vec::with_capacity(num_echo_samples);
    let mut echo_signal = Vec::with_capacity(num_echo_samples);
    let mut near_end_signal = Vec::with_capacity(num_echo_samples);
    let mut microphone_signal = Vec::with_capacity(num_echo_samples);

    // Generate signals
    for i in 0..num_echo_samples {
        // Far-end signal (speaker output)
        let far_end =
            0.7 * (2.0 * PI * i as f64 * 0.03).sin() + 0.3 * (2.0 * PI * i as f64 * 0.07).cos();
        far_end_signal.push(far_end);

        // Echo signal (delayed and filtered far-end)
        let mut echo = 0.0;
        for (j, &coeff) in echo_path.iter().enumerate() {
            if i >= j {
                echo += coeff * far_end_signal[i - j];
            }
        }
        echo_signal.push(echo);

        // Near-end signal (local speaker)
        let near_end = if i > 50 && i < 150 {
            0.5 * (2.0 * PI * i as f64 * 0.02).sin()
        } else {
            0.0
        };
        near_end_signal.push(near_end);

        // Microphone picks up near-end + echo
        microphone_signal.push(near_end + echo);
    }

    // Perform echo cancellation
    let mut echo_canceled = Vec::with_capacity(num_echo_samples);

    for i in 0..num_echo_samples {
        // Adapt filter using far-end as input and microphone as desired
        let (echo_estimate_error_mse) = echo_canceler
            .adapt(far_end_signal[i], microphone_signal[i])
            .unwrap();

        // Cancel echo: output = microphone - echo_estimate
        echo_canceled.push(microphone_signal[i] - echo_estimate);
    }

    // Calculate echo cancellation performance
    let echo_power: f64 =
        echo_signal.iter().map(|&x| x * x).sum::<f64>() / echo_signal.len() as f64;
    let residual_power: f64 = echo_canceled
        .iter()
        .skip(50)
        .take(50)
        .map(|&x| x * x)
        .sum::<f64>()
        / 50.0;
    let echo_cancellation_db = 10.0 * (echo_power / residual_power).log10();

    println!("Echo cancellation achieved: {:.1} dB", echo_cancellation_db);
    println!("Final echo canceler weights: {:?}", echo_canceler.weights());

    // Example 5: Performance Characteristics
    println!("\n\n5. Adaptive Filter Performance Summary");
    println!("------------------------------------");

    println!("Algorithm Comparison:");
    println!("┌─────────┬─────────────┬─────────────┬──────────────┬─────────────┐");
    println!("│Algorithm│ Convergence │ Computational│ Tracking     │ Stability   │");
    println!("│         │ Speed       │ Complexity   │ Capability   │             │");
    println!("├─────────┼─────────────┼─────────────┼──────────────┼─────────────┤");
    println!("│ LMS     │ Slow        │ O(N)        │ Good         │ High        │");
    println!("│ NLMS    │ Medium      │ O(N)        │ Better       │ High        │");
    println!("│ RLS     │ Fast        │ O(N²)       │ Excellent    │ Medium      │");
    println!("└─────────┴─────────────┴─────────────┴──────────────┴─────────────┘");

    println!("\nApplication Guidelines:");
    println!("- LMS: General purpose, robust, low complexity");
    println!("- NLMS: Better for varying signal power levels");
    println!("- RLS: Fast convergence, non-stationary signals, higher complexity");

    println!("\nTypical Applications:");
    println!("- System identification: Channel estimation, inverse modeling");
    println!("- Noise cancellation: Active noise control, hearing aids");
    println!("- Echo cancellation: Telephony, video conferencing");
    println!("- Equalization: Digital communications, audio processing");
    println!("- Prediction: Time series forecasting, compression");
}

/// Calculate Signal-to-Noise Ratio in dB
#[allow(dead_code)]
fn calculate_snr(signal: &[f64], noise: &[f64]) -> f64 {
    let _signal_power: f64 = signal.iter().map(|&x| x * x).sum::<f64>() / signal.len() as f64;
    let noise_power: f64 = noise.iter().map(|&x| x * x).sum::<f64>() / noise.len() as f64;

    if noise_power > 1e-10 {
        10.0 * (signal_power / noise_power).log10()
    } else {
        100.0 // Very high SNR
    }
}

/// Extract noise from difference between clean estimate and true clean signal
#[allow(dead_code)]
fn get_noise_from_clean(_clean_estimate: &[f64], trueclean: &[f64]) -> Vec<f64> {
    _clean_estimate
        .iter()
        .zip(true_clean.iter())
        .map(|(&est, &true_val)| est - true_val)
        .collect()
}

/// Find convergence point (sample number where error drops below threshold)
#[allow(dead_code)]
fn find_convergence_point(_errors: &[f64], thresholdratio: f64) -> usize {
    if errors.is_empty() {
        return 0;
    }

    let final_error = errors[_errors.len() - 1];
    let threshold = final_error + threshold_ratio * errors[0];

    for (i, &error) in errors.iter().enumerate() {
        if error <= threshold {
            return i;
        }
    }

    errors.len() - 1
}