scirs2-signal 0.1.0-rc.2

Signal processing module for SciRS2 (scirs2-signal)
Documentation
use ndarray::{Array2, ArrayView2};
use scirs2_signal::dwt::Wavelet;
use scirs2_signal::dwt2d::{dwt2d_decompose, dwt2d_reconstruct, wavedec2, waverec2, Dwt2dResult};

#[allow(dead_code)]
fn main() -> Result<(), Box<dyn std::error::Error>> {
    println!("2D Wavelet Transform Example");
    println!("===========================\n");

    // Create a sample 8x8 image
    println!("Creating a simple 8x8 test image (gradient pattern)");
    let mut image = Array2::zeros((8, 8));
    for i in 0..8 {
        for j in 0..8 {
            image[[i, j]] = (i * j) as f64;
        }
    }

    // Print the original image
    println!("\nOriginal Image:");
    print_image(image.view());

    // Perform a single-level 2D DWT using Haar wavelet
    println!("\nPerforming single-level 2D DWT using Haar wavelet...");
    let decomposition = dwt2d_decompose(&image, Wavelet::Haar, None)?;

    // Print the subbands
    println!("\nApproximation Coefficients (LL band):");
    print_image(decomposition.approx.view());

    println!("\nHorizontal Detail Coefficients (LH band):");
    print_image(decomposition.detail_h.view());

    println!("\nVertical Detail Coefficients (HL band):");
    print_image(decomposition.detail_v.view());

    println!("\nDiagonal Detail Coefficients (HH band):");
    print_image(decomposition.detail_d.view());

    // Reconstruct the image
    println!("\nReconstructing the image from wavelet coefficients...");
    let reconstructed = dwt2d_reconstruct(&decomposition, Wavelet::Haar, None)?;

    println!("\nReconstructed Image:");
    print_image(reconstructed.view());

    // Calculate reconstruction error
    let mut max_error = 0.0;
    for i in 0..8 {
        for j in 0..8 {
            let error = (image[[i, j]] - reconstructed[[i, j]]).abs();
            if error > max_error {
                max_error = error;
            }
        }
    }
    println!("\nMaximum reconstruction error: {:.2e}", max_error);

    // Now perform a multi-level decomposition (just one level to avoid overflow issues)
    println!("\n\nPerforming multi-level 2D DWT (1 level) using DB4 wavelet...");
    let coeffs = wavedec2(&image, Wavelet::DB(4), 1, None)?;

    println!("\nNumber of decomposition levels: {}", coeffs.len());

    // Print approximation coefficients at the lowest level
    println!("\nApproximation Coefficients at level 1:");
    print_image(coeffs[0].approx.view());

    // Reconstruct from multi-level decomposition
    println!("\nReconstructing from multi-level decomposition...");
    let multi_reconstructed = waverec2(&coeffs, Wavelet::DB(4), None)?;

    // Calculate multi-level reconstruction error
    let mut max_multi_error = 0.0;
    for i in 0..8 {
        for j in 0..8 {
            let error = (image[[i, j]] - multi_reconstructed[[i, j]]).abs();
            if error > max_multi_error {
                max_multi_error = error;
            }
        }
    }
    println!(
        "\nMaximum multi-level reconstruction error: {:.2e}",
        max_multi_error
    );

    // Demonstrate image compression by thresholding small coefficients
    println!("\n\nDemonstrating wavelet-based compression by zeroing small coefficients...");

    // Apply a simple threshold to all detail coefficients
    let mut thresholded_coeffs = coeffs.clone();
    let threshold = 1.0;

    for level_coeffs in &mut thresholded_coeffs {
        threshold_coefficients(level_coeffs, threshold);
    }

    // Count non-zero coefficients before and after thresholding
    let original_nonzero = count_nonzero_coefficients(&coeffs);
    let thresholded_nonzero = count_nonzero_coefficients(&thresholded_coeffs);

    println!(
        "\nOriginal coefficients: {} non-zero values",
        original_nonzero
    );
    println!(
        "After thresholding: {} non-zero values",
        thresholded_nonzero
    );
    println!(
        "Compression ratio: {:.2}x",
        original_nonzero as f64 / thresholded_nonzero as f64
    );

    // Reconstruct from thresholded coefficients
    println!("\nReconstructing from thresholded coefficients...");
    let compressed_image = waverec2(&thresholded_coeffs, Wavelet::DB(4), None)?;

    println!("\nCompressed Image:");
    print_image(compressed_image.view());

    // Calculate compression error
    let mut max_compression_error = 0.0;
    let mut mean_squared_error = 0.0;

    for i in 0..8 {
        for j in 0..8 {
            let error = (image[[i, j]] - compressed_image[[i, j]]).abs();
            if error > max_compression_error {
                max_compression_error = error;
            }
            mean_squared_error += error * error;
        }
    }

    mean_squared_error /= 64.0; // 8x8 = 64 pixels

    println!("\nCompression quality metrics:");
    println!("Maximum error: {:.2e}", max_compression_error);
    println!("Mean squared error: {:.2e}", mean_squared_error);
    println!("PSNR (dB): {:.2}", -10.0 * mean_squared_error.log10());

    println!("\nExample complete!");
    Ok(())
}

// Helper function to print a 2D array in a nicely formatted way
#[allow(dead_code)]
fn print_image(image: ArrayView2<f64>) {
    let (rows, cols) = image.dim();

    for i in 0..rows {
        for j in 0..cols {
            print!("{:6.1} ", image[[i, j]]);
        }
        println!();
    }
}

// Apply a threshold to the detail coefficients of a decomposition
#[allow(dead_code)]
fn threshold_coefficients(decomp: &mut Dwt2dResult, threshold: f64) {
    // Apply threshold to all detail coefficients
    for h in decomp.detail_h.iter_mut() {
        if h.abs() < threshold {
            *h = 0.0;
        }
    }

    for v in decomp.detail_v.iter_mut() {
        if v.abs() < threshold {
            *v = 0.0;
        }
    }

    for d in decomp.detail_d.iter_mut() {
        if d.abs() < threshold {
            *d = 0.0;
        }
    }
}

// Count non-zero coefficients in a decomposition
#[allow(dead_code)]
fn count_nonzero_coefficients(coeffs: &[Dwt2dResult]) -> usize {
    let mut count = 0;

    for decomp in _coeffs {
        // Count non-zero values in approximation coefficients (only for the first level)
        if decomp == coeffs.first().unwrap_or(decomp) {
            for &val in decomp.approx.iter() {
                if val != 0.0 {
                    count += 1;
                }
            }
        }

        // Count non-zero values in detail coefficients
        for &val in decomp.detail_h.iter() {
            if val != 0.0 {
                count += 1;
            }
        }

        for &val in decomp.detail_v.iter() {
            if val != 0.0 {
                count += 1;
            }
        }

        for &val in decomp.detail_d.iter() {
            if val != 0.0 {
                count += 1;
            }
        }
    }

    count
}